diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index ab8e6019062b78..149ce144e66272 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1097,6 +1097,25 @@ def _validate_model_class(self): exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}" raise TypeError(exception_message) + def _validate_assistant(self, assistant_model): + if assistant_model is None: + return + + if self.config.is_encoder_decoder and not assistant_model.config.is_encoder_decoder: + attributes_to_check = ["encoder_attention_heads", "encoder_ffn_dim", "encoder_layers"] + attributes_to_check = [attr for attr in dir(assistant_model.config) if attr in attributes_to_check] + are_equal = all( + getattr(self.config, attr) == getattr(assistant_model.config, attr) for attr in attributes_to_check + ) + if not are_equal: + raise ValueError( + "The main model and the assistant don't have compatible encoder-dependent input shapes. " + "Ensure you load the assistant with the correct encoder-decoder class, e.g. `AutoModelForSpeechSeq2Seq` for Whisper." + ) + + if not self.config.vocab_size == assistant_model.config.vocab_size: + raise ValueError("Make sure the main and assistant model use the same tokenizer") + def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): """Validates model kwargs for generation. Generate argument typos will also be caught here.""" # If a `Cache` instance is passed, checks whether the model is compatible with it @@ -1547,6 +1566,7 @@ def generate( tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs) self._validate_model_kwargs(model_kwargs.copy()) + self._validate_assistant(assistant_model) # 2. Set generation parameters if not already defined if synced_gpus is None: diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 2e8682d96a65e0..01faab6d74adac 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -474,7 +474,6 @@ def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs): raise ValueError("num_frames must be used only when stride is None") if self.type in {"seq2seq", "seq2seq_whisper"}: - encoder = self.model.get_encoder() # Consume values so we can let extra information flow freely through # the pipeline (important for `partial` in microphone) if "input_features" in model_inputs: @@ -499,16 +498,11 @@ def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs): generate_kwargs["num_frames"] = stride[0] // self.feature_extractor.hop_length else: generate_kwargs["num_frames"] = [s[0] // self.feature_extractor.hop_length for s in stride] - else: generate_kwargs["num_frames"] = num_frames - if self.type == "seq2seq_whisper" and inputs.shape[-1] > self.feature_extractor.nb_max_frames: - generate_kwargs["input_features"] = inputs - else: - generate_kwargs["encoder_outputs"] = encoder(inputs, attention_mask=attention_mask) - tokens = self.model.generate( + inputs=inputs, attention_mask=attention_mask, **generate_kwargs, ) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index b8e90a5b8ed18e..840b64e17db010 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -45,6 +45,7 @@ AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq, AutoModelForVision2Seq, + AutoProcessor, AutoTokenizer, BartForCausalLM, BartForConditionalGeneration, @@ -2919,6 +2920,67 @@ def test_assisted_decoding_num_assistant_tokens_heuristic_transient_schedule(sel # update_candidate_strategy is called once but assistant_model.generation_config.num_assistant_tokens should stay 5 self.assertEqual(assistant_model.generation_config.num_assistant_tokens, 5) + @slow + def test_validate_assistant(self): + # Generate a random sample: + inputs = np.random.rand(160000) + + # Load a main encoder-decoder model: + model_id = "openai/whisper-large-v2" + processor = AutoProcessor.from_pretrained(model_id) + model = AutoModelForSpeechSeq2Seq.from_pretrained( + model_id, + low_cpu_mem_usage=True, + use_safetensors=True, + ) + model.to(torch_device) + + # process the input: + features = processor(inputs, return_tensors="pt").to(torch_device) + + # Load an encoder-decoder assistant with same encoder as the main model: + assistant_distil_model_id = "distil-whisper/distil-large-v2" + assistant_seq_to_seq = AutoModelForSpeechSeq2Seq.from_pretrained( + assistant_distil_model_id, + use_safetensors=True, + ).to(torch_device) + self.assertTrue(model.generate(**features, assistant_model=assistant_seq_to_seq).sum()) + + # Load its decoder only version: + assistant_causal_lm = AutoModelForCausalLM.from_pretrained( + assistant_distil_model_id, + low_cpu_mem_usage=True, + use_safetensors=True, + ).to(torch_device) + self.assertTrue(model.generate(**features, assistant_model=assistant_causal_lm).sum()) + + # Load an encoder-decoder assistant with a different encoder than the main model: + assistant_distil_model_id = "openai/whisper-tiny" + assistant_seq_to_seq = AutoModelForSpeechSeq2Seq.from_pretrained( + assistant_distil_model_id, + use_safetensors=True, + ).to(torch_device) + self.assertTrue(model.generate(**features, assistant_model=assistant_seq_to_seq).sum()) + + # Load its decoder only version: + assistant_causal_lm = AutoModelForCausalLM.from_pretrained( + assistant_distil_model_id, + low_cpu_mem_usage=True, + use_safetensors=True, + ).to(torch_device) + # It will raise an error as the encoder of the main and assistant model are not compatible: + with self.assertRaises(ValueError): + model.generate(**features, assistant_model=assistant_causal_lm) + + # Load an encoder-decoder model with a different tokenizer than the main model: + assistant_distil_model_id = "hf-internal-testing/tiny-random-SeamlessM4Tv2ForSpeechToText" + assistant_seq_to_seq = AutoModelForSpeechSeq2Seq.from_pretrained( + assistant_distil_model_id, + ).to(torch_device) + # This should raise an error as the main and assistant model don't use the same tokenizer: + with self.assertRaises(ValueError): + model.generate(**features, assistant_model=assistant_seq_to_seq) + def test_compare_unprocessed_logit_scores(self): # Get unprocessed logit scores back from model generate function. # Assert that unprocessed logits from generate() are same as those from modal eval() diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index 5ab18e81d56854..430666990fe5c2 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import time import unittest import numpy as np @@ -23,6 +24,8 @@ MODEL_FOR_CTC_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, AutoFeatureExtractor, + AutoModelForCausalLM, + AutoModelForSpeechSeq2Seq, AutoProcessor, AutoTokenizer, Speech2TextForConditionalGeneration, @@ -1138,6 +1141,94 @@ def test_whisper_language(self): {"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."}, ) + @slow + def test_speculative_decoding_whisper_non_distil(self): + # Load data: + dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:1]") + sample = dataset[0]["audio"] + + # Load model: + model_id = "openai/whisper-large-v2" + processor = AutoProcessor.from_pretrained(model_id) + model = AutoModelForSpeechSeq2Seq.from_pretrained( + model_id, + use_safetensors=True, + ) + + # Load assistant: + assistant_model_id = "openai/whisper-tiny" + assistant_model = AutoModelForSpeechSeq2Seq.from_pretrained( + assistant_model_id, + use_safetensors=True, + ) + + # Load pipeline: + pipe = AutomaticSpeechRecognitionPipeline( + model=model, + tokenizer=processor.tokenizer, + feature_extractor=processor.feature_extractor, + generate_kwargs={"language": "en"}, + ) + + start_time = time.time() + transcription_non_ass = pipe(sample.copy(), generate_kwargs={"assistant_model": assistant_model})["text"] + total_time_assist = time.time() - start_time + + start_time = time.time() + transcription_ass = pipe(sample)["text"] + total_time_non_assist = time.time() - start_time + + self.assertEqual(transcription_ass, transcription_non_ass) + self.assertEqual( + transcription_ass, + " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", + ) + self.assertTrue(total_time_non_assist > total_time_assist, "Make sure that assistant decoding is faster") + + @slow + def test_speculative_decoding_whisper_distil(self): + # Load data: + dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:1]") + sample = dataset[0]["audio"] + + # Load model: + model_id = "openai/whisper-large-v2" + processor = AutoProcessor.from_pretrained(model_id) + model = AutoModelForSpeechSeq2Seq.from_pretrained( + model_id, + use_safetensors=True, + ) + + # Load assistant: + assistant_model_id = "distil-whisper/distil-large-v2" + assistant_model = AutoModelForCausalLM.from_pretrained( + assistant_model_id, + use_safetensors=True, + ) + + # Load pipeline: + pipe = AutomaticSpeechRecognitionPipeline( + model=model, + tokenizer=processor.tokenizer, + feature_extractor=processor.feature_extractor, + generate_kwargs={"language": "en"}, + ) + + start_time = time.time() + transcription_non_ass = pipe(sample.copy(), generate_kwargs={"assistant_model": assistant_model})["text"] + total_time_assist = time.time() - start_time + + start_time = time.time() + transcription_ass = pipe(sample)["text"] + total_time_non_assist = time.time() - start_time + + self.assertEqual(transcription_ass, transcription_non_ass) + self.assertEqual( + transcription_ass, + " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", + ) + self.assertEqual(total_time_non_assist > total_time_assist, "Make sure that assistant decoding is faster") + @slow @require_torch @require_torchaudio