Skip to content

Commit

Permalink
Fix a bug with Alter Op Layout (#6626)
Browse files Browse the repository at this point in the history
* Regression test for a Scalar type issue in Alter Op Layout

* fix the regression test by avoiding the Scalar optimization if types aren't defined
  • Loading branch information
Matthew Brookhart authored Oct 5, 2020
1 parent 86122d1 commit 311eca4
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/relay/transforms/transform_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class TransformMemorizer : public ObjectRef {
if (src_layout.ndim_primal() < dst_layout.ndim_primal()) {
// If scalar, then no need of layout transformation as scalar can be broadcasted easily even
// if the other operand has a transformed layout.
if (IsScalar(input_expr)) {
if (input_expr->checked_type_.defined() && IsScalar(input_expr)) {
return raw;
}
int num_new_axis = dst_layout.ndim_primal() - src_layout.ndim_primal();
Expand Down
89 changes: 89 additions & 0 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,95 @@ def expected():
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)


def test_alter_layout_scalar_regression():
"""regression test where scalar fails"""

def before():
x = relay.var("x", shape=(1, 56, 56, 64))
weight = relay.var("weight", shape=(3, 3, 64, 16))
bias = relay.var("bias", shape=(1, 1, 1, 16))
y = relay.nn.conv2d(
x,
weight,
channels=16,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
)
y = relay.add(y, bias)
mean = relay.mean(y, axis=3, exclude=True)
var = relay.variance(y, axis=3, exclude=True)
gamma = relay.var("gamma")
beta = relay.var("beta")
y = relay.nn.batch_norm(y, gamma, beta, mean, var, axis=3)
y = y[0]
y = relay.Function(analysis.free_vars(y), y)
return y

def alter_conv2d(attrs, inputs, tinfos, out_type):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs["data_layout"] = "NCHW16c"
return relay.nn.conv2d(data, weight, **new_attrs)

def expected():
x = relay.var("x", shape=(1, 56, 56, 64))
weight = relay.var("weight", shape=(3, 3, 64, 16))
bias = relay.var("bias", shape=(1, 1, 1, 16))
x = relay.layout_transform(x, src_layout="NHWC", dst_layout="NCHW")
x = relay.layout_transform(x, src_layout="NCHW", dst_layout="NCHW16c")
weight = relay.layout_transform(weight, src_layout="HWIO", dst_layout="OIHW")
y = relay.nn.conv2d(
x, weight, channels=16, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c"
)
bias = relay.layout_transform(bias, src_layout="NHWC", dst_layout="NCHW")
bias = relay.layout_transform(bias, src_layout="NCHW", dst_layout="NCHW16c")
add = relay.add(y, bias)
y = relay.layout_transform(add, src_layout="NCHW16c", dst_layout="NCHW")
y = relay.layout_transform(y, src_layout="NCHW", dst_layout="NHWC")
mean = relay.mean(y, axis=3, exclude=True)
var = relay.variance(y, axis=3, exclude=True)
denom = relay.const(1.0) / relay.sqrt(var + relay.const(1e-05))
gamma = relay.var("gamma", shape=(16,))
denom = denom * gamma
denom_expand1 = relay.expand_dims(denom, axis=1, num_newaxis=2)
denom_expand2 = relay.expand_dims(denom_expand1, axis=0)
denom_nchwc16 = relay.layout_transform(
denom_expand2, src_layout="NCHW", dst_layout="NCHW16c"
)
out = add * denom_nchwc16
beta = relay.var("beta", shape=(16,))
numerator = (-mean) * denom + beta
numerator_expand1 = relay.expand_dims(numerator, axis=1, num_newaxis=2)
numerator_expand2 = relay.expand_dims(numerator_expand1, axis=0)
numerator_nchwc16 = relay.layout_transform(
numerator_expand2, src_layout="NCHW", dst_layout="NCHW16c"
)
out = out + numerator_nchwc16
out = relay.layout_transform(out, src_layout="NCHW16c", dst_layout="NCHW")
y = relay.layout_transform(out, src_layout="NCHW", dst_layout="NHWC")
y = relay.Function(analysis.free_vars(y), y)
return y

with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before()
desired_layouts = {"nn.conv2d": ["NCHW", "default"], "nn.batch_norm": ["NHWC", "default"]}
a = run_opt_pass(
a,
[
transform.InferType(),
relay.transform.ConvertLayout(desired_layouts),
transform.SimplifyInference(),
transform.CanonicalizeOps(),
transform.AlterOpLayout(),
],
)
b = run_opt_pass(expected(), transform.InferType())

assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)


def test_alter_layout_concatenate():
""" NCHW, NHWC and corner case concatenate layout transform."""

Expand Down

0 comments on commit 311eca4

Please sign in to comment.