From 51ec502231e38ee6dbd621bebfa265e293eb3f5a Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 23 Feb 2023 13:01:54 -0800 Subject: [PATCH] [Compatible] Fix batch_matmul from Relay (#155) * [Compatible] Fix batch_matmul from Relay * test * lint --- src/op/from_relay/nn.cc | 22 +++++++++++++++++- tests/python/pass/test_pass_from_relay.py | 27 +++++++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/src/op/from_relay/nn.cc b/src/op/from_relay/nn.cc index 576f662e..46301801 100644 --- a/src/op/from_relay/nn.cc +++ b/src/op/from_relay/nn.cc @@ -15,9 +15,29 @@ namespace raf { namespace op { namespace from_relay { -RAF_GENERIC_ATTR_OP_FROM_RELAY("nn.batch_matmul", "raf.op.batch_matmul_nt"); RAF_GENERIC_ATTR_OP_FROM_RELAY("nn.dense", "raf.op.dense"); +// TVM's nn.batch_matmul has a transpose_a and transpose_b attribute, but RAF's +// batch_matmul_nt does not. Instead, RAF has 4 variants of batch_matmul for +// different combinations of transpose_a and transpose_b. This function +// converts the batch_matmul with transpose_a and transpose_b attributes to +// the appropriate batch_matmul variant. +RELAY_REGISTER_OP("nn.batch_matmul") + .set_attr("FRAFFromRelay", [](const Attrs& attrs, const Array& args, + const VarValueMap& val_map) { + const auto* relay_attrs = attrs.as(); + auto transpose_a = relay_attrs->transpose_a; + auto transpose_b = relay_attrs->transpose_b; + if (transpose_a && transpose_b) { + return Call(Op::Get("raf.op.batch_matmul_tt"), args); + } else if (transpose_a && !transpose_b) { + return Call(Op::Get("raf.op.batch_matmul_tn"), args); + } else if (!transpose_a && transpose_b) { + return Call(Op::Get("raf.op.batch_matmul_nt"), args); + } + return Call(Op::Get("raf.op.batch_matmul"), args); + }); + RAF_OP_FROM_RELAY("nn.conv2d", "raf.op.conv2d", [&](const Attrs& attrs, const Array& args, const VarValueMap& val_map) { Array raf_args = args; diff --git a/tests/python/pass/test_pass_from_relay.py b/tests/python/pass/test_pass_from_relay.py index e80e6cd4..d7360e23 100644 --- a/tests/python/pass/test_pass_from_relay.py +++ b/tests/python/pass/test_pass_from_relay.py @@ -1068,6 +1068,33 @@ def forward(self, m_x, m_y): check_from_relay(model, r_func, [m_x, m_y]) +@pytest.mark.parametrize("trans", [[False, False], [False, True], [True, False], [True, True]]) +def test_batch_matmul(trans): + class TransposeBatchMatmul(raf.Model): + def build(self, trans): + self.op_name = "batch_matmul" + if trans[0] or trans[1]: + self.op_name += f"_{'t' if trans[0] else 'n'}{'t' if trans[1] else 'n'}" + + @raf.model.trace + def forward(self, m_x, m_y): + x = raf.relu(m_x) + return getattr(raf, self.op_name)(x, m_y) + + model = TransposeBatchMatmul(trans) + m_x, _ = randn((4, 10, 10), dtype="float32") + m_y, _ = randn((4, 10, 10), dtype="float32") + + r_x = raf.ir.var("x", shape=(4, 10, 10), dtype="float32") + r_y = raf.ir.var("x", shape=(4, 10, 10), dtype="float32") + r_out = _relay.nn.batch_matmul( + _relay.nn.relu(r_x), r_y, transpose_a=trans[0], transpose_b=trans[1] + ) + r_func = _relay.Function(params=[r_x, r_y], body=r_out) + + check_from_relay(model, r_func, [m_x, m_y]) + + @pytest.mark.parametrize("device", get_testable_devices()) @pytest.mark.parametrize("shape", [(), (1,), (1, 2, 3, 4)]) @pytest.mark.parametrize("dtype", ["float64", "float32", "float16"])