Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Use cuDNN for conv bias and bias grad
Browse files Browse the repository at this point in the history
  • Loading branch information
Vladimir Cherepanov committed Dec 8, 2021
1 parent 40359ce commit 9d1df0f
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 38 deletions.
105 changes: 105 additions & 0 deletions src/operator/cudnn_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,111 @@ void ConvWgrad::Exec(const cudnn_cxx::Descriptor& plan,
CUDNN_CALL(cudnnBackendExecute(s->dnn_handle_, plan.get(), var_pack.get()));
}

struct LegacyTensorDestroyer {
using pointer = cudnnTensorDescriptor_t;

void operator()(cudnnTensorDescriptor_t desc) {
CUDNN_CALL_NONFATAL(cudnnDestroyTensorDescriptor(desc));
}
};

using LegacyTensor = std::unique_ptr<cudnnTensorDescriptor_t, LegacyTensorDestroyer>;

LegacyTensor MakeLegacyTensor() {
cudnnTensorDescriptor_t desc{};
CUDNN_CALL(cudnnCreateTensorDescriptor(&desc));
return LegacyTensor(desc);
}

union ScalingParam {
double d;
float f;
};

std::pair<ScalingParam, ScalingParam> AlphaBeta(int type_flag, double init_a, double init_b) {
ScalingParam a, b;
switch (type_flag) {
case kFloat64:
a.d = init_a;
b.d = init_b;
break;
case kFloat32: // fallthrough
case kFloat16:
a.f = init_a;
b.f = init_b;
break;
default:
LOG(FATAL) << "Unexpected type: " << type_flag;
}
return {a, b};
}

void SetLegacyTensor(cudnnTensorDescriptor_t desc, const TBlob& blob, const LayoutInfo& li) {
std::vector<int> dims(blob.shape_.ndim());
CHECK_EQ(dims.size(), li.n_space_dims + 2);
auto rev_order = ReverseOrder(li.Order());
for (size_t i = 0; i < dims.size(); ++i)
dims[i] = blob.shape_[rev_order[i]];
auto strides64 = li.Strides(std::vector<int64_t>(dims.begin(), dims.end()));
std::vector<int> strides(strides64.begin(), strides64.end());

auto type = static_cast<mshadow::TypeFlag>(blob.type_flag_);
CUDNN_CALL(cudnnSetTensorNdDescriptor(desc, CudnnType(type), dims.size(), &dims[0], &strides[0]));
}

void SetLegacyCTensorExpandDims(cudnnTensorDescriptor_t desc,
const TBlob& blob,
const LayoutInfo& li) {
std::vector<int> dims(li.n_space_dims + 2, 1);
dims[1] = blob.shape_[0];
std::vector<int> strides(dims.size(), 1);
strides[0] = blob.shape_[0];

auto type = static_cast<mshadow::TypeFlag>(blob.type_flag_);
CUDNN_CALL(cudnnSetTensorNdDescriptor(desc, CudnnType(type), dims.size(), &dims[0], &strides[0]));
}

bool LegacyAddBias(const OpContext& ctx, const LayoutInfo& li, const TBlob& y, const TBlob& b) {
thread_local auto y_desc = MakeLegacyTensor();
thread_local auto b_desc = MakeLegacyTensor();

auto s = ctx.get_stream<gpu>();
auto [alpha, beta] = AlphaBeta(y.type_flag_, 1.0, 1.0); // NOLINT(whitespace/braces)

SetLegacyTensor(y_desc.get(), y, li);
SetLegacyCTensorExpandDims(b_desc.get(), b, li);

auto err =
cudnnAddTensor(s->dnn_handle_, &alpha, b_desc.get(), b.dptr_, &beta, y_desc.get(), y.dptr_);
if (err == CUDNN_STATUS_NOT_SUPPORTED)
return false;
CHECK_EQ(err, CUDNN_STATUS_SUCCESS);
return true;
}

bool LegacyBiasGrad(const OpContext& ctx,
const LayoutInfo& li,
bool add_to,
const TBlob& db,
const TBlob& dy) {
thread_local auto db_desc = MakeLegacyTensor();
thread_local auto dy_desc = MakeLegacyTensor();

auto s = ctx.get_stream<gpu>();
// NOLINT_NEXT_LINE(whitespace/braces)
auto [alpha, beta] = AlphaBeta(dy.type_flag_, 1.0, add_to ? 1.0 : 0.0);

SetLegacyCTensorExpandDims(db_desc.get(), db, li);
SetLegacyTensor(dy_desc.get(), dy, li);

auto err = cudnnConvolutionBackwardBias(
s->dnn_handle_, &alpha, dy_desc.get(), dy.dptr_, &beta, db_desc.get(), db.dptr_);
if (err == CUDNN_STATUS_NOT_SUPPORTED)
return false;
CHECK_EQ(err, CUDNN_STATUS_SUCCESS);
return true;
}

} // namespace cudnn
} // namespace op
} // namespace mxnet
Expand Down
8 changes: 8 additions & 0 deletions src/operator/cudnn_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,14 @@ struct ConvWgrad {
const TBlob& dw);
};

bool LegacyAddBias(const OpContext& ctx, const LayoutInfo& li, const TBlob& y, const TBlob& b);

bool LegacyBiasGrad(const OpContext& ctx,
const LayoutInfo& li,
bool add_to,
const TBlob& db,
const TBlob& dy);

} // namespace cudnn
} // namespace op
} // namespace mxnet
Expand Down
44 changes: 25 additions & 19 deletions src/operator/nn/convolution.cu
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,17 @@ void ConvolutionCompute<gpu>(const nnvm::NodeAttrs& attrs,
if (ok && !param.no_bias) {
CHECK_EQ(inputs[conv::kBias].shape_.ndim(), 1);
auto layout = static_cast<mshadow::LayoutFlag>(param.layout.value());
int k = inputs[conv::kBias].shape_.Size();
auto b = inputs[conv::kBias].reshape(cudnn::ExpandChannelDims(layout, k));
BinaryBroadcastRTCCompute{"add"}( // NOLINT(whitespace/braces)
attrs,
ctx,
{outputs[conv::kOut], b},
{kWriteInplace},
{outputs[conv::kOut]});
auto li = cudnn::GetLayoutInfo(layout);
if (!cudnn::LegacyAddBias(ctx, li, outputs[conv::kOut], inputs[conv::kBias])) {
int k = inputs[conv::kBias].shape_.Size();
auto b = inputs[conv::kBias].reshape(cudnn::ExpandChannelDims(layout, k));
BinaryBroadcastRTCCompute{"add"}( // NOLINT(whitespace/braces)
attrs,
ctx,
{outputs[conv::kOut], b},
{kWriteInplace},
{outputs[conv::kOut]});
}
}
if (!ok) {
if (!param.cudnn_off)
Expand Down Expand Up @@ -137,17 +140,20 @@ void ConvolutionGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
cudnn::Exec<cudnn::ConvWgrad>(
ctx, conv_param, inputs[1 + conv::kData], inputs[0], outputs[conv::kWeight]));
if (ok && !param.no_bias && req[conv::kBias] != kNullOp) {
auto li = cudnn::GetLayoutInfo(static_cast<mshadow::LayoutFlag>(param.layout.value()));
if (li.channel_last) {
// This kernel should be faster.
auto y_grad = FlattenAs2DHead<gpu, DType>(inputs[0], ctx);
AddBiasGrad(outputs[conv::kBias], y_grad, req[conv::kBias], param.num_filter, ctx);
} else {
TShape axes{static_cast<int>(li.ChannelIdx())};
TShape small =
ReduceAxesShapeImpl(inputs[0].shape_, dmlc::optional<mxnet::TShape>(axes), true, true);
ReduceAxesRTCComputeImpl(
ctx, {inputs[0]}, {req[conv::kBias]}, {outputs[conv::kBias]}, small, "red::sum{}");
auto li = cudnn::GetLayoutInfo(static_cast<mshadow::LayoutFlag>(param.layout.value()));
auto add_to = req[conv::kBias] == kAddTo;
if (!cudnn::LegacyBiasGrad(ctx, li, add_to, outputs[conv::kBias], inputs[0])) {
if (li.channel_last) {
// This kernel should be faster.
auto y_grad = FlattenAs2DHead<gpu, DType>(inputs[0], ctx);
AddBiasGrad(outputs[conv::kBias], y_grad, req[conv::kBias], param.num_filter, ctx);
} else {
TShape axes{static_cast<int>(li.ChannelIdx())};
TShape small = ReduceAxesShapeImpl(
inputs[0].shape_, dmlc::optional<mxnet::TShape>(axes), true, true);
ReduceAxesRTCComputeImpl(
ctx, {inputs[0]}, {req[conv::kBias]}, {outputs[conv::kBias]}, small, "red::sum{}");
}
}
}
if (!ok) {
Expand Down
48 changes: 29 additions & 19 deletions src/operator/nn/deconvolution.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,17 @@ void DeconvolutionCompute<gpu>(const nnvm::NodeAttrs& attrs,
if (ok && !param.no_bias) {
CHECK_EQ(inputs[deconv::kBias].shape_.ndim(), 1);
auto layout = static_cast<mshadow::LayoutFlag>(param.layout.value());
int k = inputs[deconv::kBias].shape_.Size();
auto b = inputs[deconv::kBias].reshape(cudnn::ExpandChannelDims(layout, k));
BinaryBroadcastRTCCompute{"add"}( // NOLINT(whitespace/braces)
attrs,
ctx,
{outputs[deconv::kOut], b},
{kWriteInplace},
{outputs[deconv::kOut]});
auto li = cudnn::GetLayoutInfo(layout);
if (!cudnn::LegacyAddBias(ctx, li, outputs[deconv::kOut], inputs[deconv::kBias])) {
int k = inputs[deconv::kBias].shape_.Size();
auto b = inputs[deconv::kBias].reshape(cudnn::ExpandChannelDims(layout, k));
BinaryBroadcastRTCCompute{"add"}( // NOLINT(whitespace/braces)
attrs,
ctx,
{outputs[deconv::kOut], b},
{kWriteInplace},
{outputs[deconv::kOut]});
}
}
if (!ok) {
if (!param.cudnn_off)
Expand Down Expand Up @@ -115,17 +118,24 @@ void DeconvolutionGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
cudnn::Exec<cudnn::ConvWgrad>(
ctx, conv_param, inputs[0], inputs[1 + deconv::kData], outputs[deconv::kWeight]));
if (ok && !param.no_bias && req[deconv::kBias] != kNullOp) {
auto li = cudnn::GetLayoutInfo(static_cast<mshadow::LayoutFlag>(param.layout.value()));
if (li.channel_last) {
// This kernel should be faster.
auto y_grad = FlattenAs2DHead<gpu, DType>(inputs[0], ctx);
AddBiasGrad(outputs[deconv::kBias], y_grad, req[deconv::kBias], param.num_filter, ctx);
} else {
TShape axes{static_cast<int>(li.ChannelIdx())};
TShape small =
ReduceAxesShapeImpl(inputs[0].shape_, dmlc::optional<mxnet::TShape>(axes), true, true);
ReduceAxesRTCComputeImpl(
ctx, {inputs[0]}, {req[deconv::kBias]}, {outputs[deconv::kBias]}, small, "red::sum{}");
auto li = cudnn::GetLayoutInfo(static_cast<mshadow::LayoutFlag>(param.layout.value()));
auto add_to = req[conv::kBias] == kAddTo;
if (!cudnn::LegacyBiasGrad(ctx, li, add_to, outputs[deconv::kBias], inputs[0])) {
if (li.channel_last) {
// This kernel should be faster.
auto y_grad = FlattenAs2DHead<gpu, DType>(inputs[0], ctx);
AddBiasGrad(outputs[deconv::kBias], y_grad, req[deconv::kBias], param.num_filter, ctx);
} else {
TShape axes{static_cast<int>(li.ChannelIdx())};
TShape small = ReduceAxesShapeImpl(
inputs[0].shape_, dmlc::optional<mxnet::TShape>(axes), true, true);
ReduceAxesRTCComputeImpl(ctx,
{inputs[0]},
{req[deconv::kBias]},
{outputs[deconv::kBias]},
small,
"red::sum{}");
}
}
}
if (!ok) {
Expand Down

0 comments on commit 9d1df0f

Please sign in to comment.