From 2fd49d2b2833029907d1c34c439b7c99a8a19dd1 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 4 Oct 2024 21:47:08 +0100 Subject: [PATCH] Cache: revert DynamicCache init for BC (#33861) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * tmp commit * tmp commit * make fixup * missing removal * fix condition * fix end-to-end compilation * if -> elif * BC * BC * use @deprecate_kwarg("num_hidden_layers", version="4.47.0") * wups the import * 🥴 --------- Co-authored-by: Arthur Zucker --- tests/generation/test_utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index e4844cd4aed82c..7cd8c35d2ace6f 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1849,11 +1849,18 @@ def test_new_cache_format(self, num_beams, do_sample): if config.is_encoder_decoder: cache_cls = EncoderDecoderCache past_key_values = cache_cls(DynamicCache(), DynamicCache()) + past_key_values = cache_cls(DynamicCache(), DynamicCache()) else: cache_cls = DynamicCache past_key_values = cache_cls() - new_results = model.generate(past_key_values=past_key_values, **generation_kwargs, **inputs_dict) + new_results = model.generate( + input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + **generation_kwargs, + **inputs_dict, + ) # The two sets of generated sequences must match, despite the cache format between forward passes being # different