diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index b034c5699fe3..5d3ec03f578f 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -184,6 +184,7 @@ class ConstantFolder : public ExprMutator { // Constant evaluate a expression. Expr ConstEvaluate(Expr expr) { std::vector passes = {transform::FuseOps(0), + transform::ToANormalForm(), transform::InferType()}; Function func; if (expr.as()) { diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py index 4752597c828e..5167ff2bcde4 100644 --- a/tests/python/relay/test_pass_fold_constant.py +++ b/tests/python/relay/test_pass_fold_constant.py @@ -29,6 +29,25 @@ def run_opt_pass(expr, opt_pass): return entry if isinstance(expr, relay.Function) else entry.body +def test_concatenate_const(): + def before(): + data = tvm.nd.array(np.array([1.0, 2.0, 3.0])) + const = relay.const(data) + concat = relay.op.concatenate([const, const], axis=0) + func = relay.Function([], concat) + return func + + def expected(): + data = tvm.nd.array(np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0])) + const = relay.const(data) + func = relay.Function([], const) + return func + + zz = run_opt_pass(before(), transform.FoldConstant()) + zexpected = run_opt_pass(expected(), transform.InferType()) + assert relay.analysis.graph_equal(zz, zexpected) + + def test_fold_const(): c_data = np.array([1, 2, 3]).astype("float32") t = relay.TensorType([1, 2, 3], "float32")