Skip to content

Commit

Permalink
[Qwen2Audio] handle input ids expansion during processing (#35534)
Browse files Browse the repository at this point in the history
* add audio_token attribute to proc

* expand input_ids

* and legacy and expanded input_ids

* test update

* split lines

* add possibility not to provide eos and bos audio tokens

* raise errors

* test incorrect number of audio tokens

* add example

* fmt

* typo
  • Loading branch information
eustlb authored Jan 7, 2025
1 parent 628cd83 commit 7f76773
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 48 deletions.
31 changes: 31 additions & 0 deletions docs/source/en/model_doc/qwen2_audio.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,37 @@ The abstract from the paper is the following:

`Qwen2-Audio-7B` and `Qwen2-Audio-7B-Instruct` can be found on the [Huggingface Hub](https://huggingface.co/Qwen)

### Inference

```python
from io import BytesIO
from urllib.request import urlopen
import librosa
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration

model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B", trust_remote_code=True, device_map="auto")
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B", trust_remote_code=True)

prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>Generate the caption in English:"
url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Audio/glass-breaking-151256.mp3"
audio, sr = librosa.load(BytesIO(urlopen(url).read()), sr=processor.feature_extractor.sampling_rate)
inputs = processor(text=prompt, audios=audio, return_tensors="pt").to(model.device)

generate_ids = model.generate(**inputs, max_length=256)
generate_ids = generate_ids[:, inputs.input_ids.size(1):]

response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

# We can also omit the audio_bos and audio_eos tokens
prompt = "<|AUDIO|>Generate the caption in English:"
inputs = processor(text=prompt, audios=audio, return_tensors="pt").to(model.device)

generate_ids = model.generate(**inputs, max_length=256)
generate_ids = generate_ids[:, inputs.input_ids.size(1):]

response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
```

In the following, we demonstrate how to use `Qwen2-Audio-7B-Instruct` for the inference, supporting both voice chat and audio analysis modes. Note that we have used the ChatML format for dialog, in this demo we show how to leverage `apply_chat_template` for this purpose.

### Voice Chat Inference
Expand Down
31 changes: 28 additions & 3 deletions src/transformers/models/qwen2_audio/modeling_qwen2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,9 +1197,34 @@ def forward(
selected_audio_feature = audio_outputs.last_hidden_state
audio_features = self.multi_modal_projector(selected_audio_feature)

inputs_embeds, attention_mask, labels, position_ids, _ = self._merge_input_ids_with_audio_features(
audio_features, audio_output_lengths, inputs_embeds, input_ids, attention_mask, labels
)
# if we have consecutive audio tokens, then it means we expanded input_ids in processing
audio_tokens = input_ids == self.config.audio_token_index
legacy_processing = (audio_tokens[:, :-1] & audio_tokens[:, 1:]).sum() == 0

if legacy_processing:
logger.warning_once(
"Expanding inputs for audio tokens in Qwen2Audio should be done in processing."
)
inputs_embeds, attention_mask, labels, position_ids, _ = self._merge_input_ids_with_audio_features(
audio_features, audio_output_lengths, inputs_embeds, input_ids, attention_mask, labels
)
else:
num_audios, max_audio_tokens, embed_dim = audio_features.shape
audio_features_mask = torch.arange(max_audio_tokens, device=audio_output_lengths.device)[None, :]
audio_features_mask = audio_features_mask < audio_output_lengths[:, None]
audio_features = audio_features[audio_features_mask]

n_audio_tokens = (input_ids == self.config.audio_token_index).sum().item()
n_audio_features = audio_features.shape[0]

if n_audio_tokens != n_audio_features:
raise ValueError(
f"Audio features and audio tokens do not match: tokens: {n_audio_tokens}, features {n_audio_features}"
)
special_audio_mask = (input_ids == self.config.audio_token_index).to(inputs_embeds.device)
special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds)
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)

outputs = self.language_model(
attention_mask=attention_mask,
Expand Down
72 changes: 70 additions & 2 deletions src/transformers/models/qwen2_audio/processing_qwen2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,32 @@ class Qwen2AudioProcessor(ProcessorMixin):
chat_template (`Optional[str]`, *optional*):
The Jinja template to use for formatting the conversation. If not provided, the default chat template
is used.
audio_token (`str`, *optional*, defaults to `"<|AUDIO|>"`):
The token to use for audio tokens.
audio_bos_token (`str`, *optional*, defaults to `"<|audio_bos|>"`):
The token to use for audio bos tokens.
audio_eos_token (`str`, *optional*, defaults to `"<|audio_eos|>"`):
The token to use for audio eos tokens.
"""

attributes = ["feature_extractor", "tokenizer"]
feature_extractor_class = "WhisperFeatureExtractor"
tokenizer_class = "AutoTokenizer"

def __init__(self, feature_extractor=None, tokenizer=None, chat_template=None):
def __init__(
self,
feature_extractor=None,
tokenizer=None,
chat_template=None,
audio_token="<|AUDIO|>",
audio_bos_token="<|audio_bos|>",
audio_eos_token="<|audio_eos|>",
):
if chat_template is None:
chat_template = self.default_chat_template
self.audio_token = tokenizer.audio_token if hasattr(tokenizer, "audio_token") else audio_token
self.audio_bos_token = tokenizer.audio_bos_token if hasattr(tokenizer, "audio_bos_token") else audio_bos_token
self.audio_eos_token = tokenizer.audio_eos_token if hasattr(tokenizer, "audio_eos_token") else audio_eos_token
super().__init__(feature_extractor, tokenizer, chat_template=chat_template)

def __call__(
Expand Down Expand Up @@ -88,7 +105,18 @@ def __call__(

if text is None:
raise ValueError("You need to specify either a `text` input to process.")
inputs = self.tokenizer(text, padding=padding, **kwargs)
elif isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")

# ensure we have as much audios as audio tokens
num_audio_tokens = sum(sample.count(self.audio_token) for sample in text)
num_audios = 1 if type(audios) == np.ndarray else len(audios)
if num_audio_tokens != num_audios:
raise ValueError(
f"Found {num_audio_tokens} {self.audio_token} token{'s' if num_audio_tokens > 1 else ''} in provided text but received {num_audios} audio{'s' if num_audios > 1 else ''}"
)

if audios is not None:
audio_inputs = self.feature_extractor(
Expand All @@ -97,6 +125,46 @@ def __call__(
audio_inputs["feature_attention_mask"] = audio_inputs.pop(
"attention_mask"
) # rename attention_mask to prevent conflicts later on

expanded_text = []
audio_lengths = audio_inputs["feature_attention_mask"].sum(-1).tolist()

for sample in text:
replace_str = []
while self.audio_token in sample:
audio_length = audio_lengths.pop(0)
input_length = (audio_length - 1) // 2 + 1
num_audio_tokens = (input_length - 2) // 2 + 1

expanded_audio_token = self.audio_token * num_audio_tokens

audio_token_start_idx = sample.find(self.audio_token)
audio_token_end_idx = audio_token_start_idx + len(self.audio_token)

has_bos = (
sample[audio_token_start_idx - len(self.audio_bos_token) : audio_token_start_idx]
== self.audio_bos_token
)
has_eos = (
sample[audio_token_end_idx : audio_token_end_idx + len(self.audio_eos_token)]
== self.audio_eos_token
)

# Check if this audio token is surrounded by bos/eos tokens
if not has_bos and not has_eos:
expanded_audio_token = self.audio_bos_token + expanded_audio_token + self.audio_eos_token

replace_str.append(expanded_audio_token)
sample = sample.replace(self.audio_token, "<placeholder>", 1)

while "<placeholder>" in sample:
sample = sample.replace("<placeholder>", replace_str.pop(0), 1)
expanded_text.append(sample)
text = expanded_text

inputs = self.tokenizer(text, padding=padding, **kwargs)

if audios is not None:
inputs.update(audio_inputs)

return BatchFeature(data={**inputs})
Expand Down
73 changes: 30 additions & 43 deletions tests/models/qwen2_audio/test_modeling_qwen2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
parent,
ignore_index=-100,
audio_token_index=0,
seq_length=7,
seq_length=25,
feat_seq_length=60,
text_config={
"model_type": "qwen2",
Expand Down Expand Up @@ -91,7 +91,7 @@ def __init__(
self.is_training = is_training

self.batch_size = 3
self.encoder_seq_length = audio_config["max_source_positions"] // 2 + seq_length - 1
self.encoder_seq_length = seq_length

def get_config(self):
return Qwen2AudioConfig(
Expand All @@ -116,11 +116,13 @@ def prepare_config_and_inputs(self):
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_features_values, feature_attention_mask = config_and_inputs
input_length = (input_features_values.shape[-1] - 1) // 2 + 1
num_audio_tokens = (input_length - 2) // 2 + 1
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
attention_mask[:, :1] = 0
# we are giving 3 audios let's make sure we pass in 3 audios tokens
input_ids[:, 1] = config.audio_token_index
input_ids[:, 1 : 1 + num_audio_tokens] = config.audio_token_index
inputs_dict = {
"input_features": input_features_values,
"feature_attention_mask": feature_attention_mask,
Expand Down Expand Up @@ -237,54 +239,39 @@ def test_small_model_integration_test_single(self):

output = model.generate(**inputs, max_new_tokens=32)

EXPECTED_INPUT_IDS = torch.tensor(
[
[
151644,
8948,
198,
2610,
525,
264,
10950,
17847,
13,
151645,
198,
151644,
872,
198,
14755,
220,
16,
25,
220,
151647,
151646,
151648,
198,
3838,
594,
429,
5112,
30,
151645,
198,
151644,
77091,
198,
]
]
)
# fmt: off
EXPECTED_INPUT_IDS = torch.tensor([[
151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 14755, 220, 16, 25, 220, 151647,
*[151646] * 101,
151648, 198, 3838, 594, 429, 5112, 30, 151645, 198, 151644, 77091, 198,
]])
# fmt: on
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))

EXPECTED_DECODED_TEXT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\nWhat's that sound?<|im_end|>\n<|im_start|>assistant\nIt is the sound of glass breaking.<|im_end|>"
EXPECTED_DECODED_TEXT = (
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|>"
+ "<|AUDIO|>" * 101
+ "<|audio_eos|>\nWhat's that sound?<|im_end|>\n<|im_start|>assistant\nIt is the sound of glass breaking.<|im_end|>"
)

self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=False),
EXPECTED_DECODED_TEXT,
)

# test the error when incorrect number of audio tokens
# fmt: off
inputs["input_ids"] = torch.tensor([[
151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 14755, 220, 16, 25, 220, 151647,
*[151646] * 200,
151648, 198, 3838, 594, 429, 5112, 30, 151645, 198, 151644, 77091, 198,
]])
# fmt: on
with self.assertRaisesRegex(
ValueError, "Audio features and audio tokens do not match: tokens: 200, features 101"
):
model.generate(**inputs, max_new_tokens=32)

@slow
def test_small_model_integration_test_batch(self):
# Let' s make sure we test the preprocessing to replace what is used
Expand Down

0 comments on commit 7f76773

Please sign in to comment.