diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 0b873bf91bf1..3f5ecaa2cbce 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -86,7 +86,8 @@ class ConstantFolder : public ExprMutator { shape_func_op_(Op::Get("vm.shape_func")), alloc_tensor_op_(Op::Get("memory.alloc_tensor")), alloc_storage_op_(Op::Get("memory.alloc_storage")), - cast_op_(Op::Get("cast")) {} + cast_op_(Op::Get("cast")), + ndarray_size_op_(Op::Get("ndarray_size")) {} Expr VisitExpr_(const LetNode* op) final { Expr value = this->Mutate(op->value); @@ -128,6 +129,10 @@ class ConstantFolder : public ExprMutator { return EvaluateShapeOf(res, origin_args, call->attrs); } + if (call->op == ndarray_size_op_) { + return EvaluateNdarraySize(res, origin_args, call->attrs); + } + // We should think about potentially constant evaluation over these ops too. if (call->op == invoke_tvm_op_ || call->op == shape_func_op_ || call->op == alloc_tensor_op_ || call->op == alloc_storage_op_) { @@ -173,6 +178,7 @@ class ConstantFolder : public ExprMutator { const Op& alloc_tensor_op_; const Op& alloc_storage_op_; const Op& cast_op_; + const Op& ndarray_size_op_; // Convert value to expression. Expr ObjectToExpr(const ObjectRef& value) { @@ -223,10 +229,8 @@ class ConstantFolder : public ExprMutator { CHECK(param != nullptr); tvm::Array ishape; - if (const ConstantNode* op = input.as()) { - ishape = op->tensor_type()->shape; - } else if (input->checked_type_.defined()) { - ishape = input->checked_type().as()->shape; + if (auto opt = GetConstantShape(input)) { + ishape = opt.value(); } else { return expr; } @@ -261,12 +265,69 @@ class ConstantFolder : public ExprMutator { shape = Constant(ndarray); } + return CastValue(shape, param->dtype); + } + + // Evaluate a call to the ndarray_size operator for tensors with constant + // shapes. + Expr EvaluateNdarraySize(Expr expr, Array args, Attrs attrs) { + Expr input = args[0]; + const auto* param = attrs.as(); + CHECK(param != nullptr); + + tvm::Array ishape; + if (auto opt = GetConstantShape(input)) { + ishape = opt.value(); + } else { + return expr; + } + + // Get the constant size + DLContext ctx; + ctx.device_type = kDLCPU; + ctx.device_id = 0; + runtime::NDArray value; + DLDataType cdtype = DataType::Int(32); + value = runtime::NDArray::Empty({1}, cdtype, ctx); + int32_t* data = static_cast(value->data); + if (ishape.size() == 0) { + *data = 0; + } else { + *data = 1; + using ::tvm::tir::IntImmNode; + for (size_t i = 0; i < ishape.size(); ++i) { + if (const IntImmNode* dim = ishape[i].as()) { + *data *= dim->value; + } else { + return expr; + } + } + } + + Constant size = Downcast(ObjectToExpr(value)); + return CastValue(size, param->dtype); + } + + Expr CastValue(const Expr& value, DataType dtype) { // Cast the constant into correct dtype auto cast_attrs = make_object(); - cast_attrs->dtype = param->dtype; - Expr ret = Call(cast_op_, {shape}, Attrs(cast_attrs), {}); + cast_attrs->dtype = dtype; + Expr ret = Call(cast_op_, {value}, Attrs(cast_attrs), {}); return ConstEvaluate(ret); } + + Optional> GetConstantShape(const Expr& input) { + tvm::Array ishape; + if (const ConstantNode* op = input.as()) { + ishape = op->tensor_type()->shape; + } else if (input->checked_type_.defined()) { + ishape = input->checked_type().as()->shape; + } else { + return Optional>(nullptr); + } + + return Optional>(ishape); + } }; Expr FoldConstant(const Expr& expr, const IRModule& mod) { diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py index fcccab5c6b97..e9852689ee9d 100644 --- a/tests/python/relay/test_pass_fold_constant.py +++ b/tests/python/relay/test_pass_fold_constant.py @@ -164,6 +164,27 @@ def expected(dtype): assert tvm.ir.structural_equal(zz, zexpected) +def test_fold_ndarray_size(): + c_shape = (8, 9, 10) + def before(dtype): + x = relay.var("x", shape=c_shape, dtype="float32") + y = relay.var("y", shape=c_shape, dtype="float32") + z = relay.ndarray_size(x + y, dtype) + return relay.Function([x, y], z) + + def expected(dtype): + x = relay.var("x", shape=c_shape, dtype="float32") + y = relay.var("y", shape=c_shape, dtype="float32") + z = relay.const([np.size(np.zeros(c_shape))], dtype=dtype) + func = relay.Function([x, y], z) + return func + + for dtype in ["int32", "float32"]: + zz = run_opt_pass(before(dtype), transform.FoldConstant()) + zexpected = run_opt_pass(expected(dtype), transform.InferType()) + assert tvm.ir.structural_equal(zz, zexpected) + + def test_fold_full(): c_shape = (8, 9, 10) def before(): @@ -228,3 +249,4 @@ def initializer(_, param): test_fold_shape_of() test_fold_full() test_fold_batch_norm() + test_fold_ndarray_size()