diff --git a/paddle/fluid/operators/gru_op.cc b/paddle/fluid/operators/gru_op.cc deleted file mode 100644 index c948315189a15..0000000000000 --- a/paddle/fluid/operators/gru_op.cc +++ /dev/null @@ -1,594 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/gru_op.h" - -#include -#include - -#include "paddle/common/flags.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/detail/gru_cpu_kernel.h" -#include "paddle/phi/kernels/funcs/detail/gru_kernel.h" - -COMMON_DECLARE_int32(paddle_num_threads); - -namespace paddle { -namespace operators { - -class GRUOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "GRU"); - OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "GRU"); - OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "GRU"); - bool is_test = ctx->Attrs().Get("is_test"); - if (!is_test) { - OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", "GRU"); - OP_INOUT_CHECK(ctx->HasOutput("BatchResetHiddenPrev"), - "Output", - "BatchResetHiddenPrev", - "GRU"); - OP_INOUT_CHECK( - ctx->HasOutput("BatchHidden"), "Output", "BatchHidden", "GRU"); - } - auto input_dims = ctx->GetInputDim("Input"); - auto weight_dims = ctx->GetInputDim("Weight"); - int input_size = static_cast(input_dims[1]); - int frame_size = static_cast(weight_dims[0]); - if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ(input_size, - frame_size * 3, - phi::errors::InvalidArgument( - "The second dimension of Input(Input) must be 3 " - "times of frame_size in GRUOp, but received %d " - "(Input) vs %d (frame_size).", - input_size, - frame_size)); - } - PADDLE_ENFORCE_EQ( - weight_dims[1], - frame_size * 3, - phi::errors::InvalidArgument( - "The shape of Input(Weight) matrix must be [frame_size, frame_size " - "* 3], but received [%d, %d] (Weight) vs [%d, %d] (frame_size).", - weight_dims[0], - weight_dims[1], - frame_size, - frame_size * 3)); - if (ctx->HasInput("H0")) { - auto h0_dims = ctx->GetInputDim("H0"); - PADDLE_ENFORCE_EQ( - h0_dims[1], - frame_size, - phi::errors::InvalidArgument( - "The width of Input(H0) must be equal to frame_size, but " - "received %d (width of H0) vs %d (frame_size).", - h0_dims[1], - frame_size)); - } - if (ctx->HasInput("Bias")) { - auto bias_dims = ctx->GetInputDim("Bias"); - int bias_height = static_cast(bias_dims[0]); - int bias_width = static_cast(bias_dims[1]); - PADDLE_ENFORCE_EQ( - bias_height, - 1, - phi::errors::InvalidArgument( - "The shape of Bias must be [1, frame_size * 3], but received " - "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).", - bias_height, - bias_width, - frame_size * 3)); - PADDLE_ENFORCE_EQ( - bias_width, - frame_size * 3, - phi::errors::InvalidArgument( - "The shape of Bias must be [1, frame_size * 3], but received " - "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).", - bias_height, - bias_width, - frame_size * 3)); - } - if (!is_test) { - ctx->SetOutputDim("BatchGate", input_dims); - ctx->SetOutputDim("BatchResetHiddenPrev", {input_dims[0], frame_size}); - ctx->SetOutputDim("BatchHidden", {input_dims[0], frame_size}); - } - ctx->SetOutputDim("Hidden", {input_dims[0], frame_size}); - ctx->ShareLoD("Input", "Hidden"); - } -}; - -class GRUOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput( - "Input", - "(phi::DenseTensor) The first input is a LodTensor, which supports " - "variable-time length input sequence. The underlying tensor in " - "this phi::DenseTensor is a matrix with shape (T X 3D), where, T is " - "the total time steps in this mini-batch, D is the hidden size."); - AddInput("H0", - "(Tensor, optional) The initial hidden state is an optional " - "input. This is a tensor with shape (N x D), where N is the " - "batch size, D is the hidden size.") - .AsDispensable(); - AddInput( - "Weight", - "(Tensor) The learnable hidden-hidden weight matrix with shape " - "(D x 3D), where D is the hidden size. The elements continuous in " - "memory can be divided into two parts. The first part are weights of " - "the update gate and reset gate with shape (D x 2D), and the second " - "part are weights of output candidate with shape (D x D)."); - AddInput("Bias", - "(Tensor, optional) Bias vector with shape (1 x 3D) concatenating " - "bias of the update gate, reset gate and output candidate.") - .AsDispensable(); - AddOutput( - "BatchGate", - "(phi::DenseTensor) To compute with batches, sequence data will be " - "reorganized into several successive batches each containing " - "data from the same time step. The phi::DenseTensor BatchGate contains " - "the update gate, reset gate and output candidate values " - "organized in batches. The LoD size is 2. The first LoD contains " - "the batch offsets and the second LoD contains the indexes in " - "the raw sequence data.") - .AsIntermediate() - .AsExtra(); - AddOutput("BatchResetHiddenPrev", - "(phi::DenseTensor) The reset hidden state phi::DenseTensor " - "organized in batches. " - "This phi::DenseTensor is a matrix with shape (T X D) and has " - "the same LoD " - "with `BatchGate`.") - .AsIntermediate() - .AsExtra(); - AddOutput("BatchHidden", - "(phi::DenseTensor) The hidden state phi::DenseTensor organized " - "in batches. " - "This phi::DenseTensor is a matrix with shape (T X D) and has " - "the same LoD " - "with `BatchGate`.") - .AsIntermediate() - .AsExtra(); - AddOutput("Hidden", - "(phi::DenseTensor) the hidden state phi::DenseTensor organized " - "in sequences. " - "This phi::DenseTensor is a matrix with shape (T X D) and has " - "the same LoD with `BatchGate`."); - AddAttr("activation", - "(string, default tanh) " - "The activation type used for output candidate {h}_t.") - .SetDefault("tanh"); - AddAttr( - "gate_activation", - "(string, default sigmoid) " - "The activation type used in update gate and reset gate.") - .SetDefault("sigmoid"); - AddAttr("is_reverse", - "(bool, default: False) " - "whether to compute reversed GRU.") - .SetDefault(false); - AddAttr("origin_mode", - "bool" - "use origin mode in article https://arxiv.org/abs/1412.3555") - .SetDefault(false); - AddComment(R"DOC( -GRU Operator implements part calculations of the complete GRU as following: - -$$ -update\_gate: u_t = actGate(xu_t + W_u * h_{t-1} + b_u) \\ -reset\_gate: r_t = actGate(xr_t + W_r * h_{t-1} + b_r) \\ -output\_candidate: {h}_t = actNode(xc_t + W_c * dot(r_t, h_{t-1}) + b_c) \\ -output: h_t = dot((1 - u_t), h_{t-1}) + dot(u_t, {h}_t) -$$ - -@note To implement the complete GRU, fully-connected operator must be used -before to feed xu, xr and xc as the Input of GRU operator. -)DOC"); - } -}; - -class GRUGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "GRU@Grad"); - OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "GRU@Grad"); - OP_INOUT_CHECK( - ctx->HasInput("BatchGate"), "Input", "BatchGate", "GRU@Grad"); - OP_INOUT_CHECK(ctx->HasInput("BatchResetHiddenPrev"), - "Input", - "BatchResetHiddenPrev", - "GRU@Grad"); - OP_INOUT_CHECK( - ctx->HasInput("BatchHidden"), "Input", "BatchHidden", "GRU@Grad"); - OP_INOUT_CHECK(ctx->HasInput("Hidden"), "Input", "Hidden", "GRU@Grad"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Hidden")), - "Input", - framework::GradVarName("Hidden"), - "GRU@Grad"); - - auto input_dims = ctx->GetInputDim("Input"); - auto weight_dims = ctx->GetInputDim("Weight"); - int input_size = static_cast(input_dims[1]); - int frame_size = static_cast(weight_dims[0]); - int weight_height = static_cast(weight_dims[0]); - int weight_width = static_cast(weight_dims[1]); - PADDLE_ENFORCE_EQ( - input_size, - frame_size * 3, - phi::errors::InvalidArgument( - "The second dimension of Input(Input) must be 3 times of " - "frame_size in GRUOp, but received %d (Input) vs %d (frame_size).", - input_size, - frame_size)); - PADDLE_ENFORCE_EQ( - weight_height, - frame_size, - phi::errors::InvalidArgument( - "The shape of Input(Weight) matrix must be [frame_size, frame_size " - "* 3], but received [%d, %d] (Weight) vs [%d, %d] (frame_size).", - weight_height, - weight_width, - frame_size, - frame_size * 3)); - PADDLE_ENFORCE_EQ( - weight_width, - frame_size * 3, - phi::errors::InvalidArgument( - "The shape of Input(Weight) matrix must be [frame_size, frame_size " - "* 3], but received [%d, %d] (Weight) vs [%d, %d] (frame_size).", - weight_height, - weight_width, - frame_size, - frame_size * 3)); - if (ctx->HasInput("H0")) { - auto h0_dims = ctx->GetInputDim("H0"); - PADDLE_ENFORCE_EQ( - h0_dims[1], - frame_size, - phi::errors::InvalidArgument( - "The width of Input(H0) must be equal to frame_size, but " - "received %d (width of H0) vs %d (frame_size).", - h0_dims[1], - frame_size)); - auto h0_grad_name = framework::GradVarName("H0"); - if (ctx->HasOutput(h0_grad_name)) - ctx->SetOutputDim(h0_grad_name, h0_dims); - } - if (ctx->HasInput("Bias")) { - auto bias_dims = ctx->GetInputDim("Bias"); - int bias_height = static_cast(bias_dims[0]); - int bias_width = static_cast(bias_dims[1]); - PADDLE_ENFORCE_EQ( - bias_height, - 1, - phi::errors::InvalidArgument( - "The shape of Bias must be [1, frame_size * 3], but received " - "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).", - bias_height, - bias_width, - frame_size * 3)); - PADDLE_ENFORCE_EQ( - bias_width, - frame_size * 3, - phi::errors::InvalidArgument( - "The shape of Bias must be [1, frame_size * 3], but received " - "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).", - bias_height, - bias_width, - frame_size * 3)); - auto bias_grad_name = framework::GradVarName("Bias"); - if (ctx->HasOutput(bias_grad_name)) - ctx->SetOutputDim(bias_grad_name, bias_dims); - } - auto input_grad_name = framework::GradVarName("Input"); - if (ctx->HasOutput(input_grad_name)) - ctx->SetOutputDim(input_grad_name, input_dims); - auto weight_grad_name = framework::GradVarName("Weight"); - if (ctx->HasOutput(weight_grad_name)) - ctx->SetOutputDim(weight_grad_name, weight_dims); - } - - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Hidden")), - ctx.device_context().GetPlace()); - } -}; - -template -class GRUCPUKernel : public framework::OpKernel { - public: - void BatchCompute(const framework::ExecutionContext& context) const { - using LodTensorPtr = phi::DenseTensor*; - bool is_test = context.Attr("is_test"); - - bool origin_mode = context.Attr("origin_mode"); - auto* input = context.Input("Input"); - auto* h0 = context.Input("H0"); - auto* weight = context.Input("Weight"); - const T* weight_data = weight->data(); - auto* bias = context.Input("Bias"); - auto* hidden = context.Output("Hidden"); - hidden->mutable_data(context.GetPlace()); - - auto input_dims = input->dims(); - auto hidden_dims = hidden->dims(); - - LodTensorPtr batch_gate = nullptr, batch_reset_hidden_prev = nullptr, - batch_hidden = nullptr; - phi::DenseTensor batch_gate_tmp, batch_reset_hidden_prev_tmp, - batch_hidden_tmp; - if (is_test) { - batch_gate = &batch_gate_tmp; - batch_gate->Resize(input_dims); - - batch_reset_hidden_prev = &batch_reset_hidden_prev_tmp; - batch_reset_hidden_prev->Resize(hidden_dims); - - batch_hidden = &batch_hidden_tmp; - batch_hidden->Resize(hidden_dims); - } else { - batch_gate = context.Output("BatchGate"); - batch_hidden = context.Output("BatchHidden"); - batch_reset_hidden_prev = - context.Output("BatchResetHiddenPrev"); - } - batch_gate->mutable_data(context.GetPlace()); - batch_reset_hidden_prev->mutable_data(context.GetPlace()); - batch_hidden->mutable_data(context.GetPlace()); - - bool is_reverse = context.Attr("is_reverse"); - phi::funcs::LoDTensor2BatchFunctor to_batch; - auto& dev_ctx = context.template device_context(); - to_batch(dev_ctx, *input, batch_gate, true, is_reverse); - - if (bias) { - phi::funcs::RowwiseAdd add_bias; - add_bias(dev_ctx, *batch_gate, *bias, batch_gate); - } - - int frame_size = static_cast(hidden_dims[1]); - phi::funcs::GRUMetaValue gru_value; - gru_value.gate_weight = const_cast(weight_data); - gru_value.state_weight = - const_cast(weight_data + 2 * frame_size * frame_size); - phi::DenseTensor ordered_h0; - - phi::Vector order(batch_gate->lod()[2]); - - if (h0) { - // Since the batch computing for GRU reorders the input sequences - // according to their length. The initialized cell state also needs - // to reorder. - ReorderInitState( - context.template device_context(), - *h0, - order, - &ordered_h0, - true); - gru_value.prev_out_value = ordered_h0.data(); - } else { - gru_value.prev_out_value = nullptr; - } - auto batch_starts = batch_gate->lod()[0]; - size_t seq_len = batch_starts.size() - 1; - auto active_node = phi::funcs::detail::GetActivationType( - context.Attr("activation")); - auto active_gate = phi::funcs::detail::GetActivationType( - context.Attr("gate_activation")); - -#ifdef PADDLE_WITH_MKLML - // use MKL packed to speedup GEMM - if (FLAGS_paddle_num_threads >= 4) { - auto blas = phi::funcs::GetBlas(dev_ctx); - T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, - 1 /*height of C*/, - frame_size * 2 /*width of weight*/, - frame_size /*height of height*/); - PADDLE_ENFORCE_NOT_NULL( - packed_gate, - phi::errors::NotFound( - "The calculation result of packed_gate by " - "GEMM_ALLOC should not be null when using MKL.")); - blas.GEMM_PACK(CblasBMatrix, - CblasNoTrans, - 1 /*cur bs?*/, - frame_size * 2, - frame_size, - T(1.0), - gru_value.gate_weight, - frame_size * 2, - packed_gate); - T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, - 1 /*height of C*/, - frame_size /*width of weight*/, - frame_size /*height of height*/); - PADDLE_ENFORCE_NOT_NULL( - packed_state, - phi::errors::NotFound( - "The calculation result of packed_state by " - "GEMM_ALLOC should not be null when using MKL.")); - blas.GEMM_PACK(CblasBMatrix, - CblasNoTrans, - 1 /*cur bs?*/, - frame_size, - frame_size, - T(1.0), - gru_value.state_weight, - frame_size, - packed_state); - for (size_t n = 0; n < seq_len; n++) { - int bstart = static_cast(batch_starts[n]); - int bend = static_cast(batch_starts[n + 1]); - int cur_batch_size = bend - bstart; - - phi::DenseTensor gate_t = batch_gate->Slice(bstart, bend); - phi::DenseTensor reset_hidden_prev_t = - batch_reset_hidden_prev->Slice(bstart, bend); - phi::DenseTensor hidden_t = batch_hidden->Slice(bstart, bend); - gru_value.output_value = hidden_t.data(); - gru_value.gate_value = gate_t.data(); - gru_value.reset_output_value = reset_hidden_prev_t.data(); - - if (gru_value.prev_out_value) { - blas.GEMM_COMPUTE(CblasNoTrans, - CblasPacked, - cur_batch_size, - frame_size * 2, - frame_size, - gru_value.prev_out_value, - frame_size, - packed_gate, - frame_size * 2, - T(1), - gru_value.gate_value, - frame_size * 3); - } - - phi::funcs::detail::forward_reset_output( - phi::funcs::detail::forward::gru_resetOutput(), - gru_value, - frame_size, - cur_batch_size, - active_gate); - - if (gru_value.prev_out_value) { - blas.GEMM_COMPUTE(CblasNoTrans, - CblasPacked, - cur_batch_size, - frame_size, - frame_size, - gru_value.reset_output_value, - frame_size, - packed_state, - frame_size, - T(1), - gru_value.gate_value + frame_size * 2, - frame_size * 3); - } - - phi::funcs::detail::forward_final_output( - phi::funcs::detail::forward::gru_finalOutput(), - gru_value, - frame_size, - cur_batch_size, - active_node, - origin_mode); - - gru_value.prev_out_value = gru_value.output_value; - } - - blas.GEMM_FREE(packed_gate); - blas.GEMM_FREE(packed_state); - } else { -#endif - for (size_t n = 0; n < seq_len; n++) { - int bstart = static_cast(batch_starts[n]); - int bend = static_cast(batch_starts[n + 1]); - int cur_batch_size = bend - bstart; - - phi::DenseTensor gate_t = batch_gate->Slice(bstart, bend); - phi::DenseTensor reset_hidden_prev_t = - batch_reset_hidden_prev->Slice(bstart, bend); - phi::DenseTensor hidden_t = batch_hidden->Slice(bstart, bend); - gru_value.output_value = hidden_t.data(); - gru_value.gate_value = gate_t.data(); - gru_value.reset_output_value = reset_hidden_prev_t.data(); - - phi::funcs::GRUUnitFunctor::compute( - dev_ctx, // NOLINT - gru_value, - frame_size, - cur_batch_size, - active_node, - active_gate, - origin_mode); - - gru_value.prev_out_value = gru_value.output_value; - } -#ifdef PADDLE_WITH_MKLML - } -#endif - phi::funcs::Batch2LoDTensorFunctor to_seq; - batch_hidden->set_lod(batch_gate->lod()); - to_seq(dev_ctx, *batch_hidden, hidden); - } - - void Compute(const framework::ExecutionContext& context) const override { - BatchCompute(context); - } -}; - -template -class GRUGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr grad_op) const override { - grad_op->SetType("gru_grad"); - grad_op->SetInput("Input", this->Input("Input")); - grad_op->SetInput("H0", this->Input("H0")); - grad_op->SetInput("Bias", this->Input("Bias")); - grad_op->SetInput("Weight", this->Input("Weight")); - - grad_op->SetInput("BatchGate", this->Output("BatchGate")); - grad_op->SetInput("BatchResetHiddenPrev", - this->Output("BatchResetHiddenPrev")); - grad_op->SetInput("BatchHidden", this->Output("BatchHidden")); - grad_op->SetInput("Hidden", this->Output("Hidden")); - - grad_op->SetInput(framework::GradVarName("Hidden"), - this->OutputGrad("Hidden")); - - grad_op->SetOutput(framework::GradVarName("H0"), this->InputGrad("H0")); - grad_op->SetOutput(framework::GradVarName("Input"), - this->InputGrad("Input")); - grad_op->SetOutput(framework::GradVarName("Weight"), - this->InputGrad("Weight")); - grad_op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias")); - - grad_op->SetAttrMap(this->Attrs()); - } -}; - -DECLARE_NO_NEED_BUFFER_VARS_INFERER(GRUGradOpNoNeedBufferVarInferer, - "Input", - "Bias"); - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OPERATOR(gru, - ops::GRUOp, - ops::GRUOpMaker, - ops::GRUGradOpMaker, - ops::GRUGradOpMaker); -REGISTER_OPERATOR(gru_grad, - ops::GRUGradOp, - ops::GRUGradOpNoNeedBufferVarInferer); - -PD_REGISTER_STRUCT_KERNEL( - gru, CPU, ALL_LAYOUT, ops::GRUCPUKernel, float, double) {} -PD_REGISTER_STRUCT_KERNEL( - gru_grad, CPU, ALL_LAYOUT, ops::GRUGradKernel, float, double) {} diff --git a/paddle/fluid/operators/gru_op.cu.cc b/paddle/fluid/operators/gru_op.cu.cc deleted file mode 100644 index f7b4317832b0e..0000000000000 --- a/paddle/fluid/operators/gru_op.cu.cc +++ /dev/null @@ -1,140 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/gru_op.h" - -namespace paddle { -namespace operators { - -template -class GRUKernel : public framework::OpKernel { - public: - void BatchCompute(const framework::ExecutionContext& context) const { - using LodTensorPtr = phi::DenseTensor*; - - bool is_test = context.Attr("is_test"); - bool origin_mode = context.Attr("origin_mode"); - auto* input = context.Input("Input"); - auto* h0 = context.Input("H0"); - auto* weight = context.Input("Weight"); - const T* weight_data = weight->data(); - auto* bias = context.Input("Bias"); - auto* hidden = context.Output("Hidden"); - hidden->mutable_data(context.GetPlace()); - - auto input_dims = input->dims(); - auto hidden_dims = hidden->dims(); - - LodTensorPtr batch_gate, batch_reset_hidden_prev, batch_hidden; - phi::DenseTensor batch_gate_tmp, batch_reset_hidden_prev_tmp, - batch_hidden_tmp; - if (is_test) { - batch_gate = &batch_gate_tmp; - batch_gate->Resize(input_dims); - - batch_reset_hidden_prev = &batch_reset_hidden_prev_tmp; - batch_reset_hidden_prev->Resize(hidden_dims); - - batch_hidden = &batch_hidden_tmp; - batch_hidden->Resize(hidden_dims); - } else { - batch_gate = context.Output("BatchGate"); - batch_hidden = context.Output("BatchHidden"); - batch_reset_hidden_prev = - context.Output("BatchResetHiddenPrev"); - } - batch_gate->mutable_data(context.GetPlace()); - batch_reset_hidden_prev->mutable_data(context.GetPlace()); - batch_hidden->mutable_data(context.GetPlace()); - - bool is_reverse = context.Attr("is_reverse"); - phi::funcs::LoDTensor2BatchFunctor to_batch; - auto& dev_ctx = context.template device_context(); - to_batch(dev_ctx, *input, batch_gate, true, is_reverse); - - if (bias) { - phi::funcs::RowwiseAdd add_bias; - add_bias(dev_ctx, *batch_gate, *bias, batch_gate); - } - - int frame_size = hidden_dims[1]; - phi::funcs::GRUMetaValue gru_value; - gru_value.gate_weight = const_cast(weight_data); - gru_value.state_weight = - const_cast(weight_data + 2 * frame_size * frame_size); - phi::DenseTensor ordered_h0; - - phi::Vector order(batch_gate->lod()[2]); - - if (h0) { - // Since the batch computing for GRU reorders the input sequences - // according to their length. The initialized cell state also needs - // to reorder. - ReorderInitState( - context.template device_context(), - *h0, - order, - &ordered_h0, - true); - gru_value.prev_out_value = ordered_h0.data(); - } else { - gru_value.prev_out_value = nullptr; - } - auto batch_starts = batch_gate->lod()[0]; - size_t num_batch = batch_starts.size() - 1; - auto active_node = phi::funcs::detail::GetActivationType( - context.Attr("activation")); - auto active_gate = phi::funcs::detail::GetActivationType( - context.Attr("gate_activation")); - for (size_t n = 0; n < num_batch; n++) { - int bstart = static_cast(batch_starts[n]); - int bend = static_cast(batch_starts[n + 1]); - int cur_batch_size = bend - bstart; - - phi::DenseTensor gate_t = batch_gate->Slice(bstart, bend); - phi::DenseTensor reset_hidden_prev_t = - batch_reset_hidden_prev->Slice(bstart, bend); - phi::DenseTensor hidden_t = batch_hidden->Slice(bstart, bend); - gru_value.output_value = hidden_t.data(); - gru_value.gate_value = gate_t.data(); - gru_value.reset_output_value = reset_hidden_prev_t.data(); - phi::funcs::GRUUnitFunctor::compute(dev_ctx, // NOLINT - gru_value, - frame_size, - cur_batch_size, - active_node, - active_gate, - origin_mode); - gru_value.prev_out_value = gru_value.output_value; - } - - phi::funcs::Batch2LoDTensorFunctor to_seq; - batch_hidden->set_lod(batch_gate->lod()); - to_seq(dev_ctx, *batch_hidden, hidden); - } - - void Compute(const framework::ExecutionContext& context) const override { - BatchCompute(context); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -PD_REGISTER_STRUCT_KERNEL(gru, GPU, ALL_LAYOUT, ops::GRUKernel, float, double) { -} -PD_REGISTER_STRUCT_KERNEL( - gru_grad, GPU, ALL_LAYOUT, ops::GRUGradKernel, float, double) {} diff --git a/paddle/fluid/operators/gru_op.h b/paddle/fluid/operators/gru_op.h deleted file mode 100644 index 773e9ff510852..0000000000000 --- a/paddle/fluid/operators/gru_op.h +++ /dev/null @@ -1,186 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/kernels/funcs/detail/activation_functions.h" -#include "paddle/phi/kernels/funcs/eigen/common.h" -#include "paddle/phi/kernels/funcs/gru_compute.h" -#include "paddle/phi/kernels/funcs/math_function.h" -#include "paddle/phi/kernels/funcs/sequence2batch.h" - -namespace paddle { -namespace operators { - -template -inline void ReorderInitState(const DeviceContext& ctx, - const phi::DenseTensor& src, - phi::Vector index_lod, - phi::DenseTensor* dst, - bool indexed_src) { - phi::funcs::CopyMatrixRowsFunctor row_shuffle; - dst->mutable_data(src.dims(), ctx.GetPlace()); - row_shuffle(ctx, src, index_lod, dst, indexed_src); -} - -template -class GRUGradKernel : public framework::OpKernel { - public: - void BatchCompute(const framework::ExecutionContext& context) const { - bool origin_mode = context.Attr("origin_mode"); - auto* h0 = context.Input("H0"); - auto* weight = context.Input("Weight"); - const T* weight_data = weight->data(); - auto* batch_gate = context.Input("BatchGate"); - auto* batch_reset_hidden_prev = - context.Input("BatchResetHiddenPrev"); - auto* batch_hidden = context.Input("BatchHidden"); - auto* hidden = context.Input("Hidden"); - auto* hidden_grad = - context.Input(framework::GradVarName("Hidden")); - auto* input_grad = - context.Output(framework::GradVarName("Input")); - auto* h0_grad = - context.Output(framework::GradVarName("H0")); - auto* weight_grad = - context.Output(framework::GradVarName("Weight")); - auto* bias_grad = - context.Output(framework::GradVarName("Bias")); - - auto gate_dims = batch_gate->dims(); - auto hidden_dims = hidden->dims(); - int frame_size = hidden_dims[1]; - - phi::funcs::LoDTensor2BatchFunctor to_batch; - phi::DenseTensor batch_hidden_grad, batch_gate_grad, - batch_reset_hidden_prev_grad; - batch_hidden_grad.mutable_data(hidden_dims, context.GetPlace()); - batch_gate_grad.mutable_data(gate_dims, context.GetPlace()); - batch_reset_hidden_prev_grad.mutable_data(hidden_dims, - context.GetPlace()); - phi::funcs::SetConstant zero; - auto& dev_ctx = context.template device_context(); - zero(dev_ctx, &batch_hidden_grad, static_cast(0.0)); - zero(dev_ctx, &batch_gate_grad, static_cast(0.0)); - zero(dev_ctx, &batch_reset_hidden_prev_grad, static_cast(0.0)); - - phi::DenseTensor ordered_h0, ordered_h0_grad; - - phi::Vector order(batch_gate->lod()[2]); - - if (h0) { - ReorderInitState( - dev_ctx, *h0, order, &ordered_h0, true); - } - if (h0_grad) { - ordered_h0_grad.mutable_data(h0_grad->dims(), context.GetPlace()); - zero(context.template device_context(), - &ordered_h0_grad, - static_cast(0.0)); - } - - bool is_reverse = context.Attr("is_reverse"); - batch_hidden_grad.set_lod(batch_hidden->lod()); - to_batch(dev_ctx, *hidden_grad, &batch_hidden_grad, false, is_reverse); - - phi::funcs::GRUMetaValue gru_value; - gru_value.gate_weight = const_cast(weight_data); - gru_value.state_weight = - const_cast(weight_data + 2 * frame_size * frame_size); - - phi::funcs::GRUMetaGrad gru_grad; - if (weight_grad) { - gru_grad.gate_weight_grad = - weight_grad->mutable_data(context.GetPlace()); - zero(dev_ctx, weight_grad, static_cast(0.0)); - gru_grad.state_weight_grad = - weight_grad->data() + 2 * frame_size * frame_size; - } else { - gru_grad.gate_weight_grad = nullptr; - gru_grad.state_weight_grad = nullptr; - } - - auto batch_starts = batch_hidden_grad.lod()[0]; - size_t num_batch = batch_starts.size() - 1; - auto active_node = phi::funcs::detail::GetActivationType( - context.Attr("activation")); - auto active_gate = phi::funcs::detail::GetActivationType( - context.Attr("gate_activation")); - for (int n = static_cast(num_batch) - 1; n >= 0; n--) { - int bstart = static_cast(batch_starts[n]); - int bend = static_cast(batch_starts[n + 1]); - int cur_batch_size = bend - bstart; - - phi::DenseTensor gate_t = batch_gate->Slice(bstart, bend); - gru_value.gate_value = gate_t.data(); - phi::DenseTensor reset_hidden_prev_t = - batch_reset_hidden_prev->Slice(bstart, bend); - gru_value.reset_output_value = reset_hidden_prev_t.data(); - - phi::DenseTensor hidden_grad_t = batch_hidden_grad.Slice(bstart, bend); - gru_grad.output_grad = hidden_grad_t.data(); - phi::DenseTensor gate_grad_t = batch_gate_grad.Slice(bstart, bend); - gru_grad.gate_grad = gate_grad_t.data(); - phi::DenseTensor reset_hidden_prev_grad_t = - batch_reset_hidden_prev_grad.Slice(bstart, bend); - gru_grad.reset_output_grad = reset_hidden_prev_grad_t.data(); - if (n == 0) { - gru_value.prev_out_value = h0 ? ordered_h0.data() : nullptr; - gru_grad.prev_out_grad = - h0 && h0_grad ? ordered_h0_grad.data() : nullptr; - } else { - int bstart_pre = static_cast(batch_starts[n - 1]); - phi::DenseTensor hidden_prev_t = - batch_hidden->Slice(bstart_pre, bstart); - gru_value.prev_out_value = hidden_prev_t.data(); - phi::DenseTensor hidden_prev_grad_t = - batch_hidden_grad.Slice(bstart_pre, bstart); - gru_grad.prev_out_grad = hidden_prev_grad_t.data(); - } - gru_value.output_value = nullptr; - phi::funcs::GRUUnitGradFunctor::compute(dev_ctx, - gru_value, - gru_grad, - frame_size, - cur_batch_size, - active_node, - active_gate, - origin_mode); - } - if (input_grad) { - input_grad->mutable_data(context.GetPlace()); - phi::funcs::Batch2LoDTensorFunctor to_seq; - batch_gate_grad.set_lod(batch_gate->lod()); - to_seq(dev_ctx, batch_gate_grad, input_grad); - } - if (bias_grad) { - bias_grad->mutable_data(context.GetPlace()); - phi::funcs::ColwiseSum col_sum; - col_sum(dev_ctx, batch_gate_grad, bias_grad); - } - if (h0 && h0_grad) { - ReorderInitState( - dev_ctx, ordered_h0_grad, order, h0_grad, false); - } - } - - void Compute(const framework::ExecutionContext& context) const override { - BatchCompute(context); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index c6c145019f40b..4fdf26168ea49 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -574,6 +574,101 @@ void GeneralUnaryGradInferMeta(const MetaTensor& x, MetaTensor* dx) { } } +void GruGradInferMeta(const MetaTensor& input, + const MetaTensor& h0, + const MetaTensor& weight, + const MetaTensor& bias, + MetaTensor* input_grad, + MetaTensor* h0_grad, + MetaTensor* weight_grad, + MetaTensor* bias_grad, + MetaConfig config) { + const auto& input_dims = input.dims(); + const auto& weight_dims = weight.dims(); + int input_size = static_cast(input_dims[1]); + int frame_size = static_cast(weight_dims[0]); + int weight_height = static_cast(weight_dims[0]); + int weight_width = static_cast(weight_dims[1]); + PADDLE_ENFORCE_EQ( + input_size, + frame_size * 3, + phi::errors::InvalidArgument( + "The second dimension of Input(Input) must be 3 times of " + "frame_size in GRUOp, but received %d (Input) vs %d (frame_size).", + input_size, + frame_size)); + PADDLE_ENFORCE_EQ( + weight_height, + frame_size, + phi::errors::InvalidArgument( + "The shape of Input(Weight) matrix must be [frame_size, frame_size " + "* 3], but received [%d, %d] (Weight) vs [%d, %d] (frame_size).", + weight_height, + weight_width, + frame_size, + frame_size * 3)); + PADDLE_ENFORCE_EQ( + weight_width, + frame_size * 3, + phi::errors::InvalidArgument( + "The shape of Input(Weight) matrix must be [frame_size, frame_size " + "* 3], but received [%d, %d] (Weight) vs [%d, %d] (frame_size).", + weight_height, + weight_width, + frame_size, + frame_size * 3)); + if (h0.initialized()) { + const auto& h0_dims = h0.dims(); + PADDLE_ENFORCE_EQ( + h0_dims[1], + frame_size, + phi::errors::InvalidArgument( + "The width of Input(H0) must be equal to frame_size, but " + "received %d (width of H0) vs %d (frame_size).", + h0_dims[1], + frame_size)); + if (h0_grad->initialized()) { + h0_grad->set_dims(h0_dims); + h0_grad->set_dtype(h0.dtype()); + } + } + if (bias.initialized()) { + const auto& bias_dims = bias.dims(); + int bias_height = static_cast(bias_dims[0]); + int bias_width = static_cast(bias_dims[1]); + PADDLE_ENFORCE_EQ( + bias_height, + 1, + phi::errors::InvalidArgument( + "The shape of Bias must be [1, frame_size * 3], but received " + "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).", + bias_height, + bias_width, + frame_size * 3)); + PADDLE_ENFORCE_EQ( + bias_width, + frame_size * 3, + phi::errors::InvalidArgument( + "The shape of Bias must be [1, frame_size * 3], but received " + "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).", + bias_height, + bias_width, + frame_size * 3)); + if (bias_grad->initialized()) { + bias_grad->set_dims(bias_dims); + bias_grad->set_dtype(bias.dtype()); + } + } + if (input_grad->initialized()) { + input_grad->set_dims(input_dims); + input_grad->set_dtype(input.dtype()); + } + if (weight_grad->initialized()) { + weight_grad->set_dims(weight_dims); + weight_grad->set_dtype(weight.dtype()); + } +} + void GumbelSoftmaxGradInferMeta(const MetaTensor& out, const MetaTensor& dout, int axis, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 39b59958d6752..e538609d87c93 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -272,6 +272,16 @@ void GeneralQuinaryGradInferMeta(const MetaTensor& x, MetaTensor* dk, MetaTensor* dl); +void GruGradInferMeta(const MetaTensor& input, + const MetaTensor& h0, + const MetaTensor& weight, + const MetaTensor& bias, + MetaTensor* input_grad, + MetaTensor* h0_grad, + MetaTensor* weight_grad, + MetaTensor* bias_grad, + MetaConfig config = MetaConfig()); + void GumbelSoftmaxGradInferMeta(const MetaTensor& out, const MetaTensor& dout, int axis, diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 65de2d4e3ce21..02b051b11d4e1 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2375,6 +2375,91 @@ void GraphReindexInferMeta(const MetaTensor& x, out_nodes->set_dtype(x.dtype()); } +void GruInferMeta(const MetaTensor& input, + const MetaTensor& h0, + const MetaTensor& weight, + const MetaTensor& bias, + const std::string& activation, + const std::string& gate_activation, + bool is_reverse, + bool origin_mode, + bool is_test, + MetaTensor* batch_gate, + MetaTensor* batch_reset_hidden_prev, + MetaTensor* batch_hidden, + MetaTensor* hidden, + MetaConfig config) { + const auto& input_dims = input.dims(); + const auto& weight_dims = weight.dims(); + int input_size = static_cast(input_dims[1]); + int frame_size = static_cast(weight_dims[0]); + if (config.is_runtime) { + PADDLE_ENFORCE_EQ(input_size, + frame_size * 3, + phi::errors::InvalidArgument( + "The second dimension of Input(Input) must be 3 " + "times of frame_size in GRUOp, but received %d " + "(Input) vs %d (frame_size).", + input_size, + frame_size)); + } + PADDLE_ENFORCE_EQ( + weight_dims[1], + frame_size * 3, + phi::errors::InvalidArgument( + "The shape of Input(Weight) matrix must be [frame_size, frame_size " + "* 3], but received [%d, %d] (Weight) vs [%d, %d] (frame_size).", + weight_dims[0], + weight_dims[1], + frame_size, + frame_size * 3)); + if (h0.initialized()) { + const auto& h0_dims = h0.dims(); + PADDLE_ENFORCE_EQ( + h0_dims[1], + frame_size, + phi::errors::InvalidArgument( + "The width of Input(H0) must be equal to frame_size, but " + "received %d (width of H0) vs %d (frame_size).", + h0_dims[1], + frame_size)); + } + if (bias.initialized()) { + const auto& bias_dims = bias.dims(); + int bias_height = static_cast(bias_dims[0]); + int bias_width = static_cast(bias_dims[1]); + PADDLE_ENFORCE_EQ( + bias_height, + 1, + phi::errors::InvalidArgument( + "The shape of Bias must be [1, frame_size * 3], but received " + "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).", + bias_height, + bias_width, + frame_size * 3)); + PADDLE_ENFORCE_EQ( + bias_width, + frame_size * 3, + phi::errors::InvalidArgument( + "The shape of Bias must be [1, frame_size * 3], but received " + "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).", + bias_height, + bias_width, + frame_size * 3)); + } + if (!is_test) { + batch_gate->set_dims(input_dims); + batch_gate->set_dtype(input.dtype()); + batch_reset_hidden_prev->set_dims({input_dims[0], frame_size}); + batch_reset_hidden_prev->set_dtype(input.dtype()); + batch_hidden->set_dims({input_dims[0], frame_size}); + batch_hidden->set_dtype(input.dtype()); + } + hidden->set_dims({input_dims[0], frame_size}); + hidden->set_dtype(input.dtype()); + hidden->share_lod(input); +} + void GraphSampleNeighborsInferMeta(const MetaTensor& row, const MetaTensor& col_ptr, const MetaTensor& x, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index d60f8b0f3c443..a71d891ea70c9 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -492,6 +492,21 @@ void GraphReindexInferMeta(const MetaTensor& x, MetaTensor* reindex_dst, MetaTensor* out_nodes); +void GruInferMeta(const MetaTensor& input, + const MetaTensor& h0, + const MetaTensor& weight, + const MetaTensor& bias, + const std::string& activation, + const std::string& gate_activation, + bool is_reverse, + bool origin_mode, + bool is_test, + MetaTensor* batch_gate, + MetaTensor* batch_reset_hidden_prev, + MetaTensor* batch_hidden, + MetaTensor* hidden, + MetaConfig config = MetaConfig()); + void GraphSampleNeighborsInferMeta(const MetaTensor& row, const MetaTensor& col_ptr, const MetaTensor& x, diff --git a/paddle/phi/kernels/cpu/gru_grad_kernel.cc b/paddle/phi/kernels/cpu/gru_grad_kernel.cc new file mode 100644 index 0000000000000..6581de6e8a01c --- /dev/null +++ b/paddle/phi/kernels/cpu/gru_grad_kernel.cc @@ -0,0 +1,19 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/gru_kernel_impl.h" + +PD_REGISTER_KERNEL( + gru_grad, CPU, ALL_LAYOUT, phi::GRUGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/gru_kernel.cc b/paddle/phi/kernels/cpu/gru_kernel.cc new file mode 100644 index 0000000000000..2850bf2b4a524 --- /dev/null +++ b/paddle/phi/kernels/cpu/gru_kernel.cc @@ -0,0 +1,238 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/funcs/detail/gru_kernel.h" +#include +#include +#include "paddle/common/flags.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/detail/gru_cpu_kernel.h" +#include "paddle/phi/kernels/impl/gru_kernel_impl.h" + +COMMON_DECLARE_int32(paddle_num_threads); + +namespace phi { + +template +void GRUCPUKernel(const Context &dev_ctx, + const DenseTensor &input, + const paddle::optional &h0, + const DenseTensor &weight, + const paddle::optional &bias, + const std::string &activation, + const std::string &gate_activation, + bool is_reverse, + bool origin_mode, + bool is_test, + DenseTensor *param_batch_gate, + DenseTensor *param_batch_reset_hidden_prev, + DenseTensor *param_batch_hidden, + DenseTensor *hidden) { + const T *weight_data = weight.data(); + dev_ctx.template Alloc(hidden); + + auto input_dims = input.dims(); + auto hidden_dims = hidden->dims(); + + phi::DenseTensor *batch_gate = nullptr; + phi::DenseTensor *batch_reset_hidden_prev = nullptr; + phi::DenseTensor *batch_hidden = nullptr; + phi::DenseTensor batch_gate_tmp, batch_reset_hidden_prev_tmp, + batch_hidden_tmp; + if (is_test) { + batch_gate = &batch_gate_tmp; + batch_gate->Resize(input_dims); + + batch_reset_hidden_prev = &batch_reset_hidden_prev_tmp; + batch_reset_hidden_prev->Resize(hidden_dims); + + batch_hidden = &batch_hidden_tmp; + batch_hidden->Resize(hidden_dims); + } else { + batch_gate = param_batch_gate; + batch_hidden = param_batch_hidden; + batch_reset_hidden_prev = param_batch_reset_hidden_prev; + } + dev_ctx.template Alloc(batch_gate); + dev_ctx.template Alloc(batch_reset_hidden_prev); + dev_ctx.template Alloc(batch_hidden); + + phi::funcs::LoDTensor2BatchFunctor to_batch; + to_batch(dev_ctx, input, batch_gate, true, is_reverse); + + if (bias) { + phi::funcs::RowwiseAdd add_bias; + add_bias(dev_ctx, *batch_gate, bias, batch_gate); + } + + int frame_size = static_cast(hidden_dims[1]); + phi::funcs::GRUMetaValue gru_value; + gru_value.gate_weight = const_cast(weight_data); + gru_value.state_weight = + const_cast(weight_data + 2 * frame_size * frame_size); + phi::DenseTensor ordered_h0; + + phi::Vector order(batch_gate->lod()[2]); + + if (h0) { + // Since the batch computing for GRU reorders the input sequences + // according to their length. The initialized cell state also needs + // to reorder. + ReorderInitState(dev_ctx, *h0, order, &ordered_h0, true); + gru_value.prev_out_value = ordered_h0.data(); + } else { + gru_value.prev_out_value = nullptr; + } + auto batch_starts = batch_gate->lod()[0]; + size_t seq_len = batch_starts.size() - 1; + auto active_node = phi::funcs::detail::GetActivationType(activation); + auto active_gate = phi::funcs::detail::GetActivationType(gate_activation); + +#ifdef PADDLE_WITH_MKLML + // use MKL packed to speedup GEMM + if (FLAGS_paddle_num_threads >= 4) { + auto blas = phi::funcs::GetBlas(dev_ctx); + T *packed_gate = blas.GEMM_ALLOC(CblasBMatrix, + 1 /*height of C*/, + frame_size * 2 /*width of weight*/, + frame_size /*height of height*/); + PADDLE_ENFORCE_NOT_NULL( + packed_gate, + phi::errors::NotFound("The calculation result of packed_gate by " + "GEMM_ALLOC should not be null when using MKL.")); + blas.GEMM_PACK(CblasBMatrix, + CblasNoTrans, + 1 /*cur bs?*/, + frame_size * 2, + frame_size, + T(1.0), + gru_value.gate_weight, + frame_size * 2, + packed_gate); + T *packed_state = blas.GEMM_ALLOC(CblasBMatrix, + 1 /*height of C*/, + frame_size /*width of weight*/, + frame_size /*height of height*/); + PADDLE_ENFORCE_NOT_NULL( + packed_state, + phi::errors::NotFound("The calculation result of packed_state by " + "GEMM_ALLOC should not be null when using MKL.")); + blas.GEMM_PACK(CblasBMatrix, + CblasNoTrans, + 1 /*cur bs?*/, + frame_size, + frame_size, + T(1.0), + gru_value.state_weight, + frame_size, + packed_state); + for (size_t n = 0; n < seq_len; n++) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + int cur_batch_size = bend - bstart; + + phi::DenseTensor gate_t = batch_gate->Slice(bstart, bend); + phi::DenseTensor reset_hidden_prev_t = + batch_reset_hidden_prev->Slice(bstart, bend); + phi::DenseTensor hidden_t = batch_hidden->Slice(bstart, bend); + gru_value.output_value = hidden_t.data(); + gru_value.gate_value = gate_t.data(); + gru_value.reset_output_value = reset_hidden_prev_t.data(); + + if (gru_value.prev_out_value) { + blas.GEMM_COMPUTE(CblasNoTrans, + CblasPacked, + cur_batch_size, + frame_size * 2, + frame_size, + gru_value.prev_out_value, + frame_size, + packed_gate, + frame_size * 2, + T(1), + gru_value.gate_value, + frame_size * 3); + } + + phi::funcs::detail::forward_reset_output( + phi::funcs::detail::forward::gru_resetOutput(), + gru_value, + frame_size, + cur_batch_size, + active_gate); + + if (gru_value.prev_out_value) { + blas.GEMM_COMPUTE(CblasNoTrans, + CblasPacked, + cur_batch_size, + frame_size, + frame_size, + gru_value.reset_output_value, + frame_size, + packed_state, + frame_size, + T(1), + gru_value.gate_value + frame_size * 2, + frame_size * 3); + } + + phi::funcs::detail::forward_final_output( + phi::funcs::detail::forward::gru_finalOutput(), + gru_value, + frame_size, + cur_batch_size, + active_node, + origin_mode); + + gru_value.prev_out_value = gru_value.output_value; + } + + blas.GEMM_FREE(packed_gate); + blas.GEMM_FREE(packed_state); + } else { +#endif + for (size_t n = 0; n < seq_len; n++) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + int cur_batch_size = bend - bstart; + + phi::DenseTensor gate_t = batch_gate->Slice(bstart, bend); + phi::DenseTensor reset_hidden_prev_t = + batch_reset_hidden_prev->Slice(bstart, bend); + phi::DenseTensor hidden_t = batch_hidden->Slice(bstart, bend); + gru_value.output_value = hidden_t.data(); + gru_value.gate_value = gate_t.data(); + gru_value.reset_output_value = reset_hidden_prev_t.data(); + + phi::funcs::GRUUnitFunctor::compute(dev_ctx, // NOLINT + gru_value, + frame_size, + cur_batch_size, + active_node, + active_gate, + origin_mode); + + gru_value.prev_out_value = gru_value.output_value; + } +#ifdef PADDLE_WITH_MKLML + } +#endif + phi::funcs::Batch2LoDTensorFunctor to_seq; + batch_hidden->set_lod(batch_gate->lod()); + to_seq(dev_ctx, *batch_hidden, hidden); +} + +} // namespace phi +PD_REGISTER_KERNEL(gru, CPU, ALL_LAYOUT, phi::GRUCPUKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/gru_grad_kernel.cu b/paddle/phi/kernels/gpu/gru_grad_kernel.cu new file mode 100644 index 0000000000000..b21707ff25f7f --- /dev/null +++ b/paddle/phi/kernels/gpu/gru_grad_kernel.cu @@ -0,0 +1,19 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/gru_kernel_impl.h" + +PD_REGISTER_KERNEL( + gru_grad, GPU, ALL_LAYOUT, phi::GRUGradKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/gru_kernel.cu b/paddle/phi/kernels/gpu/gru_kernel.cu new file mode 100644 index 0000000000000..c9558a05f503d --- /dev/null +++ b/paddle/phi/kernels/gpu/gru_kernel.cu @@ -0,0 +1,122 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/gru_kernel_impl.h" + +namespace phi { + +template +void GRUKernel(const Context &dev_ctx, + const DenseTensor &input, + const paddle::optional &h0, + const DenseTensor &weight, + const paddle::optional &bias, + const std::string &activation, + const std::string &gate_activation, + bool is_reverse, + bool origin_mode, + bool is_test, + DenseTensor *param_batch_gate, + DenseTensor *param_batch_reset_hidden_prev, + DenseTensor *param_batch_hidden, + DenseTensor *hidden) { + const T *weight_data = weight.data(); + dev_ctx.template Alloc(hidden); + + auto input_dims = input.dims(); + auto hidden_dims = hidden->dims(); + + phi::DenseTensor *batch_gate; + phi::DenseTensor *batch_reset_hidden_prev; + phi::DenseTensor *batch_hidden; + phi::DenseTensor batch_gate_tmp, batch_reset_hidden_prev_tmp, + batch_hidden_tmp; + if (is_test) { + batch_gate = &batch_gate_tmp; + batch_gate->Resize(input_dims); + + batch_reset_hidden_prev = &batch_reset_hidden_prev_tmp; + batch_reset_hidden_prev->Resize(hidden_dims); + + batch_hidden = &batch_hidden_tmp; + batch_hidden->Resize(hidden_dims); + } else { + batch_gate = param_batch_gate; + batch_hidden = param_batch_hidden; + batch_reset_hidden_prev = param_batch_reset_hidden_prev; + } + dev_ctx.template Alloc(batch_gate); + dev_ctx.template Alloc(batch_reset_hidden_prev); + dev_ctx.template Alloc(batch_hidden); + + phi::funcs::LoDTensor2BatchFunctor to_batch; + to_batch(dev_ctx, input, batch_gate, true, is_reverse); + + if (bias) { + phi::funcs::RowwiseAdd add_bias; + add_bias(dev_ctx, *batch_gate, *bias, batch_gate); + } + + int frame_size = hidden_dims[1]; + phi::funcs::GRUMetaValue gru_value; + gru_value.gate_weight = const_cast(weight_data); + gru_value.state_weight = + const_cast(weight_data + 2 * frame_size * frame_size); + phi::DenseTensor ordered_h0; + + phi::Vector order(batch_gate->lod()[2]); + + if (h0) { + // Since the batch computing for GRU reorders the input sequences + // according to their length. The initialized cell state also needs + // to reorder. + ReorderInitState(dev_ctx, *h0, order, &ordered_h0, true); + gru_value.prev_out_value = ordered_h0.data(); + } else { + gru_value.prev_out_value = nullptr; + } + auto batch_starts = batch_gate->lod()[0]; + size_t num_batch = batch_starts.size() - 1; + auto active_node = phi::funcs::detail::GetActivationType(activation); + auto active_gate = phi::funcs::detail::GetActivationType(gate_activation); + for (size_t n = 0; n < num_batch; n++) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + int cur_batch_size = bend - bstart; + + phi::DenseTensor gate_t = batch_gate->Slice(bstart, bend); + phi::DenseTensor reset_hidden_prev_t = + batch_reset_hidden_prev->Slice(bstart, bend); + phi::DenseTensor hidden_t = batch_hidden->Slice(bstart, bend); + gru_value.output_value = hidden_t.data(); + gru_value.gate_value = gate_t.data(); + gru_value.reset_output_value = reset_hidden_prev_t.data(); + phi::funcs::GRUUnitFunctor::compute(dev_ctx, // NOLINT + gru_value, + frame_size, + cur_batch_size, + active_node, + active_gate, + origin_mode); + gru_value.prev_out_value = gru_value.output_value; + } + + phi::funcs::Batch2LoDTensorFunctor to_seq; + batch_hidden->set_lod(batch_gate->lod()); + to_seq(dev_ctx, *batch_hidden, hidden); +} +} // namespace phi + +PD_REGISTER_KERNEL(gru, GPU, ALL_LAYOUT, phi::GRUKernel, float, double) {} diff --git a/paddle/phi/kernels/impl/gru_kernel_impl.h b/paddle/phi/kernels/impl/gru_kernel_impl.h new file mode 100644 index 0000000000000..efff9056b0d47 --- /dev/null +++ b/paddle/phi/kernels/impl/gru_kernel_impl.h @@ -0,0 +1,173 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/kernels/funcs/detail/activation_functions.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/gru_compute.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/sequence2batch.h" +#include "paddle/utils/optional.h" + +namespace phi { + +template +inline void ReorderInitState(const Context &dev_ctx, + const phi::DenseTensor &src, + phi::Vector index_lod, + phi::DenseTensor *dst, + bool indexed_src) { + phi::funcs::CopyMatrixRowsFunctor row_shuffle; + dst->Resize(src.dims()); + dev_ctx.template Alloc(dst); + row_shuffle(dev_ctx, src, index_lod, dst, indexed_src); +} + +template +void GRUGradKernel(const Context &dev_ctx, + const DenseTensor &input, + const paddle::optional &h0_param, + const DenseTensor &weight, + const paddle::optional &bias, + const DenseTensor &batch_gate, + const DenseTensor &batch_reset_hidden_prev, + const DenseTensor &batch_hidden, + const DenseTensor &hidden, + const DenseTensor &hidden_grad, + const std::string &activation, + const std::string &gate_activation, + bool is_reverse, + bool origin_mode, + bool is_test, + DenseTensor *input_grad, + DenseTensor *h0_grad, + DenseTensor *weight_grad, + DenseTensor *bias_grad) { + auto *h0 = h0_param.get_ptr(); + const T *weight_data = weight.data(); + + auto gate_dims = batch_gate.dims(); + auto hidden_dims = hidden.dims(); + int frame_size = hidden_dims[1]; + + phi::funcs::LoDTensor2BatchFunctor to_batch; + phi::DenseTensor batch_hidden_grad, batch_gate_grad, + batch_reset_hidden_prev_grad; + batch_hidden_grad.Resize(hidden_dims); + batch_gate_grad.Resize(gate_dims); + batch_reset_hidden_prev_grad.Resize(hidden_dims); + dev_ctx.template Alloc(&batch_hidden_grad); + dev_ctx.template Alloc(&batch_gate_grad); + dev_ctx.template Alloc(&batch_reset_hidden_prev_grad); + + phi::funcs::SetConstant zero; + zero(dev_ctx, &batch_hidden_grad, static_cast(0.0)); + zero(dev_ctx, &batch_gate_grad, static_cast(0.0)); + zero(dev_ctx, &batch_reset_hidden_prev_grad, static_cast(0.0)); + + phi::DenseTensor ordered_h0, ordered_h0_grad; + + phi::Vector order(batch_gate.lod()[2]); + + if (h0) { + ReorderInitState(dev_ctx, *h0, order, &ordered_h0, true); + } + if (h0_grad) { + ordered_h0_grad.Resize(h0_grad->dims()); + dev_ctx.template Alloc(&ordered_h0_grad); + zero(dev_ctx, &ordered_h0_grad, static_cast(0.0)); + } + + batch_hidden_grad.set_lod(batch_hidden.lod()); + to_batch(dev_ctx, hidden_grad, &batch_hidden_grad, false, is_reverse); + + phi::funcs::GRUMetaValue gru_value; + gru_value.gate_weight = const_cast(weight_data); + gru_value.state_weight = + const_cast(weight_data + 2 * frame_size * frame_size); + + phi::funcs::GRUMetaGrad gru_grad; + if (weight_grad) { + gru_grad.gate_weight_grad = dev_ctx.template Alloc(weight_grad); + zero(dev_ctx, weight_grad, static_cast(0.0)); + gru_grad.state_weight_grad = + weight_grad->data() + 2 * frame_size * frame_size; + } else { + gru_grad.gate_weight_grad = nullptr; + gru_grad.state_weight_grad = nullptr; + } + + auto batch_starts = batch_hidden_grad.lod()[0]; + size_t num_batch = batch_starts.size() - 1; + auto active_node = phi::funcs::detail::GetActivationType(activation); + auto active_gate = phi::funcs::detail::GetActivationType(gate_activation); + for (int n = static_cast(num_batch) - 1; n >= 0; n--) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + int cur_batch_size = bend - bstart; + + phi::DenseTensor gate_t = batch_gate.Slice(bstart, bend); + gru_value.gate_value = gate_t.data(); + phi::DenseTensor reset_hidden_prev_t = + batch_reset_hidden_prev.Slice(bstart, bend); + gru_value.reset_output_value = reset_hidden_prev_t.data(); + + phi::DenseTensor hidden_grad_t = batch_hidden_grad.Slice(bstart, bend); + gru_grad.output_grad = hidden_grad_t.data(); + phi::DenseTensor gate_grad_t = batch_gate_grad.Slice(bstart, bend); + gru_grad.gate_grad = gate_grad_t.data(); + phi::DenseTensor reset_hidden_prev_grad_t = + batch_reset_hidden_prev_grad.Slice(bstart, bend); + gru_grad.reset_output_grad = reset_hidden_prev_grad_t.data(); + if (n == 0) { + gru_value.prev_out_value = h0 ? ordered_h0.data() : nullptr; + gru_grad.prev_out_grad = + h0 && h0_grad ? ordered_h0_grad.data() : nullptr; + } else { + int bstart_pre = static_cast(batch_starts[n - 1]); + phi::DenseTensor hidden_prev_t = batch_hidden.Slice(bstart_pre, bstart); + gru_value.prev_out_value = hidden_prev_t.data(); + phi::DenseTensor hidden_prev_grad_t = + batch_hidden_grad.Slice(bstart_pre, bstart); + gru_grad.prev_out_grad = hidden_prev_grad_t.data(); + } + gru_value.output_value = nullptr; + phi::funcs::GRUUnitGradFunctor::compute(dev_ctx, + gru_value, + gru_grad, + frame_size, + cur_batch_size, + active_node, + active_gate, + origin_mode); + } + if (input_grad) { + dev_ctx.template Alloc(input_grad); + phi::funcs::Batch2LoDTensorFunctor to_seq; + batch_gate_grad.set_lod(batch_gate.lod()); + to_seq(dev_ctx, batch_gate_grad, input_grad); + } + if (bias_grad) { + dev_ctx.template Alloc(bias_grad); + phi::funcs::ColwiseSum col_sum; + col_sum(dev_ctx, batch_gate_grad, bias_grad); + } + if (h0_param && h0_grad) { + ReorderInitState( + dev_ctx, ordered_h0_grad, order, h0_grad, false); + } +} +} // namespace phi diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index 01b0c1025b8e9..9f767535df3ca 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -1167,6 +1167,25 @@ optional: scale, bias inplace : (y_grad -> x_grad) +- backward_op : gru_grad + forward: gru (Tensor input, Tensor h0, Tensor weight, Tensor bias, str activation = "tanh", + str gate_activation = "sigmoid", bool is_reverse = false, bool origin_mode = false, bool is_test=false) -> + Tensor (batch_gate), Tensor (batch_reset_hidden_prev), Tensor (batch_hidden), + Tensor (hidden) + args: (Tensor input, Tensor h0, Tensor weight, Tensor bias, Tensor batch_gate, + Tensor batch_reset_hidden_prev, Tensor batch_hidden, Tensor hidden, + Tensor hidden_grad, str activation = "tanh", + str gate_activation = "sigmoid", bool is_reverse = false, bool origin_mode = false, bool is_test=false) + output: Tensor(input_grad), Tensor(h0_grad), Tensor(weight_grad), Tensor(bias_grad) + infer_meta: + func: GruGradInferMeta + param: [input, h0, weight, bias] + kernel: + func: gru_grad + data_type: hidden_grad + optional: h0, bias + no_need_buffer: input, bias + - backward_op : gumbel_softmax_grad forward : gumbel_softmax (Tensor x, float temperature, bool hard, int axis) -> Tensor(out) args : (Tensor out, Tensor out_grad, int axis) diff --git a/paddle/phi/ops/yaml/op_compat.yaml b/paddle/phi/ops/yaml/op_compat.yaml index 3435ea6c46789..5a2fae503e834 100755 --- a/paddle/phi/ops/yaml/op_compat.yaml +++ b/paddle/phi/ops/yaml/op_compat.yaml @@ -4069,6 +4069,13 @@ outputs : out : Out +- op: gru + backward: gru_grad + inputs: + {input : Input, h0 : H0, weight : Weight, bias : Bias} + outputs: + {batch_gate : BatchGate, batch_reset_hidden_prev : BatchResetHiddenPrev, batch_hidden : BatchHidden, hidden : Hidden} + - op: identity_loss inputs : x: X diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index a84f7f337af37..813c1a8681921 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -1548,6 +1548,20 @@ backward : group_norm_grad interfaces : paddle::dialect::InferSymbolicShapeInterface +- op : gru + args: (Tensor input, Tensor h0, Tensor weight, Tensor bias, str activation = "tanh", + str gate_activation = "sigmoid", bool is_reverse = false, bool origin_mode = false, bool is_test=false) + output: Tensor (batch_gate), Tensor (batch_reset_hidden_prev), Tensor (batch_hidden), + Tensor (hidden) + infer_meta: + func: GruInferMeta + kernel: + func: gru + data_type: input + optional: h0, bias + intermediate: batch_gate, batch_reset_hidden_prev, batch_hidden + backward: gru_grad + - op : gumbel_softmax args : (Tensor x, float temperature = 1.0, bool hard = false, int axis = -1) output : Tensor