From 9a2a18ff5737ad6ed9b85225b4acbef31936cfe3 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 24 Feb 2023 18:32:46 +0000 Subject: [PATCH 1/3] Revert "[Compatible] Fix batch_matmul from Relay (#155)" This reverts commit 51ec502231e38ee6dbd621bebfa265e293eb3f5a. --- src/op/from_relay/nn.cc | 22 +----------------- tests/python/pass/test_pass_from_relay.py | 27 ----------------------- 2 files changed, 1 insertion(+), 48 deletions(-) diff --git a/src/op/from_relay/nn.cc b/src/op/from_relay/nn.cc index 46301801..576f662e 100644 --- a/src/op/from_relay/nn.cc +++ b/src/op/from_relay/nn.cc @@ -15,29 +15,9 @@ namespace raf { namespace op { namespace from_relay { +RAF_GENERIC_ATTR_OP_FROM_RELAY("nn.batch_matmul", "raf.op.batch_matmul_nt"); RAF_GENERIC_ATTR_OP_FROM_RELAY("nn.dense", "raf.op.dense"); -// TVM's nn.batch_matmul has a transpose_a and transpose_b attribute, but RAF's -// batch_matmul_nt does not. Instead, RAF has 4 variants of batch_matmul for -// different combinations of transpose_a and transpose_b. This function -// converts the batch_matmul with transpose_a and transpose_b attributes to -// the appropriate batch_matmul variant. -RELAY_REGISTER_OP("nn.batch_matmul") - .set_attr("FRAFFromRelay", [](const Attrs& attrs, const Array& args, - const VarValueMap& val_map) { - const auto* relay_attrs = attrs.as(); - auto transpose_a = relay_attrs->transpose_a; - auto transpose_b = relay_attrs->transpose_b; - if (transpose_a && transpose_b) { - return Call(Op::Get("raf.op.batch_matmul_tt"), args); - } else if (transpose_a && !transpose_b) { - return Call(Op::Get("raf.op.batch_matmul_tn"), args); - } else if (!transpose_a && transpose_b) { - return Call(Op::Get("raf.op.batch_matmul_nt"), args); - } - return Call(Op::Get("raf.op.batch_matmul"), args); - }); - RAF_OP_FROM_RELAY("nn.conv2d", "raf.op.conv2d", [&](const Attrs& attrs, const Array& args, const VarValueMap& val_map) { Array raf_args = args; diff --git a/tests/python/pass/test_pass_from_relay.py b/tests/python/pass/test_pass_from_relay.py index d7360e23..e80e6cd4 100644 --- a/tests/python/pass/test_pass_from_relay.py +++ b/tests/python/pass/test_pass_from_relay.py @@ -1068,33 +1068,6 @@ def forward(self, m_x, m_y): check_from_relay(model, r_func, [m_x, m_y]) -@pytest.mark.parametrize("trans", [[False, False], [False, True], [True, False], [True, True]]) -def test_batch_matmul(trans): - class TransposeBatchMatmul(raf.Model): - def build(self, trans): - self.op_name = "batch_matmul" - if trans[0] or trans[1]: - self.op_name += f"_{'t' if trans[0] else 'n'}{'t' if trans[1] else 'n'}" - - @raf.model.trace - def forward(self, m_x, m_y): - x = raf.relu(m_x) - return getattr(raf, self.op_name)(x, m_y) - - model = TransposeBatchMatmul(trans) - m_x, _ = randn((4, 10, 10), dtype="float32") - m_y, _ = randn((4, 10, 10), dtype="float32") - - r_x = raf.ir.var("x", shape=(4, 10, 10), dtype="float32") - r_y = raf.ir.var("x", shape=(4, 10, 10), dtype="float32") - r_out = _relay.nn.batch_matmul( - _relay.nn.relu(r_x), r_y, transpose_a=trans[0], transpose_b=trans[1] - ) - r_func = _relay.Function(params=[r_x, r_y], body=r_out) - - check_from_relay(model, r_func, [m_x, m_y]) - - @pytest.mark.parametrize("device", get_testable_devices()) @pytest.mark.parametrize("shape", [(), (1,), (1, 2, 3, 4)]) @pytest.mark.parametrize("dtype", ["float64", "float32", "float16"]) From 34fa0a0317edbca77d22a3d5e46446061ddeddbb Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 24 Feb 2023 18:32:55 +0000 Subject: [PATCH 2/3] Revert "[TVM] Update Submodule (#154)" This reverts commit fc48da9ba0e4af0a8ce1baeda0e98d77a660720c. --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 697c724e..266ff51d 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 697c724e92d43ddd9d3123b9b1ae581703ac7b84 +Subproject commit 266ff51d2ad59a590b1645007890322c25468a58 From 89a462b1eca35699b6b6ce0d1138b4777591a08c Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 24 Feb 2023 18:33:04 +0000 Subject: [PATCH 3/3] Revert "[TVM] Update Submodule (#153)" This reverts commit b33f2ac5d82a3bf5bfd55eead0f20897769ca573. --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 266ff51d..f7aeaf1d 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 266ff51d2ad59a590b1645007890322c25468a58 +Subproject commit f7aeaf1d389881e408d29585ea62c6bb5ea65843