Skip to content

Commit

Permalink
fix tensors on different devices in WhisperGenerationMixin (hugging…
Browse files Browse the repository at this point in the history
…face#32316)

* fix

* enable on xpu

* no manual remove

* move to device

* remove to

* add move to
  • Loading branch information
faaany authored and alaskar-10r committed Aug 13, 2024
1 parent a0d14cd commit 38e9401
Showing 1 changed file with 2 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1513,7 +1513,7 @@ def test_with_local_lm_fast(self):
def test_whisper_prompted(self):
processor = AutoProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model = model.to("cuda")
model = model.to(torch_device)

pipe = pipeline(
"automatic-speech-recognition",
Expand All @@ -1523,15 +1523,14 @@ def test_whisper_prompted(self):
max_new_tokens=128,
chunk_length_s=30,
batch_size=16,
device="cuda:0",
)

dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]

# prompt the model to misspell "Mr Quilter" as "Mr Quillter"
whisper_prompt = "Mr. Quillter."
prompt_ids = pipe.tokenizer.get_prompt_ids(whisper_prompt, return_tensors="pt")
prompt_ids = pipe.tokenizer.get_prompt_ids(whisper_prompt, return_tensors="pt").to(torch_device)

unprompted_result = pipe(sample.copy())["text"]
prompted_result = pipe(sample, generate_kwargs={"prompt_ids": prompt_ids})["text"]
Expand Down

0 comments on commit 38e9401

Please sign in to comment.