Skip to content

Commit

Permalink
[Relay][Dyn] Dynamic full operator (apache#6260)
Browse files Browse the repository at this point in the history
* moved full from other branch

* fixed some typos

* fix lint

* add final newline

* fix int64 test
  • Loading branch information
Lily Orth-Smith authored and Trevor Morris committed Sep 2, 2020
1 parent f515eae commit 6d8e359
Show file tree
Hide file tree
Showing 13 changed files with 130 additions and 43 deletions.
6 changes: 3 additions & 3 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,11 @@ def elemwise_shape_func(attrs, inputs, _):
return [topi.math.identity(inputs[0])]

register_shape_func("cast", False, elemwise_shape_func)
register_shape_func("zeros", True, no_data_full_shape_func)
register_shape_func("zeros", False, full_shape_func)
register_shape_func("zeros_like", False, elemwise_shape_func)
register_shape_func("ones", True, no_data_full_shape_func)
register_shape_func("ones", False, full_shape_func)
register_shape_func("ones_like", False, elemwise_shape_func)
register_shape_func("full", True, full_shape_func)
register_shape_func("full", False, full_shape_func)
register_shape_func("full_like", False, elemwise_shape_func)
register_shape_func("broadcast_to", True, full_shape_func)

Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/dyn/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@ def zeros_compute(attrs, inputs, output_type):
register_shape_func("dyn.broadcast_to", True, full_shape_func)
register_shape_func("dyn.ones", True, no_data_full_shape_func)
register_shape_func("dyn.zeros", True, no_data_full_shape_func)
register_shape_func("dyn.full", True, full_shape_func)
2 changes: 1 addition & 1 deletion python/tvm/relay/op/dyn/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
_reg.register_injective_schedule("dyn.reshape")
_reg.register_broadcast_schedule("dyn.tile")
_reg.register_injective_schedule("dyn.one_hot")

_reg.register_injective_schedule("dyn.full")

@script
def _reshape_shape_func_input_data(data, newshape, ndim):
Expand Down
6 changes: 5 additions & 1 deletion python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,12 @@ def full(fill_value, shape=(), dtype=""):
result : relay.Expr
The resulting tensor.
"""
if isinstance(shape, Expr):
return _dyn_make.full(fill_value, shape, dtype)
if isinstance(shape, int):
shape = [shape]
if isinstance(shape, (list, tuple)):
shape = const(list(shape), "int32")
shape = list(shape)
return _make.full(fill_value, shape, dtype)


Expand Down
56 changes: 56 additions & 0 deletions src/relay/op/dyn/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/relay/op_attr_types.h>
#include <tvm/runtime/registry.h>
#include <tvm/topi/broadcast.h>
#include <tvm/topi/elemwise.h>
#include <tvm/topi/transform.h>

#include <utility>
Expand Down Expand Up @@ -374,6 +375,61 @@ RELAY_REGISTER_OP("dyn.one_hot")
.set_attr<FTVMCompute>("FTVMCompute", OneHotCompute)
.set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);

bool FullRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const InitOpAttrs* param = attrs.as<InitOpAttrs>();
const auto* fill_value = types[0].as<TensorTypeNode>();
const auto* fill_shape = types[1].as<TensorTypeNode>();
if (fill_value == nullptr) {
return false;
}

DataType out_dtype = param->dtype;
if (out_dtype.bits() == 0) {
out_dtype = fill_value->dtype;
}

CHECK_EQ(fill_value->shape.size(), 0)
<< "Fill value should be a scalar but has dimension " << fill_value->shape.size() << ".";

const IntImmNode* rank = fill_shape->shape[0].as<IntImmNode>();
CHECK(rank) << "Parameter shape must have static rank";

std::vector<IndexExpr> oshape;
for (int i = 0; i < rank->value; ++i) {
oshape.push_back(Any());
}
reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}

Expr MakeFull(Expr fill_value, Expr shape, DataType dtype) {
auto attrs = make_object<InitOpAttrs>();
attrs->dtype = std::move(dtype);
static const Op& op = Op::Get("dyn.full");
return Call(op, {fill_value, shape}, Attrs(attrs), {});
}
Array<te::Tensor> FullCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* out_ttype = out_type.as<TensorTypeNode>();
return {topi::full(out_ttype->shape, out_ttype->dtype, inputs[0]())};
}
TVM_REGISTER_GLOBAL("relay.op.dyn._make.full").set_body_typed(MakeFull);

RELAY_REGISTER_OP("dyn.full")
.describe(R"code(Fill array with scalar value.
)code" TVM_ADD_FILELINE)
.set_attrs_type<InitOpAttrs>()
.set_num_inputs(2)
.add_argument("fill_value", "double", "The value to fill.")
.add_argument("shape", "Tensor", "Target shape.")
.set_support_level(3)
.add_type_rel("DynamicFull", FullRel)
.set_attr<FTVMCompute>("FTVMCompute", FullCompute)
.set_attr<TOpPattern>("TOpPattern", kElemWise);

} // namespace dyn
} // namespace relay
} // namespace tvm
2 changes: 1 addition & 1 deletion src/relay/op/make_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype);

Expr MakeExpandDims(Expr data, int axis, int num_newaxis);

Expr MakeFull(Expr fill_value, Expr shape, DataType dtype);
Expr MakeFull(Expr fill_value, Array<Integer> shape, DataType dtype);

Expr MakeLayoutTransform(Expr data, String src_layout, String dst_layout);

Expand Down
41 changes: 14 additions & 27 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -994,10 +994,9 @@ TVM_REGISTER_NODE_TYPE(InitOpAttrs);

bool FullRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
CHECK_EQ(types.size(), 2);
const InitOpAttrs* param = attrs.as<InitOpAttrs>();
const auto* fill_value = types[0].as<TensorTypeNode>();
const auto* fill_shape = types[1].as<TensorTypeNode>();
if (fill_value == nullptr) {
return false;
}
Expand All @@ -1010,50 +1009,38 @@ bool FullRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
CHECK_EQ(fill_value->shape.size(), 0)
<< "Fill value should be a scalar but has dimension " << fill_value->shape.size() << ".";

const IntImmNode* shape_shape = fill_shape->shape[0].as<IntImmNode>();
CHECK(shape_shape) << "Parameter shape must have static shape";

std::vector<IndexExpr> oshape;
if (param->shape) {
const Array<Integer>& cshape_array = param->shape.value();
for (size_t i = 0; i < cshape_array.size(); ++i) {
oshape.push_back(cshape_array[i]);
}
} else {
for (int i = 0; i < shape_shape->value; ++i) {
oshape.push_back(Any());
}
const Array<Integer>& cshape_array = param->shape.value();
for (size_t i = 0; i < cshape_array.size(); ++i) {
oshape.push_back(cshape_array[i]);
}
reporter->Assign(types[2], TensorType(oshape, out_dtype));
reporter->Assign(types[1], TensorType(oshape, out_dtype));
return true;
}

Expr MakeFull(Expr fill_value, Array<Integer> shape, DataType dtype) {
auto attrs = make_object<InitOpAttrs>();
attrs->dtype = std::move(dtype);
attrs->shape = std::move(shape);
static const Op& op = Op::Get("full");
return Call(op, {fill_value}, Attrs(attrs), {});
}

Array<te::Tensor> FullCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* out_ttype = out_type.as<TensorTypeNode>();
return {topi::full(out_ttype->shape, out_ttype->dtype, inputs[0]())};
}

Expr MakeFull(Expr fill_value, Expr shape, DataType dtype) {
auto attrs = make_object<InitOpAttrs>();
if (const auto* cshape = shape.as<ConstantNode>()) {
attrs->shape = ToVector(cshape->data);
}
attrs->dtype = std::move(dtype);
static const Op& op = Op::Get("full");
return Call(op, {fill_value, shape}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.full").set_body_typed(MakeFull);

RELAY_REGISTER_OP("full")
.describe(R"code(Fill array with scalar value.
)code" TVM_ADD_FILELINE)
.set_attrs_type<InitOpAttrs>()
.set_num_inputs(2)
.set_num_inputs(1)
.add_argument("fill_value", "double", "The value to fill.")
.add_argument("shape", "Tensor", "Target shape.")
.set_support_level(3)
.add_type_rel("Full", FullRel)
.set_attr<FTVMCompute>("FTVMCompute", FullCompute)
Expand Down
10 changes: 10 additions & 0 deletions src/relay/transforms/dynamic_to_static.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,16 @@ class DynamicToStaticMutator : public MixedModeMutator {
}
return Expr(nullptr);
}},
{Op::Get("dyn.full"),
[](const CallNode* call_node) {
if (const ConstantNode* shape = call_node->args[1].as<ConstantNode>()) {
CHECK_EQ(shape->data->ndim, 1);
const InitOpAttrs* param = call_node->attrs.as<InitOpAttrs>();
CHECK(param);
return MakeFull(call_node->args[0], ToVector(shape->data), param->dtype);
}
return Expr(nullptr);
}},
};
}

Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ static inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) {
}

static inline Expr Full(Expr fill_value, Array<IndexExpr> shape, DataType dtype) {
return MakeFull(fill_value, CheckConstantShape(shape), dtype);
return MakeFull(fill_value, CheckConstantShapeArrayInteger(shape), dtype);
}

static inline Expr Conv2D(Expr data, Expr weight, Array<IndexExpr> strides,
Expand Down
24 changes: 16 additions & 8 deletions tests/python/relay/dyn/test_dynamic_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,19 +103,27 @@ def verify_zeros_ones(shape, dtype):

func = relay.Function([dyn_shape], y)
ref_res = ref(shape, dtype)
for target, ctx in ctx_list():
if (target != 'cuda'): #skip cuda because no dynamic support for GPU
for kind in ["vm", "debug"]:
mod = tvm.ir.IRModule.from_expr(func)
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(np.array(shape).astype('int64'))
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_func(func, [np.array(shape).astype('int64')], ref_res.astype('int64'))
verify_zeros_ones((1, 3), 'int64')
verify_zeros_ones((8, 9, 1, 2), 'float32')

def test_dyn_full():
def verify_full(fill_value, src_shape, dtype):
x = relay.var("x", relay.scalar_type(dtype))
rank = len(src_shape)
dyn_src_shape = relay.var("dyn_scr_shape", relay.ty.TensorType((rank,), 'int64'))
z = relay.full(x, dyn_src_shape, dtype)
func = relay.Function([x, dyn_src_shape], z)
ref_res = np.full(src_shape, fill_value).astype(dtype)

verify_zeros_ones((124, 50), 'float64')
verify_func(func, [np.array(fill_value).astype(dtype), np.array(src_shape).astype('int64')], ref_res)
verify_full(4, (1, 3, 4, 4), 'int32')
verify_full(4, (1, 3, 4, 4), 'int64')
verify_full(4.0, (2, 50), 'float32')

if __name__ == "__main__":
test_dyn_reshape()
test_dyn_shape_reshape()
test_dyn_tile()
test_dyn_zeros_ones()
test_dyn_full()
2 changes: 1 addition & 1 deletion tests/python/relay/dyn/test_dynamic_op_level6.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,4 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype):


if __name__ == "__main__":
test_topk()
test_dynamic_topk()
1 change: 1 addition & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ def verify_full(fill_value, src_shape, dtype):
op_res = intrp.evaluate(func)(np.array(fill_value, dtype))
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_full(4, (1, 3, 4, 4), "int32")
#verify_full(4, (1, 3, 4, 4), "int64") # This does not pass, python int32 is not upcast to int64, not sure how to fix it.
verify_full(4.0, (1, 4), "float32")


Expand Down
20 changes: 20 additions & 0 deletions tests/python/relay/test_pass_dynamic_to_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,25 @@ def _verify(indices_shape, depth, on_value, off_value, axis, dtype):
_verify((3, 2, 4, 5), 6, 1, 0, 1, "int32")
_verify((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32")

def test_dynamic_to_static_full():
def verify_full(fill_value, fill_shape, dtype):
x = relay.var("x", relay.scalar_type(dtype))
y = relay.var("y", relay.TensorType(fill_shape, 'int64'))
z = relay.full(x, relay.shape_of(y), dtype)

func = run_infer_type(relay.Function([x, y], z))
func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())

zz = func2.body
assert isinstance(zz, relay.Call)
assert zz.checked_type == relay.TensorType(fill_shape, dtype)

ref_res = np.full(fill_shape, fill_value).astype(dtype)
y_data = np.random.uniform(low=-1, high=1, size=fill_shape).astype('int64')
verify_func(func2, [fill_value, y_data], ref_res)

verify_full(4, (1, 2, 3, 4), 'int32')
verify_full(4.0, (1, 2, 8, 10), 'float32')

if __name__ == "__main__":
test_dynamic_to_static_reshape()
Expand All @@ -312,3 +331,4 @@ def _verify(indices_shape, depth, on_value, off_value, axis, dtype):
test_dynamic_to_static_zeros_ones()
test_dynamic_to_static_resize()
test_dynamic_to_static_one_hot()
test_dynamic_to_static_full()

0 comments on commit 6d8e359

Please sign in to comment.