Skip to content

Commit

Permalink
Custom alignment of contrast_targets for contrastive attribution me…
Browse files Browse the repository at this point in the history
…thods (#195)

* Add support for alignment of contrast targets

* Fix attribution positions

* Update deps

* Alignment utils and tests, todo auto align

* Started auto align logic

* Auto align working, tests missing

* Add tests for auto align
  • Loading branch information
gsarti authored Jun 30, 2023
1 parent 77aa4fc commit c0cc551
Show file tree
Hide file tree
Showing 21 changed files with 1,404 additions and 622 deletions.
52 changes: 15 additions & 37 deletions inseq/attr/feat/attribution_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from ...utils import extract_signature_args
from ...utils import extract_signature_args, get_aligned_idx
from ...utils.typing import (
OneOrMoreAttributionSequences,
OneOrMoreIdSequences,
Expand All @@ -13,7 +13,7 @@
TextInput,
TokenWithId,
)
from ..step_functions import STEP_SCORES_MAP
from ..step_functions import get_step_scores_args

if TYPE_CHECKING:
from ...models import AttributionModel
Expand Down Expand Up @@ -87,36 +87,29 @@ def check_attribute_positions(
return attr_pos_start, attr_pos_end


def get_step_scores(
score_identifier: str = "probability",
step_scores_args: Dict[str, Any] = {},
) -> SingleScorePerStepTensor:
"""Returns step scores for the target tokens in the batch."""
if score_identifier not in STEP_SCORES_MAP:
raise AttributeError(
f"Step score {score_identifier} not found. Available step scores are: "
f"{', '.join(list(STEP_SCORES_MAP.keys()))}. Use the inseq.register_step_function"
"function to register a custom step score."
)
return STEP_SCORES_MAP[score_identifier](**step_scores_args)


def join_token_ids(
tokens: OneOrMoreTokenSequences,
ids: OneOrMoreIdSequences,
contrast_tokens: Optional[OneOrMoreTokenSequences] = None,
contrast_targets_alignments: Optional[List[List[Tuple[int, int]]]] = None,
) -> List[TokenWithId]:
"""Joins tokens and ids into a list of TokenWithId objects."""
if contrast_tokens is None:
contrast_tokens = tokens
# 1:1 alignment between target and contrast tokens
if contrast_targets_alignments is None:
contrast_targets_alignments = [[(idx, idx) for idx, _ in enumerate(seq)] for seq in tokens]
sequences = []
for target_tokens_seq, contrast_target_tokens_seq, input_ids_seq in zip(tokens, contrast_tokens, ids):
for target_tokens_seq, contrast_target_tokens_seq, input_ids_seq, alignments_seq in zip(
tokens, contrast_tokens, ids, contrast_targets_alignments
):
curr_seq = []
for token, contrast_token, idx in zip(target_tokens_seq, contrast_target_tokens_seq, input_ids_seq):
if token != contrast_token:
curr_seq.append(TokenWithId(f"{contrast_token}{token}", -1))
for pos_idx, (token, token_idx) in enumerate(zip(target_tokens_seq, input_ids_seq)):
contrast_pos_idx = get_aligned_idx(pos_idx, alignments_seq)
if token != contrast_target_tokens_seq[contrast_pos_idx]:
curr_seq.append(TokenWithId(f"{contrast_target_tokens_seq[contrast_pos_idx]}{token}", -1))
else:
curr_seq.append(TokenWithId(token, idx))
curr_seq.append(TokenWithId(token, token_idx))
sequences.append(curr_seq)
return sequences

Expand All @@ -135,22 +128,7 @@ def extract_args(
extra_attributed_fn_args, attributed_fn_unused_args = extract_signature_args(
kwargs, attributed_fn, exclude_args=default_args, return_remaining=True
)
extra_step_scores_args = {}
for step_score in step_scores:
if step_score not in STEP_SCORES_MAP:
raise AttributeError(
f"Step score {step_score} not found. Available step scores are: "
f"{', '.join(list(STEP_SCORES_MAP.keys()))}. Use the inseq.register_step_function"
"function to register a custom step score."
)
extra_step_scores_args.update(
**extract_signature_args(
kwargs,
STEP_SCORES_MAP[step_score],
exclude_args=default_args,
return_remaining=False,
)
)
extra_step_scores_args = get_step_scores_args(step_scores, kwargs, default_args)
step_scores_unused_args = {k: v for k, v in kwargs.items() if k not in extra_step_scores_args}
unused_args = {
k: v
Expand Down
48 changes: 39 additions & 9 deletions inseq/attr/feat/feature_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
FeatureAttributionOutput,
FeatureAttributionSequenceOutput,
FeatureAttributionStepOutput,
get_batch_from_inputs,
)
from ...data.viz import close_progress_bar, get_progress_bar, update_progress_bar
from ...utils import (
Expand All @@ -44,7 +45,12 @@
)
from ...utils.typing import ModelIdentifier, SingleScorePerStepTensor
from ..attribution_decorators import batched, set_hook, unset_hook
from .attribution_utils import check_attribute_positions, get_source_target_attributions, get_step_scores, tok2string
from ..step_functions import get_step_scores
from .attribution_utils import (
check_attribute_positions,
get_source_target_attributions,
tok2string,
)

if TYPE_CHECKING:
from ...models import AttributionModel
Expand Down Expand Up @@ -302,13 +308,36 @@ def attribute(
# Sources are empty for decoder-only models
sequences = self.attribution_model.formatter.get_text_sequences(self.attribution_model, batch)
contrast_targets = attributed_fn_args.get("contrast_targets", None)
contrast_targets_alignments = attributed_fn_args.get("contrast_targets_alignments", None)
contrast_targets = [contrast_targets] if isinstance(contrast_targets, str) else contrast_targets
target_tokens_with_ids = self.attribution_model.tokenize_with_ids(
sequences.targets,
as_targets=True,
skip_special_tokens=False,
contrast_inputs=contrast_targets,
contrast_batch = None
if contrast_targets is not None:
as_targets = self.attribution_model.is_encoder_decoder
contrast_batch = get_batch_from_inputs(
attribution_model=self.attribution_model,
inputs=contrast_targets,
as_targets=as_targets,
)
contrast_batch = DecoderOnlyBatch.from_batch(contrast_batch)
contrast_targets_alignments = self.attribution_model.formatter.format_contrast_targets_alignments(
contrast_targets_alignments=contrast_targets_alignments,
target_sequences=sequences.targets,
target_tokens=self.attribution_model.clean_tokens(batch.target_tokens, as_targets=as_targets),
contrast_sequences=contrast_targets,
contrast_tokens=self.attribution_model.clean_tokens(
contrast_batch.target_tokens, as_targets=as_targets
),
special_tokens=self.attribution_model.special_tokens,
)
attributed_fn_args["contrast_targets_alignments"] = contrast_targets_alignments
if "contrast_targets" in step_scores_args:
step_scores_args["contrast_targets_alignments"] = contrast_targets_alignments
target_tokens_with_ids = self.attribution_model.get_token_with_ids(
batch,
contrast_target_tokens=contrast_batch.target_tokens if contrast_batch is not None else None,
contrast_targets_alignments=contrast_targets_alignments,
)

# Manages front padding for decoder-only models, using 0 as lower bound
# when attr_pos_start exceeds target length.
targets_lengths = [
Expand Down Expand Up @@ -359,15 +388,16 @@ def attribute(
batch[:step],
self.attribution_model.convert_ids_to_tokens(tgt_ids.unsqueeze(1), skip_special_tokens=False),
tgt_ids.detach().to("cpu"),
attributed_fn_args,
contrast_batch=contrast_batch,
contrast_targets_alignments=contrast_targets_alignments,
)
attribution_outputs.append(step_output)
if pretty_progress:
tgt_tokens = batch.target_tokens
skipped_prefixes = tok2string(self.attribution_model, tgt_tokens, end=attr_pos_start)
attributed_sentences = tok2string(self.attribution_model, tgt_tokens, attr_pos_start, step + 1)
unattributed_suffixes = tok2string(self.attribution_model, tgt_tokens, step + 1, iter_pos_end)
skipped_suffixes = tok2string(self.attribution_model, tgt_tokens, start=iter_pos_end)
unattributed_suffixes = tok2string(self.attribution_model, tgt_tokens, step + 1, attr_pos_end)
skipped_suffixes = tok2string(self.attribution_model, tgt_tokens, start=attr_pos_end)
update_progress_bar(
pbar,
skipped_prefixes,
Expand Down
68 changes: 61 additions & 7 deletions inseq/attr/step_functions.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import logging
from inspect import getfullargspec
from typing import TYPE_CHECKING, Dict, List, Optional, Protocol, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Protocol, Tuple, Union

import torch
from torch.nn.functional import kl_div, log_softmax
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM
from transformers.modeling_outputs import ModelOutput

from ..data import DecoderOnlyBatch, FeatureAttributionInput, get_batch_from_inputs
from ..data import DecoderOnlyBatch, FeatureAttributionInput, get_batch_from_inputs, slice_batch_from_position
from ..data.aggregation_functions import DEFAULT_ATTRIBUTION_AGGREGATE_DICT
from ..utils import extract_signature_args
from ..utils.typing import EmbeddingsTensor, IdsTensor, SingleScorePerStepTensor, TargetIdsTensor

if TYPE_CHECKING:
Expand Down Expand Up @@ -101,6 +102,7 @@ def _get_contrast_output(
contrast_target_prefixes: Optional[FeatureAttributionInput] = None,
contrast_sources: Optional[FeatureAttributionInput] = None,
contrast_targets: Optional[FeatureAttributionInput] = None,
contrast_targets_alignments: Optional[List[List[Tuple[int, int]]]] = None,
return_contrastive_target_ids: bool = False,
) -> ModelOutput:
"""Utility function to return the output of the model for given contrastive inputs.
Expand All @@ -115,6 +117,11 @@ def _get_contrast_output(
contrast_targets (:obj:`str` or :obj:`list(str)`): Contrastive target text(s) to be compared to the original
target text. If not specified, the original target text is used as contrastive target (will result in same
output unless ``contrast_sources`` or ``contrast_target_prefixes`` are specified). Defaults to :obj:`None`.
contrast_targets_alignments (:obj:`list(tuple(int, int))`, `optional`): A list of tuples of indices, where the
first element is the index of the original target token and the second element is the index of the
contrastive target token, used only if :obj:`contrast_targets` is specified. If an explicit alignment is
not specified, the alignment of the original and contrastive target texts is assumed to be 1:1 for all
available tokens. Defaults to :obj:`None`.
return_contrastive_target_ids (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to return the
contrastive target ids as well as the model output. Defaults to :obj:`False`.
"""
Expand All @@ -128,11 +135,9 @@ def _get_contrast_output(
)
)
curr_prefix_len = decoder_input_ids.size(1)

# We select the next contrastive token as target and truncate contrastive ids
# and their attention map to the current generation step.
c_tgt_ids = c_batch.target_ids[:, curr_prefix_len]
c_batch = c_batch[:curr_prefix_len].to(attribution_model.device)
if len(contrast_targets_alignments) > 0 and isinstance(contrast_targets_alignments[0], list):
contrast_targets_alignments = contrast_targets_alignments[0]
c_batch, c_tgt_ids = slice_batch_from_position(c_batch, curr_prefix_len, contrast_targets_alignments)

if decoder_input_ids.size(0) != c_batch.target_ids.size(0):
raise ValueError(
Expand Down Expand Up @@ -194,6 +199,7 @@ def contrast_prob_fn(
contrast_target_prefixes: Optional[FeatureAttributionInput] = None,
contrast_sources: Optional[FeatureAttributionInput] = None,
contrast_targets: Optional[FeatureAttributionInput] = None,
contrast_targets_alignments: Optional[List[List[Tuple[int, int]]]] = None,
**kwargs,
):
"""Returns the probability of a generation target given contrastive context or target prediction alternative.
Expand All @@ -211,13 +217,19 @@ def contrast_prob_fn(
contrast_targets (:obj:`str` or :obj:`list(str)`): Contrastive target text(s) to be compared to the original
target text. If not specified, the original target text is used as contrastive target (will result in same
output unless ``contrast_sources`` or ``contrast_target_prefixes`` are specified). Defaults to :obj:`None`.
contrast_targets_alignments (:obj:`list(tuple(int, int))`, `optional`): A list of tuples of indices, where the
first element is the index of the original target token and the second element is the index of the
contrastive target token, used only if :obj:`contrast_targets` is specified. If an explicit alignment is
not specified, the alignment of the original and contrastive target texts is assumed to be 1:1 for all
available tokens. Defaults to :obj:`None`.
"""
kwargs.pop("forward_output", None)
c_output, c_tgt_ids = _get_contrast_output(
attribution_model=attribution_model,
contrast_sources=contrast_sources,
contrast_target_prefixes=contrast_target_prefixes,
contrast_targets=contrast_targets,
contrast_targets_alignments=contrast_targets_alignments,
return_contrastive_target_ids=True,
**kwargs,
)
Expand Down Expand Up @@ -317,6 +329,7 @@ def contrast_prob_diff_fn(
contrast_target_prefixes: Optional[FeatureAttributionInput] = None,
contrast_sources: Optional[FeatureAttributionInput] = None,
contrast_targets: Optional[FeatureAttributionInput] = None,
contrast_targets_alignments: Optional[List[List[Tuple[int, int]]]] = None,
**kwargs,
):
"""Returns the difference between next step probability for a candidate generation target vs. a contrastive
Expand All @@ -336,6 +349,11 @@ def contrast_prob_diff_fn(
contrast_targets (:obj:`str` or :obj:`list(str)`): Contrastive target text(s) to be compared to the original
target text. If not specified, the original target text is used as contrastive target (will result in same
output unless ``contrast_sources`` or ``contrast_target_prefixes`` are specified). Defaults to :obj:`None`.
contrast_targets_alignments (:obj:`list(tuple(int, int))`, `optional`): A list of tuples of indices, where the
first element is the index of the original target token and the second element is the index of the
contrastive target token, used only if :obj:`contrast_targets` is specified. If an explicit alignment is
not specified, the alignment of the original and contrastive target texts is assumed to be 1:1 for all
available tokens. Defaults to :obj:`None`.
"""
model_probs = probability_fn(attribution_model, forward_output, target_ids)
contrast_probs = contrast_prob_fn(
Expand All @@ -344,6 +362,7 @@ def contrast_prob_diff_fn(
contrast_sources=contrast_sources,
contrast_target_prefixes=contrast_target_prefixes,
contrast_targets=contrast_targets,
contrast_targets_alignments=contrast_targets_alignments,
**kwargs,
)
# Return the prob difference as target for attribution
Expand Down Expand Up @@ -418,6 +437,41 @@ def mc_dropout_prob_avg_fn(
}


def check_is_step_function(identifier: str) -> None:
if identifier not in STEP_SCORES_MAP:
raise AttributeError(
f"Step score {identifier} not found. Available step scores are: "
f"{', '.join(list(STEP_SCORES_MAP.keys()))}. Use the inseq.register_step_function"
"function to register a custom step score."
)


def get_step_scores(
score_identifier: str = "probability",
step_scores_args: Dict[str, Any] = {},
) -> SingleScorePerStepTensor:
"""Returns step scores for the target tokens in the batch."""
check_is_step_function(score_identifier)
return STEP_SCORES_MAP[score_identifier](**step_scores_args)


def get_step_scores_args(
score_identifiers: List[str], kwargs: Dict[str, Any], default_args: Dict[str, Any]
) -> Dict[str, Any]:
step_scores_args = {}
for step_score in score_identifiers:
check_is_step_function(step_score)
step_scores_args.update(
**extract_signature_args(
kwargs,
STEP_SCORES_MAP[step_score],
exclude_args=default_args,
return_remaining=False,
)
)
return step_scores_args


def list_step_functions() -> List[str]:
"""Lists identifiers for all available step scores. One or more step scores identifiers can be passed to the
:meth:`~inseq.models.AttributionModel.attribute` method either to compute scores while attributing (``step_scores``
Expand Down
10 changes: 9 additions & 1 deletion inseq/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,14 @@
MultiDimensionalFeatureAttributionStepOutput,
get_batch_from_inputs,
)
from .batch import Batch, BatchEmbedding, BatchEncoding, DecoderOnlyBatch, EncoderDecoderBatch
from .batch import (
Batch,
BatchEmbedding,
BatchEncoding,
DecoderOnlyBatch,
EncoderDecoderBatch,
slice_batch_from_position,
)
from .viz import show_attributions

__all__ = [
Expand Down Expand Up @@ -54,4 +61,5 @@
"MultiDimensionalFeatureAttributionStepOutput",
"get_batch_from_inputs",
"list_aggregators",
"slice_batch_from_position",
]
5 changes: 5 additions & 0 deletions inseq/data/attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,11 @@ def from_step_attributions(
]
if tokenized_target_sentences is None:
tokenized_target_sentences = targets
if has_bos_token:
tokenized_target_sentences = [tok_seq[1:] for tok_seq in tokenized_target_sentences]
tokenized_target_sentences = [
drop_padding(tokenized_target_sentences[seq_id], pad_id) for seq_id in range(num_sequences)
]
if attr_pos_end is None:
attr_pos_end = max([len(t) for t in tokenized_target_sentences])
pos_start = [
Expand Down
9 changes: 9 additions & 0 deletions inseq/data/batch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

from ..utils import get_aligned_idx
from ..utils.typing import EmbeddingsTensor, ExpandedTargetIdsTensor, IdsTensor, OneOrMoreTokenSequences
from .data_utils import TensorWrapper

Expand Down Expand Up @@ -229,3 +230,11 @@ def from_batch(self, batch: Batch) -> "DecoderOnlyBatch":
encoding=batch.encoding,
embedding=batch.embedding,
)


def slice_batch_from_position(
batch: DecoderOnlyBatch, curr_idx: int, alignments: Optional[List[Tuple[int, int]]] = None
) -> Tuple[DecoderOnlyBatch, IdsTensor]:
truncate_idx = get_aligned_idx(curr_idx, alignments)
tgt_ids = batch.target_ids[:, truncate_idx]
return batch[:truncate_idx], tgt_ids
Loading

0 comments on commit c0cc551

Please sign in to comment.