Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement paddle frontend methods for partial conversion #6784

Merged
merged 18 commits into from
Jul 30, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ TEST(PDPD_Reader_Tests, ImportBasicModelToCore) {
"RefPDPDFunction");
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::NAMES);
const FunctionsComparator::Result res = func_comparator(function, reference);
ASSERT_TRUE(res.valid);
ASSERT_TRUE(res.valid) << res.message;
}

#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
Expand Down Expand Up @@ -79,6 +79,6 @@ TEST(PDPD_Reader_Tests, ImportBasicModelToCoreWstring) {
"RefPDPDFunction");
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::NAMES);
const FunctionsComparator::Result res = func_comparator(function, reference);
ASSERT_TRUE(res.valid);
ASSERT_TRUE(res.valid) << res.message;
}
#endif
2 changes: 1 addition & 1 deletion ngraph/frontend/paddlepaddle/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ endif()
link_system_libraries(${TARGET_NAME} PRIVATE ${Protobuf_LITE_LIBRARIES})

target_link_libraries(${TARGET_NAME} PRIVATE ngraph::frontend_manager::static
PRIVATE ngraph::builder)
PRIVATE ngraph::builder inference_engine_transformations)

add_clang_format_target(${TARGET_NAME}_clang FOR_TARGETS ${TARGET_NAME}
EXCLUDE_PATTERNS ${PROTO_SRCS} ${PROTO_HDRS})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <frontend_manager/frontend_manager.hpp>
#include "exceptions.hpp"
#include "model.hpp"
#include "place.hpp"
mvafin marked this conversation as resolved.
Show resolved Hide resolved

namespace ngraph
{
Expand All @@ -22,6 +23,32 @@ namespace ngraph
/// \return fully converted nGraph function
std::shared_ptr<Function> convert(InputModel::Ptr model) const override;

/// \brief Completely convert the remaining, not converted part of a function.
/// \param partiallyConverted partially converted nGraph function
/// \return fully converted nGraph function
std::shared_ptr<ngraph::Function>
convert(std::shared_ptr<ngraph::Function> partiallyConverted) const override;
mvafin marked this conversation as resolved.
Show resolved Hide resolved

/// \brief Convert only those parts of the model that can be converted leaving others
/// as-is. Converted parts are not normalized by additional transformations; normalize
/// function or another form of convert function should be called to finalize the
/// conversion process.
/// \param model Input model
/// \return partially converted nGraph function
std::shared_ptr<ngraph::Function>
convert_partially(InputModel::Ptr model) const override;
mvafin marked this conversation as resolved.
Show resolved Hide resolved

/// \brief Convert operations with one-to-one mapping with decoding nodes.
/// Each decoding node is an nGraph node representing a single FW operation node with
/// all attributes represented in FW-independent way.
/// \param model Input model
/// \return nGraph function after decoding
std::shared_ptr<ngraph::Function> decode(InputModel::Ptr model) const override;

/// \brief Runs normalization passes on function that was loaded with partial conversion
/// \param function partially converted nGraph function
void normalize(std::shared_ptr<ngraph::Function> function) const override;
mvafin marked this conversation as resolved.
Show resolved Hide resolved

protected:
/// \brief Check if FrontEndPDPD can recognize model from given parts
/// \param params Can be path to folder which contains __model__ file or path to
Expand All @@ -40,7 +67,10 @@ namespace ngraph

private:
static std::shared_ptr<Function>
convert_model(const std::shared_ptr<InputModelPDPD>& model);
convert_each_node(const std::shared_ptr<InputModelPDPD>& model,
mvafin marked this conversation as resolved.
Show resolved Hide resolved
std::function<std::map<std::string, OutputVector>(
const std::map<std::string, Output<Node>>&,
const std::shared_ptr<OpPlacePDPD>&)> func);
};

} // namespace frontend
Expand Down
45 changes: 45 additions & 0 deletions ngraph/frontend/paddlepaddle/src/decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,16 @@ namespace ngraph
return output_names;
}

size_t DecoderPDPDProto::get_output_size() const
{
size_t res = 0;
for (const auto& output : op_place->get_desc().outputs())
{
res += output.arguments().size();
}
return res;
}

ngraph::element::Type
DecoderPDPDProto::get_out_port_type(const std::string& port_name) const
{
Expand Down Expand Up @@ -135,5 +145,40 @@ namespace ngraph
" Expected number: 0 or 1");
return attrs;
}

std::map<std::string, OutputVector> DecoderPDPDProto::map_for_each_input(
std::function<Output<Node>(const std::string&)> func) const
mvafin marked this conversation as resolved.
Show resolved Hide resolved
{
std::map<std::string, OutputVector> res;
for (const auto& port : op_place->get_desc().inputs())
{
std::vector<Output<Node>> v;
v.reserve(port.arguments_size());
for (const auto& inp : port.arguments())
{
v.push_back(func(inp));
}
res.emplace(std::make_pair(port.parameter(), v));
}
return res;
}

std::map<std::string, OutputVector> DecoderPDPDProto::map_for_each_output(
std::function<Output<Node>(const std::string&)> func) const
{
std::map<std::string, OutputVector> res;
for (const auto& port : op_place->get_desc().outputs())
{
std::vector<Output<Node>> v;
v.reserve(port.arguments_size());
for (const auto& out : port.arguments())
{
v.push_back(func(out));
}
res.emplace(std::make_pair(port.parameter(), v));
}
return res;
}

} // namespace frontend
} // namespace ngraph
8 changes: 8 additions & 0 deletions ngraph/frontend/paddlepaddle/src/decoder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,18 @@ namespace ngraph

std::vector<pdpd::OutPortName> get_output_names() const override;

size_t get_output_size() const override;

ngraph::element::Type get_out_port_type(const std::string& port_name) const override;

std::string get_op_type() const override;

std::map<std::string, OutputVector>
map_for_each_input(std::function<Output<Node>(const std::string&)> func) const;

std::map<std::string, OutputVector>
map_for_each_output(std::function<Output<Node>(const std::string&)> func) const;

private:
std::vector<paddle::framework::proto::OpDesc_Attr>
decode_attribute_helper(const std::string& name) const;
Expand Down
Loading