diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index a52f42054c3e..fab184cfb5a6 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -203,6 +203,7 @@ class ConstantFolder : public ExprMutator { // Constant evaluate a expression. Expr ConstEvaluate(Expr expr) { std::vector passes = {transform::FuseOps(0), + transform::ToANormalForm(), transform::InferType()}; Function func; if (expr.as()) { diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py index b212b26c99a7..a981667219cd 100644 --- a/tests/python/relay/test_pass_fold_constant.py +++ b/tests/python/relay/test_pass_fold_constant.py @@ -32,6 +32,25 @@ def run_opt_pass(expr, opt_pass): return entry if isinstance(expr, relay.Function) else entry.body +def test_concatenate_const(): + def before(): + data = tvm.nd.array(np.array([1.0, 2.0, 3.0])) + const = relay.const(data) + concat = relay.op.concatenate([const, const], axis=0) + func = relay.Function([], concat) + return func + + def expected(): + data = tvm.nd.array(np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0])) + const = relay.const(data) + func = relay.Function([], const) + return func + + zz = run_opt_pass(before(), transform.FoldConstant()) + zexpected = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(zz, zexpected) + + def test_fold_const(): c_data = np.array([1, 2, 3]).astype("float32") t = relay.TensorType([1, 2, 3], "float32")