From ad10086c73d75abe1e9e0b2cec1dac143a65f501 Mon Sep 17 00:00:00 2001 From: co63oc Date: Sat, 25 May 2024 16:40:27 +0800 Subject: [PATCH 01/12] Fix --- paddle/fluid/operators/gru_unit_op.cc | 329 ---------------- paddle/fluid/operators/gru_unit_op.cu | 21 -- paddle/fluid/operators/gru_unit_op.h | 347 ----------------- paddle/phi/infermeta/backward.cc | 89 +++++ paddle/phi/infermeta/backward.h | 10 + paddle/phi/infermeta/multiary.cc | 84 +++++ paddle/phi/infermeta/multiary.h | 12 + .../phi/kernels/cpu/gru_unit_grad_kernel.cc | 19 + paddle/phi/kernels/cpu/gru_unit_kernel.cc | 19 + .../phi/kernels/gpu/gru_unit_grad_kernel.cu | 19 + paddle/phi/kernels/gpu/gru_unit_kernel.cu | 19 + .../phi/kernels/impl/gru_unit_kernel_impl.h | 355 ++++++++++++++++++ paddle/phi/ops/yaml/backward.yaml | 15 + paddle/phi/ops/yaml/op_compat.yaml | 7 + paddle/phi/ops/yaml/ops.yaml | 12 + 15 files changed, 660 insertions(+), 697 deletions(-) delete mode 100644 paddle/fluid/operators/gru_unit_op.cc delete mode 100644 paddle/fluid/operators/gru_unit_op.cu delete mode 100644 paddle/fluid/operators/gru_unit_op.h create mode 100644 paddle/phi/kernels/cpu/gru_unit_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/gru_unit_kernel.cc create mode 100644 paddle/phi/kernels/gpu/gru_unit_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/gru_unit_kernel.cu create mode 100644 paddle/phi/kernels/impl/gru_unit_kernel_impl.h diff --git a/paddle/fluid/operators/gru_unit_op.cc b/paddle/fluid/operators/gru_unit_op.cc deleted file mode 100644 index 5a29abda1f369..0000000000000 --- a/paddle/fluid/operators/gru_unit_op.cc +++ /dev/null @@ -1,329 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/gru_unit_op.h" - -#include - -namespace paddle { -namespace operators { - -class GRUUnitOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "GRUUnit"); - OP_INOUT_CHECK( - ctx->HasInput("HiddenPrev"), "Input", "HiddenPrev", "GRUUnit"); - OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "GRUUnit"); - OP_INOUT_CHECK(ctx->HasOutput("Gate"), "Output", "Gate", "GRUUnit"); - OP_INOUT_CHECK(ctx->HasOutput("ResetHiddenPrev"), - "Output", - "ResetHiddenPrev", - "GRUUnit"); - OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "GRUUnit"); - auto input_dims = ctx->GetInputDim("Input"); - auto hidden_prev_dims = ctx->GetInputDim("HiddenPrev"); - auto weight_dims = ctx->GetInputDim("Weight"); - int batch_size = static_cast(input_dims[0]); - int input_size = static_cast(input_dims[1]); - int frame_size = static_cast(hidden_prev_dims[1]); - int weight_height = static_cast(weight_dims[0]); - int weight_width = static_cast(weight_dims[1]); - if (ctx->IsRuntime() || input_size >= 0) { - PADDLE_ENFORCE_EQ(input_size, - frame_size * 3, - phi::errors::InvalidArgument( - "The second dimension of Input(Input) must be 3 " - "times of frame_size in GRUUnitOp, but received %d " - "(Input) vs %d (frame_size).", - input_size, - frame_size)); - } - PADDLE_ENFORCE_EQ( - weight_height, - frame_size, - phi::errors::InvalidArgument( - "The shape of Input(Weight) matrix must be [frame_size, frame_size " - "* 3] in GRUUnitOp, but received [%d, %d] (Weight) vs [%d, %d] " - "(frame_size).", - weight_height, - weight_width, - frame_size, - frame_size * 3)); - PADDLE_ENFORCE_EQ( - weight_width, - frame_size * 3, - phi::errors::InvalidArgument( - "The shape of Input(Weight) matrix must be [frame_size, frame_size " - "* 3] in GRUUnitOp, but received [%d, %d] (Weight) vs [%d, %d] " - "(frame_size).", - weight_height, - weight_width, - frame_size, - frame_size * 3)); - - if (ctx->HasInput("Bias")) { - auto bias_dims = ctx->GetInputDim("Bias"); - int bias_height = static_cast(bias_dims[0]); - int bias_width = static_cast(bias_dims[1]); - PADDLE_ENFORCE_EQ( - bias_height, - 1, - phi::errors::InvalidArgument( - "The shape of Bias must be [1, frame_size * 3], but received " - "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).", - bias_height, - bias_width, - frame_size * 3)); - PADDLE_ENFORCE_EQ( - bias_width, - frame_size * 3, - phi::errors::InvalidArgument( - "The shape of Bias must be [1, frame_size * 3], but received " - "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).", - bias_height, - bias_width, - frame_size * 3)); - } - ctx->SetOutputDim("Gate", {batch_size, frame_size * 3}); - ctx->SetOutputDim("ResetHiddenPrev", {batch_size, frame_size}); - ctx->SetOutputDim("Hidden", {batch_size, frame_size}); - } -}; - -class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("Input", - "(Tensor) Matrix with shape [batch_size, frame_size * 3] for the " - "input."); - AddInput("HiddenPrev", - "(Tensor) Matrix with shape [batch_size, frame_size] for the " - "states of previous time step."); - AddInput( - "Weight", - "(Tensor) Weight matrix with shape [frame_size, frame_size * 3]. " - "The elements continuous in memory can be divided into two parts. " - "The first part are weights of the update gate and reset gate " - "with shape [frame_size, frame_size * 2], and the second part are " - "weights of output candidate with shape [frame_size, frame_size]."); - AddInput( - "Bias", - "(Tensor) Bias vector with shape [1, frame_size * 3] concatenating " - "bias of the update gate, reset gate and output candidate.") - .AsDispensable(); - AddOutput("Gate", - "(Tensor) Matrix with shape [batch_size, frame_size * 3] for the " - "output of update gate, reset gate and output candidate.") - .AsIntermediate(); - AddOutput("ResetHiddenPrev", - "(Tensor) Matrix with shape [batch_size, frame_size] for the " - "reset hidden state of previous time step.") - .AsIntermediate(); - AddOutput("Hidden", - "(Tensor) The GRU hidden state of the current time step " - "with shape [batch_size, frame_size]."); - AddAttr("activation", - "(enum int, default tanh) " - "The activation type used for output candidate {h}_t.") - .SetDefault(tanh) - .InEnum({identity, sigmoid, tanh, relu}); - AddAttr("gate_activation", - "(enum int, default sigmoid) " - "The activation type used in update gate and reset gate.") - .SetDefault(sigmoid) - .InEnum({identity, sigmoid, tanh, relu}); - AddAttr("origin_mode", - "bool" - "use origin mode in article (https://arxiv.org/pdf/1406.1078.pdf)") - .SetDefault(false); - AddComment(R"DOC( -GRUUnit Operator implements partial calculations of the GRU unit as following: - -$$ -update \ gate: u_t = actGate(xu_t + W_u * h_{t-1} + b_u) \\ -reset \ gate: r_t = actGate(xr_t + W_r * h_{t-1} + b_r) \\ -output \ candidate: {h}_t = actNode(xc_t + W_c * dot(r_t, h_{t-1}) + b_c) \\ -output: h_t = dot((1 - u_t), h_{t-1}) + dot(u_t, {h}_t) -$$ - -which is same as one time step of GRU Operator. - -@note To implement the complete GRU unit, fully-connected operator must be -used before to feed xu, xr and xc as the Input of GRUUnit operator. - -)DOC"); - } -}; - -class GRUUnitGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "GRUUnitGrad"); - OP_INOUT_CHECK( - ctx->HasInput("HiddenPrev"), "Input", "HiddenPrev", "GRUUnitGrad"); - OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "GRUUnitGrad"); - OP_INOUT_CHECK(ctx->HasInput("Gate"), "Input", "Gate", "GRUUnitGrad"); - OP_INOUT_CHECK(ctx->HasInput("ResetHiddenPrev"), - "Input", - "ResetHiddenPrev", - "GRUUnitGrad"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Hidden")), - "Input", - "Hidden@GRAD", - "GRUUnitGrad"); - - auto input_dims = ctx->GetInputDim("Input"); - auto hidden_prev_dims = ctx->GetInputDim("HiddenPrev"); - auto weight_dims = ctx->GetInputDim("Weight"); - // int batch_size = input_dims[0]; - int input_size = static_cast(input_dims[1]); - int frame_size = static_cast(hidden_prev_dims[1]); - int weight_height = static_cast(weight_dims[0]); - int weight_width = static_cast(weight_dims[1]); - if (ctx->IsRuntime() || input_size >= 0) { - PADDLE_ENFORCE_EQ( - input_size, - frame_size * 3, - phi::errors::InvalidArgument( - "The second dimension of Input(Input) must be 3 " - "times of frame_size in GRUUnitGradOp, but received %d " - "(Input) vs %d (frame_size).", - input_size, - frame_size)); - } - PADDLE_ENFORCE_EQ( - weight_height, - frame_size, - phi::errors::InvalidArgument( - "The shape of Input(Weight) matrix must be [frame_size, frame_size " - "* 3] in GRUUnitGradOp, but received [%d, %d] (Weight) vs [%d, %d] " - "(frame_size).", - weight_height, - weight_width, - frame_size, - frame_size * 3)); - PADDLE_ENFORCE_EQ( - weight_width, - frame_size * 3, - phi::errors::InvalidArgument( - "The shape of Input(Weight) matrix must be [frame_size, frame_size " - "* 3] in GRUUnitGradOp, but received [%d, %d] (Weight) vs [%d, %d] " - "(frame_size).", - weight_height, - weight_width, - frame_size, - frame_size * 3)); - if (ctx->HasInput("Bias")) { - auto bias_dims = ctx->GetInputDim("Bias"); - int bias_height = static_cast(bias_dims[0]); - int bias_width = static_cast(bias_dims[1]); - - PADDLE_ENFORCE_EQ( - bias_height, - 1, - phi::errors::InvalidArgument( - "The shape of Bias must be [1, frame_size * 3], but received " - "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).", - bias_height, - bias_width, - frame_size * 3)); - PADDLE_ENFORCE_EQ( - bias_width, - frame_size * 3, - phi::errors::InvalidArgument( - "The shape of Bias must be [1, frame_size * 3], but received " - "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).", - bias_height, - bias_width, - frame_size * 3)); - auto bias_grad_name = framework::GradVarName("Bias"); - if (ctx->HasOutput(bias_grad_name)) - ctx->SetOutputDim(bias_grad_name, bias_dims); - } - auto input_grad_name = framework::GradVarName("Input"); - if (ctx->HasOutput(input_grad_name)) - ctx->SetOutputDim(input_grad_name, input_dims); - auto hidden_prev_grad_name = framework::GradVarName("HiddenPrev"); - if (ctx->HasOutput(hidden_prev_grad_name)) - ctx->SetOutputDim(hidden_prev_grad_name, hidden_prev_dims); - auto weight_grad_name = framework::GradVarName("Weight"); - if (ctx->HasOutput(weight_grad_name)) - ctx->SetOutputDim(weight_grad_name, weight_dims); - } - - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Hidden")), - ctx.device_context().GetPlace()); - } -}; - -template -class GRUUnitGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("gru_unit_grad"); - - op->SetInput("Input", this->Input("Input")); - op->SetInput("HiddenPrev", this->Input("HiddenPrev")); - op->SetInput("Weight", this->Input("Weight")); - op->SetInput("Bias", this->Input("Bias")); - - op->SetInput("Gate", this->Output("Gate")); - op->SetInput("ResetHiddenPrev", this->Output("ResetHiddenPrev")); - op->SetInput(framework::GradVarName("Hidden"), this->OutputGrad("Hidden")); - - op->SetAttrMap(this->Attrs()); - - op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input")); - op->SetOutput(framework::GradVarName("HiddenPrev"), - this->InputGrad("HiddenPrev")); - op->SetOutput(framework::GradVarName("Weight"), this->InputGrad("Weight")); - op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias")); - } -}; - -DECLARE_NO_NEED_BUFFER_VARS_INFERER(GRUUnitGradOpNoNeedBufferVarInferer, - "Bias"); - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OPERATOR(gru_unit, - ops::GRUUnitOp, - ops::GRUUnitOpMaker, - ops::GRUUnitGradOpMaker, - ops::GRUUnitGradOpMaker); -REGISTER_OPERATOR(gru_unit_grad, - ops::GRUUnitGradOp, - ops::GRUUnitGradOpNoNeedBufferVarInferer); - -PD_REGISTER_STRUCT_KERNEL( - gru_unit, CPU, ALL_LAYOUT, ops::GRUUnitKernel, float, double) {} -PD_REGISTER_STRUCT_KERNEL( - gru_unit_grad, CPU, ALL_LAYOUT, ops::GRUUnitGradKernel, float, double) {} diff --git a/paddle/fluid/operators/gru_unit_op.cu b/paddle/fluid/operators/gru_unit_op.cu deleted file mode 100644 index 192594a09e86f..0000000000000 --- a/paddle/fluid/operators/gru_unit_op.cu +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#include "paddle/fluid/operators/gru_unit_op.h" - -namespace ops = paddle::operators; - -PD_REGISTER_STRUCT_KERNEL( - gru_unit, GPU, ALL_LAYOUT, ops::GRUUnitKernel, float, double) {} -PD_REGISTER_STRUCT_KERNEL( - gru_unit_grad, GPU, ALL_LAYOUT, ops::GRUUnitGradKernel, float, double) {} diff --git a/paddle/fluid/operators/gru_unit_op.h b/paddle/fluid/operators/gru_unit_op.h deleted file mode 100644 index e5b91abd144e5..0000000000000 --- a/paddle/fluid/operators/gru_unit_op.h +++ /dev/null @@ -1,347 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/activation_op.h" -#include "paddle/fluid/platform/place.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/eigen/common.h" - -namespace paddle { -namespace operators { - -enum GRUActivationType { identity = 0, sigmoid = 1, tanh = 2, relu = 3 }; - -template -class GRUUnitKernel : public framework::OpKernel { - public: - template - void ActCompute( - const int act_type, const Device& d, X x, Y y, phi::Place place) const { - if (act_type == identity) { - y.device(d) = x; - } else if (act_type == sigmoid) { - SigmoidFunctor()(d, x, y); - } else if (act_type == tanh) { - TanhFunctor()(d, x, y); - } else if (act_type == relu) { - if (place == platform::CPUPlace()) - ReluCPUFunctor()(d, x, y); - else - ReluCUDAFunctor()(d, x, y); - } else { - PADDLE_THROW(phi::errors::Unimplemented( - "Unsupported activation type, only supports identity, sigmoid, tanh " - "and relu.")); - } - } - - void Compute(const framework::ExecutionContext& context) const override { - auto* input = context.Input("Input"); - auto* hidden_prev = context.Input("HiddenPrev"); - auto* weight = context.Input("Weight"); - auto* bias = context.Input("Bias"); - auto* gate = context.Output("Gate"); - gate->mutable_data(context.GetPlace()); - auto* reset_hidden_prev = - context.Output("ResetHiddenPrev"); - reset_hidden_prev->mutable_data(context.GetPlace()); - auto* hidden = context.Output("Hidden"); - hidden->mutable_data(context.GetPlace()); - - int batch_size = input->dims()[0]; - int frame_size = hidden_prev->dims()[1]; - - auto x = phi::EigenMatrix::From(*input); - auto h_p = phi::EigenMatrix::From(*hidden_prev); - auto g = phi::EigenMatrix::From(*gate); - auto r_h_p = phi::EigenMatrix::From(*reset_hidden_prev); - auto h = phi::EigenMatrix::From(*hidden); - auto& place = - *context.template device_context().eigen_device(); - - // calculate unactivated gate outputs - if (bias) { - auto b = phi::EigenMatrix::From(*bias); - g.device(place) = - x + b.reshape(Eigen::array({{1, frame_size * 3}})) - .broadcast(Eigen::array({{batch_size, 1}})); - } else { - g.device(place) = x; - } - const T* hidden_prev_data = hidden_prev->data(); - const T* weight_data = weight->data(); - T* gate_data = gate->data(); - T* reset_hidden_prev_data = reset_hidden_prev->data(); - auto& dev_ctx = context.template device_context(); - auto blas = phi::funcs::GetBlas(dev_ctx); - blas.GEMM(false, - false, - batch_size, - 2 * frame_size, - frame_size, - 1, - hidden_prev_data, - frame_size, - weight_data, - frame_size * 2, - 1, - gate_data, - frame_size * 3); - - // calculate activated gate - Eigen::array extents{{batch_size, frame_size}}; - Eigen::array u_offsets{{0, 0}}; - ActCompute(context.Attr("gate_activation"), - place, - g.slice(u_offsets, extents), - g.slice(u_offsets, extents), - context.GetPlace()); - auto u = g.slice(u_offsets, extents); // update gate - Eigen::array r_offsets{{0, frame_size}}; - ActCompute(context.Attr("gate_activation"), - place, - g.slice(r_offsets, extents), - g.slice(r_offsets, extents), - context.GetPlace()); - auto r = g.slice(r_offsets, extents); // reset gate - r_h_p.device(place) = r * h_p; // reset previous hidden state - blas.GEMM(false, - false, - batch_size, - frame_size, - frame_size, - 1, - reset_hidden_prev_data, - frame_size, - weight_data + frame_size * frame_size * 2, - frame_size, - 1, - gate_data + frame_size * 2, - frame_size * 3); - - Eigen::array c_offsets{{0, frame_size * 2}}; - ActCompute(context.Attr("activation"), - place, - g.slice(c_offsets, extents), - g.slice(c_offsets, extents), - context.GetPlace()); - auto c = g.slice(c_offsets, extents); // output candidate - - // calculate final output - if (context.Attr("origin_mode")) { - h.device(place) = c + u * (h_p - c); // (1 - u) * c + u * h_p - } else { - h.device(place) = u * (c - h_p) + h_p; // u * c + (1 - u) * h_p - } - } -}; - -template -class GRUUnitGradKernel : public framework::OpKernel { - public: - template - void ActGradCompute( - const int act_type, const Device& d, X x, Y y, DX dx, DY dy) const { - // x is dummy and won't be used even in Relu(use y instead) - if (act_type == identity) - dx.device(d) = dy; - else if (act_type == sigmoid) - SigmoidGradFunctor()(d, x, y, dy, dx); - else if (act_type == tanh) - TanhGradFunctor()(d, x, y, dy, dx); - else if (act_type == relu) - ReluGradFunctor()(d, x, y, dy, dx); - else - PADDLE_THROW(phi::errors::Unimplemented( - "Unsupported activation type, only supports identity, sigmoid, tanh " - "and relu.")); - } - - void Compute(const framework::ExecutionContext& context) const override { - auto* input = context.Input("Input"); - auto* hidden_prev = context.Input("HiddenPrev"); - auto* weight = context.Input("Weight"); - auto* gate = context.Input("Gate"); - auto* reset_hidden_prev = - context.Input("ResetHiddenPrev"); - auto* hidden_grad = - context.Input(framework::GradVarName("Hidden")); - auto* input_grad = - context.Output(framework::GradVarName("Input")); - auto* hidden_prev_grad = - context.Output(framework::GradVarName("HiddenPrev")); - auto* weight_grad = - context.Output(framework::GradVarName("Weight")); - auto* bias_grad = - context.Output(framework::GradVarName("Bias")); - phi::DenseTensor gate_grad; - phi::DenseTensor reset_hidden_prev_grad; - - const T* hidden_prev_data = hidden_prev->data(); - const T* weight_data = weight->data(); - T* gate_grad_data = - gate_grad.mutable_data(input->dims(), context.GetPlace()); - const T* reset_hidden_prev_data = reset_hidden_prev->data(); - T* reset_hidden_prev_grad_data = reset_hidden_prev_grad.mutable_data( - reset_hidden_prev->dims(), context.GetPlace()); - - auto h_p = phi::EigenMatrix::From(*hidden_prev); - auto g = phi::EigenMatrix::From(*gate); - auto d_h = phi::EigenMatrix::From(*hidden_grad); - auto d_g = phi::EigenMatrix::From(gate_grad); - auto d_r_h_p = phi::EigenMatrix::From(reset_hidden_prev_grad); - auto& place = - *context.template device_context().eigen_device(); - - int batch_size = input->dims()[0]; - int frame_size = hidden_prev->dims()[1]; - - Eigen::array extents{{batch_size, frame_size}}; - Eigen::array u_offsets{{0, 0}}; - auto u = g.slice(u_offsets, extents); // update gate - Eigen::array r_offsets{{0, frame_size}}; - auto r = g.slice(r_offsets, extents); // reset gate - Eigen::array c_offsets{{0, frame_size * 2}}; - auto c = g.slice(c_offsets, extents); // output candidate - - // backward for unactivated update gate - if (context.Attr("origin_mode")) { - ActGradCompute(context.Attr("gate_activation"), - place, - u, - u, - d_g.slice(u_offsets, extents), - d_h * (h_p - c)); - // backward for unactivated output candidate - ActGradCompute(context.Attr("activation"), - place, - c, - c, - d_g.slice(c_offsets, extents), - d_h * (1 - u)); - } else { - ActGradCompute(context.Attr("gate_activation"), - place, - u, - u, - d_g.slice(u_offsets, extents), - d_h * (c - h_p)); - // backward for unactivated output candidate - ActGradCompute(context.Attr("activation"), - place, - c, - c, - d_g.slice(c_offsets, extents), - d_h * u); - } - // backward for reset_hidden_prev - auto& dev_ctx = context.template device_context(); - auto blas = phi::funcs::GetBlas(dev_ctx); - blas.GEMM(false, - true, - batch_size, - frame_size, - frame_size, - 1, - gate_grad_data + frame_size * 2, - frame_size * 3, - weight_data + frame_size * frame_size * 2, - frame_size, - 0, - reset_hidden_prev_grad_data, - frame_size); - // backward for unactivated reset gate - ActGradCompute(context.Attr("gate_activation"), - place, - r, - r, - d_g.slice(r_offsets, extents), - d_r_h_p * h_p); - // backward for weight - if (weight_grad) { - T* weight_grad_data = weight_grad->mutable_data(context.GetPlace()); - // backward for state_weight - blas.GEMM(true, - false, - frame_size, - frame_size, - batch_size, - 1, - reset_hidden_prev_data, - frame_size, - gate_grad_data + frame_size * 2, - frame_size * 3, - 0, - weight_grad_data + frame_size * frame_size * 2, - frame_size); - - // backward for update_gate_weight and reset_gate_weight - blas.GEMM(true, - false, - frame_size, - frame_size * 2, - batch_size, - 1, - hidden_prev_data, - frame_size, - gate_grad_data, - frame_size * 3, - 0, - weight_grad_data, - frame_size * 2); - } - // backward for hidden_prev - if (hidden_prev_grad) { - T* hidden_prev_grad_data = - hidden_prev_grad->mutable_data(context.GetPlace()); - auto d_h_p = phi::EigenMatrix::From(*hidden_prev_grad); - if (context.Attr("origin_mode")) { - d_h_p.device(place) = d_r_h_p * r + d_h * u; - } else { - d_h_p.device(place) = d_r_h_p * r + d_h * (1 - u); - } - blas.GEMM(false, - true, - batch_size, - frame_size, - frame_size * 2, - 1, - gate_grad_data, - frame_size * 3, - weight_data, - frame_size * 2, - 1, - hidden_prev_grad_data, - frame_size); - } - // backward for input - if (input_grad) { - input_grad->mutable_data(context.GetPlace()); - auto d_x = phi::EigenMatrix::From(*input_grad); - d_x.device(place) = d_g; - } - // backward for bias - if (bias_grad) { - bias_grad->mutable_data(context.GetPlace()); - auto d_b = phi::EigenVector::Flatten(*bias_grad); - d_b.device(place) = d_g.sum(Eigen::array({{0}})); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index c6c145019f40b..aa5cc3762bf6d 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -574,6 +574,95 @@ void GeneralUnaryGradInferMeta(const MetaTensor& x, MetaTensor* dx) { } } +void GruUnitGradInferMeta(const MetaTensor& input, + const MetaTensor& hidden_prev, + const MetaTensor& weight, + const MetaTensor& bias, + MetaTensor* input_grad, + MetaTensor* hidden_prev_grad, + MetaTensor* weight_grad, + MetaTensor* bias_grad, + MetaConfig config) { + const auto& input_dims = input.dims(); + const auto& hidden_prev_dims = hidden_prev.dims(); + const auto& weight_dims = weight.dims(); + // int batch_size = input_dims[0]; + int input_size = static_cast(input_dims[1]); + int frame_size = static_cast(hidden_prev_dims[1]); + int weight_height = static_cast(weight_dims[0]); + int weight_width = static_cast(weight_dims[1]); + if (config.is_runtime || input_size >= 0) { + PADDLE_ENFORCE_EQ( + input_size, + frame_size * 3, + phi::errors::InvalidArgument( + "The second dimension of Input(Input) must be 3 " + "times of frame_size in GRUUnitGradOp, but received %d " + "(Input) vs %d (frame_size).", + input_size, + frame_size)); + } + PADDLE_ENFORCE_EQ( + weight_height, + frame_size, + phi::errors::InvalidArgument( + "The shape of Input(Weight) matrix must be [frame_size, frame_size " + "* 3] in GRUUnitGradOp, but received [%d, %d] (Weight) vs [%d, %d] " + "(frame_size).", + weight_height, + weight_width, + frame_size, + frame_size * 3)); + PADDLE_ENFORCE_EQ( + weight_width, + frame_size * 3, + phi::errors::InvalidArgument( + "The shape of Input(Weight) matrix must be [frame_size, frame_size " + "* 3] in GRUUnitGradOp, but received [%d, %d] (Weight) vs [%d, %d] " + "(frame_size).", + weight_height, + weight_width, + frame_size, + frame_size * 3)); + if (bias.initialized()) { + const auto& bias_dims = bias.dims(); + int bias_height = static_cast(bias_dims[0]); + int bias_width = static_cast(bias_dims[1]); + + PADDLE_ENFORCE_EQ( + bias_height, + 1, + phi::errors::InvalidArgument( + "The shape of Bias must be [1, frame_size * 3], but received " + "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).", + bias_height, + bias_width, + frame_size * 3)); + PADDLE_ENFORCE_EQ( + bias_width, + frame_size * 3, + phi::errors::InvalidArgument( + "The shape of Bias must be [1, frame_size * 3], but received " + "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).", + bias_height, + bias_width, + frame_size * 3)); + if (bias_grad->initialized()) { + bias_grad->set_dims(bias_dims); + } + } + + if (input_grad->initialized()) { + input_grad->set_dims(input_dims); + } + if (hidden_prev_grad->initialized()) { + hidden_prev_grad->set_dims(hidden_prev_dims); + } + if (weight_grad->initialized()) { + weight_grad->set_dims(weight_dims); + } +} + void GumbelSoftmaxGradInferMeta(const MetaTensor& out, const MetaTensor& dout, int axis, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 39b59958d6752..e21be2b535490 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -272,6 +272,16 @@ void GeneralQuinaryGradInferMeta(const MetaTensor& x, MetaTensor* dk, MetaTensor* dl); +void GruUnitGradInferMeta(const MetaTensor& input, + const MetaTensor& hidden_prev, + const MetaTensor& weight, + const MetaTensor& bias, + MetaTensor* input_grad, + MetaTensor* hidden_prev_grad, + MetaTensor* weight_grad, + MetaTensor* bias_grad, + MetaConfig config = MetaConfig()); + void GumbelSoftmaxGradInferMeta(const MetaTensor& out, const MetaTensor& dout, int axis, diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 65de2d4e3ce21..4c550df98034d 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2425,6 +2425,90 @@ void GraphSampleNeighborsInferMeta(const MetaTensor& row, out_count->set_dtype(DataType::INT32); } +void GruUnitInferMeta(const MetaTensor& input, + const MetaTensor& hidden_prev, + const MetaTensor& weight, + const MetaTensor& bias, + int activation, + int gate_activation, + bool origin_mode, + MetaTensor* gate, + MetaTensor* reset_hidden_prev, + MetaTensor* hidden, + MetaConfig config) { + const auto& input_dims = input.dims(); + const auto& hidden_prev_dims = hidden_prev.dims(); + const auto& weight_dims = weight.dims(); + int batch_size = static_cast(input_dims[0]); + int input_size = static_cast(input_dims[1]); + int frame_size = static_cast(hidden_prev_dims[1]); + int weight_height = static_cast(weight_dims[0]); + int weight_width = static_cast(weight_dims[1]); + if (config.is_runtime || input_size >= 0) { + PADDLE_ENFORCE_EQ(input_size, + frame_size * 3, + phi::errors::InvalidArgument( + "The second dimension of Input(Input) must be 3 " + "times of frame_size in GRUUnitOp, but received %d " + "(Input) vs %d (frame_size).", + input_size, + frame_size)); + } + PADDLE_ENFORCE_EQ( + weight_height, + frame_size, + phi::errors::InvalidArgument( + "The shape of Input(Weight) matrix must be [frame_size, frame_size " + "* 3] in GRUUnitOp, but received [%d, %d] (Weight) vs [%d, %d] " + "(frame_size).", + weight_height, + weight_width, + frame_size, + frame_size * 3)); + PADDLE_ENFORCE_EQ( + weight_width, + frame_size * 3, + phi::errors::InvalidArgument( + "The shape of Input(Weight) matrix must be [frame_size, frame_size " + "* 3] in GRUUnitOp, but received [%d, %d] (Weight) vs [%d, %d] " + "(frame_size).", + weight_height, + weight_width, + frame_size, + frame_size * 3)); + + if (bias.initialized()) { + const auto& bias_dims = bias.dims(); + int bias_height = static_cast(bias_dims[0]); + int bias_width = static_cast(bias_dims[1]); + PADDLE_ENFORCE_EQ( + bias_height, + 1, + phi::errors::InvalidArgument( + "The shape of Bias must be [1, frame_size * 3], but received " + "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).", + bias_height, + bias_width, + frame_size * 3)); + PADDLE_ENFORCE_EQ( + bias_width, + frame_size * 3, + phi::errors::InvalidArgument( + "The shape of Bias must be [1, frame_size * 3], but received " + "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).", + bias_height, + bias_width, + frame_size * 3)); + } + gate->set_dims({batch_size, frame_size * 3}); + reset_hidden_prev->set_dims({batch_size, frame_size}); + hidden->set_dims({batch_size, frame_size}); + + gate->set_dtype(input.dtype()); + reset_hidden_prev->set_dtype(input.dtype()); + hidden->set_dtype(input.dtype()); +} + void HSigmoidLossInferMeta(const MetaTensor& x, const MetaTensor& label, const MetaTensor& w, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index d60f8b0f3c443..4eb516bb76c8e 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -504,6 +504,18 @@ void GraphSampleNeighborsInferMeta(const MetaTensor& row, MetaTensor* out_count, MetaTensor* out_eids); +void GruUnitInferMeta(const MetaTensor& input, + const MetaTensor& hidden_prev, + const MetaTensor& weight, + const paddle::optional& bias, + int activation, + int gate_activation, + bool origin_mode, + MetaTensor* gate, + MetaTensor* reset_hidden_prev, + MetaTensor* hidden, + MetaConfig config = MetaConfig()); + void HSigmoidLossInferMeta(const MetaTensor& x, const MetaTensor& label, const MetaTensor& w, diff --git a/paddle/phi/kernels/cpu/gru_unit_grad_kernel.cc b/paddle/phi/kernels/cpu/gru_unit_grad_kernel.cc new file mode 100644 index 0000000000000..cb3ab4fabcb32 --- /dev/null +++ b/paddle/phi/kernels/cpu/gru_unit_grad_kernel.cc @@ -0,0 +1,19 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/gru_unit_kernel_impl.h" + +PD_REGISTER_KERNEL( + gru_unit_grad, CPU, ALL_LAYOUT, phi::GRUUnitGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/gru_unit_kernel.cc b/paddle/phi/kernels/cpu/gru_unit_kernel.cc new file mode 100644 index 0000000000000..4c1f3a6108bb1 --- /dev/null +++ b/paddle/phi/kernels/cpu/gru_unit_kernel.cc @@ -0,0 +1,19 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/gru_unit_kernel_impl.h" + +PD_REGISTER_KERNEL( + gru_unit, CPU, ALL_LAYOUT, phi::GRUUnitKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/gru_unit_grad_kernel.cu b/paddle/phi/kernels/gpu/gru_unit_grad_kernel.cu new file mode 100644 index 0000000000000..f58ee83bbfb49 --- /dev/null +++ b/paddle/phi/kernels/gpu/gru_unit_grad_kernel.cu @@ -0,0 +1,19 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/gru_unit_kernel_impl.h" + +PD_REGISTER_KERNEL( + gru_unit_grad, GPU, ALL_LAYOUT, phi::GRUUnitGradKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/gru_unit_kernel.cu b/paddle/phi/kernels/gpu/gru_unit_kernel.cu new file mode 100644 index 0000000000000..56695d228f0ca --- /dev/null +++ b/paddle/phi/kernels/gpu/gru_unit_kernel.cu @@ -0,0 +1,19 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/gru_unit_kernel_impl.h" + +PD_REGISTER_KERNEL( + gru_unit, GPU, ALL_LAYOUT, phi::GRUUnitKernel, float, double) {} diff --git a/paddle/phi/kernels/impl/gru_unit_kernel_impl.h b/paddle/phi/kernels/impl/gru_unit_kernel_impl.h new file mode 100644 index 0000000000000..11fd1c5b5a637 --- /dev/null +++ b/paddle/phi/kernels/impl/gru_unit_kernel_impl.h @@ -0,0 +1,355 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/activation_functor.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/utils/optional.h" +namespace phi { + +enum GRUActivationType { identity = 0, sigmoid = 1, tanh = 2, relu = 3 }; + +template +void ActCompute( + const int act_type, const Device& d, X x, Y y, phi::Place place) { + if (act_type == identity) { + y.device(d) = x; + } else if (act_type == sigmoid) { + phi::funcs::SigmoidFunctor()(d, x, y); + } else if (act_type == tanh) { + phi::funcs::TanhFunctor()(d, x, y); + } else if (act_type == relu) { + if (place == phi::CPUPlace()) + phi::funcs::ReluCPUFunctor()(d, x, y); + else + phi::funcs::ReluCUDAFunctor()(d, x, y); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Unsupported activation type, only supports identity, sigmoid, tanh " + "and relu.")); + } +} + +template +void GRUUnitKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& hidden_prev, + const DenseTensor& weight, + const paddle::optional& bias, + int activation, + int gate_activation, + bool origin_mode, + const DenseTensor* gate, + const DenseTensor* reset_hidden_prev, + const DenseTensor* hidden) { + auto* input_p = &input; + auto* hidden_prev_p = &hidden_prev; + auto* weight_p = &weight; + auto* bias_p = bias->get_ptr(); + + dev_ctx.template Alloc(gate); + dev_ctx.template Alloc(reset_hidden_prev); + dev_ctx.template Alloc(hidden); + + int batch_size = input_p->dims()[0]; + int frame_size = hidden_prev_p->dims()[1]; + + auto x = phi::EigenMatrix::From(input); + auto h_p = phi::EigenMatrix::From(hidden_prev); + auto g = phi::EigenMatrix::From(*gate); + auto r_h_p = phi::EigenMatrix::From(*reset_hidden_prev); + auto h = phi::EigenMatrix::From(*hidden); + auto& place = *dev_ctx.eigen_device(); + + // calculate unactivated gate outputs + if (bias) { + auto b = phi::EigenMatrix::From(bias.get()); + g.device(place) = + x + b.reshape(Eigen::array({{1, frame_size * 3}})) + .broadcast(Eigen::array({{batch_size, 1}})); + } else { + g.device(place) = x; + } + const T* hidden_prev_data = hidden_prev.data(); + const T* weight_data = weight.data(); + T* gate_data = gate->data(); + T* reset_hidden_prev_data = reset_hidden_prev->data(); + auto blas = phi::funcs::GetBlas(dev_ctx); + blas.GEMM(false, + false, + batch_size, + 2 * frame_size, + frame_size, + 1, + hidden_prev_data, + frame_size, + weight_data, + frame_size * 2, + 1, + gate_data, + frame_size * 3); + + // calculate activated gate + Eigen::array extents{{batch_size, frame_size}}; + Eigen::array u_offsets{{0, 0}}; + ActCompute( + gate_activation, + place, + g.slice(u_offsets, extents), + g.slice(u_offsets, extents), + dev_ctx.GetPlace()); + auto u = g.slice(u_offsets, extents); // update gate + Eigen::array r_offsets{{0, frame_size}}; + ActCompute( + gate_activation, + place, + g.slice(r_offsets, extents), + g.slice(r_offsets, extents), + dev_ctx.GetPlace()); + auto r = g.slice(r_offsets, extents); // reset gate + r_h_p.device(place) = r * h_p; // reset previous hidden state + blas.GEMM(false, + false, + batch_size, + frame_size, + frame_size, + 1, + reset_hidden_prev_data, + frame_size, + weight_data + frame_size * frame_size * 2, + frame_size, + 1, + gate_data + frame_size * 2, + frame_size * 3); + + Eigen::array c_offsets{{0, frame_size * 2}}; + ActCompute( + activation, + place, + g.slice(c_offsets, extents), + g.slice(c_offsets, extents), + dev_ctx.GetPlace()); + auto c = g.slice(c_offsets, extents); // output candidate + + // calculate final output + if (origin_mode) { + h.device(place) = c + u * (h_p - c); // (1 - u) * c + u * h_p + } else { + h.device(place) = u * (c - h_p) + h_p; // u * c + (1 - u) * h_p + } +} + +template +void ActGradCompute( + const int act_type, const Device& d, X x, Y y, DX dx, DY dy) { + // x is dummy and won't be used even in Relu(use y instead) + if (act_type == identity) + dx.device(d) = dy; + else if (act_type == sigmoid) + phi::funcs::SigmoidGradFunctor()(d, x, y, dy, dx); + else if (act_type == tanh) + phi::funcs::TanhGradFunctor()(d, x, y, dy, dx); + else if (act_type == relu) + phi::funcs::ReluGradFunctor()(d, x, y, dy, dx); + else + PADDLE_THROW(phi::errors::Unimplemented( + "Unsupported activation type, only supports identity, sigmoid, tanh " + "and relu.")); +} + +template +void GRUUnitGradKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& hidden_prev, + const DenseTensor& weight, + const paddle::optional& bias, + const DenseTensor& gate, + const DenseTensor& reset_hidden_prev, + const DenseTensor& hidden_grad, + int activation, + int gate_activation, + bool origin_mode, + const DenseTensor* input_grad, + const DenseTensor* hidden_prev_grad, + const DenseTensor* weight_grad, + const DenseTensor* bias_grad) { + phi::DenseTensor gate_grad; + phi::DenseTensor reset_hidden_prev_grad; + + const T* hidden_prev_data = hidden_prev.data(); + const T* weight_data = weight.data(); + gate_grad.Resize(input.dims()); + T* gate_grad_data = dev_ctx.template Alloc(&gate_grad); + const T* reset_hidden_prev_data = reset_hidden_prev.data(); + reset_hidden_prev_grad.Resize(reset_hidden_prev.dims()); + T* reset_hidden_prev_grad_data = + dev_ctx.template Alloc(&reset_hidden_prev_grad); + + auto h_p = phi::EigenMatrix::From(hidden_prev); + auto g = phi::EigenMatrix::From(gate); + auto d_h = phi::EigenMatrix::From(hidden_grad); + auto d_g = phi::EigenMatrix::From(gate_grad); + auto d_r_h_p = phi::EigenMatrix::From(reset_hidden_prev_grad); + auto& place = *dev_ctx.eigen_device(); + + int batch_size = input.dims()[0]; + int frame_size = hidden_prev.dims()[1]; + + Eigen::array extents{{batch_size, frame_size}}; + Eigen::array u_offsets{{0, 0}}; + auto u = g.slice(u_offsets, extents); // update gate + Eigen::array r_offsets{{0, frame_size}}; + auto r = g.slice(r_offsets, extents); // reset gate + Eigen::array c_offsets{{0, frame_size * 2}}; + auto c = g.slice(c_offsets, extents); // output candidate + + // backward for unactivated update gate + if (origin_mode) { + ActGradCompute(gate_activation, + place, + u, + u, + d_g.slice(u_offsets, extents), + d_h * (h_p - c)); + // backward for unactivated output candidate + ActGradCompute( + activation, place, c, c, d_g.slice(c_offsets, extents), d_h * (1 - u)); + } else { + ActGradCompute(gate_activation, + place, + u, + u, + d_g.slice(u_offsets, extents), + d_h * (c - h_p)); + // backward for unactivated output candidate + ActGradCompute( + activation, place, c, c, d_g.slice(c_offsets, extents), d_h * u); + } + // backward for reset_hidden_prev + auto blas = phi::funcs::GetBlas(dev_ctx); + blas.GEMM(false, + true, + batch_size, + frame_size, + frame_size, + 1, + gate_grad_data + frame_size * 2, + frame_size * 3, + weight_data + frame_size * frame_size * 2, + frame_size, + 0, + reset_hidden_prev_grad_data, + frame_size); + // backward for unactivated reset gate + ActGradCompute(gate_activation, + place, + r, + r, + d_g.slice(r_offsets, extents), + d_r_h_p * h_p); + // backward for weight + if (weight_grad) { + T* weight_grad_data = dev_ctx.template Alloc(weight_grad); + // backward for state_weight + blas.GEMM(true, + false, + frame_size, + frame_size, + batch_size, + 1, + reset_hidden_prev_data, + frame_size, + gate_grad_data + frame_size * 2, + frame_size * 3, + 0, + weight_grad_data + frame_size * frame_size * 2, + frame_size); + + // backward for update_gate_weight and reset_gate_weight + blas.GEMM(true, + false, + frame_size, + frame_size * 2, + batch_size, + 1, + hidden_prev_data, + frame_size, + gate_grad_data, + frame_size * 3, + 0, + weight_grad_data, + frame_size * 2); + } + // backward for hidden_prev + if (hidden_prev_grad) { + T* hidden_prev_grad_data = dev_ctx.template Alloc(hidden_prev_grad); + auto d_h_p = phi::EigenMatrix::From(*hidden_prev_grad); + if (origin_mode) { + d_h_p.device(place) = d_r_h_p * r + d_h * u; + } else { + d_h_p.device(place) = d_r_h_p * r + d_h * (1 - u); + } + blas.GEMM(false, + true, + batch_size, + frame_size, + frame_size * 2, + 1, + gate_grad_data, + frame_size * 3, + weight_data, + frame_size * 2, + 1, + hidden_prev_grad_data, + frame_size); + } + // backward for input + if (input_grad) { + dev_ctx.template Alloc(input_grad); + auto d_x = phi::EigenMatrix::From(*input_grad); + d_x.device(place) = d_g; + } + // backward for bias + if (bias_grad) { + dev_ctx.template Alloc(bias_grad); + auto d_b = phi::EigenVector::Flatten(*bias_grad); + d_b.device(place) = d_g.sum(Eigen::array({{0}})); + } +} +} // namespace phi diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index 01b0c1025b8e9..6447419870edd 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -1167,6 +1167,21 @@ optional: scale, bias inplace : (y_grad -> x_grad) +- backward_op : gru_unit_grad + forward: gru_unit (Tensor input, Tensor hidden_prev, Tensor weight, Tensor bias, int activation + = 2, int gate_activation = 1, bool origin_mode = false) -> Tensor (gate), Tensor (reset_hidden_prev), Tensor (hidden) + args: (Tensor input, Tensor hidden_prev, Tensor weight, Tensor bias, Tensor gate, Tensor reset_hidden_prev, Tensor hidden_grad, + int activation, int gate_activation, bool origin_mode) + output: Tensor (input_grad), Tensor (hidden_prev_grad), Tensor (weight_grad), Tensor (bias_grad) + infer_meta: + func: GruUnitGradInferMeta + param : [input, hidden_prev, weight, bias] + kernel: + func: gru_unit_grad + data_type: hidden_grad + optional: bias + no_need_buffer: bias + - backward_op : gumbel_softmax_grad forward : gumbel_softmax (Tensor x, float temperature, bool hard, int axis) -> Tensor(out) args : (Tensor out, Tensor out_grad, int axis) diff --git a/paddle/phi/ops/yaml/op_compat.yaml b/paddle/phi/ops/yaml/op_compat.yaml index 3435ea6c46789..4f156d73d0b3b 100755 --- a/paddle/phi/ops/yaml/op_compat.yaml +++ b/paddle/phi/ops/yaml/op_compat.yaml @@ -4069,6 +4069,13 @@ outputs : out : Out +- op: gru_unit + backward: gru_unit_grad + inputs: + {input : Input, hidden_prev : HiddenPrev, weight : Weight, bias : Bias} + outputs: + {gate : Gate, reset_hidden_prev : ResetHiddenPrev, hidden : Hidden} + - op: identity_loss inputs : x: X diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index a84f7f337af37..8c0ff67ff2d28 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -1548,6 +1548,18 @@ backward : group_norm_grad interfaces : paddle::dialect::InferSymbolicShapeInterface +- op : gru_unit + args: (Tensor input, Tensor hidden_prev, Tensor weight, Tensor bias, int activation + = 2, int gate_activation = 1, bool origin_mode = false) + output: Tensor (gate), Tensor (reset_hidden_prev), Tensor (hidden) + infer_meta: + func: GruUnitInferMeta + kernel: + func: gru_unit + optional: bias + intermediate: gate, reset_hidden_prev + backward: gru_unit_grad + - op : gumbel_softmax args : (Tensor x, float temperature = 1.0, bool hard = false, int axis = -1) output : Tensor From d59cb7544a299fd607155234c88001cccd59f04f Mon Sep 17 00:00:00 2001 From: co63oc Date: Sun, 26 May 2024 07:19:13 +0800 Subject: [PATCH 02/12] Fix --- paddle/phi/infermeta/backward.cc | 4 ++++ paddle/phi/kernels/impl/gru_unit_kernel_impl.h | 14 +++++++------- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index aa5cc3762bf6d..0a21dbf73bb4e 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -649,17 +649,21 @@ void GruUnitGradInferMeta(const MetaTensor& input, frame_size * 3)); if (bias_grad->initialized()) { bias_grad->set_dims(bias_dims); + bias_grad->set_dtype(bias.dtype()); } } if (input_grad->initialized()) { input_grad->set_dims(input_dims); + input_grad->set_dtype(input.dtype()); } if (hidden_prev_grad->initialized()) { hidden_prev_grad->set_dims(hidden_prev_dims); + hidden_prev_grad->set_dtype(hidden_prev.dtype()); } if (weight_grad->initialized()) { weight_grad->set_dims(weight_dims); + weight_grad->set_dtype(weight.dtype()); } } diff --git a/paddle/phi/kernels/impl/gru_unit_kernel_impl.h b/paddle/phi/kernels/impl/gru_unit_kernel_impl.h index 11fd1c5b5a637..7ca738ef1fdab 100644 --- a/paddle/phi/kernels/impl/gru_unit_kernel_impl.h +++ b/paddle/phi/kernels/impl/gru_unit_kernel_impl.h @@ -55,9 +55,9 @@ void GRUUnitKernel(const Context& dev_ctx, int activation, int gate_activation, bool origin_mode, - const DenseTensor* gate, - const DenseTensor* reset_hidden_prev, - const DenseTensor* hidden) { + DenseTensor* gate, + DenseTensor* reset_hidden_prev, + DenseTensor* hidden) { auto* input_p = &input; auto* hidden_prev_p = &hidden_prev; auto* weight_p = &weight; @@ -185,10 +185,10 @@ void GRUUnitGradKernel(const Context& dev_ctx, int activation, int gate_activation, bool origin_mode, - const DenseTensor* input_grad, - const DenseTensor* hidden_prev_grad, - const DenseTensor* weight_grad, - const DenseTensor* bias_grad) { + DenseTensor* input_grad, + DenseTensor* hidden_prev_grad, + DenseTensor* weight_grad, + DenseTensor* bias_grad) { phi::DenseTensor gate_grad; phi::DenseTensor reset_hidden_prev_grad; From 8437a795a955fa5b281ac1a21264b40f4642c53c Mon Sep 17 00:00:00 2001 From: co63oc Date: Sun, 26 May 2024 12:00:11 +0800 Subject: [PATCH 03/12] Fix --- .../phi/kernels/impl/gru_unit_kernel_impl.h | 108 ++++++++---------- 1 file changed, 47 insertions(+), 61 deletions(-) diff --git a/paddle/phi/kernels/impl/gru_unit_kernel_impl.h b/paddle/phi/kernels/impl/gru_unit_kernel_impl.h index 7ca738ef1fdab..a36766bf71246 100644 --- a/paddle/phi/kernels/impl/gru_unit_kernel_impl.h +++ b/paddle/phi/kernels/impl/gru_unit_kernel_impl.h @@ -25,7 +25,7 @@ namespace phi { enum GRUActivationType { identity = 0, sigmoid = 1, tanh = 2, relu = 3 }; -template +template void ActCompute( const int act_type, const Device& d, X x, Y y, phi::Place place) { if (act_type == identity) { @@ -46,6 +46,8 @@ void ActCompute( } } +#define ACT_COMPUTE ActCompute + template void GRUUnitKernel(const Context& dev_ctx, const DenseTensor& input, @@ -61,7 +63,7 @@ void GRUUnitKernel(const Context& dev_ctx, auto* input_p = &input; auto* hidden_prev_p = &hidden_prev; auto* weight_p = &weight; - auto* bias_p = bias->get_ptr(); + auto* bias_p = bias.get_ptr(); dev_ctx.template Alloc(gate); dev_ctx.template Alloc(reset_hidden_prev); @@ -108,20 +110,18 @@ void GRUUnitKernel(const Context& dev_ctx, // calculate activated gate Eigen::array extents{{batch_size, frame_size}}; Eigen::array u_offsets{{0, 0}}; - ActCompute( - gate_activation, - place, - g.slice(u_offsets, extents), - g.slice(u_offsets, extents), - dev_ctx.GetPlace()); + ACT_COMPUTE(gate_activation, + place, + g.slice(u_offsets, extents), + g.slice(u_offsets, extents), + dev_ctx.GetPlace()); auto u = g.slice(u_offsets, extents); // update gate Eigen::array r_offsets{{0, frame_size}}; - ActCompute( - gate_activation, - place, - g.slice(r_offsets, extents), - g.slice(r_offsets, extents), - dev_ctx.GetPlace()); + ACT_COMPUTE(gate_activation, + place, + g.slice(r_offsets, extents), + g.slice(r_offsets, extents), + dev_ctx.GetPlace()); auto r = g.slice(r_offsets, extents); // reset gate r_h_p.device(place) = r * h_p; // reset previous hidden state blas.GEMM(false, @@ -139,12 +139,11 @@ void GRUUnitKernel(const Context& dev_ctx, frame_size * 3); Eigen::array c_offsets{{0, frame_size * 2}}; - ActCompute( - activation, - place, - g.slice(c_offsets, extents), - g.slice(c_offsets, extents), - dev_ctx.GetPlace()); + ACT_COMPUTE(activation, + place, + g.slice(c_offsets, extents), + g.slice(c_offsets, extents), + dev_ctx.GetPlace()); auto c = g.slice(c_offsets, extents); // output candidate // calculate final output @@ -155,7 +154,12 @@ void GRUUnitKernel(const Context& dev_ctx, } } -template +template void ActGradCompute( const int act_type, const Device& d, X x, Y y, DX dx, DY dy) { // x is dummy and won't be used even in Relu(use y instead) @@ -173,6 +177,8 @@ void ActGradCompute( "and relu.")); } +#define ACT_GRAD_COMPUTE ActGradCompute + template void GRUUnitGradKernel(const Context& dev_ctx, const DenseTensor& input, @@ -221,40 +227,24 @@ void GRUUnitGradKernel(const Context& dev_ctx, // backward for unactivated update gate if (origin_mode) { - ActGradCompute(gate_activation, - place, - u, - u, - d_g.slice(u_offsets, extents), - d_h * (h_p - c)); + ACT_GRAD_COMPUTE(gate_activation, + place, + u, + u, + d_g.slice(u_offsets, extents), + d_h * (h_p - c)); // backward for unactivated output candidate - ActGradCompute( + ACT_GRAD_COMPUTE( activation, place, c, c, d_g.slice(c_offsets, extents), d_h * (1 - u)); } else { - ActGradCompute(gate_activation, - place, - u, - u, - d_g.slice(u_offsets, extents), - d_h * (c - h_p)); + ACT_GRAD_COMPUTE(gate_activation, + place, + u, + u, + d_g.slice(u_offsets, extents), + d_h * (c - h_p)); // backward for unactivated output candidate - ActGradCompute( + ACT_GRAD_COMPUTE( activation, place, c, c, d_g.slice(c_offsets, extents), d_h * u); } // backward for reset_hidden_prev @@ -273,16 +263,12 @@ void GRUUnitGradKernel(const Context& dev_ctx, reset_hidden_prev_grad_data, frame_size); // backward for unactivated reset gate - ActGradCompute(gate_activation, - place, - r, - r, - d_g.slice(r_offsets, extents), - d_r_h_p * h_p); + ACT_GRAD_COMPUTE(gate_activation, + place, + r, + r, + d_g.slice(r_offsets, extents), + d_r_h_p * h_p); // backward for weight if (weight_grad) { T* weight_grad_data = dev_ctx.template Alloc(weight_grad); From 49002a56842f0cfdf2df12eb7b8952aca7fac1da Mon Sep 17 00:00:00 2001 From: co63oc Date: Sun, 26 May 2024 14:00:00 +0800 Subject: [PATCH 04/12] Fix --- paddle/phi/kernels/impl/gru_unit_kernel_impl.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/paddle/phi/kernels/impl/gru_unit_kernel_impl.h b/paddle/phi/kernels/impl/gru_unit_kernel_impl.h index a36766bf71246..641ad0d5a27fc 100644 --- a/paddle/phi/kernels/impl/gru_unit_kernel_impl.h +++ b/paddle/phi/kernels/impl/gru_unit_kernel_impl.h @@ -62,8 +62,6 @@ void GRUUnitKernel(const Context& dev_ctx, DenseTensor* hidden) { auto* input_p = &input; auto* hidden_prev_p = &hidden_prev; - auto* weight_p = &weight; - auto* bias_p = bias.get_ptr(); dev_ctx.template Alloc(gate); dev_ctx.template Alloc(reset_hidden_prev); From 78f21de10f00d2f622e8ed1434e7311d1ac1dfe5 Mon Sep 17 00:00:00 2001 From: co63oc Date: Sun, 26 May 2024 14:07:41 +0800 Subject: [PATCH 05/12] Fix --- paddle/phi/infermeta/multiary.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 4eb516bb76c8e..9988e0c1c5ce5 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -507,7 +507,7 @@ void GraphSampleNeighborsInferMeta(const MetaTensor& row, void GruUnitInferMeta(const MetaTensor& input, const MetaTensor& hidden_prev, const MetaTensor& weight, - const paddle::optional& bias, + const MetaTensor& bias, int activation, int gate_activation, bool origin_mode, From d4b2fea5b41b5d8eaf048a1307118b5b34cbf07e Mon Sep 17 00:00:00 2001 From: co63oc Date: Sun, 26 May 2024 15:15:41 +0800 Subject: [PATCH 06/12] Fix --- paddle/fluid/operators/gru_op.cc | 594 ---------------------- paddle/fluid/operators/gru_op.cu.cc | 140 ----- paddle/fluid/operators/gru_op.h | 186 ------- paddle/phi/infermeta/backward.cc | 95 ++++ paddle/phi/infermeta/backward.h | 10 + paddle/phi/infermeta/multiary.cc | 85 ++++ paddle/phi/infermeta/multiary.h | 15 + paddle/phi/kernels/cpu/gru_grad_kernel.cc | 19 + paddle/phi/kernels/cpu/gru_kernel.cc | 238 +++++++++ paddle/phi/kernels/gpu/gru_grad_kernel.cu | 19 + paddle/phi/kernels/gpu/gru_kernel.cu | 122 +++++ paddle/phi/kernels/impl/gru_kernel_impl.h | 173 +++++++ paddle/phi/ops/yaml/backward.yaml | 19 + paddle/phi/ops/yaml/op_compat.yaml | 7 + paddle/phi/ops/yaml/ops.yaml | 14 + 15 files changed, 816 insertions(+), 920 deletions(-) delete mode 100644 paddle/fluid/operators/gru_op.cc delete mode 100644 paddle/fluid/operators/gru_op.cu.cc delete mode 100644 paddle/fluid/operators/gru_op.h create mode 100644 paddle/phi/kernels/cpu/gru_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/gru_kernel.cc create mode 100644 paddle/phi/kernels/gpu/gru_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/gru_kernel.cu create mode 100644 paddle/phi/kernels/impl/gru_kernel_impl.h diff --git a/paddle/fluid/operators/gru_op.cc b/paddle/fluid/operators/gru_op.cc deleted file mode 100644 index c948315189a15..0000000000000 --- a/paddle/fluid/operators/gru_op.cc +++ /dev/null @@ -1,594 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/gru_op.h" - -#include -#include - -#include "paddle/common/flags.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/detail/gru_cpu_kernel.h" -#include "paddle/phi/kernels/funcs/detail/gru_kernel.h" - -COMMON_DECLARE_int32(paddle_num_threads); - -namespace paddle { -namespace operators { - -class GRUOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "GRU"); - OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "GRU"); - OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "GRU"); - bool is_test = ctx->Attrs().Get("is_test"); - if (!is_test) { - OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", "GRU"); - OP_INOUT_CHECK(ctx->HasOutput("BatchResetHiddenPrev"), - "Output", - "BatchResetHiddenPrev", - "GRU"); - OP_INOUT_CHECK( - ctx->HasOutput("BatchHidden"), "Output", "BatchHidden", "GRU"); - } - auto input_dims = ctx->GetInputDim("Input"); - auto weight_dims = ctx->GetInputDim("Weight"); - int input_size = static_cast(input_dims[1]); - int frame_size = static_cast(weight_dims[0]); - if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ(input_size, - frame_size * 3, - phi::errors::InvalidArgument( - "The second dimension of Input(Input) must be 3 " - "times of frame_size in GRUOp, but received %d " - "(Input) vs %d (frame_size).", - input_size, - frame_size)); - } - PADDLE_ENFORCE_EQ( - weight_dims[1], - frame_size * 3, - phi::errors::InvalidArgument( - "The shape of Input(Weight) matrix must be [frame_size, frame_size " - "* 3], but received [%d, %d] (Weight) vs [%d, %d] (frame_size).", - weight_dims[0], - weight_dims[1], - frame_size, - frame_size * 3)); - if (ctx->HasInput("H0")) { - auto h0_dims = ctx->GetInputDim("H0"); - PADDLE_ENFORCE_EQ( - h0_dims[1], - frame_size, - phi::errors::InvalidArgument( - "The width of Input(H0) must be equal to frame_size, but " - "received %d (width of H0) vs %d (frame_size).", - h0_dims[1], - frame_size)); - } - if (ctx->HasInput("Bias")) { - auto bias_dims = ctx->GetInputDim("Bias"); - int bias_height = static_cast(bias_dims[0]); - int bias_width = static_cast(bias_dims[1]); - PADDLE_ENFORCE_EQ( - bias_height, - 1, - phi::errors::InvalidArgument( - "The shape of Bias must be [1, frame_size * 3], but received " - "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).", - bias_height, - bias_width, - frame_size * 3)); - PADDLE_ENFORCE_EQ( - bias_width, - frame_size * 3, - phi::errors::InvalidArgument( - "The shape of Bias must be [1, frame_size * 3], but received " - "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).", - bias_height, - bias_width, - frame_size * 3)); - } - if (!is_test) { - ctx->SetOutputDim("BatchGate", input_dims); - ctx->SetOutputDim("BatchResetHiddenPrev", {input_dims[0], frame_size}); - ctx->SetOutputDim("BatchHidden", {input_dims[0], frame_size}); - } - ctx->SetOutputDim("Hidden", {input_dims[0], frame_size}); - ctx->ShareLoD("Input", "Hidden"); - } -}; - -class GRUOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput( - "Input", - "(phi::DenseTensor) The first input is a LodTensor, which supports " - "variable-time length input sequence. The underlying tensor in " - "this phi::DenseTensor is a matrix with shape (T X 3D), where, T is " - "the total time steps in this mini-batch, D is the hidden size."); - AddInput("H0", - "(Tensor, optional) The initial hidden state is an optional " - "input. This is a tensor with shape (N x D), where N is the " - "batch size, D is the hidden size.") - .AsDispensable(); - AddInput( - "Weight", - "(Tensor) The learnable hidden-hidden weight matrix with shape " - "(D x 3D), where D is the hidden size. The elements continuous in " - "memory can be divided into two parts. The first part are weights of " - "the update gate and reset gate with shape (D x 2D), and the second " - "part are weights of output candidate with shape (D x D)."); - AddInput("Bias", - "(Tensor, optional) Bias vector with shape (1 x 3D) concatenating " - "bias of the update gate, reset gate and output candidate.") - .AsDispensable(); - AddOutput( - "BatchGate", - "(phi::DenseTensor) To compute with batches, sequence data will be " - "reorganized into several successive batches each containing " - "data from the same time step. The phi::DenseTensor BatchGate contains " - "the update gate, reset gate and output candidate values " - "organized in batches. The LoD size is 2. The first LoD contains " - "the batch offsets and the second LoD contains the indexes in " - "the raw sequence data.") - .AsIntermediate() - .AsExtra(); - AddOutput("BatchResetHiddenPrev", - "(phi::DenseTensor) The reset hidden state phi::DenseTensor " - "organized in batches. " - "This phi::DenseTensor is a matrix with shape (T X D) and has " - "the same LoD " - "with `BatchGate`.") - .AsIntermediate() - .AsExtra(); - AddOutput("BatchHidden", - "(phi::DenseTensor) The hidden state phi::DenseTensor organized " - "in batches. " - "This phi::DenseTensor is a matrix with shape (T X D) and has " - "the same LoD " - "with `BatchGate`.") - .AsIntermediate() - .AsExtra(); - AddOutput("Hidden", - "(phi::DenseTensor) the hidden state phi::DenseTensor organized " - "in sequences. " - "This phi::DenseTensor is a matrix with shape (T X D) and has " - "the same LoD with `BatchGate`."); - AddAttr("activation", - "(string, default tanh) " - "The activation type used for output candidate {h}_t.") - .SetDefault("tanh"); - AddAttr( - "gate_activation", - "(string, default sigmoid) " - "The activation type used in update gate and reset gate.") - .SetDefault("sigmoid"); - AddAttr("is_reverse", - "(bool, default: False) " - "whether to compute reversed GRU.") - .SetDefault(false); - AddAttr("origin_mode", - "bool" - "use origin mode in article https://arxiv.org/abs/1412.3555") - .SetDefault(false); - AddComment(R"DOC( -GRU Operator implements part calculations of the complete GRU as following: - -$$ -update\_gate: u_t = actGate(xu_t + W_u * h_{t-1} + b_u) \\ -reset\_gate: r_t = actGate(xr_t + W_r * h_{t-1} + b_r) \\ -output\_candidate: {h}_t = actNode(xc_t + W_c * dot(r_t, h_{t-1}) + b_c) \\ -output: h_t = dot((1 - u_t), h_{t-1}) + dot(u_t, {h}_t) -$$ - -@note To implement the complete GRU, fully-connected operator must be used -before to feed xu, xr and xc as the Input of GRU operator. -)DOC"); - } -}; - -class GRUGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "GRU@Grad"); - OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "GRU@Grad"); - OP_INOUT_CHECK( - ctx->HasInput("BatchGate"), "Input", "BatchGate", "GRU@Grad"); - OP_INOUT_CHECK(ctx->HasInput("BatchResetHiddenPrev"), - "Input", - "BatchResetHiddenPrev", - "GRU@Grad"); - OP_INOUT_CHECK( - ctx->HasInput("BatchHidden"), "Input", "BatchHidden", "GRU@Grad"); - OP_INOUT_CHECK(ctx->HasInput("Hidden"), "Input", "Hidden", "GRU@Grad"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Hidden")), - "Input", - framework::GradVarName("Hidden"), - "GRU@Grad"); - - auto input_dims = ctx->GetInputDim("Input"); - auto weight_dims = ctx->GetInputDim("Weight"); - int input_size = static_cast(input_dims[1]); - int frame_size = static_cast(weight_dims[0]); - int weight_height = static_cast(weight_dims[0]); - int weight_width = static_cast(weight_dims[1]); - PADDLE_ENFORCE_EQ( - input_size, - frame_size * 3, - phi::errors::InvalidArgument( - "The second dimension of Input(Input) must be 3 times of " - "frame_size in GRUOp, but received %d (Input) vs %d (frame_size).", - input_size, - frame_size)); - PADDLE_ENFORCE_EQ( - weight_height, - frame_size, - phi::errors::InvalidArgument( - "The shape of Input(Weight) matrix must be [frame_size, frame_size " - "* 3], but received [%d, %d] (Weight) vs [%d, %d] (frame_size).", - weight_height, - weight_width, - frame_size, - frame_size * 3)); - PADDLE_ENFORCE_EQ( - weight_width, - frame_size * 3, - phi::errors::InvalidArgument( - "The shape of Input(Weight) matrix must be [frame_size, frame_size " - "* 3], but received [%d, %d] (Weight) vs [%d, %d] (frame_size).", - weight_height, - weight_width, - frame_size, - frame_size * 3)); - if (ctx->HasInput("H0")) { - auto h0_dims = ctx->GetInputDim("H0"); - PADDLE_ENFORCE_EQ( - h0_dims[1], - frame_size, - phi::errors::InvalidArgument( - "The width of Input(H0) must be equal to frame_size, but " - "received %d (width of H0) vs %d (frame_size).", - h0_dims[1], - frame_size)); - auto h0_grad_name = framework::GradVarName("H0"); - if (ctx->HasOutput(h0_grad_name)) - ctx->SetOutputDim(h0_grad_name, h0_dims); - } - if (ctx->HasInput("Bias")) { - auto bias_dims = ctx->GetInputDim("Bias"); - int bias_height = static_cast(bias_dims[0]); - int bias_width = static_cast(bias_dims[1]); - PADDLE_ENFORCE_EQ( - bias_height, - 1, - phi::errors::InvalidArgument( - "The shape of Bias must be [1, frame_size * 3], but received " - "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).", - bias_height, - bias_width, - frame_size * 3)); - PADDLE_ENFORCE_EQ( - bias_width, - frame_size * 3, - phi::errors::InvalidArgument( - "The shape of Bias must be [1, frame_size * 3], but received " - "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).", - bias_height, - bias_width, - frame_size * 3)); - auto bias_grad_name = framework::GradVarName("Bias"); - if (ctx->HasOutput(bias_grad_name)) - ctx->SetOutputDim(bias_grad_name, bias_dims); - } - auto input_grad_name = framework::GradVarName("Input"); - if (ctx->HasOutput(input_grad_name)) - ctx->SetOutputDim(input_grad_name, input_dims); - auto weight_grad_name = framework::GradVarName("Weight"); - if (ctx->HasOutput(weight_grad_name)) - ctx->SetOutputDim(weight_grad_name, weight_dims); - } - - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Hidden")), - ctx.device_context().GetPlace()); - } -}; - -template -class GRUCPUKernel : public framework::OpKernel { - public: - void BatchCompute(const framework::ExecutionContext& context) const { - using LodTensorPtr = phi::DenseTensor*; - bool is_test = context.Attr("is_test"); - - bool origin_mode = context.Attr("origin_mode"); - auto* input = context.Input("Input"); - auto* h0 = context.Input("H0"); - auto* weight = context.Input("Weight"); - const T* weight_data = weight->data(); - auto* bias = context.Input("Bias"); - auto* hidden = context.Output("Hidden"); - hidden->mutable_data(context.GetPlace()); - - auto input_dims = input->dims(); - auto hidden_dims = hidden->dims(); - - LodTensorPtr batch_gate = nullptr, batch_reset_hidden_prev = nullptr, - batch_hidden = nullptr; - phi::DenseTensor batch_gate_tmp, batch_reset_hidden_prev_tmp, - batch_hidden_tmp; - if (is_test) { - batch_gate = &batch_gate_tmp; - batch_gate->Resize(input_dims); - - batch_reset_hidden_prev = &batch_reset_hidden_prev_tmp; - batch_reset_hidden_prev->Resize(hidden_dims); - - batch_hidden = &batch_hidden_tmp; - batch_hidden->Resize(hidden_dims); - } else { - batch_gate = context.Output("BatchGate"); - batch_hidden = context.Output("BatchHidden"); - batch_reset_hidden_prev = - context.Output("BatchResetHiddenPrev"); - } - batch_gate->mutable_data(context.GetPlace()); - batch_reset_hidden_prev->mutable_data(context.GetPlace()); - batch_hidden->mutable_data(context.GetPlace()); - - bool is_reverse = context.Attr("is_reverse"); - phi::funcs::LoDTensor2BatchFunctor to_batch; - auto& dev_ctx = context.template device_context(); - to_batch(dev_ctx, *input, batch_gate, true, is_reverse); - - if (bias) { - phi::funcs::RowwiseAdd add_bias; - add_bias(dev_ctx, *batch_gate, *bias, batch_gate); - } - - int frame_size = static_cast(hidden_dims[1]); - phi::funcs::GRUMetaValue gru_value; - gru_value.gate_weight = const_cast(weight_data); - gru_value.state_weight = - const_cast(weight_data + 2 * frame_size * frame_size); - phi::DenseTensor ordered_h0; - - phi::Vector order(batch_gate->lod()[2]); - - if (h0) { - // Since the batch computing for GRU reorders the input sequences - // according to their length. The initialized cell state also needs - // to reorder. - ReorderInitState( - context.template device_context(), - *h0, - order, - &ordered_h0, - true); - gru_value.prev_out_value = ordered_h0.data(); - } else { - gru_value.prev_out_value = nullptr; - } - auto batch_starts = batch_gate->lod()[0]; - size_t seq_len = batch_starts.size() - 1; - auto active_node = phi::funcs::detail::GetActivationType( - context.Attr("activation")); - auto active_gate = phi::funcs::detail::GetActivationType( - context.Attr("gate_activation")); - -#ifdef PADDLE_WITH_MKLML - // use MKL packed to speedup GEMM - if (FLAGS_paddle_num_threads >= 4) { - auto blas = phi::funcs::GetBlas(dev_ctx); - T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, - 1 /*height of C*/, - frame_size * 2 /*width of weight*/, - frame_size /*height of height*/); - PADDLE_ENFORCE_NOT_NULL( - packed_gate, - phi::errors::NotFound( - "The calculation result of packed_gate by " - "GEMM_ALLOC should not be null when using MKL.")); - blas.GEMM_PACK(CblasBMatrix, - CblasNoTrans, - 1 /*cur bs?*/, - frame_size * 2, - frame_size, - T(1.0), - gru_value.gate_weight, - frame_size * 2, - packed_gate); - T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, - 1 /*height of C*/, - frame_size /*width of weight*/, - frame_size /*height of height*/); - PADDLE_ENFORCE_NOT_NULL( - packed_state, - phi::errors::NotFound( - "The calculation result of packed_state by " - "GEMM_ALLOC should not be null when using MKL.")); - blas.GEMM_PACK(CblasBMatrix, - CblasNoTrans, - 1 /*cur bs?*/, - frame_size, - frame_size, - T(1.0), - gru_value.state_weight, - frame_size, - packed_state); - for (size_t n = 0; n < seq_len; n++) { - int bstart = static_cast(batch_starts[n]); - int bend = static_cast(batch_starts[n + 1]); - int cur_batch_size = bend - bstart; - - phi::DenseTensor gate_t = batch_gate->Slice(bstart, bend); - phi::DenseTensor reset_hidden_prev_t = - batch_reset_hidden_prev->Slice(bstart, bend); - phi::DenseTensor hidden_t = batch_hidden->Slice(bstart, bend); - gru_value.output_value = hidden_t.data(); - gru_value.gate_value = gate_t.data(); - gru_value.reset_output_value = reset_hidden_prev_t.data(); - - if (gru_value.prev_out_value) { - blas.GEMM_COMPUTE(CblasNoTrans, - CblasPacked, - cur_batch_size, - frame_size * 2, - frame_size, - gru_value.prev_out_value, - frame_size, - packed_gate, - frame_size * 2, - T(1), - gru_value.gate_value, - frame_size * 3); - } - - phi::funcs::detail::forward_reset_output( - phi::funcs::detail::forward::gru_resetOutput(), - gru_value, - frame_size, - cur_batch_size, - active_gate); - - if (gru_value.prev_out_value) { - blas.GEMM_COMPUTE(CblasNoTrans, - CblasPacked, - cur_batch_size, - frame_size, - frame_size, - gru_value.reset_output_value, - frame_size, - packed_state, - frame_size, - T(1), - gru_value.gate_value + frame_size * 2, - frame_size * 3); - } - - phi::funcs::detail::forward_final_output( - phi::funcs::detail::forward::gru_finalOutput(), - gru_value, - frame_size, - cur_batch_size, - active_node, - origin_mode); - - gru_value.prev_out_value = gru_value.output_value; - } - - blas.GEMM_FREE(packed_gate); - blas.GEMM_FREE(packed_state); - } else { -#endif - for (size_t n = 0; n < seq_len; n++) { - int bstart = static_cast(batch_starts[n]); - int bend = static_cast(batch_starts[n + 1]); - int cur_batch_size = bend - bstart; - - phi::DenseTensor gate_t = batch_gate->Slice(bstart, bend); - phi::DenseTensor reset_hidden_prev_t = - batch_reset_hidden_prev->Slice(bstart, bend); - phi::DenseTensor hidden_t = batch_hidden->Slice(bstart, bend); - gru_value.output_value = hidden_t.data(); - gru_value.gate_value = gate_t.data(); - gru_value.reset_output_value = reset_hidden_prev_t.data(); - - phi::funcs::GRUUnitFunctor::compute( - dev_ctx, // NOLINT - gru_value, - frame_size, - cur_batch_size, - active_node, - active_gate, - origin_mode); - - gru_value.prev_out_value = gru_value.output_value; - } -#ifdef PADDLE_WITH_MKLML - } -#endif - phi::funcs::Batch2LoDTensorFunctor to_seq; - batch_hidden->set_lod(batch_gate->lod()); - to_seq(dev_ctx, *batch_hidden, hidden); - } - - void Compute(const framework::ExecutionContext& context) const override { - BatchCompute(context); - } -}; - -template -class GRUGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr grad_op) const override { - grad_op->SetType("gru_grad"); - grad_op->SetInput("Input", this->Input("Input")); - grad_op->SetInput("H0", this->Input("H0")); - grad_op->SetInput("Bias", this->Input("Bias")); - grad_op->SetInput("Weight", this->Input("Weight")); - - grad_op->SetInput("BatchGate", this->Output("BatchGate")); - grad_op->SetInput("BatchResetHiddenPrev", - this->Output("BatchResetHiddenPrev")); - grad_op->SetInput("BatchHidden", this->Output("BatchHidden")); - grad_op->SetInput("Hidden", this->Output("Hidden")); - - grad_op->SetInput(framework::GradVarName("Hidden"), - this->OutputGrad("Hidden")); - - grad_op->SetOutput(framework::GradVarName("H0"), this->InputGrad("H0")); - grad_op->SetOutput(framework::GradVarName("Input"), - this->InputGrad("Input")); - grad_op->SetOutput(framework::GradVarName("Weight"), - this->InputGrad("Weight")); - grad_op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias")); - - grad_op->SetAttrMap(this->Attrs()); - } -}; - -DECLARE_NO_NEED_BUFFER_VARS_INFERER(GRUGradOpNoNeedBufferVarInferer, - "Input", - "Bias"); - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OPERATOR(gru, - ops::GRUOp, - ops::GRUOpMaker, - ops::GRUGradOpMaker, - ops::GRUGradOpMaker); -REGISTER_OPERATOR(gru_grad, - ops::GRUGradOp, - ops::GRUGradOpNoNeedBufferVarInferer); - -PD_REGISTER_STRUCT_KERNEL( - gru, CPU, ALL_LAYOUT, ops::GRUCPUKernel, float, double) {} -PD_REGISTER_STRUCT_KERNEL( - gru_grad, CPU, ALL_LAYOUT, ops::GRUGradKernel, float, double) {} diff --git a/paddle/fluid/operators/gru_op.cu.cc b/paddle/fluid/operators/gru_op.cu.cc deleted file mode 100644 index f7b4317832b0e..0000000000000 --- a/paddle/fluid/operators/gru_op.cu.cc +++ /dev/null @@ -1,140 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/gru_op.h" - -namespace paddle { -namespace operators { - -template -class GRUKernel : public framework::OpKernel { - public: - void BatchCompute(const framework::ExecutionContext& context) const { - using LodTensorPtr = phi::DenseTensor*; - - bool is_test = context.Attr("is_test"); - bool origin_mode = context.Attr("origin_mode"); - auto* input = context.Input("Input"); - auto* h0 = context.Input("H0"); - auto* weight = context.Input("Weight"); - const T* weight_data = weight->data(); - auto* bias = context.Input("Bias"); - auto* hidden = context.Output("Hidden"); - hidden->mutable_data(context.GetPlace()); - - auto input_dims = input->dims(); - auto hidden_dims = hidden->dims(); - - LodTensorPtr batch_gate, batch_reset_hidden_prev, batch_hidden; - phi::DenseTensor batch_gate_tmp, batch_reset_hidden_prev_tmp, - batch_hidden_tmp; - if (is_test) { - batch_gate = &batch_gate_tmp; - batch_gate->Resize(input_dims); - - batch_reset_hidden_prev = &batch_reset_hidden_prev_tmp; - batch_reset_hidden_prev->Resize(hidden_dims); - - batch_hidden = &batch_hidden_tmp; - batch_hidden->Resize(hidden_dims); - } else { - batch_gate = context.Output("BatchGate"); - batch_hidden = context.Output("BatchHidden"); - batch_reset_hidden_prev = - context.Output("BatchResetHiddenPrev"); - } - batch_gate->mutable_data(context.GetPlace()); - batch_reset_hidden_prev->mutable_data(context.GetPlace()); - batch_hidden->mutable_data(context.GetPlace()); - - bool is_reverse = context.Attr("is_reverse"); - phi::funcs::LoDTensor2BatchFunctor to_batch; - auto& dev_ctx = context.template device_context(); - to_batch(dev_ctx, *input, batch_gate, true, is_reverse); - - if (bias) { - phi::funcs::RowwiseAdd add_bias; - add_bias(dev_ctx, *batch_gate, *bias, batch_gate); - } - - int frame_size = hidden_dims[1]; - phi::funcs::GRUMetaValue gru_value; - gru_value.gate_weight = const_cast(weight_data); - gru_value.state_weight = - const_cast(weight_data + 2 * frame_size * frame_size); - phi::DenseTensor ordered_h0; - - phi::Vector order(batch_gate->lod()[2]); - - if (h0) { - // Since the batch computing for GRU reorders the input sequences - // according to their length. The initialized cell state also needs - // to reorder. - ReorderInitState( - context.template device_context(), - *h0, - order, - &ordered_h0, - true); - gru_value.prev_out_value = ordered_h0.data(); - } else { - gru_value.prev_out_value = nullptr; - } - auto batch_starts = batch_gate->lod()[0]; - size_t num_batch = batch_starts.size() - 1; - auto active_node = phi::funcs::detail::GetActivationType( - context.Attr("activation")); - auto active_gate = phi::funcs::detail::GetActivationType( - context.Attr("gate_activation")); - for (size_t n = 0; n < num_batch; n++) { - int bstart = static_cast(batch_starts[n]); - int bend = static_cast(batch_starts[n + 1]); - int cur_batch_size = bend - bstart; - - phi::DenseTensor gate_t = batch_gate->Slice(bstart, bend); - phi::DenseTensor reset_hidden_prev_t = - batch_reset_hidden_prev->Slice(bstart, bend); - phi::DenseTensor hidden_t = batch_hidden->Slice(bstart, bend); - gru_value.output_value = hidden_t.data(); - gru_value.gate_value = gate_t.data(); - gru_value.reset_output_value = reset_hidden_prev_t.data(); - phi::funcs::GRUUnitFunctor::compute(dev_ctx, // NOLINT - gru_value, - frame_size, - cur_batch_size, - active_node, - active_gate, - origin_mode); - gru_value.prev_out_value = gru_value.output_value; - } - - phi::funcs::Batch2LoDTensorFunctor to_seq; - batch_hidden->set_lod(batch_gate->lod()); - to_seq(dev_ctx, *batch_hidden, hidden); - } - - void Compute(const framework::ExecutionContext& context) const override { - BatchCompute(context); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -PD_REGISTER_STRUCT_KERNEL(gru, GPU, ALL_LAYOUT, ops::GRUKernel, float, double) { -} -PD_REGISTER_STRUCT_KERNEL( - gru_grad, GPU, ALL_LAYOUT, ops::GRUGradKernel, float, double) {} diff --git a/paddle/fluid/operators/gru_op.h b/paddle/fluid/operators/gru_op.h deleted file mode 100644 index 773e9ff510852..0000000000000 --- a/paddle/fluid/operators/gru_op.h +++ /dev/null @@ -1,186 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/kernels/funcs/detail/activation_functions.h" -#include "paddle/phi/kernels/funcs/eigen/common.h" -#include "paddle/phi/kernels/funcs/gru_compute.h" -#include "paddle/phi/kernels/funcs/math_function.h" -#include "paddle/phi/kernels/funcs/sequence2batch.h" - -namespace paddle { -namespace operators { - -template -inline void ReorderInitState(const DeviceContext& ctx, - const phi::DenseTensor& src, - phi::Vector index_lod, - phi::DenseTensor* dst, - bool indexed_src) { - phi::funcs::CopyMatrixRowsFunctor row_shuffle; - dst->mutable_data(src.dims(), ctx.GetPlace()); - row_shuffle(ctx, src, index_lod, dst, indexed_src); -} - -template -class GRUGradKernel : public framework::OpKernel { - public: - void BatchCompute(const framework::ExecutionContext& context) const { - bool origin_mode = context.Attr("origin_mode"); - auto* h0 = context.Input("H0"); - auto* weight = context.Input("Weight"); - const T* weight_data = weight->data(); - auto* batch_gate = context.Input("BatchGate"); - auto* batch_reset_hidden_prev = - context.Input("BatchResetHiddenPrev"); - auto* batch_hidden = context.Input("BatchHidden"); - auto* hidden = context.Input("Hidden"); - auto* hidden_grad = - context.Input(framework::GradVarName("Hidden")); - auto* input_grad = - context.Output(framework::GradVarName("Input")); - auto* h0_grad = - context.Output(framework::GradVarName("H0")); - auto* weight_grad = - context.Output(framework::GradVarName("Weight")); - auto* bias_grad = - context.Output(framework::GradVarName("Bias")); - - auto gate_dims = batch_gate->dims(); - auto hidden_dims = hidden->dims(); - int frame_size = hidden_dims[1]; - - phi::funcs::LoDTensor2BatchFunctor to_batch; - phi::DenseTensor batch_hidden_grad, batch_gate_grad, - batch_reset_hidden_prev_grad; - batch_hidden_grad.mutable_data(hidden_dims, context.GetPlace()); - batch_gate_grad.mutable_data(gate_dims, context.GetPlace()); - batch_reset_hidden_prev_grad.mutable_data(hidden_dims, - context.GetPlace()); - phi::funcs::SetConstant zero; - auto& dev_ctx = context.template device_context(); - zero(dev_ctx, &batch_hidden_grad, static_cast(0.0)); - zero(dev_ctx, &batch_gate_grad, static_cast(0.0)); - zero(dev_ctx, &batch_reset_hidden_prev_grad, static_cast(0.0)); - - phi::DenseTensor ordered_h0, ordered_h0_grad; - - phi::Vector order(batch_gate->lod()[2]); - - if (h0) { - ReorderInitState( - dev_ctx, *h0, order, &ordered_h0, true); - } - if (h0_grad) { - ordered_h0_grad.mutable_data(h0_grad->dims(), context.GetPlace()); - zero(context.template device_context(), - &ordered_h0_grad, - static_cast(0.0)); - } - - bool is_reverse = context.Attr("is_reverse"); - batch_hidden_grad.set_lod(batch_hidden->lod()); - to_batch(dev_ctx, *hidden_grad, &batch_hidden_grad, false, is_reverse); - - phi::funcs::GRUMetaValue gru_value; - gru_value.gate_weight = const_cast(weight_data); - gru_value.state_weight = - const_cast(weight_data + 2 * frame_size * frame_size); - - phi::funcs::GRUMetaGrad gru_grad; - if (weight_grad) { - gru_grad.gate_weight_grad = - weight_grad->mutable_data(context.GetPlace()); - zero(dev_ctx, weight_grad, static_cast(0.0)); - gru_grad.state_weight_grad = - weight_grad->data() + 2 * frame_size * frame_size; - } else { - gru_grad.gate_weight_grad = nullptr; - gru_grad.state_weight_grad = nullptr; - } - - auto batch_starts = batch_hidden_grad.lod()[0]; - size_t num_batch = batch_starts.size() - 1; - auto active_node = phi::funcs::detail::GetActivationType( - context.Attr("activation")); - auto active_gate = phi::funcs::detail::GetActivationType( - context.Attr("gate_activation")); - for (int n = static_cast(num_batch) - 1; n >= 0; n--) { - int bstart = static_cast(batch_starts[n]); - int bend = static_cast(batch_starts[n + 1]); - int cur_batch_size = bend - bstart; - - phi::DenseTensor gate_t = batch_gate->Slice(bstart, bend); - gru_value.gate_value = gate_t.data(); - phi::DenseTensor reset_hidden_prev_t = - batch_reset_hidden_prev->Slice(bstart, bend); - gru_value.reset_output_value = reset_hidden_prev_t.data(); - - phi::DenseTensor hidden_grad_t = batch_hidden_grad.Slice(bstart, bend); - gru_grad.output_grad = hidden_grad_t.data(); - phi::DenseTensor gate_grad_t = batch_gate_grad.Slice(bstart, bend); - gru_grad.gate_grad = gate_grad_t.data(); - phi::DenseTensor reset_hidden_prev_grad_t = - batch_reset_hidden_prev_grad.Slice(bstart, bend); - gru_grad.reset_output_grad = reset_hidden_prev_grad_t.data(); - if (n == 0) { - gru_value.prev_out_value = h0 ? ordered_h0.data() : nullptr; - gru_grad.prev_out_grad = - h0 && h0_grad ? ordered_h0_grad.data() : nullptr; - } else { - int bstart_pre = static_cast(batch_starts[n - 1]); - phi::DenseTensor hidden_prev_t = - batch_hidden->Slice(bstart_pre, bstart); - gru_value.prev_out_value = hidden_prev_t.data(); - phi::DenseTensor hidden_prev_grad_t = - batch_hidden_grad.Slice(bstart_pre, bstart); - gru_grad.prev_out_grad = hidden_prev_grad_t.data(); - } - gru_value.output_value = nullptr; - phi::funcs::GRUUnitGradFunctor::compute(dev_ctx, - gru_value, - gru_grad, - frame_size, - cur_batch_size, - active_node, - active_gate, - origin_mode); - } - if (input_grad) { - input_grad->mutable_data(context.GetPlace()); - phi::funcs::Batch2LoDTensorFunctor to_seq; - batch_gate_grad.set_lod(batch_gate->lod()); - to_seq(dev_ctx, batch_gate_grad, input_grad); - } - if (bias_grad) { - bias_grad->mutable_data(context.GetPlace()); - phi::funcs::ColwiseSum col_sum; - col_sum(dev_ctx, batch_gate_grad, bias_grad); - } - if (h0 && h0_grad) { - ReorderInitState( - dev_ctx, ordered_h0_grad, order, h0_grad, false); - } - } - - void Compute(const framework::ExecutionContext& context) const override { - BatchCompute(context); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index c6c145019f40b..4fdf26168ea49 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -574,6 +574,101 @@ void GeneralUnaryGradInferMeta(const MetaTensor& x, MetaTensor* dx) { } } +void GruGradInferMeta(const MetaTensor& input, + const MetaTensor& h0, + const MetaTensor& weight, + const MetaTensor& bias, + MetaTensor* input_grad, + MetaTensor* h0_grad, + MetaTensor* weight_grad, + MetaTensor* bias_grad, + MetaConfig config) { + const auto& input_dims = input.dims(); + const auto& weight_dims = weight.dims(); + int input_size = static_cast(input_dims[1]); + int frame_size = static_cast(weight_dims[0]); + int weight_height = static_cast(weight_dims[0]); + int weight_width = static_cast(weight_dims[1]); + PADDLE_ENFORCE_EQ( + input_size, + frame_size * 3, + phi::errors::InvalidArgument( + "The second dimension of Input(Input) must be 3 times of " + "frame_size in GRUOp, but received %d (Input) vs %d (frame_size).", + input_size, + frame_size)); + PADDLE_ENFORCE_EQ( + weight_height, + frame_size, + phi::errors::InvalidArgument( + "The shape of Input(Weight) matrix must be [frame_size, frame_size " + "* 3], but received [%d, %d] (Weight) vs [%d, %d] (frame_size).", + weight_height, + weight_width, + frame_size, + frame_size * 3)); + PADDLE_ENFORCE_EQ( + weight_width, + frame_size * 3, + phi::errors::InvalidArgument( + "The shape of Input(Weight) matrix must be [frame_size, frame_size " + "* 3], but received [%d, %d] (Weight) vs [%d, %d] (frame_size).", + weight_height, + weight_width, + frame_size, + frame_size * 3)); + if (h0.initialized()) { + const auto& h0_dims = h0.dims(); + PADDLE_ENFORCE_EQ( + h0_dims[1], + frame_size, + phi::errors::InvalidArgument( + "The width of Input(H0) must be equal to frame_size, but " + "received %d (width of H0) vs %d (frame_size).", + h0_dims[1], + frame_size)); + if (h0_grad->initialized()) { + h0_grad->set_dims(h0_dims); + h0_grad->set_dtype(h0.dtype()); + } + } + if (bias.initialized()) { + const auto& bias_dims = bias.dims(); + int bias_height = static_cast(bias_dims[0]); + int bias_width = static_cast(bias_dims[1]); + PADDLE_ENFORCE_EQ( + bias_height, + 1, + phi::errors::InvalidArgument( + "The shape of Bias must be [1, frame_size * 3], but received " + "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).", + bias_height, + bias_width, + frame_size * 3)); + PADDLE_ENFORCE_EQ( + bias_width, + frame_size * 3, + phi::errors::InvalidArgument( + "The shape of Bias must be [1, frame_size * 3], but received " + "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).", + bias_height, + bias_width, + frame_size * 3)); + if (bias_grad->initialized()) { + bias_grad->set_dims(bias_dims); + bias_grad->set_dtype(bias.dtype()); + } + } + if (input_grad->initialized()) { + input_grad->set_dims(input_dims); + input_grad->set_dtype(input.dtype()); + } + if (weight_grad->initialized()) { + weight_grad->set_dims(weight_dims); + weight_grad->set_dtype(weight.dtype()); + } +} + void GumbelSoftmaxGradInferMeta(const MetaTensor& out, const MetaTensor& dout, int axis, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 39b59958d6752..e538609d87c93 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -272,6 +272,16 @@ void GeneralQuinaryGradInferMeta(const MetaTensor& x, MetaTensor* dk, MetaTensor* dl); +void GruGradInferMeta(const MetaTensor& input, + const MetaTensor& h0, + const MetaTensor& weight, + const MetaTensor& bias, + MetaTensor* input_grad, + MetaTensor* h0_grad, + MetaTensor* weight_grad, + MetaTensor* bias_grad, + MetaConfig config = MetaConfig()); + void GumbelSoftmaxGradInferMeta(const MetaTensor& out, const MetaTensor& dout, int axis, diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 65de2d4e3ce21..02b051b11d4e1 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2375,6 +2375,91 @@ void GraphReindexInferMeta(const MetaTensor& x, out_nodes->set_dtype(x.dtype()); } +void GruInferMeta(const MetaTensor& input, + const MetaTensor& h0, + const MetaTensor& weight, + const MetaTensor& bias, + const std::string& activation, + const std::string& gate_activation, + bool is_reverse, + bool origin_mode, + bool is_test, + MetaTensor* batch_gate, + MetaTensor* batch_reset_hidden_prev, + MetaTensor* batch_hidden, + MetaTensor* hidden, + MetaConfig config) { + const auto& input_dims = input.dims(); + const auto& weight_dims = weight.dims(); + int input_size = static_cast(input_dims[1]); + int frame_size = static_cast(weight_dims[0]); + if (config.is_runtime) { + PADDLE_ENFORCE_EQ(input_size, + frame_size * 3, + phi::errors::InvalidArgument( + "The second dimension of Input(Input) must be 3 " + "times of frame_size in GRUOp, but received %d " + "(Input) vs %d (frame_size).", + input_size, + frame_size)); + } + PADDLE_ENFORCE_EQ( + weight_dims[1], + frame_size * 3, + phi::errors::InvalidArgument( + "The shape of Input(Weight) matrix must be [frame_size, frame_size " + "* 3], but received [%d, %d] (Weight) vs [%d, %d] (frame_size).", + weight_dims[0], + weight_dims[1], + frame_size, + frame_size * 3)); + if (h0.initialized()) { + const auto& h0_dims = h0.dims(); + PADDLE_ENFORCE_EQ( + h0_dims[1], + frame_size, + phi::errors::InvalidArgument( + "The width of Input(H0) must be equal to frame_size, but " + "received %d (width of H0) vs %d (frame_size).", + h0_dims[1], + frame_size)); + } + if (bias.initialized()) { + const auto& bias_dims = bias.dims(); + int bias_height = static_cast(bias_dims[0]); + int bias_width = static_cast(bias_dims[1]); + PADDLE_ENFORCE_EQ( + bias_height, + 1, + phi::errors::InvalidArgument( + "The shape of Bias must be [1, frame_size * 3], but received " + "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).", + bias_height, + bias_width, + frame_size * 3)); + PADDLE_ENFORCE_EQ( + bias_width, + frame_size * 3, + phi::errors::InvalidArgument( + "The shape of Bias must be [1, frame_size * 3], but received " + "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).", + bias_height, + bias_width, + frame_size * 3)); + } + if (!is_test) { + batch_gate->set_dims(input_dims); + batch_gate->set_dtype(input.dtype()); + batch_reset_hidden_prev->set_dims({input_dims[0], frame_size}); + batch_reset_hidden_prev->set_dtype(input.dtype()); + batch_hidden->set_dims({input_dims[0], frame_size}); + batch_hidden->set_dtype(input.dtype()); + } + hidden->set_dims({input_dims[0], frame_size}); + hidden->set_dtype(input.dtype()); + hidden->share_lod(input); +} + void GraphSampleNeighborsInferMeta(const MetaTensor& row, const MetaTensor& col_ptr, const MetaTensor& x, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index d60f8b0f3c443..a71d891ea70c9 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -492,6 +492,21 @@ void GraphReindexInferMeta(const MetaTensor& x, MetaTensor* reindex_dst, MetaTensor* out_nodes); +void GruInferMeta(const MetaTensor& input, + const MetaTensor& h0, + const MetaTensor& weight, + const MetaTensor& bias, + const std::string& activation, + const std::string& gate_activation, + bool is_reverse, + bool origin_mode, + bool is_test, + MetaTensor* batch_gate, + MetaTensor* batch_reset_hidden_prev, + MetaTensor* batch_hidden, + MetaTensor* hidden, + MetaConfig config = MetaConfig()); + void GraphSampleNeighborsInferMeta(const MetaTensor& row, const MetaTensor& col_ptr, const MetaTensor& x, diff --git a/paddle/phi/kernels/cpu/gru_grad_kernel.cc b/paddle/phi/kernels/cpu/gru_grad_kernel.cc new file mode 100644 index 0000000000000..6581de6e8a01c --- /dev/null +++ b/paddle/phi/kernels/cpu/gru_grad_kernel.cc @@ -0,0 +1,19 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/gru_kernel_impl.h" + +PD_REGISTER_KERNEL( + gru_grad, CPU, ALL_LAYOUT, phi::GRUGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/gru_kernel.cc b/paddle/phi/kernels/cpu/gru_kernel.cc new file mode 100644 index 0000000000000..2850bf2b4a524 --- /dev/null +++ b/paddle/phi/kernels/cpu/gru_kernel.cc @@ -0,0 +1,238 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/funcs/detail/gru_kernel.h" +#include +#include +#include "paddle/common/flags.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/detail/gru_cpu_kernel.h" +#include "paddle/phi/kernels/impl/gru_kernel_impl.h" + +COMMON_DECLARE_int32(paddle_num_threads); + +namespace phi { + +template +void GRUCPUKernel(const Context &dev_ctx, + const DenseTensor &input, + const paddle::optional &h0, + const DenseTensor &weight, + const paddle::optional &bias, + const std::string &activation, + const std::string &gate_activation, + bool is_reverse, + bool origin_mode, + bool is_test, + DenseTensor *param_batch_gate, + DenseTensor *param_batch_reset_hidden_prev, + DenseTensor *param_batch_hidden, + DenseTensor *hidden) { + const T *weight_data = weight.data(); + dev_ctx.template Alloc(hidden); + + auto input_dims = input.dims(); + auto hidden_dims = hidden->dims(); + + phi::DenseTensor *batch_gate = nullptr; + phi::DenseTensor *batch_reset_hidden_prev = nullptr; + phi::DenseTensor *batch_hidden = nullptr; + phi::DenseTensor batch_gate_tmp, batch_reset_hidden_prev_tmp, + batch_hidden_tmp; + if (is_test) { + batch_gate = &batch_gate_tmp; + batch_gate->Resize(input_dims); + + batch_reset_hidden_prev = &batch_reset_hidden_prev_tmp; + batch_reset_hidden_prev->Resize(hidden_dims); + + batch_hidden = &batch_hidden_tmp; + batch_hidden->Resize(hidden_dims); + } else { + batch_gate = param_batch_gate; + batch_hidden = param_batch_hidden; + batch_reset_hidden_prev = param_batch_reset_hidden_prev; + } + dev_ctx.template Alloc(batch_gate); + dev_ctx.template Alloc(batch_reset_hidden_prev); + dev_ctx.template Alloc(batch_hidden); + + phi::funcs::LoDTensor2BatchFunctor to_batch; + to_batch(dev_ctx, input, batch_gate, true, is_reverse); + + if (bias) { + phi::funcs::RowwiseAdd add_bias; + add_bias(dev_ctx, *batch_gate, bias, batch_gate); + } + + int frame_size = static_cast(hidden_dims[1]); + phi::funcs::GRUMetaValue gru_value; + gru_value.gate_weight = const_cast(weight_data); + gru_value.state_weight = + const_cast(weight_data + 2 * frame_size * frame_size); + phi::DenseTensor ordered_h0; + + phi::Vector order(batch_gate->lod()[2]); + + if (h0) { + // Since the batch computing for GRU reorders the input sequences + // according to their length. The initialized cell state also needs + // to reorder. + ReorderInitState(dev_ctx, *h0, order, &ordered_h0, true); + gru_value.prev_out_value = ordered_h0.data(); + } else { + gru_value.prev_out_value = nullptr; + } + auto batch_starts = batch_gate->lod()[0]; + size_t seq_len = batch_starts.size() - 1; + auto active_node = phi::funcs::detail::GetActivationType(activation); + auto active_gate = phi::funcs::detail::GetActivationType(gate_activation); + +#ifdef PADDLE_WITH_MKLML + // use MKL packed to speedup GEMM + if (FLAGS_paddle_num_threads >= 4) { + auto blas = phi::funcs::GetBlas(dev_ctx); + T *packed_gate = blas.GEMM_ALLOC(CblasBMatrix, + 1 /*height of C*/, + frame_size * 2 /*width of weight*/, + frame_size /*height of height*/); + PADDLE_ENFORCE_NOT_NULL( + packed_gate, + phi::errors::NotFound("The calculation result of packed_gate by " + "GEMM_ALLOC should not be null when using MKL.")); + blas.GEMM_PACK(CblasBMatrix, + CblasNoTrans, + 1 /*cur bs?*/, + frame_size * 2, + frame_size, + T(1.0), + gru_value.gate_weight, + frame_size * 2, + packed_gate); + T *packed_state = blas.GEMM_ALLOC(CblasBMatrix, + 1 /*height of C*/, + frame_size /*width of weight*/, + frame_size /*height of height*/); + PADDLE_ENFORCE_NOT_NULL( + packed_state, + phi::errors::NotFound("The calculation result of packed_state by " + "GEMM_ALLOC should not be null when using MKL.")); + blas.GEMM_PACK(CblasBMatrix, + CblasNoTrans, + 1 /*cur bs?*/, + frame_size, + frame_size, + T(1.0), + gru_value.state_weight, + frame_size, + packed_state); + for (size_t n = 0; n < seq_len; n++) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + int cur_batch_size = bend - bstart; + + phi::DenseTensor gate_t = batch_gate->Slice(bstart, bend); + phi::DenseTensor reset_hidden_prev_t = + batch_reset_hidden_prev->Slice(bstart, bend); + phi::DenseTensor hidden_t = batch_hidden->Slice(bstart, bend); + gru_value.output_value = hidden_t.data(); + gru_value.gate_value = gate_t.data(); + gru_value.reset_output_value = reset_hidden_prev_t.data(); + + if (gru_value.prev_out_value) { + blas.GEMM_COMPUTE(CblasNoTrans, + CblasPacked, + cur_batch_size, + frame_size * 2, + frame_size, + gru_value.prev_out_value, + frame_size, + packed_gate, + frame_size * 2, + T(1), + gru_value.gate_value, + frame_size * 3); + } + + phi::funcs::detail::forward_reset_output( + phi::funcs::detail::forward::gru_resetOutput(), + gru_value, + frame_size, + cur_batch_size, + active_gate); + + if (gru_value.prev_out_value) { + blas.GEMM_COMPUTE(CblasNoTrans, + CblasPacked, + cur_batch_size, + frame_size, + frame_size, + gru_value.reset_output_value, + frame_size, + packed_state, + frame_size, + T(1), + gru_value.gate_value + frame_size * 2, + frame_size * 3); + } + + phi::funcs::detail::forward_final_output( + phi::funcs::detail::forward::gru_finalOutput(), + gru_value, + frame_size, + cur_batch_size, + active_node, + origin_mode); + + gru_value.prev_out_value = gru_value.output_value; + } + + blas.GEMM_FREE(packed_gate); + blas.GEMM_FREE(packed_state); + } else { +#endif + for (size_t n = 0; n < seq_len; n++) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + int cur_batch_size = bend - bstart; + + phi::DenseTensor gate_t = batch_gate->Slice(bstart, bend); + phi::DenseTensor reset_hidden_prev_t = + batch_reset_hidden_prev->Slice(bstart, bend); + phi::DenseTensor hidden_t = batch_hidden->Slice(bstart, bend); + gru_value.output_value = hidden_t.data(); + gru_value.gate_value = gate_t.data(); + gru_value.reset_output_value = reset_hidden_prev_t.data(); + + phi::funcs::GRUUnitFunctor::compute(dev_ctx, // NOLINT + gru_value, + frame_size, + cur_batch_size, + active_node, + active_gate, + origin_mode); + + gru_value.prev_out_value = gru_value.output_value; + } +#ifdef PADDLE_WITH_MKLML + } +#endif + phi::funcs::Batch2LoDTensorFunctor to_seq; + batch_hidden->set_lod(batch_gate->lod()); + to_seq(dev_ctx, *batch_hidden, hidden); +} + +} // namespace phi +PD_REGISTER_KERNEL(gru, CPU, ALL_LAYOUT, phi::GRUCPUKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/gru_grad_kernel.cu b/paddle/phi/kernels/gpu/gru_grad_kernel.cu new file mode 100644 index 0000000000000..b21707ff25f7f --- /dev/null +++ b/paddle/phi/kernels/gpu/gru_grad_kernel.cu @@ -0,0 +1,19 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/gru_kernel_impl.h" + +PD_REGISTER_KERNEL( + gru_grad, GPU, ALL_LAYOUT, phi::GRUGradKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/gru_kernel.cu b/paddle/phi/kernels/gpu/gru_kernel.cu new file mode 100644 index 0000000000000..c9558a05f503d --- /dev/null +++ b/paddle/phi/kernels/gpu/gru_kernel.cu @@ -0,0 +1,122 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/gru_kernel_impl.h" + +namespace phi { + +template +void GRUKernel(const Context &dev_ctx, + const DenseTensor &input, + const paddle::optional &h0, + const DenseTensor &weight, + const paddle::optional &bias, + const std::string &activation, + const std::string &gate_activation, + bool is_reverse, + bool origin_mode, + bool is_test, + DenseTensor *param_batch_gate, + DenseTensor *param_batch_reset_hidden_prev, + DenseTensor *param_batch_hidden, + DenseTensor *hidden) { + const T *weight_data = weight.data(); + dev_ctx.template Alloc(hidden); + + auto input_dims = input.dims(); + auto hidden_dims = hidden->dims(); + + phi::DenseTensor *batch_gate; + phi::DenseTensor *batch_reset_hidden_prev; + phi::DenseTensor *batch_hidden; + phi::DenseTensor batch_gate_tmp, batch_reset_hidden_prev_tmp, + batch_hidden_tmp; + if (is_test) { + batch_gate = &batch_gate_tmp; + batch_gate->Resize(input_dims); + + batch_reset_hidden_prev = &batch_reset_hidden_prev_tmp; + batch_reset_hidden_prev->Resize(hidden_dims); + + batch_hidden = &batch_hidden_tmp; + batch_hidden->Resize(hidden_dims); + } else { + batch_gate = param_batch_gate; + batch_hidden = param_batch_hidden; + batch_reset_hidden_prev = param_batch_reset_hidden_prev; + } + dev_ctx.template Alloc(batch_gate); + dev_ctx.template Alloc(batch_reset_hidden_prev); + dev_ctx.template Alloc(batch_hidden); + + phi::funcs::LoDTensor2BatchFunctor to_batch; + to_batch(dev_ctx, input, batch_gate, true, is_reverse); + + if (bias) { + phi::funcs::RowwiseAdd add_bias; + add_bias(dev_ctx, *batch_gate, *bias, batch_gate); + } + + int frame_size = hidden_dims[1]; + phi::funcs::GRUMetaValue gru_value; + gru_value.gate_weight = const_cast(weight_data); + gru_value.state_weight = + const_cast(weight_data + 2 * frame_size * frame_size); + phi::DenseTensor ordered_h0; + + phi::Vector order(batch_gate->lod()[2]); + + if (h0) { + // Since the batch computing for GRU reorders the input sequences + // according to their length. The initialized cell state also needs + // to reorder. + ReorderInitState(dev_ctx, *h0, order, &ordered_h0, true); + gru_value.prev_out_value = ordered_h0.data(); + } else { + gru_value.prev_out_value = nullptr; + } + auto batch_starts = batch_gate->lod()[0]; + size_t num_batch = batch_starts.size() - 1; + auto active_node = phi::funcs::detail::GetActivationType(activation); + auto active_gate = phi::funcs::detail::GetActivationType(gate_activation); + for (size_t n = 0; n < num_batch; n++) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + int cur_batch_size = bend - bstart; + + phi::DenseTensor gate_t = batch_gate->Slice(bstart, bend); + phi::DenseTensor reset_hidden_prev_t = + batch_reset_hidden_prev->Slice(bstart, bend); + phi::DenseTensor hidden_t = batch_hidden->Slice(bstart, bend); + gru_value.output_value = hidden_t.data(); + gru_value.gate_value = gate_t.data(); + gru_value.reset_output_value = reset_hidden_prev_t.data(); + phi::funcs::GRUUnitFunctor::compute(dev_ctx, // NOLINT + gru_value, + frame_size, + cur_batch_size, + active_node, + active_gate, + origin_mode); + gru_value.prev_out_value = gru_value.output_value; + } + + phi::funcs::Batch2LoDTensorFunctor to_seq; + batch_hidden->set_lod(batch_gate->lod()); + to_seq(dev_ctx, *batch_hidden, hidden); +} +} // namespace phi + +PD_REGISTER_KERNEL(gru, GPU, ALL_LAYOUT, phi::GRUKernel, float, double) {} diff --git a/paddle/phi/kernels/impl/gru_kernel_impl.h b/paddle/phi/kernels/impl/gru_kernel_impl.h new file mode 100644 index 0000000000000..efff9056b0d47 --- /dev/null +++ b/paddle/phi/kernels/impl/gru_kernel_impl.h @@ -0,0 +1,173 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/kernels/funcs/detail/activation_functions.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/gru_compute.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/sequence2batch.h" +#include "paddle/utils/optional.h" + +namespace phi { + +template +inline void ReorderInitState(const Context &dev_ctx, + const phi::DenseTensor &src, + phi::Vector index_lod, + phi::DenseTensor *dst, + bool indexed_src) { + phi::funcs::CopyMatrixRowsFunctor row_shuffle; + dst->Resize(src.dims()); + dev_ctx.template Alloc(dst); + row_shuffle(dev_ctx, src, index_lod, dst, indexed_src); +} + +template +void GRUGradKernel(const Context &dev_ctx, + const DenseTensor &input, + const paddle::optional &h0_param, + const DenseTensor &weight, + const paddle::optional &bias, + const DenseTensor &batch_gate, + const DenseTensor &batch_reset_hidden_prev, + const DenseTensor &batch_hidden, + const DenseTensor &hidden, + const DenseTensor &hidden_grad, + const std::string &activation, + const std::string &gate_activation, + bool is_reverse, + bool origin_mode, + bool is_test, + DenseTensor *input_grad, + DenseTensor *h0_grad, + DenseTensor *weight_grad, + DenseTensor *bias_grad) { + auto *h0 = h0_param.get_ptr(); + const T *weight_data = weight.data(); + + auto gate_dims = batch_gate.dims(); + auto hidden_dims = hidden.dims(); + int frame_size = hidden_dims[1]; + + phi::funcs::LoDTensor2BatchFunctor to_batch; + phi::DenseTensor batch_hidden_grad, batch_gate_grad, + batch_reset_hidden_prev_grad; + batch_hidden_grad.Resize(hidden_dims); + batch_gate_grad.Resize(gate_dims); + batch_reset_hidden_prev_grad.Resize(hidden_dims); + dev_ctx.template Alloc(&batch_hidden_grad); + dev_ctx.template Alloc(&batch_gate_grad); + dev_ctx.template Alloc(&batch_reset_hidden_prev_grad); + + phi::funcs::SetConstant zero; + zero(dev_ctx, &batch_hidden_grad, static_cast(0.0)); + zero(dev_ctx, &batch_gate_grad, static_cast(0.0)); + zero(dev_ctx, &batch_reset_hidden_prev_grad, static_cast(0.0)); + + phi::DenseTensor ordered_h0, ordered_h0_grad; + + phi::Vector order(batch_gate.lod()[2]); + + if (h0) { + ReorderInitState(dev_ctx, *h0, order, &ordered_h0, true); + } + if (h0_grad) { + ordered_h0_grad.Resize(h0_grad->dims()); + dev_ctx.template Alloc(&ordered_h0_grad); + zero(dev_ctx, &ordered_h0_grad, static_cast(0.0)); + } + + batch_hidden_grad.set_lod(batch_hidden.lod()); + to_batch(dev_ctx, hidden_grad, &batch_hidden_grad, false, is_reverse); + + phi::funcs::GRUMetaValue gru_value; + gru_value.gate_weight = const_cast(weight_data); + gru_value.state_weight = + const_cast(weight_data + 2 * frame_size * frame_size); + + phi::funcs::GRUMetaGrad gru_grad; + if (weight_grad) { + gru_grad.gate_weight_grad = dev_ctx.template Alloc(weight_grad); + zero(dev_ctx, weight_grad, static_cast(0.0)); + gru_grad.state_weight_grad = + weight_grad->data() + 2 * frame_size * frame_size; + } else { + gru_grad.gate_weight_grad = nullptr; + gru_grad.state_weight_grad = nullptr; + } + + auto batch_starts = batch_hidden_grad.lod()[0]; + size_t num_batch = batch_starts.size() - 1; + auto active_node = phi::funcs::detail::GetActivationType(activation); + auto active_gate = phi::funcs::detail::GetActivationType(gate_activation); + for (int n = static_cast(num_batch) - 1; n >= 0; n--) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + int cur_batch_size = bend - bstart; + + phi::DenseTensor gate_t = batch_gate.Slice(bstart, bend); + gru_value.gate_value = gate_t.data(); + phi::DenseTensor reset_hidden_prev_t = + batch_reset_hidden_prev.Slice(bstart, bend); + gru_value.reset_output_value = reset_hidden_prev_t.data(); + + phi::DenseTensor hidden_grad_t = batch_hidden_grad.Slice(bstart, bend); + gru_grad.output_grad = hidden_grad_t.data(); + phi::DenseTensor gate_grad_t = batch_gate_grad.Slice(bstart, bend); + gru_grad.gate_grad = gate_grad_t.data(); + phi::DenseTensor reset_hidden_prev_grad_t = + batch_reset_hidden_prev_grad.Slice(bstart, bend); + gru_grad.reset_output_grad = reset_hidden_prev_grad_t.data(); + if (n == 0) { + gru_value.prev_out_value = h0 ? ordered_h0.data() : nullptr; + gru_grad.prev_out_grad = + h0 && h0_grad ? ordered_h0_grad.data() : nullptr; + } else { + int bstart_pre = static_cast(batch_starts[n - 1]); + phi::DenseTensor hidden_prev_t = batch_hidden.Slice(bstart_pre, bstart); + gru_value.prev_out_value = hidden_prev_t.data(); + phi::DenseTensor hidden_prev_grad_t = + batch_hidden_grad.Slice(bstart_pre, bstart); + gru_grad.prev_out_grad = hidden_prev_grad_t.data(); + } + gru_value.output_value = nullptr; + phi::funcs::GRUUnitGradFunctor::compute(dev_ctx, + gru_value, + gru_grad, + frame_size, + cur_batch_size, + active_node, + active_gate, + origin_mode); + } + if (input_grad) { + dev_ctx.template Alloc(input_grad); + phi::funcs::Batch2LoDTensorFunctor to_seq; + batch_gate_grad.set_lod(batch_gate.lod()); + to_seq(dev_ctx, batch_gate_grad, input_grad); + } + if (bias_grad) { + dev_ctx.template Alloc(bias_grad); + phi::funcs::ColwiseSum col_sum; + col_sum(dev_ctx, batch_gate_grad, bias_grad); + } + if (h0_param && h0_grad) { + ReorderInitState( + dev_ctx, ordered_h0_grad, order, h0_grad, false); + } +} +} // namespace phi diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index 01b0c1025b8e9..9f767535df3ca 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -1167,6 +1167,25 @@ optional: scale, bias inplace : (y_grad -> x_grad) +- backward_op : gru_grad + forward: gru (Tensor input, Tensor h0, Tensor weight, Tensor bias, str activation = "tanh", + str gate_activation = "sigmoid", bool is_reverse = false, bool origin_mode = false, bool is_test=false) -> + Tensor (batch_gate), Tensor (batch_reset_hidden_prev), Tensor (batch_hidden), + Tensor (hidden) + args: (Tensor input, Tensor h0, Tensor weight, Tensor bias, Tensor batch_gate, + Tensor batch_reset_hidden_prev, Tensor batch_hidden, Tensor hidden, + Tensor hidden_grad, str activation = "tanh", + str gate_activation = "sigmoid", bool is_reverse = false, bool origin_mode = false, bool is_test=false) + output: Tensor(input_grad), Tensor(h0_grad), Tensor(weight_grad), Tensor(bias_grad) + infer_meta: + func: GruGradInferMeta + param: [input, h0, weight, bias] + kernel: + func: gru_grad + data_type: hidden_grad + optional: h0, bias + no_need_buffer: input, bias + - backward_op : gumbel_softmax_grad forward : gumbel_softmax (Tensor x, float temperature, bool hard, int axis) -> Tensor(out) args : (Tensor out, Tensor out_grad, int axis) diff --git a/paddle/phi/ops/yaml/op_compat.yaml b/paddle/phi/ops/yaml/op_compat.yaml index 3435ea6c46789..5a2fae503e834 100755 --- a/paddle/phi/ops/yaml/op_compat.yaml +++ b/paddle/phi/ops/yaml/op_compat.yaml @@ -4069,6 +4069,13 @@ outputs : out : Out +- op: gru + backward: gru_grad + inputs: + {input : Input, h0 : H0, weight : Weight, bias : Bias} + outputs: + {batch_gate : BatchGate, batch_reset_hidden_prev : BatchResetHiddenPrev, batch_hidden : BatchHidden, hidden : Hidden} + - op: identity_loss inputs : x: X diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index a84f7f337af37..813c1a8681921 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -1548,6 +1548,20 @@ backward : group_norm_grad interfaces : paddle::dialect::InferSymbolicShapeInterface +- op : gru + args: (Tensor input, Tensor h0, Tensor weight, Tensor bias, str activation = "tanh", + str gate_activation = "sigmoid", bool is_reverse = false, bool origin_mode = false, bool is_test=false) + output: Tensor (batch_gate), Tensor (batch_reset_hidden_prev), Tensor (batch_hidden), + Tensor (hidden) + infer_meta: + func: GruInferMeta + kernel: + func: gru + data_type: input + optional: h0, bias + intermediate: batch_gate, batch_reset_hidden_prev, batch_hidden + backward: gru_grad + - op : gumbel_softmax args : (Tensor x, float temperature = 1.0, bool hard = false, int axis = -1) output : Tensor From cec6c6ab275ca773d96ab8231fa212d51c7ed529 Mon Sep 17 00:00:00 2001 From: co63oc Date: Sun, 26 May 2024 19:23:27 +0800 Subject: [PATCH 07/12] Fix --- paddle/phi/kernels/cpu/gru_kernel.cc | 2 +- paddle/phi/kernels/gpu/gru_kernel.cu | 4 ++-- paddle/phi/kernels/impl/gru_kernel_impl.h | 10 +++++----- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/paddle/phi/kernels/cpu/gru_kernel.cc b/paddle/phi/kernels/cpu/gru_kernel.cc index 2850bf2b4a524..224a1e913ea01 100644 --- a/paddle/phi/kernels/cpu/gru_kernel.cc +++ b/paddle/phi/kernels/cpu/gru_kernel.cc @@ -74,7 +74,7 @@ void GRUCPUKernel(const Context &dev_ctx, if (bias) { phi::funcs::RowwiseAdd add_bias; - add_bias(dev_ctx, *batch_gate, bias, batch_gate); + add_bias(dev_ctx, *batch_gate, bias.get(), batch_gate); } int frame_size = static_cast(hidden_dims[1]); diff --git a/paddle/phi/kernels/gpu/gru_kernel.cu b/paddle/phi/kernels/gpu/gru_kernel.cu index c9558a05f503d..a582b5fc12209 100644 --- a/paddle/phi/kernels/gpu/gru_kernel.cu +++ b/paddle/phi/kernels/gpu/gru_kernel.cu @@ -61,11 +61,11 @@ void GRUKernel(const Context &dev_ctx, dev_ctx.template Alloc(batch_reset_hidden_prev); dev_ctx.template Alloc(batch_hidden); - phi::funcs::LoDTensor2BatchFunctor to_batch; + phi::funcs::LoDTensor2BatchFunctor to_batch; to_batch(dev_ctx, input, batch_gate, true, is_reverse); if (bias) { - phi::funcs::RowwiseAdd add_bias; + phi::funcs::RowwiseAdd add_bias; add_bias(dev_ctx, *batch_gate, *bias, batch_gate); } diff --git a/paddle/phi/kernels/impl/gru_kernel_impl.h b/paddle/phi/kernels/impl/gru_kernel_impl.h index efff9056b0d47..07ec807513618 100644 --- a/paddle/phi/kernels/impl/gru_kernel_impl.h +++ b/paddle/phi/kernels/impl/gru_kernel_impl.h @@ -25,11 +25,11 @@ namespace phi { template -inline void ReorderInitState(const Context &dev_ctx, - const phi::DenseTensor &src, - phi::Vector index_lod, - phi::DenseTensor *dst, - bool indexed_src) { +void ReorderInitState(const Context &dev_ctx, + const phi::DenseTensor &src, + phi::Vector index_lod, + phi::DenseTensor *dst, + bool indexed_src) { phi::funcs::CopyMatrixRowsFunctor row_shuffle; dst->Resize(src.dims()); dev_ctx.template Alloc(dst); From 7fa5c5202a6299f5a807b65dd244ee60bcd56bb9 Mon Sep 17 00:00:00 2001 From: co63oc Date: Mon, 27 May 2024 15:06:35 +0800 Subject: [PATCH 08/12] Fix --- paddle/phi/infermeta/backward.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 4fdf26168ea49..ea102bee4b47e 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -627,7 +627,7 @@ void GruGradInferMeta(const MetaTensor& input, "received %d (width of H0) vs %d (frame_size).", h0_dims[1], frame_size)); - if (h0_grad->initialized()) { + if (h0_grad != nullptr) { h0_grad->set_dims(h0_dims); h0_grad->set_dtype(h0.dtype()); } @@ -654,16 +654,16 @@ void GruGradInferMeta(const MetaTensor& input, bias_height, bias_width, frame_size * 3)); - if (bias_grad->initialized()) { + if (bias_grad != nullptr) { bias_grad->set_dims(bias_dims); bias_grad->set_dtype(bias.dtype()); } } - if (input_grad->initialized()) { + if (input_grad != nullptr) { input_grad->set_dims(input_dims); input_grad->set_dtype(input.dtype()); } - if (weight_grad->initialized()) { + if (weight_grad != nullptr) { weight_grad->set_dims(weight_dims); weight_grad->set_dtype(weight.dtype()); } From b0ff9028e3d109b2fab2485ce407f3604da9bf4a Mon Sep 17 00:00:00 2001 From: co63oc Date: Mon, 27 May 2024 15:10:28 +0800 Subject: [PATCH 09/12] Fix --- paddle/phi/infermeta/backward.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 0a21dbf73bb4e..9085e87bc2ac9 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -647,21 +647,21 @@ void GruUnitGradInferMeta(const MetaTensor& input, bias_height, bias_width, frame_size * 3)); - if (bias_grad->initialized()) { + if (bias_grad != nullptr) { bias_grad->set_dims(bias_dims); bias_grad->set_dtype(bias.dtype()); } } - if (input_grad->initialized()) { + if (input_grad != nullptr) { input_grad->set_dims(input_dims); input_grad->set_dtype(input.dtype()); } - if (hidden_prev_grad->initialized()) { + if (hidden_prev_grad != nullptr) { hidden_prev_grad->set_dims(hidden_prev_dims); hidden_prev_grad->set_dtype(hidden_prev.dtype()); } - if (weight_grad->initialized()) { + if (weight_grad != nullptr) { weight_grad->set_dims(weight_dims); weight_grad->set_dtype(weight.dtype()); } From a6af2f3acb2b64482c4aca799b87b4758cc9b16e Mon Sep 17 00:00:00 2001 From: co63oc Date: Mon, 27 May 2024 15:48:45 +0800 Subject: [PATCH 10/12] Fix --- paddle/phi/infermeta/backward.cc | 12 ++-- paddle/phi/infermeta/backward.h | 20 +++---- paddle/phi/infermeta/multiary.cc | 100 +++++++++++++++---------------- paddle/phi/infermeta/multiary.h | 24 ++++---- 4 files changed, 78 insertions(+), 78 deletions(-) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 9085e87bc2ac9..88d91b24ebaec 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -568,12 +568,6 @@ void GeneralQuinaryGradInferMeta(const MetaTensor& x, } } -void GeneralUnaryGradInferMeta(const MetaTensor& x, MetaTensor* dx) { - if (dx) { - dx->share_meta(x); - } -} - void GruUnitGradInferMeta(const MetaTensor& input, const MetaTensor& hidden_prev, const MetaTensor& weight, @@ -667,6 +661,12 @@ void GruUnitGradInferMeta(const MetaTensor& input, } } +void GeneralUnaryGradInferMeta(const MetaTensor& x, MetaTensor* dx) { + if (dx) { + dx->share_meta(x); + } +} + void GumbelSoftmaxGradInferMeta(const MetaTensor& out, const MetaTensor& dout, int axis, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index e21be2b535490..fffff679da5e3 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -261,6 +261,16 @@ void GeneralQuaternaryGradInferMeta(const MetaTensor& x, MetaTensor* dz, MetaTensor* dk); +void GruUnitGradInferMeta(const MetaTensor& input, + const MetaTensor& hidden_prev, + const MetaTensor& weight, + const MetaTensor& bias, + MetaTensor* input_grad, + MetaTensor* hidden_prev_grad, + MetaTensor* weight_grad, + MetaTensor* bias_grad, + MetaConfig config = MetaConfig()); + void GeneralQuinaryGradInferMeta(const MetaTensor& x, const MetaTensor& y, const MetaTensor& z, @@ -272,16 +282,6 @@ void GeneralQuinaryGradInferMeta(const MetaTensor& x, MetaTensor* dk, MetaTensor* dl); -void GruUnitGradInferMeta(const MetaTensor& input, - const MetaTensor& hidden_prev, - const MetaTensor& weight, - const MetaTensor& bias, - MetaTensor* input_grad, - MetaTensor* hidden_prev_grad, - MetaTensor* weight_grad, - MetaTensor* bias_grad, - MetaConfig config = MetaConfig()); - void GumbelSoftmaxGradInferMeta(const MetaTensor& out, const MetaTensor& dout, int axis, diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 4c550df98034d..732ac73f228eb 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2375,56 +2375,6 @@ void GraphReindexInferMeta(const MetaTensor& x, out_nodes->set_dtype(x.dtype()); } -void GraphSampleNeighborsInferMeta(const MetaTensor& row, - const MetaTensor& col_ptr, - const MetaTensor& x, - const MetaTensor& eids, - const MetaTensor& perm_buffer, - int sample_size, - bool return_eids, - bool flag_perm_buffer, - MetaTensor* out, - MetaTensor* out_count, - MetaTensor* out_eids) { - // GSN: GraphSampleNeighbors - auto GSNShapeCheck = [](const phi::DDim& dims, std::string tensor_name) { - if (dims.size() == 2) { - PADDLE_ENFORCE_EQ( - dims[1], - 1, - phi::errors::InvalidArgument("The last dim of %s should be 1 when it " - "is 2D, but we get %d", - tensor_name, - dims[1])); - } else { - PADDLE_ENFORCE_EQ( - dims.size(), - 1, - phi::errors::InvalidArgument( - "The %s should be 1D, when it is not 2D, but we get %d", - tensor_name, - dims.size())); - } - }; - - GSNShapeCheck(row.dims(), "Row"); - GSNShapeCheck(col_ptr.dims(), "Col_Ptr"); - GSNShapeCheck(x.dims(), "X"); - if (return_eids) { - GSNShapeCheck(eids.dims(), "Eids"); - out_eids->set_dims({-1}); - out_eids->set_dtype(row.dtype()); - } - if (flag_perm_buffer) { - GSNShapeCheck(perm_buffer.dims(), "Perm_Buffer"); - } - - out->set_dims({-1}); - out->set_dtype(row.dtype()); - out_count->set_dims({-1}); - out_count->set_dtype(DataType::INT32); -} - void GruUnitInferMeta(const MetaTensor& input, const MetaTensor& hidden_prev, const MetaTensor& weight, @@ -2509,6 +2459,56 @@ void GruUnitInferMeta(const MetaTensor& input, hidden->set_dtype(input.dtype()); } +void GraphSampleNeighborsInferMeta(const MetaTensor& row, + const MetaTensor& col_ptr, + const MetaTensor& x, + const MetaTensor& eids, + const MetaTensor& perm_buffer, + int sample_size, + bool return_eids, + bool flag_perm_buffer, + MetaTensor* out, + MetaTensor* out_count, + MetaTensor* out_eids) { + // GSN: GraphSampleNeighbors + auto GSNShapeCheck = [](const phi::DDim& dims, std::string tensor_name) { + if (dims.size() == 2) { + PADDLE_ENFORCE_EQ( + dims[1], + 1, + phi::errors::InvalidArgument("The last dim of %s should be 1 when it " + "is 2D, but we get %d", + tensor_name, + dims[1])); + } else { + PADDLE_ENFORCE_EQ( + dims.size(), + 1, + phi::errors::InvalidArgument( + "The %s should be 1D, when it is not 2D, but we get %d", + tensor_name, + dims.size())); + } + }; + + GSNShapeCheck(row.dims(), "Row"); + GSNShapeCheck(col_ptr.dims(), "Col_Ptr"); + GSNShapeCheck(x.dims(), "X"); + if (return_eids) { + GSNShapeCheck(eids.dims(), "Eids"); + out_eids->set_dims({-1}); + out_eids->set_dtype(row.dtype()); + } + if (flag_perm_buffer) { + GSNShapeCheck(perm_buffer.dims(), "Perm_Buffer"); + } + + out->set_dims({-1}); + out->set_dtype(row.dtype()); + out_count->set_dims({-1}); + out_count->set_dtype(DataType::INT32); +} + void HSigmoidLossInferMeta(const MetaTensor& x, const MetaTensor& label, const MetaTensor& w, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 9988e0c1c5ce5..6f5ce5488310f 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -492,18 +492,6 @@ void GraphReindexInferMeta(const MetaTensor& x, MetaTensor* reindex_dst, MetaTensor* out_nodes); -void GraphSampleNeighborsInferMeta(const MetaTensor& row, - const MetaTensor& col_ptr, - const MetaTensor& x, - const MetaTensor& eids, - const MetaTensor& perm_buffer, - int sample_size, - bool return_eids, - bool flag_perm_buffer, - MetaTensor* out, - MetaTensor* out_count, - MetaTensor* out_eids); - void GruUnitInferMeta(const MetaTensor& input, const MetaTensor& hidden_prev, const MetaTensor& weight, @@ -516,6 +504,18 @@ void GruUnitInferMeta(const MetaTensor& input, MetaTensor* hidden, MetaConfig config = MetaConfig()); +void GraphSampleNeighborsInferMeta(const MetaTensor& row, + const MetaTensor& col_ptr, + const MetaTensor& x, + const MetaTensor& eids, + const MetaTensor& perm_buffer, + int sample_size, + bool return_eids, + bool flag_perm_buffer, + MetaTensor* out, + MetaTensor* out_count, + MetaTensor* out_eids); + void HSigmoidLossInferMeta(const MetaTensor& x, const MetaTensor& label, const MetaTensor& w, From cb0ee7c85c194b78e9c8cedd828e3c7fc2c20a76 Mon Sep 17 00:00:00 2001 From: co63oc Date: Mon, 27 May 2024 15:51:57 +0800 Subject: [PATCH 11/12] Fix --- paddle/phi/infermeta/multiary.cc | 94 ++++++++++++++++---------------- paddle/phi/infermeta/multiary.h | 18 +++--- 2 files changed, 56 insertions(+), 56 deletions(-) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 732ac73f228eb..183074783c926 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2328,53 +2328,6 @@ void GraphKhopSamplerInferMeta(const MetaTensor& row, reindex_x->set_dtype(x.dtype()); } -void GraphReindexInferMeta(const MetaTensor& x, - const MetaTensor& neighbors, - const MetaTensor& count, - const MetaTensor& hashtable_value, - const MetaTensor& hashtable_index, - MetaTensor* reindex_src, - MetaTensor* reindex_dst, - MetaTensor* out_nodes) { - bool flag_buffer_hashtable = - hashtable_value.initialized() && hashtable_index.initialized(); - auto GraphReindexShapeCheck = [](const phi::DDim& dims, - std::string tensor_name) { - if (dims.size() == 2) { - PADDLE_ENFORCE_EQ( - dims[1], - 1, - phi::errors::InvalidArgument("The last dim of %s should be 1 when it " - "is 2D, but we get %d", - tensor_name, - dims[1])); - } else { - PADDLE_ENFORCE_EQ( - dims.size(), - 1, - phi::errors::InvalidArgument( - "The %s should be 1D, when it is not 2D, but we get %d", - tensor_name, - dims.size())); - } - }; - - GraphReindexShapeCheck(x.dims(), "X"); - GraphReindexShapeCheck(neighbors.dims(), "Neighbors"); - GraphReindexShapeCheck(count.dims(), "Count"); - if (flag_buffer_hashtable) { - GraphReindexShapeCheck(hashtable_value.dims(), "HashTable_Value"); - GraphReindexShapeCheck(hashtable_index.dims(), "HashTable_Index"); - } - - reindex_src->set_dims({-1}); - reindex_src->set_dtype(neighbors.dtype()); - reindex_dst->set_dims({-1}); - reindex_dst->set_dtype(neighbors.dtype()); - out_nodes->set_dims({-1}); - out_nodes->set_dtype(x.dtype()); -} - void GruUnitInferMeta(const MetaTensor& input, const MetaTensor& hidden_prev, const MetaTensor& weight, @@ -2459,6 +2412,53 @@ void GruUnitInferMeta(const MetaTensor& input, hidden->set_dtype(input.dtype()); } +void GraphReindexInferMeta(const MetaTensor& x, + const MetaTensor& neighbors, + const MetaTensor& count, + const MetaTensor& hashtable_value, + const MetaTensor& hashtable_index, + MetaTensor* reindex_src, + MetaTensor* reindex_dst, + MetaTensor* out_nodes) { + bool flag_buffer_hashtable = + hashtable_value.initialized() && hashtable_index.initialized(); + auto GraphReindexShapeCheck = [](const phi::DDim& dims, + std::string tensor_name) { + if (dims.size() == 2) { + PADDLE_ENFORCE_EQ( + dims[1], + 1, + phi::errors::InvalidArgument("The last dim of %s should be 1 when it " + "is 2D, but we get %d", + tensor_name, + dims[1])); + } else { + PADDLE_ENFORCE_EQ( + dims.size(), + 1, + phi::errors::InvalidArgument( + "The %s should be 1D, when it is not 2D, but we get %d", + tensor_name, + dims.size())); + } + }; + + GraphReindexShapeCheck(x.dims(), "X"); + GraphReindexShapeCheck(neighbors.dims(), "Neighbors"); + GraphReindexShapeCheck(count.dims(), "Count"); + if (flag_buffer_hashtable) { + GraphReindexShapeCheck(hashtable_value.dims(), "HashTable_Value"); + GraphReindexShapeCheck(hashtable_index.dims(), "HashTable_Index"); + } + + reindex_src->set_dims({-1}); + reindex_src->set_dtype(neighbors.dtype()); + reindex_dst->set_dims({-1}); + reindex_dst->set_dtype(neighbors.dtype()); + out_nodes->set_dims({-1}); + out_nodes->set_dtype(x.dtype()); +} + void GraphSampleNeighborsInferMeta(const MetaTensor& row, const MetaTensor& col_ptr, const MetaTensor& x, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 6f5ce5488310f..5e445d6e90509 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -483,15 +483,6 @@ void GraphKhopSamplerInferMeta(const MetaTensor& row, MetaTensor* reindex_x, MetaTensor* out_eids); -void GraphReindexInferMeta(const MetaTensor& x, - const MetaTensor& neighbors, - const MetaTensor& count, - const MetaTensor& hashtable_value, - const MetaTensor& hashtable_index, - MetaTensor* reindex_src, - MetaTensor* reindex_dst, - MetaTensor* out_nodes); - void GruUnitInferMeta(const MetaTensor& input, const MetaTensor& hidden_prev, const MetaTensor& weight, @@ -504,6 +495,15 @@ void GruUnitInferMeta(const MetaTensor& input, MetaTensor* hidden, MetaConfig config = MetaConfig()); +void GraphReindexInferMeta(const MetaTensor& x, + const MetaTensor& neighbors, + const MetaTensor& count, + const MetaTensor& hashtable_value, + const MetaTensor& hashtable_index, + MetaTensor* reindex_src, + MetaTensor* reindex_dst, + MetaTensor* out_nodes); + void GraphSampleNeighborsInferMeta(const MetaTensor& row, const MetaTensor& col_ptr, const MetaTensor& x, From a44a57cf4ced8af62df4decc4fae72d07c25966c Mon Sep 17 00:00:00 2001 From: co63oc Date: Tue, 28 May 2024 13:49:51 +0800 Subject: [PATCH 12/12] Fix --- paddle/phi/ops/yaml/op_compat.yaml | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/paddle/phi/ops/yaml/op_compat.yaml b/paddle/phi/ops/yaml/op_compat.yaml index 8e6536ffed56b..e3d0561fb2b84 100755 --- a/paddle/phi/ops/yaml/op_compat.yaml +++ b/paddle/phi/ops/yaml/op_compat.yaml @@ -1783,11 +1783,6 @@ attrs: data_format: data_layout -- op : gru - backward : gru_grad - extra : - attrs : [bool is_test = false] - - op : gumbel_softmax inputs : x : X @@ -4081,6 +4076,9 @@ {input : Input, h0 : H0, weight : Weight, bias : Bias} outputs: {batch_gate : BatchGate, batch_reset_hidden_prev : BatchResetHiddenPrev, batch_hidden : BatchHidden, hidden : Hidden} + extra : + attrs : [bool is_test = false] + outputs : [batch_gate, batch_reset_hidden_prev, batch_hidden] - op: gru_unit backward: gru_unit_grad