diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 5e3912bab261..c60d31cc9578 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -797,7 +797,8 @@ def write_all_str(module_file, module_all_list): _NP_EXT_OP_IMPLEMENTED_SET = {'_npx_softmax', '_npx_log_softmax', '_npx_masked_softmax', '_npx_masked_log_softmax', '_npx_activation', '_npx_batch_norm', '_npx_fully_connected', '_npx_pick', - '_npx_convolution', '_npx_deconvolution'} + '_npx_convolution', '_npx_deconvolution', '_npx_pooling', + '_npx_dropout', '_npx_one_hot', '_npx_rnn'} _NP_INTERNAL_OP_PREFIX = '_npi_' diff --git a/python/mxnet/ndarray/numpy_extension/_op.py b/python/mxnet/ndarray/numpy_extension/_op.py index 346b85d42fc8..718022dc5b8c 100644 --- a/python/mxnet/ndarray/numpy_extension/_op.py +++ b/python/mxnet/ndarray/numpy_extension/_op.py @@ -20,13 +20,14 @@ import numpy as _np from .. import numpy as np # pylint: disable=reimported +from .._internal import NDArrayBase from . import _api_internal from ...util import set_module __all__ = ['softmax', 'log_softmax', 'masked_softmax', 'masked_log_softmax', 'activation', 'batch_norm', 'fully_connected', 'pick', 'convolution', - 'deconvolution'] + 'deconvolution', 'pooling', 'dropout', 'one_hot', 'rnn'] # pylint: disable=too-many-arguments @@ -708,3 +709,351 @@ def deconvolution(data=None, weight=None, bias=None, kernel=None, stride=None, d return _api_internal.deconvolution(data, weight, bias, kernel, stride, dilate, pad, adj, target_shape, num_filter, num_group, workspace, no_bias, cudnn_tune, cudnn_off, layout) + + +# pylint: disable=too-many-arguments, unused-argument +@set_module('mxnet.ndarray.numpy_extension') +def pooling(data=None, kernel=None, stride=None, pad=None, pool_type="max", + pooling_convention="valid", global_pool=False, cudnn_off=False, + p_value=None, count_include_pad=None, layout=None, **kwargs): + r"""Performs pooling on the input. + + The shapes for 1-D pooling are + + - **data** and **out**: *(batch_size, channel, width)* (NCW layout) or + *(batch_size, width, channel)* (NWC layout), + + The shapes for 2-D pooling are + + - **data** and **out**: *(batch_size, channel, height, width)* (NCHW layout) or + *(batch_size, height, width, channel)* (NHWC layout), + + out_height = f(height, kernel[0], pad[0], stride[0]) + out_width = f(width, kernel[1], pad[1], stride[1]) + + The definition of *f* depends on ``pooling_convention``, which has two options: + + - **valid** (default):: + + f(x, k, p, s) = floor((x+2*p-k)/s)+1 + + - **full**, which is compatible with Caffe:: + + f(x, k, p, s) = ceil((x+2*p-k)/s)+1 + + When ``global_pool`` is set to be true, then global pooling is performed. It will reset + ``kernel=(height, width)`` and set the appropiate padding to 0. + + Three pooling options are supported by ``pool_type``: + + - **avg**: average pooling + - **max**: max pooling + - **sum**: sum pooling + - **lp**: Lp pooling + + For 3-D pooling, an additional *depth* dimension is added before + *height*. Namely the input data and output will have shape *(batch_size, channel, depth, + height, width)* (NCDHW layout) or *(batch_size, depth, height, width, channel)* (NDHWC layout). + + Notes on Lp pooling: + + Lp pooling was first introduced by this paper: https://arxiv.org/pdf/1204.3968.pdf. + L-1 pooling is simply sum pooling, while L-inf pooling is simply max pooling. + We can see that Lp pooling stands between those two, in practice the most common value for p is 2. + + For each window ``X``, the mathematical expression for Lp pooling is: + + :math:`f(X) = \sqrt[p]{\sum_{x}^{X} x^p}` + + Parameters + ---------- + data : NDArray + Input data to the pooling operator. + kernel : Shape(tuple), optional, default=[] + Pooling kernel size: (y, x) or (d, y, x) + pool_type : {'avg', 'lp', 'max', 'sum'},optional, default='max' + Pooling type to be applied. + global_pool : boolean, optional, default=0 + Ignore kernel size, do global pooling based on current input feature map. + cudnn_off : boolean, optional, default=0 + Turn off cudnn pooling and use MXNet pooling operator. + pooling_convention : {'full', 'same', 'valid'},optional, default='valid' + Pooling convention to be applied. + stride : Shape(tuple), optional, default=[] + Stride: for pooling (y, x) or (d, y, x). Defaults to 1 for each dimension. + pad : Shape(tuple), optional, default=[] + Pad for pooling: (y, x) or (d, y, x). Defaults to no padding. + p_value : int or None, optional, default='None' + Value of p for Lp pooling, can be 1 or 2, required for Lp Pooling. + count_include_pad : boolean or None, optional, default=None + Only used for AvgPool, specify whether to count padding elements for averagecalculation. + For example, with a 5*5 kernel on a 3*3 corner of a image,the sum of the 9 valid elements will + be divided by 25 if this is set to true,or it will be divided by 9 if this is set to false. + Defaults to true. + layout : {None, 'NCDHW', 'NCHW', 'NCW', 'NDHWC', 'NHWC', 'NWC'},optional, default='None' + Set layout for input and output. Empty for + default layout: NCW for 1d, NCHW for 2d and NCDHW for 3d. + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + """ + assert data is not None and kernel is not None, "Missing input data or kernel" + out = _api_internal.pooling(data, kernel, stride, pad, pool_type, pooling_convention, + global_pool, cudnn_off, p_value, count_include_pad, layout) + if isinstance(out, NDArrayBase): + return out + else: + return list(out) + + +# pylint: disable=too-many-arguments, unused-argument +@set_module('mxnet.ndarray.numpy_extension') +def dropout(data, p=0.5, mode="training", axes=None, cudnn_off=True, **kwargs): + r"""Applies dropout operation to input array. + + - During training, each element of the input is set to zero with probability p. + The whole array is rescaled by :math:`1/(1-p)` to keep the expected + sum of the input unchanged. + + - During testing, this operator does not change the input if mode is 'training'. + If mode is 'always', the same computaion as during training will be applied. + + Parameters + ---------- + data : NDArray + Input array to which dropout will be applied. + p : float, optional, default=0.5 + Fraction of the input that gets dropped out during training time. + mode : {'always', 'training'},optional, default='training' + Whether to only turn on dropout during training or to also turn on for inference. + axes : Shape(tuple), optional, default=[] + Axes for variational dropout kernel. + cudnn_off : boolean or None, optional, default=0 + Whether to turn off cudnn in dropout operator. This option is ignored if axes is specified. + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + """ + return _api_internal.dropout(data, p, mode, axes, cudnn_off) + + +# pylint: disable=too-many-arguments +@set_module('mxnet.ndarray.numpy_extension') +def one_hot(data, depth=None, on_value=1.0, off_value=0.0, dtype="float32"): + r"""Returns a one-hot array. + + The locations represented by `indices` take value `on_value`, while all + other locations take value `off_value`. + + `one_hot` operation with `indices` of shape ``(i0, i1)`` and `depth` of ``d`` would result + in an output array of shape ``(i0, i1, d)`` with:: + + output[i,j,:] = off_value + output[i,j,indices[i,j]] = on_value + + Parameters + ---------- + indices : NDArray + array of locations where to set on_value + depth : long, required + Depth of the one hot dimension. + on_value : double, optional, default=1 + The value assigned to the locations represented by indices. + off_value : double, optional, default=0 + The value assigned to the locations not represented by indices. + dtype : {'bfloat16', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'}, + optional, default='float32' + DType of the output + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + + Example + ------- + >>> data = np.array([1,0,2,0]) + >>> npx.one_hot(data, 3) + array([[0., 1., 0.], + [1., 0., 0.], + [0., 0., 1.], + [1., 0., 0.]], dtype=float64) + >>> npx.one_hot(data, 3, on_value=8, off_value=1, dtype='int32') + array([[1, 8, 1], + [8, 1, 1], + [1, 1, 8], + [8, 1, 1]], dtype=int32) + >>> data = np.array([[1,0],[1,0],[2,0]]) + >>> npx.one_hot(data, 3) + array([[[0., 1., 0.], + [1., 0., 0.]], + + [[0., 1., 0.], + [1., 0., 0.]], + + [[0., 0., 1.], + [1., 0., 0.]]], dtype=float64) + """ + assert depth is not None, "Please provide the depth of one hot dimension." + if not isinstance(dtype, str): + dtype = _np.dtype(dtype).name + return _api_internal.one_hot(data, depth, on_value, off_value, dtype) + + +# pylint: disable=too-many-arguments +@set_module('mxnet.ndarray.numpy_extension') +def rnn(data=None, parameters=None, state=None, state_cell=None, sequence_length=None, + mode=None, state_size=None, num_layers=None, bidirectional=False, + state_outputs=False, p=0.0, use_sequence_length=False, projection_size=None, + lstm_state_clip_min=None, lstm_state_clip_max=None, lstm_state_clip_nan=None): + r"""Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are + implemented, with both multi-layer and bidirectional support. + + When the input data is of type float32 and the environment variables MXNET_CUDA_ALLOW_TENSOR_CORE + and MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION are set to 1, this operator will try to use + pseudo-float16 precision (float32 math with float16 I/O) precision in order to use + Tensor Cores on suitable NVIDIA GPUs. This can sometimes give significant speedups. + + **Vanilla RNN** + + Applies a single-gate recurrent layer to input X. Two kinds of activation function are supported: + ReLU and Tanh. + + With ReLU activation function: + + .. math:: + h_t = relu(W_{ih} * x_t + b_{ih} + W_{hh} * h_{(t-1)} + b_{hh}) + + With Tanh activtion function: + + .. math:: + h_t = \tanh(W_{ih} * x_t + b_{ih} + W_{hh} * h_{(t-1)} + b_{hh}) + + Reference paper: Finding structure in time - Elman, 1988. + https://crl.ucsd.edu/~elman/Papers/fsit.pdf + + **LSTM** + + Long Short-Term Memory - Hochreiter, 1997. http://www.bioinf.jku.at/publications/older/2604.pdf + + .. math:: + \begin{array}{ll} + i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\ + f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\ + g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hc} h_{(t-1)} + b_{hg}) \\ + o_t = \mathrm{sigmoid}(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\ + c_t = f_t * c_{(t-1)} + i_t * g_t \\ + h_t = o_t * \tanh(c_t) + \end{array} + + With the projection size being set, LSTM could use the projection feature to reduce the parameters + size and give some speedups without significant damage to the accuracy. + + Long Short-Term Memory Based Recurrent Neural Network Architectures for Large Vocabulary Speech + Recognition - Sak et al. 2014. https://arxiv.org/abs/1402.1128 + + .. math:: + \begin{array}{ll} + i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{ri} r_{(t-1)} + b_{ri}) \\ + f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{rf} r_{(t-1)} + b_{rf}) \\ + g_t = \tanh(W_{ig} x_t + b_{ig} + W_{rc} r_{(t-1)} + b_{rg}) \\ + o_t = \mathrm{sigmoid}(W_{io} x_t + b_{o} + W_{ro} r_{(t-1)} + b_{ro}) \\ + c_t = f_t * c_{(t-1)} + i_t * g_t \\ + h_t = o_t * \tanh(c_t) + r_t = W_{hr} h_t + \end{array} + + **GRU** + + Gated Recurrent Unit - Cho et al. 2014. http://arxiv.org/abs/1406.1078 + + The definition of GRU here is slightly different from paper but compatible with CUDNN. + + .. math:: + \begin{array}{ll} + r_t = \mathrm{sigmoid}(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ + z_t = \mathrm{sigmoid}(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ + n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\ + h_t = (1 - z_t) * n_t + z_t * h_{(t-1)} \\ + \end{array} + + Parameters + ---------- + data : NDArray + Input data to RNN + parameters : NDArray + Vector of all RNN trainable parameters concatenated + state : NDArray + initial hidden state of the RNN + state_cell : NDArray + initial cell state for LSTM networks (only for LSTM) + sequence_length : NDArray + Vector of valid sequence lengths for each element in batch. + (Only used if use_sequence_length kwarg is True) + state_size : int (non-negative), required + size of the state for each layer + num_layers : int (non-negative), required + number of stacked layers + bidirectional : boolean, optional, default=0 + whether to use bidirectional recurrent layers + mode : {'gru', 'lstm', 'rnn_relu', 'rnn_tanh'}, required + the type of RNN to compute + p : float, optional, default=0 + drop rate of the dropout on the outputs of each RNN layer, except the last layer. + state_outputs : boolean, optional, default=0 + Whether to have the states as symbol outputs. + projection_size : int or None, optional, default='None' + size of project size + lstm_state_clip_min : double or None, optional, default=None + Minimum clip value of LSTM states. This option must be used together with lstm_state_clip_max. + lstm_state_clip_max : double or None, optional, default=None + Maximum clip value of LSTM states. This option must be used together with lstm_state_clip_min. + lstm_state_clip_nan : boolean, optional, default=0 + Whether to stop NaN from propagating in state by clipping it to min/max. + If clipping range is not specified, this option is ignored. + use_sequence_length : boolean, optional, default=0 + If set to true, this layer takes in an extra input parameter `sequence_length` + to specify variable length sequence + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + """ + assert mode is not None, "Please provide rnn type to compute. e.g. rnn_relu, rnn_tanh, lstm, gru" + assert data is not None and parameters is not None and state is not None, \ + "Missing input data/parameters/state." + assert state_size is not None, "Please provide state_size" + assert num_layers is not None, "Please provide num_layers" + if use_sequence_length: + assert sequence_length is not None, \ + "use_sequence_length is set True, but no sequence_length provided." + if mode == "lstm": + assert state_cell is not None, \ + "RNN computing mode is lstm, but no state_cell is provided" + return _api_internal.rnn(data, parameters, state, state_cell, sequence_length, + state_size, num_layers, bidirectional, state_outputs, + mode, p, use_sequence_length, projection_size, + lstm_state_clip_min, lstm_state_clip_max, lstm_state_clip_nan) + else: + return _api_internal.rnn(data, parameters, state, sequence_length, + state_size, num_layers, bidirectional, state_outputs, + mode, p, use_sequence_length, projection_size, + lstm_state_clip_min, lstm_state_clip_max, lstm_state_clip_nan) + else: + if mode == "lstm": + assert state_cell is not None, \ + "RNN computing mode is lstm, but no state_cell is provided" + return _api_internal.rnn(data, parameters, state, state_cell, + state_size, num_layers, bidirectional, state_outputs, + mode, p, use_sequence_length, projection_size, + lstm_state_clip_min, lstm_state_clip_max, lstm_state_clip_nan) + else: + return _api_internal.rnn(data, parameters, state, + state_size, num_layers, bidirectional, state_outputs, + mode, p, use_sequence_length, projection_size, + lstm_state_clip_min, lstm_state_clip_max, lstm_state_clip_nan) diff --git a/python/mxnet/numpy_extension/_op.py b/python/mxnet/numpy_extension/_op.py index 6ca2248348d7..b7d75ffdc6d0 100644 --- a/python/mxnet/numpy_extension/_op.py +++ b/python/mxnet/numpy_extension/_op.py @@ -23,7 +23,7 @@ __all__ = ['softmax', 'log_softmax', 'masked_softmax', 'masked_log_softmax', 'activation', 'batch_norm', 'fully_connected', 'pick', 'convolution', - 'deconvolution'] + 'deconvolution', 'pooling', 'dropout', 'one_hot', 'rnn'] # pylint: disable=too-many-arguments @@ -656,3 +656,320 @@ def deconvolution(data=None, weight=None, bias=None, kernel=None, stride=None, d target_shape=target_shape, num_filter=num_filter, num_group=num_group, workspace=workspace, no_bias=no_bias, cudnn_tune=cudnn_tune, cudnn_off=cudnn_off, layout=layout) + + +# pylint: disable=too-many-arguments, unused-argument +@set_module('mxnet.numpy_extension') +def pooling(data=None, kernel=None, stride=None, pad=None, pool_type="max", + pooling_convention="valid", global_pool=False, cudnn_off=False, + p_value=None, count_include_pad=None, layout=None, **kwargs): + r"""Performs pooling on the input. + + The shapes for 1-D pooling are + + - **data** and **out**: *(batch_size, channel, width)* (NCW layout) or + *(batch_size, width, channel)* (NWC layout), + + The shapes for 2-D pooling are + + - **data** and **out**: *(batch_size, channel, height, width)* (NCHW layout) or + *(batch_size, height, width, channel)* (NHWC layout), + + out_height = f(height, kernel[0], pad[0], stride[0]) + out_width = f(width, kernel[1], pad[1], stride[1]) + + The definition of *f* depends on ``pooling_convention``, which has two options: + + - **valid** (default):: + + f(x, k, p, s) = floor((x+2*p-k)/s)+1 + + - **full**, which is compatible with Caffe:: + + f(x, k, p, s) = ceil((x+2*p-k)/s)+1 + + When ``global_pool`` is set to be true, then global pooling is performed. It will reset + ``kernel=(height, width)`` and set the appropiate padding to 0. + + Three pooling options are supported by ``pool_type``: + + - **avg**: average pooling + - **max**: max pooling + - **sum**: sum pooling + - **lp**: Lp pooling + + For 3-D pooling, an additional *depth* dimension is added before + *height*. Namely the input data and output will have shape *(batch_size, channel, depth, + height, width)* (NCDHW layout) or *(batch_size, depth, height, width, channel)* (NDHWC layout). + + Notes on Lp pooling: + + Lp pooling was first introduced by this paper: https://arxiv.org/pdf/1204.3968.pdf. + L-1 pooling is simply sum pooling, while L-inf pooling is simply max pooling. + We can see that Lp pooling stands between those two, in practice the most common value for p is 2. + + For each window ``X``, the mathematical expression for Lp pooling is: + + :math:`f(X) = \sqrt[p]{\sum_{x}^{X} x^p}` + + Parameters + ---------- + data : NDArray + Input data to the pooling operator. + kernel : Shape(tuple), optional, default=[] + Pooling kernel size: (y, x) or (d, y, x) + pool_type : {'avg', 'lp', 'max', 'sum'},optional, default='max' + Pooling type to be applied. + global_pool : boolean, optional, default=0 + Ignore kernel size, do global pooling based on current input feature map. + cudnn_off : boolean, optional, default=0 + Turn off cudnn pooling and use MXNet pooling operator. + pooling_convention : {'full', 'same', 'valid'},optional, default='valid' + Pooling convention to be applied. + stride : Shape(tuple), optional, default=[] + Stride: for pooling (y, x) or (d, y, x). Defaults to 1 for each dimension. + pad : Shape(tuple), optional, default=[] + Pad for pooling: (y, x) or (d, y, x). Defaults to no padding. + p_value : int or None, optional, default='None' + Value of p for Lp pooling, can be 1 or 2, required for Lp Pooling. + count_include_pad : boolean or None, optional, default=None + Only used for AvgPool, specify whether to count padding elements for averagecalculation. + For example, with a 5*5 kernel on a 3*3 corner of a image,the sum of the 9 valid elements will + be divided by 25 if this is set to true,or it will be divided by 9 if this is set to false. + Defaults to true. + layout : {None, 'NCDHW', 'NCHW', 'NCW', 'NDHWC', 'NHWC', 'NWC'},optional, default='None' + Set layout for input and output. Empty for + default layout: NCW for 1d, NCHW for 2d and NCDHW for 3d. + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + """ + return _mx_nd_npx.pooling(data=data, kernel=kernel, stride=stride, pad=pad, + pool_type=pool_type, pooling_convention=pooling_convention, + global_pool=global_pool, cudnn_off=cudnn_off, p_value=p_value, + count_include_pad=count_include_pad, layout=layout) + + +# pylint: disable=too-many-arguments, unused-argument +@set_module('mxnet.numpy_extension') +def dropout(data, p=0.5, mode="training", axes=None, cudnn_off=True, **kwargs): + r"""Applies dropout operation to input array. + + - During training, each element of the input is set to zero with probability p. + The whole array is rescaled by :math:`1/(1-p)` to keep the expected + sum of the input unchanged. + + - During testing, this operator does not change the input if mode is 'training'. + If mode is 'always', the same computaion as during training will be applied. + + Parameters + ---------- + data : NDArray + Input array to which dropout will be applied. + p : float, optional, default=0.5 + Fraction of the input that gets dropped out during training time. + mode : {'always', 'training'},optional, default='training' + Whether to only turn on dropout during training or to also turn on for inference. + axes : Shape(tuple), optional, default=[] + Axes for variational dropout kernel. + cudnn_off : boolean or None, optional, default=0 + Whether to turn off cudnn in dropout operator. This option is ignored if axes is specified. + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + """ + return _mx_nd_npx.dropout(data=data, p=p, mode=mode, axes=axes, cudnn_off=cudnn_off) + + +# pylint: disable=too-many-arguments +@set_module('mxnet.numpy_extension') +def one_hot(data, depth=None, on_value=1.0, off_value=0.0, dtype="float32"): + r"""Returns a one-hot array. + + The locations represented by `indices` take value `on_value`, while all + other locations take value `off_value`. + + `one_hot` operation with `indices` of shape ``(i0, i1)`` and `depth` of ``d`` would result + in an output array of shape ``(i0, i1, d)`` with:: + + output[i,j,:] = off_value + output[i,j,indices[i,j]] = on_value + + Parameters + ---------- + indices : NDArray + array of locations where to set on_value + depth : long, required + Depth of the one hot dimension. + on_value : double, optional, default=1 + The value assigned to the locations represented by indices. + off_value : double, optional, default=0 + The value assigned to the locations not represented by indices. + dtype : {'bfloat16', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'}, + optional, default='float32' + DType of the output + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + + Example + ------- + >>> data = np.array([1,0,2,0]) + >>> npx.one_hot(data, 3) + array([[0., 1., 0.], + [1., 0., 0.], + [0., 0., 1.], + [1., 0., 0.]], dtype=float64) + >>> npx.one_hot(data, 3, on_value=8, off_value=1, dtype='int32') + array([[1, 8, 1], + [8, 1, 1], + [1, 1, 8], + [8, 1, 1]], dtype=int32) + >>> data = np.array([[1,0],[1,0],[2,0]]) + >>> npx.one_hot(data, 3) + array([[[0., 1., 0.], + [1., 0., 0.]], + + [[0., 1., 0.], + [1., 0., 0.]], + + [[0., 0., 1.], + [1., 0., 0.]]], dtype=float64) + """ + return _mx_nd_npx.one_hot(data=data, depth=depth, on_value=on_value, off_value=off_value, + dtype=dtype) + + +# pylint: disable=too-many-arguments, unused-argument +@set_module('mxnet.numpy_extension') +def rnn(data=None, parameters=None, state=None, state_cell=None, sequence_length=None, + mode=None, state_size=None, num_layers=None, bidirectional=False, + state_outputs=False, p=0.0, use_sequence_length=False, projection_size=None, + lstm_state_clip_min=None, lstm_state_clip_max=None, lstm_state_clip_nan=None): + r"""Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are + implemented, with both multi-layer and bidirectional support. + + When the input data is of type float32 and the environment variables MXNET_CUDA_ALLOW_TENSOR_CORE + and MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION are set to 1, this operator will try to use + pseudo-float16 precision (float32 math with float16 I/O) precision in order to use + Tensor Cores on suitable NVIDIA GPUs. This can sometimes give significant speedups. + + **Vanilla RNN** + + Applies a single-gate recurrent layer to input X. Two kinds of activation function are supported: + ReLU and Tanh. + + With ReLU activation function: + + .. math:: + h_t = relu(W_{ih} * x_t + b_{ih} + W_{hh} * h_{(t-1)} + b_{hh}) + + With Tanh activtion function: + + .. math:: + h_t = \tanh(W_{ih} * x_t + b_{ih} + W_{hh} * h_{(t-1)} + b_{hh}) + + Reference paper: Finding structure in time - Elman, 1988. + https://crl.ucsd.edu/~elman/Papers/fsit.pdf + + **LSTM** + + Long Short-Term Memory - Hochreiter, 1997. http://www.bioinf.jku.at/publications/older/2604.pdf + + .. math:: + \begin{array}{ll} + i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\ + f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\ + g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hc} h_{(t-1)} + b_{hg}) \\ + o_t = \mathrm{sigmoid}(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\ + c_t = f_t * c_{(t-1)} + i_t * g_t \\ + h_t = o_t * \tanh(c_t) + \end{array} + + With the projection size being set, LSTM could use the projection feature to reduce the parameters + size and give some speedups without significant damage to the accuracy. + + Long Short-Term Memory Based Recurrent Neural Network Architectures for Large Vocabulary Speech + Recognition - Sak et al. 2014. https://arxiv.org/abs/1402.1128 + + .. math:: + \begin{array}{ll} + i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{ri} r_{(t-1)} + b_{ri}) \\ + f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{rf} r_{(t-1)} + b_{rf}) \\ + g_t = \tanh(W_{ig} x_t + b_{ig} + W_{rc} r_{(t-1)} + b_{rg}) \\ + o_t = \mathrm{sigmoid}(W_{io} x_t + b_{o} + W_{ro} r_{(t-1)} + b_{ro}) \\ + c_t = f_t * c_{(t-1)} + i_t * g_t \\ + h_t = o_t * \tanh(c_t) + r_t = W_{hr} h_t + \end{array} + + **GRU** + + Gated Recurrent Unit - Cho et al. 2014. http://arxiv.org/abs/1406.1078 + + The definition of GRU here is slightly different from paper but compatible with CUDNN. + + .. math:: + \begin{array}{ll} + r_t = \mathrm{sigmoid}(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ + z_t = \mathrm{sigmoid}(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ + n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\ + h_t = (1 - z_t) * n_t + z_t * h_{(t-1)} \\ + \end{array} + + Parameters + ---------- + data : NDArray + Input data to RNN + parameters : NDArray + Vector of all RNN trainable parameters concatenated + state : NDArray + initial hidden state of the RNN + state_cell : NDArray + initial cell state for LSTM networks (only for LSTM) + sequence_length : NDArray + Vector of valid sequence lengths for each element in batch. + (Only used if use_sequence_length kwarg is True) + state_size : int (non-negative), required + size of the state for each layer + num_layers : int (non-negative), required + number of stacked layers + bidirectional : boolean, optional, default=0 + whether to use bidirectional recurrent layers + mode : {'gru', 'lstm', 'rnn_relu', 'rnn_tanh'}, required + the type of RNN to compute + p : float, optional, default=0 + drop rate of the dropout on the outputs of each RNN layer, except the last layer. + state_outputs : boolean, optional, default=0 + Whether to have the states as symbol outputs. + projection_size : int or None, optional, default='None' + size of project size + lstm_state_clip_min : double or None, optional, default=None + Minimum clip value of LSTM states. This option must be used together with lstm_state_clip_max. + lstm_state_clip_max : double or None, optional, default=None + Maximum clip value of LSTM states. This option must be used together with lstm_state_clip_min. + lstm_state_clip_nan : boolean, optional, default=0 + Whether to stop NaN from propagating in state by clipping it to min/max. + If clipping range is not specified, this option is ignored. + use_sequence_length : boolean, optional, default=0 + If set to true, this layer takes in an extra input parameter `sequence_length` + to specify variable length sequence + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + """ + return _mx_nd_npx.rnn(data=data, parameters=parameters, state=state, state_cell=state_cell, + sequence_length=sequence_length, mode=mode, state_size=state_size, + num_layers=num_layers, bidirectional=bidirectional, + state_outputs=state_outputs, p=p, use_sequence_length=use_sequence_length, + projection_size=projection_size, lstm_state_clip_min=lstm_state_clip_min, + lstm_state_clip_max=lstm_state_clip_max, + lstm_state_clip_nan=lstm_state_clip_nan) diff --git a/src/api/operator/numpy_extension/npx_dropout_op.cc b/src/api/operator/numpy_extension/npx_dropout_op.cc new file mode 100644 index 000000000000..e17320f30a2e --- /dev/null +++ b/src/api/operator/numpy_extension/npx_dropout_op.cc @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file npx_dropout_op.cc + * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_dropout_op.cc + */ +#include +#include +#include "../utils.h" +#include "../../../operator/nn/dropout-inl.h" + +namespace mxnet { + +inline int String2Mode(const std::string& s) { + using namespace op; + if (s == "training") { + return dropout::kTraining; + } else if (s == "always") { + return dropout::kAlways; + } else { + LOG(FATAL) << "unknown dropout mode " << s; + } + LOG(FATAL) << "should not reach here "; + return 0; +} + +MXNET_REGISTER_API("_npx.dropout") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npx_dropout"); + op::DropoutParam param; + // inputs + int num_inputs = 1; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + // p + param.p = args[1].operator double(); + // mode + param.mode = String2Mode(args[2].operator std::string()); + // axes + if (args[3].type_code() == kNull) { + param.axes = TShape(0, 0); + } else if (args[3].type_code() == kDLInt) { + param.axes = TShape(1, args[3].operator int64_t()); + } else { + param.axes = TShape(args[3].operator ObjectRef()); + } + // cudnn_off + if (args[4].type_code() == kNull) { + param.cudnn_off = false; + } else { + param.cudnn_off = args[4].operator bool(); + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_outputs = 1; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; +}); + +} // namespace mxnet diff --git a/src/api/operator/numpy_extension/npx_one_hot_op.cc b/src/api/operator/numpy_extension/npx_one_hot_op.cc new file mode 100644 index 000000000000..090d56e3b22e --- /dev/null +++ b/src/api/operator/numpy_extension/npx_one_hot_op.cc @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file npx_one_hot_op.cc + * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_one_hot_op.cc + */ +#include +#include +#include "../utils.h" +#include "../../../operator/tensor/indexing_op.h" + +namespace mxnet { + +MXNET_REGISTER_API("_npx.one_hot") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npx_one_hot"); + op::OneHotParam param; + // inputs + int num_inputs = 1; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + // depth + param.depth = args[1].operator int64_t(); + // on_value + if (args[2].type_code() == kNull) { + param.on_value = 1.0; + } else { + param.on_value = args[2].operator double(); + } + // off_value + if (args[3].type_code() == kNull) { + param.off_value = 0.0; + } else { + param.off_value = args[3].operator double(); + } + // dtype + if (args[4].type_code() != kNull) { + param.dtype = String2MXNetTypeWithBool(args[4].operator std::string()); + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_outputs = 1; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; +}); + +} // namespace mxnet diff --git a/src/api/operator/numpy_extension/npx_pooling_op.cc b/src/api/operator/numpy_extension/npx_pooling_op.cc new file mode 100644 index 000000000000..5e8ab8c3435b --- /dev/null +++ b/src/api/operator/numpy_extension/npx_pooling_op.cc @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file npx_pooling_op.cc + * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_pooling_op.cc + */ +#include +#include +#include "../utils.h" +#include "../../../operator/nn/pooling-inl.h" + +namespace mxnet { + +inline int String2Layout(const std::string& s) { + using namespace op; + if (s == "NCW") { + return mshadow::kNCW; + } else if (s == "NCHW") { + return mshadow::kNCHW; + } else if (s == "NCDHW") { + return mshadow::kNCDHW; + } else if (s == "NWC") { + return mshadow::kNWC; + } else if (s == "NHWC") { + return mshadow::kNHWC; + } else if (s == "NDHWC") { + return mshadow::kNDHWC; + } else { + LOG(FATAL) << "unknown layout type " << s; + } + LOG(FATAL) << "should not reach here "; + return 0; +} + +inline int String2PoolType(const std::string& s) { + using namespace op; + if (s == "max") { + return pool_enum::kMaxPooling; + } else if (s == "avg") { + return pool_enum::kAvgPooling; + } else if (s == "sum") { + return pool_enum::kSumPooling; + } else if (s == "lp") { + return pool_enum::kLpPooling; + } else { + LOG(FATAL) << "unknown pooling type type " << s; + } + LOG(FATAL) << "should not reach here "; + return 0; +} + +inline int String2Convention(const std::string& s) { + using namespace op; + if (s == "full") { + return pool_enum::kFull; + } else if (s == "valid") { + return pool_enum::kValid; + } else if (s == "same") { + return pool_enum::kSame; + } else { + LOG(FATAL) << "unknown pooling convention type " << s; + } + LOG(FATAL) << "should not reach here "; + return 0; +} + +MXNET_REGISTER_API("_npx.pooling") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npx_pooling"); + op::PoolingParam param; + // inputs + int num_inputs = 1; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + + // kernel + if (args[1].type_code() == kDLInt) { + param.kernel = TShape(1, args[1].operator int64_t()); + } else { + param.kernel = TShape(args[1].operator ObjectRef()); + } + + // stride + if (args[2].type_code() == kNull) { + if (param.kernel.ndim() == 1) { + param.stride = mshadow::Shape1(1); + } else if (param.kernel.ndim() == 2) { + param.stride = mshadow::Shape2(1, 1); + } else { + param.stride = mshadow::Shape3(1, 1, 1); + } + } else if (args[2].type_code() == kDLInt) { + param.stride = TShape(1, args[2].operator int64_t()); + } else { + param.stride = TShape(args[2].operator ObjectRef()); + } + // pad + if (args[3].type_code() == kNull) { + if (param.kernel.ndim() == 1) { + param.pad = mshadow::Shape1(0); + } else if (param.kernel.ndim() == 2) { + param.pad = mshadow::Shape2(0, 0); + } else { + param.pad = mshadow::Shape3(0, 0, 0); + } + } else if (args[3].type_code() == kDLInt) { + param.pad = TShape(1, args[3].operator int64_t()); + } else { + param.pad = TShape(args[3].operator ObjectRef()); + } + // pool type + param.pool_type = String2PoolType(args[4].operator std::string()); + // pooling convention + param.pooling_convention = String2Convention(args[5].operator std::string()); + // global pool + param.global_pool = args[6].operator bool(); + // cudnn_off + if (args[7].type_code() == kNull) { + param.cudnn_off = false; + } else { + param.cudnn_off = args[7].operator bool(); + } + // p_value + if (args[8].type_code() == kNull) { + param.p_value = dmlc::nullopt; + } else { + param.p_value = args[8].operator int(); + } + // count_include_pad + if (args[9].type_code() == kNull) { + param.count_include_pad = dmlc::nullopt; + } else { + param.count_include_pad = args[9].operator bool(); + } + // layout + if (args[10].type_code() == kNull) { + param.layout = dmlc::nullopt; + } else { + param.layout = String2Layout(args[num_inputs + 10]); + } + + if (param.global_pool == false) { + CHECK_EQ(param.kernel.ndim(), 3U) << param.kernel.ndim() + << "D pooling not supported"; + } + + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + if (num_outputs == 1) { + *ret = ndoutputs[0]; + } else { + std::vector ndarray_handles; + ndarray_handles.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + ndarray_handles.emplace_back(ndoutputs[i]); + } + *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); + } +}); + +} // namespace mxnet diff --git a/src/api/operator/numpy_extension/npx_rnn_op.cc b/src/api/operator/numpy_extension/npx_rnn_op.cc new file mode 100644 index 000000000000..6d94b390c4d2 --- /dev/null +++ b/src/api/operator/numpy_extension/npx_rnn_op.cc @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file npx_rnn_op.cc + * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_rnn_op.cc + */ +#include +#include +#include "../utils.h" +#include "../../../operator/rnn-inl.h" + +namespace mxnet { + +inline int String2ComputeMode(const std::string& s) { + using namespace op; + if (s == "rnn_relu") { + return rnn_enum::kRnnRelu; + } else if (s == "rnn_tanh") { + return rnn_enum::kRnnTanh; + } else if (s == "lstm") { + return rnn_enum::kLstm; + } else if (s == "gru") { + return rnn_enum::kGru; + } else { + LOG(FATAL) << "unknown compute mode " << s; + } + LOG(FATAL) << "should not reach here "; + return 0; +} + +MXNET_REGISTER_API("_npx.rnn") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npx_rnn"); + op::RNNParam param; + int args_size = args.size(); + int num_inputs = 0; + + // mode + param.mode = String2ComputeMode(args[args_size - 7].operator std::string()); + num_inputs = (param.mode == op::rnn_enum::kLstm) ? 4 : 3; + // use_sequence_length + if (args[args_size - 5].type_code() == kNull) { + param.use_sequence_length = false; + } else { + param.use_sequence_length = args[args_size - 5].operator bool(); + } + if (param.use_sequence_length) num_inputs += 1; + // inputs + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + // state_size + param.state_size = (uint32_t) (args[args_size - 11].operator int()); + // num_layers + param.num_layers = (uint32_t) (args[args_size - 10].operator int()); + // bidirectional + if (args[args_size - 9].type_code() == kNull) { + param.bidirectional = false; + } else { + param.bidirectional = args[args_size - 9].operator bool(); + } + // state_outputs + if (args[args_size - 8].type_code() == kNull) { + param.state_outputs = false; + } else { + param.state_outputs = args[args_size - 8].operator bool(); + } + // p + if (args[args_size - 6].type_code() == kNull) { + param.p = 0.0; + } else { + param.p = args[args_size - 6].operator double(); + } + // projection_size + if (args[args_size - 4].type_code() == kNull) { + param.projection_size = dmlc::nullopt; + } else { + param.projection_size = args[args_size - 4].operator int(); + } + // lstm_state_clip_min + if (args[args_size - 3].type_code() == kNull) { + param.lstm_state_clip_min = dmlc::nullopt; + } else { + param.lstm_state_clip_min = args[args_size - 3].operator double(); + } + // lstm_state_clip_max + if (args[args_size - 2].type_code() == kNull) { + param.lstm_state_clip_max = dmlc::nullopt; + } else { + param.lstm_state_clip_max = args[args_size - 2].operator double(); + } + // lstm_state_clip_nan + if (args[args_size - 1].type_code() == kNull) { + param.lstm_state_clip_nan = false; + } else { + param.lstm_state_clip_nan = args[args_size - 1].operator bool(); + } + // initialize + param.seq_length_ = 0; + param.batch_size_ = 0; + param.input_size_ = 0; + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + if (num_outputs == 1) { + *ret = ndoutputs[0]; + } else { + std::vector ndarray_handles; + ndarray_handles.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + ndarray_handles.emplace_back(ndoutputs[i]); + } + *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); + } +}); + +} // namespace mxnet diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index 31011c5aed3b..b32e47460e0a 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -84,6 +84,29 @@ struct DropoutParam : public dmlc::Parameter { .describe("Whether to turn off cudnn in dropout operator. " "This option is ignored if axes is specified."); } + std::string Mode2String(int mode) { + switch (mode) { + case dropout::kTraining: + return "training"; + case dropout::kAlways: + return "always"; + default: + LOG(FATAL) << "Unknown mode enum " << mode; + } + LOG(FATAL) << "should not reach here "; + return ""; + } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream p_s, mode_s, axes_s, cudnn_off_s; + p_s << p; + mode_s << mode; + axes_s << axes; + cudnn_off_s << cudnn_off; + (*dict)["p"] = p_s.str(); + (*dict)["mode"] = Mode2String(mode); + (*dict)["axes"] = axes_s.str(); + (*dict)["cudnn_off"] = cudnn_off_s.str(); + } }; // struct DropoutParam template diff --git a/src/operator/nn/pooling-inl.h b/src/operator/nn/pooling-inl.h index 03f0fa8edd6c..d8193ec95e60 100644 --- a/src/operator/nn/pooling-inl.h +++ b/src/operator/nn/pooling-inl.h @@ -138,6 +138,86 @@ struct PoolingParam : public dmlc::Parameter { } return ret_val; } + + std::string PoolType2String(int pool_type) { + switch (pool_type) { + case pool_enum::kMaxPooling: + return "max"; + case pool_enum::kAvgPooling: + return "avg"; + case pool_enum::kSumPooling: + return "sum"; + case pool_enum::kLpPooling: + return "lp"; + default: + LOG(FATAL) << "Unknown pool type enum " << pool_type; + } + LOG(FATAL) << "should not reach here "; + return ""; + } + std::string Convention2String(int pool_convention) { + switch (pool_convention) { + case pool_enum::kFull: + return "full"; + case pool_enum::kValid: + return "valid"; + case pool_enum::kSame: + return "same"; + default: + LOG(FATAL) << "Unknown pool convention enum " << pool_convention; + } + LOG(FATAL) << "should not reach here "; + return ""; + } + std::string Layout2String(int layout) { + switch (layout) { + case mshadow::kNCW: + return "NCW"; + case mshadow::kNCHW: + return "NCHW"; + case mshadow::kNCDHW: + return "NCDHW"; + case mshadow::kNWC: + return "NWC"; + case mshadow::kNHWC: + return "NHWC"; + case mshadow::kNDHWC: + return "NDHWC"; + default: + LOG(FATAL) << "Unknown layout enum " << layout; + } + LOG(FATAL) << "should not reach here "; + return ""; + } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream kernel_s, stride_s, pad_s, pool_type_s, + pooling_convention_s, global_pool_s, cudnn_off_s, + p_value_s, count_include_pad_s, layout_s; + kernel_s << kernel; + stride_s << stride; + pad_s << pad; + pool_type_s << pool_type; + pooling_convention_s << pooling_convention; + global_pool_s << global_pool; + cudnn_off_s << cudnn_off; + p_value_s << p_value; + count_include_pad_s << count_include_pad; + layout_s << layout; + (*dict)["kernel"] = kernel_s.str(); + (*dict)["stride"] = stride_s.str(); + (*dict)["pad"] = pad_s.str(); + (*dict)["pool_type"] = PoolType2String(pool_type); + (*dict)["pooling_convention"] = Convention2String(pooling_convention); + (*dict)["global_pool"] = global_pool_s.str(); + (*dict)["cudnn_off"] = cudnn_off_s.str(); + (*dict)["p_value"] = p_value_s.str(); + (*dict)["count_include_pad"] = count_include_pad_s.str(); + if (layout.has_value()) { + (*dict)["layout"] = Layout2String(layout.value()); + } else { + (*dict)["layout"] = layout_s.str(); + } + } }; } // namespace op diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index aa1226e640c4..74362e33ac7b 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -65,7 +65,12 @@ struct RNNParam : public dmlc::Parameter { bool bidirectional, state_outputs; int mode; float p; +#pragma GCC diagnostic push +#if __GNUC__ >= 6 +#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" +#endif index_t seq_length_, batch_size_, input_size_; +#pragma GCC diagnostic pop bool use_sequence_length; dmlc::optional projection_size; @@ -122,6 +127,51 @@ struct RNNParam : public dmlc::Parameter { "`sequence_length` " "to specify variable length sequence"); } + std::string ComputeMode2String(int mode) { + switch (mode) { + case rnn_enum::kRnnRelu: + return "rnn_relu"; + case rnn_enum::kRnnTanh: + return "rnn_tanh"; + case rnn_enum::kLstm: + return "lstm"; + case rnn_enum::kGru: + return "gru"; + default: + LOG(FATAL) << "Unknown mode enum " << mode; + } + LOG(FATAL) << "should not reach here "; + return ""; + } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream state_size_s, num_layers_s, bidirectional_s, + state_outputs_s, mode_s, p_s, + use_sequence_length_s, projection_size_s, + lstm_state_clip_min_s, lstm_state_clip_max_s, + lstm_state_clip_nan_s; + state_size_s << state_size; + num_layers_s << num_layers; + bidirectional_s << bidirectional; + state_outputs_s << state_outputs; + mode_s << mode; + p_s << p; + use_sequence_length_s << use_sequence_length; + projection_size_s << projection_size; + lstm_state_clip_min_s << lstm_state_clip_min; + lstm_state_clip_max_s << lstm_state_clip_max; + lstm_state_clip_nan_s << lstm_state_clip_nan; + (*dict)["state_size"] = state_size_s.str(); + (*dict)["num_layers"] = num_layers_s.str(); + (*dict)["bidirectional"] = bidirectional_s.str(); + (*dict)["state_outputs"] = state_outputs_s.str(); + (*dict)["mode"] = ComputeMode2String(mode); + (*dict)["p"] = p_s.str(); + (*dict)["use_sequence_length"] = use_sequence_length_s.str(); + (*dict)["projection_size"] = projection_size_s.str(); + (*dict)["lstm_state_clip_min"] = lstm_state_clip_min_s.str(); + (*dict)["lstm_state_clip_max"] = lstm_state_clip_max_s.str(); + (*dict)["lstm_state_clip_nan"] = lstm_state_clip_nan_s.str(); + } }; inline index_t GetRnnParamSize(int num_layer, diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index b06f5f92a2d8..95f4abc4bc68 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -1160,6 +1160,17 @@ struct OneHotParam : public dmlc::Parameter { MXNET_ADD_ALL_TYPES .describe("DType of the output"); } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream depth_s, on_value_s, off_value_s, axis_s, dtype_s; + depth_s << depth; + on_value_s << on_value; + off_value_s << off_value; + dtype_s << dtype; + (*dict)["depth"] = depth_s.str(); + (*dict)["on_value"] = on_value_s.str(); + (*dict)["off_value"] = off_value_s.str(); + (*dict)["dtype"] = MXNetTypeWithBool2String(dtype); + } }; inline void GetOneHotParams(const OneHotParam& param, index_t* depth, double* on_value,