diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc index bbfb97c56dc2..d893d941576f 100644 --- a/src/relay/pass/alter_op_layout.cc +++ b/src/relay/pass/alter_op_layout.cc @@ -161,6 +161,9 @@ std::tuple, Array, bool> CallInfer( const Array& old_in_layouts, const Array > &old_in_shapes) { static auto finfer_layout = Op::GetAttr("FInferCorrectLayout"); + if (!call->op.as()) { + return std::make_tuple<>(Array(nullptr), Array(nullptr), false); + } Op op = Downcast(call->op); if (finfer_layout.count(op)) { diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index c1941c970bf3..9ab582d5b3e2 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -931,6 +931,47 @@ def expected_nhwc(): assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) +def test_alter_op_with_global_var(): + """Test directly replacing an operator with a new one""" + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight = relay.var('weight', shape=(64, 64, 3, 3)) + y = relay.nn.conv2d(x, weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1)) + y = relay.nn.relu(y) + mod = relay.Module() + foo = relay.GlobalVar('foo') + mod[foo] = relay.Function([x, weight], y) + mod["main"] = relay.Function([x, weight], foo(x, weight)) + return mod + + def alter_conv2d(attrs, inputs, tinfos): + data, weight = inputs + weight = relay.multiply(weight, relay.const(2.0, "float32")) + return relay.nn.conv2d(data, weight, **attrs) + + def expected(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight = relay.var('weight', shape=(64, 64, 3, 3)) + y = relay.nn.conv2d(x, relay.multiply(weight, relay.const(2.0, "float32")), + channels=64, + kernel_size=(3, 3), + padding=(1, 1)) + y = relay.nn.relu(y) + mod = relay.Module() + foo = relay.GlobalVar('foo') + mod[foo] = relay.Function([x, weight], y) + mod["main"] = relay.Function([x, weight], foo(x, weight)) + return mod + + with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d): + a = before() + a = transform.AlterOpLayout()(a) + b = transform.InferType()(expected()) + + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) if __name__ == "__main__": test_alter_op() @@ -949,3 +990,4 @@ def expected_nhwc(): test_alter_layout_pool() test_alter_layout_sum() test_alter_layout_nhwc_nchw_arm() + test_alter_op_with_global_var()