Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[v1.x Backport] Fix softmax, logsoftmax failed on empty ndarray (#18602) #18708

Merged
merged 3 commits into from
Aug 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/operator/nn/log_softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ static void LogSoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (inputs[0].shape().Size() == 0U) return;
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
if (SupportMKLDNNLogSoftmax(param, inputs[0], outputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
Expand All @@ -57,6 +58,7 @@ static void LogSoftmaxGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (inputs[0].shape().Size() == 0U) return;
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
if (SupportMKLDNNLogSoftmax(param, inputs[1], outputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
Expand Down
58 changes: 31 additions & 27 deletions src/operator/nn/softmax-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ template<typename OP, bool negate, typename AType, typename DType, typename OTyp
inline void Softmax(Stream<cpu> *s, DType *in, OType *out, IType *length,
Shape<ndim> shape, int axis, const DType temperature) {
index_t M = shape[axis];
if (M == 0) return;
index_t N = shape.Size()/M;
Shape<ndim> stride = calc_stride(shape);
Shape<ndim> sshape = shape;
Expand Down Expand Up @@ -186,6 +187,7 @@ inline void SoftmaxGrad(Stream<cpu> *s, OType *out, OType *ograd,
DType *igrad, IType *length, Shape<ndim> shape,
int axis, const DType temperature) {
index_t M = shape[axis];
if (M == 0) return;
index_t N = shape.Size()/M;
Shape<ndim> stride = calc_stride(shape);
Shape<ndim> sshape = shape;
Expand Down Expand Up @@ -402,6 +404,7 @@ inline void Softmax(Stream<gpu> *s, DType *in, OType *out, IType *length,
const int x_bits = 7;
const int x_size = 1 << x_bits;
index_t M = shape[axis];
if (M == 0 || shape.Size() == 0) return;
index_t N = shape.Size()/M;
Shape<ndim> stride = calc_stride(shape);
Shape<ndim> sshape = shape;
Expand Down Expand Up @@ -555,6 +558,7 @@ inline void SoftmaxGrad(Stream<gpu> *s, OType *out, OType *ograd,
const int x_bits = 7;
const int x_size = 1 << x_bits;
index_t M = shape[axis];
if (M == 0 || shape.Size() == 0) return;
index_t N = shape.Size()/M;
Shape<ndim> stride = calc_stride(shape);
Shape<ndim> sshape = shape;
Expand Down Expand Up @@ -775,7 +779,7 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
if (req[0] == kNullOp) return;
if (req[0] == kNullOp || inputs[0].Size() == 0U) return;
CHECK_NE(req[0], kAddTo);
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
int axis = CheckAxis(param.axis, inputs[0].ndim());
Expand All @@ -798,35 +802,35 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
type = inputs[1].type_flag_;
}
MXNET_INT32_INT64_TYPE_SWITCH(type, IType, {
IType* mask_ptr = nullptr;
if (param.use_length.value()) {
mask_ptr = inputs[1].dptr<IType>();
IType* mask_ptr = nullptr;
if (param.use_length.value()) {
mask_ptr = inputs[1].dptr<IType>();
}
if (safe_acc) {
if (shape.ndim() == 2) {
Softmax<OP, negate, AType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<OType>(), mask_ptr, shape.get<2>(),
axis, static_cast<DType>(temperature));
} else {
Softmax<OP, negate, AType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<OType>(), mask_ptr, shape.get<3>(),
axis, static_cast<DType>(temperature));
}
if (safe_acc) {
if (shape.ndim() == 2) {
Softmax<OP, negate, AType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<OType>(), mask_ptr, shape.get<2>(),
axis, static_cast<DType>(temperature));
} else {
Softmax<OP, negate, AType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<OType>(), mask_ptr, shape.get<3>(),
axis, static_cast<DType>(temperature));
}
} else {
if (shape.ndim() == 2) {
Softmax<OP, negate, DType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<OType>(), mask_ptr, shape.get<2>(),
axis, static_cast<DType>(temperature));
} else {
if (shape.ndim() == 2) {
Softmax<OP, negate, DType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<OType>(), mask_ptr, shape.get<2>(),
axis, static_cast<DType>(temperature));
} else {
Softmax<OP, negate, DType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<OType>(), mask_ptr, shape.get<3>(),
axis, static_cast<DType>(temperature));
}
Softmax<OP, negate, DType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<OType>(), mask_ptr, shape.get<3>(),
axis, static_cast<DType>(temperature));
}
}
});
});
});
Expand Down
2 changes: 2 additions & 0 deletions src/operator/nn/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (inputs[0].shape().Size() == 0U) return;
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
if (SupportMKLDNNSoftmax(param, inputs[0], outputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
Expand All @@ -58,6 +59,7 @@ static void SoftmaxGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (inputs[0].shape().Size() == 0U) return;
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
if (SupportMKLDNNSoftmax(param, inputs[1], outputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
Expand Down
6 changes: 4 additions & 2 deletions src/operator/numpy/np_boolean_mask_assign.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,9 @@ void NumpyBooleanAssignForwardCPU(const nnvm::NodeAttrs& attrs,
// If there's no True in mask, return directly
if (valid_num == 0) return;

const TShape& vshape = inputs[2].shape_;

if (inputs.size() == 3U) {
// tensor case
const TShape& vshape = inputs.at(2).shape_;
if (inputs[2].shape_.Size() != 1) {
auto vndim = vshape.ndim();
auto dndim = dshape.ndim();
Expand Down Expand Up @@ -254,6 +253,8 @@ void NumpyBooleanAssignForwardCPU(const nnvm::NodeAttrs& attrs,
}

if (inputs.size() == 3U) {
// tensor case
const TShape& vshape = inputs.at(2).shape_;
MSHADOW_TYPE_SWITCH_WITH_BOOL(data.type_flag_, DType, {
if (inputs[2].shape_.Size() == 1) {
Kernel<BooleanAssignCPUKernel<true>, cpu>::Launch(
Expand All @@ -269,6 +270,7 @@ void NumpyBooleanAssignForwardCPU(const nnvm::NodeAttrs& attrs,
}
});
} else {
// scalar case
CHECK(attrs.dict.find("value") != attrs.dict.end()) << "value needs be provided";
MSHADOW_TYPE_SWITCH_WITH_BOOL(data.type_flag_, DType, {
Kernel<BooleanAssignCPUKernel<true>, cpu>::Launch(
Expand Down
57 changes: 57 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1558,6 +1558,63 @@ def _test_batchnorm_impl(shape, fix_gamma, cudnn_off, output_mean_var,
data_grad_req,
gamma_grad_req, beta_grad_req)

@with_seed()
@use_np
def test_npx_softmax():
class TestSoftmax(HybridBlock):
def __init__(self, axis):
super(TestSoftmax, self).__init__()
self._axis = axis

def hybrid_forward(self, F, a):
return F.npx.softmax(a, axis=axis)

class TestLogSoftmax(HybridBlock):
def __init__(self, axis):
super(TestLogSoftmax, self).__init__()
self._axis = axis

def hybrid_forward(self, F, a):
return F.npx.log_softmax(a, axis=axis)

def np_softmax(x, axis=-1):
if (x.shape[axis] == 0):
return _np.sum(x, axis=axis, keepdims=True)
x = x - _np.max(x, axis=axis, keepdims=True)
x = _np.exp(x)
x /= _np.sum(x, axis=axis, keepdims=True)
return x

def np_log_softmax(x, axis=-1):
return _np.log(np_softmax(x, axis))

#(operator, function) tuples
tested_ops = [(TestSoftmax, np_softmax),
(TestLogSoftmax, np_log_softmax)]

# only testing 0-size shaped inputs here, other input cases have been tested in test_opeartor.py
for SoftmaxOp, softmax_function in tested_ops:
for hybridize in [True, False]:
for shape in [(3, 0, 4), (0, 0)]:
mx_a = np.random.uniform(size=shape)
mx_a.attach_grad()
for axis in range(-len(shape), len(shape)):
test_softmax_op = SoftmaxOp(axis)
if hybridize:
test_softmax_op.hybridize()

with mx.autograd.record():
mx_out = test_softmax_op(mx_a)

mx_out.wait_to_read()

np_out = softmax_function(mx_a.asnumpy(), axis)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, equal_nan=True)

mx_out.backward()
mx_a.grad.wait_to_read()
assert_almost_equal(mx_a.grad.asnumpy(), _np.zeros(shape), rtol=1e-3, atol=1e-5)


@with_seed()
@use_np
Expand Down