From d5806eca7661c36191655f1f431251b188eb37fa Mon Sep 17 00:00:00 2001 From: GaoYuYang <zdlnlhmj@126.com> Date: Fri, 24 Feb 2023 07:36:52 +0800 Subject: [PATCH] [Frontend][Paddle] Add where_index op and add vm for paddle frontend's unitest (#14099) * Add where_index for paddle frontend and add vm for paddle's unitest * change using_tvm to use_tvm --- python/tvm/relay/frontend/paddlepaddle.py | 9 +++ .../frontend/paddlepaddle/test_forward.py | 55 +++++++++++++++---- 2 files changed, 54 insertions(+), 10 deletions(-) diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index 75fecf217851..e688369a072a 100755 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -2074,6 +2074,14 @@ def convert_unsqueeze(g, op, block): g.add_node(op.output("Out")[0], x) +def convert_where_index(g, op, block): + """Operator converter for where_index.""" + + condition = g.get_node(op.input("Condition")[0]) + out = _op.argwhere(condition) + g.add_node(op.output("Out")[0], out) + + _convert_map = { "abs": convert_unary_op, "acos": convert_unary_op, @@ -2211,6 +2219,7 @@ def convert_unsqueeze(g, op, block): "top_k_v2": convert_topk, "transpose2": convert_transpose, "unsqueeze2": convert_unsqueeze, + "where_index": convert_where_index, } diff --git a/tests/python/frontend/paddlepaddle/test_forward.py b/tests/python/frontend/paddlepaddle/test_forward.py index cd2c0be7ef36..70fbf6aee554 100755 --- a/tests/python/frontend/paddlepaddle/test_forward.py +++ b/tests/python/frontend/paddlepaddle/test_forward.py @@ -57,7 +57,7 @@ def get_paddle_model(func, input_spec): return baseline_model -def verify_model(func, input_data, rtol=1e-5, atol=1e-5): +def verify_model(func, input_data, use_vm=False, rtol=1e-5, atol=1e-5): if not (isinstance(input_data, (tuple, list))): input_data = [input_data] @@ -93,19 +93,44 @@ def verify_model(func, input_data, rtol=1e-5, atol=1e-5): if arg.name_hint in input_names: compiled_names.append(arg.name_hint) - with tvm.transform.PassContext(opt_level=3): + if use_vm: + tvm_vm_input = [] + for idx, data in enumerate(input_data): + if isinstance(data, np.ndarray): + tvm_vm_input.append(data) + else: + tvm_vm_input.append(data.numpy()) for target, dev in tvm.testing.enabled_targets(): - lib = relay.build(mod, target=target, params=params) - gmod = graph_executor.GraphModule(lib["default"](dev)) - for name in compiled_names: - gmod.set_input(name, compiled_input[name]) - gmod.run() + result = relay.create_executor("vm", mod=mod, device=dev, target=target).evaluate()( + *tvm_vm_input, **params + ) + tvm_vm_output = [] + if isinstance(result, tvm.runtime.NDArray): + tvm_vm_output = result.numpy() + else: + tvm_vm_output = [r.numpy() for r in result] + if not isinstance(tvm_vm_output, list): + tvm_vm_output = [tvm_vm_output] for i, baseline_output in enumerate(baseline_outputs): - compiled_output = gmod.get_output(i).numpy() + assert_shapes_match(baseline_output, tvm_vm_output[i]) + tvm.testing.assert_allclose(baseline_output, tvm_vm_output[i], rtol=rtol, atol=atol) + else: + with tvm.transform.PassContext(opt_level=3): + for target, dev in tvm.testing.enabled_targets(): + lib = relay.build(mod, target=target, params=params) + gmod = graph_executor.GraphModule(lib["default"](dev)) + for name in compiled_names: + gmod.set_input(name, compiled_input[name]) + gmod.run() + + for i, baseline_output in enumerate(baseline_outputs): + compiled_output = gmod.get_output(i).numpy() - assert_shapes_match(baseline_output, compiled_output) - tvm.testing.assert_allclose(baseline_output, compiled_output, rtol=rtol, atol=atol) + assert_shapes_match(baseline_output, compiled_output) + tvm.testing.assert_allclose( + baseline_output, compiled_output, rtol=rtol, atol=atol + ) @tvm.testing.uses_gpu @@ -1749,5 +1774,15 @@ def norm_2(inputs): verify_model(norm_2, input_data=input_data) +@tvm.testing.uses_gpu +def test_forward_where_index(): + @paddle.jit.to_static + def where_index_1(inputs): + return paddle.nonzero(inputs) + + input_data = paddle.to_tensor([[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 3.0]]) + verify_model(where_index_1, input_data=input_data, use_vm=True) + + if __name__ == "__main__": tvm.testing.main()