Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cherry-pick] support data_format='NHWC' for prelu channel mode #38495

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions paddle/fluid/inference/tensorrt/convert/prelu_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ class PReluOpConverter : public OpConverter {
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
// Get attrs
std::string mode = BOOST_GET_CONST(std::string, op_desc.GetAttr("mode"));
std::string data_format = "NCHW";
if (op_desc.HasAttr("data_format")) {
data_format =
BOOST_GET_CONST(std::string, op_desc.GetAttr("data_format"));
}
auto* alpha_var = scope.FindVar(op_desc.Input("Alpha")[0]);
auto* alpha_tensor = alpha_var->GetMutable<framework::LoDTensor>();

Expand All @@ -47,7 +52,7 @@ class PReluOpConverter : public OpConverter {
nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) {
plugin::PReluPluginDynamic* plugin = new plugin::PReluPluginDynamic(
alpha_data, alpha_tensor_temp->numel(), mode);
alpha_data, alpha_tensor_temp->numel(), mode, data_format);
layer = engine_->AddDynamicPlugin(&input, input_num, plugin);
} else {
#if IS_TRT_VERSION_GE(7000)
Expand All @@ -74,8 +79,8 @@ class PReluOpConverter : public OpConverter {
layer = TRT_ENGINE_ADD_LAYER(engine_, ParametricReLU, *input,
*alpha_layer_output);
#else
plugin::PReluPlugin* plugin =
new plugin::PReluPlugin(alpha_data, alpha_tensor_temp->numel(), mode);
plugin::PReluPlugin* plugin = new plugin::PReluPlugin(
alpha_data, alpha_tensor_temp->numel(), mode, data_format);
layer = engine_->AddPlugin(&input, input_num, plugin);
#endif
}
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,11 @@ int PReluPlugin::enqueue(int batch_size, const void *const *inputs,
}

if (mode_ == "channel") {
bool channel_last = data_format_ == "NHWC";
operators::math::PreluChannelWiseDirectCUDAFunctor<float>
prelu_channel_wise;
prelu_channel_wise(stream, input, alpha, output, input_dims.d[0],
input_dims.d[1], numel);
input_dims.d[1], channel_last, numel);
} else if (mode_ == "element") {
operators::math::PreluElementWiseDirectCUDAFunctor<float>
prelu_element_wise;
Expand Down Expand Up @@ -168,10 +169,11 @@ int PReluPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
}

if (mode_ == "channel") {
bool channel_last = data_format_ == "NHWC";
operators::math::PreluChannelWiseDirectCUDAFunctor<float>
prelu_channel_wise;
prelu_channel_wise(stream, input, alpha, output, input_dims.d[0],
input_dims.d[1], numel);
input_dims.d[1], channel_last, numel);
} else if (mode_ == "element") {
operators::math::PreluElementWiseDirectCUDAFunctor<float>
prelu_element_wise;
Expand Down
22 changes: 15 additions & 7 deletions paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ class PReluPlugin : public PluginTensorRT {
std::vector<float> weight_;
float* p_gpu_weight_;
std::string mode_;
std::string data_format_;

public:
size_t getSerializationSize() const TRT_NOEXCEPT override {
return getBaseSerializationSize() + SerializedSize(mode_.c_str()) +
SerializedSize(weight_);
SerializedSize(data_format_.c_str()) + SerializedSize(weight_);
}

// TRT will call this func when we need to serialize the configuration of
Expand All @@ -46,11 +47,12 @@ class PReluPlugin : public PluginTensorRT {
serializeBase(buffer);
SerializeValue(&buffer, weight_);
SerializeValue(&buffer, mode_.c_str());
SerializeValue(&buffer, data_format_.c_str());
}

PReluPlugin(const float* weight, const int weight_num,
std::string const& mode)
: mode_(mode) {
std::string const& mode, std::string const& data_format)
: mode_(mode), data_format_(data_format) {
weight_.resize(weight_num);
std::copy(weight, weight + weight_num, weight_.data());
}
Expand All @@ -63,13 +65,17 @@ class PReluPlugin : public PluginTensorRT {
const char* prelu_mode;
DeserializeValue(&serialData, &serialLength, &prelu_mode);
mode_ = std::string(prelu_mode);
const char* prelu_data_format;
DeserializeValue(&serialData, &serialLength, &prelu_data_format);
data_format_ = std::string(prelu_data_format);
}
~PReluPlugin() {}
int initialize() TRT_NOEXCEPT override;
void terminate() TRT_NOEXCEPT override;

PReluPlugin* clone() const TRT_NOEXCEPT override {
auto* ptr = new PReluPlugin(weight_.data(), weight_.size(), mode_);
auto* ptr =
new PReluPlugin(weight_.data(), weight_.size(), mode_, data_format_);
ptr->p_gpu_weight_ = p_gpu_weight_;
return ptr;
}
Expand Down Expand Up @@ -108,16 +114,17 @@ REGISTER_TRT_PLUGIN_V2(PReluPluginCreator);
class PReluPluginDynamic : public DynamicPluginTensorRT {
public:
PReluPluginDynamic(const float* weight, const int weight_num,
std::string const& mode)
: mode_(mode) {
std::string const& mode, std::string const& data_format)
: mode_(mode), data_format_(data_format) {
weight_.resize(weight_num);
std::copy(weight, weight + weight_num, weight_.data());
}

PReluPluginDynamic(void const* serialData, size_t serialLength);
~PReluPluginDynamic() {}
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
auto ptr = new PReluPluginDynamic(weight_.data(), weight_.size(), mode_);
auto ptr = new PReluPluginDynamic(weight_.data(), weight_.size(), mode_,
data_format_);
ptr->p_gpu_weight_ = p_gpu_weight_;
return ptr;
}
Expand Down Expand Up @@ -167,6 +174,7 @@ class PReluPluginDynamic : public DynamicPluginTensorRT {
std::vector<float> weight_;
float* p_gpu_weight_;
std::string mode_;
std::string data_format_;
};
#endif

Expand Down
33 changes: 26 additions & 7 deletions paddle/fluid/operators/math/prelu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ inline static int PADDLE_GET_BLOCKS(const int N) {
}

template <typename T>
__global__ void PReluChannelWiseKernel(const T *input, const T *alpha,
T *output, size_t channel_num,
size_t plane_size, size_t numel) {
__global__ void PReluChannelFirstWiseKernel(const T *input, const T *alpha,
T *output, size_t channel_num,
size_t plane_size, size_t numel) {
CUDA_KERNEL_LOOP(index, numel) {
size_t temp = index / plane_size;
size_t channel_index = temp % channel_num;
Expand All @@ -38,6 +38,19 @@ __global__ void PReluChannelWiseKernel(const T *input, const T *alpha,
}
}

template <typename T>
__global__ void PReluChannelLastWiseKernel(const T *input, const T *alpha,
T *output, size_t channel_num,
size_t numel) {
CUDA_KERNEL_LOOP(index, numel) {
size_t channel_index = index % channel_num;
T scale = alpha[channel_index];
T x = input[index];
T zero = static_cast<T>(0);
output[index] = (x > zero) ? x : scale * x;
}
}

template <typename T>
__global__ void PReluElementWiseKernel(const T *input, const T *alpha,
T *output, size_t spatial_size,
Expand Down Expand Up @@ -65,10 +78,16 @@ __global__ void PReluScalarKernel(const T *input, const T *alpha, T *output,
template <typename T>
void PreluChannelWiseDirectCUDAFunctor<T>::operator()(
gpuStream_t stream, const T *input, const T *alpha, T *output,
size_t batch_size, size_t channel, size_t numel) {
PReluChannelWiseKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0,
stream>>>(input, alpha, output, channel,
numel / batch_size / channel, numel);
size_t batch_size, size_t channel, bool channel_last, size_t numel) {
if (channel_last) {
PReluChannelLastWiseKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0,
stream>>>(input, alpha, output, channel,
numel);
} else {
PReluChannelFirstWiseKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0,
stream>>>(
input, alpha, output, channel, numel / batch_size / channel, numel);
}
}

template <typename T>
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/operators/math/prelu.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ template <typename T>
class PreluChannelWiseDirectCUDAFunctor {
public:
void operator()(gpuStream_t stream, const T *input, const T *alpha, T *output,
size_t batch_size, size_t channel, size_t numel);
size_t batch_size, size_t channel, bool channel_last,
size_t numel);
};

template <typename T>
Expand Down
19 changes: 14 additions & 5 deletions paddle/fluid/operators/mkldnn/prelu_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class PReluMKLDNNHandler
const mkldnn::engine engine, platform::Place cpu_place,
const Tensor* x, const Tensor* weights,
const std::string& uniq_name, const std::string& mode,
bool is_test = false)
const std::string& data_format, bool is_test = false)
: platform::MKLDNNHandlerT<T, dnnl::prelu_forward, dnnl::prelu_backward>(
dev_ctx, engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()),
Expand All @@ -49,8 +49,13 @@ class PReluMKLDNNHandler
if (weights->dims().size() != x->dims().size()) {
auto new_weights_dims = std::vector<int64_t>(x->dims().size(), 1);
if (mode == "channel") {
new_weights_dims[1] =
*std::max_element(weights_dims.begin(), weights_dims.end());
if (data_format == "NHWC") {
new_weights_dims[x->dims().size() - 1] =
*std::max_element(weights_dims.begin(), weights_dims.end());
} else {
new_weights_dims[1] =
*std::max_element(weights_dims.begin(), weights_dims.end());
}
}
weights_dims = std::move(new_weights_dims);
}
Expand Down Expand Up @@ -110,9 +115,11 @@ class PReluMKLDNNKernel : public framework::OpKernel<T> {
auto* out = ctx.Output<Tensor>("Out");
const bool is_test = ctx.Attr<bool>("is_test");
const auto mode = ctx.Attr<std::string>("mode");
const auto data_format = ctx.Attr<std::string>("data_format");

PReluMKLDNNHandler<T> handler(dev_ctx, onednn_engine, ctx.GetPlace(), x,
alpha, ctx.InputName("X"), mode, is_test);
alpha, ctx.InputName("X"), mode, data_format,
is_test);

auto src_memory_p = handler.AcquireSrcMemory(x);
auto weights_memory_p =
Expand Down Expand Up @@ -149,9 +156,11 @@ class PReluGradMKLDNNKernel : public framework::OpKernel<T> {
auto* alpha = ctx.Input<Tensor>("Alpha");
const bool is_test = ctx.Attr<bool>("is_test");
const auto mode = ctx.Attr<std::string>("mode");
const auto data_format = ctx.Attr<std::string>("data_format");

PReluMKLDNNHandler<T> handler(dev_ctx, onednn_engine, ctx.GetPlace(), x,
alpha, framework::GradVarName("X"), mode);
alpha, framework::GradVarName("X"), mode,
data_format);

auto src_memory_p = handler.AcquireSrcMemory(x);
auto weights_memory_p =
Expand Down
36 changes: 30 additions & 6 deletions paddle/fluid/operators/prelu_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,40 @@ class PReluOp : public framework::OperatorWithKernel {
"But recevied alpha's size: %d.",
product(ctx->GetInputDim("Alpha"))));
} else if (mode == "channel") {
PADDLE_ENFORCE_EQ(product(ctx->GetInputDim("Alpha")), x_dim[1],
platform::errors::InvalidArgument(
"For mode 'channel', size of weight Alpha must be "
"equal to the number of channels of input(x). But "
"recevied alpha's size: %d, x_dim[1]: %d",
product(ctx->GetInputDim("Alpha")), x_dim[1]));
auto x_rank = x_dim.size();
PADDLE_ENFORCE_GE(x_rank, 2,
platform::errors::InvalidArgument(
"For mode 'channel', rank of input X must be "
"equal or larger than 2. But recevied X's "
"rank: %d",
x_rank));
const std::string data_format_str =
ctx->Attrs().Get<std::string>("data_format");
PADDLE_ENFORCE_EQ(data_format_str == "NCHW" || data_format_str == "NHWC",
true,
platform::errors::InvalidArgument(
"For mode 'channel', data_format must be one of "
"NCHW and NHWC. But recevied data_format: %s",
data_format_str));
if (data_format_str == "NCHW") {
PADDLE_ENFORCE_EQ(
product(ctx->GetInputDim("Alpha")) == x_dim[1], true,
platform::errors::InvalidArgument(
"For mode 'channel', size of weight Alpha must be "
"equal to the number of channels of input(x). But "
"recevied alpha's size: %d, x_dim[1]: %d",
product(ctx->GetInputDim("Alpha")), x_dim[1]));
} else {
PADDLE_ENFORCE_EQ(
product(ctx->GetInputDim("Alpha")) == x_dim[x_rank - 1], true,
platform::errors::InvalidArgument(
"For mode 'channel', size of weight Alpha must be "
"equal to the number of channels of input(x). But "
"recevied alpha's size: %d, x_dim[%d]: %d",
product(ctx->GetInputDim("Alpha")), x_rank - 1,
x_dim[x_rank - 1]));
}

} else if (mode == "element") {
auto alpha_dim = ctx->GetInputDim("Alpha");
auto alpha_rank = alpha_dim.size();
Expand Down Expand Up @@ -134,6 +155,9 @@ There are modes:
)DOC");
AddAttr<std::string>("mode", "The mode for inputs to share weights.")
.SetDefault("all");
AddAttr<std::string>("data_format",
"Data format that specifies the layout of input")
.SetDefault("NCHW");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false)
Expand Down
Loading