From 3dbce0f87fb48207114fcb263270708be5ed1ee6 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 19 Jan 2020 16:57:11 +0900 Subject: [PATCH] fixed test on simple net --- tests/python/relay/test_pass_partition_graph.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 69be0e602d439..45951a8f3b151 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -477,11 +477,11 @@ def test_partition_conv_bias_relu(): def get_layers(prefix, data, in_channel, out_channel, include_bn=True, include_sigmoid=False): - weight = relay.const(np.random.randn(out_channel, in_channel, 3, 3)) - bn_gamma = relay.const(np.random.randn(out_channel)) - bn_beta = relay.const(np.random.randn(out_channel)) - bn_mmean = relay.const(np.random.randn(out_channel)) - bn_mvar = relay.const(np.random.randn(out_channel)) + weight = relay.var(prefix + "weight") + bn_gamma = relay.var(prefix + "bn_gamma") + bn_beta = relay.var(prefix + "bn_beta") + bn_mmean = relay.var(prefix + "bn_mean") + bn_mvar = relay.var(prefix + "bn_var") layer = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3), channels=out_channel, padding=(1, 1)) @@ -511,7 +511,7 @@ def pre_optimize(mod, params): ]) if params != {}: - # This is required for constant folding on mobilenet + # This is required for constant folding mod["main"] = bind_params_by_name(mod["main"], params) with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]): @@ -572,7 +572,7 @@ def test_exec(mod, params, ref_mod, ref_params, out_shape): net = get_net() mod, params = tvm.relay.testing.create_workload(net) ref_mod, ref_params = tvm.relay.testing.create_workload(net) - # test_exec(mod, params, ref_mod, ref_params, (1, 16, 224, 224)) + test_exec(mod, params, ref_mod, ref_params, (1, 16, 224, 224)) mod, params = relay.testing.mobilenet.get_workload() ref_mod, ref_params = relay.testing.mobilenet.get_workload()