Skip to content

Commit

Permalink
[Relay] Handle ndarray_size in FoldConstant (#6156)
Browse files Browse the repository at this point in the history
* [Relay] Handle ndarray_size in FoldConstant

* Use Optional
  • Loading branch information
lixiaoquan authored Jul 28, 2020
1 parent 8cd53e0 commit 44ff1f3
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 7 deletions.
75 changes: 68 additions & 7 deletions src/relay/transforms/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ class ConstantFolder : public ExprMutator {
shape_func_op_(Op::Get("vm.shape_func")),
alloc_tensor_op_(Op::Get("memory.alloc_tensor")),
alloc_storage_op_(Op::Get("memory.alloc_storage")),
cast_op_(Op::Get("cast")) {}
cast_op_(Op::Get("cast")),
ndarray_size_op_(Op::Get("ndarray_size")) {}

Expr VisitExpr_(const LetNode* op) final {
Expr value = this->Mutate(op->value);
Expand Down Expand Up @@ -128,6 +129,10 @@ class ConstantFolder : public ExprMutator {
return EvaluateShapeOf(res, origin_args, call->attrs);
}

if (call->op == ndarray_size_op_) {
return EvaluateNdarraySize(res, origin_args, call->attrs);
}

// We should think about potentially constant evaluation over these ops too.
if (call->op == invoke_tvm_op_ || call->op == shape_func_op_ || call->op == alloc_tensor_op_ ||
call->op == alloc_storage_op_) {
Expand Down Expand Up @@ -173,6 +178,7 @@ class ConstantFolder : public ExprMutator {
const Op& alloc_tensor_op_;
const Op& alloc_storage_op_;
const Op& cast_op_;
const Op& ndarray_size_op_;

// Convert value to expression.
Expr ObjectToExpr(const ObjectRef& value) {
Expand Down Expand Up @@ -223,10 +229,8 @@ class ConstantFolder : public ExprMutator {
CHECK(param != nullptr);

tvm::Array<IndexExpr> ishape;
if (const ConstantNode* op = input.as<ConstantNode>()) {
ishape = op->tensor_type()->shape;
} else if (input->checked_type_.defined()) {
ishape = input->checked_type().as<TensorTypeNode>()->shape;
if (auto opt = GetConstantShape(input)) {
ishape = opt.value();
} else {
return expr;
}
Expand Down Expand Up @@ -261,12 +265,69 @@ class ConstantFolder : public ExprMutator {
shape = Constant(ndarray);
}

return CastValue(shape, param->dtype);
}

// Evaluate a call to the ndarray_size operator for tensors with constant
// shapes.
Expr EvaluateNdarraySize(Expr expr, Array<Expr> args, Attrs attrs) {
Expr input = args[0];
const auto* param = attrs.as<NdarraySizeAttrs>();
CHECK(param != nullptr);

tvm::Array<IndexExpr> ishape;
if (auto opt = GetConstantShape(input)) {
ishape = opt.value();
} else {
return expr;
}

// Get the constant size
DLContext ctx;
ctx.device_type = kDLCPU;
ctx.device_id = 0;
runtime::NDArray value;
DLDataType cdtype = DataType::Int(32);
value = runtime::NDArray::Empty({1}, cdtype, ctx);
int32_t* data = static_cast<int32_t*>(value->data);
if (ishape.size() == 0) {
*data = 0;
} else {
*data = 1;
using ::tvm::tir::IntImmNode;
for (size_t i = 0; i < ishape.size(); ++i) {
if (const IntImmNode* dim = ishape[i].as<IntImmNode>()) {
*data *= dim->value;
} else {
return expr;
}
}
}

Constant size = Downcast<Constant>(ObjectToExpr(value));
return CastValue(size, param->dtype);
}

Expr CastValue(const Expr& value, DataType dtype) {
// Cast the constant into correct dtype
auto cast_attrs = make_object<CastAttrs>();
cast_attrs->dtype = param->dtype;
Expr ret = Call(cast_op_, {shape}, Attrs(cast_attrs), {});
cast_attrs->dtype = dtype;
Expr ret = Call(cast_op_, {value}, Attrs(cast_attrs), {});
return ConstEvaluate(ret);
}

Optional<tvm::Array<IndexExpr>> GetConstantShape(const Expr& input) {
tvm::Array<IndexExpr> ishape;
if (const ConstantNode* op = input.as<ConstantNode>()) {
ishape = op->tensor_type()->shape;
} else if (input->checked_type_.defined()) {
ishape = input->checked_type().as<TensorTypeNode>()->shape;
} else {
return Optional<tvm::Array<IndexExpr>>(nullptr);
}

return Optional<tvm::Array<IndexExpr>>(ishape);
}
};

Expr FoldConstant(const Expr& expr, const IRModule& mod) {
Expand Down
22 changes: 22 additions & 0 deletions tests/python/relay/test_pass_fold_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,27 @@ def expected(dtype):
assert tvm.ir.structural_equal(zz, zexpected)


def test_fold_ndarray_size():
c_shape = (8, 9, 10)
def before(dtype):
x = relay.var("x", shape=c_shape, dtype="float32")
y = relay.var("y", shape=c_shape, dtype="float32")
z = relay.ndarray_size(x + y, dtype)
return relay.Function([x, y], z)

def expected(dtype):
x = relay.var("x", shape=c_shape, dtype="float32")
y = relay.var("y", shape=c_shape, dtype="float32")
z = relay.const([np.size(np.zeros(c_shape))], dtype=dtype)
func = relay.Function([x, y], z)
return func

for dtype in ["int32", "float32"]:
zz = run_opt_pass(before(dtype), transform.FoldConstant())
zexpected = run_opt_pass(expected(dtype), transform.InferType())
assert tvm.ir.structural_equal(zz, zexpected)


def test_fold_full():
c_shape = (8, 9, 10)
def before():
Expand Down Expand Up @@ -228,3 +249,4 @@ def initializer(_, param):
test_fold_shape_of()
test_fold_full()
test_fold_batch_norm()
test_fold_ndarray_size()

0 comments on commit 44ff1f3

Please sign in to comment.