Skip to content

Commit

Permalink
[Relay][Op] Dropout and batch_norm (#1870)
Browse files Browse the repository at this point in the history
  • Loading branch information
slyubomirsky authored and tqchen committed Oct 16, 2018
1 parent a5be8fd commit b4946e7
Show file tree
Hide file tree
Showing 6 changed files with 428 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ This level enables fully connected multi-layer perceptron.
tvm.relay.tanh
tvm.relay.sigmoid
tvm.relay.nn.relu
tvm.relay.nn.dropout
tvm.relay.nn.batch_norm


**Level 2: Convolutions**
Expand Down
35 changes: 35 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,41 @@ struct PadAttrs : public tvm::AttrsNode<PadAttrs> {
}
};

/*! \brief Attributes used in dropout operator */
struct DropoutAttrs : public tvm::AttrsNode<DropoutAttrs> {
double rate;
TVM_DECLARE_ATTRS(DropoutAttrs, "relay.attrs.DropoutAttrs") {
TVM_ATTR_FIELD(rate)
.describe("Fraction of the input that gets dropped out during training time")
.set_default(0.5);
}
}; // struct DropoutAttrs

/*! \brief Attributes used in batch_norm operator */
struct BatchNormAttrs : public tvm::AttrsNode<BatchNormAttrs> {
int axis;
double epsilon;
bool center;
bool scale;

TVM_DECLARE_ATTRS(BatchNormAttrs, "relay.attrs.BatchNormAttrs") {
TVM_ATTR_FIELD(axis)
.describe("Specify which shape axis denotes the channel.")
.set_default(1);
TVM_ATTR_FIELD(epsilon)
.describe("Small float added to variance to avoid dividing by zero")
.set_default(1e-5);
TVM_ATTR_FIELD(center)
.describe("If True, add offset of beta to normalized tensor. If False, beta is ignored")
.set_default(true);
TVM_ATTR_FIELD(scale)
.describe("If True, multiply by gamma. If False, gamma is not used. "
"When the next layer is piecewise linear (also, e.g., nn.relu), "
"this can be disabled since the scaling will be done by the next layer.")
.set_default(true);
}
}; // struct BatchNormAttrs

/*! \brief Attributes for LRN operator */
struct LRNAttrs : public tvm::AttrsNode<LRNAttrs> {
IndexExpr size;
Expand Down
28 changes: 28 additions & 0 deletions python/tvm/relay/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,32 @@
from .env import Environment


class TupleWrapper(tvm._ffi.node.NodeGeneric):
"""TupleWrapper.
This class is a Python wrapper for a Relay tuple of known size.
It allows for accessing the fields of the Relay tuple as though
it were a Python tuple.
"""

def __init__(self, tuple_value, size):
self.tuple_value = tuple_value
self.size = size


def asnode(self):
"""Returns the underlying Relay tuple if this wrapper is passed
as an argument to an FFI function."""

return self.tuple_value

def __getitem__(self, key):
return self.tuple_value.fields[key]

def __len__(self):
return len(self.tuple_value.fields)


def _convert_to_value(arg, ctxt=tvm.cpu(0)):
# type: (Any, tvm.Context) -> tvm.nd.NDArray
"""Convert Python values into the appropriate types
Expand Down Expand Up @@ -61,6 +87,8 @@ def convert(arg):
return relay.Tuple([convert(el) for el in arg])
elif isinstance(arg, PartialFunc):
return arg.to_func()
elif isinstance(arg, tvm._ffi.node.NodeGeneric):
return arg.asnode()
else:
value = _convert_to_value(arg)
return Constant(value)
Expand Down
102 changes: 102 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Neural network operations."""
from __future__ import absolute_import as _abs
from tvm.relay.ir_builder import TupleWrapper
from . import _make


Expand Down Expand Up @@ -484,6 +485,7 @@ def lrn(data, size=5, axis=1, bias=2, alpha=.00001, beta=0.75):
.. math::
(data / (bias + (alpha * sum_data ^2 /size))^beta)
Parameters
----------
data : relay.Expr
Expand Down Expand Up @@ -535,3 +537,103 @@ def l2_normalize(data, eps, axis=None):
The computed result.
"""
return _make.l2_normalize(data, eps, axis)

def dropout(data, rate=0.5):
"""Applies the dropout operation to the input array.
During training, each element of the input is set to zero with
probability ``p``. The whole array is rescaled by ``1/(1-p)``
to keep the expected sum of the input unchanged.
Parameters
----------
data : relay.Expr
The input data to the operator.
rate : float, optional (default=0.5)
The probability for an element to be reset to 0.
Returns
-------
result : relay.Tuple([relay.Expr, relay.Expr])
The first member of the tuple is the result of dropping elements from ``data``
and rescaling. The second member is a "mask" tensor, which is of the same
shape and data type as ``data`` and, for each element in ``data``, is 1.0
if the element was not dropped and 0.0 if it was.
"""
result = _make.dropout(data, rate)
return TupleWrapper(result, 2)

def batch_norm(data, gamma, beta, moving_mean, moving_var,
axis=1, epsilon=1e-5, center=True, scale=True):
r"""
Batch normalization layer (Ioffe and Szegedy, 2014).
Normalizes the input at each batch, i.e. applies a transformation
that maintains the mean activation close to 0 and the activation
standard deviation close to 1.
.. math::
data\_mean[i] = mean(data[:,i,:,...]) \\
data\_var[i] = var(data[:,i,:,...])
Then compute the normalized output, which has the same shape as input, as following:
.. math::
out[:,i,:,...] = \frac{data[:,i,:,...] - data\_mean[i]}{\sqrt{data\_var[i]+\epsilon}}
* gamma[i] + beta[i]
Both *mean* and *var* returns a scalar by treating the input as a vector.
Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta``
have shape *(k,)*.
Besides the inputs and the outputs, this operator accepts two auxiliary
states, ``moving_mean`` and ``moving_var``, which are *k*-length
vectors. They are global statistics for the whole dataset, which are updated by::
moving_mean = moving_mean * momentum + data_mean * (1 - momentum)
moving_var = moving_var * momentum + data_var * (1 - momentum)
The parameter ``axis`` specifies which axis of the input shape denotes
the 'channel' (separately normalized groups). The default is 1.
Specifying -1 sets the channel axis to be the last item in the input shape.
.. note::
This operator can be optimized away for inference.
Parameters
----------
data : relay.Expr
Input to which batch_norm will be applied.
gamma : relay.Expr
The gamma scale factor.
beta : relay.Expr
The beta offset factor.
moving_mean : relay.Expr
Running mean of input,
moving_var : relay.Expr
Running variance of input.
axis : int, optional, default=1
Specify along which shape axis the channel is specified.
epsilon : double, optional, default=1e-5
Small float added to variance to avoid diving by zero.
center : boolean, optional, default=True
If True, add offset of beta to normalized tensor, If False,
beta is ignored.
scale : boolean, optional, default=True
If true, multiply by gamma. If False, gamma is not used.
When the next layer is piecewise linear (also e.g. nn.relu),
this can be disabled since the scalingwill be done by the next layer.
Returns
-------
result : relay.Tuple([relay.Expr, relay.Expr, relay.Expr])
Tuple of normed data (same shape as input), new running mean (k-length vector),
and new running variance (k-length vector)
"""
result = _make.batch_norm(data, gamma, beta, moving_mean, moving_var,
axis, epsilon, center, scale)
return TupleWrapper(result, 3)
172 changes: 172 additions & 0 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,5 +217,177 @@ Normalizes along dimension axis using an L2 norm
.set_support_level(2)
.add_type_rel("Identity", IdentityRel);

// Dropout
TVM_REGISTER_NODE_TYPE(DropoutAttrs);

bool DropoutRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;

// dropout returns the original tensor with dropout applied
// and a mask tensor (1.0 where element not dropped, 0.0 where dropped)
auto ret_type = TensorTypeNode::make(data->shape, data->dtype);
reporter->Assign(types[1], TupleTypeNode::make(Array<Type>({ret_type, ret_type})));
return true;
}

Expr MakeDropout(Expr data, double rate) {
auto attrs = make_node<DropoutAttrs>();
attrs->rate = rate;
static const Op& op = Op::Get("nn.dropout");
return CallNode::make(op, {data}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op.nn._make.dropout")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeDropout, args, rv);
});

RELAY_REGISTER_OP("nn.dropout")
.describe(R"code(Applies the dropout operation to the input array.
During training, each element of the input is set to zero with probability ``p``.
The whole array is rescaled by ``1/(1-p)`` to keep the expected sum of the input unchanged.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "Input to which dropout will be applied.")
.set_support_level(1)
.add_type_rel("Dropout", DropoutRel);

// batch_norm
TVM_REGISTER_NODE_TYPE(BatchNormAttrs);

bool CheckVectorLength(int64_t dim, const DataType& dtype, Type vector, const char* name) {
const auto* candidate = vector.as<TensorTypeNode>();
CHECK(candidate != nullptr)
<< name << " should be a vector but is not a tensor type,";
CHECK_EQ(dtype, candidate->dtype)
<< name << " should be of the same data type as the original but it is not.";
CHECK_EQ(candidate->shape.size(), 1)
<< name << " should be a vector but has a shape of "
<< candidate->shape.size() << " dimensions instead of 1.";

const int64_t* length = as_const_int(candidate->shape[0]);
if (length == nullptr) return false;
CHECK(*length == dim)
<< name << " should be as long as the channel but has length "
<< *length << " instead of " << dim << ".";
return true;
}

bool BatchNormRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 6);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
if (data->shape.size() == 0) return false;

const BatchNormAttrs* param = attrs.as<BatchNormAttrs>();

// axis of -1 means use the last dimension
CHECK(param->axis >= -1 && param->axis < (int)data->shape.size());
int axis = (param->axis != -1) ? param->axis : data->shape.size() - 1;

auto dim = as_const_int(data->shape[axis]);
if (dim == nullptr) return false;

// if we are using beta and gamma, they need to be of shape (dim,)
if (param->scale && !CheckVectorLength(*dim, data->dtype, types[1], "The gamma scale factor")) {
return false;
}

if (param->center && !CheckVectorLength(*dim, data->dtype, types[2], "The beta offset factor")) {
return false;
}

// the two running averages must also be vectors of length dim
if (!CheckVectorLength(*dim, data->dtype, types[3], "The moving mean")) {
return false;
}
if (!CheckVectorLength(*dim, data->dtype, types[4], "The moving variance")) {
return false;
}

// output is a tuple of the normed data (same shape as input), new running mean,
// and new running average (the latter two are both vectors of length dim)
std::vector<Type> fields;
auto vec_ty = TensorTypeNode::make(Array<IndexExpr>({data->shape[axis]}),
data->dtype);
fields.push_back(TensorTypeNode::make(data->shape, data->dtype));
fields.push_back(vec_ty);
fields.push_back(vec_ty);
reporter->Assign(types[5], TupleTypeNode::make(Array<Type>(fields)));
return true;
}

Expr MakeBatchNorm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var,
int axis, double epsilon, bool center, bool scale) {
auto attrs = make_node<BatchNormAttrs>();
attrs->axis = axis;
attrs->epsilon = epsilon;
attrs->center = center;
attrs->scale = scale;
static const Op& op = Op::Get("nn.batch_norm");
return CallNode::make(op, {data, gamma, beta, moving_mean, moving_var}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op.nn._make.batch_norm")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 9>(MakeBatchNorm, args, rv);
});

RELAY_REGISTER_OP("nn.batch_norm")
.describe(R"code(Batch normalization layer (Ioffe and Szegedy, 2014).
Normalizes the input at each batch, i.e. applies a transformation
that maintains the mean activation close to 0 and the activation
standard deviation close to 1.
.. math::
data\_mean[i] = mean(data[:,i,:,...]) \\
data\_var[i] = var(data[:,i,:,...])
Then compute the normalized output, which has the same shape as input, as following:
.. math::
out[:,i,:,...] = \frac{data[:,i,:,...] - data\_mean[i]}{\sqrt{data\_var[i]+\epsilon}} \
* gamma[i] + beta[i]
Both *mean* and *var* returns a scalar by treating the input as a vector.
Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta`` have shape *(k,)*.
Besides the inputs and the outputs, this operator accepts two auxiliary
states, ``moving_mean`` and ``moving_var``, which are *k*-length
vectors. They are global statistics for the whole dataset, which are updated
by::
moving_mean = moving_mean * momentum + data_mean * (1 - momentum)
moving_var = moving_var * momentum + data_var * (1 - momentum)
The parameter ``axis`` specifies which axis of the input shape denotes
the 'channel' (separately normalized groups). The default is 1. Specifying -1 sets the channel
axis to be the last item in the input shape.
.. note::
This operator can be optimized away for inference.
)code" TVM_ADD_FILELINE)
.set_num_inputs(5)
.add_argument("data", "Tensor", "Input to which batch_norm will be applied.")
.add_argument("gamma", "Tensor", "The gamma scale factor.")
.add_argument("beta", "Tensor", "The beta offset factor.")
.add_argument("moving_mean", "Tensor", "Running mean of input.")
.add_argument("moving_var", "Tensor", "Running variance of input.")
.set_support_level(1)
.add_type_rel("BatchNorm", BatchNormRel);

} // namespace relay
} // namespace tvm
Loading

0 comments on commit b4946e7

Please sign in to comment.