Skip to content

Commit

Permalink
move dgc_momentum InferShape to phi (#56358)
Browse files Browse the repository at this point in the history
  • Loading branch information
huangjiyi authored Aug 18, 2023
1 parent ee01d78 commit a533dae
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 87 deletions.
95 changes: 8 additions & 87 deletions paddle/fluid/operators/optimizers/dgc_momentum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

#include <string>

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/multiary.h"

namespace paddle {
namespace operators {
Expand All @@ -24,92 +26,6 @@ class DGCMomentumOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("current_step"),
"Input",
"current_step",
"DGCMomentumOp");
OP_INOUT_CHECK(ctx->HasInput("nranks"), "Input", "nranks", "DGCMomentumOp");
OP_INOUT_CHECK(
ctx->HasOutput("Grad_out"), "Output", "Grad_out", "DGCMomentumOp");

PADDLE_ENFORCE_EQ(ctx->HasInput("Param"),
true,
platform::errors::NotFound(
"Input(param) of Momentum should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Grad"),
true,
platform::errors::NotFound(
"Input(grad) of Momentum should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Velocity"),
true,
platform::errors::NotFound(
"Input(velocity) of Momentum should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("LearningRate"),
true,
platform::errors::NotFound(
"Input(LearningRate) of Momentum should not be null."));
PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("Param").front(),
framework::proto::VarType::LOD_TENSOR,
platform::errors::InvalidArgument(
"The input var's type should be phi::DenseTensor, "
"but the received is %s",
ctx->GetInputsVarType("Param").front()));

PADDLE_ENFORCE_EQ(ctx->HasOutput("ParamOut"),
true,
platform::errors::NotFound(
"Output(ParamOut) of Momentum should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("VelocityOut"),
true,
platform::errors::NotFound(
"Output(VelocityOut) of Momentum should not be null."));

auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_NE(phi::product(lr_dims),
0,
platform::errors::InvalidArgument(
"Maybe the Input variable LearningRate has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function."));
PADDLE_ENFORCE_EQ(phi::product(lr_dims),
1,
platform::errors::InvalidArgument(
"Learning_rate should be a scalar. But Received "
"LearningRate's dim [%s]",
phi::product(lr_dims)));

auto param_dim = ctx->GetInputDim("Param");
if (ctx->GetInputsVarType("Grad")[0] ==
framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(
param_dim,
ctx->GetInputDim("Grad"),
platform::errors::InvalidArgument(
"Param and Grad input of MomentumOp should have the same "
"dimension. But received Param's dim [%s] and Grad's dim [%s].",
param_dim,
ctx->GetInputDim("Grad")));
PADDLE_ENFORCE_EQ(
param_dim,
ctx->GetInputDim("Velocity"),
platform::errors::InvalidArgument(
"Param and Velocity of MomentumOp should have the same "
"dimension. But received Param's dim [%s] and Velocity [%s].",
param_dim,
ctx->GetInputDim("Velocity")));
}

ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("VelocityOut", param_dim);
if (ctx->HasOutput("MasterParamOut")) {
ctx->SetOutputDim("MasterParamOut", param_dim);
}
}

phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
Expand Down Expand Up @@ -199,7 +115,12 @@ DGC Momentum Operator.
} // namespace operators
} // namespace paddle

DECLARE_INFER_SHAPE_FUNCTOR(dgc_momentum,
DGCMomentumInferShapeFunctor,
PD_INFER_META(phi::DGCMomentumInferMeta));

namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(dgc_momentum,
ops::DGCMomentumOp,
ops::DGCMomentumOpMaker);
ops::DGCMomentumOpMaker,
DGCMomentumInferShapeFunctor);
61 changes: 61 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1305,6 +1305,67 @@ void DeformableConvInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}

void DGCMomentumInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& velocity,
const MetaTensor& learning_rate,
const MetaTensor& master_param,
const MetaTensor& current_step_tensor,
const MetaTensor& nranks_tensor,
float mu,
bool use_nesterov,
const std::string& regularization_method,
float regularization_coeff,
bool multi_precision,
float rescale_grad,
float rampup_begin_step,
MetaTensor* param_out,
MetaTensor* velocity_out,
MetaTensor* master_param_out,
MetaTensor* grad_out) {
auto lr_dims = learning_rate.dims();

PADDLE_ENFORCE_NE(phi::product(lr_dims),
0,
phi::errors::InvalidArgument(
"Maybe the Input variable LearningRate has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function."));
PADDLE_ENFORCE_EQ(phi::product(lr_dims),
1,
phi::errors::InvalidArgument(
"Learning_rate should be a scalar. But Received "
"LearningRate's dim [%s]",
phi::product(lr_dims)));

auto param_dims = param.dims();
auto grad_dims = grad.dims();
auto velocity_dims = velocity.dims();
PADDLE_ENFORCE_EQ(
param_dims,
grad_dims,
phi::errors::InvalidArgument(
"Param and Grad input of MomentumOp should have the same "
"dimension. But received Param's dim [%s] and Grad's dim [%s].",
param_dims,
grad_dims));
PADDLE_ENFORCE_EQ(
param_dims,
velocity_dims,
phi::errors::InvalidArgument(
"Param and Velocity of MomentumOp should have the same "
"dimension. But received Param's dim [%s] and Velocity [%s].",
param_dims,
velocity_dims));

param_out->set_dims(param_dims);
velocity_out->set_dims(param_dims);
if (master_param != nullptr) {
master_param_out->set_dims(param_dims);
}
}

void EditDistanceInferMeta(const MetaTensor& hyps,
const MetaTensor& refs,
const MetaTensor& hypslength,
Expand Down
19 changes: 19 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,25 @@ void DeformableConvInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());

void DGCMomentumInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& velocity,
const MetaTensor& learning_rate,
const MetaTensor& master_param,
const MetaTensor& current_step_tensor,
const MetaTensor& nranks_tensor,
float mu,
bool use_nesterov,
const std::string& regularization_method,
float regularization_coeff,
bool multi_precision,
float rescale_grad,
float rampup_begin_step,
MetaTensor* param_out,
MetaTensor* velocity_out,
MetaTensor* master_param_out,
MetaTensor* grad_out);

void EditDistanceInferMeta(const MetaTensor& hyps,
const MetaTensor& refs,
const MetaTensor& hypslength,
Expand Down

0 comments on commit a533dae

Please sign in to comment.