diff --git a/ngraph/frontend/paddlepaddle/src/op/argmax.cpp b/ngraph/frontend/paddlepaddle/src/op/argmax.cpp new file mode 100644 index 00000000000000..7d8c069031d07f --- /dev/null +++ b/ngraph/frontend/paddlepaddle/src/op/argmax.cpp @@ -0,0 +1,57 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "argmax.hpp" +#include + +namespace ngraph +{ + namespace frontend + { + namespace pdpd + { + namespace op + { + NamedOutputs argmax(const NodeContext& node) + { + auto data = node.get_ng_input("X"); + bool flatten = node.get_attribute("flatten"); + const element::Type& index_element_type = element::i64; + const Output k = + ngraph::opset6::Constant::create(ngraph::element::i64, {}, {1}); + + if (!flatten) + { + auto axis = node.get_attribute("axis"); + const auto axis_to_remove = + ngraph::opset6::Constant::create(element::u64, Shape{}, {axis}); + auto node_topk = std::make_shared( + data, k, axis, "max", "index", index_element_type); + const auto reshaped_indices = std::make_shared( + node_topk->output(1), axis_to_remove); + return node.default_single_output_mapping( + {std::make_shared(reshaped_indices, + element::i64)}, + {"Out"}); + } + else + { + int64_t axis = 0; + const Output reshape_flatten = + ngraph::opset6::Constant::create(ngraph::element::i64, {1}, {-1}); + auto node_reshape = + std::make_shared(data, reshape_flatten, true); + auto node_topk = std::make_shared( + node_reshape, k, axis, "max", "index", index_element_type); + return node.default_single_output_mapping( + {std::make_shared(node_topk->output(1), + element::i64)}, + {"Out"}); + } + } + + } // namespace op + } // namespace pdpd + } // namespace frontend +} // namespace ngraph diff --git a/ngraph/frontend/paddlepaddle/src/op/argmax.hpp b/ngraph/frontend/paddlepaddle/src/op/argmax.hpp new file mode 100644 index 00000000000000..20d9db406be0cf --- /dev/null +++ b/ngraph/frontend/paddlepaddle/src/op/argmax.hpp @@ -0,0 +1,20 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once +#include "node_context.hpp" + +namespace ngraph +{ + namespace frontend + { + namespace pdpd + { + namespace op + { + NamedOutputs argmax(const NodeContext& node); + } + } // namespace pdpd + } // namespace frontend +} // namespace ngraph \ No newline at end of file diff --git a/ngraph/frontend/paddlepaddle/src/op/assign_value.cpp b/ngraph/frontend/paddlepaddle/src/op/assign_value.cpp new file mode 100644 index 00000000000000..fb503abbba80e8 --- /dev/null +++ b/ngraph/frontend/paddlepaddle/src/op/assign_value.cpp @@ -0,0 +1,66 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "assign_value.hpp" +#include +namespace ngraph +{ + namespace frontend + { + namespace pdpd + { + namespace op + { + NamedOutputs assign_value(const NodeContext& node) + { + std::vector shape = node.get_attribute>("shape"); + auto dtype = node.get_attribute("dtype"); + std::shared_ptr const_node; + + switch (dtype) + { + case element::i32: + { + auto values = node.get_attribute>("int32_values"); + const_node = {opset6::Constant::create( + dtype, Shape{shape.begin(), shape.end()}, values)}; + break; + } + case element::f32: + { + std::vector values = + node.get_attribute>("fp32_values"); + const_node = {opset6::Constant::create( + dtype, Shape{shape.begin(), shape.end()}, values)}; + break; + } + case element::boolean: + { + auto values = node.get_attribute>("bool_values"); + const_node = {opset6::Constant::create( + dtype, Shape{shape.begin(), shape.end()}, values)}; + break; + } + case element::i64: + { + auto values = node.get_attribute>("int64_values"); + const_node = {opset6::Constant::create( + dtype, Shape{shape.begin(), shape.end()}, values)}; + break; + } + default: + { + PDPD_OP_VALIDATION_CHECK( + node, false, "assign_value only supports int32, int64, float32, bool"); + break; + } + } + + return node.default_single_output_mapping({const_node}, {"Out"}); + } + + } // namespace op + } // namespace pdpd + } // namespace frontend +} // namespace ngraph diff --git a/ngraph/frontend/paddlepaddle/src/op/assign_value.hpp b/ngraph/frontend/paddlepaddle/src/op/assign_value.hpp new file mode 100644 index 00000000000000..b954b3a04cce50 --- /dev/null +++ b/ngraph/frontend/paddlepaddle/src/op/assign_value.hpp @@ -0,0 +1,21 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "node_context.hpp" + +namespace ngraph +{ + namespace frontend + { + namespace pdpd + { + namespace op + { + NamedOutputs assign_value(const NodeContext& node); + } + } // namespace pdpd + } // namespace frontend +} // namespace ngraph diff --git a/ngraph/frontend/paddlepaddle/src/op/batch_norm.cpp b/ngraph/frontend/paddlepaddle/src/op/batch_norm.cpp new file mode 100644 index 00000000000000..c38c4189fa04a0 --- /dev/null +++ b/ngraph/frontend/paddlepaddle/src/op/batch_norm.cpp @@ -0,0 +1,64 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "batch_norm.hpp" +#include + +namespace ngraph +{ + namespace frontend + { + namespace pdpd + { + namespace op + { + NamedOutputs batch_norm(const NodeContext& node) + { + auto data = node.get_ng_input("X"); + auto gamma = node.get_ng_input("Scale"); + auto beta = node.get_ng_input("Bias"); + auto mean = node.get_ng_input("Mean"); + auto variance = node.get_ng_input("Variance"); + auto data_layout = node.get_attribute("data_layout"); + + PDPD_ASSERT((data_layout == "NCHW" || data_layout == "NHWC"), + "Not supported input data layout!"); + if (data_layout == "NCHW") + { + return node.default_single_output_mapping( + {std::make_shared( + data, + gamma, + beta, + mean, + variance, + node.get_attribute("epsilon"))}, + {"Y"}); + } + else + { + auto input_order = ngraph::opset6::Constant::create( + ngraph::element::i64, {4}, {0, 3, 1, 2}); + auto data_nchw = + std::make_shared(data, input_order); + auto node_batch_norm = std::make_shared( + data_nchw, + gamma, + beta, + mean, + variance, + node.get_attribute("epsilon")); + auto output_order = ngraph::opset6::Constant::create( + ngraph::element::i64, {4}, {0, 2, 3, 1}); + return node.default_single_output_mapping( + {std::make_shared(node_batch_norm, + output_order)}, + {"Y"}); + } + } + + } // namespace op + } // namespace pdpd + } // namespace frontend +} // namespace ngraph diff --git a/ngraph/frontend/paddlepaddle/src/op/batch_norm.hpp b/ngraph/frontend/paddlepaddle/src/op/batch_norm.hpp new file mode 100644 index 00000000000000..3757421bba65f5 --- /dev/null +++ b/ngraph/frontend/paddlepaddle/src/op/batch_norm.hpp @@ -0,0 +1,20 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once +#include "node_context.hpp" + +namespace ngraph +{ + namespace frontend + { + namespace pdpd + { + namespace op + { + NamedOutputs batch_norm(const NodeContext& node); + } + } // namespace pdpd + } // namespace frontend +} // namespace ngraph \ No newline at end of file diff --git a/ngraph/frontend/paddlepaddle/src/op/cast.cpp b/ngraph/frontend/paddlepaddle/src/op/cast.cpp new file mode 100644 index 00000000000000..2cb181f0b24158 --- /dev/null +++ b/ngraph/frontend/paddlepaddle/src/op/cast.cpp @@ -0,0 +1,28 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "cast.hpp" +#include + +namespace ngraph +{ + namespace frontend + { + namespace pdpd + { + namespace op + { + NamedOutputs cast(const NodeContext& node) + { + auto data = node.get_ng_input("X"); + auto out_dtype = node.get_attribute("out_dtype"); + + return node.default_single_output_mapping( + {std::make_shared(data, out_dtype)}, {"Out"}); + } + + } // namespace op + } // namespace pdpd + } // namespace frontend +} // namespace ngraph \ No newline at end of file diff --git a/ngraph/frontend/paddlepaddle/src/op/cast.hpp b/ngraph/frontend/paddlepaddle/src/op/cast.hpp new file mode 100644 index 00000000000000..1e3a19aaf5975c --- /dev/null +++ b/ngraph/frontend/paddlepaddle/src/op/cast.hpp @@ -0,0 +1,20 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once +#include "node_context.hpp" + +namespace ngraph +{ + namespace frontend + { + namespace pdpd + { + namespace op + { + NamedOutputs cast(const NodeContext& node); + } + } // namespace pdpd + } // namespace frontend +} // namespace ngraph \ No newline at end of file diff --git a/ngraph/frontend/paddlepaddle/src/op/clip.cpp b/ngraph/frontend/paddlepaddle/src/op/clip.cpp new file mode 100644 index 00000000000000..1909e392eaf2f8 --- /dev/null +++ b/ngraph/frontend/paddlepaddle/src/op/clip.cpp @@ -0,0 +1,31 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "clip.hpp" +#include + +namespace ngraph +{ + namespace frontend + { + namespace pdpd + { + namespace op + { + NamedOutputs clip(const NodeContext& node) + { + auto data = node.get_ng_input("X"); + auto min = node.get_attribute("min"); + auto max = node.get_attribute("max"); + PDPD_OP_VALIDATION_CHECK( + node, max >= min, "clip: max value must greater than min value!"); + + return node.default_single_output_mapping( + {std::make_shared(data, min, max)}, {"Out"}); + } + + } // namespace op + } // namespace pdpd + } // namespace frontend +} // namespace ngraph \ No newline at end of file diff --git a/ngraph/frontend/paddlepaddle/src/op/clip.hpp b/ngraph/frontend/paddlepaddle/src/op/clip.hpp new file mode 100644 index 00000000000000..babfa2ccd95bfd --- /dev/null +++ b/ngraph/frontend/paddlepaddle/src/op/clip.hpp @@ -0,0 +1,20 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once +#include "node_context.hpp" + +namespace ngraph +{ + namespace frontend + { + namespace pdpd + { + namespace op + { + NamedOutputs clip(const NodeContext& node); + } + } // namespace pdpd + } // namespace frontend +} // namespace ngraph \ No newline at end of file diff --git a/ngraph/frontend/paddlepaddle/src/op/concat.cpp b/ngraph/frontend/paddlepaddle/src/op/concat.cpp new file mode 100644 index 00000000000000..a9c6fa6388d848 --- /dev/null +++ b/ngraph/frontend/paddlepaddle/src/op/concat.cpp @@ -0,0 +1,27 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "concat.hpp" +#include + +namespace ngraph +{ + namespace frontend + { + namespace pdpd + { + namespace op + { + NamedOutputs concat(const NodeContext& node) + { + auto data = node.get_ng_inputs("X"); + auto axis = node.get_attribute("axis"); + return node.default_single_output_mapping( + {std::make_shared(data, axis)}, {"Out"}); + } + + } // namespace op + } // namespace pdpd + } // namespace frontend +} // namespace ngraph \ No newline at end of file diff --git a/ngraph/frontend/paddlepaddle/src/op/concat.hpp b/ngraph/frontend/paddlepaddle/src/op/concat.hpp new file mode 100644 index 00000000000000..0d32fa22f6e3bd --- /dev/null +++ b/ngraph/frontend/paddlepaddle/src/op/concat.hpp @@ -0,0 +1,20 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once +#include "node_context.hpp" + +namespace ngraph +{ + namespace frontend + { + namespace pdpd + { + namespace op + { + NamedOutputs concat(const NodeContext& node); + } + } // namespace pdpd + } // namespace frontend +} // namespace ngraph \ No newline at end of file diff --git a/ngraph/frontend/paddlepaddle/src/op_table.cpp b/ngraph/frontend/paddlepaddle/src/op_table.cpp index 411cfe8ecbf6d2..916737fc0c2ede 100644 --- a/ngraph/frontend/paddlepaddle/src/op_table.cpp +++ b/ngraph/frontend/paddlepaddle/src/op_table.cpp @@ -1,7 +1,12 @@ // Copyright (C) 2018-2021 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // - +#include "op/argmax.hpp" +#include "op/assign_value.hpp" +#include "op/batch_norm.hpp" +#include "op/cast.hpp" +#include "op/clip.hpp" +#include "op/concat.hpp" #include "op/conv2d.hpp" #include "op/elementwise_ops.hpp" #include "op/relu.hpp" @@ -18,7 +23,13 @@ namespace ngraph { std::map get_supported_ops() { - return {{"conv2d", op::conv2d}, + return {{"arg_max", op::argmax}, + {"assign_value", op::assign_value}, + {"batch_norm", op::batch_norm}, + {"cast", op::cast}, + {"clip", op::clip}, + {"concat", op::concat}, + {"conv2d", op::conv2d}, {"elementwise_add", op::elementwise_add}, {"elementwise_div", op::elementwise_div}, {"elementwise_max", op::elementwise_max}, diff --git a/ngraph/test/CMakeLists.txt b/ngraph/test/CMakeLists.txt index 8dbf0f888bcf8c..ae12fb2e3e1610 100644 --- a/ngraph/test/CMakeLists.txt +++ b/ngraph/test/CMakeLists.txt @@ -579,6 +579,7 @@ add_executable(unit-test ${SRC}) target_include_directories(unit-test PRIVATE ".") target_include_directories(unit-test PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/runtime) +target_include_directories(unit-test PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/frontend/shared/include) add_definitions("-DCURDIR=\"${CMAKE_CURRENT_SOURCE_DIR}\"") add_definitions("-DJSON_INCLUDES=\"${JSON_INCLUDE_DIR}\"") @@ -647,6 +648,7 @@ install(TARGETS unit-test ############ FRONTEND ############ target_include_directories(unit-test PRIVATE ${FRONTEND_INCLUDE_PATH}) target_link_libraries(unit-test PRIVATE frontend_manager) +target_link_libraries(unit-test PRIVATE cnpy) add_subdirectory(frontend) ### END FRONTEND ### diff --git a/ngraph/test/files/paddlepaddle/gen_scripts/generate_argmax.py b/ngraph/test/files/paddlepaddle/gen_scripts/generate_argmax.py new file mode 100644 index 00000000000000..54b24364b2d481 --- /dev/null +++ b/ngraph/test/files/paddlepaddle/gen_scripts/generate_argmax.py @@ -0,0 +1,60 @@ +# +# pool2d paddle model generator +# +import numpy as np +from save_model import saveModel +import sys +data_type = 'float32' + + +def pdpd_argmax(name : str, x, axis): + import paddle as pdpd + pdpd.enable_static() + + with pdpd.static.program_guard(pdpd.static.Program(), pdpd.static.Program()): + node_x = pdpd.static.data(name='x', shape=x.shape, dtype='float32') + out = pdpd.argmax(x=node_x, axis=axis) + out = pdpd.cast(out, np.float32) + cpu = pdpd.static.cpu_places(1) + exe = pdpd.static.Executor(cpu[0]) + # startup program will call initializer to initialize the parameters. + exe.run(pdpd.static.default_startup_program()) + + outs = exe.run( + feed={'x': x}, + fetch_list=[out]) + + saveModel(name, exe, feedkeys=['x'], fetchlist=[out], inputs=[x], outputs=[outs[0]], target_dir=sys.argv[1]) + + return outs[0] + +def pdpd_argmax1(name : str, x): + import paddle as pdpd + pdpd.enable_static() + + with pdpd.static.program_guard(pdpd.static.Program(), pdpd.static.Program()): + node_x = pdpd.static.data(name='x', shape=x.shape, dtype='float32') + out = pdpd.argmax(x=node_x) + out = pdpd.cast(out, np.float32) + cpu = pdpd.static.cpu_places(1) + exe = pdpd.static.Executor(cpu[0]) + # startup program will call initializer to initialize the parameters. + exe.run(pdpd.static.default_startup_program()) + + outs = exe.run( + feed={'x': x}, + fetch_list=[out]) + + saveModel(name, exe, feedkeys=['x'], fetchlist=[out], inputs=[x], outputs=[outs[0]], target_dir=sys.argv[1]) + + return outs[0] + +def main(): + data = np.random.random([3,5,7,2]).astype("float32") + axis = 0 + pdpd_argmax("argmax", data, axis) + pdpd_argmax1("argmax1", data) + + +if __name__ == "__main__": + main() diff --git a/ngraph/test/files/paddlepaddle/gen_scripts/generate_assign_value.py b/ngraph/test/files/paddlepaddle/gen_scripts/generate_assign_value.py new file mode 100644 index 00000000000000..7d29574b2a92b4 --- /dev/null +++ b/ngraph/test/files/paddlepaddle/gen_scripts/generate_assign_value.py @@ -0,0 +1,58 @@ +import numpy as np +from save_model import saveModel +import sys + + +def pdpd_assign_value(name, test_x): + import paddle as pdpd + pdpd.enable_static() + main_program = pdpd.static.Program() + startup_program = pdpd.static.Program() + with pdpd.static.program_guard(main_program, startup_program): + node_x = pdpd.static.data(name='x', shape=test_x.shape, dtype=test_x.dtype if test_x.dtype != np.bool else np.int32) + node_x = pdpd.cast(node_x, dtype=test_x.dtype) + const_value = pdpd.assign(test_x, output=None) + result = pdpd.cast(pdpd.concat([node_x, const_value], 0), dtype=np.float32) + cpu = pdpd.static.cpu_places(1) + exe = pdpd.static.Executor(cpu[0]) + # startup program will call initializer to initialize the parameters. + exe.run(pdpd.static.default_startup_program()) + if test_x.dtype == np.bool: + test_x = test_x.astype(np.int32) + + outs = exe.run( + feed={'x': test_x}, + fetch_list=[result] + ) + + saveModel(name, exe, feedkeys=['x'], fetchlist=[result], inputs=[test_x], outputs=[outs[0]], target_dir=sys.argv[1]) + + print(outs[0]) + + +def compare(): + + test_cases = [ + { + "name": "assign_value_fp32", + "input": np.ones([1, 1, 4, 4]).astype(np.float32) + }, + { + "name": "assign_value_int32", + "input": np.ones([1, 1, 4, 4]).astype(np.int32) + }, + { + "name": "assign_value_int64", + "input": np.ones([1, 1, 4, 4]).astype(np.int64) + }, + { + "name": "assign_value_boolean", + "input": np.array([False, True, False]) + } + ] + for test in test_cases: + pdpd_assign_value(test['name'], test['input']) + + +if __name__ == "__main__": + compare() diff --git a/ngraph/test/files/paddlepaddle/gen_scripts/generate_batch_norm.py b/ngraph/test/files/paddlepaddle/gen_scripts/generate_batch_norm.py new file mode 100644 index 00000000000000..fbbba99160c4da --- /dev/null +++ b/ngraph/test/files/paddlepaddle/gen_scripts/generate_batch_norm.py @@ -0,0 +1,89 @@ +# +# pool2d paddle model generator +# +import numpy as np +from save_model import saveModel +import sys + + +def batch_norm1(name : str, x, scale, bias, mean, var, data_layout): + import paddle as pdpd + pdpd.enable_static() + + node_x = pdpd.static.data(name='x', shape=x.shape, dtype='float32') + scale_attr = pdpd.ParamAttr(name="scale1", initializer=pdpd.nn.initializer.Assign(scale)) + bias_attr = pdpd.ParamAttr(name="bias1", initializer=pdpd.nn.initializer.Assign(bias)) + + out = pdpd.static.nn.batch_norm(node_x, epsilon=1e-5, + param_attr=scale_attr, + bias_attr=bias_attr, + moving_mean_name="bn_mean1", + moving_variance_name="bn_variance1", + use_global_stats=True, + data_layout=data_layout) + + cpu = pdpd.static.cpu_places(1) + exe = pdpd.static.Executor(cpu[0]) + # startup program will call initializer to initialize the parameters. + exe.run(pdpd.static.default_startup_program()) + pdpd.static.global_scope().var("bn_mean1").get_tensor().set(mean, pdpd.CPUPlace()) + pdpd.static.global_scope().var("bn_variance1").get_tensor().set(var, pdpd.CPUPlace()) + + outs = exe.run( + feed={'x': x}, + fetch_list=[out]) + + saveModel(name, exe, feedkeys=['x'], fetchlist=[out], inputs=[x], outputs=[outs[0]], target_dir=sys.argv[1]) + + return outs[0] + +def batch_norm2(name : str, x, scale, bias, mean, var, data_layout): + import paddle as pdpd + pdpd.enable_static() + + node_x = pdpd.static.data(name='x', shape=x.shape, dtype='float32') + scale_attr = pdpd.ParamAttr(name="scale2", initializer=pdpd.nn.initializer.Assign(scale)) + bias_attr = pdpd.ParamAttr(name="bias2", initializer=pdpd.nn.initializer.Assign(bias)) + + out = pdpd.static.nn.batch_norm(node_x, epsilon=1e-5, + param_attr=scale_attr, + bias_attr=bias_attr, + moving_mean_name="bn_mean2", + moving_variance_name="bn_variance2", + use_global_stats=True, + data_layout=data_layout) + + cpu = pdpd.static.cpu_places(1) + exe = pdpd.static.Executor(cpu[0]) + # startup program will call initializer to initialize the parameters. + exe.run(pdpd.static.default_startup_program()) + pdpd.static.global_scope().var("bn_mean2").get_tensor().set(mean, pdpd.CPUPlace()) + pdpd.static.global_scope().var("bn_variance2").get_tensor().set(var, pdpd.CPUPlace()) + + outs = exe.run( + feed={'x': x}, + fetch_list=[out]) + + saveModel(name, exe, feedkeys=['x'], fetchlist=[out], inputs=[x], outputs=[outs[0]], target_dir=sys.argv[1]) + + return outs[0] + +def main(): + import paddle as pdpd + data = np.array([[[[-1, 0, 1]], [[2, 3, 4]]]]).astype(np.float32) + # data layout is NCHW + scale = np.array([1.0, 1.5]).astype(np.float32) + bias = np.array([0, 1]).astype(np.float32) + mean = np.array([0, 3]).astype(np.float32) + var = np.array([1, 1.5]).astype(np.float32) + batch_norm1("batch_norm_nchw", data, scale, bias, mean, var, "NCHW") + + # data layout is NHWC + scale = np.array([1.0, 1.5, 2.0]).astype(np.float32) + bias = np.array([0, 1, 2]).astype(np.float32) + mean = np.array([0.5, 1.5, 1.5]).astype(np.float32) + var = np.array([1, 1.5, 2]).astype(np.float32) + batch_norm2("batch_norm_nhwc", data, scale, bias, mean, var, "NHWC") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/ngraph/test/files/paddlepaddle/gen_scripts/generate_clip.py b/ngraph/test/files/paddlepaddle/gen_scripts/generate_clip.py new file mode 100644 index 00000000000000..55edd6c62dd0d2 --- /dev/null +++ b/ngraph/test/files/paddlepaddle/gen_scripts/generate_clip.py @@ -0,0 +1,39 @@ +# +# clip paddle model generator +# +import numpy as np +from save_model import saveModel +import sys + +def clip(name: str, x, min, max): + import paddle as pdpd + pdpd.enable_static() + + with pdpd.static.program_guard(pdpd.static.Program(), pdpd.static.Program()): + node_x = pdpd.static.data(name='x', shape=x.shape, dtype='float32') + out = pdpd.fluid.layers.clip(node_x, min=min, max=max) + + cpu = pdpd.static.cpu_places(1) + exe = pdpd.static.Executor(cpu[0]) + # startup program will call initializer to initialize the parameters. + exe.run(pdpd.static.default_startup_program()) + + outs = exe.run( + feed={'x': x}, + fetch_list=[out]) + + saveModel(name, exe, feedkeys=['x'], fetchlist=[out], inputs=[x], outputs=[outs[0]], target_dir=sys.argv[1]) + + return outs[0] + + +def main(): + data = np.random.random([2, 3, 4]).astype('float32') + min = 0 + max = 0.8 + + clip("clip", data, min, max) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/ngraph/test/frontend/paddlepaddle/op_fuzzy.cpp b/ngraph/test/frontend/paddlepaddle/op_fuzzy.cpp new file mode 100644 index 00000000000000..a2d29bfb6e6aa6 --- /dev/null +++ b/ngraph/test/frontend/paddlepaddle/op_fuzzy.cpp @@ -0,0 +1,40 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "util/engine/test_engines.hpp" +#include "util/test_control.hpp" +#include +#include "op_fuzzy.hpp" +#include "ngraph/ngraph.hpp" + +using namespace ngraph; +using namespace InferenceEngine; +using namespace ngraph; +using namespace ngraph::frontend; +using TestEngine = test::IE_CPU_Engine; + +static const std::string PDPD = "pdpd"; +using PDPDFuzzyOpTest = FrontEndFuzzyOpTest; + +static const std::vector models{ + std::string("argmax"), + std::string("argmax1"), + std::string("assign_value_boolean"), + std::string("assign_value_fp32"), + std::string("assign_value_int32"), + std::string("assign_value_int64"), + std::string("batch_norm_nchw"), + std::string("batch_norm_nhwc"), + std::string("clip"), + std::string("relu"), +}; + +INSTANTIATE_TEST_SUITE_P(PDPDFuzzyOpTest, + FrontEndFuzzyOpTest, + ::testing::Combine(::testing::Values(PDPD), + ::testing::Values(std::string(TEST_PDPD_MODELS)), + ::testing::ValuesIn(models)), + PDPDFuzzyOpTest::getTestCaseName); diff --git a/ngraph/test/frontend/shared/include/op_fuzzy.hpp b/ngraph/test/frontend/shared/include/op_fuzzy.hpp new file mode 100644 index 00000000000000..0d11df6c8b51b0 --- /dev/null +++ b/ngraph/test/frontend/shared/include/op_fuzzy.hpp @@ -0,0 +1,35 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include + +using FuzzyOpTestParam = std::tuple; // Model name + +class FrontEndFuzzyOpTest : public ::testing::TestWithParam +{ +public: + std::string m_feName; + std::string m_pathToModels; + std::string m_modelFile; + ngraph::frontend::FrontEndManager m_fem; + ngraph::frontend::FrontEnd::Ptr m_frontEnd; + ngraph::frontend::InputModel::Ptr m_inputModel; + + static std::string getTestCaseName(const testing::TestParamInfo& obj); + + void SetUp() override; + +protected: + void initParamTest(); + + void doLoadFromFile(); + + void runConvertedModel(const std::shared_ptr function, const std::string& model_file); +}; diff --git a/ngraph/test/frontend/shared/src/op_fuzzy.cpp b/ngraph/test/frontend/shared/src/op_fuzzy.cpp new file mode 100644 index 00000000000000..526207d25a66ce --- /dev/null +++ b/ngraph/test/frontend/shared/src/op_fuzzy.cpp @@ -0,0 +1,163 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "util/engine/test_engines.hpp" +#include "util/test_case.hpp" +#include "util/test_control.hpp" +#include "op_fuzzy.hpp" +#include "utils.hpp" + +using namespace ngraph; +using namespace InferenceEngine; + +using namespace ngraph; +using namespace ngraph::frontend; +using TestEngine = test::IE_CPU_Engine; + +std::string + FrontEndFuzzyOpTest::getTestCaseName(const testing::TestParamInfo& obj) +{ + std::string fe, path, fileName; + std::tie(fe, path, fileName) = obj.param; + return fe + "_" + FrontEndTestUtils::fileToTestName(fileName); +} + +void FrontEndFuzzyOpTest::SetUp() +{ + FrontEndTestUtils::setupTestEnv(); + m_fem = FrontEndManager(); // re-initialize after setting up environment + initParamTest(); +} + +void FrontEndFuzzyOpTest::initParamTest() +{ + std::tie(m_feName, m_pathToModels, m_modelFile) = GetParam(); + m_modelFile = m_pathToModels + m_modelFile; +} + +void FrontEndFuzzyOpTest::doLoadFromFile() +{ + std::vector frontends; + ASSERT_NO_THROW(frontends = m_fem.get_available_front_ends()); + ASSERT_NO_THROW(m_frontEnd = m_fem.load_by_framework(m_feName)); + ASSERT_NE(m_frontEnd, nullptr); + ASSERT_NO_THROW(m_inputModel = m_frontEnd->load_from_file(m_modelFile)); + ASSERT_NE(m_inputModel, nullptr); +} + +template +inline void addInputOutput(cnpy::NpyArray& npy_array, + test::TestCase& test_case, + bool is_input = true) +{ + T* npy_begin = npy_array.data(); + std::vector data(npy_begin, npy_begin + npy_array.num_vals); + if (is_input) + test_case.add_input(data); + else + test_case.add_expected_output(data); +} + +static bool ends_with(std::string const& value, std::string const& ending) +{ + if (ending.size() > value.size()) + return false; + return std::equal(ending.rbegin(), ending.rend(), value.rbegin()); +} + +static std::string getModelFolder(const std::string& modelFile) +{ + if (!ends_with(modelFile, ".pdmodel")) + return modelFile; + size_t found = modelFile.find_last_of("/\\"); + return modelFile.substr(0, found); +}; + +void FrontEndFuzzyOpTest::runConvertedModel(const std::shared_ptr function, + const std::string& modelFile) +{ + auto modelFolder = getModelFolder(modelFile); + + // run test + auto testCase = test::TestCase(function); + + const auto parameters = function->get_parameters(); + for (size_t i = 0; i < parameters.size(); i++) + { + // read input npy file + std::string dataFile = + modelFolder + "/input" + std::to_string((parameters.size() - 1) - i) + ".npy"; + cnpy::NpyArray input = cnpy::npy_load(dataFile); + auto input_dtype = parameters[i]->get_element_type(); + + if (input_dtype == element::f32) + { + addInputOutput(input, testCase, true); + } + else if (input_dtype == element::i32) + { + addInputOutput(input, testCase, true); + } + else if (input_dtype == element::i64) + { + addInputOutput(input, testCase, true); + } + else + { + throw std::runtime_error("not supported dtype in" + input_dtype.get_type_name()); + } + } + + const auto results = function->get_results(); + bool useFloatTest = false; + for (size_t i = 0; i < results.size(); i++) + { + // read expected output npy file + std::string dataFile = modelFolder + "/output" + std::to_string(i) + ".npy"; + cnpy::NpyArray output = cnpy::npy_load(dataFile); + auto outputDtype = results[i]->get_element_type(); + if (outputDtype == element::f32) + { + addInputOutput(output, testCase, false); + useFloatTest = true; + } + else if (outputDtype == element::i32) + { + addInputOutput(output, testCase, false); + } + else if (outputDtype == element::i64) + { + addInputOutput(output, testCase, false); + } + else + { + throw std::runtime_error("not supported dtype out " + outputDtype.get_type_name()); + } + } + + if (useFloatTest) + { + testCase.run_with_tolerance_as_fp(); + } + else + { + testCase.run(); + } +} + +TEST_P(FrontEndFuzzyOpTest, testOpFuzzy) +{ + // load + ASSERT_NO_THROW(doLoadFromFile()); + + // convert + std::shared_ptr function; + function = m_frontEnd->convert(m_inputModel); + ASSERT_NE(function, nullptr); + + // run + runConvertedModel(function, m_modelFile); +}