Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.x] add ONNX DistilBERT tests (#19999)
Browse files Browse the repository at this point in the history
* add roberta tests

* add distil bert test
  • Loading branch information
Zha0q1 authored Mar 10, 2021
1 parent f32310d commit d29fa5b
Showing 1 changed file with 51 additions and 4 deletions.
55 changes: 51 additions & 4 deletions tests/python-pytest/onnx/test_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,8 +596,8 @@ def test_roberta_inference_onnxruntime(tmp_path, model_name):
onnx_file = "%s.onnx" % prefix
input_shapes = [(batch, seq_length), (batch,), (batch, num_masked_positions)]
converted_model_path = mx.contrib.onnx.export_model(sym_file, params_file, input_shapes,
[np.float32, np.float32, np.int32],
onnx_file, verbose=True)
[np.float32, np.float32, np.int32],
onnx_file, verbose=True)

sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
Expand Down Expand Up @@ -626,11 +626,11 @@ def test_bert_inference_onnxruntime(tmp_path, model):
name=model,
ctx=ctx,
dataset_name=dataset,
pretrained=False,
pretrained=True,
use_pooler=True,
use_decoder=False,
use_classifier=False)
model.initialize(ctx=ctx)

model.hybridize(static_alloc=True)

batch = 5
Expand Down Expand Up @@ -669,3 +669,50 @@ def test_bert_inference_onnxruntime(tmp_path, model):
shutil.rmtree(tmp_path)


@with_seed()
@pytest.mark.parametrize('model_name', ['distilbert_6_768_12'])
def test_distilbert_inference_onnxruntime(tmp_path, model_name):
tmp_path = str(tmp_path)
try:
import gluonnlp as nlp
dataset = 'distilbert_book_corpus_wiki_en_uncased'
ctx = mx.cpu(0)
model, _ = nlp.model.get_model(
name=model_name,
ctx=ctx,
pretrained=True,
dataset_name=dataset)

model.hybridize(static_alloc=True)

batch = 2
seq_length = 32
num_masked_positions = 1
inputs = mx.nd.random.uniform(0, 30522, shape=(batch, seq_length), dtype='float32', ctx=ctx)
valid_length = mx.nd.array([seq_length] * batch, dtype='float32', ctx=ctx)

sequence_outputs = model(inputs, valid_length)

prefix = "%s/distilbert" % tmp_path
model.export(prefix)
sym_file = "%s-symbol.json" % prefix
params_file = "%s-0000.params" % prefix
onnx_file = "%s.onnx" % prefix

input_shapes = [(batch, seq_length), (batch,)]
converted_model_path = mx.contrib.onnx.export_model(sym_file, params_file, input_shapes,
[np.float32, np.float32],
onnx_file, verbose=True)
sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
sess = onnxruntime.InferenceSession(onnx_file, sess_options)

in_tensors = [inputs, valid_length]
input_dict = dict((sess.get_inputs()[i].name, in_tensors[i].asnumpy()) for i in range(len(in_tensors)))
pred = sess.run(None, input_dict)

assert_almost_equal(sequence_outputs, pred[0])

finally:
shutil.rmtree(tmp_path)

0 comments on commit d29fa5b

Please sign in to comment.