From 9a2a18ff5737ad6ed9b85225b4acbef31936cfe3 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 24 Feb 2023 18:32:46 +0000 Subject: [PATCH] Revert "[Compatible] Fix batch_matmul from Relay (#155)" This reverts commit 51ec502231e38ee6dbd621bebfa265e293eb3f5a. --- src/op/from_relay/nn.cc | 22 +----------------- tests/python/pass/test_pass_from_relay.py | 27 ----------------------- 2 files changed, 1 insertion(+), 48 deletions(-) diff --git a/src/op/from_relay/nn.cc b/src/op/from_relay/nn.cc index 46301801..576f662e 100644 --- a/src/op/from_relay/nn.cc +++ b/src/op/from_relay/nn.cc @@ -15,29 +15,9 @@ namespace raf { namespace op { namespace from_relay { +RAF_GENERIC_ATTR_OP_FROM_RELAY("nn.batch_matmul", "raf.op.batch_matmul_nt"); RAF_GENERIC_ATTR_OP_FROM_RELAY("nn.dense", "raf.op.dense"); -// TVM's nn.batch_matmul has a transpose_a and transpose_b attribute, but RAF's -// batch_matmul_nt does not. Instead, RAF has 4 variants of batch_matmul for -// different combinations of transpose_a and transpose_b. This function -// converts the batch_matmul with transpose_a and transpose_b attributes to -// the appropriate batch_matmul variant. -RELAY_REGISTER_OP("nn.batch_matmul") - .set_attr("FRAFFromRelay", [](const Attrs& attrs, const Array& args, - const VarValueMap& val_map) { - const auto* relay_attrs = attrs.as(); - auto transpose_a = relay_attrs->transpose_a; - auto transpose_b = relay_attrs->transpose_b; - if (transpose_a && transpose_b) { - return Call(Op::Get("raf.op.batch_matmul_tt"), args); - } else if (transpose_a && !transpose_b) { - return Call(Op::Get("raf.op.batch_matmul_tn"), args); - } else if (!transpose_a && transpose_b) { - return Call(Op::Get("raf.op.batch_matmul_nt"), args); - } - return Call(Op::Get("raf.op.batch_matmul"), args); - }); - RAF_OP_FROM_RELAY("nn.conv2d", "raf.op.conv2d", [&](const Attrs& attrs, const Array& args, const VarValueMap& val_map) { Array raf_args = args; diff --git a/tests/python/pass/test_pass_from_relay.py b/tests/python/pass/test_pass_from_relay.py index d7360e23..e80e6cd4 100644 --- a/tests/python/pass/test_pass_from_relay.py +++ b/tests/python/pass/test_pass_from_relay.py @@ -1068,33 +1068,6 @@ def forward(self, m_x, m_y): check_from_relay(model, r_func, [m_x, m_y]) -@pytest.mark.parametrize("trans", [[False, False], [False, True], [True, False], [True, True]]) -def test_batch_matmul(trans): - class TransposeBatchMatmul(raf.Model): - def build(self, trans): - self.op_name = "batch_matmul" - if trans[0] or trans[1]: - self.op_name += f"_{'t' if trans[0] else 'n'}{'t' if trans[1] else 'n'}" - - @raf.model.trace - def forward(self, m_x, m_y): - x = raf.relu(m_x) - return getattr(raf, self.op_name)(x, m_y) - - model = TransposeBatchMatmul(trans) - m_x, _ = randn((4, 10, 10), dtype="float32") - m_y, _ = randn((4, 10, 10), dtype="float32") - - r_x = raf.ir.var("x", shape=(4, 10, 10), dtype="float32") - r_y = raf.ir.var("x", shape=(4, 10, 10), dtype="float32") - r_out = _relay.nn.batch_matmul( - _relay.nn.relu(r_x), r_y, transpose_a=trans[0], transpose_b=trans[1] - ) - r_func = _relay.Function(params=[r_x, r_y], body=r_out) - - check_from_relay(model, r_func, [m_x, m_y]) - - @pytest.mark.parametrize("device", get_testable_devices()) @pytest.mark.parametrize("shape", [(), (1,), (1, 2, 3, 4)]) @pytest.mark.parametrize("dtype", ["float64", "float32", "float16"])