Skip to content

Commit

Permalink
Consistency improvements and tests (#121)
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti authored Mar 2, 2022
1 parent 745caf9 commit ac07669
Show file tree
Hide file tree
Showing 10 changed files with 209 additions and 263 deletions.
11 changes: 10 additions & 1 deletion inseq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from importlib import metadata as importlib_metadata

from .attr import list_feature_attribution_methods
from .data import load_attributions, save_attributions, show_attributions
from .models import AttributionModel, load_model

Expand All @@ -15,4 +16,12 @@ def get_version() -> str:

version: str = get_version()

__all__ = ["AttributionModel", "load_model", "show_attributions", "save_attributions", "load_attributions" "version"]
__all__ = [
"AttributionModel",
"load_model",
"show_attributions",
"save_attributions",
"load_attributions",
"list_feature_attribution_methods",
"version",
]
4 changes: 2 additions & 2 deletions inseq/attr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .feat import FeatureAttribution
from .feat import FeatureAttribution, list_feature_attribution_methods


__all__ = ["FeatureAttribution"]
__all__ = ["FeatureAttribution", "list_feature_attribution_methods"]
3 changes: 2 additions & 1 deletion inseq/attr/feat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .feature_attribution import FeatureAttribution
from .feature_attribution import FeatureAttribution, list_feature_attribution_methods
from .gradient_attribution import (
DiscretizedIntegratedGradientsAttribution,
GradientAttribution,
Expand All @@ -10,6 +10,7 @@

__all__ = [
"FeatureAttribution",
"list_feature_attribution_methods",
"GradientAttribution",
"InputXGradientAttribution",
"IntegratedGradientsAttribution",
Expand Down
38 changes: 22 additions & 16 deletions inseq/attr/feat/feature_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
drop_padding,
extract_signature_args,
find_char_indexes,
get_available_methods,
logits2probs,
pretty_tensor,
)
Expand Down Expand Up @@ -230,15 +231,12 @@ def prepare(
if isinstance(sources, str) or isinstance(sources, list):
sources: BatchEncoding = self.attribution_model.encode(sources, return_baseline=True)
if isinstance(sources, BatchEncoding):
# Layer attribution methods use token ids as inputs instead of embeddings since they
# don't need to attribute scores to embedding vectors.
if self.is_layer_attribution:
embeds = BatchEmbedding()
else:
embeds = BatchEmbedding(
input_embeds=self.attribution_model.embed(sources.input_ids),
baseline_embeds=self.attribution_model.embed(sources.baseline_ids),
)
# Even when we are performing layer attribution, we might need the embeddings
# to compute step probabilities.
embeds = BatchEmbedding(
input_embeds=self.attribution_model.embed(sources.input_ids),
baseline_embeds=self.attribution_model.embed(sources.baseline_ids),
)
sources = Batch(sources, embeds)
if isinstance(targets, str) or isinstance(targets, list):
targets: BatchEncoding = self.attribution_model.encode(
Expand Down Expand Up @@ -295,6 +293,10 @@ def attribute(
:class:`~inseq.data.FeatureAttributionSequenceOutput` depending on the number of inputs, with an
optional added list of single :class:`~inseq.data.FeatureAttributionOutput` for each step.
"""
if self.is_layer_attribution and attribute_target:
raise ValueError(
"Layer attribution methods do not support attribute_target=True. Use regular ones instead."
)
max_generated_length = batch.targets.input_ids.shape[1]
attr_pos_start, attr_pos_end = self.check_attribute_positions(
max_generated_length,
Expand Down Expand Up @@ -514,15 +516,12 @@ def format_attribute_args(
if self.is_layer_attribution:
inputs = (batch.sources.input_ids,)
baselines = (batch.sources.baseline_ids,)
if attribute_target:
inputs = inputs + (batch.targets.input_ids,)
baselines = baselines + (batch.targets.baseline_ids,)
else:
inputs = (batch.sources.input_embeds,)
baselines = (batch.sources.baseline_embeds,)
if attribute_target:
inputs = inputs + (batch.targets.input_embeds,)
baselines = baselines + (batch.targets.baseline_embeds,)
if attribute_target:
inputs = inputs + (batch.targets.input_embeds,)
baselines = baselines + (batch.targets.baseline_embeds,)
attribute_args = {
"inputs": inputs,
"target": target_ids,
Expand Down Expand Up @@ -550,7 +549,7 @@ def get_step_prediction_probabilities(self, batch: EncoderDecoderBatch, target_i
raise ValueError("Attribution model is not set.")
logits = self.attribution_model.score_func(
encoder_tensors=batch.sources.input_embeds,
decoder_tensors=batch.targets.input_embeds,
decoder_embeds=batch.targets.input_embeds,
encoder_attention_mask=batch.sources.attention_mask,
decoder_attention_mask=batch.targets.attention_mask,
use_embeddings=True,
Expand Down Expand Up @@ -608,3 +607,10 @@ def unhook(self, **kwargs) -> NoReturn:
Abstract method, must be implemented by subclasses.
"""
pass


def list_feature_attribution_methods():
"""
Lists all available feature attribution methods.
"""
return get_available_methods(FeatureAttribution)
25 changes: 4 additions & 21 deletions inseq/attr/feat/gradient_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from captum.attr import (
DeepLift,
GradientShap,
InputXGradient,
IntegratedGradients,
LayerDeepLift,
Expand Down Expand Up @@ -51,8 +50,8 @@ def hook(self, **kwargs):
logger.debug(f"target_layer={self.target_layer}")
if isinstance(self.target_layer, str):
self.target_layer = rgetattr(self.attribution_model.model, self.target_layer)
# For now only encoder attribution is supported
self.attribution_model.configure_interpretable_embeddings(do_encoder=not self.is_layer_attribution)
if not self.is_layer_attribution:
self.attribution_model.configure_interpretable_embeddings()

@unset_hook
def unhook(self, **kwargs):
Expand All @@ -61,7 +60,8 @@ def unhook(self, **kwargs):
"""
if self.is_layer_attribution:
self.target_layer = None
self.attribution_model.remove_interpretable_embeddings(do_encoder=not self.is_layer_attribution)
else:
self.attribution_model.remove_interpretable_embeddings()

def attribute_step(
self,
Expand Down Expand Up @@ -254,23 +254,6 @@ def __init__(self, attribution_model):
self.method = Saliency(self.attribution_model.score_func)


class GradientShapAttribution(GradientAttribution):
"""GradientShap attribution method.
Reference implementation:
`https://captum.ai/api/gradient_shap.html <https://captum.ai/api/gradient_shap.html>`__.
"""

method_name = "gradient_shap"

def __init__(self, attribution_model, **kwargs):
super().__init__(attribution_model)
super().__init__(attribution_model)
multiply_by_inputs = kwargs.pop("multiply_by_inputs", True)
self.method = GradientShap(self.attribution_model.score_func, multiply_by_inputs)
self.use_baseline = True


# Layer methods


Expand Down
19 changes: 19 additions & 0 deletions inseq/models/attribution_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,25 @@ def vocabulary_embeddings(self) -> VocabularyEmbeddingsTensor:
def get_embedding_layer(self) -> torch.nn.Module:
pass

def configure_interpretable_embeddings(self, **kwargs) -> None:
"""Configure the model with interpretable embeddings for gradient attribution.
This method needs to be defined for models that cannot receive embeddings directly from their
forward method parameters, to allow the usage of interpretable embeddings as surrogate for
feature attribution methods. Model that support precomputed embedding inputs by design can
skip this method.
"""
pass

def remove_interpretable_embeddings(self, **kwargs) -> None:
"""Removes interpretable embeddings used for gradient attribution.
If the configure_interpretable_embeddings method is defined, this method needs to be defined
to allow restoring original embeddings for the model. This is required for methods using the
decorator @unhooked since they require the original model capabilities.
"""
pass


class HookableModelWrapper(torch.nn.Module):
"""Module to wrap the AttributionModel class
Expand Down
63 changes: 2 additions & 61 deletions inseq/models/huggingface_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
""" HuggingFace Seq2seq model """
from typing import List, Literal, NoReturn, Optional, Tuple, Union, overload
from typing import List, NoReturn, Optional, Tuple, Union

import logging
import warnings

import torch
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer
from torch import long
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers.generation_utils import BeamSampleOutput, BeamSearchOutput, GreedySearchOutput, SampleOutput
Expand Down Expand Up @@ -118,26 +116,21 @@ def configure_embeddings_scale(self):
def score_func(
self,
encoder_tensors: AttributionForwardInputs,
decoder_tensors: AttributionForwardInputs,
decoder_embeds: AttributionForwardInputs,
encoder_attention_mask: Optional[IdsTensor] = None,
decoder_attention_mask: Optional[IdsTensor] = None,
use_embeddings: bool = True,
) -> FullLogitsTensor:
if use_embeddings:
encoder_embeds = encoder_tensors
decoder_embeds = decoder_tensors
encoder_ids = None
decoder_ids = None
else:
encoder_embeds = None
decoder_embeds = None
encoder_ids = encoder_tensors
decoder_ids = encoder_tensors
output = self.model(
input_ids=encoder_ids,
inputs_embeds=encoder_embeds,
attention_mask=encoder_attention_mask,
decoder_input_ids=decoder_ids,
decoder_inputs_embeds=decoder_embeds,
decoder_attention_mask=decoder_attention_mask,
)
Expand All @@ -147,26 +140,6 @@ def score_func(
logger.debug(f"logits: {pretty_tensor(logits)}")
return logits

@overload
@unhooked
def generate(
self,
encodings: Union[TextInput, BatchEncoding],
return_generation_output: Literal[False] = False,
**kwargs,
) -> List[str]:
...

@overload
@unhooked
def generate(
self,
encodings: Union[TextInput, BatchEncoding],
return_generation_output: Literal[True],
**kwargs,
) -> Tuple[List[str], GenerationOutput]:
...

@unhooked
def generate(
self,
Expand Down Expand Up @@ -316,35 +289,3 @@ def vocabulary_embeddings(self) -> VocabularyEmbeddingsTensor:

def get_embedding_layer(self) -> torch.nn.Module:
return self.model.get_encoder().embed_tokens

def configure_interpretable_embeddings(self, do_encoder: bool = True, do_decoder: bool = True) -> None:
"""Configure the model with interpretable embeddings for gradient attribution."""
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
try:
if do_encoder:
encoder = self.model.get_encoder()
self.encoder_int_embeds = configure_interpretable_embedding_layer(encoder, "embed_tokens")
if do_decoder:
decoder = self.model.get_decoder()
self.decoder_int_embeds = configure_interpretable_embedding_layer(decoder, "embed_tokens")
except AssertionError:
logger.warn("Interpretable embeddings were already configured for layer embed_tokens")

def remove_interpretable_embeddings(self, do_encoder: bool = True, do_decoder: bool = True) -> None:
warn_msg = (
"Cannot remove interpretable embedding wrapper from {model}."
"No interpretable embedding layer was configured."
)
if do_encoder:
if not self.encoder_int_embeds:
logger.warn(warn_msg.format(model="encoder"))
encoder = self.model.get_encoder()
remove_interpretable_embedding_layer(encoder, self.encoder_int_embeds)
self.encoder_int_embeds = None
if do_decoder:
if not self.decoder_int_embeds:
logger.warn(warn_msg.format(model="decoder"))
decoder = self.model.get_decoder()
remove_interpretable_embedding_layer(decoder, self.decoder_int_embeds)
self.decoder_int_embeds = None
3 changes: 2 additions & 1 deletion inseq/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
pretty_tensor,
rgetattr,
)
from .registry import Registry
from .registry import Registry, get_available_methods
from .torch_utils import euclidean_distance, logits2probs, remap_from_filtered, sum_normalize


Expand All @@ -26,6 +26,7 @@
"pretty_tensor",
"pretty_dict",
"rgetattr",
"get_available_methods",
"isnotebook",
"find_char_indexes",
"extract_signature_args",
Expand Down
22 changes: 0 additions & 22 deletions tests/attr/feat/test_gradient_attribution.py

This file was deleted.

Loading

0 comments on commit ac07669

Please sign in to comment.