Skip to content

Commit d133a1d

Browse files
Alex Nikulkovfacebook-github-bot
authored andcommitted
Add support for specifying loss type in ReAgent Neural LinUCB
Summary: Add support for specifying which loss function to use in Neural LinUCB. MSE, MAE and binary cross-entropy (BCE) are supported. The default is MSE Reviewed By: BerenLuthien Differential Revision: D50992907 fbshipit-source-id: 997c08a39c487ac68c62dea561e4b0d43d0e821f
1 parent c790065 commit d133a1d

File tree

3 files changed

+9
-4
lines changed

3 files changed

+9
-4
lines changed

reagent/test/training/cb/test_deep_represent_linucb.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ class TestDeepRepresentLinUCB(unittest.TestCase):
2323

2424
def setUp(self):
2525

26-
self.params = DeepRepresentLinUCBTrainerParameters(lr=1e-1)
26+
self.params = DeepRepresentLinUCBTrainerParameters(
27+
lr=1e-1, loss_type="cross_entropy"
28+
)
2729

2830
input_dim = 100
2931
sizes = [20]
@@ -43,14 +45,15 @@ def setUp(self):
4345
sizes=sizes + [linucb_inp_dim],
4446
activations=activations,
4547
mlp_layers=customized_layers,
48+
output_activation="sigmoid",
4649
)
4750

4851
self.policy = Policy(scorer=policy_network, sampler=GreedyActionSampler())
4952
self.trainer = DeepRepresentLinUCBTrainer(self.policy, **self.params.asdict())
5053
self.batch = CBInput(
5154
context_arm_features=torch.rand(2, 2, input_dim),
5255
action=torch.tensor([[0], [1]], dtype=torch.long),
53-
reward=torch.tensor([[1.5], [-2.3]]),
56+
reward=torch.tensor([[0.3], [0.1]]),
5457
) # random Gaussian features
5558

5659
def test_linucb_training_step(self):

reagent/training/cb/deep_represent_linucb_trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from reagent.gym.policies.policy import Policy
99
from reagent.models.deep_represent_linucb import DeepRepresentLinearRegressionUCB
1010
from reagent.training.cb.linucb_trainer import LinUCBTrainer
11+
from reagent.training.cb.supervised_trainer import LOSS_TYPES
1112

1213
logger = logging.getLogger(__name__)
1314

@@ -29,6 +30,7 @@ def __init__(
2930
policy: Policy,
3031
lr: float = 1e-3,
3132
weight_decay: float = 0.0,
33+
loss_type: str = "mse", # one of the LOSS_TYPES names
3234
**kwargs,
3335
):
3436
super().__init__(
@@ -40,7 +42,7 @@ def __init__(
4042
policy.scorer, DeepRepresentLinearRegressionUCB
4143
), "Trainer requires the policy scorer to be DeepRepresentLinearRegressionUCB"
4244
self.scorer = policy.scorer
43-
self.loss_fn = torch.nn.functional.mse_loss
45+
self.loss_fn = LOSS_TYPES[loss_type]
4446
self.lr = lr
4547
self.weight_decay = weight_decay
4648

reagent/training/cb/supervised_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class SupervisedTrainer(BaseCBTrainerWithEval):
2828
def __init__(
2929
self,
3030
policy: Policy,
31-
loss_type: str = "mse", # one of the LossTypes names
31+
loss_type: str = "mse", # one of the LOSS_TYPES names
3232
lr: float = 1e-3,
3333
weight_decay: float = 0.0,
3434
*args,

0 commit comments

Comments
 (0)