Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Dyn] Dynamic full operator #6260

Merged
merged 5 commits into from
Aug 13, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,11 @@ def elemwise_shape_func(attrs, inputs, _):
return [topi.math.identity(inputs[0])]

register_shape_func("cast", False, elemwise_shape_func)
register_shape_func("zeros", True, no_data_full_shape_func)
register_shape_func("zeros", False, full_shape_func)
register_shape_func("zeros_like", False, elemwise_shape_func)
register_shape_func("ones", True, no_data_full_shape_func)
register_shape_func("ones", False, full_shape_func)
Comment on lines +204 to +206
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:/ Is this a rebase error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I was going through my code and I forgot to change this for the zeros and ones PR -- for the static shape ones and zeros, it is not data dependent since the shape of the input is in attributes, and uses the normal full_shape_func. It passes the tests for static ones and zeros..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps I should just remove the shape func registration for the static ops, it seems like this change didn't affect any behavior.

register_shape_func("ones_like", False, elemwise_shape_func)
register_shape_func("full", True, full_shape_func)
register_shape_func("full", False, full_shape_func)
register_shape_func("full_like", False, elemwise_shape_func)
register_shape_func("broadcast_to", True, full_shape_func)

Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/dyn/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@ def zeros_compute(attrs, inputs, output_type):
register_shape_func("dyn.broadcast_to", True, full_shape_func)
register_shape_func("dyn.ones", True, no_data_full_shape_func)
register_shape_func("dyn.zeros", True, no_data_full_shape_func)
register_shape_func("dyn.full", True, full_shape_func)
2 changes: 1 addition & 1 deletion python/tvm/relay/op/dyn/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
_reg.register_injective_schedule("dyn.reshape")
_reg.register_broadcast_schedule("dyn.tile")
_reg.register_injective_schedule("dyn.one_hot")

_reg.register_injective_schedule("dyn.full")

@script
def _reshape_shape_func_input_data(data, newshape, ndim):
Expand Down
6 changes: 5 additions & 1 deletion python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,12 @@ def full(fill_value, shape=(), dtype=""):
result : relay.Expr
The resulting tensor.
"""
if isinstance(shape, Expr):
return _dyn_make.full(fill_value, shape, dtype)
if isinstance(shape, int):
shape = [shape]
if isinstance(shape, (list, tuple)):
shape = const(list(shape), "int32")
shape = list(shape)
return _make.full(fill_value, shape, dtype)


Expand Down
56 changes: 56 additions & 0 deletions src/relay/op/dyn/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/relay/op_attr_types.h>
#include <tvm/runtime/registry.h>
#include <tvm/topi/broadcast.h>
#include <tvm/topi/elemwise.h>
#include <tvm/topi/transform.h>

#include <utility>
Expand Down Expand Up @@ -374,6 +375,61 @@ RELAY_REGISTER_OP("dyn.one_hot")
.set_attr<FTVMCompute>("FTVMCompute", OneHotCompute)
.set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);

bool FullRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const InitOpAttrs* param = attrs.as<InitOpAttrs>();
const auto* fill_value = types[0].as<TensorTypeNode>();
const auto* fill_shape = types[1].as<TensorTypeNode>();
if (fill_value == nullptr) {
return false;
}

DataType out_dtype = param->dtype;
if (out_dtype.bits() == 0) {
out_dtype = fill_value->dtype;
}

CHECK_EQ(fill_value->shape.size(), 0)
<< "Fill value should be a scalar but has dimension " << fill_value->shape.size() << ".";

const IntImmNode* rank = fill_shape->shape[0].as<IntImmNode>();
CHECK(rank) << "Parameter shape must have static rank";

std::vector<IndexExpr> oshape;
for (int i = 0; i < rank->value; ++i) {
oshape.push_back(Any());
}
reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}

Expr MakeFull(Expr fill_value, Expr shape, DataType dtype) {
auto attrs = make_object<InitOpAttrs>();
attrs->dtype = std::move(dtype);
static const Op& op = Op::Get("dyn.full");
return Call(op, {fill_value, shape}, Attrs(attrs), {});
}
Array<te::Tensor> FullCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* out_ttype = out_type.as<TensorTypeNode>();
return {topi::full(out_ttype->shape, out_ttype->dtype, inputs[0]())};
}
TVM_REGISTER_GLOBAL("relay.op.dyn._make.full").set_body_typed(MakeFull);

RELAY_REGISTER_OP("dyn.full")
.describe(R"code(Fill array with scalar value.

)code" TVM_ADD_FILELINE)
.set_attrs_type<InitOpAttrs>()
.set_num_inputs(2)
.add_argument("fill_value", "double", "The value to fill.")
.add_argument("shape", "Tensor", "Target shape.")
.set_support_level(3)
.add_type_rel("DynamicFull", FullRel)
.set_attr<FTVMCompute>("FTVMCompute", FullCompute)
.set_attr<TOpPattern>("TOpPattern", kElemWise);

} // namespace dyn
} // namespace relay
} // namespace tvm
2 changes: 1 addition & 1 deletion src/relay/op/make_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype);

Expr MakeExpandDims(Expr data, int axis, int num_newaxis);

Expr MakeFull(Expr fill_value, Expr shape, DataType dtype);
Expr MakeFull(Expr fill_value, Array<Integer> shape, DataType dtype);

Expr MakeLayoutTransform(Expr data, String src_layout, String dst_layout);

Expand Down
41 changes: 14 additions & 27 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -994,10 +994,9 @@ TVM_REGISTER_NODE_TYPE(InitOpAttrs);

bool FullRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
CHECK_EQ(types.size(), 2);
const InitOpAttrs* param = attrs.as<InitOpAttrs>();
const auto* fill_value = types[0].as<TensorTypeNode>();
const auto* fill_shape = types[1].as<TensorTypeNode>();
if (fill_value == nullptr) {
return false;
}
Expand All @@ -1010,50 +1009,38 @@ bool FullRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
CHECK_EQ(fill_value->shape.size(), 0)
<< "Fill value should be a scalar but has dimension " << fill_value->shape.size() << ".";

const IntImmNode* shape_shape = fill_shape->shape[0].as<IntImmNode>();
CHECK(shape_shape) << "Parameter shape must have static shape";

std::vector<IndexExpr> oshape;
if (param->shape) {
const Array<Integer>& cshape_array = param->shape.value();
for (size_t i = 0; i < cshape_array.size(); ++i) {
oshape.push_back(cshape_array[i]);
}
} else {
for (int i = 0; i < shape_shape->value; ++i) {
oshape.push_back(Any());
}
const Array<Integer>& cshape_array = param->shape.value();
for (size_t i = 0; i < cshape_array.size(); ++i) {
oshape.push_back(cshape_array[i]);
}
reporter->Assign(types[2], TensorType(oshape, out_dtype));
reporter->Assign(types[1], TensorType(oshape, out_dtype));
return true;
}

Expr MakeFull(Expr fill_value, Array<Integer> shape, DataType dtype) {
auto attrs = make_object<InitOpAttrs>();
attrs->dtype = std::move(dtype);
attrs->shape = std::move(shape);
static const Op& op = Op::Get("full");
return Call(op, {fill_value}, Attrs(attrs), {});
}

Array<te::Tensor> FullCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* out_ttype = out_type.as<TensorTypeNode>();
return {topi::full(out_ttype->shape, out_ttype->dtype, inputs[0]())};
}

Expr MakeFull(Expr fill_value, Expr shape, DataType dtype) {
auto attrs = make_object<InitOpAttrs>();
if (const auto* cshape = shape.as<ConstantNode>()) {
attrs->shape = ToVector(cshape->data);
}
attrs->dtype = std::move(dtype);
static const Op& op = Op::Get("full");
return Call(op, {fill_value, shape}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.full").set_body_typed(MakeFull);

RELAY_REGISTER_OP("full")
.describe(R"code(Fill array with scalar value.

)code" TVM_ADD_FILELINE)
.set_attrs_type<InitOpAttrs>()
.set_num_inputs(2)
.set_num_inputs(1)
.add_argument("fill_value", "double", "The value to fill.")
.add_argument("shape", "Tensor", "Target shape.")
.set_support_level(3)
.add_type_rel("Full", FullRel)
.set_attr<FTVMCompute>("FTVMCompute", FullCompute)
Expand Down
10 changes: 10 additions & 0 deletions src/relay/transforms/dynamic_to_static.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,16 @@ class DynamicToStaticMutator : public MixedModeMutator {
}
return Expr(nullptr);
}},
{Op::Get("dyn.full"),
[](const CallNode* call_node) {
if (const ConstantNode* shape = call_node->args[1].as<ConstantNode>()) {
CHECK_EQ(shape->data->ndim, 1);
const InitOpAttrs* param = call_node->attrs.as<InitOpAttrs>();
CHECK(param);
return MakeFull(call_node->args[0], ToVector(shape->data), param->dtype);
}
return Expr(nullptr);
}},
};
}

Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ static inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) {
}

static inline Expr Full(Expr fill_value, Array<IndexExpr> shape, DataType dtype) {
return MakeFull(fill_value, CheckConstantShape(shape), dtype);
return MakeFull(fill_value, CheckConstantShapeArrayInteger(shape), dtype);
}

static inline Expr Conv2D(Expr data, Expr weight, Array<IndexExpr> strides,
Expand Down
24 changes: 16 additions & 8 deletions tests/python/relay/dyn/test_dynamic_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,19 +103,27 @@ def verify_zeros_ones(shape, dtype):

func = relay.Function([dyn_shape], y)
ref_res = ref(shape, dtype)
for target, ctx in ctx_list():
if (target != 'cuda'): #skip cuda because no dynamic support for GPU
for kind in ["vm", "debug"]:
mod = tvm.ir.IRModule.from_expr(func)
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(np.array(shape).astype('int64'))
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_func(func, [np.array(shape).astype('int64')], ref_res.astype('int64'))
verify_zeros_ones((1, 3), 'int64')
verify_zeros_ones((8, 9, 1, 2), 'float32')

def test_dyn_full():
def verify_full(fill_value, src_shape, dtype):
x = relay.var("x", relay.scalar_type(dtype))
rank = len(src_shape)
dyn_src_shape = relay.var("dyn_scr_shape", relay.ty.TensorType((rank,), 'int64'))
z = relay.full(x, dyn_src_shape, dtype)
func = relay.Function([x, dyn_src_shape], z)
ref_res = np.full(src_shape, fill_value).astype(dtype)

verify_zeros_ones((124, 50), 'float64')
verify_func(func, [fill_value, np.array(src_shape).astype('int64')], ref_res)
verify_full(4, (1, 3, 4, 4), 'int32')
verify_full(4, (1, 3, 4, 4), 'int64') #does not pass, fix me
verify_full(4.0, (2, 50), 'float32')

if __name__ == "__main__":
test_dyn_reshape()
test_dyn_shape_reshape()
test_dyn_tile()
test_dyn_zeros_ones()
test_dyn_full()
2 changes: 1 addition & 1 deletion tests/python/relay/dyn/test_dynamic_op_level6.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,4 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype):


if __name__ == "__main__":
test_topk()
test_dynamic_topk()
1 change: 1 addition & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ def verify_full(fill_value, src_shape, dtype):
op_res = intrp.evaluate(func)(np.array(fill_value, dtype))
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_full(4, (1, 3, 4, 4), "int32")
#verify_full(4, (1, 3, 4, 4), "int64") # This does not pass, python int32 is not upcast to int64, not sure how to fix it.
verify_full(4.0, (1, 4), "float32")


Expand Down
20 changes: 20 additions & 0 deletions tests/python/relay/test_pass_dynamic_to_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,25 @@ def _verify(indices_shape, depth, on_value, off_value, axis, dtype):
_verify((3, 2, 4, 5), 6, 1, 0, 1, "int32")
_verify((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32")

def test_dynamic_to_static_full():
def verify_full(fill_value, fill_shape, dtype):
x = relay.var("x", relay.scalar_type(dtype))
y = relay.var("y", relay.TensorType(fill_shape, 'int64'))
z = relay.full(x, relay.shape_of(y), dtype)

func = run_infer_type(relay.Function([x, y], z))
func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())

zz = func2.body
assert isinstance(zz, relay.Call)
assert zz.checked_type == relay.TensorType(fill_shape, dtype)

ref_res = np.full(fill_shape, fill_value).astype(dtype)
y_data = np.random.uniform(low=-1, high=1, size=fill_shape).astype('int64')
verify_func(func2, [fill_value, y_data], ref_res)

verify_full(4, (1, 2, 3, 4), 'int32')
verify_full(4.0, (1, 2, 8, 10), 'float32')

if __name__ == "__main__":
test_dynamic_to_static_reshape()
Expand All @@ -312,3 +331,4 @@ def _verify(indices_shape, depth, on_value, off_value, axis, dtype):
test_dynamic_to_static_zeros_ones()
test_dynamic_to_static_resize()
test_dynamic_to_static_one_hot()
test_dynamic_to_static_full()