Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[FFI] fix masked_softmax #20114

Merged
merged 6 commits into from
Apr 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 8 additions & 30 deletions python/mxnet/ndarray/numpy_extension/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
used in Gluon dispatched by F=ndarray module."""

import numpy as _np
from .. import numpy as np # pylint: disable=reimported
from .._internal import NDArrayBase
from . import _api_internal
from ...util import set_module
Expand Down Expand Up @@ -134,7 +133,7 @@ def log_softmax(data, axis=-1, length=None, temperature=None, use_length=False,

# pylint: disable=too-many-arguments
@set_module('mxnet.ndarray.numpy_extension')
def masked_softmax(data, mask, axis=-1, temperature=1.0, dtype=None):
def masked_softmax(data, mask, axis=-1, temperature=1.0, normalize=True):
r"""Applies the softmax function masking elements according to the mask provided

Parameters
Expand All @@ -147,9 +146,6 @@ def masked_softmax(data, mask, axis=-1, temperature=1.0, dtype=None):
The axis along which to compute softmax.
temperature : double or None, optional, default=None
Temperature parameter in softmax
dtype : {None, 'float16', 'float32', 'float64'},optional, default='None'
DType of the output in case this can't be inferred. Defaults to
the same as input's dtype if not defined (dtype=None).
normalize : boolean or None, optional, default=1
Whether to normalize input data x: x = x - max(x)

Expand All @@ -167,22 +163,15 @@ def masked_softmax(data, mask, axis=-1, temperature=1.0, dtype=None):
>>> data = np.arange(10).reshape((2, 5))
>>> npx.masked_softmax(data, mask, axis=0)
array([[0.00669285, 0. , 0.00669285, 0. , 0.00669285],
[0.9933072 , 0. , 0.9933072 , 0. , 0.9933072 ]])
[0.9933072 , 0. , 0.9933072 , 0. , 0.9933072 ]])
"""
if mask is not None:
neg = -1e18
if _np.dtype(dtype) == _np.float16:
neg = -1e4
data = np.where(mask, data, neg)
logits = (softmax(data, axis=axis) / temperature) * mask
else:
logits = softmax(data, axis=axis) / temperature
return logits
assert data is not None and mask is not None, "Missing input data and mask"
return _api_internal.masked_softmax(data, mask, axis, temperature, normalize)


# pylint: disable=too-many-arguments
@set_module('mxnet.ndarray.numpy_extension')
def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None):
def masked_log_softmax(data, mask, axis=-1, temperature=1.0, normalize=True):
r"""Computes the masked log softmax of the input.
This is equivalent to computing masked softmax followed by log.

Expand All @@ -196,9 +185,6 @@ def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None):
The axis along which to compute softmax.
temperature : double or None, optional, default=None
Temperature parameter in softmax
dtype : {None, 'float16', 'float32', 'float64'},optional, default='None'
DType of the output in case this can't be inferred. Defaults to
the same as input's dtype if not defined (dtype=None).
normalize : boolean or None, optional, default=1
Whether to normalize input data x: x = x - max(x)

Expand All @@ -216,18 +202,10 @@ def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None):
>>> data = np.arange(10).reshape((2, 5))
>>> npx.masked_log_softmax(data, mask, axis=0)
array([[-5.0067153 , -inf, -5.0067153 , -inf, -5.0067153 ],
[-0.00671535, -inf, -0.00671535, -inf, -0.00671535]])
[-0.00671535, -inf, -0.00671535, -inf, -0.00671535]])
"""
if mask is not None:
neg = -1e18
inf = -_np.inf
if _np.dtype(dtype) == _np.float16:
neg = -1e4
data = np.where(mask, data, neg)
logits = np.where(mask, log_softmax(data, axis=axis) / temperature, inf)
else:
logits = log_softmax(data, axis=axis) / temperature
return logits
assert data is not None and mask is not None, "Missing input data and mask"
return _api_internal.masked_log_softmax(data, mask, axis, temperature, normalize)


# pylint: disable=too-many-arguments, unused-argument
Expand Down
18 changes: 6 additions & 12 deletions python/mxnet/numpy_extension/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def log_softmax(data, axis=-1, length=None, temperature=None, use_length=False,

# pylint: disable=too-many-arguments
@set_module('mxnet.numpy_extension')
def masked_softmax(data, mask, axis=-1, temperature=1.0, dtype=None):
def masked_softmax(data, mask, axis=-1, temperature=1.0, normalize=True):
r"""Applies the softmax function masking elements according to the mask provided

Parameters
Expand All @@ -131,9 +131,6 @@ def masked_softmax(data, mask, axis=-1, temperature=1.0, dtype=None):
The axis along which to compute softmax.
temperature : double or None, optional, default=None
Temperature parameter in softmax
dtype : {None, 'float16', 'float32', 'float64'},optional, default='None'
DType of the output in case this can't be inferred. Defaults to
the same as input's dtype if not defined (dtype=None).
normalize : boolean or None, optional, default=1
Whether to normalize input data x: x = x - max(x)

Expand All @@ -151,15 +148,15 @@ def masked_softmax(data, mask, axis=-1, temperature=1.0, dtype=None):
>>> data = np.arange(10).reshape((2, 5))
>>> npx.masked_softmax(data, mask, axis=0)
array([[0.00669285, 0. , 0.00669285, 0. , 0.00669285],
[0.9933072 , 0. , 0.9933072 , 0. , 0.9933072 ]])
[0.9933072 , 0. , 0.9933072 , 0. , 0.9933072 ]])
"""
return _mx_nd_npx.masked_softmax(data, mask, axis=axis, temperature=temperature,
dtype=dtype)
normalize=normalize)


# pylint: disable=too-many-arguments
@set_module('mxnet.numpy_extension')
def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None):
def masked_log_softmax(data, mask, axis=-1, temperature=1.0, normalize=True):
r"""Computes the masked log softmax of the input.
This is equivalent to computing masked softmax followed by log.

Expand All @@ -173,9 +170,6 @@ def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None):
The axis along which to compute softmax.
temperature : double or None, optional, default=None
Temperature parameter in softmax
dtype : {None, 'float16', 'float32', 'float64'},optional, default='None'
DType of the output in case this can't be inferred. Defaults to
the same as input's dtype if not defined (dtype=None).
normalize : boolean or None, optional, default=1
Whether to normalize input data x: x = x - max(x)

Expand All @@ -193,10 +187,10 @@ def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None):
>>> data = np.arange(10).reshape((2, 5))
>>> npx.masked_log_softmax(data, mask, axis=0)
array([[-5.0067153 , -inf, -5.0067153 , -inf, -5.0067153 ],
[-0.00671535, -inf, -0.00671535, -inf, -0.00671535]])
[-0.00671535, -inf, -0.00671535, -inf, -0.00671535]])
"""
return _mx_nd_npx.masked_log_softmax(data, mask, axis=axis, temperature=temperature,
dtype=dtype)
normalize=normalize)


# pylint: disable=too-many-arguments, unused-argument
Expand Down
100 changes: 96 additions & 4 deletions src/api/operator/numpy_extension/npx_softmax_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,17 @@ MXNET_REGISTER_API("_npx.softmax")
// parse axis
if (args[args_size - 4].type_code() == kDLInt) {
param.axis = args[args_size - 4].operator int();
} else {
} else if (args[args_size - 4].type_code() == kDLFloat) {
param.axis = static_cast<int>(args[args_size - 4].operator double());
} else {
param.axis = -1;
}

// parse temperature
if (args[args_size - 3].type_code() == kNull) {
param.temperature = dmlc::nullopt;
} else {
param.temperature = args[args_size - 3].operator int64_t();
param.temperature = args[args_size - 3].operator double();
}

// parse dtype
Expand Down Expand Up @@ -106,15 +108,17 @@ MXNET_REGISTER_API("_npx.log_softmax")
// parse axis
if (args[args_size - 4].type_code() == kDLInt) {
param.axis = args[args_size - 4].operator int();
} else {
} else if (args[args_size - 4].type_code() == kDLFloat) {
param.axis = static_cast<int>(args[args_size - 4].operator double());
} else {
param.axis = -1;
}

// parse temperature
if (args[args_size - 3].type_code() == kNull) {
param.temperature = dmlc::nullopt;
} else {
param.temperature = args[args_size - 3].operator int64_t();
param.temperature = args[args_size - 3].operator double();
}

// parse dtype
Expand All @@ -133,4 +137,92 @@ MXNET_REGISTER_API("_npx.log_softmax")
*ret = ndoutputs[0];
});

MXNET_REGISTER_API("_npx.masked_softmax")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
nnvm::NodeAttrs attrs;
static const nnvm::Op* op = Op::Get("_npx_masked_softmax");
op::MaskedSoftmaxParam param;

// inputs
int num_inputs = 2;
std::vector<NDArray*> inputs;
inputs.reserve(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
inputs.push_back(args[i].operator mxnet::NDArray*());
}
// parse axis
if (args[2].type_code() == kDLInt) {
param.axis = args[2].operator int();
} else if (args[2].type_code() == kDLFloat) {
param.axis = static_cast<int>(args[2].operator double());
} else {
param.axis = -1;
}
// parse temperature
if (args[3].type_code() == kNull) {
param.temperature = dmlc::nullopt;
} else {
param.temperature = args[3].operator double();
}
// parse normalize
if (args[4].type_code() == kNull) {
param.normalize = true;
} else {
param.normalize = args[4].operator bool();
}

attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::MaskedSoftmaxParam>(&attrs);

int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr);
*ret = ndoutputs[0];
});

MXNET_REGISTER_API("_npx.masked_log_softmax")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
nnvm::NodeAttrs attrs;
static const nnvm::Op* op = Op::Get("_npx_masked_log_softmax");
op::MaskedSoftmaxParam param;

// inputs
int num_inputs = 2;
std::vector<NDArray*> inputs;
inputs.reserve(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
inputs.push_back(args[i].operator mxnet::NDArray*());
}
// parse axis
if (args[2].type_code() == kDLInt) {
param.axis = args[2].operator int();
} else if (args[2].type_code() == kDLFloat) {
param.axis = static_cast<int>(args[2].operator double());
} else {
param.axis = -1;
}
// parse temperature
if (args[3].type_code() == kNull) {
param.temperature = dmlc::nullopt;
} else {
param.temperature = args[3].operator double();
}
// parse normalize
if (args[4].type_code() == kNull) {
param.normalize = true;
} else {
param.normalize = args[4].operator bool();
}

attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::MaskedSoftmaxParam>(&attrs);

int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr);
*ret = ndoutputs[0];
});

} // namespace mxnet
10 changes: 9 additions & 1 deletion src/operator/nn/softmax-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1199,7 +1199,6 @@ struct SoftmaxParam : public dmlc::Parameter<SoftmaxParam> {
struct MaskedSoftmaxParam : public dmlc::Parameter<MaskedSoftmaxParam> {
int axis;
dmlc::optional<double> temperature;
dmlc::optional<int> dtype;
dmlc::optional<bool> normalize;
DMLC_DECLARE_PARAMETER(MaskedSoftmaxParam) {
DMLC_DECLARE_FIELD(axis).set_default(-1)
Expand All @@ -1210,6 +1209,15 @@ struct MaskedSoftmaxParam : public dmlc::Parameter<MaskedSoftmaxParam> {
.set_default(dmlc::optional<bool>(true))
.describe("Whether to normalize input data x: x = x - max(x)");
}
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream axis_s, temperature_s, normalize_s;
axis_s << axis;
temperature_s << temperature;
normalize_s << normalize;
(*dict)["axis"] = axis_s.str();
(*dict)["temperature"] = temperature_s.str();
(*dict)["normalize"] = normalize_s.str();
}
};

static inline bool softmax_has_dtype_override(const nnvm::NodeAttrs& attrs) {
Expand Down
Loading