diff --git a/captum/attr/_core/layer/layer_conductance.py b/captum/attr/_core/layer/layer_conductance.py index 1f1a5f467..2d15d2527 100644 --- a/captum/attr/_core/layer/layer_conductance.py +++ b/captum/attr/_core/layer/layer_conductance.py @@ -2,7 +2,7 @@ # pyre-strict import typing -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Callable, cast, Dict, List, Literal, Optional, Tuple, Union import torch from captum._utils.common import ( @@ -44,8 +44,7 @@ class LayerConductance(LayerAttribution, GradientAttribution): def __init__( self, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - forward_func: Callable, + forward_func: Callable[..., Tensor], layer: Module, device_ids: Union[None, List[int]] = None, ) -> None: @@ -73,8 +72,6 @@ def has_convergence_delta(self) -> bool: return True @typing.overload - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `75`. def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], @@ -91,8 +88,6 @@ def attribute( ) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ... @typing.overload - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `91`. def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], @@ -376,7 +371,7 @@ def _attribute( layer_evals, ) = compute_layer_gradients_and_eval( forward_fn=self.forward_func, - layer=self.layer, + layer=cast(Module, self.layer), inputs=scaled_features_tpl, additional_forward_args=input_additional_args, target_ind=expanded_target, @@ -389,8 +384,6 @@ def _attribute( # This approximates the total input gradient of each step multiplied # by the step size. grad_diffs = tuple( - # pyre-fixme[58]: `-` is not supported for operand types `Tuple[Tensor, - # ...]` and `Tuple[Tensor, ...]`. layer_eval[num_examples:] - layer_eval[:-num_examples] for layer_eval in layer_evals ) @@ -403,8 +396,7 @@ def _attribute( grad_diff * layer_gradient[:-num_examples], n_steps, num_examples, - # pyre-fixme[16]: `tuple` has no attribute `shape`. - layer_eval.shape[1:], + tuple(layer_eval.shape[1:]), ) for layer_gradient, layer_eval, grad_diff in zip( layer_gradients, layer_evals, grad_diffs