From ef953fe90e5b36f98037b59b83b4ffa7939cc5d0 Mon Sep 17 00:00:00 2001 From: Yiyan66 <57363390+Yiyan66@users.noreply.github.com> Date: Mon, 27 Apr 2020 09:38:04 +0800 Subject: [PATCH 1/3] [v1.x] Backport of fix npx.softmax for 0-sized inputs (#18158) Co-authored-by: Hao Jin --- src/operator/nn/softmax-inl.h | 56 +++++++++++--------- src/operator/numpy/np_boolean_mask_assign.cc | 6 ++- tests/python/unittest/test_numpy_op.py | 38 +++++++++++++ 3 files changed, 72 insertions(+), 28 deletions(-) diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index f8a3fe429c53..f1f41778a9bd 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -71,6 +71,7 @@ template *s, DType *in, OType *out, IType *length, Shape shape, int axis, const DType temperature) { index_t M = shape[axis]; + if (M == 0) return; index_t N = shape.Size()/M; Shape stride = calc_stride(shape); Shape sshape = shape; @@ -186,6 +187,7 @@ inline void SoftmaxGrad(Stream *s, OType *out, OType *ograd, DType *igrad, IType *length, Shape shape, int axis, const DType temperature) { index_t M = shape[axis]; + if (M == 0) return; index_t N = shape.Size()/M; Shape stride = calc_stride(shape); Shape sshape = shape; @@ -402,6 +404,7 @@ inline void Softmax(Stream *s, DType *in, OType *out, IType *length, const int x_bits = 7; const int x_size = 1 << x_bits; index_t M = shape[axis]; + if (M == 0 || shape.Size() == 0) return; index_t N = shape.Size()/M; Shape stride = calc_stride(shape); Shape sshape = shape; @@ -555,6 +558,7 @@ inline void SoftmaxGrad(Stream *s, OType *out, OType *ograd, const int x_bits = 7; const int x_size = 1 << x_bits; index_t M = shape[axis]; + if (M == 0 || shape.Size() == 0) return; index_t N = shape.Size()/M; Shape stride = calc_stride(shape); Shape sshape = shape; @@ -798,35 +802,35 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs, type = inputs[1].type_flag_; } MXNET_INT32_INT64_TYPE_SWITCH(type, IType, { - IType* mask_ptr = nullptr; - if (param.use_length.value()) { - mask_ptr = inputs[1].dptr(); + IType* mask_ptr = nullptr; + if (param.use_length.value()) { + mask_ptr = inputs[1].dptr(); + } + if (safe_acc) { + if (shape.ndim() == 2) { + Softmax( + ctx.get_stream(), inputs[0].dptr(), + outputs[0].dptr(), mask_ptr, shape.get<2>(), + axis, static_cast(temperature)); + } else { + Softmax( + ctx.get_stream(), inputs[0].dptr(), + outputs[0].dptr(), mask_ptr, shape.get<3>(), + axis, static_cast(temperature)); } - if (safe_acc) { - if (shape.ndim() == 2) { - Softmax( - ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), mask_ptr, shape.get<2>(), - axis, static_cast(temperature)); - } else { - Softmax( - ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), mask_ptr, shape.get<3>(), - axis, static_cast(temperature)); - } + } else { + if (shape.ndim() == 2) { + Softmax( + ctx.get_stream(), inputs[0].dptr(), + outputs[0].dptr(), mask_ptr, shape.get<2>(), + axis, static_cast(temperature)); } else { - if (shape.ndim() == 2) { - Softmax( - ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), mask_ptr, shape.get<2>(), - axis, static_cast(temperature)); - } else { - Softmax( - ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), mask_ptr, shape.get<3>(), - axis, static_cast(temperature)); - } + Softmax( + ctx.get_stream(), inputs[0].dptr(), + outputs[0].dptr(), mask_ptr, shape.get<3>(), + axis, static_cast(temperature)); } + } }); }); }); diff --git a/src/operator/numpy/np_boolean_mask_assign.cc b/src/operator/numpy/np_boolean_mask_assign.cc index ef7cce4d2491..cc58b5afc015 100644 --- a/src/operator/numpy/np_boolean_mask_assign.cc +++ b/src/operator/numpy/np_boolean_mask_assign.cc @@ -221,10 +221,9 @@ void NumpyBooleanAssignForwardCPU(const nnvm::NodeAttrs& attrs, // If there's no True in mask, return directly if (valid_num == 0) return; - const TShape& vshape = inputs[2].shape_; - if (inputs.size() == 3U) { // tensor case + const TShape& vshape = inputs.at(2).shape_; if (inputs[2].shape_.Size() != 1) { auto vndim = vshape.ndim(); auto dndim = dshape.ndim(); @@ -254,6 +253,8 @@ void NumpyBooleanAssignForwardCPU(const nnvm::NodeAttrs& attrs, } if (inputs.size() == 3U) { + // tensor case + const TShape& vshape = inputs.at(2).shape_; MSHADOW_TYPE_SWITCH_WITH_BOOL(data.type_flag_, DType, { if (inputs[2].shape_.Size() == 1) { Kernel, cpu>::Launch( @@ -269,6 +270,7 @@ void NumpyBooleanAssignForwardCPU(const nnvm::NodeAttrs& attrs, } }); } else { + // scalar case CHECK(attrs.dict.find("value") != attrs.dict.end()) << "value needs be provided"; MSHADOW_TYPE_SWITCH_WITH_BOOL(data.type_flag_, DType, { Kernel, cpu>::Launch( diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 5d3c03e53963..7dcaf72a4e75 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1558,6 +1558,44 @@ def _test_batchnorm_impl(shape, fix_gamma, cudnn_off, output_mean_var, data_grad_req, gamma_grad_req, beta_grad_req) +@with_seed() +@use_np +def test_npx_softmax(): + class TestSoftmax(HybridBlock): + def __init__(self, axis): + super(TestSoftmax, self).__init__() + self._axis = axis + + def hybrid_forward(self, F, a): + return F.npx.softmax(a, axis=axis) + + def np_softmax(x, axis=-1): + if (x.shape[axis] == 0): + return _np.sum(x, axis=axis, keepdims=True) + x = x - _np.max(x, axis=axis, keepdims=True) + x = _np.exp(x) + x /= _np.sum(x, axis=axis, keepdims=True) + return x + + # only testing 0-size shaped inputs here, other input cases have been tested in test_opeartor.py + for hybridize in [True, False]: + for shape in [(3, 0, 4), (0, 0)]: + mx_a = np.random.uniform(size=shape) + mx_a.attach_grad() + for axis in range(-len(shape), len(shape)): + test_softmax = TestSoftmax(axis) + if hybridize: + test_softmax.hybridize() + + with mx.autograd.record(): + mx_out = test_softmax(mx_a) + + np_out = np_softmax(mx_a.asnumpy(), axis) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, equal_nan=True) + + mx_out.backward() + assert_almost_equal(mx_a.grad.asnumpy(), _np.zeros(shape), rtol=1e-3, atol=1e-5) + @with_seed() @use_np From 0787b56f6bd70dbb719352906b8a6c0a9765e348 Mon Sep 17 00:00:00 2001 From: bgawrych Date: Wed, 1 Jul 2020 16:43:06 +0200 Subject: [PATCH 2/3] Fix softmax, logsoftmax failed on empty ndarray (#18602) * Fix failing empty array (log_)softmax * Modify test for npx (log_)softmax --- src/operator/nn/log_softmax.cc | 1 + src/operator/nn/softmax-inl.h | 2 +- src/operator/nn/softmax.cc | 1 + tests/python/unittest/test_numpy_op.py | 46 ++++++++++++++++++-------- 4 files changed, 35 insertions(+), 15 deletions(-) diff --git a/src/operator/nn/log_softmax.cc b/src/operator/nn/log_softmax.cc index 16324b51c322..f3ef4abb9f6d 100644 --- a/src/operator/nn/log_softmax.cc +++ b/src/operator/nn/log_softmax.cc @@ -40,6 +40,7 @@ static void LogSoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { + if (inputs[0].shape().Size() == 0U) return; const SoftmaxParam& param = nnvm::get(attrs.parsed); if (SupportMKLDNNLogSoftmax(param, inputs[0], outputs[0])) { MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index f1f41778a9bd..018d851336d2 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -779,7 +779,7 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { using namespace mxnet_op; - if (req[0] == kNullOp) return; + if (req[0] == kNullOp || inputs[0].Size() == 0U) return; CHECK_NE(req[0], kAddTo); const SoftmaxParam& param = nnvm::get(attrs.parsed); int axis = CheckAxis(param.axis, inputs[0].ndim()); diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc index 50cfc2f713f4..b95e159f9862 100644 --- a/src/operator/nn/softmax.cc +++ b/src/operator/nn/softmax.cc @@ -41,6 +41,7 @@ static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { + if (inputs[0].shape().Size() == 0U) return; const SoftmaxParam& param = nnvm::get(attrs.parsed); if (SupportMKLDNNSoftmax(param, inputs[0], outputs[0])) { MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 7dcaf72a4e75..91f84bb27eb0 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1569,6 +1569,14 @@ def __init__(self, axis): def hybrid_forward(self, F, a): return F.npx.softmax(a, axis=axis) + class TestLogSoftmax(HybridBlock): + def __init__(self, axis): + super(TestLogSoftmax, self).__init__() + self._axis = axis + + def hybrid_forward(self, F, a): + return F.npx.log_softmax(a, axis=axis) + def np_softmax(x, axis=-1): if (x.shape[axis] == 0): return _np.sum(x, axis=axis, keepdims=True) @@ -1577,24 +1585,34 @@ def np_softmax(x, axis=-1): x /= _np.sum(x, axis=axis, keepdims=True) return x + def np_log_softmax(x, axis=-1): + return _np.log(np_softmax(x, axis)) + + #(operator, function) tuples + tested_ops = [(TestSoftmax, np_softmax), + (TestLogSoftmax, np_log_softmax)] + # only testing 0-size shaped inputs here, other input cases have been tested in test_opeartor.py - for hybridize in [True, False]: - for shape in [(3, 0, 4), (0, 0)]: - mx_a = np.random.uniform(size=shape) - mx_a.attach_grad() - for axis in range(-len(shape), len(shape)): - test_softmax = TestSoftmax(axis) - if hybridize: - test_softmax.hybridize() + for SoftmaxOp, softmax_function in tested_ops: + for hybridize in [True, False]: + for shape in [(3, 0, 4), (0, 0)]: + mx_a = np.random.uniform(size=shape) + mx_a.attach_grad() + for axis in range(-len(shape), len(shape)): + test_softmax_op = SoftmaxOp(axis) + if hybridize: + test_softmax_op.hybridize() - with mx.autograd.record(): - mx_out = test_softmax(mx_a) + with mx.autograd.record(): + mx_out = test_softmax_op(mx_a) - np_out = np_softmax(mx_a.asnumpy(), axis) - assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, equal_nan=True) + mx_out.wait_to_read() - mx_out.backward() - assert_almost_equal(mx_a.grad.asnumpy(), _np.zeros(shape), rtol=1e-3, atol=1e-5) + np_out = softmax_function(mx_a.asnumpy(), axis) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, equal_nan=True) + + mx_out.backward() + assert_almost_equal(mx_a.grad.asnumpy(), _np.zeros(shape), rtol=1e-3, atol=1e-5) @with_seed() From 55cb01288602448ea169c2a65dfd04ddcb46a5d0 Mon Sep 17 00:00:00 2001 From: Bart Gawrych Date: Mon, 20 Jul 2020 12:06:05 +0200 Subject: [PATCH 3/3] Fix softmax, logsoftmax backward failed on empty ndarray (#18710) --- src/operator/nn/log_softmax.cc | 1 + src/operator/nn/softmax.cc | 1 + tests/python/unittest/test_numpy_op.py | 1 + 3 files changed, 3 insertions(+) diff --git a/src/operator/nn/log_softmax.cc b/src/operator/nn/log_softmax.cc index f3ef4abb9f6d..28ae8cf361ec 100644 --- a/src/operator/nn/log_softmax.cc +++ b/src/operator/nn/log_softmax.cc @@ -58,6 +58,7 @@ static void LogSoftmaxGradComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { + if (inputs[0].shape().Size() == 0U) return; const SoftmaxParam& param = nnvm::get(attrs.parsed); if (SupportMKLDNNLogSoftmax(param, inputs[1], outputs[0])) { MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc index b95e159f9862..9b28b71560bd 100644 --- a/src/operator/nn/softmax.cc +++ b/src/operator/nn/softmax.cc @@ -59,6 +59,7 @@ static void SoftmaxGradComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { + if (inputs[0].shape().Size() == 0U) return; const SoftmaxParam& param = nnvm::get(attrs.parsed); if (SupportMKLDNNSoftmax(param, inputs[1], outputs[0])) { MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 91f84bb27eb0..97c7d8675495 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1612,6 +1612,7 @@ def np_log_softmax(x, axis=-1): assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, equal_nan=True) mx_out.backward() + mx_a.grad.wait_to_read() assert_almost_equal(mx_a.grad.asnumpy(), _np.zeros(shape), rtol=1e-3, atol=1e-5)