Skip to content

Commit 43e0d39

Browse files
generatedunixname89002005307016facebook-github-bot
authored andcommitted
suppress errors in reagent
Differential Revision: D47142615 fbshipit-source-id: 5a6e3a5e3fc202dd5cdc890d1508a57d58b6cef4
1 parent 24c6dfb commit 43e0d39

File tree

12 files changed

+0
-39
lines changed

12 files changed

+0
-39
lines changed

reagent/evaluation/evaluation_data_page.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -317,24 +317,17 @@ def create_from_tensors_dqn(
317317
metrics: Optional[torch.Tensor] = None,
318318
):
319319
old_q_train_state = trainer.q_network.training
320-
# pyre-fixme[16]: `DQNTrainer` has no attribute `reward_network`.
321320
old_reward_train_state = trainer.reward_network.training
322-
# pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute
323-
# `training`.
324321
old_q_cpe_train_state = trainer.q_network_cpe.training
325322
trainer.q_network.train(False)
326-
# pyre-fixme[16]: `Tensor` has no attribute `train`.
327323
trainer.reward_network.train(False)
328-
# pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute
329-
# `train`.
330324
trainer.q_network_cpe.train(False)
331325

332326
num_actions = trainer.num_actions
333327
action_mask = actions.float()
334328

335329
# pyre-fixme[6]: Expected `Tensor` for 2nd param but got `FeatureData`.
336330
rewards = trainer.boost_rewards(rewards, actions)
337-
# pyre-fixme[29]: `Union[nn.Module, torch.Tensor]` is not a function.
338331
model_values = trainer.q_network_cpe(states)[:, 0:num_actions]
339332
# TODO: make generic get_action_idxs for each trainer class
340333
# Note: model_outputs are obtained from the q_network for DQN algorithms
@@ -360,7 +353,6 @@ def create_from_tensors_dqn(
360353
+ str(possible_actions_mask.shape)
361354
)
362355

363-
# pyre-fixme[29]: `Union[nn.Module, torch.Tensor]` is not a function.
364356
rewards_and_metric_rewards = trainer.reward_network(states)
365357

366358
# In case we reuse the modular for Q-network
@@ -390,7 +382,6 @@ def create_from_tensors_dqn(
390382
model_metrics_for_logged_action = None
391383
model_metrics_values_for_logged_action = None
392384
else:
393-
# pyre-fixme[29]: `Union[nn.Module, torch.Tensor]` is not a function.
394385
model_metrics_values = trainer.q_network_cpe(states)
395386
# Backward compatility
396387
if hasattr(model_metrics_values, "q_values"):
@@ -430,12 +421,8 @@ def create_from_tensors_dqn(
430421
model_metrics_values_for_logged_action_list, dim=1
431422
)
432423

433-
# pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute
434-
# `train`.
435424
trainer.q_network_cpe.train(old_q_cpe_train_state)
436425
trainer.q_network.train(old_q_train_state)
437-
# pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute
438-
# `train`.
439426
trainer.reward_network.train(old_reward_train_state)
440427

441428
return cls(

reagent/gym/policies/scorers/continuous_scorer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ def sac_scorer(actor_network: ModelBase) -> Scorer:
1111
@torch.no_grad()
1212
def score(preprocessed_obs: rlt.FeatureData) -> GaussianSamplerScore:
1313
actor_network.eval()
14-
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
1514
loc, scale_log = actor_network._get_loc_and_scale_log(preprocessed_obs)
1615
actor_network.train()
1716
return GaussianSamplerScore(loc=loc, scale_log=scale_log)

reagent/model_managers/discrete_dqn_base.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,6 @@ def create_policy(
8686
)
8787
else:
8888
sampler = GreedyActionSampler()
89-
# pyre-fixme[6]: Expected `ModelBase` for 1st param but got
90-
# `Union[torch.Tensor, torch.nn.Module]`.
9189
scorer = discrete_dqn_scorer(trainer_module.q_network)
9290
return Policy(scorer=scorer, sampler=sampler)
9391

reagent/model_managers/parametric_dqn_base.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ def create_policy(
7272
"""Create an online DiscreteDQN Policy from env."""
7373

7474
# FIXME: this only works for one-hot encoded actions
75-
# pyre-fixme[16]: `Tensor` has no attribute `input_prototype`.
7675
action_dim = trainer_module.q_network.input_prototype()[1].float_features.shape[
7776
1
7877
]
@@ -87,8 +86,6 @@ def create_policy(
8786
sampler = SoftmaxActionSampler(temperature=self.rl_parameters.temperature)
8887
scorer = parametric_dqn_scorer(
8988
max_num_actions=action_dim,
90-
# pyre-fixme[6]: Expected `ModelBase` for 2nd param but got
91-
# `Union[torch.Tensor, torch.nn.Module]`.
9289
q_network=trainer_module.q_network,
9390
)
9491
return Policy(scorer=scorer, sampler=sampler)

reagent/model_managers/slate_q_base.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,6 @@ def create_policy(
7272
else:
7373
scorer = slate_q_scorer(
7474
num_candidates=self.num_candidates,
75-
# pyre-fixme[6]: Expected `ModelBase` for 2nd param but got
76-
# `Union[torch.Tensor, torch.nn.Module]`.
7775
q_network=trainer_module.q_network,
7876
)
7977
sampler = TopKSampler(k=self.slate_size)

reagent/models/seq2slate.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,6 @@ class Seq2SlateNet(ModelBase):
856856

857857
def __post_init_post_parse__(self) -> None:
858858
super().__init__()
859-
# pyre-fixme[16]: `Seq2SlateNet` has no attribute `seq2slate`.
860859
self.seq2slate = self._build_model()
861860

862861
def _build_model(self):
@@ -879,7 +878,6 @@ def forward(
879878
greedy: Optional[bool] = None,
880879
):
881880
if mode == Seq2SlateMode.RANK_MODE:
882-
# pyre-fixme[29]: `Union[nn.Module, torch.Tensor]` is not a function.
883881
res = self.seq2slate(
884882
mode=mode.value,
885883
state=input.state.float_features,
@@ -899,7 +897,6 @@ def forward(
899897
assert input.tgt_in_seq is not None
900898
assert input.tgt_in_idx is not None
901899
assert input.tgt_out_idx is not None
902-
# pyre-fixme[29]: `Union[nn.Module, torch.Tensor]` is not a function.
903900
res = self.seq2slate(
904901
mode=mode.value,
905902
state=input.state.float_features,
@@ -915,7 +912,6 @@ def forward(
915912
return rlt.RankingOutput(log_probs=log_probs)
916913
elif mode == Seq2SlateMode.ENCODER_SCORE_MODE:
917914
assert input.tgt_out_idx is not None
918-
# pyre-fixme[29]: `Union[nn.Module, torch.Tensor]` is not a function.
919915
res = self.seq2slate(
920916
mode=mode.value,
921917
state=input.state.float_features,

reagent/net_builder/synthetic_reward_net_builder.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,6 @@ def build_serving_module(
6868
seq_len,
6969
state_preprocessor,
7070
action_preprocessor,
71-
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a
72-
# function.
7371
synthetic_reward_network.export_mlp().cpu().eval(),
7472
)
7573
else:

reagent/training/cfeval/bayes_by_backprop_trainer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ def train_step_gen(
1616
):
1717
weight = self._get_sample_weight(training_batch)
1818

19-
# pyre-ignore seems to be pyre bug for pytorch
2019
loss = self.reward_net.sample_elbo(
2120
torch.cat([training_batch.action, training_batch.state.float_features], 1),
2221
training_batch.reward,
@@ -48,7 +47,6 @@ def validation_step(self, batch: rlt.BanditRewardModelInput, batch_idx: int):
4847
batch = self._training_batch_type.from_dict(batch)
4948

5049
weight = self._get_sample_weight(batch)
51-
# pyre-ignore
5250
loss = self.reward_net.sample_elbo(
5351
torch.cat([batch.action, batch.state.float_features], 1),
5452
batch.reward,

reagent/training/dqn_trainer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,6 @@ def compute_td_loss(
231231

232232
# Get Q-value of action taken
233233
all_q_values = self.q_network(batch.state)
234-
# pyre-fixme[16]: `DQNTrainer` has no attribute `all_action_scores`.
235234
self.all_action_scores = all_q_values.detach()
236235
q_values = torch.sum(all_q_values * batch.action, 1, keepdim=True)
237236
td_loss = self.q_network_loss(q_values, target_q_values.detach())

reagent/training/dqn_trainer_base.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -266,22 +266,17 @@ def _initialize_cpe(
266266
optimizer: an optimizer object for training q_network_cpe.
267267
"""
268268
if not self.calc_cpe_in_training:
269-
# pyre-fixme[16]: `DQNTrainerBase` has no attribute `reward_network`.
270269
self.reward_network = None
271270
return
272271

273272
assert reward_network is not None, "reward_network is required for CPE"
274273
self.reward_network = reward_network
275-
# pyre-fixme[16]: `DQNTrainerBase` has no attribute `reward_network_optimizer`.
276274
self.reward_network_optimizer = optimizer
277275
assert (
278276
q_network_cpe is not None and q_network_cpe_target is not None
279277
), "q_network_cpe and q_network_cpe_target are required for CPE"
280-
# pyre-fixme[16]: `DQNTrainerBase` has no attribute `q_network_cpe`.
281278
self.q_network_cpe = q_network_cpe
282-
# pyre-fixme[16]: `DQNTrainerBase` has no attribute `q_network_cpe_target`.
283279
self.q_network_cpe_target = q_network_cpe_target
284-
# pyre-fixme[16]: `DQNTrainerBase` has no attribute `q_network_cpe_optimizer`.
285280
self.q_network_cpe_optimizer = optimizer
286281
num_output_nodes = len(self.metrics_to_score) * self.num_actions
287282
reward_idx_offsets = torch.arange(
@@ -295,7 +290,6 @@ def _initialize_cpe(
295290
reward_stripped_metrics_to_score = (
296291
self.metrics_to_score[:-1] if len(self.metrics_to_score) > 1 else None
297292
)
298-
# pyre-fixme[16]: `DQNTrainerBase` has no attribute `evaluator`.
299293
self.evaluator = Evaluator(
300294
self._actions,
301295
self.rl_parameters.gamma,

0 commit comments

Comments
 (0)