From 526db3a7b3b1377dc0370470a1d4f63ea63b33a4 Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Thu, 30 Apr 2020 10:00:27 -0700 Subject: [PATCH] [Fix] Add ConstantNode to IsAtomic (#5457) * add constantnode to atomic * Add ToANormalForm to FoldConstant --- src/relay/transforms/fold_constant.cc | 1 + tests/python/relay/test_pass_fold_constant.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index a52f42054c3e..fab184cfb5a6 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -203,6 +203,7 @@ class ConstantFolder : public ExprMutator { // Constant evaluate a expression. Expr ConstEvaluate(Expr expr) { std::vector passes = {transform::FuseOps(0), + transform::ToANormalForm(), transform::InferType()}; Function func; if (expr.as()) { diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py index b212b26c99a7..a981667219cd 100644 --- a/tests/python/relay/test_pass_fold_constant.py +++ b/tests/python/relay/test_pass_fold_constant.py @@ -32,6 +32,25 @@ def run_opt_pass(expr, opt_pass): return entry if isinstance(expr, relay.Function) else entry.body +def test_concatenate_const(): + def before(): + data = tvm.nd.array(np.array([1.0, 2.0, 3.0])) + const = relay.const(data) + concat = relay.op.concatenate([const, const], axis=0) + func = relay.Function([], concat) + return func + + def expected(): + data = tvm.nd.array(np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0])) + const = relay.const(data) + func = relay.Function([], const) + return func + + zz = run_opt_pass(before(), transform.FoldConstant()) + zexpected = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(zz, zexpected) + + def test_fold_const(): c_data = np.array([1, 2, 3]).astype("float32") t = relay.TensorType([1, 2, 3], "float32")