Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI,RELAY][TFLITE] Sparse to dense operator #5447

Merged
merged 11 commits into from
Jun 4, 2020
2 changes: 2 additions & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ List of operators
topi.expand_dims
topi.reshape
topi.unravel_index
topi.sparse_to_dense
topi.squeeze
topi.concatenate
topi.split
Expand Down Expand Up @@ -154,6 +155,7 @@ topi
.. autofunction:: topi.expand_dims
.. autofunction:: topi.reshape
.. autofunction:: topi.unravel_index
.. autofunction:: topi.sparse_to_dense
.. autofunction:: topi.squeeze
.. autofunction:: topi.concatenate
.. autofunction:: topi.split
Expand Down
1 change: 1 addition & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ This level enables additional math and transform operators.
tvm.relay.tile
tvm.relay.reverse
tvm.relay.unravel_index
tvm.relay.sparse_to_dense


**Level 4: Broadcast and Reductions**
Expand Down
9 changes: 9 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,15 @@ struct SequenceMaskAttrs : public tvm::AttrsNode<SequenceMaskAttrs> {
}
}; // struct SequenceMaskAttrs.

/*! \brief Attributes used in sparse_to_dense operator */
struct SparseToDenseAttrs : public tvm::AttrsNode<SparseToDenseAttrs> {
Array<Integer> output_shape;

TVM_DECLARE_ATTRS(SparseToDenseAttrs, "relay.attrs.SparseToDenseAttrs") {
TVM_ATTR_FIELD(output_shape).describe("Shape of the dense output tensor");
}
}; // struct SparseToDenseAttrs

/*! \brief Attributes for ndarray_size operator */
struct NdarraySizeAttrs : public tvm::AttrsNode<NdarraySizeAttrs> {
DataType dtype;
Expand Down
32 changes: 32 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from .common import infer_shape as _infer_shape
from .tflite_flexbuffer import FlexBufferDecoder


__all__ = ['from_tflite']

class TensorWrapper(object):
Expand Down Expand Up @@ -130,6 +131,7 @@ def __init__(self, model, subgraph, exp_tab):
'SOFTMAX': self.convert_softmax,
'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd,
'SPACE_TO_DEPTH': self.convert_space_to_depth,
'SPARSE_TO_DENSE': self.convert_sparse_to_dense,
'SPLIT': self.convert_split,
'SPLIT_V': self.convert_split_v,
'SQRT': self.convert_sqrt,
Expand Down Expand Up @@ -2267,6 +2269,36 @@ def convert_space_to_depth(self, op):

return out

def convert_sparse_to_dense(self, op):
"""Convert TFLite SPARSE_TO_DENSE"""
try:
from tflite.TensorType import TensorType
except ImportError:
raise ImportError("The tflite package must be installed")

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 4, "input tensors length should be 4"

indices, values = input_tensors[0], input_tensors[2]
default_value = input_tensors[3]
output_shape = input_tensors[1]

dhruvaray marked this conversation as resolved.
Show resolved Hide resolved
for t in input_tensors:
assert not t.qnn_params, "Quantized input is not expected."

for t in [indices, output_shape]:
t_type = t.tensor.Type()
assert t_type in (TensorType.INT32, TensorType.INT64)

out = _op.sparse_to_dense(
self.get_tensor_expr(indices),
list(self.get_tensor_value(output_shape)),
self.get_tensor_expr(values),
self.get_tensor_expr(default_value)
)

return out

def convert_prelu(self, op):
"""Convert TFLite PReLU"""
input_tensors = self.get_input_tensors(op)
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
_reg.register_injective_schedule("one_hot")
_reg.register_reduce_schedule("collapse_sum_like")
_reg.register_injective_schedule("unravel_index")
_reg.register_injective_schedule("sparse_to_dense")

# concatenate
_reg.register_schedule("concatenate", strategy.schedule_concatenate)
Expand Down
31 changes: 31 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,3 +890,34 @@ def unravel_index(indices, shape):
"""

return _make.unravel_index(indices, shape)

def sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value=0):
"""Converts a sparse representation into a dense tensor.

Example::
- sparse_to_dense([[0, 0], [1, 1]], [2, 2], [3, 3], 0) = [[3, 0], [0, 3]]

Parameters
----------
sparse_indices : relay.Expr
A 0-D, 1-D, or 2-D tensor of integers containing location of sparse values.

output_shape : relay.Expr
A list of integers. Shape of the dense output tensor.

sparse_values : relay.Expr
A 0-D or 1-D tensor containing the sparse values for the sparse indices.

default_value : relay.Expr
A 0-D tensor containing the default value for the remaining locations.
Defaults to 0.

Returns
-------
result : relay.Expr
Dense tensor of shape output_shape. Has the same type as sparse_values.
"""

if default_value == 0:
default_value = const(0)
return _make.sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value)
74 changes: 74 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2524,5 +2524,79 @@ Example::
.set_attr<FTVMCompute>("FTVMCompute", UnRavelIndexCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

// sparse_to_dense
TVM_REGISTER_NODE_TYPE(SparseToDenseAttrs);

bool SparseToDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(num_inputs, 3);
auto sparse_indices = types[0].as<TensorTypeNode>();
auto sparse_values = types[1].as<TensorTypeNode>();
auto default_value = types[2].as<TensorTypeNode>();
CHECK(sparse_indices != nullptr && sparse_values != nullptr && default_value != nullptr);

CHECK(sparse_indices->dtype.is_int()) << "sparse_indices must be tensor of integers";

CHECK_LE(sparse_indices->shape.size(), 3)
<< "sparse_indices must be a tensor of either 0D, 1D or 2D";

CHECK_LE(sparse_values->shape.size(), 2) << "sparse_values must be a tensor of either 0D, 1D";

CHECK_EQ(default_value->shape.size(), 0) << "default_value should be a scalar";

const auto* param = attrs.as<SparseToDenseAttrs>();
CHECK(param != nullptr);

Array<IndexExpr> oshape;
for (auto i : param->output_shape) {
oshape.push_back(i);
}
reporter->Assign(types[3], TensorType(oshape, sparse_values->dtype));
return true;
}

Array<te::Tensor> SparseToDenseCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
CHECK_EQ(inputs.size(), 3);
const auto* param = attrs.as<SparseToDenseAttrs>();
CHECK(param != nullptr);
return {topi::sparse_to_dense(inputs[0], param->output_shape, inputs[1], inputs[2]())};
}

TVM_REGISTER_GLOBAL("relay.op._make.sparse_to_dense")
.set_body_typed([](Expr indices, Array<Integer> output_shape, Expr values, Expr default_value) {
auto attrs = make_object<SparseToDenseAttrs>();
attrs->output_shape = std::move(output_shape);
static const Op& op = Op::Get("sparse_to_dense");
return Call(op, {indices, values, default_value}, Attrs(attrs));
});

RELAY_REGISTER_OP("sparse_to_dense")
.describe(R"code(A dense tensor from a sparse representation.

- **sparse_indices**: A 0-D, 1-D, or 2-D tensor of integers containing location of sparse values

- **output_shape**: A list of integers. Shape of the dense output tensor.

- **sparse_values**: A 0-D or 1-D tensor containing the sparse values for the sparse indices.

- **default_value**: A 0-D tensor containing the default value for the remaining locations. Defaults to 0.

Example::
- sparse_to_dense([0, 0], [1, 2]], [3, 4], [1, 2], 0) = [[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]]

)code" TVM_ADD_FILELINE)
.set_num_inputs(3)
.set_support_level(3)
.set_attrs_type<SparseToDenseAttrs>()
.add_argument("sparse_indices", "Tensor", "Contains sparse indices.")
.add_argument("sparse_values", "Tensor", "Contains values for sparse indices.")
.add_argument("default_value", "Tensor", "Value to set for non-sparse indices. Defaults to 0.")
.add_type_rel("SparseToDense", SparseToDenseRel)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute", SparseToDenseCompute);

} // namespace relay
} // namespace tvm
76 changes: 75 additions & 1 deletion tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,6 @@ def test_all_resize():
if 'RESIZE_NEAREST_NEIGHBOR' in dir(BuiltinOperator()):
_test_resize(tf.image.resize_nearest_neighbor, data, align_corners=False)


#######################################################################
# Concatenation
# -------------
Expand Down Expand Up @@ -1862,6 +1861,80 @@ def test_forward_spacetodepth():
_test_spacetodepth(np.random.normal(size=[1, 32, 32, 4]).astype("float32"), 2)
_test_spacetodepth(np.random.normal(size=[1, 16, 8, 32]).astype("float32"), 4)

#######################################################################
# Sparse To Dense
# ---------------
def _test_sparse_to_dense(sparse_indices, sparse_values, default_value, output_shape):
# tflite 1.13 convert method does not accept empty shapes
if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
with tf.Graph().as_default():
indices = tf.placeholder(shape=sparse_indices.shape, dtype=str(sparse_indices.dtype), name="indices")
values = tf.placeholder(shape=sparse_values.shape, dtype=str(sparse_values.dtype), name="values")
oshape = tf.constant(output_shape, shape=output_shape.shape, dtype=str(output_shape.dtype))

if default_value == None:
output = tf.sparse_to_dense(indices, oshape, values)
compare_tflite_with_tvm(
[sparse_indices, sparse_values],
["indices", "values"],
[indices, values],
[output]
)
else:
dv = tf.placeholder(shape=(), dtype=str(default_value.dtype), name="default_value")
output = tf.sparse_to_dense(indices, oshape, values, dv)
compare_tflite_with_tvm(
[sparse_indices, sparse_values, default_value],
["indices", "values", "default_value"],
[indices, values, dv],
[output]
)

def test_forward_sparse_to_dense():
'''
Works in tvm/topi/tensorflow. But tflite converter breaks this test case
_test_sparse_to_dense(
np.int32(1),
np.int32(3),
np.int32(0),
np.array([5]).astype("int32")
)
'''
# vector
_test_sparse_to_dense(
np.array([0, 1, 4]).astype("int32"),
np.array([3, 3, 3]).astype("int32"),
np.int32(0),
np.array([5]).astype("int32")
)
# vector nXd
_test_sparse_to_dense(
np.array([[0, 0], [1, 2]]).astype("int32"),
np.array([1, 2]).astype("int32"),
np.int32(0),
np.array([3, 4]).astype("int32")
)
_test_sparse_to_dense(
np.array([[0, 0, 0], [1, 2, 3]]).astype("int32"),
np.array([1, 2]).astype("int32"),
np.int32(4),
np.array([2, 3, 4]).astype("int32")
)
# floats
_test_sparse_to_dense(
np.array([0, 1, 4]).astype("int32"),
np.array([3.1, 3.1, 3.1]).astype("float32"),
np.float32(3.5),
np.array([5]).astype("int32")
)
# default value not specified
_test_sparse_to_dense(
np.array([0, 1, 4]).astype("int32"),
np.array([3.1, 3.1, 3.1]).astype("float32"),
None,
np.array([5]).astype("int32")
)

#######################################################################
# Fully Connected
# ---------------
Expand Down Expand Up @@ -2305,6 +2378,7 @@ def test_forward_mediapipe_hand_landmark():
test_forward_stridedslice()
test_forward_depthtospace()
test_forward_spacetodepth()
test_forward_sparse_to_dense()
test_forward_select()
test_forward_quantize_dequantize()

Expand Down
55 changes: 54 additions & 1 deletion tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,58 @@ def verify_unravel_index(indices, shape, dtype):
# output which is inline with Tensorflow
# verify_unravel_index([0, 1, 2, 5], [2, 2], dtype)

def test_sparse_to_dense():
def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_shape, xpected):
sparse_indices_data = np.array(sparse_indices)
sparse_values_data = np.array(sparse_values)
default_value_data = np.array(default_value)

a = relay.var("a", relay.TensorType(sparse_indices_data.shape, str(sparse_indices_data.dtype)))
b = relay.var("b", relay.TensorType(sparse_values_data.shape, str(sparse_values_data.dtype)))
if default_value is None:
args = [a, b]
d = relay.sparse_to_dense(a, output_shape, b)
else:
c = relay.var("c", relay.TensorType(default_value_data.shape, str(default_value_data.dtype)))
args = [a, b, c]
d = relay.sparse_to_dense(a, output_shape, b, c)

zz = run_infer_type(d)
assert zz.checked_type == relay.ty.TensorType(output_shape, str(sparse_values_data.dtype))

func = relay.Function(args, d)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
if default_value is None:
op_res = intrp.evaluate(func)(sparse_indices_data, sparse_values_data)
else:
op_res = intrp.evaluate(func)(
sparse_indices_data, sparse_values_data, default_value_data
)
tvm.testing.assert_allclose(op_res.asnumpy(), xpected, rtol=1e-5)


verify_sparse_to_dense(1, 3, 0, [5], [0, 3, 0, 0, 0]) # scalar
verify_sparse_to_dense([0, 1, 4], [3, 3, 3], 0, [5], [3, 3, 0, 0, 3]) # vector
verify_sparse_to_dense([[0, 0], [1, 2]], [1, 2], 0, [3, 4], [[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]]) # nXd
verify_sparse_to_dense(
[[0, 0, 0], [1, 2, 3]],
[1, 2],
4,
[2, 3, 4],
[[[1, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 4]], [[4, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 2]]]
) # nXd
verify_sparse_to_dense([0, 1, 4], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1]) # floats
verify_sparse_to_dense(1, 3, None, [5], [0, 3, 0, 0, 0]) # default value not specified

#negative test cases
#sparse indices should be ints
#verify_sparse_to_dense([[0.1, 1.1, 4.1], [0,2,4]], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])
#sparse_values should be 0d or 1d only
#verify_sparse_to_dense([[0, 1, 4], [0, 2, 4]], [[[3.1, 3.1, 3.1]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])
#sparse_indices should not be > 2d tensor
#verify_sparse_to_dense([[[[0, 1, 4], [0, 2, 4]]]], [[[[3.1, 3.1, 3.1]]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])

if __name__ == "__main__":
test_arange()
Expand Down Expand Up @@ -780,4 +832,5 @@ def verify_unravel_index(indices, shape, dtype):
test_gather_nd()
test_isfinite()
test_isinf()
test_unravel_index()
test_unravel_index()
test_sparse_to_dense()
Loading