@@ -401,10 +401,10 @@ def attribute(
401401 if attr_progress is not None :
402402 attr_progress .close ()
403403
404- # pyre-fixme[7]: Expected `Variable[TensorOrTupleOfTensorsGeneric <:
405- # [Tensor, typing.Tuple[Tensor, ...]]]`
406- # but got `Union[Tensor, typing.Tuple[Tensor, ...]]`.
407- return self . _generate_result ( total_attrib , weights , is_inputs_tuple ) # type: ignore # noqa: E501 line too long
404+ return cast (
405+ TensorOrTupleOfTensorsGeneric ,
406+ self . _generate_result ( total_attrib , weights , is_inputs_tuple ),
407+ )
408408
409409 def _attribute_with_independent_feature_masks (
410410 self ,
@@ -629,8 +629,7 @@ def _should_skip_inputs_and_warn(
629629 all_empty = False
630630 if self ._min_examples_per_batch_grouped is not None and (
631631 formatted_inputs [tensor_idx ].shape [0 ]
632- # pyre-ignore[58]: Type has been narrowed to int
633- < self ._min_examples_per_batch_grouped
632+ < cast (int , self ._min_examples_per_batch_grouped )
634633 ):
635634 should_skip = True
636635 break
@@ -789,35 +788,35 @@ def attribute_future(
789788 )
790789
791790 if enable_cross_tensor_attribution :
792- # pyre-fixme[7]: Expected `Future[Variable[TensorOrTupleOfTensorsGeneric
793- # <:[Tensor, typing.Tuple[Tensor, ...]]]]` but got
794- # `Future[Union[Tensor, typing.Tuple[Tensor, ...]]]`
795- return self . _attribute_with_cross_tensor_feature_masks_future ( # type: ignore # noqa: E501 line too long
796- formatted_inputs = formatted_inputs ,
797- formatted_additional_forward_args = formatted_additional_forward_args ,
798- target = target ,
799- baselines = baselines ,
800- formatted_feature_mask = formatted_feature_mask ,
801- attr_progress = attr_progress ,
802- processed_initial_eval_fut = processed_initial_eval_fut ,
803- is_inputs_tuple = is_inputs_tuple ,
804- perturbations_per_eval = perturbations_per_eval ,
791+ return cast (
792+ Future [ TensorOrTupleOfTensorsGeneric ],
793+ self . _attribute_with_cross_tensor_feature_masks_future ( # type: ignore # noqa: E501 line too long
794+ formatted_inputs = formatted_inputs ,
795+ formatted_additional_forward_args = formatted_additional_forward_args ,
796+ target = target ,
797+ baselines = baselines ,
798+ formatted_feature_mask = formatted_feature_mask ,
799+ attr_progress = attr_progress ,
800+ processed_initial_eval_fut = processed_initial_eval_fut ,
801+ is_inputs_tuple = is_inputs_tuple ,
802+ perturbations_per_eval = perturbations_per_eval ,
803+ ) ,
805804 )
806805 else :
807- # pyre-fixme[7]: Expected `Future[Variable[TensorOrTupleOfTensorsGeneric
808- # <:[Tensor, typing.Tuple[Tensor, ...]]]]` but got
809- # `Future[Union[Tensor, typing.Tuple[Tensor, ...]]]`
810- return self . _attribute_with_independent_feature_masks_future ( # type: ignore # noqa: E501 line too long
811- formatted_inputs ,
812- formatted_additional_forward_args ,
813- target ,
814- baselines ,
815- formatted_feature_mask ,
816- perturbations_per_eval ,
817- attr_progress ,
818- processed_initial_eval_fut ,
819- is_inputs_tuple ,
820- ** kwargs ,
806+ return cast (
807+ Future [ TensorOrTupleOfTensorsGeneric ],
808+ self . _attribute_with_independent_feature_masks_future ( # type: ignore # noqa: E501 line too long
809+ formatted_inputs ,
810+ formatted_additional_forward_args ,
811+ target ,
812+ baselines ,
813+ formatted_feature_mask ,
814+ perturbations_per_eval ,
815+ attr_progress ,
816+ processed_initial_eval_fut ,
817+ is_inputs_tuple ,
818+ ** kwargs ,
819+ ) ,
821820 )
822821
823822 def _attribute_with_independent_feature_masks_future (
0 commit comments