@@ -317,24 +317,17 @@ def create_from_tensors_dqn(
317
317
metrics : Optional [torch .Tensor ] = None ,
318
318
):
319
319
old_q_train_state = trainer .q_network .training
320
- # pyre-fixme[16]: `DQNTrainer` has no attribute `reward_network`.
321
320
old_reward_train_state = trainer .reward_network .training
322
- # pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute
323
- # `training`.
324
321
old_q_cpe_train_state = trainer .q_network_cpe .training
325
322
trainer .q_network .train (False )
326
- # pyre-fixme[16]: `Tensor` has no attribute `train`.
327
323
trainer .reward_network .train (False )
328
- # pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute
329
- # `train`.
330
324
trainer .q_network_cpe .train (False )
331
325
332
326
num_actions = trainer .num_actions
333
327
action_mask = actions .float ()
334
328
335
329
# pyre-fixme[6]: Expected `Tensor` for 2nd param but got `FeatureData`.
336
330
rewards = trainer .boost_rewards (rewards , actions )
337
- # pyre-fixme[29]: `Union[nn.Module, torch.Tensor]` is not a function.
338
331
model_values = trainer .q_network_cpe (states )[:, 0 :num_actions ]
339
332
# TODO: make generic get_action_idxs for each trainer class
340
333
# Note: model_outputs are obtained from the q_network for DQN algorithms
@@ -360,7 +353,6 @@ def create_from_tensors_dqn(
360
353
+ str (possible_actions_mask .shape )
361
354
)
362
355
363
- # pyre-fixme[29]: `Union[nn.Module, torch.Tensor]` is not a function.
364
356
rewards_and_metric_rewards = trainer .reward_network (states )
365
357
366
358
# In case we reuse the modular for Q-network
@@ -390,7 +382,6 @@ def create_from_tensors_dqn(
390
382
model_metrics_for_logged_action = None
391
383
model_metrics_values_for_logged_action = None
392
384
else :
393
- # pyre-fixme[29]: `Union[nn.Module, torch.Tensor]` is not a function.
394
385
model_metrics_values = trainer .q_network_cpe (states )
395
386
# Backward compatility
396
387
if hasattr (model_metrics_values , "q_values" ):
@@ -430,12 +421,8 @@ def create_from_tensors_dqn(
430
421
model_metrics_values_for_logged_action_list , dim = 1
431
422
)
432
423
433
- # pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute
434
- # `train`.
435
424
trainer .q_network_cpe .train (old_q_cpe_train_state )
436
425
trainer .q_network .train (old_q_train_state )
437
- # pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute
438
- # `train`.
439
426
trainer .reward_network .train (old_reward_train_state )
440
427
441
428
return cls (
0 commit comments