Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pipeline's "num_return_sequences" > greater than 1 causes a runtime error with Gemma-2-9B. #31965

Closed
OsamaS99 opened this issue Jul 15, 2024 · 6 comments · Fixed by #32163
Closed

Comments

@OsamaS99
Copy link
Contributor

OsamaS99 commented Jul 15, 2024

System Info

  • transformers version: 4.43.0.dev0
  • Platform: Linux-6.5.0-35-generic-x86_64-with-glibc2.36
  • Python version: 3.11.9
  • Huggingface_hub version: 0.23.4
  • Safetensors version: 0.4.3
  • Accelerate version: 0.32.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.3.1+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA RTX A6000

Who can help?

@ArthurZucker @Narsil

Reproduction

image

  1. Load Gemma-2-9B into a pipeline.
  2. Set num_return_sequences to any value > 1

Expected behavior

"name": "RuntimeError",
"message": "shape mismatch: value tensor of shape [2, 8, 2, 256] cannot be broadcast to indexing result of shape [1, 8, 2, 256]",
	"stack": "---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[6], line 32
     22 tokenizer.padding_side = \"left\"
     24 pl: TextGenerationPipeline = transformers.pipeline(
     25     \"text-generation\",
     26     model=model,
   (...)
     29     device_map=\"auto\",
     30 )
---> 32 sequences_list = pl(
     33     \"hello\",
     34     do_sample=True,
     35     batch_size=16,
     36     top_k=10,
     37     top_p=0.9,
     38     num_return_sequences=2,
     39     repetition_penalty=1.1,
     40     max_new_tokens=1024,
     41 )

File ~/miniconda3/envs/st-ft/lib/python3.11/site-packages/transformers/pipelines/text_generation.py:262, in TextGenerationPipeline.__call__(self, text_inputs, **kwargs)
    260         return super().__call__(chats, **kwargs)
    261 else:
--> 262     return super().__call__(text_inputs, **kwargs)

File ~/miniconda3/envs/st-ft/lib/python3.11/site-packages/transformers/pipelines/base.py:1254, in Pipeline.__call__(self, inputs, num_workers, batch_size, *args, **kwargs)
   1246     return next(
   1247         iter(
   1248             self.get_iterator(
   (...)
   1251         )
   1252     )
   1253 else:
-> 1254     return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)

File ~/miniconda3/envs/st-ft/lib/python3.11/site-packages/transformers/pipelines/base.py:1261, in Pipeline.run_single(self, inputs, preprocess_params, forward_params, postprocess_params)
   1259 def run_single(self, inputs, preprocess_params, forward_params, postprocess_params):
   1260     model_inputs = self.preprocess(inputs, **preprocess_params)
-> 1261     model_outputs = self.forward(model_inputs, **forward_params)
   1262     outputs = self.postprocess(model_outputs, **postprocess_params)
   1263     return outputs

File ~/miniconda3/envs/st-ft/lib/python3.11/site-packages/transformers/pipelines/base.py:1161, in Pipeline.forward(self, model_inputs, **forward_params)
   1159     with inference_context():
   1160         model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
-> 1161         model_outputs = self._forward(model_inputs, **forward_params)
   1162         model_outputs = self._ensure_tensor_on_device(model_outputs, device=torch.device(\"cpu\"))
   1163 else:

File ~/miniconda3/envs/st-ft/lib/python3.11/site-packages/transformers/pipelines/text_generation.py:351, in TextGenerationPipeline._forward(self, model_inputs, **generate_kwargs)
    348         generate_kwargs[\"min_length\"] += prefix_length
    350 # BS x SL
--> 351 generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
    352 out_b = generated_sequence.shape[0]
    353 if self.framework == \"pt\":

File ~/miniconda3/envs/st-ft/lib/python3.11/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/miniconda3/envs/st-ft/lib/python3.11/site-packages/transformers/generation/utils.py:1969, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   1961     input_ids, model_kwargs = self._expand_inputs_for_generation(
   1962         input_ids=input_ids,
   1963         expand_size=generation_config.num_return_sequences,
   1964         is_encoder_decoder=self.config.is_encoder_decoder,
   1965         **model_kwargs,
   1966     )
   1968     # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 1969     result = self._sample(
   1970         input_ids,
   1971         logits_processor=prepared_logits_processor,
   1972         logits_warper=prepared_logits_warper,
   1973         stopping_criteria=prepared_stopping_criteria,
   1974         generation_config=generation_config,
   1975         synced_gpus=synced_gpus,
   1976         streamer=streamer,
   1977         **model_kwargs,
   1978     )
   1980 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
   1981     # 11. prepare logits warper
   1982     prepared_logits_warper = (
   1983         self._get_logits_warper(generation_config, device=input_ids.device)
   1984         if generation_config.do_sample
   1985         else None
   1986     )

File ~/miniconda3/envs/st-ft/lib/python3.11/site-packages/transformers/generation/utils.py:2912, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs)
   2909 model_inputs.update({\"output_hidden_states\": output_hidden_states} if output_hidden_states else {})
   2911 # forward pass to get next token
-> 2912 outputs = self(**model_inputs, return_dict=True)
   2914 if synced_gpus and this_peer_finished:
   2915     continue  # don't waste resources running the code we don't need

File ~/miniconda3/envs/st-ft/lib/python3.11/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/st-ft/lib/python3.11/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File ~/miniconda3/envs/st-ft/lib/python3.11/site-packages/accelerate/hooks.py:169, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    167         output = module._old_forward(*args, **kwargs)
    168 else:
--> 169     output = module._old_forward(*args, **kwargs)
    170 return module._hf_hook.post_forward(module, output)

File ~/miniconda3/envs/st-ft/lib/python3.11/site-packages/transformers/models/gemma2/modeling_gemma2.py:944, in Gemma2ForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
    941 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    943 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
--> 944 outputs = self.model(
    945     input_ids=input_ids,
    946     attention_mask=attention_mask,
    947     position_ids=position_ids,
    948     past_key_values=past_key_values,
    949     inputs_embeds=inputs_embeds,
    950     use_cache=use_cache,
    951     output_attentions=output_attentions,
    952     output_hidden_states=output_hidden_states,
    953     return_dict=return_dict,
    954     cache_position=cache_position,
    955 )
    957 hidden_states = outputs[0]
    958 logits = self.lm_head(hidden_states)

File ~/miniconda3/envs/st-ft/lib/python3.11/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/st-ft/lib/python3.11/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File ~/miniconda3/envs/st-ft/lib/python3.11/site-packages/transformers/models/gemma2/modeling_gemma2.py:784, in Gemma2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
    773     layer_outputs = self._gradient_checkpointing_func(
    774         decoder_layer.__call__,
    775         hidden_states,
   (...)
    781         cache_position,
    782     )
    783 else:
--> 784     layer_outputs = decoder_layer(
    785         hidden_states,
    786         attention_mask=causal_mask,
    787         position_ids=position_ids,
    788         past_key_value=past_key_values,
    789         output_attentions=output_attentions,
    790         use_cache=use_cache,
    791         cache_position=cache_position,
    792     )
    794 hidden_states = layer_outputs[0]
    796 if output_attentions:

File ~/miniconda3/envs/st-ft/lib/python3.11/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/st-ft/lib/python3.11/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File ~/miniconda3/envs/st-ft/lib/python3.11/site-packages/accelerate/hooks.py:169, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    167         output = module._old_forward(*args, **kwargs)
    168 else:
--> 169     output = module._old_forward(*args, **kwargs)
    170 return module._hf_hook.post_forward(module, output)

File ~/miniconda3/envs/st-ft/lib/python3.11/site-packages/transformers/models/gemma2/modeling_gemma2.py:526, in Gemma2DecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
    523 hidden_states = self.input_layernorm(hidden_states)
    525 # Self Attention
--> 526 hidden_states, self_attn_weights, present_key_value = self.self_attn(
    527     hidden_states=hidden_states,
    528     attention_mask=attention_mask,
    529     position_ids=position_ids,
    530     past_key_value=past_key_value,
    531     output_attentions=output_attentions,
    532     use_cache=use_cache,
    533     cache_position=cache_position,
    534 )
    535 hidden_states = self.post_attention_layernorm(hidden_states)
    536 hidden_states = residual + hidden_states

File ~/miniconda3/envs/st-ft/lib/python3.11/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/st-ft/lib/python3.11/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File ~/miniconda3/envs/st-ft/lib/python3.11/site-packages/accelerate/hooks.py:169, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    167         output = module._old_forward(*args, **kwargs)
    168 else:
--> 169     output = module._old_forward(*args, **kwargs)
    170 return module._hf_hook.post_forward(module, output)

File ~/miniconda3/envs/st-ft/lib/python3.11/site-packages/transformers/models/gemma2/modeling_gemma2.py:236, in Gemma2Attention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
    228 if past_key_value is not None:
    229     # sin and cos are specific to RoPE models; cache_position needed for the static cache
    230     cache_kwargs = {
    231         \"sin\": sin,
    232         \"cos\": cos,
    233         \"sliding_window\": self.sliding_window,
    234         \"cache_position\": cache_position,
    235     }
--> 236     key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
    238 key_states = repeat_kv(key_states, self.num_key_value_groups)
    239 value_states = repeat_kv(value_states, self.num_key_value_groups)

File ~/miniconda3/envs/st-ft/lib/python3.11/site-packages/transformers/cache_utils.py:1228, in HybridCache.update(self, key_states, value_states, layer_idx, cache_kwargs)
   1225 else:
   1226     update_fn = self._static_update
-> 1228 return update_fn(
   1229     cache_position,
   1230     layer_idx,
   1231     key_states,
   1232     value_states,
   1233     k_out,
   1234     v_out,
   1235     k_out.shape[2],
   1236 )

File ~/miniconda3/envs/st-ft/lib/python3.11/site-packages/transformers/cache_utils.py:1192, in HybridCache._sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len)
   1189 k_out = k_out[:, :, indices]
   1190 v_out = v_out[:, :, indices]
-> 1192 k_out[:, :, cache_position] = key_states
   1193 v_out[:, :, cache_position] = value_states
   1194 # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)

RuntimeError: shape mismatch: value tensor of shape [2, 8, 2, 256] cannot be broadcast to indexing result of shape [1, 8, 2, 256]"}```
@ArthurZucker
Copy link
Collaborator

hey! Could you share a reproducer as well as the output of transformers-cli env

@OsamaS99
Copy link
Contributor Author

Hello @ArthurZucker,
I have updated my issue.

@ArthurZucker
Copy link
Collaborator

thanks, would you mind copy pasting the code and not a screenshot? 🤗

@OsamaS99
Copy link
Contributor Author

OsamaS99 commented Jul 16, 2024

Sure, here it is.

from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.pipelines.text_generation import TextGenerationPipeline
import torch

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-9b-it",
    torch_dtype = torch.bfloat16,
    device_map="auto",
    attn_implementation="eager",
    trust_remote_code=True,
    token=token
)

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it",
    trust_remote_code=True,
    token=token)
tokenizer.pad_token = tokenizer.bos_token 
tokenizer.padding_side = "left"

pl: TextGenerationPipeline = transformers.pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

sequences_list = pl(
    "hello",
    do_sample=True,
    num_return_sequences=2,
)

@OsamaS99
Copy link
Contributor Author

Same issue was also reported on the model community on HF https://huggingface.co/google/gemma-2-27b-it/discussions/28

@OsamaS99
Copy link
Contributor Author

I think it comes down to the "cache_implementation" of gemma 2, which is of type "hybrid". Maybe this can be fixed by changing line 1767 in generation/utils.py
from
getattr(generation_config, "num_beams", 1) * batch_size,
to

getattr(generation_config, "num_beams", 1) * getattr(generation_config, "num_return_sequences", 1) * batch_size,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants