Skip to content

Commit

Permalink
support paddle elementwise_floordiv
Browse files Browse the repository at this point in the history
  • Loading branch information
taixiurong committed Sep 26, 2022
1 parent 2bc32a1 commit 6f54f6f
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/core/tests/frontend/paddle/op_fuzzy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,12 @@ static const std::vector<std::string> models{
std::string("elementwise_mul4"),
std::string("elementwise_pow4"),
std::string("elementwise_sub4"),
std::string("elementwise_floordiv_int32_1"),
std::string("elementwise_floordiv_int32_2"),
std::string("elementwise_floordiv_int32_3"),
std::string("elementwise_floordiv_int64_1"),
std::string("elementwise_floordiv_int64_2"),
std::string("elementwise_floordiv_int64_3"),
std::string("embedding_0"),
std::string("embedding_sparse"),
std::string("embedding_none_weight"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,34 @@ def elementwise_pow(name : str, x, y, axis, in_dtype):

return outs[0]


def elementwise_floordiv(name : str, x, y, axis, in_dtype):
import paddle
paddle.enable_static()

with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()):
node_x = paddle.static.data(name = 'x', shape = x.shape, dtype = in_dtype)
node_y = paddle.static.data(name = 'y', shape = y.shape, dtype = in_dtype)
if paddle.__version__ == "1.8":
out = paddle.fluid.layers.nn.elementwise_floordiv(node_x, node_y, axis=axis)
else:
if axis != -1:
pass
out = paddle.floor_divide(node_x, node_y)

cpu = paddle.static.cpu_places(1)
exe = paddle.static.Executor(cpu[0])

# startup program will call initializer to initialize the parameters.
exe.run(paddle.static.default_startup_program())
outs = exe.run(
feed={'x': x, 'y': y},
fetch_list=[out])
saveModel(name, exe, feedkeys=['x', 'y'], fetchlist=[out], inputs=[x, y], outputs=[outs[0]], target_dir=sys.argv[1])

return outs[0]


def elementwise_ops(name : str, data_x, data_y, axis, in_dtype):
elementwise_add("elementwise_add" + name, data_x, data_y, axis, in_dtype)
elementwise_sub("elementwise_sub" + name, data_x, data_y, axis, in_dtype)
Expand Down Expand Up @@ -193,5 +221,29 @@ def main():
axis = 0
elementwise_ops("4", data_x, data_y, axis, in_dtype)

# test for elementwise_floordiv, support int and int64
# paddle1.8 support axis = [0, x_last_dims]
# paddle2.x only support axis = -1
floordiv_support_dtype = ['int64', 'int32']
data_x = np.array([-4, 0, -8])

data_y = np.array([3, 5, 3])
axis = -1
for dtype in floordiv_support_dtype:
elementwise_floordiv("elementwise_floordiv_" + dtype + "_1",
data_x.astype(dtype), data_y.astype(dtype), axis, dtype)

data_x = np.random.randint(-10, 10, [2, 5, 3, 4])
data_y = np.random.randint(1, 5, [3, 4])
for dtype in floordiv_support_dtype:
elementwise_floordiv("elementwise_floordiv_" + dtype + "_2",
data_x.astype(dtype), data_y.astype(dtype), axis, dtype)

data_y = np.random.randint(1, 5, [5, 3, 4])
for dtype in floordiv_support_dtype:
elementwise_floordiv("elementwise_floordiv_" + dtype + "_3",
data_x.astype(dtype), data_y.astype(dtype), axis, dtype)


if __name__ == "__main__":
main()
15 changes: 15 additions & 0 deletions src/frontends/paddle/src/op/elementwise_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,21 @@ NamedOutputs elementwise_greater_equal(const NodeContext& node_context) {
return elementwise_ops<default_opset::GreaterEqual>(node_context);
}

NamedOutputs elementwise_floordiv(const NodeContext& node_context) {
auto x = node_context.get_input("X");
auto y = node_context.get_input("Y");
auto axis = -1;
if (node_context.has_attribute("axis")) {
axis = node_context.get_attribute<int>("axis");
}
return node_context.default_single_output_mapping(
{std::make_shared<default_opset::Divide>(x,
y,
false,
ov::op::AutoBroadcastSpec(ov::op::AutoBroadcastType::PDPD, axis))},
{"Out"});
}

} // namespace op
} // namespace paddle
} // namespace frontend
Expand Down
2 changes: 2 additions & 0 deletions src/frontends/paddle/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ OP_CONVERTER(dropout);
OP_CONVERTER(elementwise_add);
OP_CONVERTER(elementwise_div);
OP_CONVERTER(elementwise_equal);
OP_CONVERTER(elementwise_floordiv);
OP_CONVERTER(elementwise_greater_equal);
OP_CONVERTER(elementwise_max);
OP_CONVERTER(elementwise_min);
Expand Down Expand Up @@ -123,6 +124,7 @@ std::map<std::string, CreatorFunction> get_supported_ops() {
{"dropout", op::dropout},
{"elementwise_add", op::elementwise_add},
{"elementwise_div", op::elementwise_div},
{"elementwise_floordiv", op::elementwise_floordiv},
{"elementwise_max", op::elementwise_max},
{"elementwise_min", op::elementwise_min},
{"elementwise_mul", op::elementwise_mul},
Expand Down

0 comments on commit 6f54f6f

Please sign in to comment.