Skip to content

Commit

Permalink
add squeeze (apache#52)
Browse files Browse the repository at this point in the history
* add transform

* fix

* update doc

* Update tvm
  • Loading branch information
sxjscience authored and tqchen committed May 29, 2018
1 parent 5541a27 commit 3f599a6
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 2 deletions.
2 changes: 2 additions & 0 deletions nnvm/docs/top.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ This level enables fully connected multi-layer perceptron.
nnvm.symbol.flatten
nnvm.symbol.concatenate
nnvm.symbol.expand_dims
nnvm.symbol.squeeze
nnvm.symbol.split
nnvm.symbol.dropout
nnvm.symbol.batch_norm
Expand Down Expand Up @@ -112,6 +113,7 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.flatten
.. autofunction:: nnvm.symbol.concatenate
.. autofunction:: nnvm.symbol.expand_dims
.. autofunction:: nnvm.symbol.squeeze
.. autofunction:: nnvm.symbol.split
.. autofunction:: nnvm.symbol.dropout
.. autofunction:: nnvm.symbol.batch_norm
Expand Down
10 changes: 10 additions & 0 deletions nnvm/include/nnvm/top/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,16 @@ struct ReshapeParam : public dmlc::Parameter<ReshapeParam> {
}
};

struct SqueezeParam : public dmlc::Parameter<SqueezeParam> {
TShape axis;

DMLC_DECLARE_PARAMETER(SqueezeParam) {
DMLC_DECLARE_FIELD(axis).set_default(TShape())
.describe("The axis to squeeze in the input tensor."
" If set to None, all size=1 axes will be squeezed");
}
};

struct ScalarParam : public dmlc::Parameter<ScalarParam> {
double scalar;

Expand Down
2 changes: 1 addition & 1 deletion nnvm/python/nnvm/compiler/compile_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def items(self):
"""
res = _list_cache_items()
assert len(res) % 2 == 0
return [(res[2*i], res[2*i+1]) for i in range(len(res)/2)]
return [(res[2*i], res[2*i+1]) for i in range(len(res) // 2)]

def clear_cache(self):
"""Clear the existing cached functions."""
Expand Down
10 changes: 10 additions & 0 deletions nnvm/python/nnvm/top/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@ def compute_reshape(attrs, inputs, out_info):
reg.register_pattern("reshape", OpPattern.INJECTIVE)
reg.register_schedule("reshape", _fschedule_injective)

# reshape
@reg.register_compute("squeeze")
def compute_squeeze(attrs, inputs, out_info):
"""Compute definition of reshape"""
axis = attrs.get_int_tuple("axis")
axis = tuple(axis) if axis else None
return topi.squeeze(inputs[0], axis)
reg.register_pattern("squeeze", OpPattern.INJECTIVE)
reg.register_schedule("squeeze", _fschedule_injective)

# concatenate
@reg.register_compute("concatenate")
def compute_concatenate(attrs, inputs, out_info):
Expand Down
1 change: 1 addition & 0 deletions nnvm/src/top/op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <dmlc/parameter.h>
#include <string>
#include <vector>
#include <unordered_set>

namespace nnvm {
namespace top {
Expand Down
74 changes: 74 additions & 0 deletions nnvm/src/top/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,80 @@ The significance of each is explained below:
.set_num_outputs(1)
.set_support_level(3);

// squeeze
DMLC_REGISTER_PARAMETER(SqueezeParam);

inline bool SqueezeShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
const SqueezeParam& param = nnvm::get<SqueezeParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
const TShape& shp = (*in_attrs)[0];
if (shp.ndim() == 0) return false;

std::vector<int64_t> oshape;
if (param.axis.ndim() == 0) {
for (dim_t i = 0; i < shp.ndim(); ++i) {
if(shp[i] != 1) {
oshape.emplace_back(shp[i]);
}
}
} else {
std::unordered_set<dim_t> axis_checker;
for (size_t i = 0; i < param.axis.ndim(); ++i) {
if(param.axis[i] < 0) {
int real_axis = param.axis[i] + static_cast<int>(shp.ndim());
CHECK(real_axis < static_cast<int>(shp.ndim()) && real_axis >= 0);
axis_checker.insert(real_axis);
}
}
for (size_t i = 0; i < shp.ndim(); ++i) {
if(axis_checker.find(i) == axis_checker.end()) {
oshape.emplace_back(shp[i]);
} else {
CHECK_EQ(shp[i], 1) << "The squeezed axis must have shape 1!"
<< "Want to squeeze " << i
<< ", which has shape" << shp[i];
}
}
}
if(oshape.size() == 0) {
// Handles the case where all axes are squeezed.
oshape.push_back(1);
}
TShape out_shape(oshape.begin(), oshape.end());
CHECK_EQ(out_shape.Size(), shp.Size())
<< "Target shape size is different to source. "
<< "Target: " << out_shape
<< "\nSource: " << shp;
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, out_shape);
return true;
}

NNVM_REGISTER_OP(squeeze)
.describe(R"code(Squeeze axises in the array.
Examples::
x = [[[0], [1], [2]]]
squeeze(x) = [0, 1, 2]
squeeze(x, 0) = [[0], [1], [2]]
squeeze(x, (0, 2)) = [0, 1, 2]
)code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "Source input")
.add_arguments(SqueezeParam::__FIELDS__())
.set_attr_parser(ParamParser<SqueezeParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<SqueezeParam>)
.set_attr<nnvm::FInferShape>("FInferShape", SqueezeShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_support_level(1);

// tranpose
DMLC_REGISTER_PARAMETER(TransposeParam);

Expand Down
25 changes: 25 additions & 0 deletions nnvm/tests/python/compiler/test_top_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,30 @@ def test_split():
verify_split((5, 3), [3], axis=0)
verify_split((5, 9, 3), [3, 4], axis=1)


def verify_squeeze(dshape, axis):
x = sym.Variable("x")
if axis:
y = sym.squeeze(x, axis=axis)
else:
y = sym.squeeze(x)
y = y + 1
dtype = "float32"
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = graph_runtime.create(graph, lib, ctx)
# set input
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
m.run(x=data)
out_np = np.squeeze(data.asnumpy(), axis=axis) + 1
out = m.get_output(0, tvm.nd.empty(out_np.shape))
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)

def test_squeeze():
verify_squeeze((1, 3, 2, 5), None)
verify_squeeze((1, 3, 1), axis=0)
verify_squeeze((1, 3, 2, 5, 1), axis=-1)

if __name__ == "__main__":
test_split()
test_concatenate()
Expand All @@ -232,3 +256,4 @@ def test_split():
test_tanh()
test_sigmoid()
test_softmax()
test_squeeze()
1 change: 0 additions & 1 deletion nnvm/tests/python/compiler/test_top_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def test_reshape():
verify_reshape((2, 3, 4), (8, 3))
verify_reshape((4, 7), (2, 7, 2))


if __name__ == "__main__":
test_reshape()
test_reduce()
Expand Down

0 comments on commit 3f599a6

Please sign in to comment.