Skip to content

Commit

Permalink
Attribute eos default false, fix referenceless case (#125)
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti authored Mar 7, 2022
1 parent c77810b commit e0d22b9
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
8 changes: 4 additions & 4 deletions inseq/attr/feat/feature_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def prepare_and_attribute(
output_step_attributions: bool = False,
attribute_target: bool = False,
output_step_probabilities: bool = False,
include_eos_baseline: bool = True,
include_eos_baseline: bool = False,
**kwargs,
) -> OneOrMoreFeatureAttributionSequenceOutputsWithStepOutputs:
r"""
Expand All @@ -179,7 +179,7 @@ def prepare_and_attribute(
output_step_probabilities (:obj:`bool`, optional): Whether to output the prediction probabilities for the
current generation step or not. Defaults to False.
include_eos_baseline (:obj:`bool`, `optional`): Whether to include the EOS token in the baseline for
attribution. By default the EOS token is used for attribution. Defaults to True.
attribution. By default the EOS token is not used for attribution. Defaults to False.
Returns:
:obj:`OneOrMoreFeatureAttributionSequenceOutputsWithStepOutputs`: One or more
Expand All @@ -205,7 +205,7 @@ def prepare(
sources: FeatureAttributionInput,
targets: FeatureAttributionInput,
prepend_bos_token: bool = True,
include_eos_baseline: bool = True,
include_eos_baseline: bool = False,
) -> EncoderDecoderBatch:
r"""
Prepares sources and target to produce an :class:`~inseq.data.EncoderDecoderBatch`.
Expand All @@ -228,7 +228,7 @@ def prepare(
prepend_bos_token (:obj:`bool`, `optional`): Whether to prepend a BOS token to the
targets, if they are to be encoded. Defaults to True.
include_eos_baseline (:obj:`bool`, `optional`): Whether to include the EOS token in the baseline for
attribution. By default the EOS token is used for attribution. Defaults to True.
attribution. By default the EOS token is not used for attribution. Defaults to False.
Returns:
:obj:`EncoderDecoderBatch`: An :class:`~inseq.data.EncoderDecoderBatch` object containing sources
Expand Down
4 changes: 2 additions & 2 deletions inseq/models/attribution_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def attribute(
output_step_attributions: bool = False,
attribute_target: bool = False,
output_step_probabilities: bool = False,
include_eos_baseline: bool = True,
include_eos_baseline: bool = False,
device: Optional[str] = None,
**kwargs,
) -> OneOrMoreFeatureAttributionSequenceOutputsWithStepOutputs:
Expand All @@ -125,7 +125,7 @@ def attribute(
self.device = device
texts, reference_texts = self.format_input_texts(texts, reference_texts)
if not reference_texts:
texts = self.encode(texts, return_baseline=True)
texts = self.encode(texts, return_baseline=True, include_eos_baseline=include_eos_baseline)
generation_args = kwargs.pop("generation_args", {})
reference_texts = self.generate(texts, return_generation_output=False, **generation_args)
logger.debug(f"reference_texts={reference_texts}")
Expand Down
8 changes: 5 additions & 3 deletions inseq/models/huggingface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,10 @@ def __init__(
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, *tokenizer_inputs, **tokenizer_kwargs)
self.model_name = self.model.config.name_or_path
self.pad_id = self.model.config.pad_token_id
self.unk_id = self.tokenizer.unk_token_id
self.eos_id = self.model.config.eos_token_id
self.bos_id = self.model.config.decoder_start_token_id
self.unk_token = self.tokenizer.unk_token
self.pad_token = self.tokenizer.convert_ids_to_tokens(self.pad_id)
self.bos_token = self.tokenizer.convert_ids_to_tokens(self.bos_id)
self.encoder_embed_scale = 1.0
Expand Down Expand Up @@ -173,7 +175,7 @@ def encode(
as_targets: bool = False,
prepend_bos_token: bool = True,
return_baseline: bool = False,
include_eos_baseline: bool = True,
include_eos_baseline: bool = False,
) -> BatchEncoding:
"""Encode one or multiple texts, producing a BatchEncoding
Expand Down Expand Up @@ -201,9 +203,9 @@ def encode(
baseline_ids = None
if return_baseline:
if include_eos_baseline:
baseline_ids = torch.ones_like(batch["input_ids"]).long() * self.pad_id
baseline_ids = torch.ones_like(batch["input_ids"]).long() * self.unk_id
else:
baseline_ids = batch["input_ids"].ne(self.eos_id).long() * self.pad_id
baseline_ids = batch["input_ids"].ne(self.eos_id).long() * self.unk_id
# We prepend a BOS token only when tokenizing target texts.
if as_targets and prepend_bos_token:
ones_mask = torch.ones((batch["input_ids"].shape[0], 1), device=self.device, dtype=long)
Expand Down

0 comments on commit e0d22b9

Please sign in to comment.