Skip to content

Commit

Permalink
LRN only supports 4D tensors, remove it from alter_op_layout (apache#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart authored and Trevor Morris committed Jun 18, 2020
1 parent d162f04 commit 2d31b3e
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 1 deletion.
1 change: 0 additions & 1 deletion src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,6 @@ centered at that value (zero padding is added where necessary).
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.add_type_rel("Identity", IdentityRel);


Expand Down
51 changes: 51 additions & 0 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,56 @@ def expected():

assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)

def test_alter_layout_lrn():
"""Test alternating the layout of a conv2d.
The layout of broadcast operators and the weight should be changed accordingly.
"""
def before():
x = relay.var("x", shape=(1, 64, 56, 56))
bias = relay.var("bias")
weight = relay.var("weight")
y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
y = relay.nn.max_pool2d(y, pool_size=(2, 2))
y = relay.nn.lrn(y)
y = relay.Function(analysis.free_vars(y), y)
return y

def alter_conv2d(attrs, inputs, tinfos, out_type):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c'
new_attrs['kernel_layout'] = 'OIHW16i'
return relay.nn.conv2d(data, weight, **new_attrs)


def expected():
x = relay.var("x", shape=(1, 64, 56, 56))
bias = relay.var("bias", shape=(64,))
weight = relay.var("weight", shape=(64, 64, 3, 3))

y = relay.layout_transform(x, "NCHW", "NCHW16c")
w = relay.layout_transform(weight, "OIHW", "OIHW16i")
y = relay.nn.conv2d(y, w,
channels=64,
kernel_size=(3, 3),
padding=(1, 1),
kernel_layout="OIHW16i",
data_layout="NCHW16c")
y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout="NCHW16c")
y = relay.layout_transform(y, "NCHW16c", "NCHW")
y = relay.nn.lrn(y)
y = relay.Function(analysis.free_vars(y), y)
return y

with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before()
a = run_opt_pass(a, [transform.CanonicalizeOps(),
transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())

assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)



def test_alter_layout_dual_path():
"""
Expand Down Expand Up @@ -1027,6 +1077,7 @@ def expected():
test_alter_return_none()
test_alter_layout()
test_alter_layout_dual_path()
test_alter_layout_lrn()
test_alter_layout_resnet()
test_alter_layout_broadcast_op()
test_alter_layout_broadcast_scalar_op()
Expand Down

0 comments on commit 2d31b3e

Please sign in to comment.