From 3695349f09512a199d5d9f31871951d5b2ba9878 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 6 Jun 2024 13:01:48 +0200 Subject: [PATCH 01/10] fix --- src/transformers/generation/utils.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index c6819090892594..30e205da7c4e68 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -3658,6 +3658,7 @@ def _assisted_decoding( """ # init values do_sample = logits_warper is not None + eos_token_id = generation_config.eos_token_id output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores @@ -3696,11 +3697,22 @@ def _assisted_decoding( # 1. Fetch candidate sequences from a `CandidateGenerator` candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids) - candidate_input_ids = candidate_input_ids.to(self.device) if candidate_logits is not None: candidate_logits = candidate_logits.to(self.device) - candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] + + # remove remaining candidate ids if an "eos" token is found, otherwise the target model may + # accept eos and the rest as valid, thus not stopping generation after "eos" + # NOTE: below code is written based on the fact that assisted decoding supports only bs=1 + mask = torch.isin(candidate_input_ids[:, -candidate_length:], eos_token_id) + match_indices = torch.nonzero(mask, as_tuple=True)[1] + if match_indices.numel() > 0: + first_eos_index = match_indices[0].item() + input_ids.shape[1] + candidate_input_ids = candidate_input_ids[:, :first_eos_index] + candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] + if candidate_logits is not None: + candidate_logits = candidate_logits[:, :first_eos_index] + candidate_input_ids = candidate_input_ids.to(self.device) is_done_candidate = stopping_criteria(candidate_input_ids, None) # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain From f4c45de2a10f35aa1cdeba1e7af1fd19287707b0 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 10 Jul 2024 08:26:29 +0200 Subject: [PATCH 02/10] move changes to prompt lookup --- .../generation/candidate_generator.py | 11 +++++++++++ src/transformers/generation/utils.py | 16 ++-------------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index e735d0a2ca7f5a..c35c606b89464a 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -267,6 +267,7 @@ class PromptLookupCandidateGenerator(CandidateGenerator): def __init__( self, + eos_token_id: torch.Tensor = None, num_output_tokens: int = 10, max_matching_ngram_size: int = None, max_length: int = 20, @@ -274,6 +275,7 @@ def __init__( self.num_output_tokens = num_output_tokens self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2 self.max_length = max_length + self.eos_token_id = eos_token_id if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0: raise ValueError("Invalid max_matching_ngram_size or num_output_tokens") @@ -319,6 +321,15 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, if start_idx < end_idx: chosen_ids = input_ids[0, start_idx:end_idx] match_found = True + + # remove remaining candidate ids if an "eos" token is found, otherwise the target model may + # accept eos and the rest as valid, thus not stopping generation after "eos" + # NOTE: below code is written based on the fact that assisted decoding supports only bs=1 + mask = torch.isin(chosen_ids, self.eos_token_id) + match_indices_eos = torch.nonzero(mask) + if match_indices_eos.numel() > 0: + first_eos_index = match_indices_eos[0].item() + chosen_ids = chosen_ids[:first_eos_index] break if match_found: break diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 61cf82cd985f44..c3f72d3377c44e 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -716,6 +716,7 @@ def _get_candidate_generator( """ if generation_config.prompt_lookup_num_tokens is not None: candidate_generator = PromptLookupCandidateGenerator( + eos_token_id=generation_config.eos_token_id, num_output_tokens=generation_config.prompt_lookup_num_tokens, max_matching_ngram_size=generation_config.max_matching_ngram_size, max_length=generation_config.max_length, @@ -3712,7 +3713,6 @@ def _assisted_decoding( """ # init values do_sample = logits_warper is not None - eos_token_id = generation_config.eos_token_id output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores @@ -3756,20 +3756,8 @@ def _assisted_decoding( candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids) if candidate_logits is not None: candidate_logits = candidate_logits.to(self.device) - candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] - # remove remaining candidate ids if an "eos" token is found, otherwise the target model may - # accept eos and the rest as valid, thus not stopping generation after "eos" - # NOTE: below code is written based on the fact that assisted decoding supports only bs=1 - mask = torch.isin(candidate_input_ids[:, -candidate_length:], eos_token_id) - match_indices = torch.nonzero(mask, as_tuple=True)[1] - if match_indices.numel() > 0: - first_eos_index = match_indices[0].item() + input_ids.shape[1] - candidate_input_ids = candidate_input_ids[:, :first_eos_index] - candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] - if candidate_logits is not None: - candidate_logits = candidate_logits[:, :first_eos_index] - candidate_input_ids = candidate_input_ids.to(self.device) + candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] is_done_candidate = stopping_criteria(candidate_input_ids, None) # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain From 49f7ca31d5b8ce6847a6562b495eeb8c1022bbb9 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 10 Jul 2024 08:43:32 +0200 Subject: [PATCH 03/10] add test --- tests/generation/test_utils.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 8fa41fbdbe2b07..862933cfc19bf6 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -77,6 +77,7 @@ MaxLengthCriteria, MinLengthLogitsProcessor, PhrasalConstraint, + PromptLookupCandidateGenerator, SampleDecoderOnlyOutput, SampleEncoderDecoderOutput, StoppingCriteria, @@ -1323,6 +1324,24 @@ def test_assisted_decoding_sample(self): self._check_outputs(output_assisted, input_ids, model.config, use_cache=True) + def test_prompt_lookup_decoding_stops_at_eos(self): + # This test ensures that the prompt lookup generation stops at eos token and does not suggest more tokens + # (see https://github.com/huggingface/transformers/pull/31301) + + input_ids = torch.randint(1, 50, (1, 10), device=torch_device) # generate inputs in range from 1-50 + input_ids[:, 4] = 0 # inject arbitrarily eos-token-id in input ids so that PLD can copy it + input_ids[:, -1] = input_ids[:, 3] # put pre-eos token in the end for match to happen + eos_token_id = torch.tensor([0], device=torch_device) + + # init cand geenerator with max_matching_ngram_size=1 to match per-token + cand_generator = PromptLookupCandidateGenerator( + eos_token_id=eos_token_id, num_output_tokens=4, max_matching_ngram_size=1 + ) + output_prompt_lookup = cand_generator.get_candidates(input_ids)[0] + self.assertTrue( + output_prompt_lookup.shape[-1] == 10 + ) # PLD shouldn't propose any new tokens based on eos-match + def test_generate_with_head_masking(self): """Test designed for encoder-decoder models to ensure the attention head masking is used.""" attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] From 5a8ac7356a4ac6769855f0ee7caa21b1ee104696 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 10 Jul 2024 08:50:24 +0200 Subject: [PATCH 04/10] set eos in assistant model --- src/transformers/generation/candidate_generator.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index c35c606b89464a..3bc4a0de22c670 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -107,6 +107,9 @@ def __init__( # Prepare the assistant and the starting number of candidate tokens self.assistant_model = assistant_model self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens + + # Set eos in assistant same as in target model + self.assistant_model.generation_config.eos_token_id = generation_config.eos_token_id # Prepare the kwargs for the assistant model assistant_kwargs = {} From 92fe7b2559f133558b51db80b32473bb82f94062 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 10 Jul 2024 08:50:41 +0200 Subject: [PATCH 05/10] style --- src/transformers/generation/candidate_generator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 3bc4a0de22c670..d662b866e73671 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -107,8 +107,8 @@ def __init__( # Prepare the assistant and the starting number of candidate tokens self.assistant_model = assistant_model self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens - - # Set eos in assistant same as in target model + + # Set eos in assistant same as in target model self.assistant_model.generation_config.eos_token_id = generation_config.eos_token_id # Prepare the kwargs for the assistant model From b102cd8d5a07631e477087816be4349756f2342a Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 10 Jul 2024 09:09:53 +0200 Subject: [PATCH 06/10] fix flakiness --- tests/generation/test_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 862933cfc19bf6..97620556853410 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1330,7 +1330,9 @@ def test_prompt_lookup_decoding_stops_at_eos(self): input_ids = torch.randint(1, 50, (1, 10), device=torch_device) # generate inputs in range from 1-50 input_ids[:, 4] = 0 # inject arbitrarily eos-token-id in input ids so that PLD can copy it - input_ids[:, -1] = input_ids[:, 3] # put pre-eos token in the end for match to happen + pre_eos_token = 51 + input_ids[:, 3] = pre_eos_token # set pre-eos to arbitrary id which is for sure not present in inputs + input_ids[:, -1] = pre_eos_token # put pre-eos token in the end for the necessary match to happen eos_token_id = torch.tensor([0], device=torch_device) # init cand geenerator with max_matching_ngram_size=1 to match per-token From 4020372d3adc1047aa7dc8e6aeec6e9eedd1a1c2 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 24 Jul 2024 09:42:15 +0200 Subject: [PATCH 07/10] changes for new `main` --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d2e40f4e56e7ff..b7bfeaf40d8c89 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -725,7 +725,7 @@ def _get_candidate_generator( """ if generation_config.prompt_lookup_num_tokens is not None: candidate_generator = PromptLookupCandidateGenerator( - eos_token_id=generation_config.eos_token_id, + eos_token_id=generation_config._eos_token_tensor, num_output_tokens=generation_config.prompt_lookup_num_tokens, max_matching_ngram_size=generation_config.max_matching_ngram_size, max_length=generation_config.max_length, From 1ea3eb8ac41a8d49ea13f118c3bc20bcc0c4247e Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Wed, 24 Jul 2024 17:40:45 +0500 Subject: [PATCH 08/10] Update tests/generation/test_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- tests/generation/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 8d75aceda6afe6..62443b5bc55edd 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1385,7 +1385,7 @@ def test_prompt_lookup_decoding_stops_at_eos(self): eos_token_id = torch.tensor([0], device=torch_device) # init cand geenerator with max_matching_ngram_size=1 to match per-token - cand_generator = PromptLookupCandidateGenerator( + candidate_generator = PromptLookupCandidateGenerator( eos_token_id=eos_token_id, num_output_tokens=4, max_matching_ngram_size=1 ) output_prompt_lookup = cand_generator.get_candidates(input_ids)[0] From fd95ff37362b4023f40a8b8e98a12502a3c123f1 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Wed, 24 Jul 2024 17:40:55 +0500 Subject: [PATCH 09/10] Update tests/generation/test_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- tests/generation/test_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 62443b5bc55edd..a97f3b878bcf74 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1389,9 +1389,8 @@ def test_prompt_lookup_decoding_stops_at_eos(self): eos_token_id=eos_token_id, num_output_tokens=4, max_matching_ngram_size=1 ) output_prompt_lookup = cand_generator.get_candidates(input_ids)[0] - self.assertTrue( - output_prompt_lookup.shape[-1] == 10 - ) # PLD shouldn't propose any new tokens based on eos-match + # PLD shouldn't propose any new tokens based on eos-match + self.assertTrue(output_prompt_lookup.shape[-1] == 10) def test_generate_with_head_masking(self): """Test designed for encoder-decoder models to ensure the attention head masking is used.""" From 8df32caa0a97fa78cacbc42d98ea5a9dc03135bf Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 24 Jul 2024 15:01:06 +0200 Subject: [PATCH 10/10] add comment to explain --- tests/generation/test_utils.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index a97f3b878bcf74..2c440bbd71ae66 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1377,18 +1377,27 @@ def test_prompt_lookup_decoding_stops_at_eos(self): # This test ensures that the prompt lookup generation stops at eos token and does not suggest more tokens # (see https://github.com/huggingface/transformers/pull/31301) + # The main idea is to have an ngram (unigram in our case) that is repeated twice in the input ids. + # First time at the very end, so input ends with the unigrams, and second any arbitrary location. + # Also, we need an EOS token which will be injected just after the arbitrary located ngram. + # We verify that PLD will not copy and propose candidated that contain an EOS token, even if there are overlapping ngrams + # in input ids. Otherwise a proposed EOS along with the trailing (ngrams-1) tokens might be accepted by the target model. + # That seems as if the model "generated" and EOS but didn't stop from user's perspective + input_ids = torch.randint(1, 50, (1, 10), device=torch_device) # generate inputs in range from 1-50 - input_ids[:, 4] = 0 # inject arbitrarily eos-token-id in input ids so that PLD can copy it - pre_eos_token = 51 - input_ids[:, 3] = pre_eos_token # set pre-eos to arbitrary id which is for sure not present in inputs - input_ids[:, -1] = pre_eos_token # put pre-eos token in the end for the necessary match to happen + arbitrary_ngram = 51 # this is the arbitrary ngram, specifically chosen OOV to prevent flaky tests + input_ids[:, 3] = arbitrary_ngram # set pre-eos to arbitrary_ngram which is for sure not present in inputs + input_ids[:, -1] = arbitrary_ngram # put arbitrary_ngram in the end for the necessary match to happen + eos_token_id = torch.tensor([0], device=torch_device) + input_ids[:, 4] = eos_token_id # inject eos-token-id in input ids so that it is located after arbitrary_ngram # init cand geenerator with max_matching_ngram_size=1 to match per-token candidate_generator = PromptLookupCandidateGenerator( eos_token_id=eos_token_id, num_output_tokens=4, max_matching_ngram_size=1 ) - output_prompt_lookup = cand_generator.get_candidates(input_ids)[0] + output_prompt_lookup = candidate_generator.get_candidates(input_ids)[0] + # PLD shouldn't propose any new tokens based on eos-match self.assertTrue(output_prompt_lookup.shape[-1] == 10)