Skip to content

Commit

Permalink
Fix neuron feature ablation pyre fixme issues (pytorch#1462)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#1462

Differential Revision: D67705096
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Dec 30, 2024
1 parent 7e37201 commit 6c1b2da
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions captum/attr/_core/neuron/neuron_feature_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
import torch
from captum._utils.common import _verify_select_neuron
from captum._utils.gradient import _forward_layer_eval
from captum._utils.typing import BaselineType, TensorOrTupleOfTensorsGeneric
from captum._utils.typing import (
BaselineType,
SliceIntType,
TensorOrTupleOfTensorsGeneric,
)
from captum.attr._core.feature_ablation import FeatureAblation
from captum.attr._utils.attribution import NeuronAttribution, PerturbationAttribution
from captum.log import log_usage
Expand All @@ -31,8 +35,7 @@ class NeuronFeatureAblation(NeuronAttribution, PerturbationAttribution):

def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Union[int, float, Tensor]],
layer: Module,
device_ids: Union[None, List[int]] = None,
) -> None:
Expand Down Expand Up @@ -61,8 +64,11 @@ def __init__(
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
neuron_selector: Union[
int,
Tuple[Union[int, SliceIntType], ...],
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
],
baselines: BaselineType = None,
additional_forward_args: Optional[object] = None,
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
Expand Down Expand Up @@ -250,8 +256,7 @@ def attribute(
>>> feature_mask=feature_mask)
"""

# pyre-fixme[3]: Return type must be annotated.
def neuron_forward_func(*args: Any):
def neuron_forward_func(*args: Any) -> Tensor:
with torch.no_grad():
layer_eval = _forward_layer_eval(
self.forward_func,
Expand Down

0 comments on commit 6c1b2da

Please sign in to comment.