Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Browse files Browse the repository at this point in the history
… fix_sparse_reshape_bug
  • Loading branch information
risemeup1 committed May 21, 2024
2 parents 77a15d1 + 15888d1 commit 4fba58f
Show file tree
Hide file tree
Showing 168 changed files with 1,231 additions and 919 deletions.
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,6 @@ paddle/fluid/operators/generated_sparse_op.cc
paddle/fluid/operators/generated_static_op.cc
paddle/fluid/operators/generated_fused_op.cc
paddle/fluid/operators/ops_signature/generated_*.cc
paddle/phi/api/yaml/parsed_apis/
paddle/fluid/operators/generator/parsed_ops/
paddle/fluid/pybind/tmp_eager_op_function_impl.h
paddle/fluid/pybind/eager_op_function_impl.h
paddle/fluid/pybind/eager_op_function_impl.h
Expand Down
37 changes: 29 additions & 8 deletions paddle/cinn/frontend/op_mappers/paddle/relu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,22 @@

#include "paddle/cinn/frontend/op_mapper_registry.h"
#include "paddle/cinn/frontend/op_mappers/common_utils.h"

#include "paddle/common/enforce.h"
namespace cinn {
namespace frontend {
namespace paddle_mappers {

void ReluOpMapper(const paddle::cpp::OpDesc& op_desc,
const OpMapperContext& ctx) {
CHECK_EQ(op_desc.Input("X").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("X").size(),
1UL,
phi::errors::InvalidArgument("The input of Relu op must be 1."));
auto x_name = op_desc.Input("X").front();
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Output("Out").size(),
1UL,
phi::errors::InvalidArgument("The output of Relu op must be 1."));
auto out_name = op_desc.Output("Out").front();
auto x = ctx.GetVar(x_name);
auto out = ctx.Builder()->Relu(x);
Expand All @@ -34,9 +40,15 @@ void ReluOpMapper(const paddle::cpp::OpDesc& op_desc,

void Relu6OpMapper(const paddle::cpp::OpDesc& op_desc,
const OpMapperContext& ctx) {
CHECK_EQ(op_desc.Input("X").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("X").size(),
1UL,
phi::errors::InvalidArgument("The input of Relu6 op must be 1."));
auto x_name = op_desc.Input("X").front();
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Output("Out").size(),
1UL,
phi::errors::InvalidArgument("The output of Relu6 op must be 1."));
auto out_name = op_desc.Output("Out").front();

auto threshold = utils::GetAttrOrDefault<float>(op_desc, "threshold", 6.0f);
Expand All @@ -49,11 +61,20 @@ void Relu6OpMapper(const paddle::cpp::OpDesc& op_desc,

void ReluGradOpMapper(const paddle::cpp::OpDesc& op_desc,
const OpMapperContext& ctx) {
CHECK_EQ(op_desc.Input(paddle::GradVarName("Out")).size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input(paddle::GradVarName("Out")).size(),
1UL,
phi::errors::InvalidArgument("The input of ReluGrad op must be 1."));
auto dout_name = op_desc.Input(paddle::GradVarName("Out")).front();
CHECK_EQ(op_desc.Input("Out").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("Out").size(),
1UL,
phi::errors::InvalidArgument("The input of ReluGrad op must be 1."));
auto out_name = op_desc.Input("Out").front();
CHECK_EQ(op_desc.Output(paddle::GradVarName("X")).size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Output(paddle::GradVarName("X")).size(),
1UL,
phi::errors::InvalidArgument("The output of ReluGrad op must be 1."));
auto dx_name = op_desc.Output(paddle::GradVarName("X")).front();

auto dout = ctx.GetVar(dout_name);
Expand Down
47 changes: 37 additions & 10 deletions paddle/cinn/frontend/op_mappers/paddle/reshape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@
#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/frontend/op_mapper_registry.h"
#include "paddle/cinn/frontend/op_mappers/common_utils.h"

#include "paddle/common/enforce.h"
namespace cinn {
namespace frontend {
namespace paddle_mappers {

void ReshapeOpMapper(const paddle::cpp::OpDesc& op_desc,
const OpMapperContext& ctx) {
CHECK_EQ(op_desc.Input("X").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("X").size(),
1UL,
phi::errors::InvalidArgument("The input of Reshape op must be 1."));
auto x_name = op_desc.Input("X").front();
auto x = ctx.GetVar(x_name);

Expand All @@ -33,7 +36,10 @@ void ReshapeOpMapper(const paddle::cpp::OpDesc& op_desc,

auto out = ctx.Builder()->Reshape(x, shape);

CHECK_EQ(op_desc.Output("Out").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Output("Out").size(),
1UL,
phi::errors::InvalidArgument("The output of Reshape op must be 1."));
auto out_name = op_desc.Output("Out").front();
ctx.AddVar(out_name, out);
ctx.AddVarModelToProgram(out_name, out->id);
Expand All @@ -42,13 +48,19 @@ void ReshapeOpMapper(const paddle::cpp::OpDesc& op_desc,
void ReshapeGradOpMapper(const paddle::cpp::OpDesc& op_desc,
const OpMapperContext& ctx) {
auto get_input_var = [&op_desc, &ctx](const std::string& op_name) {
CHECK_EQ(op_desc.Input(op_name).size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input(op_name).size(),
1UL,
phi::errors::InvalidArgument("The input of ReshapeGrad op must be 1."));
auto var_name = op_desc.Input(op_name).front();
return ctx.GetVar(var_name);
};

auto get_output_name = [&op_desc](const std::string& op_name) {
CHECK_EQ(op_desc.Output(op_name).size(), 1UL);
PADDLE_ENFORCE_EQ(op_desc.Output(op_name).size(),
1UL,
phi::errors::InvalidArgument(
"The output of ReshapeGrad op must be 1."));
return op_desc.Output(op_name).front();
};

Expand All @@ -67,7 +79,10 @@ void ReshapeGradOpMapper(const paddle::cpp::OpDesc& op_desc,

void Reshape2OpMapper(const paddle::cpp::OpDesc& op_desc,
const OpMapperContext& ctx) {
CHECK_EQ(op_desc.Input("X").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("X").size(),
1UL,
phi::errors::InvalidArgument("The input of Reshape2 op must be 1."));
auto x_name = op_desc.Input("X").front();
auto x = ctx.GetVar(x_name);

Expand All @@ -78,7 +93,10 @@ void Reshape2OpMapper(const paddle::cpp::OpDesc& op_desc,

auto out = ctx.Builder()->Reshape(x, shape);

CHECK_EQ(op_desc.Output("Out").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Output("Out").size(),
1UL,
phi::errors::InvalidArgument("The output of Reshape2 op must be 1."));
auto out_name = op_desc.Output("Out").front();
ctx.AddVar(out_name, out);
ctx.AddVarModelToProgram(out_name, out->id);
Expand All @@ -89,7 +107,10 @@ void Reshape2OpMapper(const paddle::cpp::OpDesc& op_desc,
// will be used in Reshape_grad, in this way, the framework can reuse
// the memory of X immediately the Reshape2_op is finished.
// Considering compatibility issues, we could not fix Reshape2_op
CHECK_EQ(op_desc.Output("XShape").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Output("XShape").size(),
1UL,
phi::errors::InvalidArgument("The output of Reshape2 op must be 1."));
auto xshape_name = op_desc.Output("XShape").front();

auto xshape = ctx.Builder()->Identity(x);
Expand All @@ -102,13 +123,19 @@ void Reshape2OpMapper(const paddle::cpp::OpDesc& op_desc,
void Reshape2GradOpMapper(const paddle::cpp::OpDesc& op_desc,
const OpMapperContext& ctx) {
auto get_input_var = [&op_desc, &ctx](const std::string& op_name) {
CHECK_EQ(op_desc.Input(op_name).size(), 1UL);
PADDLE_ENFORCE_EQ(op_desc.Input(op_name).size(),
1UL,
phi::errors::InvalidArgument(
"The input of Reshape2Grad op must be 1."));
auto var_name = op_desc.Input(op_name).front();
return ctx.GetVar(var_name);
};

auto get_output_name = [&op_desc](const std::string& op_name) {
CHECK_EQ(op_desc.Output(op_name).size(), 1UL);
PADDLE_ENFORCE_EQ(op_desc.Output(op_name).size(),
1UL,
phi::errors::InvalidArgument(
"The output of Reshape2Grad op must be 1."));
return op_desc.Output(op_name).front();
};

Expand Down
12 changes: 9 additions & 3 deletions paddle/cinn/frontend/op_mappers/paddle/reverse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,22 @@

#include "paddle/cinn/frontend/op_mapper_registry.h"
#include "paddle/cinn/frontend/op_mappers/common_utils.h"

#include "paddle/common/enforce.h"
namespace cinn {
namespace frontend {
namespace paddle_mappers {

void ReverseOpMapper(const paddle::cpp::OpDesc& op_desc,
const OpMapperContext& ctx) {
CHECK_EQ(op_desc.Input("X").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("X").size(),
1UL,
phi::errors::InvalidArgument("The input of Reverse op must be 1."));
auto x_name = op_desc.Input("X").front();
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Output("Out").size(),
1UL,
phi::errors::InvalidArgument("The output of Reverse op must be 1."));
auto out_name = op_desc.Output("Out").front();

auto axes = utils::GetAttrOrDefault<std::vector<int>>(
Expand Down
47 changes: 34 additions & 13 deletions paddle/cinn/frontend/op_mappers/paddle/roll.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,35 @@
#include "paddle/cinn/frontend/op_mapper_registry.h"
#include "paddle/cinn/frontend/op_mappers/common_utils.h"
#include "paddle/cinn/frontend/var_type_utils.h"

#include "paddle/common/enforce.h"
namespace cinn {
namespace frontend {
namespace paddle_mappers {

void RollOpMapper(const paddle::cpp::OpDesc& op_desc,
const OpMapperContext& ctx) {
// input
CHECK_EQ(op_desc.Input("X").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("X").size(),
1UL,
phi::errors::InvalidArgument("The input of Roll op must be 1."));
auto x_name = op_desc.Input("X").front();
// output
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Output("Out").size(),
1UL,
phi::errors::InvalidArgument("The output of Roll op must be 1."));
auto out_name = op_desc.Output("Out").front();

// attr shifts and axis
CHECK(op_desc.HasAttr("shifts"));
CHECK(op_desc.HasAttr("axis"));
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("shifts"),
true,
phi::errors::InvalidArgument("Roll op must have shifts attribute"));
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("axis"),
true,
phi::errors::InvalidArgument("Roll op must have axis attribute"));
std::vector<int> shifts = utils::ToShapeType(
utils::GetAttrOrDefault<std::vector<int64_t>>(op_desc, "shifts", {1}));
std::vector<int> axis = utils::ToShapeType(
Expand All @@ -44,8 +56,11 @@ void RollOpMapper(const paddle::cpp::OpDesc& op_desc,
// check axis and shifts and when axis is None, we should flatten x.
bool axis_None = false;
if (axis.size() == 0) {
CHECK_EQ(shifts.size(), 1)
<< "shifts.size() should be equal to 1 when axis is None";
PADDLE_ENFORCE_EQ(
shifts.size(),
1UL,
phi::errors::InvalidArgument(
"shifts.size() should be equal to 1 when axis is None"));
axis.push_back(0);
axis_None = true;
int reshape_num = 1;
Expand All @@ -55,19 +70,25 @@ void RollOpMapper(const paddle::cpp::OpDesc& op_desc,
vec_x_dims = std::vector<int>{reshape_num};
x = ctx.Builder()->Reshape(x, vec_x_dims);
} else {
CHECK_EQ(shifts.size(), axis.size())
<< "shifts.size() should be equal to axis.size()";
PADDLE_ENFORCE_EQ(shifts.size(),
axis.size(),
phi::errors::InvalidArgument(
"shifts.size() should be equal to axis.size()"));
}

// preprocessing the shifts and axis
int shifts_size = shifts.size();
std::unordered_map<int, int> axis_to_shifts;
for (int i = 0; i < shifts_size; ++i) {
int vec_x_dims_size = vec_x_dims.size();
CHECK_GE(axis[i], -vec_x_dims_size)
<< "axis value should be >= " << -vec_x_dims_size;
CHECK_LT(axis[i], vec_x_dims_size)
<< "axis value should be < " << vec_x_dims_size;
PADDLE_ENFORCE_GE(axis[i],
-vec_x_dims_size,
phi::errors::InvalidArgument(
"axis value should be >= -vec_x_dims_size"));
PADDLE_ENFORCE_LT(
axis[i],
vec_x_dims_size,
phi::errors::InvalidArgument("axis value should be < vec_x_dims_size"));
if (axis[i] < 0) {
axis[i] += vec_x_dims_size;
}
Expand Down
17 changes: 13 additions & 4 deletions paddle/cinn/frontend/op_mappers/paddle/scale.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,22 @@
#include "paddle/cinn/frontend/op_mapper_registry.h"
#include "paddle/cinn/frontend/op_mappers/common_utils.h"
#include "paddle/cinn/utils/string.h"

#include "paddle/common/enforce.h"
namespace cinn {
namespace frontend {
namespace paddle_mappers {

void ScaleOpMapper(const paddle::cpp::OpDesc& op_desc,
const cinn::frontend::OpMapperContext& ctx) {
CHECK_EQ(op_desc.Input("X").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("X").size(),
1UL,
phi::errors::InvalidArgument("The input of Scale op must be 1."));
auto x_name = op_desc.Input("X").front();
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Output("Out").size(),
1UL,
phi::errors::InvalidArgument("The output of Scale op must be 1."));
auto out_name = op_desc.Output("Out").front();

auto bias = utils::GetAttrOrDefault<float>(op_desc, "bias", 0.0f);
Expand All @@ -38,7 +44,10 @@ void ScaleOpMapper(const paddle::cpp::OpDesc& op_desc,
absl::optional<Variable> out;
if (op_desc.HasInput("ScaleTensor") &&
!op_desc.Input("ScaleTensor").empty()) {
CHECK_EQ(op_desc.Input("ScaleTensor").size(), 1);
PADDLE_ENFORCE_EQ(
op_desc.Input("ScaleTensor").size(),
1UL,
phi::errors::InvalidArgument("The input of ScaleTensor must be 1."));
auto scale_name = op_desc.Input("ScaleTensor").front();
auto scale_tensor = ctx.GetVar(scale_name);

Expand Down
Loading

0 comments on commit 4fba58f

Please sign in to comment.