diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 0903376c11d9..22e6282b43a4 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -1645,7 +1645,9 @@ def convert_reshape(node, **kwargs): targ_shape = [-1, 0] reverse = 'True' + special_case = False if targ_shape == [0, 0, -3, -3] and reverse != 'True': + special_case = True nodes = [ make_node('Shape', [input_nodes[0]], [name+'_shape']), make_node('Split', [name+'_shape'], [name+'_dim0', name+'_dim1', name+'_dim2', @@ -1657,9 +1659,9 @@ def convert_reshape(node, **kwargs): [name+'_shape_new'], axis=0), make_node('Reshape', [input_nodes[0], name+'_shape_new'], [name], name=name) ] - return nodes if targ_shape == [0, -4, -1, 4, 0, 0] and reverse != 'True': + special_case = True create_tensor([4], name+'_4', kwargs['initializer']) nodes = [ make_node('Shape', [input_nodes[0]], [name+'_shape']), @@ -1670,9 +1672,9 @@ def convert_reshape(node, **kwargs): [name+'_shape_new'], axis=0), make_node('Reshape', [input_nodes[0], name+'_shape_new'], [name], name=name) ] - return nodes if targ_shape == [0, 0, -4, 2, 2, 0, 0] and reverse != 'True': + special_case = True create_tensor([2], name+'_2', kwargs['initializer']) nodes = [ make_node('Shape', [input_nodes[0]], [name+'_shape']), @@ -1682,9 +1684,9 @@ def convert_reshape(node, **kwargs): name+'_dim3', name+'_dim4'], [name+'_shape_new'], axis=0), make_node('Reshape', [input_nodes[0], name+'_shape_new'], [name], name=name) ] - return nodes if targ_shape == [-4, 1, -1, 0, 0, 0] and reverse != 'True': + special_case = True create_tensor([1], name+'_1', kwargs['initializer']) create_tensor([-1], name+'_m1', kwargs['initializer']) nodes = [ @@ -1695,9 +1697,9 @@ def convert_reshape(node, **kwargs): [name+'_shape_new'], axis=0), make_node('Reshape', [input_nodes[0], name+'_shape_new'], [name], name=name) ] - return nodes if targ_shape == [-4, 1, 1000, 0, 0] and reverse != 'True': + special_case = True create_tensor([1], name+'_1', kwargs['initializer']) create_tensor([1000], name+'_1000', kwargs['initializer']) nodes = [ @@ -1707,6 +1709,32 @@ def convert_reshape(node, **kwargs): [name+'_shape_new'], axis=0), make_node('Reshape', [input_nodes[0], name+'_shape_new'], [name], name=name) ] + + if targ_shape == [0, -4, 12, -1, 0] and reverse != 'True': + special_case = True + create_tensor([-1], name+'_m1', kwargs['initializer']) + create_tensor([12], name+'_12', kwargs['initializer']) + nodes = [ + make_node('Shape', [input_nodes[0]], [name+'_shape']), + make_node('Split', [name+'_shape'], [name+'_dim0', name+'_dim1', name+'_dim2'], axis=0), + make_node('Concat', [name+'_dim0', name+'_12', name+'_m1', name+'_dim2'], + [name+'_shape_new'], axis=0), + make_node('Reshape', [input_nodes[0], name+'_shape_new'], [name], name=name) + ] + + if targ_shape == [0, -4, 16, -1, 0] and reverse != 'True': + special_case = True + create_tensor([-1], name+'_m1', kwargs['initializer']) + create_tensor([16], name+'_16', kwargs['initializer']) + nodes = [ + make_node('Shape', [input_nodes[0]], [name+'_shape']), + make_node('Split', [name+'_shape'], [name+'_dim0', name+'_dim1', name+'_dim2'], axis=0), + make_node('Concat', [name+'_dim0', name+'_16', name+'_m1', name+'_dim2'], + [name+'_shape_new'], axis=0), + make_node('Reshape', [input_nodes[0], name+'_shape_new'], [name], name=name) + ] + + if special_case: return nodes not_supported_shape = [-2, -3, -4] diff --git a/tests/python-pytest/onnx/test_onnxruntime.py b/tests/python-pytest/onnx/test_onnxruntime.py index ca4c0fed3fd9..6ad0794f875d 100644 --- a/tests/python-pytest/onnx/test_onnxruntime.py +++ b/tests/python-pytest/onnx/test_onnxruntime.py @@ -1155,3 +1155,79 @@ def verify_one_step_ahead_decoder(): finally: shutil.rmtree(tmp_path) + + +@with_seed() +@pytest.mark.parametrize('model_params', [('gpt2_117m', 24), ('gpt2_345m', 48)]) +def test_gpt_pretrained_inference_onnxruntime(tmp_path, model_params): + tmp_path = str(tmp_path) + try: + import gluonnlp as nlp + import urllib.request + from zipfile import ZipFile + import importlib.util + import sys + + url = 'https://nlp.gluon.ai/_downloads/77d227fbc8f1613e6802acc7253cc090/text_generation.zip' + urllib.request.urlretrieve(url, tmp_path + 'text_generation.zip') + + with ZipFile(tmp_path + 'text_generation.zip', 'r') as zipObj: + zipObj.extractall(tmp_path) + + # load in the text_generation module, refer to: + # https://github.com/dmlc/gluon-nlp/tree/v0.10.x/scripts/text_generation + spec = importlib.util.spec_from_file_location( + 'text_generation', + tmp_path + '/text_generation/__init__.py') + mod = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = mod + spec.loader.exec_module(mod) + + ctx = mx.cpu(0) + model_name= model_params[0] + dataset= 'openai_webtext' + # get_model() is overridden in here: + # https://github.com/dmlc/gluon-nlp/blob/v0.10.x/scripts/text_generation/model/__init__.py#L23 + model, _ = mod.model.get_model( + name=model_name, + ctx=ctx, + pretrained=True, + dataset_name=dataset) + + model.hybridize() + + batch = 4 + seq_length = 64 + inputs = mx.nd.random.uniform(0, 50257, shape=(batch, seq_length), dtype='float32', + ctx=ctx) + + pred = model(inputs) + + prefix = "%s/%s" % (tmp_path, model_name) + model.export(prefix) + sym_file = "%s-symbol.json" % prefix + params_file = "%s-0000.params" % prefix + onnx_file = "%s.onnx" % prefix + + input_shapes = [(batch, seq_length)] + input_types = [np.float32] + converted_model_path = mx.contrib.onnx.export_model(sym_file, params_file, input_shapes, + input_types, onnx_file) + + ses_opt = onnxruntime.SessionOptions() + ses_opt.log_severity_level = 3 + session = onnxruntime.InferenceSession(onnx_file, ses_opt) + onnx_inputs = [inputs] + input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy()) for i in range(len(onnx_inputs))) + pred_onx = session.run(None, input_dict) + + # check output + assert_almost_equal(pred[0], pred_onx[0]) + # check states + num_states = model_params[1] + for i in range(num_states): + assert_almost_equal(pred[1][i], pred_onx[i+1]) + + finally: + shutil.rmtree(tmp_path) + diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 220f259beb46..520ac407e49b 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -276,6 +276,14 @@ def test_onnx_export_reshape_special_cases(tmp_path, dtype): M9 = def_model('reshape', shape=(-4, 1, 1000, 0, 0)) op_export_test('reshape_spec_9', M9, [x7], tmp_path) + x8 = mx.nd.ones((3, 96, 5), dtype=dtype) + M10 = def_model('reshape', shape=(0, -4, 12, -1, 0)) + op_export_test('reshape_spec_10', M10, [x8], tmp_path) + + x9 = mx.nd.ones((3, 96, 5), dtype=dtype) + M11 = def_model('reshape', shape=(0, -4, 16, -1, 0)) + op_export_test('reshape_spec_11', M11, [x9], tmp_path) + @pytest.mark.parametrize('dtype', ['int32', 'int64']) def test_onnx_export_embedding(tmp_path, dtype):