Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPT-Neo ONNX Inference with past is broken #13175

Closed
1 of 4 tasks
whiteRa2bit opened this issue Aug 18, 2021 · 8 comments
Closed
1 of 4 tasks

GPT-Neo ONNX Inference with past is broken #13175

whiteRa2bit opened this issue Aug 18, 2021 · 8 comments

Comments

@whiteRa2bit
Copy link

whiteRa2bit commented Aug 18, 2021

Environment info

  • transformers version: 4.10.0.dev0 (1fec32a)
  • Platform: Linux
  • Python version: 3.8.8
  • PyTorch version (GPU?): 1.9.0a0+2ecb2c7, True
  • Tensorflow version (GPU?): Not Installed, False
  • Using GPU in script?: Yes (3090)
  • Using distributed or parallel set-up in script?: No

Who can help

The issue is connected with a pull #12911:
@michaelbenayoun @mfuntowicz @sgugger @LysandreJik

Information

Model I am using is gpt-neo 1.3B

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

To reproduce

Steps to reproduce the behavior:

  1. Model export
from pathlib import Path
from transformers import GPTNeoForCausalLM, GPT2TokenizerFast, GPTNeoConfig
from transformers.models.gpt_neo import GPTNeoOnnxConfig
from transformers.onnx.convert import export

MODEL_PATH = 'EleutherAI/gpt-neo-1.3B'
TASK = 'causal-lm'
ONNX_MODEL_PATH = Path("onnx_dir/gpt_neo_13b.onnx")
ONNX_MODEL_PATH.parent.mkdir(exist_ok=True, parents=True)


def main():
    tokenizer = GPT2TokenizerFast.from_pretrained(MODEL_PATH)
    config = GPTNeoConfig.from_pretrained(MODEL_PATH)
    onnx_config = GPTNeoOnnxConfig.with_past(config, task=TASK)

    model = GPTNeoForCausalLM(config=config).from_pretrained(MODEL_PATH)
    onnx_inputs, onnx_outputs = export(tokenizer=tokenizer, model=model, config=onnx_config, opset=12, output=ONNX_MODEL_PATH)
    print(f'Inputs: {onnx_inputs}')
    print(f'Outputs: {onnx_outputs}')


if __name__ == '__main__':
    main()
  1. Inference code
import numpy as np
import onnxruntime as ort
from transformers import GPT2TokenizerFast, GPTNeoConfig
from pathlib import Path

MODEL_PATH = 'EleutherAI/gpt-neo-1.3B'
ONNX_MODEL_PATH = Path("onnx_dir/gpt_neo_13b.onnx")
PROMPTS = ['Hello there']


def _get_inputs(prompts, tokenizer, config):
    encodings_dict = tokenizer.batch_encode_plus(prompts)
    # Shape: [batch_size, seq_length]
    input_ids = np.array(encodings_dict["input_ids"], dtype=np.int64)
    # Shape: [batch_size, seq_length]
    attention_mask = np.array(encodings_dict["attention_mask"], dtype=np.float32)

    batch_size, seq_length = input_ids.shape
    past_seq_length = 0
    num_attention_heads = config.num_attention_heads
    hidden_size = config.hidden_size

    even_present_state_shape = [
        batch_size, num_attention_heads, past_seq_length, hidden_size // num_attention_heads
    ]
    odd_present_state_shape = [batch_size, past_seq_length, hidden_size]

    onnx_inputs = {}
    for idx in range(config.num_layers):
        if idx % 2 == 0:
            onnx_inputs[f'past_key_values.{idx}.key'] = np.empty(even_present_state_shape, dtype=np.float32)
            onnx_inputs[f'past_key_values.{idx}.value'] = np.empty(even_present_state_shape, dtype=np.float32)
        else:
            onnx_inputs[f'past_key_values.{idx}.key_value'] = np.empty(odd_present_state_shape, dtype=np.float32)

    onnx_inputs['input_ids'] = input_ids
    onnx_inputs['attention_mask'] = attention_mask

    return onnx_inputs


def main():
    config = GPTNeoConfig.from_pretrained(MODEL_PATH)
    tokenizer = GPT2TokenizerFast.from_pretrained(MODEL_PATH)
    ort_session = ort.InferenceSession(str(ONNX_MODEL_PATH))

    onnx_inputs = _get_inputs(PROMPTS, tokenizer, config)
    outputs = ort_session.run(['logits'], onnx_inputs)


if __name__ == '__main__':
    main()

The inference code runs into the following error:

Traceback (most recent call last):
  ....
  File "inference.py", line 60, in main
    outputs = ort_session.run(['logits'], onnx_inputs)
  File "/opt/conda/lib/python3.8/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 188, in run
    return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Reshape node. Name:'Reshape_501' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/tensor/reshape_helper.h:42 onnxruntime::ReshapeHelper::ReshapeHelper(const onnxruntime::TensorShape&, std::vector<long int>&, bool) gsl::narrow_cast<int64_t>(input_shape.Size()) == size was false. The input tensor cannot be reshaped to the requested shape. Input shape:{1,1,1, 4096}, requested shape:{1,1,1,16,128}

Expected behavior

Onnx Inference for a model with past states should work. While converting without past states the inference works fine.

@patrickvonplaten
Copy link
Contributor

Gently pinging @mfuntowicz here

@whiteRa2bit
Copy link
Author

@michaelbenayoun @mfuntowicz @sgugger @LysandreJik would you be so kind to assist in resolving this issue?

@LysandreJik
Copy link
Member

Hello @whiteRa2bit, thanks for testing out the experimental -with-past feature of the ONNX export! @michaelbenayoun and @mfuntowicz are the best suited to answer, but they're off until early next week. We'll make sure to attend to this issue as soon as they're back! Thank you for your understanding.

@whiteRa2bit
Copy link
Author

@LysandreJik, thanks a lot for letting me know!

@whiteRa2bit
Copy link
Author

whiteRa2bit commented Aug 31, 2021

An update from my side:
Inference works fine with the sequence length equals 1, while for all other lengths it breaks with the error I described above:

I tried to visualize the converted onnx graph using netron and found the node where the error occurs:
image

@michaelbenayoun
Copy link
Member

Hi @whiteRa2bit,
I've actually made the same observation this morning, I am working on it!

@michaelbenayoun
Copy link
Member

#13491 along with #13524 solve the issue, but be careful of 2 things:

  • when exporting the model with past keys and values, the attention mask should have a sequence length of past_sequence_length + input_ids_sequence_length
  • ORT seems to not like inputs produced by np.empty (it produces NaN on my end compared to proper output when using np.zeros or np.ones for instance)

@github-actions
Copy link

github-actions bot commented Oct 8, 2021

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants