From d5806eca7661c36191655f1f431251b188eb37fa Mon Sep 17 00:00:00 2001
From: GaoYuYang <zdlnlhmj@126.com>
Date: Fri, 24 Feb 2023 07:36:52 +0800
Subject: [PATCH] [Frontend][Paddle] Add where_index op and add vm for paddle
 frontend's unitest (#14099)

* Add where_index for paddle frontend and add vm for paddle's unitest

* change using_tvm to use_tvm
---
 python/tvm/relay/frontend/paddlepaddle.py     |  9 +++
 .../frontend/paddlepaddle/test_forward.py     | 55 +++++++++++++++----
 2 files changed, 54 insertions(+), 10 deletions(-)

diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py
index 75fecf217851..e688369a072a 100755
--- a/python/tvm/relay/frontend/paddlepaddle.py
+++ b/python/tvm/relay/frontend/paddlepaddle.py
@@ -2074,6 +2074,14 @@ def convert_unsqueeze(g, op, block):
     g.add_node(op.output("Out")[0], x)
 
 
+def convert_where_index(g, op, block):
+    """Operator converter for where_index."""
+
+    condition = g.get_node(op.input("Condition")[0])
+    out = _op.argwhere(condition)
+    g.add_node(op.output("Out")[0], out)
+
+
 _convert_map = {
     "abs": convert_unary_op,
     "acos": convert_unary_op,
@@ -2211,6 +2219,7 @@ def convert_unsqueeze(g, op, block):
     "top_k_v2": convert_topk,
     "transpose2": convert_transpose,
     "unsqueeze2": convert_unsqueeze,
+    "where_index": convert_where_index,
 }
 
 
diff --git a/tests/python/frontend/paddlepaddle/test_forward.py b/tests/python/frontend/paddlepaddle/test_forward.py
index cd2c0be7ef36..70fbf6aee554 100755
--- a/tests/python/frontend/paddlepaddle/test_forward.py
+++ b/tests/python/frontend/paddlepaddle/test_forward.py
@@ -57,7 +57,7 @@ def get_paddle_model(func, input_spec):
     return baseline_model
 
 
-def verify_model(func, input_data, rtol=1e-5, atol=1e-5):
+def verify_model(func, input_data, use_vm=False, rtol=1e-5, atol=1e-5):
     if not (isinstance(input_data, (tuple, list))):
         input_data = [input_data]
 
@@ -93,19 +93,44 @@ def verify_model(func, input_data, rtol=1e-5, atol=1e-5):
         if arg.name_hint in input_names:
             compiled_names.append(arg.name_hint)
 
-    with tvm.transform.PassContext(opt_level=3):
+    if use_vm:
+        tvm_vm_input = []
+        for idx, data in enumerate(input_data):
+            if isinstance(data, np.ndarray):
+                tvm_vm_input.append(data)
+            else:
+                tvm_vm_input.append(data.numpy())
         for target, dev in tvm.testing.enabled_targets():
-            lib = relay.build(mod, target=target, params=params)
-            gmod = graph_executor.GraphModule(lib["default"](dev))
-            for name in compiled_names:
-                gmod.set_input(name, compiled_input[name])
-            gmod.run()
+            result = relay.create_executor("vm", mod=mod, device=dev, target=target).evaluate()(
+                *tvm_vm_input, **params
+            )
+            tvm_vm_output = []
+            if isinstance(result, tvm.runtime.NDArray):
+                tvm_vm_output = result.numpy()
+            else:
+                tvm_vm_output = [r.numpy() for r in result]
+            if not isinstance(tvm_vm_output, list):
+                tvm_vm_output = [tvm_vm_output]
 
             for i, baseline_output in enumerate(baseline_outputs):
-                compiled_output = gmod.get_output(i).numpy()
+                assert_shapes_match(baseline_output, tvm_vm_output[i])
+                tvm.testing.assert_allclose(baseline_output, tvm_vm_output[i], rtol=rtol, atol=atol)
+    else:
+        with tvm.transform.PassContext(opt_level=3):
+            for target, dev in tvm.testing.enabled_targets():
+                lib = relay.build(mod, target=target, params=params)
+                gmod = graph_executor.GraphModule(lib["default"](dev))
+                for name in compiled_names:
+                    gmod.set_input(name, compiled_input[name])
+                gmod.run()
+
+                for i, baseline_output in enumerate(baseline_outputs):
+                    compiled_output = gmod.get_output(i).numpy()
 
-                assert_shapes_match(baseline_output, compiled_output)
-                tvm.testing.assert_allclose(baseline_output, compiled_output, rtol=rtol, atol=atol)
+                    assert_shapes_match(baseline_output, compiled_output)
+                    tvm.testing.assert_allclose(
+                        baseline_output, compiled_output, rtol=rtol, atol=atol
+                    )
 
 
 @tvm.testing.uses_gpu
@@ -1749,5 +1774,15 @@ def norm_2(inputs):
     verify_model(norm_2, input_data=input_data)
 
 
+@tvm.testing.uses_gpu
+def test_forward_where_index():
+    @paddle.jit.to_static
+    def where_index_1(inputs):
+        return paddle.nonzero(inputs)
+
+    input_data = paddle.to_tensor([[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 3.0]])
+    verify_model(where_index_1, input_data=input_data, use_vm=True)
+
+
 if __name__ == "__main__":
     tvm.testing.main()