Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[FFI] part4: npx.embedding, npx.topk, npx.layer_norm, npx.leaky_relu (#…
Browse files Browse the repository at this point in the history
…20105)

* ffi: npx.embedding, npx.topk, npx.layer_norm, npx.leaky_relu

* fix lint

* fix build

* fix website

* fix build

* fix leaky_relu

* add test cases

* update
  • Loading branch information
barry-jin authored Apr 20, 2021
1 parent 7dba11a commit 3e4b121
Show file tree
Hide file tree
Showing 12 changed files with 1,027 additions and 9 deletions.
3 changes: 2 additions & 1 deletion python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,8 @@ def write_all_str(module_file, module_all_list):
'_npx_masked_log_softmax', '_npx_activation',
'_npx_batch_norm', '_npx_fully_connected', '_npx_pick',
'_npx_convolution', '_npx_deconvolution', '_npx_pooling',
'_npx_dropout', '_npx_one_hot', '_npx_rnn', '_npx_batch_dot',
'_npx_dropout', '_npx_one_hot', '_npx_rnn', '_npx_embedding',
'_npx_topk', '_npx_layer_norm', '_npx_leaky_relu', '_npx_batch_dot',
'_npx_broadcast_like', '_npx_arange_like'}

_NP_INTERNAL_OP_PREFIX = '_npi_'
Expand Down
288 changes: 283 additions & 5 deletions python/mxnet/ndarray/numpy_extension/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@

__all__ = ['softmax', 'log_softmax', 'masked_softmax', 'masked_log_softmax',
'activation', 'batch_norm', 'fully_connected', 'pick', 'convolution',
'deconvolution', 'pooling', 'dropout', 'one_hot', 'rnn',
'batch_dot', 'broadcast_like', 'arange_like']
'deconvolution', 'pooling', 'dropout', 'one_hot', 'rnn', 'embedding',
'topk', 'layer_norm', 'leaky_relu', 'batch_dot', 'broadcast_like',
'arange_like']


# pylint: disable=too-many-arguments
Expand Down Expand Up @@ -333,9 +334,12 @@ def batch_norm(x, gamma, beta, running_mean, running_var, eps=1e-3, momentum=0.9
out : NDArray or list of NDArrays
The output of this function.
"""
return _api_internal.batch_norm(x, gamma, beta, running_mean, running_var, eps, momentum,
fix_gamma, use_global_stats, output_mean_var, axis,
cudnn_off, min_calib_range, max_calib_range)
out = _api_internal.batch_norm(x, gamma, beta, running_mean, running_var, eps, momentum,
fix_gamma, use_global_stats, output_mean_var, axis,
cudnn_off, min_calib_range, max_calib_range)
if isinstance(out, NDArrayBase):
return out
return list(out)


# pylint: disable=too-many-arguments, unused-argument
Expand Down Expand Up @@ -1036,6 +1040,280 @@ def rnn(data=None, parameters=None, state=None, state_cell=None, sequence_length
lstm_state_clip_min, lstm_state_clip_max, lstm_state_clip_nan)


# pylint: disable=too-many-arguments, unused-argument
@set_module('mxnet.ndarray.numpy_extension')
def embedding(data, weight, input_dim=None, output_dim=None, dtype="float32", sparse_grad=False,
**kwargs):
r"""Maps integer indices to vector representations (embeddings).
This operator maps words to real-valued vectors in a high-dimensional space,
called word embeddings. These embeddings can capture semantic and syntactic properties of the words.
For example, it has been noted that in the learned embedding spaces, similar words tend
to be close to each other and dissimilar words far apart.
For an input array of shape (d1, ..., dK),
the shape of an output array is (d1, ..., dK, output_dim).
All the input values should be integers in the range [0, input_dim).
If the input_dim is ip0 and output_dim is op0, then shape of the embedding weight matrix must be
(ip0, op0).
When "sparse_grad" is False, if any index mentioned is too large, it is replaced by the index that
addresses the last vector in an embedding matrix.
When "sparse_grad" is True, an error will be raised if invalid indices are found.
The storage type of weight can be either row_sparse or default.
.. Note::
If "sparse_grad" is set to True, the storage type of gradient w.r.t weights will be
"row_sparse". Only a subset of optimizers support sparse gradients, including SGD, AdaGrad
and Adam. Note that by default lazy updates is turned on, which may perform differently
from standard updates. For more details, please check the Optimization API at:
https://mxnet.incubator.apache.org/api/python/optimization/optimization.html
Parameters
----------
data : NDArray
The input array to the embedding operator.
weight : NDArray
The embedding weight matrix.
input_dim : long, required
Vocabulary size of the input indices.
output_dim : long, required
Dimension of the embedding vectors.
dtype : {'bfloat16', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'},
optional, default='float32'
Data type of weight.
sparse_grad : boolean, optional, default=0
Compute row sparse gradient in the backward calculation.
If set to True, the grad's storage type is row_sparse.
Returns
-------
out : NDArray or list of NDArrays
The output of this function.
Example
-------
>>> input_dim = 4
>>> output_dim = 5
Each row in weight matrix y represents a word. So, y = (w0,w1,w2,w3)
>>> y = np.arange(input_dim * output_dim).reshape(input_dim, output_dim)
>>> y
array([[ 0., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 9.],
[10., 11., 12., 13., 14.],
[15., 16., 17., 18., 19.]])
Input array x represents n-grams(2-gram). So, x = [(w1,w3), (w0,w2)]
>>> x = np.array([[1., 3.], [0., 2.]])
>>> x
array([[1., 3.],
[0., 2.]])
Mapped input x to its vector representation y.
>>> npx.embedding(x, y, input_dim, output_dim)
array([[[ 5., 6., 7., 8., 9.],
[15., 16., 17., 18., 19.]],
[[ 0., 1., 2., 3., 4.],
[10., 11., 12., 13., 14.]]])
"""
assert input_dim > 1, "Vocabulary size of the input indices should be greater than 1."
assert output_dim > 1, "Dimension of the embedding vectors should greater than 1."
return _api_internal.embedding(data, weight, input_dim, output_dim, dtype, sparse_grad)


# pylint: disable=too-many-arguments
@set_module('mxnet.ndarray.numpy_extension')
def topk(data, axis=-1, k=1, ret_typ="indices", is_ascend=False, dtype="float32"):
r"""Returns the indices of the top *k* elements in an input array along the given
axis (by default).
If ret_type is set to 'value' returns the value of top *k* elements (instead of indices).
In case of ret_type = 'both', both value and index would be returned.
The returned elements will be sorted.
Parameters
----------
data : NDArray
The input array
axis : int or None, optional, default='-1'
Axis along which to choose the top k indices.
If not given, the flattened array is used. Default is -1.
k : int, optional, default='1'
Number of top elements to select, should be always smaller than or equal to
the element number in the given axis. A global sort is performed if set k < 1.
ret_typ : {'both', 'indices', 'mask', 'value'},optional, default='indices'
The return type.
"value" means to return the top k values,
"indices" means to return the indices of the top k values,
"mask" means to return a mask array containing 0 and 1. 1 means the top k values.
"both" means to return a list of both values and indices of top k elements.
is_ascend : boolean, optional, default=0
Whether to choose k largest or k smallest elements.
Top K largest elements will be chosen if set to false.
dtype : {'float16', 'float32', 'float64', 'int32', 'int64', 'uint8'},
optional, default='float32'
DType of the output indices when ret_typ is "indices" or "both".
An error will be raised if the selected data type cannot precisely represent the indices.
Returns
-------
out : NDArray or list of NDArrays
The output of this function.
Example
-------
>>> x = np.array([[0.3, 0.2, 0.4], [0.1, 0.3, 0.2]])
returns an index of the largest element on last axis
>>> npx.topk(x)
array([[2.],
[1.]])
returns the value of top-2 largest elements on last axis
>>> npx.topk(x, ret_typ='value', k=2)
array([[0.4, 0.3],
[0.3, 0.2]])
returns the value of top-2 smallest elements on last axis
>>> npx.topk(x, ret_typ='value', k=2, is_ascend=1)
array([[0.2, 0.3],
[0.1, 0.2]])
returns the value of top-2 largest elements on axis 0
>>> npx.topk(x, axis=0, ret_typ='value', k=2)
array([[0.3, 0.3, 0.4],
[0.1, 0.2, 0.2]])
flattens and then returns list of both values and indices
>>> npx.topk(x, ret_typ='both', k=2)
[array([[0.4, 0.3], [0.3, 0.2]]),
array([[2., 0.], [1., 2.]])]
"""
out = _api_internal.topk(data, axis, k, ret_typ, is_ascend, dtype)
if isinstance(out, NDArrayBase):
return out
return list(out)


# pylint: disable=too-many-arguments
@set_module('mxnet.ndarray.numpy_extension')
def layer_norm(data=None, gamma=None, beta=None, axis=None, eps=None, output_mean_var=None):
r"""Layer normalization.
Normalizes the channels of the input tensor by mean and variance, and applies a scale ``gamma`` as
well as offset ``beta``.
Assume the input has more than one dimension and we normalize along axis 1.
We first compute the mean and variance along this axis and then
compute the normalized output, which has the same shape as input, as following:
.. math::
out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis) + \epsilon}} * gamma + beta
Both ``gamma`` and ``beta`` are learnable parameters.
Unlike BatchNorm and InstanceNorm, the *mean* and *var* are computed along the channel dimension.
Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta``
have shape *(k,)*. If ``output_mean_var`` is set to be true, then outputs both ``data_mean`` and
``data_std``. Note that no gradient will be passed through these two outputs.
The parameter ``axis`` specifies which axis of the input shape denotes
the 'channel' (separately normalized groups). The default is -1, which sets the channel
axis to be the last item in the input shape.
Parameters
----------
data : NDArray
Input data to layer normalization
gamma : NDArray
gamma array
beta : NDArray
beta array
axis : int, optional, default='-1'
The axis to perform layer normalization.
Usually, this should be be axis of the channel dimension.
Negative values means indexing from right to left.
eps : float, optional, default=9.99999975e-06
An `epsilon` parameter to prevent division by 0.
output_mean_var : boolean, optional, default=0
Output the mean and std calculated along the given axis.
Returns
-------
out : NDArray or list of NDArrays
The output of this function.
"""
out = _api_internal.layer_norm(data, gamma, beta, axis, eps, output_mean_var)
if isinstance(out, NDArrayBase):
return out
return list(out)


# pylint: disable=too-many-arguments, unused-argument
@set_module('mxnet.ndarray.numpy_extension')
def leaky_relu(data=None, gamma=None, act_type="leaky", slope=0.25, lower_bound=0.125,
upper_bound=0.334, **kwargs):
r"""Applies Leaky rectified linear unit activation element-wise to the input.
Leaky ReLUs attempt to fix the "dying ReLU" problem by allowing a small `slope`
when the input is negative and has a slope of one when input is positive.
The following modified ReLU Activation functions are supported:
- *elu*: Exponential Linear Unit. `y = x > 0 ? x : slope * (exp(x)-1)`
- *gelu*: Gaussian Error Linear Unit. `y = 0.5 * x * (1 + erf(x / sqrt(2)))`
- *selu*: Scaled Exponential Linear Unit. `y = lambda * (x > 0 ? x : alpha * (exp(x) - 1))` where
*lambda = 1.0507009873554804934193349852946* and *alpha = 1.6732632423543772848170429916717*.
- *leaky*: Leaky ReLU. `y = x > 0 ? x : slope * x`
- *prelu*: Parametric ReLU. This is same as *leaky* except that `slope` is learnt during training.
- *rrelu*: Randomized ReLU. same as *leaky* but the `slope` is uniformly and randomly chosen from
*[lower_bound, upper_bound)* for training, while fixed to be
*(lower_bound+upper_bound)/2* for inference.
Parameters
----------
data : NDArray
Input data to activation function.
gamma : NDArray
Input data to activation function.
act_type : {'elu', 'gelu', 'leaky', 'prelu', 'rrelu', 'selu'},optional, default='leaky'
Activation function to be applied.
slope : float, optional, default=0.25
Init slope for the activation. (For leaky and elu only)
lower_bound : float, optional, default=0.125
Lower bound of random slope. (For rrelu only)
upper_bound : float, optional, default=0.333999991
Upper bound of random slope. (For rrelu only)
Returns
-------
out : NDArray or list of NDArrays
The output of this function.
"""
if act_type == "prelu":
assert gamma is not None, "If activation function is prelu, please provide input gamma"
out = _api_internal.leaky_relu(data, gamma, act_type, slope, lower_bound, upper_bound)
if isinstance(out, NDArrayBase):
return out
return list(out)
else:
return _api_internal.leaky_relu(data, act_type, slope, lower_bound, upper_bound)


# pylint: disable=too-many-arguments
@set_module('mxnet.ndarray.numpy_extension')
def batch_dot(a, b, transpose_a=False, transpose_b=False, forward_stype="default"):
Expand Down
Loading

0 comments on commit 3e4b121

Please sign in to comment.