Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attribute eos default false, fix referenceless case #125

Merged
merged 3 commits into from
Mar 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions inseq/attr/feat/feature_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def prepare_and_attribute(
output_step_attributions: bool = False,
attribute_target: bool = False,
output_step_probabilities: bool = False,
include_eos_baseline: bool = True,
include_eos_baseline: bool = False,
**kwargs,
) -> OneOrMoreFeatureAttributionSequenceOutputsWithStepOutputs:
r"""
Expand All @@ -179,7 +179,7 @@ def prepare_and_attribute(
output_step_probabilities (:obj:`bool`, optional): Whether to output the prediction probabilities for the
current generation step or not. Defaults to False.
include_eos_baseline (:obj:`bool`, `optional`): Whether to include the EOS token in the baseline for
attribution. By default the EOS token is used for attribution. Defaults to True.
attribution. By default the EOS token is not used for attribution. Defaults to False.

Returns:
:obj:`OneOrMoreFeatureAttributionSequenceOutputsWithStepOutputs`: One or more
Expand All @@ -205,7 +205,7 @@ def prepare(
sources: FeatureAttributionInput,
targets: FeatureAttributionInput,
prepend_bos_token: bool = True,
include_eos_baseline: bool = True,
include_eos_baseline: bool = False,
) -> EncoderDecoderBatch:
r"""
Prepares sources and target to produce an :class:`~inseq.data.EncoderDecoderBatch`.
Expand All @@ -228,7 +228,7 @@ def prepare(
prepend_bos_token (:obj:`bool`, `optional`): Whether to prepend a BOS token to the
targets, if they are to be encoded. Defaults to True.
include_eos_baseline (:obj:`bool`, `optional`): Whether to include the EOS token in the baseline for
attribution. By default the EOS token is used for attribution. Defaults to True.
attribution. By default the EOS token is not used for attribution. Defaults to False.

Returns:
:obj:`EncoderDecoderBatch`: An :class:`~inseq.data.EncoderDecoderBatch` object containing sources
Expand Down
4 changes: 2 additions & 2 deletions inseq/models/attribution_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def attribute(
output_step_attributions: bool = False,
attribute_target: bool = False,
output_step_probabilities: bool = False,
include_eos_baseline: bool = True,
include_eos_baseline: bool = False,
device: Optional[str] = None,
**kwargs,
) -> OneOrMoreFeatureAttributionSequenceOutputsWithStepOutputs:
Expand All @@ -125,7 +125,7 @@ def attribute(
self.device = device
texts, reference_texts = self.format_input_texts(texts, reference_texts)
if not reference_texts:
texts = self.encode(texts, return_baseline=True)
texts = self.encode(texts, return_baseline=True, include_eos_baseline=include_eos_baseline)
generation_args = kwargs.pop("generation_args", {})
reference_texts = self.generate(texts, return_generation_output=False, **generation_args)
logger.debug(f"reference_texts={reference_texts}")
Expand Down
8 changes: 5 additions & 3 deletions inseq/models/huggingface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,10 @@ def __init__(
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, *tokenizer_inputs, **tokenizer_kwargs)
self.model_name = self.model.config.name_or_path
self.pad_id = self.model.config.pad_token_id
self.unk_id = self.tokenizer.unk_token_id
self.eos_id = self.model.config.eos_token_id
self.bos_id = self.model.config.decoder_start_token_id
self.unk_token = self.tokenizer.unk_token
self.pad_token = self.tokenizer.convert_ids_to_tokens(self.pad_id)
self.bos_token = self.tokenizer.convert_ids_to_tokens(self.bos_id)
self.encoder_embed_scale = 1.0
Expand Down Expand Up @@ -173,7 +175,7 @@ def encode(
as_targets: bool = False,
prepend_bos_token: bool = True,
return_baseline: bool = False,
include_eos_baseline: bool = True,
include_eos_baseline: bool = False,
) -> BatchEncoding:
"""Encode one or multiple texts, producing a BatchEncoding

Expand Down Expand Up @@ -201,9 +203,9 @@ def encode(
baseline_ids = None
if return_baseline:
if include_eos_baseline:
baseline_ids = torch.ones_like(batch["input_ids"]).long() * self.pad_id
baseline_ids = torch.ones_like(batch["input_ids"]).long() * self.unk_id
else:
baseline_ids = batch["input_ids"].ne(self.eos_id).long() * self.pad_id
baseline_ids = batch["input_ids"].ne(self.eos_id).long() * self.unk_id
# We prepend a BOS token only when tokenizing target texts.
if as_targets and prepend_bos_token:
ones_mask = torch.ones((batch["input_ids"].shape[0], 1), device=self.device, dtype=long)
Expand Down