Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[FFI] Part3: npx.pooling, npx.dropout, npx.one_hot, npx.rnn #20102

Merged
merged 5 commits into from
Mar 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,8 @@ def write_all_str(module_file, module_all_list):
_NP_EXT_OP_IMPLEMENTED_SET = {'_npx_softmax', '_npx_log_softmax', '_npx_masked_softmax',
'_npx_masked_log_softmax', '_npx_activation',
'_npx_batch_norm', '_npx_fully_connected', '_npx_pick',
'_npx_convolution', '_npx_deconvolution'}
'_npx_convolution', '_npx_deconvolution', '_npx_pooling',
'_npx_dropout', '_npx_one_hot', '_npx_rnn'}

_NP_INTERNAL_OP_PREFIX = '_npi_'

Expand Down
351 changes: 350 additions & 1 deletion python/mxnet/ndarray/numpy_extension/_op.py

Large diffs are not rendered by default.

319 changes: 318 additions & 1 deletion python/mxnet/numpy_extension/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

__all__ = ['softmax', 'log_softmax', 'masked_softmax', 'masked_log_softmax',
'activation', 'batch_norm', 'fully_connected', 'pick', 'convolution',
'deconvolution']
'deconvolution', 'pooling', 'dropout', 'one_hot', 'rnn']


# pylint: disable=too-many-arguments
Expand Down Expand Up @@ -656,3 +656,320 @@ def deconvolution(data=None, weight=None, bias=None, kernel=None, stride=None, d
target_shape=target_shape, num_filter=num_filter,
num_group=num_group, workspace=workspace, no_bias=no_bias,
cudnn_tune=cudnn_tune, cudnn_off=cudnn_off, layout=layout)


# pylint: disable=too-many-arguments, unused-argument
@set_module('mxnet.numpy_extension')
def pooling(data=None, kernel=None, stride=None, pad=None, pool_type="max",
pooling_convention="valid", global_pool=False, cudnn_off=False,
p_value=None, count_include_pad=None, layout=None, **kwargs):
r"""Performs pooling on the input.

The shapes for 1-D pooling are

- **data** and **out**: *(batch_size, channel, width)* (NCW layout) or
*(batch_size, width, channel)* (NWC layout),

The shapes for 2-D pooling are

- **data** and **out**: *(batch_size, channel, height, width)* (NCHW layout) or
*(batch_size, height, width, channel)* (NHWC layout),

out_height = f(height, kernel[0], pad[0], stride[0])
out_width = f(width, kernel[1], pad[1], stride[1])

The definition of *f* depends on ``pooling_convention``, which has two options:

- **valid** (default)::

f(x, k, p, s) = floor((x+2*p-k)/s)+1

- **full**, which is compatible with Caffe::

f(x, k, p, s) = ceil((x+2*p-k)/s)+1

When ``global_pool`` is set to be true, then global pooling is performed. It will reset
``kernel=(height, width)`` and set the appropiate padding to 0.

Three pooling options are supported by ``pool_type``:

- **avg**: average pooling
- **max**: max pooling
- **sum**: sum pooling
- **lp**: Lp pooling

For 3-D pooling, an additional *depth* dimension is added before
*height*. Namely the input data and output will have shape *(batch_size, channel, depth,
height, width)* (NCDHW layout) or *(batch_size, depth, height, width, channel)* (NDHWC layout).

Notes on Lp pooling:

Lp pooling was first introduced by this paper: https://arxiv.org/pdf/1204.3968.pdf.
L-1 pooling is simply sum pooling, while L-inf pooling is simply max pooling.
We can see that Lp pooling stands between those two, in practice the most common value for p is 2.

For each window ``X``, the mathematical expression for Lp pooling is:

:math:`f(X) = \sqrt[p]{\sum_{x}^{X} x^p}`

Parameters
----------
data : NDArray
Input data to the pooling operator.
kernel : Shape(tuple), optional, default=[]
Pooling kernel size: (y, x) or (d, y, x)
pool_type : {'avg', 'lp', 'max', 'sum'},optional, default='max'
Pooling type to be applied.
global_pool : boolean, optional, default=0
Ignore kernel size, do global pooling based on current input feature map.
cudnn_off : boolean, optional, default=0
Turn off cudnn pooling and use MXNet pooling operator.
pooling_convention : {'full', 'same', 'valid'},optional, default='valid'
Pooling convention to be applied.
stride : Shape(tuple), optional, default=[]
Stride: for pooling (y, x) or (d, y, x). Defaults to 1 for each dimension.
pad : Shape(tuple), optional, default=[]
Pad for pooling: (y, x) or (d, y, x). Defaults to no padding.
p_value : int or None, optional, default='None'
Value of p for Lp pooling, can be 1 or 2, required for Lp Pooling.
count_include_pad : boolean or None, optional, default=None
Only used for AvgPool, specify whether to count padding elements for averagecalculation.
For example, with a 5*5 kernel on a 3*3 corner of a image,the sum of the 9 valid elements will
be divided by 25 if this is set to true,or it will be divided by 9 if this is set to false.
Defaults to true.
layout : {None, 'NCDHW', 'NCHW', 'NCW', 'NDHWC', 'NHWC', 'NWC'},optional, default='None'
Set layout for input and output. Empty for
default layout: NCW for 1d, NCHW for 2d and NCDHW for 3d.

Returns
-------
out : NDArray or list of NDArrays
The output of this function.
"""
return _mx_nd_npx.pooling(data=data, kernel=kernel, stride=stride, pad=pad,
pool_type=pool_type, pooling_convention=pooling_convention,
global_pool=global_pool, cudnn_off=cudnn_off, p_value=p_value,
count_include_pad=count_include_pad, layout=layout)


# pylint: disable=too-many-arguments, unused-argument
@set_module('mxnet.numpy_extension')
def dropout(data, p=0.5, mode="training", axes=None, cudnn_off=True, **kwargs):
r"""Applies dropout operation to input array.

- During training, each element of the input is set to zero with probability p.
The whole array is rescaled by :math:`1/(1-p)` to keep the expected
sum of the input unchanged.

- During testing, this operator does not change the input if mode is 'training'.
If mode is 'always', the same computaion as during training will be applied.

Parameters
----------
data : NDArray
Input array to which dropout will be applied.
p : float, optional, default=0.5
Fraction of the input that gets dropped out during training time.
mode : {'always', 'training'},optional, default='training'
Whether to only turn on dropout during training or to also turn on for inference.
axes : Shape(tuple), optional, default=[]
Axes for variational dropout kernel.
cudnn_off : boolean or None, optional, default=0
Whether to turn off cudnn in dropout operator. This option is ignored if axes is specified.

Returns
-------
out : NDArray or list of NDArrays
The output of this function.
"""
return _mx_nd_npx.dropout(data=data, p=p, mode=mode, axes=axes, cudnn_off=cudnn_off)


# pylint: disable=too-many-arguments
@set_module('mxnet.numpy_extension')
def one_hot(data, depth=None, on_value=1.0, off_value=0.0, dtype="float32"):
r"""Returns a one-hot array.

The locations represented by `indices` take value `on_value`, while all
other locations take value `off_value`.

`one_hot` operation with `indices` of shape ``(i0, i1)`` and `depth` of ``d`` would result
in an output array of shape ``(i0, i1, d)`` with::

output[i,j,:] = off_value
output[i,j,indices[i,j]] = on_value

Parameters
----------
indices : NDArray
array of locations where to set on_value
depth : long, required
Depth of the one hot dimension.
on_value : double, optional, default=1
The value assigned to the locations represented by indices.
off_value : double, optional, default=0
The value assigned to the locations not represented by indices.
dtype : {'bfloat16', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'},
optional, default='float32'
DType of the output

Returns
-------
out : NDArray or list of NDArrays
The output of this function.

Example
-------
>>> data = np.array([1,0,2,0])
>>> npx.one_hot(data, 3)
array([[0., 1., 0.],
[1., 0., 0.],
[0., 0., 1.],
[1., 0., 0.]], dtype=float64)
>>> npx.one_hot(data, 3, on_value=8, off_value=1, dtype='int32')
array([[1, 8, 1],
[8, 1, 1],
[1, 1, 8],
[8, 1, 1]], dtype=int32)
>>> data = np.array([[1,0],[1,0],[2,0]])
>>> npx.one_hot(data, 3)
array([[[0., 1., 0.],
[1., 0., 0.]],

[[0., 1., 0.],
[1., 0., 0.]],

[[0., 0., 1.],
[1., 0., 0.]]], dtype=float64)
"""
return _mx_nd_npx.one_hot(data=data, depth=depth, on_value=on_value, off_value=off_value,
dtype=dtype)


# pylint: disable=too-many-arguments, unused-argument
@set_module('mxnet.numpy_extension')
def rnn(data=None, parameters=None, state=None, state_cell=None, sequence_length=None,
mode=None, state_size=None, num_layers=None, bidirectional=False,
state_outputs=False, p=0.0, use_sequence_length=False, projection_size=None,
lstm_state_clip_min=None, lstm_state_clip_max=None, lstm_state_clip_nan=None):
r"""Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are
implemented, with both multi-layer and bidirectional support.

When the input data is of type float32 and the environment variables MXNET_CUDA_ALLOW_TENSOR_CORE
and MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION are set to 1, this operator will try to use
pseudo-float16 precision (float32 math with float16 I/O) precision in order to use
Tensor Cores on suitable NVIDIA GPUs. This can sometimes give significant speedups.

**Vanilla RNN**

Applies a single-gate recurrent layer to input X. Two kinds of activation function are supported:
ReLU and Tanh.

With ReLU activation function:

.. math::
h_t = relu(W_{ih} * x_t + b_{ih} + W_{hh} * h_{(t-1)} + b_{hh})

With Tanh activtion function:

.. math::
h_t = \tanh(W_{ih} * x_t + b_{ih} + W_{hh} * h_{(t-1)} + b_{hh})

Reference paper: Finding structure in time - Elman, 1988.
https://crl.ucsd.edu/~elman/Papers/fsit.pdf

**LSTM**

Long Short-Term Memory - Hochreiter, 1997. http://www.bioinf.jku.at/publications/older/2604.pdf

.. math::
\begin{array}{ll}
i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\
f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\
g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hc} h_{(t-1)} + b_{hg}) \\
o_t = \mathrm{sigmoid}(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\
c_t = f_t * c_{(t-1)} + i_t * g_t \\
h_t = o_t * \tanh(c_t)
\end{array}

With the projection size being set, LSTM could use the projection feature to reduce the parameters
size and give some speedups without significant damage to the accuracy.

Long Short-Term Memory Based Recurrent Neural Network Architectures for Large Vocabulary Speech
Recognition - Sak et al. 2014. https://arxiv.org/abs/1402.1128

.. math::
\begin{array}{ll}
i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{ri} r_{(t-1)} + b_{ri}) \\
f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{rf} r_{(t-1)} + b_{rf}) \\
g_t = \tanh(W_{ig} x_t + b_{ig} + W_{rc} r_{(t-1)} + b_{rg}) \\
o_t = \mathrm{sigmoid}(W_{io} x_t + b_{o} + W_{ro} r_{(t-1)} + b_{ro}) \\
c_t = f_t * c_{(t-1)} + i_t * g_t \\
h_t = o_t * \tanh(c_t)
r_t = W_{hr} h_t
\end{array}

**GRU**

Gated Recurrent Unit - Cho et al. 2014. http://arxiv.org/abs/1406.1078

The definition of GRU here is slightly different from paper but compatible with CUDNN.

.. math::
\begin{array}{ll}
r_t = \mathrm{sigmoid}(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
z_t = \mathrm{sigmoid}(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\
h_t = (1 - z_t) * n_t + z_t * h_{(t-1)} \\
\end{array}

Parameters
----------
data : NDArray
Input data to RNN
parameters : NDArray
Vector of all RNN trainable parameters concatenated
state : NDArray
initial hidden state of the RNN
state_cell : NDArray
initial cell state for LSTM networks (only for LSTM)
sequence_length : NDArray
Vector of valid sequence lengths for each element in batch.
(Only used if use_sequence_length kwarg is True)
state_size : int (non-negative), required
size of the state for each layer
num_layers : int (non-negative), required
number of stacked layers
bidirectional : boolean, optional, default=0
whether to use bidirectional recurrent layers
mode : {'gru', 'lstm', 'rnn_relu', 'rnn_tanh'}, required
the type of RNN to compute
p : float, optional, default=0
drop rate of the dropout on the outputs of each RNN layer, except the last layer.
state_outputs : boolean, optional, default=0
Whether to have the states as symbol outputs.
projection_size : int or None, optional, default='None'
size of project size
lstm_state_clip_min : double or None, optional, default=None
Minimum clip value of LSTM states. This option must be used together with lstm_state_clip_max.
lstm_state_clip_max : double or None, optional, default=None
Maximum clip value of LSTM states. This option must be used together with lstm_state_clip_min.
lstm_state_clip_nan : boolean, optional, default=0
Whether to stop NaN from propagating in state by clipping it to min/max.
If clipping range is not specified, this option is ignored.
use_sequence_length : boolean, optional, default=0
If set to true, this layer takes in an extra input parameter `sequence_length`
to specify variable length sequence

Returns
-------
out : NDArray or list of NDArrays
The output of this function.
"""
return _mx_nd_npx.rnn(data=data, parameters=parameters, state=state, state_cell=state_cell,
sequence_length=sequence_length, mode=mode, state_size=state_size,
num_layers=num_layers, bidirectional=bidirectional,
state_outputs=state_outputs, p=p, use_sequence_length=use_sequence_length,
projection_size=projection_size, lstm_state_clip_min=lstm_state_clip_min,
lstm_state_clip_max=lstm_state_clip_max,
lstm_state_clip_nan=lstm_state_clip_nan)
Loading