diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 5f31724ec3ea..74ac74e67620 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1024,6 +1024,9 @@ def _impl_v10(cls, inputs, attr, params): attrs = {"starts": inputs[1], "ends": inputs[2]} if len(inputs) >= 4: attrs["axes"] = inputs[3] + if len(inputs) >= 5: + attrs["steps"] = inputs[4] + attrs = {k: (v, get_name(v)) for (k, v) in attrs.items()} attrs = { k: params[v[1]].asnumpy() @@ -1033,12 +1036,23 @@ def _impl_v10(cls, inputs, attr, params): } # Update the starts and ends according to axes if required. - if "axes" in attrs: - if max(attrs["axes"] + 1) != len(attrs["axes"]): - new_starts, new_ends, _ = cls._common(attrs["starts"], attrs["ends"], attrs["axes"]) - attrs["starts"] = new_starts - attrs["ends"] = new_ends - return _op.strided_slice(inputs[0], begin=list(attrs["starts"]), end=list(attrs["ends"])) + if "axes" in attrs and max(attrs["axes"] + 1) != len(attrs["axes"]): + new_starts, new_ends, _ = cls._common(attrs["starts"], attrs["ends"], attrs["axes"]) + attrs["starts"] = new_starts + attrs["ends"] = new_ends + + begins = list(attrs["starts"]) + ends = list(attrs["ends"]) + strides = [1] * len(begins) + + if "steps" in attrs: + steps = list(attrs["steps"]) + axes = attrs["axes"] + assert len(steps) == len(axes) + for axis, step in zip(axes, steps): + strides[axis] = step + + return _op.strided_slice(inputs[0], begin=begins, end=ends, strides=strides) class Gather(OnnxOpConverter): diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 894a6b6d40ce..81c8e77f0537 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -20,10 +20,8 @@ from onnx import helper, TensorProto, mapping import torch import torchvision -from tvm import topi import tvm.topi.testing import tvm -from tvm import te from tvm import relay from tvm.contrib import graph_runtime import scipy @@ -52,9 +50,10 @@ def get_tvm_output_with_vm(graph_def, input_data, target, ctx, opset=None): mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset) ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=target) - indata = tvm.nd.array(input_data) - result = ex.evaluate()(indata) - return result.asnumpy() + result = ex.evaluate()(*input_data) + if isinstance(result, tvm.runtime.NDArray): + return result.asnumpy() + return [r.asnumpy() for r in result] def get_tvm_output( @@ -104,21 +103,71 @@ def get_onnxruntime_output(model, inputs, dtype="float32"): rep = onnxruntime.backend.prepare(model, "CPU") if isinstance(inputs, list) and len(inputs) > 1: - ort_out = rep.run(inputs) + return rep.run(inputs) + elif isinstance(inputs, list) and len(inputs) == 1: + inp = inputs[0] else: - x = inputs.astype(dtype) - ort_out = rep.run(x)[0] - return ort_out + inp = inputs + return rep.run(inp.astype(dtype))[0] + + +def verify_with_ort_with_inputs( + model, + inputs, + out_shape=None, + targets=None, + use_vm=False, + opset=None, + dtype="float32", + rtol=1e-5, + atol=1e-5, +): + def flatten(out): + if isinstance(out, list) and len(out) == 1: + out = out[0] + if isinstance(out, np.ndarray): + return out.flatten() + return out + ort_out = get_onnxruntime_output(model, inputs, dtype) -def verify_onnx_forward_impl(graph_file, data_shape, out_shape): - dtype = "float32" - x = np.random.uniform(size=data_shape) - model = onnx.load_model(graph_file) - c2_out = get_onnxruntime_output(model, x, dtype) - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, x, target, ctx, out_shape, dtype) - tvm.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5) + if targets is None: + targets = [tgt for (tgt, _) in tvm.testing.enabled_targets()] + + for target in targets: + ctx = tvm.context(target, 0) + + if use_vm: + tvm_out = get_tvm_output_with_vm(model, inputs, target, ctx, opset=opset) + else: + tvm_out = get_tvm_output(model, inputs, target, ctx, out_shape, dtype, opset=opset) + + tvm.testing.assert_allclose(flatten(ort_out), flatten(tvm_out), rtol=rtol, atol=atol) + + +def verify_with_ort( + model, + input_shapes, + out_shape=None, + targets=None, + use_vm=False, + opset=None, + dtype="float32", + rtol=1e-5, + atol=1e-5, +): + inputs = [np.random.uniform(size=ishape).astype(dtype) for ishape in input_shapes] + verify_with_ort_with_inputs( + model, + inputs, + out_shape=out_shape, + targets=targets, + use_vm=use_vm, + opset=opset, + dtype=dtype, + rtol=rtol, + atol=atol, + ) def make_constant_node(name, data_type, dims, vals): @@ -161,8 +210,7 @@ def test_reshape(): for target, ctx in tvm.testing.enabled_targets(): x = np.random.uniform(size=in_shape).astype("int32") tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, "float32") - - tvm.testing.assert_allclose(ref_shape, tvm_out.shape) + tvm.testing.assert_allclose(ref_shape, tvm_out.shape) @tvm.testing.uses_gpu @@ -193,8 +241,7 @@ def _test_expand(name, data, shape, ref_data): for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, data, target, ctx, ref_data.shape, "float32") - - tvm.testing.assert_allclose(ref_data, tvm_out) + tvm.testing.assert_allclose(ref_data, tvm_out) in_shape = (3, 1) shape = (3, 4) @@ -221,11 +268,7 @@ def verify_depth_to_space(inshape, outshape, mode, blockSize): model = helper.make_model(graph, producer_name="depth_to_space_test") - for target, ctx in tvm.testing.enabled_targets(): - x = np.random.uniform(size=inshape).astype("float32") - tvm_out = get_tvm_output(model, x, target, ctx, outshape, "float32") - onnx_out = get_onnxruntime_output(model, x, "float32") - tvm.testing.assert_allclose(onnx_out, tvm_out) + verify_with_ort(model, [inshape], outshape) @tvm.testing.uses_gpu @@ -248,11 +291,7 @@ def verify_space_to_depth(inshape, outshape, blockSize): model = helper.make_model(graph, producer_name="space_to_depth_test") - for target, ctx in tvm.testing.enabled_targets(): - x = np.random.uniform(size=inshape).astype("float32") - tvm_out = get_tvm_output(model, x, target, ctx, outshape, "float32") - onnx_out = get_onnxruntime_output(model, x, "float32") - tvm.testing.assert_allclose(onnx_out, tvm_out) + verify_with_ort(model, [inshape], outshape) @tvm.testing.uses_gpu @@ -293,8 +332,7 @@ def test_shape(): for target, ctx in tvm.testing.enabled_targets(): x = np.random.uniform(size=in_shape).astype("int32") tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, "int32") - - tvm.testing.assert_allclose(ref_shape, tvm_out) + tvm.testing.assert_allclose(ref_shape, tvm_out) def _test_power_iteration(x_shape, y_shape): @@ -350,8 +388,7 @@ def test_squeeze(): for target, ctx in tvm.testing.enabled_targets(): x = np.random.uniform(size=in_shape).astype("float32") tvm_out = get_tvm_output(model, x, target, ctx, out_shape, "float32") - - tvm.testing.assert_allclose(out_shape, tvm_out.shape) + tvm.testing.assert_allclose(out_shape, tvm_out.shape) @tvm.testing.uses_gpu @@ -375,8 +412,7 @@ def test_flatten(): for target, ctx in tvm.testing.enabled_targets(): x = np.random.uniform(size=in_shape).astype("int32") tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, "float32") - - tvm.testing.assert_allclose(ref_shape, tvm_out.shape) + tvm.testing.assert_allclose(ref_shape, tvm_out.shape) @tvm.testing.uses_gpu @@ -398,8 +434,7 @@ def test_unsqueeze(): for target, ctx in tvm.testing.enabled_targets(): x = np.random.uniform(size=in_shape).astype("float32") tvm_out = get_tvm_output(model, x, target, ctx, out_shape, "float32") - - tvm.testing.assert_allclose(out_shape, tvm_out.shape) + tvm.testing.assert_allclose(out_shape, tvm_out.shape) def verify_gather(in_shape, indices, axis, dtype): @@ -450,11 +485,8 @@ def verify_gatherelements(in_shape, indices, axis): outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, list(in_shape))], ) model = helper.make_model(graph, producer_name="gather_elements_test") - onnx_out = get_onnxruntime_output(model, [x, indices]) - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, [x, indices], target, ctx, onnx_out[0].shape) - tvm.testing.assert_allclose(onnx_out[0], tvm_out) + verify_with_ort_with_inputs(model, [x, indices]) @tvm.testing.uses_gpu @@ -491,11 +523,7 @@ def verify_scatter(in_shape, indices, axis): outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, list(in_shape))], ) model = helper.make_model(graph, producer_name="scatter_test") - onnx_out = get_onnxruntime_output(model, [x, indices, updates]) - - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, [x, indices, updates], target, ctx, onnx_out[0].shape) - tvm.testing.assert_allclose(onnx_out[0], tvm_out) + verify_with_ort_with_inputs(model, [x, indices, updates]) @tvm.testing.uses_gpu @@ -525,14 +553,14 @@ def _test_slice_iteration_v1(indata, outdata, starts, ends, axes=None): for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, "float32", opset=1) - - tvm.testing.assert_allclose(outdata, tvm_out) + tvm.testing.assert_allclose(outdata, tvm_out) def _test_slice_iteration_v10(indata, outdata, **attrs): starts = attrs["starts"] ends = attrs["ends"] axes = None if "axes" not in attrs else attrs["axes"] + steps = None if "steps" not in attrs else attrs["steps"] starts = np.asarray(starts) ends = np.asarray(ends) inputs = [ @@ -589,8 +617,8 @@ def add_noop_to_input_attr(attr_name, attr): return [ref_node, ref_node2, reshape1_node, reshape2_node] slice_inputs = [] - for attr_name in ["starts", "ends", "axes"]: - if attr_name == "axes" and not axes: + for attr_name in ["starts", "ends", "axes", "steps"]: + if attr_name not in attrs: continue if "add_noop_to_input_attrs" in attrs and attr_name in attrs["add_noop_to_input_attrs"]: nodes.extend(add_noop_to_input_attr(attr_name, attrs[attr_name])) @@ -602,6 +630,13 @@ def add_noop_to_input_attr(attr_name, attr): axes = np.asarray(axes) inputs.append(helper.make_tensor_value_info("axes", TensorProto.INT32, list(axes.shape))) initializer.append(helper.make_tensor("axes", TensorProto.INT32, list(axes.shape), axes)) + + if steps: + assert axes is not None and len(axes) == len(steps) + steps = np.asarray(steps) + inputs.append(helper.make_tensor_value_info("steps", TensorProto.INT32, list(axes.shape))) + initializer.append(helper.make_tensor("steps", TensorProto.INT32, list(steps.shape), steps)) + y = helper.make_node("Slice", ["data", *slice_inputs], ["out"]) nodes.append(y) @@ -616,8 +651,7 @@ def add_noop_to_input_attr(attr_name, attr): for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, "float32", opset=10) - - tvm.testing.assert_allclose(outdata, tvm_out) + tvm.testing.assert_allclose(outdata, tvm_out) @tvm.testing.uses_gpu @@ -681,6 +715,19 @@ def test_slice(): x, x, starts=(0, 0), ends=(9223372036854775807, 9223372036854775807), axes=(0, 3) ) + x = np.random.randn(4, 4).astype(np.float32) + _test_slice_iteration_v10( + x, x[:, 1::2], starts=(1,), ends=(9223372036854775807,), axes=(1,), steps=(2,) + ) + _test_slice_iteration_v10( + x, + x[0::1, 1::2], + starts=(0, 1), + ends=(4, 4), + axes=(0, 1), + steps=(1, 2), + ) + def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs): indata = np.random.uniform(-1, 1, size=inshape).astype(dtype) @@ -699,8 +746,7 @@ def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs): for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, dtype) - - tvm.testing.assert_allclose(outdata, tvm_out) + tvm.testing.assert_allclose(outdata, tvm_out) @tvm.testing.uses_gpu @@ -742,11 +788,7 @@ def test_clip_min_max_as_inputs(): ) model = helper.make_model(graph, producer_name="clip_test") - indata = np.random.uniform(-1, 7, size=input_shape).astype("float32") - onnx_out = get_onnxruntime_output(model, indata, "float32") - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, indata, target, ctx, input_shape, "float32") - tvm.testing.assert_allclose(onnx_out, tvm_out) + verify_with_ort(model, [input_shape], input_shape) @tvm.testing.uses_gpu @@ -771,8 +813,7 @@ def _test_finite_ops(inshape, outfunc, npargs, dtype, opname, kwargs): for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, dtype) - - tvm.testing.assert_allclose(outdata, tvm_out) + tvm.testing.assert_allclose(outdata, tvm_out) @tvm.testing.uses_gpu @@ -1574,10 +1615,7 @@ def verify_reduce_func(func, data, axis, keepdims): model = helper.make_model(graph, producer_name="reduce_test") - onnx_out = get_onnxruntime_output(model, data, "float32") - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, data, target, ctx, outshape, "float32") - tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) + verify_with_ort_with_inputs(model, [data], outshape) @tvm.testing.uses_gpu @@ -1815,15 +1853,7 @@ def verify_prelu(x_shape, a_shape): model = helper.make_model(graph, producer_name="prelu_test") - indata = np.random.uniform(-10, 10, x_shape).astype(np.float32) - slopedata = np.random.uniform(-10, 10, a_shape).astype(np.float32) - onnx_out = get_onnxruntime_output(model, [indata, slopedata]) - - for target, ctx in [("llvm", tvm.cpu())]: - tvm_out = get_tvm_output( - model, [indata, slopedata], target, ctx, list(x_shape), output_dtype="float32" - ) - tvm.testing.assert_allclose(onnx_out[0], tvm_out, rtol=1e-05, atol=1e-05) + verify_with_ort(model, [x_shape, a_shape], list(x_shape)) verify_prelu([3, 4, 5, 6], [1, 4, 1, 1]) verify_prelu([1, 8, 5, 6], [1, 8, 1, 1]) @@ -1900,11 +1930,8 @@ def check_torch_conversion(model, input_size): # Set verbose=True for more output torch.onnx.export(model(), dummy_input, file_name, export_params=True, verbose=False) onnx_model = onnx.load(file_name) - for target, ctx in tvm.testing.enabled_targets(): - input_data = np.random.uniform(size=input_size).astype("int32") - c2_out = get_onnxruntime_output(onnx_model, input_data) - tvm_out = get_tvm_output(onnx_model, input_data, target, ctx) - tvm.testing.assert_allclose(c2_out, tvm_out) + input_data = np.random.uniform(size=input_size).astype("int32") + verify_with_ort_with_inputs(onnx_model, [input_data]) @tvm.testing.uses_gpu @@ -2244,18 +2271,9 @@ def verify_batch_norm(in_shape): ) model = helper.make_model(graph, producer_name="batchnorm_test") - - for target, ctx in tvm.testing.enabled_targets(): - x = np.random.uniform(size=in_shape).astype("float32") - scale = np.random.uniform(size=in_shape[1]).astype("float32") - b = np.random.uniform(size=in_shape[1]).astype("float32") - mean = np.random.uniform(size=in_shape[1]).astype("float32") - var = np.random.uniform(size=in_shape[1]).astype("float32") - onnx_out = get_onnxruntime_output(model, [x, scale, b, mean, var], "float32")[0] - tvm_out = get_tvm_output( - model, [x, scale, b, mean, var], target, ctx, in_shape, "float32" - ) - tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) + # X, scale, b, mean, var + inshapes = [in_shape, in_shape[1], in_shape[1], in_shape[1], in_shape[1]] + verify_with_ort(model, inshapes, in_shape) verify_batch_norm([1, 3, 224, 224]) verify_batch_norm([1, 3, 24, 24]) @@ -2288,19 +2306,9 @@ def verify_batch_norm_dynamic_subgraph(in_shape, o_shape): ) model = helper.make_model(graph, producer_name="batchnorm_test") - - for target, ctx in tvm.testing.enabled_targets(): - x = np.random.uniform(size=in_shape).astype("float32") - inp = np.random.uniform(size=o_shape).astype("float32") - scale = np.random.uniform(size=in_shape[1]).astype("float32") - b = np.random.uniform(size=in_shape[1]).astype("float32") - mean = np.random.uniform(size=in_shape[1]).astype("float32") - var = np.random.uniform(size=in_shape[1]).astype("float32") - onnx_out = get_onnxruntime_output(model, [x, inp, scale, b, mean, var], "float32")[0] - tvm_out = get_tvm_output( - model, [x, inp, scale, b, mean, var], target, ctx, in_shape, "float32" - ) - tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) + # X, inp, scale, b, mean, var + inshapes = [in_shape, o_shape, in_shape[1], in_shape[1], in_shape[1], in_shape[1]] + verify_with_ort(model, inshapes, in_shape, use_vm=False) verify_batch_norm_dynamic_subgraph([16, 16, 10, 10], [160, 160]) @@ -2364,12 +2372,7 @@ def verify_conv( model = helper.make_model(graph, producer_name="conv_test") - for target, ctx in tvm.testing.enabled_targets(): - x = np.random.uniform(size=x_shape).astype("float32") - W = np.random.uniform(size=w_shape).astype("float32") - tvm_out = get_tvm_output(model, [x, W], target, ctx, y_shape) - onnx_out = get_onnxruntime_output(model, [x, W], "float32")[0] - tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) + verify_with_ort(model, [x_shape, w_shape], y_shape) @tvm.testing.uses_gpu @@ -2476,13 +2479,7 @@ def verify_convtranspose(x_shape, w_shape, y_shape, p): ) model = helper.make_model(graph, producer_name="convtranspose_trest") - - for target, ctx in tvm.testing.enabled_targets(): - x = np.random.uniform(size=x_shape).astype("float32") - W = np.random.uniform(size=w_shape).astype("float32") - tvm_out = get_tvm_output(model, [x, W], target, ctx, y_shape) - onnx_out = get_onnxruntime_output(model, [x, W], "float32")[0] - tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) + verify_with_ort(model, [x_shape, w_shape], y_shape) @tvm.testing.uses_gpu @@ -2548,11 +2545,7 @@ def verify_pooling(x_shape, kernel_shape, strides, pads, out_shape, mode, auto_p ) model = helper.make_model(graph, producer_name="pooling_test") - - for target, ctx in tvm.testing.enabled_targets(): - onnx_out = get_onnxruntime_output(model, x_np, "float32") - tvm_out = get_tvm_output(model, [x_np], target, ctx, out_shape) - tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) + verify_with_ort(model, [x_shape], out_shape) @tvm.testing.uses_gpu @@ -2657,12 +2650,7 @@ def verify_mod(x_shape, y_shape, fmod, out_shape, dtype="float32"): outputs=[helper.make_tensor_value_info("z", onnx_dtype, list(out_shape))], ) model = helper.make_model(graph, producer_name="mod_test") - - onnx_out = get_onnxruntime_output(model, [x_np, y_np], dtype)[0] - - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, [x_np, y_np], target, ctx, out_shape) - tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) + verify_with_ort_with_inputs(model, [x_np, y_np], out_shape) @tvm.testing.uses_gpu @@ -2731,9 +2719,6 @@ def test_xor(): def verify_max_roi_pool(x_shape, rois_shape, pooled_shape, spatial_scale, out_shape): - x_np = np.random.uniform(size=x_shape).astype("float32") - rois_np = np.random.uniform(size=rois_shape).astype("float32") - if spatial_scale is None: pool_node = helper.make_node( "MaxRoiPool", inputs=["x", "rois"], outputs=["y"], pooled_shape=pooled_shape @@ -2758,11 +2743,7 @@ def verify_max_roi_pool(x_shape, rois_shape, pooled_shape, spatial_scale, out_sh ) model = helper.make_model(graph, producer_name="pool_test") - - onnx_out = get_onnxruntime_output(model, [x_np, rois_np], "float32")[0] - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, [x_np, rois_np], target, ctx, out_shape) - tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) + verify_with_ort(model, [x_shape, rois_shape], out_shape) @tvm.testing.uses_gpu @@ -2785,8 +2766,6 @@ def test_max_roi_pool(): def verify_lppool(x_shape, kernel_shape, p, strides, pads, out_shape, auto_pad="NOTSET"): - x_np = np.random.uniform(size=x_shape).astype("float32") - if pads is None: pool_node = helper.make_node( "LpPool", @@ -2816,11 +2795,7 @@ def verify_lppool(x_shape, kernel_shape, p, strides, pads, out_shape, auto_pad=" ) model = helper.make_model(graph, producer_name="lppool_test") - - for target, ctx in tvm.testing.enabled_targets(): - onnx_out = get_onnxruntime_output(model, x_np, "float32") - tvm_out = get_tvm_output(model, [x_np], target, ctx, out_shape) - tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) + verify_with_ort(model, [x_shape], out_shape) @tvm.testing.uses_gpu @@ -3228,12 +3203,7 @@ def verify(ishape, oshape, scales, mode, coord_trans): model = helper.make_model(graph, producer_name="resize_test") - for target, ctx in tvm.testing.enabled_targets(): - x = np.random.uniform(size=ishape).astype("float32") - onnx_out = get_onnxruntime_output(model, x, "float32") - tvm_out = get_tvm_output(model, x, target, ctx, oshape, "float32", opset=11) - - tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, atol=1e-05) + verify_with_ort(model, [ishape], oshape, use_vm=False, opset=11) # upsampling verify([1, 16, 32, 32], [1, 16, 64, 64], [], "nearest", "asymmetric") @@ -3266,11 +3236,9 @@ def verify_nonzero(indata, outdata, dtype): model = helper.make_model(graph, producer_name="nonzero_test") - onnx_out = get_onnxruntime_output(model, indata, dtype) - - for target, ctx in [("llvm", tvm.cpu())]: - tvm_out = get_tvm_output_with_vm(model, indata, target, ctx, opset=9) - tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, atol=1e-05) + verify_with_ort_with_inputs( + model, [indata], targets=["llvm"], dtype="int64", use_vm=True, opset=9 + ) input_data = np.array([[1, 0], [1, 1]], dtype=np.int64) result = np.array((np.nonzero(input_data))) # expected output [[0, 1, 1], [0, 0, 1]] @@ -3378,17 +3346,7 @@ def verify_roi_align( np_rois = np.random.uniform(size=[num_roi, 4]).astype("float32") * input_dims[2] np_batch_indicies = np.random.randint(low=0, high=input_dims[0], size=num_roi) - onnx_out = get_onnxruntime_output(model, [np_data, np_rois, np_batch_indicies]) - for target, ctx in [("llvm", tvm.cpu())]: - tvm_out = get_tvm_output( - model, - [np_data, np_rois, np_batch_indicies], - target, - ctx, - output_dims, - output_dtype="float32", - ) - tvm.testing.assert_allclose(onnx_out[0], tvm_out, rtol=1e-05, atol=1e-05) + verify_with_ort_with_inputs(model, [np_data, np_rois, np_batch_indicies], output_dims) verify_roi_align((1, 4, 16, 16), 32, 7, 7, sampling_ratio=0, spatial_scale=1.0) verify_roi_align((4, 4, 16, 32), 32, 7, 7, sampling_ratio=0, spatial_scale=1.0)