Skip to content

Commit

Permalink
[Relay][Fix] Fix alter op layout when calling a global var (#4454) (#…
Browse files Browse the repository at this point in the history
…5904)

* [Relay][Fix] Fix alter op layout when calling a global var

* add test case
  • Loading branch information
icemelon authored Jun 24, 2020
1 parent 804b7fa commit 3cd42c9
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/relay/pass/alter_op_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ std::tuple<Array<Layout>, Array<Layout>, bool> CallInfer(
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr> > &old_in_shapes) {
static auto finfer_layout = Op::GetAttr<FInferCorrectLayout>("FInferCorrectLayout");
if (!call->op.as<OpNode>()) {
return std::make_tuple<>(Array<Layout>(nullptr), Array<Layout>(nullptr), false);
}

Op op = Downcast<Op>(call->op);
if (finfer_layout.count(op)) {
Expand Down
42 changes: 42 additions & 0 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,47 @@ def expected_nhwc():

assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)

def test_alter_op_with_global_var():
"""Test directly replacing an operator with a new one"""
def before():
x = relay.var("x", shape=(1, 64, 56, 56))
weight = relay.var('weight', shape=(64, 64, 3, 3))
y = relay.nn.conv2d(x, weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
y = relay.nn.relu(y)
mod = relay.Module()
foo = relay.GlobalVar('foo')
mod[foo] = relay.Function([x, weight], y)
mod["main"] = relay.Function([x, weight], foo(x, weight))
return mod

def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
weight = relay.multiply(weight, relay.const(2.0, "float32"))
return relay.nn.conv2d(data, weight, **attrs)

def expected():
x = relay.var("x", shape=(1, 64, 56, 56))
weight = relay.var('weight', shape=(64, 64, 3, 3))
y = relay.nn.conv2d(x, relay.multiply(weight, relay.const(2.0, "float32")),
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
y = relay.nn.relu(y)
mod = relay.Module()
foo = relay.GlobalVar('foo')
mod[foo] = relay.Function([x, weight], y)
mod["main"] = relay.Function([x, weight], foo(x, weight))
return mod

with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before()
a = transform.AlterOpLayout()(a)
b = transform.InferType()(expected())

assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)

if __name__ == "__main__":
test_alter_op()
Expand All @@ -949,3 +990,4 @@ def expected_nhwc():
test_alter_layout_pool()
test_alter_layout_sum()
test_alter_layout_nhwc_nchw_arm()
test_alter_op_with_global_var()

0 comments on commit 3cd42c9

Please sign in to comment.