From 0656e630fbc8d426af346fc7d9bed1d998f60210 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Mon, 30 Dec 2024 16:00:37 -0800 Subject: [PATCH] Fix neuron feature ablation pyre fixme issues (#1462) Summary: Fixing unresolved pyre fixme issues in corresponding file Reviewed By: craymichael Differential Revision: D67705096 --- .../_core/neuron/neuron_feature_ablation.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/captum/attr/_core/neuron/neuron_feature_ablation.py b/captum/attr/_core/neuron/neuron_feature_ablation.py index c72cf806a1..d391481ed4 100644 --- a/captum/attr/_core/neuron/neuron_feature_ablation.py +++ b/captum/attr/_core/neuron/neuron_feature_ablation.py @@ -6,7 +6,11 @@ import torch from captum._utils.common import _verify_select_neuron from captum._utils.gradient import _forward_layer_eval -from captum._utils.typing import BaselineType, TensorOrTupleOfTensorsGeneric +from captum._utils.typing import ( + BaselineType, + SliceIntType, + TensorOrTupleOfTensorsGeneric, +) from captum.attr._core.feature_ablation import FeatureAblation from captum.attr._utils.attribution import NeuronAttribution, PerturbationAttribution from captum.log import log_usage @@ -31,8 +35,7 @@ class NeuronFeatureAblation(NeuronAttribution, PerturbationAttribution): def __init__( self, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - forward_func: Callable, + forward_func: Callable[..., Union[int, float, Tensor]], layer: Module, device_ids: Union[None, List[int]] = None, ) -> None: @@ -61,8 +64,11 @@ def __init__( def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], + neuron_selector: Union[ + int, + Tuple[Union[int, SliceIntType], ...], + Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor], + ], baselines: BaselineType = None, additional_forward_args: Optional[object] = None, feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, @@ -250,8 +256,7 @@ def attribute( >>> feature_mask=feature_mask) """ - # pyre-fixme[3]: Return type must be annotated. - def neuron_forward_func(*args: Any): + def neuron_forward_func(*args: Any) -> Tensor: with torch.no_grad(): layer_eval = _forward_layer_eval( self.forward_func,