Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNXRuntimeError]RUNTIME_EXCEPTION : Non-zero status code returned while running Reshape node. #13526

Closed
2 of 4 tasks
BenjaminWegener opened this issue Sep 11, 2021 · 3 comments

Comments

@BenjaminWegener
Copy link

Environment info

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

  • transformers version: 4.5.0.dev0
  • Platform: Windows-10-10.0.19041-SP0
  • Python version: 3.7.2
  • PyTorch version (GPU?): 1.9.0+cpu (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Using GPU in script?: no
  • Using distributed or parallel set-up in script?: no

Who can help

Information

Model I am using (Bert, XLNet ...): gptneo 125M

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

RuntimeException Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_13348/1934748927.py in
66
67 onnx_inputs = _get_inputs(PROMPTS, tokenizer, config)
---> 68 outputs = ort_session.run(['logits'], onnx_inputs)

c:\python37\lib\site-packages\onnxruntime\capi\onnxruntime_inference_collection.py in run(self, output_names, input_feed, run_options)
186 output_names = [output.name for output in self._outputs_meta]
187 try:
--> 188 return self._sess.run(output_names, input_feed, run_options)
189 except C.EPFail as err:
190 if self._enable_fallback:

RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Reshape node. Name:'Reshape_501' Status Message: D:\a_work\1\s\onnxruntime\core\providers\cpu\tensor\reshape_helper.h:42 onnxruntime::ReshapeHelper::ReshapeHelper gsl::narrow_cast<int64_t>(input_shape.Size()) == size was false. The input tensor cannot be reshaped to the requested shape. Input shape:{1,1,1,4096}, requested shape:{1,1,1,16,128}

To reproduce

Steps to reproduce the behavior:

from pathlib import Path
from transformers import GPTNeoForCausalLM, GPT2TokenizerFast, GPTNeoConfig
from transformers.models.gpt_neo import GPTNeoOnnxConfig
from transformers.onnx.convert import export
import numpy as np
import onnxruntime as ort

MODEL_PATH = 'EleutherAI/gpt-neo-1.3B'
#MODEL_PATH = 'EleutherAI/gpt-neo-125M'
TASK = 'causal-lm'
ONNX_MODEL_PATH = Path("gpt_neo_1.3B.onnx")
#ONNX_MODEL_PATH = Path("gpt_neo_125M.onnx")
ONNX_MODEL_PATH.parent.mkdir(exist_ok=True, parents=True)

tokenizer = GPT2TokenizerFast.from_pretrained(MODEL_PATH)
config = GPTNeoConfig.from_pretrained(MODEL_PATH)
onnx_config = GPTNeoOnnxConfig.with_past(config, task=TASK)

print(config)
print(onnx_config)
model = GPTNeoForCausalLM(config=config).from_pretrained(MODEL_PATH)
onnx_inputs, onnx_outputs = export(tokenizer=tokenizer, model=model, config=onnx_config, opset=12, output=ONNX_MODEL_PATH)
print(f'Inputs: {onnx_inputs}')
print(f'Outputs: {onnx_outputs}')

PROMPTS = ['Hello there']


def _get_inputs(prompts, tokenizer, config):
    encodings_dict = tokenizer.batch_encode_plus(prompts)
    # Shape: [batch_size, seq_length]
    input_ids = np.array(encodings_dict["input_ids"], dtype=np.int64)
    # Shape: [batch_size, seq_length]
    attention_mask = np.array(encodings_dict["attention_mask"], dtype=np.float32)

    batch_size, seq_length = input_ids.shape
    past_seq_length = 0
    num_attention_heads = config.num_attention_heads
    hidden_size = config.hidden_size

    even_present_state_shape = [
        batch_size, num_attention_heads, past_seq_length, hidden_size // num_attention_heads
    ]
    odd_present_state_shape = [batch_size, past_seq_length, hidden_size]

    onnx_inputs = {}
    for idx in range(config.num_layers):
        if idx % 2 == 0:
            onnx_inputs[f'past_key_values.{idx}.key'] = np.empty(even_present_state_shape, dtype=np.float32)
            onnx_inputs[f'past_key_values.{idx}.value'] = np.empty(even_present_state_shape, dtype=np.float32)
        else:
            onnx_inputs[f'past_key_values.{idx}.key_value'] = np.empty(odd_present_state_shape, dtype=np.float32)

    onnx_inputs['input_ids'] = input_ids
    onnx_inputs['attention_mask'] = attention_mask

    return onnx_inputs

config = GPTNeoConfig.from_pretrained(MODEL_PATH)
tokenizer = GPT2TokenizerFast.from_pretrained(MODEL_PATH)
ort_session = ort.InferenceSession(str(ONNX_MODEL_PATH))

onnx_inputs = _get_inputs(PROMPTS, tokenizer, config)
outputs = ort_session.run(['logits'], onnx_inputs)

Expected behavior

model exporting and loading without shape mismatch

@BenjaminWegener
Copy link
Author

@michaelbenayoun
Copy link
Member

Hi @BenjaminWegener,
The local attention implementation was simplified.
You do not have to check for the past_key_values idx value anymore, try changing the loop that creates past_key_values tensors like this:

for idx in range(config.num_layers):
        onnx_inputs[f'past_key_values.{idx}.key'] = np.empty(even_present_state_shape, dtype=np.float32)
        onnx_inputs[f'past_key_values.{idx}.value'] = np.empty(even_present_state_shape, dtype=np.float32)

    onnx_inputs['input_ids'] = input_ids
    onnx_inputs['attention_mask'] = attention_mask

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Nov 3, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants