Skip to content

Commit

Permalink
[Relay][Op] Add unbiased variance op and corresponding support in pyt…
Browse files Browse the repository at this point in the history
…orch frontend (apache#6232)
  • Loading branch information
shiwenloong authored and wjliu1998 committed Aug 13, 2020
1 parent 5df5698 commit e32d759
Show file tree
Hide file tree
Showing 11 changed files with 147 additions and 38 deletions.
31 changes: 31 additions & 0 deletions include/tvm/relay/attrs/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,37 @@ struct ReduceAttrs : public tvm::AttrsNode<ReduceAttrs> {
"Whether to perform reduction on axis that are NOT in axis instead.");
}
};

struct VarianceAttrs : public tvm::AttrsNode<VarianceAttrs> {
Array<Integer> axis;
bool keepdims;
bool exclude;
bool unbiased;

TVM_DECLARE_ATTRS(VarianceAttrs, "relay.attrs.VarianceAttrs") {
TVM_ATTR_FIELD(axis)
.set_default(NullValue<Array<Integer>>())
.describe(R"code(The axis or axes along which to perform the reduction.
The default, `axis=()`, will compute over all elements into a
scalar array with shape `(1,)`.
If `axis` is int, a reduction is performed on a particular axis.
If `axis` is a tuple of ints, a reduction is performed on all the axes
specified in the tuple.
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.)code");

TVM_ATTR_FIELD(keepdims).set_default(false).describe(
"If this is set to `True`, the reduced axes are left "
"in the result as dimension with size one.");
TVM_ATTR_FIELD(exclude).set_default(false).describe(
"Whether to perform reduction on axis that are NOT in axis instead.");
TVM_ATTR_FIELD(unbiased).set_default(false).describe("Whether to use the unbiased estimation.");
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_REDUCE_H_
25 changes: 10 additions & 15 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,28 +1262,23 @@ def _impl(inputs, input_types):
keepdims = bool(inputs[3])
unbiased = bool(inputs[2])

if unbiased:
msg = "Currently only supports standard-deviation calculated via the biased "\
"estimator. PyTorch's Bessel's correction is not supported."
raise NotImplementedError(msg)

return _op.reduce.std(data, axis=axis, keepdims=keepdims)
return _op.reduce.std(data, axis=axis, keepdims=keepdims, unbiased=unbiased)

return _impl

def _variance():
def _impl(inputs, input_types):
data = inputs[0]
axis = list(_infer_shape(inputs[1]))
keepdims = bool(inputs[3])
unbiased = bool(inputs[2])

if unbiased:
msg = "Currently only supports standard-deviation calculated via the biased "\
"estimator. PyTorch's Bessel's correction is not supported."
raise NotImplementedError(msg)
if len(inputs) == 2:
axis = None
keepdims = False
unbiased = bool(inputs[1])
else:
axis = list(_infer_shape(inputs[1]))
keepdims = bool(inputs[3])
unbiased = bool(inputs[2])

return _op.reduce.variance(data, axis=axis, keepdims=keepdims)
return _op.reduce.variance(data, axis=axis, keepdims=keepdims, unbiased=unbiased)

return _impl

Expand Down
15 changes: 11 additions & 4 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,16 +589,23 @@ def mean_grad(orig, grad):
def variance_grad(orig, grad):
"""Note that we take mean as an argument in the variance node"""
data, data_mean, axis = orig.args[0], orig.args[1], _get_reduce_axis(orig)
unbiased = orig.attrs.unbiased
shape = data.checked_type.concrete_shape
if axis is None:
axis = list(range(len(data.checked_type.concrete_shape)))
if not orig.attrs.keepdims:
grad = _unreduce_expand(grad, axis)
mult = 2.0
mult1 = 2.0
mult2 = -2.0
count = 1
for a in axis:
mult /= shape[a]
return [(grad * const(mult, dtype=data.checked_type.dtype)) * data,
const(-2, dtype=data.checked_type.dtype) * grad * data_mean]
count *= shape[a]
if unbiased:
mult2 = mult2 * count / (count - 1)
count -= 1
mult1 /= count
return [(grad * const(mult1, dtype=data.checked_type.dtype)) * data,
const(mult2, dtype=data.checked_type.dtype) * grad * data_mean]


@register_gradient("copy")
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,11 @@ class ReduceAttrs(Attrs):
"""Attributes used in reduction operators (e.g. sum)"""


@tvm._ffi.register_object("relay.attrs.VarianceAttrs")
class VarianceAttrs(Attrs):
"""Attributes used in reduction operators (e.g. sum)"""


@tvm._ffi.register_object("relay.attrs.RequantizeAttrs")
class RequantizeAttrs(Attrs):
"""Attributes used in requantize operators"""
Expand Down
18 changes: 12 additions & 6 deletions python/tvm/relay/op/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def mean(data, axis=None, keepdims=False, exclude=False):
return _make.mean(data, axis, keepdims, exclude)


def variance(data, axis=None, keepdims=False, exclude=False):
def variance(data, axis=None, keepdims=False, exclude=False, unbiased=False):
"""Computes the variance of data over given axes.
Parameters
Expand All @@ -334,17 +334,20 @@ def variance(data, axis=None, keepdims=False, exclude=False):
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
unbiased : bool
If this is set to True, the unbiased estimation will be used.
Returns
-------
result : relay.Expr
The computed result.
"""
axis = [axis] if isinstance(axis, int) else axis
m = mean(data, axis, True, exclude)
return _make._variance(data, m, axis, keepdims, exclude)
return _make._variance(data, m, axis, keepdims, exclude, unbiased)


def std(data, axis=None, keepdims=False, exclude=False):
def std(data, axis=None, keepdims=False, exclude=False, unbiased=False):
"""Computes the standard deviation of data over given axes.
Parameters
Expand All @@ -366,14 +369,17 @@ def std(data, axis=None, keepdims=False, exclude=False):
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
unbiased : bool
If this is set to True, the unbiased estimation will be used.
Returns
-------
result : relay.Expr
The computed result.
"""
axis = [axis] if isinstance(axis, int) else axis
m = mean(data, axis, True, exclude)
return sqrt(_make._variance(data, m, axis, keepdims, exclude))
return sqrt(_make._variance(data, m, axis, keepdims, exclude, unbiased))


def mean_variance(data, axis=None, keepdims=False, exclude=False):
Expand Down Expand Up @@ -405,7 +411,7 @@ def mean_variance(data, axis=None, keepdims=False, exclude=False):
"""
axis = [axis] if isinstance(axis, int) else axis
m = mean(data, axis, True, exclude)
var = _make._variance(data, m, axis, keepdims, exclude)
var = _make._variance(data, m, axis, keepdims, exclude, False)
if not keepdims:
m = squeeze(m)
return TupleWrapper(Tuple((m, var)), 2)
Expand Down Expand Up @@ -440,7 +446,7 @@ def mean_std(data, axis=None, keepdims=False, exclude=False):
"""
axis = [axis] if isinstance(axis, int) else axis
m = mean(data, axis, True, exclude)
s = sqrt(_make._variance(data, m, axis, keepdims, exclude))
s = sqrt(_make._variance(data, m, axis, keepdims, exclude, False))
if not keepdims:
m = squeeze(m)
return TupleWrapper(Tuple((m, s)), 2)
Expand Down
3 changes: 2 additions & 1 deletion src/relay/op/make_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ Expr MakeTile(Expr data, Array<Integer> reps);

Expr MakeTopK(Expr data, int k, int axis, String ret_type, bool is_ascend, DataType dtype);

Expr MakeVariance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, bool exclude);
Expr MakeVariance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, bool exclude,
bool unbiased);

Expr MakeZeros(Array<Integer> shape, DataType dtype);

Expand Down
31 changes: 22 additions & 9 deletions src/relay/op/tensor/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ namespace tvm {
namespace relay {

TVM_REGISTER_NODE_TYPE(ReduceAttrs);
TVM_REGISTER_NODE_TYPE(VarianceAttrs);

/*!
* \brief GetReduceAxes, get the new axis from indim and other arguments
Expand Down Expand Up @@ -193,12 +194,14 @@ Array<te::Tensor> ReduceCompute(const Attrs& attrs, const Array<te::Tensor>& inp
/*!
* \brief ReduceShapeImpl get the outshape for the reduction operator
* \param in_shape Shape of input data.
* \param param ReduceAttrs details.
* \param param Attrs details.
* \param reporter The reporter to report solution to.
* \return oshape Output shape inferred.
* \tparam AttrsType The attribute type.
*/
template <typename AttrsType>
inline std::vector<IndexExpr> ReduceShapeImpl(const std::vector<IndexExpr>& in_shape,
const ReduceAttrs* param,
const AttrsType* param,
const TypeReporter& reporter) {
uint32_t indim = in_shape.size();
auto r_axes = GetReduceAxes(indim, param->axis, param->exclude);
Expand Down Expand Up @@ -542,7 +545,7 @@ bool VarianceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
std::vector<IndexExpr> mean_shape(mean->shape.begin(), mean->shape.end());
CHECK_EQ(in_shape.size(), mean_shape.size());

const ReduceAttrs* param = attrs.as<ReduceAttrs>();
const VarianceAttrs* param = attrs.as<VarianceAttrs>();
CHECK(param != nullptr);

// assign output type and shape
Expand All @@ -554,39 +557,49 @@ bool VarianceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
Array<te::Tensor> VarianceCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
IndexExpr count = tir::make_const(inputs[0]->dtype, 1);
const ReduceAttrs* param = attrs.as<ReduceAttrs>();
const VarianceAttrs* param = attrs.as<VarianceAttrs>();
CHECK(param != nullptr);
auto axes = param->axis;
bool unbiased = param->unbiased;
auto data = inputs[0];
auto mean = inputs[1];
for (int64_t i : GetReduceAxes(data->shape.size(), param->axis, param->exclude)) {
count *= data->shape[i];
}
if (unbiased) {
count -= 1;
}
std::vector<Integer> expand_shape;
auto sq_diff = topi::power(topi::subtract(data, mean), 2);
auto var = topi::divide(ReduceCompute(attrs, {sq_diff}, out_type, topi::sum)[0], count);
if (param->exclude) {
axes = GetExcludeAxes(sq_diff->shape.size(), param->axis);
CHECK_NE(axes.size(), 0);
}
auto var = topi::divide(topi::sum(sq_diff, axes, param->keepdims, false), count);

return {var};
}

Expr MakeVariance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, bool exclude) {
auto attrs = make_object<ReduceAttrs>();
Expr MakeVariance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, bool exclude,
bool unbiased = false) {
auto attrs = make_object<VarianceAttrs>();
attrs->axis = std::move(axis);
attrs->keepdims = keepdims;
attrs->exclude = exclude;
attrs->unbiased = unbiased;
static const Op& op = Op::Get("variance");
return Call(op, {data, mean}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make._variance").set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 5>(MakeVariance, args, rv);
runtime::detail::unpack_call<Expr, 6>(MakeVariance, args, rv);
});

RELAY_REGISTER_OP("variance")
.describe(R"code(Computes the variance of array elements over given axes.
)code" TVM_ADD_FILELINE)
.set_attrs_type<ReduceAttrs>()
.set_attrs_type<VarianceAttrs>()
.set_support_level(4)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
Expand Down
5 changes: 3 additions & 2 deletions src/relay/transforms/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -580,8 +580,9 @@ inline Expr Mean(Expr data, Array<Integer> axis, bool keepdims, bool exclude) {
return MakeReduce(data, axis, keepdims, exclude, "mean");
}

inline Expr Variance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, bool exclude) {
return MakeVariance(data, mean, axis, keepdims, exclude);
inline Expr Variance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, bool exclude,
bool unbiased = false) {
return MakeVariance(data, mean, axis, keepdims, exclude, unbiased);
}

static inline Expr Where(const Expr& condition, const Expr& x, const Expr& y) {
Expand Down
35 changes: 35 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1873,13 +1873,28 @@ class Std6(Module):
def forward(self, *args):
return args[0].std(unbiased=False)

class Std7(Module):
def forward(self, *args):
return args[0].std(dim=1, keepdim=False, unbiased=True)

class Std8(Module):
def forward(self, *args):
return args[0].std(dim=(2,3), keepdim=True, unbiased=True)

class Std9(Module):
def forward(self, *args):
return args[0].std(unbiased=True)

input_data = torch.rand(input_shape).float()
verify_model(Std1().float().eval(), input_data=input_data)
verify_model(Std2().float().eval(), input_data=input_data)
verify_model(Std3().float().eval(), input_data=input_data)
verify_model(Std4().float().eval(), input_data=input_data)
verify_model(Std5().float().eval(), input_data=input_data)
verify_model(Std6().float().eval(), input_data=input_data)
verify_model(Std7().float().eval(), input_data=input_data)
verify_model(Std8().float().eval(), input_data=input_data)
verify_model(Std9().float().eval(), input_data=input_data)


def test_forward_variance():
Expand All @@ -1906,12 +1921,32 @@ class Variance5(Module):
def forward(self, *args):
return args[0].var(dim=(2,3), keepdim=False, unbiased=False)

class Variance6(Module):
def forward(self, *args):
return args[0].var(unbiased=False)

class Variance7(Module):
def forward(self, *args):
return args[0].var(dim=1, keepdim=False, unbiased=True)

class Variance8(Module):
def forward(self, *args):
return args[0].var(dim=(2,3), keepdim=True, unbiased=True)

class Variance9(Module):
def forward(self, *args):
return args[0].var(unbiased=True)

input_data = torch.rand(input_shape).float()
verify_model(Variance1().float().eval(), input_data=input_data)
verify_model(Variance2().float().eval(), input_data=input_data)
verify_model(Variance3().float().eval(), input_data=input_data)
verify_model(Variance4().float().eval(), input_data=input_data)
verify_model(Variance5().float().eval(), input_data=input_data)
verify_model(Variance6().float().eval(), input_data=input_data)
verify_model(Variance7().float().eval(), input_data=input_data)
verify_model(Variance8().float().eval(), input_data=input_data)
verify_model(Variance9().float().eval(), input_data=input_data)


def test_forward_rsub():
Expand Down
5 changes: 4 additions & 1 deletion tests/python/relay/test_op_grad_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ def verify_reduction_grad(red_fn, d_shape, axis=None, keepdims=False, exclude=Fa


def test_reduction_grad():
for op in (relay.sum, relay.variance, relay.mean):
def _unbiased_variance(x, axis=None, keepdims=False, exclude=False):
return relay.variance(x, axis=axis, keepdims=keepdims, exclude=exclude, unbiased=True)

for op in (relay.sum, relay.variance, _unbiased_variance, relay.mean):
verify_reduction_grad(op, (4, 2))
verify_reduction_grad(op, (4, 2), axis=-1, keepdims=True)
verify_reduction_grad(op, (4, 2, 1), axis=(1, 2), exclude=True)
Expand Down
12 changes: 12 additions & 0 deletions tests/python/relay/test_op_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,26 @@ def _np_log_sum_exp(x, axis, keepdims=False):
if not keepdims:
x = np.squeeze(x, axis=axis)
return x

def _unbiased_relay_wrapper(f):
def _unbiased_func(x, axis=None, keepdims=False, exclude=False):
return f(x, axis=axis, keepdims=keepdims, exclude=exclude, unbiased=True)
return _unbiased_func

def _unbiased_np_wrapper(f):
def _unbiased_func(a, axis=None, dtype=None, keepdims=None):
return f(a, axis=axis, dtype=dtype, ddof=1, keepdims=keepdims)
return _unbiased_func

d1, d2, d3, d4 = te.var("d1"), te.var("d2"), te.var("d3"), te.var("d4")
for func in [[relay.sum, np.sum],
[relay.max, np.max],
[relay.min, np.min],
[relay.mean, np.mean],
[relay.variance, np.var],
[_unbiased_relay_wrapper(relay.variance), _unbiased_np_wrapper(np.var)],
[relay.std, np.std],
[_unbiased_relay_wrapper(relay.std), _unbiased_np_wrapper(np.std)],
[relay.prod, np.prod],
[relay.all, np.all],
[relay.any, np.any],
Expand Down

0 comments on commit e32d759

Please sign in to comment.