diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 0f348ad7d168..e174d348fd18 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -798,7 +798,8 @@ def write_all_str(module_file, module_all_list): '_npx_masked_log_softmax', '_npx_activation', '_npx_batch_norm', '_npx_fully_connected', '_npx_pick', '_npx_convolution', '_npx_deconvolution', '_npx_pooling', - '_npx_dropout', '_npx_one_hot', '_npx_rnn', '_npx_batch_dot', + '_npx_dropout', '_npx_one_hot', '_npx_rnn', '_npx_embedding', + '_npx_topk', '_npx_layer_norm', '_npx_leaky_relu', '_npx_batch_dot', '_npx_broadcast_like', '_npx_arange_like'} _NP_INTERNAL_OP_PREFIX = '_npi_' diff --git a/python/mxnet/ndarray/numpy_extension/_op.py b/python/mxnet/ndarray/numpy_extension/_op.py index 99196f661c63..5d5ca1aad78c 100644 --- a/python/mxnet/ndarray/numpy_extension/_op.py +++ b/python/mxnet/ndarray/numpy_extension/_op.py @@ -26,8 +26,9 @@ __all__ = ['softmax', 'log_softmax', 'masked_softmax', 'masked_log_softmax', 'activation', 'batch_norm', 'fully_connected', 'pick', 'convolution', - 'deconvolution', 'pooling', 'dropout', 'one_hot', 'rnn', - 'batch_dot', 'broadcast_like', 'arange_like'] + 'deconvolution', 'pooling', 'dropout', 'one_hot', 'rnn', 'embedding', + 'topk', 'layer_norm', 'leaky_relu', 'batch_dot', 'broadcast_like', + 'arange_like'] # pylint: disable=too-many-arguments @@ -333,9 +334,12 @@ def batch_norm(x, gamma, beta, running_mean, running_var, eps=1e-3, momentum=0.9 out : NDArray or list of NDArrays The output of this function. """ - return _api_internal.batch_norm(x, gamma, beta, running_mean, running_var, eps, momentum, - fix_gamma, use_global_stats, output_mean_var, axis, - cudnn_off, min_calib_range, max_calib_range) + out = _api_internal.batch_norm(x, gamma, beta, running_mean, running_var, eps, momentum, + fix_gamma, use_global_stats, output_mean_var, axis, + cudnn_off, min_calib_range, max_calib_range) + if isinstance(out, NDArrayBase): + return out + return list(out) # pylint: disable=too-many-arguments, unused-argument @@ -1036,6 +1040,280 @@ def rnn(data=None, parameters=None, state=None, state_cell=None, sequence_length lstm_state_clip_min, lstm_state_clip_max, lstm_state_clip_nan) +# pylint: disable=too-many-arguments, unused-argument +@set_module('mxnet.ndarray.numpy_extension') +def embedding(data, weight, input_dim=None, output_dim=None, dtype="float32", sparse_grad=False, + **kwargs): + r"""Maps integer indices to vector representations (embeddings). + + This operator maps words to real-valued vectors in a high-dimensional space, + called word embeddings. These embeddings can capture semantic and syntactic properties of the words. + For example, it has been noted that in the learned embedding spaces, similar words tend + to be close to each other and dissimilar words far apart. + + For an input array of shape (d1, ..., dK), + the shape of an output array is (d1, ..., dK, output_dim). + All the input values should be integers in the range [0, input_dim). + + If the input_dim is ip0 and output_dim is op0, then shape of the embedding weight matrix must be + (ip0, op0). + + When "sparse_grad" is False, if any index mentioned is too large, it is replaced by the index that + addresses the last vector in an embedding matrix. + When "sparse_grad" is True, an error will be raised if invalid indices are found. + + The storage type of weight can be either row_sparse or default. + + .. Note:: + + If "sparse_grad" is set to True, the storage type of gradient w.r.t weights will be + "row_sparse". Only a subset of optimizers support sparse gradients, including SGD, AdaGrad + and Adam. Note that by default lazy updates is turned on, which may perform differently + from standard updates. For more details, please check the Optimization API at: + https://mxnet.incubator.apache.org/api/python/optimization/optimization.html + + Parameters + ---------- + data : NDArray + The input array to the embedding operator. + weight : NDArray + The embedding weight matrix. + input_dim : long, required + Vocabulary size of the input indices. + output_dim : long, required + Dimension of the embedding vectors. + dtype : {'bfloat16', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'}, + optional, default='float32' + Data type of weight. + sparse_grad : boolean, optional, default=0 + Compute row sparse gradient in the backward calculation. + If set to True, the grad's storage type is row_sparse. + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + + Example + ------- + >>> input_dim = 4 + >>> output_dim = 5 + + Each row in weight matrix y represents a word. So, y = (w0,w1,w2,w3) + + >>> y = np.arange(input_dim * output_dim).reshape(input_dim, output_dim) + >>> y + array([[ 0., 1., 2., 3., 4.], + [ 5., 6., 7., 8., 9.], + [10., 11., 12., 13., 14.], + [15., 16., 17., 18., 19.]]) + + Input array x represents n-grams(2-gram). So, x = [(w1,w3), (w0,w2)] + + >>> x = np.array([[1., 3.], [0., 2.]]) + >>> x + array([[1., 3.], + [0., 2.]]) + + Mapped input x to its vector representation y. + + >>> npx.embedding(x, y, input_dim, output_dim) + array([[[ 5., 6., 7., 8., 9.], + [15., 16., 17., 18., 19.]], + + [[ 0., 1., 2., 3., 4.], + [10., 11., 12., 13., 14.]]]) + """ + assert input_dim > 1, "Vocabulary size of the input indices should be greater than 1." + assert output_dim > 1, "Dimension of the embedding vectors should greater than 1." + return _api_internal.embedding(data, weight, input_dim, output_dim, dtype, sparse_grad) + + +# pylint: disable=too-many-arguments +@set_module('mxnet.ndarray.numpy_extension') +def topk(data, axis=-1, k=1, ret_typ="indices", is_ascend=False, dtype="float32"): + r"""Returns the indices of the top *k* elements in an input array along the given + axis (by default). + If ret_type is set to 'value' returns the value of top *k* elements (instead of indices). + In case of ret_type = 'both', both value and index would be returned. + The returned elements will be sorted. + + Parameters + ---------- + data : NDArray + The input array + axis : int or None, optional, default='-1' + Axis along which to choose the top k indices. + If not given, the flattened array is used. Default is -1. + k : int, optional, default='1' + Number of top elements to select, should be always smaller than or equal to + the element number in the given axis. A global sort is performed if set k < 1. + ret_typ : {'both', 'indices', 'mask', 'value'},optional, default='indices' + The return type. + "value" means to return the top k values, + "indices" means to return the indices of the top k values, + "mask" means to return a mask array containing 0 and 1. 1 means the top k values. + "both" means to return a list of both values and indices of top k elements. + is_ascend : boolean, optional, default=0 + Whether to choose k largest or k smallest elements. + Top K largest elements will be chosen if set to false. + dtype : {'float16', 'float32', 'float64', 'int32', 'int64', 'uint8'}, + optional, default='float32' + DType of the output indices when ret_typ is "indices" or "both". + An error will be raised if the selected data type cannot precisely represent the indices. + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + + Example + ------- + >>> x = np.array([[0.3, 0.2, 0.4], [0.1, 0.3, 0.2]]) + + returns an index of the largest element on last axis + + >>> npx.topk(x) + array([[2.], + [1.]]) + + returns the value of top-2 largest elements on last axis + + >>> npx.topk(x, ret_typ='value', k=2) + array([[0.4, 0.3], + [0.3, 0.2]]) + + returns the value of top-2 smallest elements on last axis + + >>> npx.topk(x, ret_typ='value', k=2, is_ascend=1) + array([[0.2, 0.3], + [0.1, 0.2]]) + + returns the value of top-2 largest elements on axis 0 + + >>> npx.topk(x, axis=0, ret_typ='value', k=2) + array([[0.3, 0.3, 0.4], + [0.1, 0.2, 0.2]]) + + flattens and then returns list of both values and indices + + >>> npx.topk(x, ret_typ='both', k=2) + [array([[0.4, 0.3], [0.3, 0.2]]), + array([[2., 0.], [1., 2.]])] + """ + out = _api_internal.topk(data, axis, k, ret_typ, is_ascend, dtype) + if isinstance(out, NDArrayBase): + return out + return list(out) + + +# pylint: disable=too-many-arguments +@set_module('mxnet.ndarray.numpy_extension') +def layer_norm(data=None, gamma=None, beta=None, axis=None, eps=None, output_mean_var=None): + r"""Layer normalization. + + Normalizes the channels of the input tensor by mean and variance, and applies a scale ``gamma`` as + well as offset ``beta``. + + Assume the input has more than one dimension and we normalize along axis 1. + We first compute the mean and variance along this axis and then + compute the normalized output, which has the same shape as input, as following: + + .. math:: + + out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis) + \epsilon}} * gamma + beta + + Both ``gamma`` and ``beta`` are learnable parameters. + + Unlike BatchNorm and InstanceNorm, the *mean* and *var* are computed along the channel dimension. + + Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta`` + have shape *(k,)*. If ``output_mean_var`` is set to be true, then outputs both ``data_mean`` and + ``data_std``. Note that no gradient will be passed through these two outputs. + + The parameter ``axis`` specifies which axis of the input shape denotes + the 'channel' (separately normalized groups). The default is -1, which sets the channel + axis to be the last item in the input shape. + + Parameters + ---------- + data : NDArray + Input data to layer normalization + gamma : NDArray + gamma array + beta : NDArray + beta array + axis : int, optional, default='-1' + The axis to perform layer normalization. + Usually, this should be be axis of the channel dimension. + Negative values means indexing from right to left. + eps : float, optional, default=9.99999975e-06 + An `epsilon` parameter to prevent division by 0. + output_mean_var : boolean, optional, default=0 + Output the mean and std calculated along the given axis. + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + """ + out = _api_internal.layer_norm(data, gamma, beta, axis, eps, output_mean_var) + if isinstance(out, NDArrayBase): + return out + return list(out) + + +# pylint: disable=too-many-arguments, unused-argument +@set_module('mxnet.ndarray.numpy_extension') +def leaky_relu(data=None, gamma=None, act_type="leaky", slope=0.25, lower_bound=0.125, + upper_bound=0.334, **kwargs): + r"""Applies Leaky rectified linear unit activation element-wise to the input. + + Leaky ReLUs attempt to fix the "dying ReLU" problem by allowing a small `slope` + when the input is negative and has a slope of one when input is positive. + + The following modified ReLU Activation functions are supported: + + - *elu*: Exponential Linear Unit. `y = x > 0 ? x : slope * (exp(x)-1)` + - *gelu*: Gaussian Error Linear Unit. `y = 0.5 * x * (1 + erf(x / sqrt(2)))` + - *selu*: Scaled Exponential Linear Unit. `y = lambda * (x > 0 ? x : alpha * (exp(x) - 1))` where + *lambda = 1.0507009873554804934193349852946* and *alpha = 1.6732632423543772848170429916717*. + - *leaky*: Leaky ReLU. `y = x > 0 ? x : slope * x` + - *prelu*: Parametric ReLU. This is same as *leaky* except that `slope` is learnt during training. + - *rrelu*: Randomized ReLU. same as *leaky* but the `slope` is uniformly and randomly chosen from + *[lower_bound, upper_bound)* for training, while fixed to be + *(lower_bound+upper_bound)/2* for inference. + + Parameters + ---------- + data : NDArray + Input data to activation function. + gamma : NDArray + Input data to activation function. + act_type : {'elu', 'gelu', 'leaky', 'prelu', 'rrelu', 'selu'},optional, default='leaky' + Activation function to be applied. + slope : float, optional, default=0.25 + Init slope for the activation. (For leaky and elu only) + lower_bound : float, optional, default=0.125 + Lower bound of random slope. (For rrelu only) + upper_bound : float, optional, default=0.333999991 + Upper bound of random slope. (For rrelu only) + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + """ + if act_type == "prelu": + assert gamma is not None, "If activation function is prelu, please provide input gamma" + out = _api_internal.leaky_relu(data, gamma, act_type, slope, lower_bound, upper_bound) + if isinstance(out, NDArrayBase): + return out + return list(out) + else: + return _api_internal.leaky_relu(data, act_type, slope, lower_bound, upper_bound) + + # pylint: disable=too-many-arguments @set_module('mxnet.ndarray.numpy_extension') def batch_dot(a, b, transpose_a=False, transpose_b=False, forward_stype="default"): diff --git a/python/mxnet/numpy_extension/_op.py b/python/mxnet/numpy_extension/_op.py index 2df97ace623d..1d672f45ec54 100644 --- a/python/mxnet/numpy_extension/_op.py +++ b/python/mxnet/numpy_extension/_op.py @@ -23,8 +23,9 @@ __all__ = ['softmax', 'log_softmax', 'masked_softmax', 'masked_log_softmax', 'activation', 'batch_norm', 'fully_connected', 'pick', 'convolution', - 'deconvolution', 'pooling', 'dropout', 'one_hot', 'rnn', - 'batch_dot', 'broadcast_like', 'arange_like'] + 'deconvolution', 'pooling', 'dropout', 'one_hot', 'rnn', 'embedding', + 'topk', 'layer_norm', 'leaky_relu', 'batch_dot', 'broadcast_like', + 'arange_like'] # pylint: disable=too-many-arguments @@ -968,6 +969,268 @@ def rnn(data=None, parameters=None, state=None, state_cell=None, sequence_length lstm_state_clip_nan=lstm_state_clip_nan) +# pylint: disable=too-many-arguments, unused-argument +@set_module('mxnet.numpy_extension') +def embedding(data, weight, input_dim=None, output_dim=None, dtype="float32", sparse_grad=False, + **kwargs): + r"""Maps integer indices to vector representations (embeddings). + + This operator maps words to real-valued vectors in a high-dimensional space, + called word embeddings. These embeddings can capture semantic and syntactic properties of the words. + For example, it has been noted that in the learned embedding spaces, similar words tend + to be close to each other and dissimilar words far apart. + + For an input array of shape (d1, ..., dK), + the shape of an output array is (d1, ..., dK, output_dim). + All the input values should be integers in the range [0, input_dim). + + If the input_dim is ip0 and output_dim is op0, then shape of the embedding weight matrix must be + (ip0, op0). + + When "sparse_grad" is False, if any index mentioned is too large, it is replaced by the index that + addresses the last vector in an embedding matrix. + When "sparse_grad" is True, an error will be raised if invalid indices are found. + + The storage type of weight can be either row_sparse or default. + + .. Note:: + + If "sparse_grad" is set to True, the storage type of gradient w.r.t weights will be + "row_sparse". Only a subset of optimizers support sparse gradients, including SGD, AdaGrad + and Adam. Note that by default lazy updates is turned on, which may perform differently + from standard updates. For more details, please check the Optimization API at: + https://mxnet.incubator.apache.org/api/python/optimization/optimization.html + + Parameters + ---------- + data : NDArray + The input array to the embedding operator. + weight : NDArray + The embedding weight matrix. + input_dim : long, required + Vocabulary size of the input indices. + output_dim : long, required + Dimension of the embedding vectors. + dtype : {'bfloat16', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'}, + optional, default='float32' + Data type of weight. + sparse_grad : boolean, optional, default=0 + Compute row sparse gradient in the backward calculation. + If set to True, the grad's storage type is row_sparse. + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + + Example + ------- + >>> input_dim = 4 + >>> output_dim = 5 + + Each row in weight matrix y represents a word. So, y = (w0,w1,w2,w3) + + >>> y = np.arange(input_dim * output_dim).reshape(input_dim, output_dim) + >>> y + array([[ 0., 1., 2., 3., 4.], + [ 5., 6., 7., 8., 9.], + [10., 11., 12., 13., 14.], + [15., 16., 17., 18., 19.]]) + + Input array x represents n-grams(2-gram). So, x = [(w1,w3), (w0,w2)] + + >>> x = np.array([[1., 3.], [0., 2.]]) + >>> x + array([[1., 3.], + [0., 2.]]) + + Mapped input x to its vector representation y. + + >>> npx.embedding(x, y, input_dim, output_dim) + array([[[ 5., 6., 7., 8., 9.], + [15., 16., 17., 18., 19.]], + + [[ 0., 1., 2., 3., 4.], + [10., 11., 12., 13., 14.]]]) + """ + return _mx_nd_npx.embedding(data=data, weight=weight, input_dim=input_dim, output_dim=output_dim, + dtype=dtype, sparse_grad=sparse_grad) + + +# pylint: disable=too-many-arguments +@set_module('mxnet.numpy_extension') +def topk(data, axis=-1, k=1, ret_typ="indices", is_ascend=False, dtype="float32"): + r"""Returns the indices of the top *k* elements in an input array along the given + axis (by default). + If ret_type is set to 'value' returns the value of top *k* elements (instead of indices). + In case of ret_type = 'both', both value and index would be returned. + The returned elements will be sorted. + + Parameters + ---------- + data : NDArray + The input array + axis : int or None, optional, default='-1' + Axis along which to choose the top k indices. + If not given, the flattened array is used. Default is -1. + k : int, optional, default='1' + Number of top elements to select, should be always smaller than or equal to + the element number in the given axis. A global sort is performed if set k < 1. + ret_typ : {'both', 'indices', 'mask', 'value'},optional, default='indices' + The return type. + "value" means to return the top k values, + "indices" means to return the indices of the top k values, + "mask" means to return a mask array containing 0 and 1. 1 means the top k values. + "both" means to return a list of both values and indices of top k elements. + is_ascend : boolean, optional, default=0 + Whether to choose k largest or k smallest elements. + Top K largest elements will be chosen if set to false. + dtype : {'float16', 'float32', 'float64', 'int32', 'int64', 'uint8'}, + optional, default='float32' + DType of the output indices when ret_typ is "indices" or "both". + An error will be raised if the selected data type cannot precisely represent the indices. + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + + Example + ------- + >>> x = np.array([[0.3, 0.2, 0.4], [0.1, 0.3, 0.2]]) + + returns an index of the largest element on last axis + + >>> npx.topk(x) + array([[2.], + [1.]]) + + returns the value of top-2 largest elements on last axis + + >>> npx.topk(x, ret_typ='value', k=2) + array([[0.4, 0.3], + [0.3, 0.2]]) + + returns the value of top-2 smallest elements on last axis + + >>> npx.topk(x, ret_typ='value', k=2, is_ascend=1) + array([[0.2, 0.3], + [0.1, 0.2]]) + + returns the value of top-2 largest elements on axis 0 + + >>> npx.topk(x, axis=0, ret_typ='value', k=2) + array([[0.3, 0.3, 0.4], + [0.1, 0.2, 0.2]]) + + flattens and then returns list of both values and indices + + >>> npx.topk(x, ret_typ='both', k=2) + [array([[0.4, 0.3], [0.3, 0.2]]), + array([[2., 0.], [1., 2.]])] + """ + return _mx_nd_npx.topk(data=data, axis=axis, k=k, ret_typ=ret_typ, is_ascend=is_ascend, dtype=dtype) + + +# pylint: disable=too-many-arguments +@set_module('mxnet.numpy_extension') +def layer_norm(data=None, gamma=None, beta=None, axis=None, eps=None, output_mean_var=None): + r"""Layer normalization. + + Normalizes the channels of the input tensor by mean and variance, and applies a scale ``gamma`` as + well as offset ``beta``. + + Assume the input has more than one dimension and we normalize along axis 1. + We first compute the mean and variance along this axis and then + compute the normalized output, which has the same shape as input, as following: + + .. math:: + + out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis) + \epsilon}} * gamma + beta + + Both ``gamma`` and ``beta`` are learnable parameters. + + Unlike BatchNorm and InstanceNorm, the *mean* and *var* are computed along the channel dimension. + + Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta`` + have shape *(k,)*. If ``output_mean_var`` is set to be true, then outputs both ``data_mean`` and + ``data_std``. Note that no gradient will be passed through these two outputs. + + The parameter ``axis`` specifies which axis of the input shape denotes + the 'channel' (separately normalized groups). The default is -1, which sets the channel + axis to be the last item in the input shape. + + Parameters + ---------- + data : NDArray + Input data to layer normalization + gamma : NDArray + gamma array + beta : NDArray + beta array + axis : int, optional, default='-1' + The axis to perform layer normalization. + Usually, this should be be axis of the channel dimension. + Negative values means indexing from right to left. + eps : float, optional, default=9.99999975e-06 + An `epsilon` parameter to prevent division by 0. + output_mean_var : boolean, optional, default=0 + Output the mean and std calculated along the given axis. + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + """ + return _mx_nd_npx.layer_norm(data=data, gamma=gamma, beta=beta, axis=axis, eps=eps, + output_mean_var=output_mean_var) + + +# pylint: disable=too-many-arguments +@set_module('mxnet.numpy_extension') +def leaky_relu(data=None, gamma=None, act_type="leaky", slope=0.25, lower_bound=0.125, + upper_bound=0.334, **kwargs): + r"""Applies Leaky rectified linear unit activation element-wise to the input. + + Leaky ReLUs attempt to fix the "dying ReLU" problem by allowing a small `slope` + when the input is negative and has a slope of one when input is positive. + + The following modified ReLU Activation functions are supported: + + - *elu*: Exponential Linear Unit. `y = x > 0 ? x : slope * (exp(x)-1)` + - *gelu*: Gaussian Error Linear Unit. `y = 0.5 * x * (1 + erf(x / sqrt(2)))` + - *selu*: Scaled Exponential Linear Unit. `y = lambda * (x > 0 ? x : alpha * (exp(x) - 1))` where + *lambda = 1.0507009873554804934193349852946* and *alpha = 1.6732632423543772848170429916717*. + - *leaky*: Leaky ReLU. `y = x > 0 ? x : slope * x` + - *prelu*: Parametric ReLU. This is same as *leaky* except that `slope` is learnt during training. + - *rrelu*: Randomized ReLU. same as *leaky* but the `slope` is uniformly and randomly chosen from + *[lower_bound, upper_bound)* for training, while fixed to be + *(lower_bound+upper_bound)/2* for inference. + + Parameters + ---------- + data : NDArray + Input data to activation function. + gamma : NDArray + Input data to activation function. + act_type : {'elu', 'gelu', 'leaky', 'prelu', 'rrelu', 'selu'},optional, default='leaky' + Activation function to be applied. + slope : float, optional, default=0.25 + Init slope for the activation. (For leaky and elu only) + lower_bound : float, optional, default=0.125 + Lower bound of random slope. (For rrelu only) + upper_bound : float, optional, default=0.333999991 + Upper bound of random slope. (For rrelu only) + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + """ + return _mx_nd_npx.leaky_relu(data=data, gamma=gamma, act_type=act_type, slope=slope, + lower_bound=lower_bound, upper_bound=upper_bound) + + # pylint: disable=too-many-arguments, unused-argument @set_module('mxnet.numpy_extension') def batch_dot(a, b, transpose_a=False, transpose_b=False, forward_stype="default"): diff --git a/src/api/operator/numpy_extension/npx_embedding_op.cc b/src/api/operator/numpy_extension/npx_embedding_op.cc new file mode 100644 index 000000000000..58b5e3ff740f --- /dev/null +++ b/src/api/operator/numpy_extension/npx_embedding_op.cc @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file npx_embedding_op.cc + * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_embedding_op.cc + */ +#include +#include +#include "../utils.h" +#include "../../../operator/tensor/indexing_op.h" + +namespace mxnet { + +MXNET_REGISTER_API("_npx.embedding") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npx_embedding"); + op::EmbeddingParam param; + // inputs + int num_inputs = 2; + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + // input_dim + param.input_dim = args[2].operator int64_t(); + // output_dim + param.output_dim = args[3].operator int64_t(); + // dtype + param.dtype = String2MXNetTypeWithBool(args[4].operator std::string()); + // sparse_grad; + if (args[5].type_code() == kNull) { + param.sparse_grad = false; + } else { + param.sparse_grad = args[5].operator bool(); + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_outputs = 1; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + *ret = ndoutputs[0]; +}); + +} // namespace mxnet diff --git a/src/api/operator/numpy_extension/npx_layer_norm_op.cc b/src/api/operator/numpy_extension/npx_layer_norm_op.cc new file mode 100644 index 000000000000..b638088d328d --- /dev/null +++ b/src/api/operator/numpy_extension/npx_layer_norm_op.cc @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file npx_layer_norm_op.cc + * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_layer_norm_op.cc + */ +#include +#include +#include "../utils.h" +#include "../../../operator/nn/layer_norm-inl.h" + +namespace mxnet { + +MXNET_REGISTER_API("_npx.layer_norm") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npx_layer_norm"); + op::LayerNormParam param; + // inputs + int num_inputs = 3; + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + // axis + if (args[3].type_code() == kNull) { + param.axis = -1; + } else { + param.axis = args[3].operator int(); + } + // eps + if (args[4].type_code() == kNull) { + param.eps = 1e-5f; + } else { + param.eps = args[4].operator double(); + } + // output_mean_var + if (args[5].type_code() == kNull) { + param.output_mean_var = false; + } else { + param.output_mean_var = args[5].operator bool(); + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_outputs = 3; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + if (num_outputs == 1) { + *ret = ndoutputs[0]; + } else { + std::vector ndarray_handles; + ndarray_handles.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + ndarray_handles.emplace_back(ndoutputs[i]); + } + *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); + } +}); + +} // namespace mxnet diff --git a/src/api/operator/numpy_extension/npx_leaky_relu_op.cc b/src/api/operator/numpy_extension/npx_leaky_relu_op.cc new file mode 100644 index 000000000000..7717cf79c8ab --- /dev/null +++ b/src/api/operator/numpy_extension/npx_leaky_relu_op.cc @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file npx_leaky_relu_op.cc + * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_leaky_relu_op.cc + */ +#include +#include +#include "../utils.h" +#include "../../../operator/leaky_relu-inl.h" + +namespace mxnet { + +inline int String2ActType(const std::string& s) { + using namespace op; + if (s == "rrelu") { + return leakyrelu::kRReLU; + } else if (s == "leaky") { + return leakyrelu::kLeakyReLU; + } else if (s == "prelu") { + return leakyrelu::kPReLU; + } else if (s == "elu") { + return leakyrelu::kELU; + } else if (s == "selu") { + return leakyrelu::kSELU; + } else if (s == "gelu") { + return leakyrelu::kGELU; + } else { + LOG(FATAL) << "unknown activation type " << s; + } + LOG(FATAL) << "should not reach here "; + return 0; +} + +MXNET_REGISTER_API("_npx.leaky_relu") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npx_leaky_relu"); + op::LeakyReLUParam param; + int args_size = args.size(); + // act_type + param.act_type = String2ActType(args[args_size - 4].operator std::string()); + // inputs + int num_inputs = param.act_type == op::leakyrelu::kPReLU ? 2 : 1; + int num_outputs = param.act_type == op::leakyrelu::kPReLU ? 2 : 1; + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + // slope + if (args[args_size - 3].type_code() == kNull) { + param.slope = 0.25f; + } else { + param.slope = args[args_size - 3].operator double(); + } + // lower_bound + if (args[args_size - 2].type_code() == kNull) { + param.lower_bound = 0.125f; + } else { + param.lower_bound = args[args_size - 2].operator double(); + } + // upper_bound + if (args[args_size - 1].type_code() == kNull) { + param.upper_bound = 0.334f; + } else { + param.upper_bound = args[args_size - 1].operator double(); + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + if (num_outputs == 1) { + *ret = ndoutputs[0]; + } else { + std::vector ndarray_handles; + ndarray_handles.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + ndarray_handles.emplace_back(ndoutputs[i]); + } + *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); + } +}); + +} // namespace mxnet diff --git a/src/api/operator/numpy_extension/npx_topk_op.cc b/src/api/operator/numpy_extension/npx_topk_op.cc new file mode 100644 index 000000000000..6fcea5ae5591 --- /dev/null +++ b/src/api/operator/numpy_extension/npx_topk_op.cc @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file npx_topk_op.cc + * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_topk_op.cc + */ +#include +#include +#include "../utils.h" +#include "../../../operator/tensor/ordering_op-inl.h" + +namespace mxnet { + +inline int String2ReturnType(const std::string& s) { + using namespace op; + if (s == "value") { + return topk_enum::kReturnValue; + } else if (s == "indices") { + return topk_enum::kReturnIndices; + } else if (s == "mask") { + return topk_enum::kReturnMask; + } else if (s == "both") { + return topk_enum::kReturnBoth; + } else { + LOG(FATAL) << "unknown return type " << s; + } + LOG(FATAL) << "should not reach here "; + return 0; +} + +MXNET_REGISTER_API("_npx.topk") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npx_topk"); + op::TopKParam param; + // inputs + int num_inputs = 1; + NDArray* inputs[] = {args[0].operator mxnet::NDArray *()}; + // axis + if (args[1].type_code() == kNull) { + param.axis = dmlc::nullopt; + } else { + param.axis = args[1].operator int(); + } + // k + if (args[2].type_code() == kNull) { + param.k = 1; + } else { + param.k = args[2].operator int(); + } + // ret_typ + param.ret_typ = String2ReturnType(args[3].operator std::string()); + // is_ascend + if (args[4].type_code() == kNull) { + param.is_ascend = false; + } else { + param.is_ascend = args[4].operator bool(); + } + // dtype + param.dtype = String2MXNetTypeWithBool(args[5].operator std::string()); + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_outputs = 1; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + if (num_outputs == 1) { + *ret = ndoutputs[0]; + } else { + std::vector ndarray_handles; + ndarray_handles.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + ndarray_handles.emplace_back(ndoutputs[i]); + } + *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); + } +}); + +} // namespace mxnet diff --git a/src/operator/leaky_relu-inl.h b/src/operator/leaky_relu-inl.h index 87755ec0bcc5..0546142af338 100644 --- a/src/operator/leaky_relu-inl.h +++ b/src/operator/leaky_relu-inl.h @@ -73,6 +73,37 @@ struct LeakyReLUParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(upper_bound).set_default(0.334f) .describe("Upper bound of random slope. (For rrelu only)"); } + std::string ActType2String(int act_type) { + switch (act_type) { + case leakyrelu::kRReLU: + return "rrelu"; + case leakyrelu::kLeakyReLU: + return "leaky"; + case leakyrelu::kPReLU: + return "prelu"; + case leakyrelu::kELU: + return "elu"; + case leakyrelu::kSELU: + return "selu"; + case leakyrelu::kGELU: + return "gelu"; + default: + LOG(FATAL) << "Unknown act_type enum " << act_type; + } + LOG(FATAL) << "should not reach here "; + return ""; + } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream act_type_s, slope_s, lower_bound_s, upper_bound_s; + act_type_s << act_type; + slope_s << slope; + lower_bound_s << lower_bound; + upper_bound_s << upper_bound; + (*dict)["act_type"] = ActType2String(act_type); + (*dict)["slope"] = slope_s.str(); + (*dict)["lower_bound"] = lower_bound_s.str(); + (*dict)["upper_bound"] = upper_bound_s.str(); + } }; template diff --git a/src/operator/nn/layer_norm-inl.h b/src/operator/nn/layer_norm-inl.h index 2c309ffebd6d..d8c8dbc7a2f5 100644 --- a/src/operator/nn/layer_norm-inl.h +++ b/src/operator/nn/layer_norm-inl.h @@ -61,9 +61,18 @@ struct LayerNormParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(output_mean_var).set_default(false) .describe("Output the mean and std calculated along the given axis."); } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream axis_s, eps_s, output_mean_var_s; + axis_s << axis; + eps_s << eps; + output_mean_var_s << output_mean_var; + (*dict)["axis"] = axis_s.str(); + (*dict)["eps"] = eps_s.str(); + (*dict)["output_mean_var"] = output_mean_var_s.str(); + } }; -static int GetRealAxis(int axis, int ndim) { +inline int GetRealAxis(int axis, int ndim) { return axis < 0 ? (axis + ndim) : axis; } diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 95f4abc4bc68..c94ac437d789 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -82,6 +82,17 @@ struct EmbeddingParam: public dmlc::Parameter { .describe("Compute row sparse gradient in the backward calculation. If set to True, " "the grad's storage type is row_sparse."); } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream input_dim_s, output_dim_s, dtype_s, sparse_grad_s; + input_dim_s << input_dim; + output_dim_s << output_dim; + dtype_s << dtype; + sparse_grad_s << sparse_grad; + (*dict)["input_dim"] = input_dim_s.str(); + (*dict)["output_dim"] = output_dim_s.str(); + (*dict)["sparse_grad"] = sparse_grad_s.str(); + (*dict)["dtype"] = MXNetTypeWithBool2String(dtype); + } }; /*! diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h index 26f24775f59e..db8d0625949b 100644 --- a/src/operator/tensor/ordering_op-inl.h +++ b/src/operator/tensor/ordering_op-inl.h @@ -95,6 +95,35 @@ struct TopKParam : public dmlc::Parameter { "An error will be raised if the selected data type cannot precisely represent the " "indices."); } + std::string ReturnType2String(int ret_typ) { + switch (ret_typ) { + case topk_enum::kReturnValue: + return "value"; + case topk_enum::kReturnIndices: + return "indices"; + case topk_enum::kReturnMask: + return "mask"; + case topk_enum::kReturnBoth: + return "both"; + default: + LOG(FATAL) << "Unknown return type enum " << ret_typ; + } + LOG(FATAL) << "should not reach here "; + return ""; + } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream axis_s, k_s, ret_typ_s, is_ascend_s, dtype_s; + axis_s << axis; + k_s << k; + dtype_s << dtype; + ret_typ_s << ret_typ; + is_ascend_s << is_ascend; + (*dict)["axis"] = axis_s.str(); + (*dict)["k"] = k_s.str(); + (*dict)["ret_typ"] = ReturnType2String(ret_typ); + (*dict)["is_ascend"] = is_ascend_s.str(); + (*dict)["dtype"] = MXNetTypeWithBool2String(dtype); + } }; struct SortParam : public dmlc::Parameter { diff --git a/tests/python/unittest/test_numpy_gluon.py b/tests/python/unittest/test_numpy_gluon.py index b65610d05e1f..bcd018157791 100644 --- a/tests/python/unittest/test_numpy_gluon.py +++ b/tests/python/unittest/test_numpy_gluon.py @@ -602,3 +602,56 @@ def test_pixelshuffle3d(): [64, 88, 65, 89, 66, 90, 67, 91], [68, 92, 69, 93, 70, 94, 71, 95]]]]] ) + +@use_np +def test_embedding(): + def check_embedding(): + layer = gluon.nn.Embedding(10, 100) + layer.initialize() + x = mx.np.array([3,4,2,0,1]) + with mx.autograd.record(): + y = layer(x) + y.backward() + assert (layer.weight.grad().asnumpy()[:5] == 1).all() + assert (layer.weight.grad().asnumpy()[5:] == 0).all() + + def check_embedding_large_input(): + embedding = mx.gluon.nn.Embedding(10, 1) + embedding.initialize() + embedding.hybridize() + shape = (20481,) + with mx.autograd.record(): + emb_in = embedding(mx.np.ones(shape)) + loss = emb_in.sum() + loss.backward() + assert embedding.weight.grad().sum().item() == 20481 + + check_embedding() + check_embedding_large_input() + + +@use_np +@pytest.mark.parametrize('dshape', [(10, ), (2, 10, 10, 10)]) +def test_layernorm(dshape): + layer = nn.LayerNorm(in_channels=10) + print("checking layer {}\nshape: {}.".format(layer, dshape)) + layer.initialize() + x = mx.np.ones(shape=dshape) + x.attach_grad() + with mx.autograd.record(): + out = layer(x) + out.backward() + + np_out = out.asnumpy() + np_dx = x.grad.asnumpy() + + layer.hybridize() + + x = mx.np.ones(shape=dshape) + x.attach_grad() + with mx.autograd.record(): + out = layer(x) + out.backward() + + mx.test_utils.assert_almost_equal(np_out, out.asnumpy(), rtol=1e-5, atol=1e-6) + mx.test_utils.assert_almost_equal(np_dx, x.grad.asnumpy(), rtol=1e-5, atol=1e-6)