Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

matmul and matmul_v2 refactor #42732

Merged
merged 5 commits into from
May 18, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 3 additions & 68 deletions paddle/fluid/operators/matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -686,76 +686,11 @@ class MatMulOp : public framework::OperatorWithKernel {
context->Attrs().Get<std::vector<int>>("fused_transpose_Out");

if (!reshape_out.empty() && !transpose_out.empty()) {
auto reshape_out_size = reshape_out.size();
auto transpose_out_size = transpose_out.size();
PADDLE_ENFORCE_EQ(transpose_out_size, 4,
platform::errors::InvalidArgument(
"transpose_out supported rank is 4, "
"received %d",
transpose_out_size));
const std::vector<int> supported_axis{0, 2, 1, 3};
const bool supported_transpose_axis = std::equal(
transpose_out.begin(), transpose_out.end(), supported_axis.begin());
PADDLE_ENFORCE_EQ(
supported_transpose_axis, true,
platform::errors::InvalidArgument(
"supported transpose axis for the fuse are {0, 2, 1, 3}"));
PADDLE_ENFORCE_EQ(
reshape_out_size, 3,
platform::errors::InvalidArgument("reshape_out supported rank is 3, "
"received %d",
reshape_out_size));

// int num_negative = std::count(reshape_out.begin(), reshape_out.end(),
// -1);
// PADDLE_ENFORCE_LE(num_negative, 1,
// platform::errors::InvalidArgument(
// "The max number of -1 in fused_reshape_Out is 1 "
// "but received %d.",
// num_negative));

// auto it_zero = std::find(reshape_out.begin(), reshape_out.end(), 0);
// if (it_zero != reshape_out.end()) {
// for (uint64_t i = 0; i < reshape_out.size(); i++) {
// if (reshape_out[i] == 0) {
// PADDLE_ENFORCE_LT(
// i, ddim_out.size(),
// platform::errors::InvalidArgument(
// "The index of 0 in fused_reshape_Out ",
// "should be less than output dim size, ",
// "but the index is %d and output dim size is %d", i,
// ddim_out.size()));
// reshape_out[i] = ddim_out.at(i);
// }
// }
// }

// if "-1" is present then one of reshape dims must be infered
auto it = std::find(reshape_out.begin(), reshape_out.end(), -1);
Silv3S marked this conversation as resolved.
Show resolved Hide resolved
if (it != reshape_out.end()) {
int index = std::distance(reshape_out.begin(), it);

auto ddim_out_vec = phi::vectorize(ddim_out);

int ddim_out_product =
std::accumulate(ddim_out_vec.begin(), ddim_out_vec.end(), 1,
std::multiplies<int>());
int reshape_out_product = std::accumulate(
reshape_out.begin(), reshape_out.end(), -1, std::multiplies<int>());

reshape_out[index] = ddim_out_product / reshape_out_product;
}

framework::DDim shape_out =
ddim_out.transpose(transpose_out).reshape(reshape_out);
context->SetOutputDim("Out", shape_out);
} else {
context->SetOutputDim("Out", ddim_out);
ddim_out = ddim_out.transpose(transpose_out).reshape(reshape_out);
}
#else
context->SetOutputDim("Out", ddim_out);
#endif
context->ShareLoD("X", /*->*/ "Out");
context->SetOutputDim("Out", ddim_out);
context->ShareLoD("X", "Out");
}

framework::OpKernelType GetExpectedKernelType(
Expand Down
71 changes: 3 additions & 68 deletions paddle/fluid/operators/matmul_v2_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,77 +176,12 @@ class MatMulV2Op : public framework::OperatorWithKernel {
ctx->Attrs().Get<std::vector<int>>("fused_transpose_Out");

if (!reshape_out.empty() && !transpose_out.empty()) {
auto reshape_out_size = reshape_out.size();
auto transpose_out_size = transpose_out.size();
PADDLE_ENFORCE_EQ(transpose_out_size, 4,
platform::errors::InvalidArgument(
"transpose_out supported rank is 4, "
"received %d",
transpose_out_size));
const std::vector<int> supported_axis{0, 2, 1, 3};
const bool supported_transpose_axis = std::equal(
transpose_out.begin(), transpose_out.end(), supported_axis.begin());
PADDLE_ENFORCE_EQ(
supported_transpose_axis, true,
platform::errors::InvalidArgument(
"supported transpose axis for the fuse are {0, 2, 1, 3}"));
PADDLE_ENFORCE_EQ(
reshape_out_size, 3,
platform::errors::InvalidArgument("reshape_out supported rank is 3, "
"received %d",
reshape_out_size));

// int num_negative = std::count(reshape_out.begin(), reshape_out.end(),
// -1);
// PADDLE_ENFORCE_LE(num_negative, 1,
// platform::errors::InvalidArgument(
// "The max number of -1 in fused_reshape_Out is 1 "
// "but received %d.",
// num_negative));

// auto it_zero = std::find(reshape_out.begin(), reshape_out.end(), 0);
// if (it_zero != reshape_out.end()) {
// for (uint64_t i = 0; i < reshape_out.size(); i++) {
// if (reshape_out[i] == 0) {
// PADDLE_ENFORCE_LT(
// i, ddim_out.size(),
// platform::errors::InvalidArgument(
// "The index of 0 in fused_reshape_Out ",
// "should be less than output dim size, ",
// "but the index is %d and output dim size is %d", i,
// ddim_out.size()));
// reshape_out[i] = ddim_out.at(i);
// }
// }
// }

// if "-1" is present then one of reshape dims must be infered
auto it = std::find(reshape_out.begin(), reshape_out.end(), -1);
if (it != reshape_out.end()) {
int index = std::distance(reshape_out.begin(), it);

auto ddim_out_vec = phi::vectorize(ddim_out);

int ddim_out_product =
std::accumulate(ddim_out_vec.begin(), ddim_out_vec.end(), 1,
std::multiplies<int>());
int reshape_out_product = std::accumulate(
reshape_out.begin(), reshape_out.end(), -1, std::multiplies<int>());

reshape_out[index] = ddim_out_product / reshape_out_product;
}

framework::DDim shape_out =
ddim_out.transpose(transpose_out).reshape(reshape_out);
ctx->SetOutputDim("Out", shape_out);
} else {
ctx->SetOutputDim("Out", ddim_out);
ddim_out = ddim_out.transpose(transpose_out).reshape(reshape_out);
}
#else
ctx->SetOutputDim("Out", ddim_out);
#endif

ctx->ShareLoD("X", /* --> */ "Out");
ctx->SetOutputDim("Out", ddim_out);
ctx->ShareLoD("X", "Out");
}

protected:
Expand Down
17 changes: 16 additions & 1 deletion paddle/phi/core/ddim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,26 @@ DDim stride_numel(const DDim& ddim) {
return strides;
}

DDim DDim::reshape(const std::vector<int>& shape) const {
DDim DDim::reshape(const std::vector<int>& new_shape) const {
std::vector<int> shape = new_shape;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating a new vector here and do ifelse branch in core/ddim.cc will influence each cpu/gpu op I think ?

const int64_t copy_dim_val = 0;
const DDim& in_dims = *this;
DDim out_dims;
out_dims.rank_ = shape.size();

// dim marked as "-1" must be infered
auto it = std::find(shape.begin(), shape.end(), -1);
if (it != shape.end()) {
int index = std::distance(shape.begin(), it);
auto ddim_out_vec = phi::vectorize(in_dims);
int ddim_out_product = std::accumulate(
ddim_out_vec.begin(), ddim_out_vec.end(), 1, std::multiplies<int>());
int reshape_out_product =
std::accumulate(shape.begin(), shape.end(), -1, std::multiplies<int>());

shape[index] = ddim_out_product / reshape_out_product;
}
Silv3S marked this conversation as resolved.
Show resolved Hide resolved

for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] == copy_dim_val) {
PADDLE_ENFORCE_LT(static_cast<int>(i),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -518,44 +518,6 @@ class TestMatMulOpTransposeReshapeOtherDimInt(
def init_data_type(self):
self.data_type_ = np.int8


class TestMatMulOpTransposeReshapeTransposeAxisNotSupportedException(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why these tests are deleted?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests only checked if after setting incorrect values for fuse_reshape_out and fuse_transpose_out test would fail. I deleted it, because the only way of setting these parameters is via fuse_pass, which has all PADDLE_ENFORCES to prevent such case. Also, we have separate unit test dedicated for matmul+transpose+reshape fuse pass.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice that you have spotted that! Thanks!

TestMatMulOpTransposeReshapeBasicFloat):
def init_params_and_out(self):
self.transpose_out = [0, 1, 2, 3]
self.reshape_out = [0, 0, self.x.shape[1] * self.y.shape[-1]]
self.out = np.matmul(self.x, self.y)

def test_check_output(self):
self.assertRaises(AttributeError, self.check_raise_error,
'supported transpose axis '
'for the fuse are {0, 2, 1, 3}')


class TestMatMulOpTransposeReshapeTransposeRankNotSupportedException(
TestMatMulOpTransposeReshapeBasicFloat):
def init_params_and_out(self):
self.transpose_out = [0, 2, 1]
self.reshape_out = [0, 0, self.x.shape[1] * self.y.shape[-1]]
self.out = np.matmul(self.x, self.y)

def test_check_output(self):
self.assertRaises(AttributeError, self.check_raise_error,
'transpose_out supported rank is 4')


class TestMatMulOpTransposeReshapeRankOfReshapeNotSupportedException(
TestMatMulOpTransposeReshapeBasicFloat):
def init_params_and_out(self):
self.transpose_out = [0, 2, 1, 3]
self.reshape_out = [0, 0]
self.out = np.matmul(self.x, self.y)

def test_check_output(self):
self.assertRaises(AttributeError, self.check_raise_error,
'reshape_out supported rank is 3')


if __name__ == "__main__":
from paddle import enable_static
enable_static()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@
TestMatMulOpTransposeReshapeEmptyFloat,
TestMatMulOpTransposeReshapeBasicFloat,
TestMatMulOpTransposeReshapeOtherDimFloat,
TestMatMulOpTransposeReshapeTransposeAxisNotSupportedException,
TestMatMulOpTransposeReshapeTransposeRankNotSupportedException,
TestMatMulOpTransposeReshapeRankOfReshapeNotSupportedException,
TestReshapeTransposeMatMulOp, TestReshapeTransposeMatMulOp4DXFloat,
TestReshapeTransposeMatMulOp4DYFloat, TestReshapeTransposeMatMulOp4DXYFloat,
TestReshapeTransposeMatMulOp2DXFloat, TestReshapeTransposeMatMulOp2DYFloat,
Expand Down Expand Up @@ -457,24 +454,6 @@ def set_op_type(self):
self.op_type = "matmul_v2"


class TestMatMulV2OpTransposeReshapeTransposeAxisNotSupportedException(
TestMatMulOpTransposeReshapeTransposeAxisNotSupportedException):
def set_op_type(self):
self.op_type = "matmul_v2"


class TestMatMulV2OpTransposeReshapeRankOfReshapeNotSupportedException(
TestMatMulOpTransposeReshapeRankOfReshapeNotSupportedException):
def set_op_type(self):
self.op_type = "matmul_v2"


class TestMatMulV2OpTransposeReshapeTransposeRankNotSupportedException(
TestMatMulOpTransposeReshapeTransposeRankNotSupportedException):
def set_op_type(self):
self.op_type = "matmul_v2"


class TestMatMulV2OpReshapeTranspose(TestReshapeTransposeMatMulOp):
def set_op_type_and_transpose_y_name(self):
self.op_type = "matmul_v2"
Expand Down