Skip to content

Commit

Permalink
fix a bug in cudnn softmax activation. (apache#10918)
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da committed Jun 28, 2018
1 parent e632fb5 commit cb5c8d3
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
13 changes: 10 additions & 3 deletions src/operator/nn/cudnn/cudnn_softmax_activation-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class CuDNNSoftmaxActivationOp {
}

void Forward(const OpContext &ctx, const TBlob &in_data,
const OpReqType &req, const TBlob &out_data) {
const OpReqType &req, const TBlob &out_data) {
using namespace mshadow;
using namespace mshadow::expr;
Stream<gpu> *s = ctx.get_stream<gpu>();
Expand Down Expand Up @@ -102,14 +102,14 @@ class CuDNNSoftmaxActivationOp {
}

void Backward(const OpContext &ctx, const TBlob &out_grad,
const TBlob &out_data, const OpReqType &req, const TBlob &in_grad) {
const TBlob &out_data, const OpReqType &req,
const TBlob &in_grad) {
using namespace mshadow;
using namespace mshadow::expr;
float alpha = 1.0f;
float beta = 0.0f;
Stream<gpu> *s = ctx.get_stream<gpu>();
Tensor<gpu, 4> grad;
Tensor<gpu, 4> data;
Tensor<gpu, 4> output_data;
Tensor<gpu, 4> input_grad;
cudnnSoftmaxMode_t softmax_mode;
Expand Down Expand Up @@ -141,6 +141,13 @@ class CuDNNSoftmaxActivationOp {
softmax_mode = CUDNN_SOFTMAX_MODE_CHANNEL;
}
CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle);
CUDNN_CALL(cudnnSetTensor4dDescriptor(shape_desc_,
CUDNN_TENSOR_NCHW,
dtype_,
input_grad.shape_[0],
input_grad.shape_[1],
input_grad.shape_[2],
input_grad.shape_[3]));
CUDNN_CALL(cudnnSoftmaxBackward(s->dnn_handle_,
CUDNN_SOFTMAX_ACCURATE,
softmax_mode,
Expand Down
20 changes: 20 additions & 0 deletions tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1855,6 +1855,26 @@ def test_create_sparse_ndarray_gpu_to_cpu():
assert(same(rsp_copy.asnumpy(), rsp_created.asnumpy()))


@with_seed()
def test_softmax_activation():
gpu_a = mx.nd.array([[3., 0.5, -0.5, 2., 7.],
[2., -.4, 7., 3., 0.2]], ctx=mx.gpu(0))
cpu_a = mx.nd.array([[3., 0.5, -0.5, 2., 7.],
[2., -.4, 7., 3., 0.2]], ctx=mx.cpu())

cpu_a.attach_grad()
gpu_a.attach_grad()
with mx.autograd.record():
gpu_y = mx.nd.SoftmaxActivation(data = gpu_a)
cpu_y = mx.nd.SoftmaxActivation(data = cpu_a)
assert_almost_equal(cpu_y.asnumpy(), gpu_y.asnumpy(), atol = 1e-3, rtol = 1e-3)

gpu_y.backward()
cpu_y.backward()
assert_almost_equal(cpu_a.grad.asnumpy(), gpu_a.grad.asnumpy(),
atol = 1e-3, rtol = 1e-3)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit cb5c8d3

Please sign in to comment.