Skip to content

Commit

Permalink
Fix layer deeplift pyre fixme issues (pytorch#1470)
Browse files Browse the repository at this point in the history
Summary:

Fixing unresolved pyre fixme issues in corresponding file

Reviewed By: cyrjano

Differential Revision: D67705583
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Dec 30, 2024
1 parent 448360e commit a2ab2ee
Showing 1 changed file with 18 additions and 15 deletions.
33 changes: 18 additions & 15 deletions captum/attr/_core/layer/layer_deep_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,9 @@ def attribute(
additional_forward_args,
)

# pyre-fixme[24]: Generic type `Sequence` expects 1 type parameter.
def chunk_output_fn(out: TensorOrTupleOfTensorsGeneric) -> Sequence:
def chunk_output_fn(
out: TensorOrTupleOfTensorsGeneric,
) -> Sequence[Union[Tensor, Sequence[Tensor]]]:
if isinstance(out, Tensor):
return out.chunk(2)
return tuple(out_sub.chunk(2) for out_sub in out)
Expand Down Expand Up @@ -434,8 +435,6 @@ def __init__(

# Ignoring mypy error for inconsistent signature with DeepLiftShap
@typing.overload # type: ignore
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `453`.
def attribute(
self,
inputs: Union[Tensor, Tuple[Tensor, ...]],
Expand All @@ -450,9 +449,7 @@ def attribute(
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...

@typing.overload
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `439`.
@typing.overload # type: ignore
def attribute(
self,
inputs: Union[Tensor, Tuple[Tensor, ...]],
Expand Down Expand Up @@ -654,7 +651,7 @@ def attribute(
) = DeepLiftShap._expand_inputs_baselines_targets(
self, baselines, inputs, target, additional_forward_args
)
attributions = LayerDeepLift.attribute.__wrapped__( # type: ignore
attribs_layer_deeplift = LayerDeepLift.attribute.__wrapped__( # type: ignore
self,
exp_inp,
exp_base,
Expand All @@ -667,8 +664,12 @@ def attribute(
attribute_to_layer_input=attribute_to_layer_input,
custom_attribution_func=custom_attribution_func,
)
delta: Tensor
attributions: Union[Tensor, Tuple[Tensor, ...]]
if return_convergence_delta:
attributions, delta = attributions
attributions, delta = attribs_layer_deeplift
else:
attributions = attribs_layer_deeplift
if isinstance(attributions, tuple):
attributions = tuple(
DeepLiftShap._compute_mean_across_baselines(
Expand All @@ -681,15 +682,17 @@ def attribute(
self, inp_bsz, base_bsz, attributions
)
if return_convergence_delta:
# pyre-fixme[61]: `delta` is undefined, or not always defined.
return attributions, delta
else:
# pyre-fixme[7]: Expected `Union[Tuple[Union[Tensor,
# typing.Tuple[Tensor, ...]], Tensor], Tensor, typing.Tuple[Tensor, ...]]`
# but got `Union[tuple[Tensor], Tensor]`.
return attributions
return cast(
Union[
Tensor,
Tuple[Tensor, ...],
Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor],
],
attributions,
)

@property
# pyre-fixme[3]: Return type must be annotated.
def multiplies_by_inputs(self) -> bool:
return self._multiply_by_inputs

0 comments on commit a2ab2ee

Please sign in to comment.