Skip to content

Commit

Permalink
raise exception for missing op_handlers (PaddlePaddle#210)
Browse files Browse the repository at this point in the history
* raise exception for missing op_handlers

* rename var
  • Loading branch information
gglin001 authored Oct 8, 2021
1 parent c330a67 commit 075a074
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 1 deletion.
12 changes: 11 additions & 1 deletion paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ void PopartCanonicalizationPass::ApplyImpl(ir::Graph* graph) const {
VLOG(10) << "Raw Graph: ";
VLOG(10) << DebugString(graph);

std::vector<std::string> missing_ops;
auto nodes = graph->Nodes();
for (auto* node : nodes) {
if (!node->IsOp()) {
Expand All @@ -46,10 +47,19 @@ void PopartCanonicalizationPass::ApplyImpl(ir::Graph* graph) const {
ipu::ClearNode(node);
graph->RemoveNode(node);
} else {
LOG(ERROR) << "Can not find OpHandler for op_type: " << op_type;
missing_ops.push_back(op_type);
}
}

if (!missing_ops.empty()) {
LOG(ERROR) << "Can not find OpHandler for op_type: ";
for (auto& op_type : missing_ops) {
LOG(ERROR) << op_type;
}
PADDLE_THROW(platform::errors::Unimplemented(
"Found unimplemented op_handler(s) for IPU"));
}

// post popart_canonicalization

VLOG(10) << "Post Graph: ";
Expand Down
102 changes: 102 additions & 0 deletions python/paddle/fluid/tests/unittests/ipu/test_ipu_missing_op_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.compiler as compiler
import paddle.optimizer
import paddle.static
from paddle.fluid.tests.unittests.ipu.op_test_ipu import (IPUOpTest,
np_dtype_to_fluid_str)

paddle.enable_static()


@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestBase(IPUOpTest):
def setUp(self):
self.set_atol()
self.set_training()
self.set_feed()
self.set_feed_attr()
self.set_attrs()

def set_feed(self):
self.feed = {
"x": np.random.uniform(size=[1, 3, 2, 2]).astype('float32'),
}

def set_feed_attr(self):
self.feed_shape = [x.shape for x in self.feed.values()]
self.feed_list = list(self.feed.keys())
self.feed_dtype = [
np_dtype_to_fluid_str(x.dtype) for x in self.feed.values()
]

def set_attrs(self):
self.attrs = {"first_n": 1}

def _test_base(self, run_ipu=True):
scope = fluid.core.Scope()
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
SEED = self.SEED
main_prog.random_seed = SEED
startup_prog.random_seed = SEED

with fluid.scope_guard(scope):
with paddle.static.program_guard(main_prog, startup_prog):
x = paddle.static.data(
name=self.feed_list[0],
shape=self.feed_shape[0],
dtype=self.feed_dtype[0])
# print op is unimplemented
out = paddle.fluid.layers.Print(x, **self.attrs)

fetch_list = [out.name]

if run_ipu:
place = paddle.IPUPlace()
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)

if run_ipu:
feed_list = self.feed_list
ipu_strategy = compiler.get_ipu_strategy()
ipu_strategy.is_training = self.is_training
program = compiler.IpuCompiler(
main_prog,
ipu_strategy=ipu_strategy).compile(feed_list, fetch_list)
else:
program = main_prog

result = exe.run(program, feed=self.feed, fetch_list=fetch_list)
return result[0]

def test_base(self):
res0 = self._test_base(False)
try:
res1 = self._test_base(True)
except NotImplementedError:
pass


if __name__ == "__main__":
unittest.main()

0 comments on commit 075a074

Please sign in to comment.