Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using assistant in AutomaticSpeechRecognitionPipeline with different encoder size #30637

Merged
merged 25 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
864df8d
fiw input to generate in pipeline
kamilakesbi May 2, 2024
ff0c638
fixup
kamilakesbi May 2, 2024
749cfaa
pass input_features to generate with assistant
kamilakesbi May 3, 2024
f3011b0
error if model and assistant with different enc size
kamilakesbi May 3, 2024
404f67b
fix
kamilakesbi May 3, 2024
fd492a7
apply review suggestions
kamilakesbi May 6, 2024
e41d519
use self.config.is_encoder_decoder
kamilakesbi May 10, 2024
27242a6
pass inputs to generate directly
kamilakesbi May 10, 2024
726f53f
add slow tests
kamilakesbi May 10, 2024
c7f3f1c
Update src/transformers/generation/utils.py
kamilakesbi May 10, 2024
405606c
Update tests/pipelines/test_pipelines_automatic_speech_recognition.py
kamilakesbi May 10, 2024
2c8c039
Update tests/pipelines/test_pipelines_automatic_speech_recognition.py
kamilakesbi May 10, 2024
5b6f297
Update tests/pipelines/test_pipelines_automatic_speech_recognition.py
kamilakesbi May 10, 2024
03d2c3e
Update tests/pipelines/test_pipelines_automatic_speech_recognition.py
kamilakesbi May 10, 2024
f1c8c8a
Update tests/pipelines/test_pipelines_automatic_speech_recognition.py
kamilakesbi May 10, 2024
d1571a9
apply review
kamilakesbi May 10, 2024
83e17f6
Update src/transformers/generation/utils.py
kamilakesbi May 15, 2024
29046c6
Update tests/pipelines/test_pipelines_automatic_speech_recognition.py
kamilakesbi May 15, 2024
d216376
apply code review
kamilakesbi May 15, 2024
87b08e9
update attributes encoder_xyz to check
kamilakesbi May 15, 2024
a43e202
Update src/transformers/generation/utils.py
kamilakesbi May 20, 2024
b23f1f3
Update src/transformers/generation/utils.py
kamilakesbi May 20, 2024
5547aef
Update src/transformers/generation/utils.py
kamilakesbi May 20, 2024
e2bdde1
add slow test
kamilakesbi May 21, 2024
3a35145
solve conflicts
kamilakesbi May 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,6 +1097,25 @@ def _validate_model_class(self):
exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}"
raise TypeError(exception_message)

def _validate_assistant(self, assistant_model):
kamilakesbi marked this conversation as resolved.
Show resolved Hide resolved
if assistant_model is None:
return

if self.config.is_encoder_decoder and not assistant_model.config.is_encoder_decoder:
attributes_to_check = ["encoder_attention_heads", "encoder_ffn_dim", "encoder_layers"]
attributes_to_check = [attr for attr in dir(assistant_model.config) if attr in attributes_to_check]
are_equal = all(
getattr(self.config, attr) == getattr(assistant_model.config, attr) for attr in attributes_to_check
)
if not are_equal:
raise ValueError(
"The main model and the assistant don't have compatible encoder-dependent input shapes. "
"Ensure you load the assistant with the correct encoder-decoder class, e.g. `AutoModelForSpeechSeq2Seq` for Whisper."
)

if not self.config.vocab_size == assistant_model.config.vocab_size:
raise ValueError("Make sure the main and assistant model use the same tokenizer")

def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
# If a `Cache` instance is passed, checks whether the model is compatible with it
Expand Down Expand Up @@ -1547,6 +1566,7 @@ def generate(
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
self._validate_model_kwargs(model_kwargs.copy())
self._validate_assistant(assistant_model)

# 2. Set generation parameters if not already defined
if synced_gpus is None:
Expand Down
8 changes: 1 addition & 7 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,6 @@ def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs):
raise ValueError("num_frames must be used only when stride is None")

if self.type in {"seq2seq", "seq2seq_whisper"}:
encoder = self.model.get_encoder()
# Consume values so we can let extra information flow freely through
# the pipeline (important for `partial` in microphone)
if "input_features" in model_inputs:
Expand All @@ -499,16 +498,11 @@ def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs):
generate_kwargs["num_frames"] = stride[0] // self.feature_extractor.hop_length
else:
generate_kwargs["num_frames"] = [s[0] // self.feature_extractor.hop_length for s in stride]

else:
generate_kwargs["num_frames"] = num_frames

if self.type == "seq2seq_whisper" and inputs.shape[-1] > self.feature_extractor.nb_max_frames:
generate_kwargs["input_features"] = inputs
else:
generate_kwargs["encoder_outputs"] = encoder(inputs, attention_mask=attention_mask)

tokens = self.model.generate(
inputs=inputs,
attention_mask=attention_mask,
**generate_kwargs,
)
Expand Down
62 changes: 62 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
AutoModelForSeq2SeqLM,
AutoModelForSpeechSeq2Seq,
AutoModelForVision2Seq,
AutoProcessor,
AutoTokenizer,
BartForCausalLM,
BartForConditionalGeneration,
Expand Down Expand Up @@ -2919,6 +2920,67 @@ def test_assisted_decoding_num_assistant_tokens_heuristic_transient_schedule(sel
# update_candidate_strategy is called once but assistant_model.generation_config.num_assistant_tokens should stay 5
self.assertEqual(assistant_model.generation_config.num_assistant_tokens, 5)

@slow
def test_validate_assistant(self):
kamilakesbi marked this conversation as resolved.
Show resolved Hide resolved
# Generate a random sample:
inputs = np.random.rand(160000)

# Load a main encoder-decoder model:
model_id = "openai/whisper-large-v2"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
low_cpu_mem_usage=True,
use_safetensors=True,
)
model.to(torch_device)

# process the input:
features = processor(inputs, return_tensors="pt").to(torch_device)

# Load an encoder-decoder assistant with same encoder as the main model:
assistant_distil_model_id = "distil-whisper/distil-large-v2"
assistant_seq_to_seq = AutoModelForSpeechSeq2Seq.from_pretrained(
assistant_distil_model_id,
use_safetensors=True,
).to(torch_device)
self.assertTrue(model.generate(**features, assistant_model=assistant_seq_to_seq).sum())

# Load its decoder only version:
assistant_causal_lm = AutoModelForCausalLM.from_pretrained(
assistant_distil_model_id,
low_cpu_mem_usage=True,
use_safetensors=True,
).to(torch_device)
self.assertTrue(model.generate(**features, assistant_model=assistant_causal_lm).sum())

# Load an encoder-decoder assistant with a different encoder than the main model:
assistant_distil_model_id = "openai/whisper-tiny"
assistant_seq_to_seq = AutoModelForSpeechSeq2Seq.from_pretrained(
assistant_distil_model_id,
use_safetensors=True,
).to(torch_device)
self.assertTrue(model.generate(**features, assistant_model=assistant_seq_to_seq).sum())

# Load its decoder only version:
assistant_causal_lm = AutoModelForCausalLM.from_pretrained(
assistant_distil_model_id,
low_cpu_mem_usage=True,
use_safetensors=True,
).to(torch_device)
# It will raise an error as the encoder of the main and assistant model are not compatible:
with self.assertRaises(ValueError):
model.generate(**features, assistant_model=assistant_causal_lm)

# Load an encoder-decoder model with a different tokenizer than the main model:
assistant_distil_model_id = "hf-internal-testing/tiny-random-SeamlessM4Tv2ForSpeechToText"
assistant_seq_to_seq = AutoModelForSpeechSeq2Seq.from_pretrained(
assistant_distil_model_id,
).to(torch_device)
# This should raise an error as the main and assistant model don't use the same tokenizer:
with self.assertRaises(ValueError):
model.generate(**features, assistant_model=assistant_seq_to_seq)

def test_compare_unprocessed_logit_scores(self):
# Get unprocessed logit scores back from model generate function.
# Assert that unprocessed logits from generate() are same as those from modal eval()
Expand Down
91 changes: 91 additions & 0 deletions tests/pipelines/test_pipelines_automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import time
import unittest

import numpy as np
Expand All @@ -23,6 +24,8 @@
MODEL_FOR_CTC_MAPPING,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
AutoFeatureExtractor,
AutoModelForCausalLM,
AutoModelForSpeechSeq2Seq,
AutoProcessor,
AutoTokenizer,
Speech2TextForConditionalGeneration,
Expand Down Expand Up @@ -1138,6 +1141,94 @@ def test_whisper_language(self):
{"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."},
)

@slow
def test_speculative_decoding_whisper_non_distil(self):
# Load data:
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:1]")
sample = dataset[0]["audio"]

# Load model:
model_id = "openai/whisper-large-v2"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
use_safetensors=True,
)

# Load assistant:
assistant_model_id = "openai/whisper-tiny"
assistant_model = AutoModelForSpeechSeq2Seq.from_pretrained(
assistant_model_id,
use_safetensors=True,
)

# Load pipeline:
pipe = AutomaticSpeechRecognitionPipeline(
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
generate_kwargs={"language": "en"},
)

start_time = time.time()
transcription_non_ass = pipe(sample.copy(), generate_kwargs={"assistant_model": assistant_model})["text"]
total_time_assist = time.time() - start_time

start_time = time.time()
transcription_ass = pipe(sample)["text"]
total_time_non_assist = time.time() - start_time

self.assertEqual(transcription_ass, transcription_non_ass)
self.assertEqual(
transcription_ass,
" Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
)
self.assertTrue(total_time_non_assist > total_time_assist, "Make sure that assistant decoding is faster")

@slow
def test_speculative_decoding_whisper_distil(self):
kamilakesbi marked this conversation as resolved.
Show resolved Hide resolved
# Load data:
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:1]")
sample = dataset[0]["audio"]

# Load model:
model_id = "openai/whisper-large-v2"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
use_safetensors=True,
)

# Load assistant:
assistant_model_id = "distil-whisper/distil-large-v2"
assistant_model = AutoModelForCausalLM.from_pretrained(
assistant_model_id,
use_safetensors=True,
)

# Load pipeline:
pipe = AutomaticSpeechRecognitionPipeline(
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
generate_kwargs={"language": "en"},
)

start_time = time.time()
transcription_non_ass = pipe(sample.copy(), generate_kwargs={"assistant_model": assistant_model})["text"]
total_time_assist = time.time() - start_time

start_time = time.time()
transcription_ass = pipe(sample)["text"]
total_time_non_assist = time.time() - start_time

self.assertEqual(transcription_ass, transcription_non_ass)
self.assertEqual(
transcription_ass,
" Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
)
self.assertEqual(total_time_non_assist > total_time_assist, "Make sure that assistant decoding is faster")

@slow
@require_torch
@require_torchaudio
Expand Down
Loading