Skip to content

Commit

Permalink
add Equal Mean ops (PaddlePaddle#68)
Browse files Browse the repository at this point in the history
* add equal,mean, fix mul

* add equal, mean, elementwis_miul tests

* include header
  • Loading branch information
gglin001 authored Aug 17, 2021
1 parent 24268f3 commit f468c40
Show file tree
Hide file tree
Showing 6 changed files with 275 additions and 1 deletion.
10 changes: 10 additions & 0 deletions paddle/fluid/framework/ipu/ipu_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,11 @@ void IpuBackend::LowerBody(const ir::Graph* graph) {
auto outputs = op->Output("__outputs__");
popart::TensorId result = builder_->aiOnnxOpset11().add(inputs);
tensors_.emplace(outputs[0], result);
} else if (op_type == "Mul") {
auto inputs = GetOpInputs(op);
auto outputs = op->Output("__outputs__");
popart::TensorId result = builder_->aiOnnxOpset11().mul(inputs);
tensors_.emplace(outputs[0], result);
} else if (op_type == "Conv") {
auto inputs = GetOpInputs(op);
auto outputs = op->Output("__outputs__");
Expand All @@ -344,6 +349,11 @@ void IpuBackend::LowerBody(const ir::Graph* graph) {
popart::TensorId result = builder_->aiOnnxOpset11().conv(
inputs, dilations, group, {}, pads, strides);
tensors_.emplace(outputs[0], result);
} else if (op_type == "Equal") {
auto inputs = GetOpInputs(op);
auto outputs = op->Output("__outputs__");
popart::TensorId result = builder_->aiOnnxOpset11().equal(inputs);
tensors_.emplace(outputs[0], result);
} else if (op_type == "MatMul") {
auto inputs = GetOpInputs(op);
auto outputs = op->Output("__outputs__");
Expand Down
12 changes: 11 additions & 1 deletion paddle/fluid/framework/ipu/popart_canonicalization/logic_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,24 @@
// limitations under the License.

#include "paddle/fluid/framework/ipu/popart_canonicalization/canonicalization_utils.h"
#include "paddle/fluid/framework/ipu/popart_canonicalization/op_builder.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace framework {
namespace ipu {
namespace {

//
ir::Node *equal_handler(ir::Graph *graph, ir::Node *node) {
auto new_node = CreateBaseOp(
graph, "Equal", {GetInputNode("X", node), GetInputNode("Y", node)},
node->outputs);
ReplaceNodeInputs(node, new_node);
ReplaceNodeOutputs(node, new_node);
return new_node;
}

REGISTER_HANDLER(equal, equal_handler);

} // namespace
} // namespace ipu
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/framework/ipu/popart_canonicalization/math_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ ir::Node *reduce_mean_handler(ir::Graph *graph, ir::Node *node) {
return graph->CreateOpNode(op_desc.get());
}

ir::Node *mean_handler(ir::Graph *graph, ir::Node *node) {
auto new_node = CreateBaseOp(graph, "ReduceMean", {GetInputNode("X", node)},
{GetOutputNode("Out", node)},
{
{"keepdims", int64_t{0}},
});
ReplaceNodeInputs(node, new_node);
ReplaceNodeOutputs(node, new_node);
return new_node;
}

ir::Node *pow_handler(ir::Graph *graph, ir::Node *node) {
// Op(pow) -> Op(Constant)->Var(const_out)->Op(Pow)
auto *op = node->Op();
Expand Down Expand Up @@ -111,6 +122,7 @@ ir::Node *softmax_handler(ir::Graph *graph, ir::Node *node) {
}

REGISTER_HANDLER(reduce_mean, reduce_mean_handler);
REGISTER_HANDLER(mean, mean_handler);
REGISTER_HANDLER(pow, pow_handler);
REGISTER_HANDLER(mul, mul_handler);
REGISTER_HANDLER(sum, sum_handler);
Expand Down
86 changes: 86 additions & 0 deletions python/paddle/fluid/tests/unittests/ipu/test_elemetwise_mul_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import numpy as np
import unittest
import paddle
import paddle.fluid
import paddle.static
import paddle.fluid.compiler as compiler

paddle.enable_static()
SEED = 2021


@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestAdd(unittest.TestCase):
def _test_add(self, run_ipu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)

np_a = np.random.rand(3, 3, 3).astype(np.float32)
np_b = np.arange(1, 4).reshape([3]).astype(np.float32)
with paddle.static.program_guard(main_prog, startup_prog):
a = paddle.static.data(
name="a",
shape=[3, 3, 3],
dtype='float32', )
b = paddle.static.data(
name="b",
shape=[3],
dtype='float32', )
# out = paddle.fluid.layers.elementwise_mul(a, b, axis=-1)
# out = paddle.fluid.layers.elementwise_mul(a, b, axis=0)
# out = paddle.fluid.layers.elementwise_mul(a, b, axis=1)
out = paddle.fluid.layers.elementwise_mul(a, b, axis=2)

if run_ipu:
place = paddle.IPUPlace()
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)

if run_ipu:
feed_list = [a.name, b.name]
fetch_list = [out.name]
ipu_strategy = compiler.get_ipu_strategy()
ipu_strategy.is_training = False
program = compiler.IpuCompiler(
main_prog, ipu_strategy=ipu_strategy).compile(feed_list,
fetch_list)
else:
program = main_prog

result = exe.run(
program,
feed={'a': np_a,
'b': np_b},
fetch_list=[out], )
return result[0]

def test_add(self):
ipu_res = self._test_add(True)
cpu_res = self._test_add(False)
self.assertTrue(np.allclose(ipu_res, cpu_res, atol=1e-4))


if __name__ == "__main__":
unittest.main()
79 changes: 79 additions & 0 deletions python/paddle/fluid/tests/unittests/ipu/test_equal_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import numpy as np
import unittest
import paddle
import paddle.fluid
import paddle.fluid.compiler as compiler

paddle.enable_static()
SEED = 2021


@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestPow(unittest.TestCase):
def _test_pow(self, run_ipu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)

np_data = np.random.uniform(
low=0, high=1, size=(1, 10)).astype(np.float32)
with paddle.static.program_guard(main_prog, startup_prog):
data = paddle.static.data(
name="data",
shape=[1, 10],
dtype='float32', )
zero = paddle.fluid.layers.fill_constant(
shape=[1, 10], value=0.0, dtype='float32')
out = paddle.fluid.layers.equal(data, zero)

if run_ipu:
place = paddle.IPUPlace()
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)

if run_ipu:
feed_list = [data.name]
fetch_list = [out.name]
ipu_strategy = compiler.get_ipu_strategy()
ipu_strategy.is_training = False
program = compiler.IpuCompiler(
main_prog, ipu_strategy=ipu_strategy).compile(feed_list,
fetch_list)
else:
program = main_prog

result = exe.run(program, feed={'data': np_data}, fetch_list=[out])
return result[0]

def test_pow(self):
ipu_res = self._test_pow(True)
cpu_res = self._test_pow(False)

self.assertTrue(np.allclose(ipu_res, cpu_res, atol=1e-4))

print()


if __name__ == "__main__":
unittest.main()
77 changes: 77 additions & 0 deletions python/paddle/fluid/tests/unittests/ipu/test_mean_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import numpy as np
import unittest
import paddle
import paddle.fluid
import paddle.fluid.compiler as compiler

paddle.enable_static()
SEED = 2021


@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestPow(unittest.TestCase):
def _test_pow(self, run_ipu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)

np_data = np.random.uniform(
low=0, high=1, size=(1, 3, 10, 10)).astype(np.float32)
with paddle.static.program_guard(main_prog, startup_prog):
data = paddle.static.data(
name="data",
shape=[1, 3, 10, 10],
dtype='float32', )
out = paddle.fluid.layers.mean(data)

if run_ipu:
place = paddle.IPUPlace()
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)

if run_ipu:
feed_list = [data.name]
fetch_list = [out.name]
ipu_strategy = compiler.get_ipu_strategy()
ipu_strategy.is_training = False
program = compiler.IpuCompiler(
main_prog, ipu_strategy=ipu_strategy).compile(feed_list,
fetch_list)
else:
program = main_prog

result = exe.run(program, feed={'data': np_data}, fetch_list=[out])
return result[0]

def test_pow(self):
ipu_res = self._test_pow(True)
cpu_res = self._test_pow(False)

self.assertTrue(np.allclose(ipu_res, cpu_res, atol=1e-4))

print()


if __name__ == "__main__":
unittest.main()

0 comments on commit f468c40

Please sign in to comment.