Skip to content

Commit

Permalink
Implement paddle frontend methods for partial conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
mvafin committed Jul 25, 2021
1 parent 6e8b0e0 commit be9fba2
Show file tree
Hide file tree
Showing 13 changed files with 542 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ TEST(PDPD_Reader_Tests, ImportBasicModelToCore) {
"RefPDPDFunction");
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::NAMES);
const FunctionsComparator::Result res = func_comparator(function, reference);
ASSERT_TRUE(res.valid);
ASSERT_TRUE(res.valid) << res.message;
}

#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
Expand Down Expand Up @@ -79,6 +79,6 @@ TEST(PDPD_Reader_Tests, ImportBasicModelToCoreWstring) {
"RefPDPDFunction");
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::NAMES);
const FunctionsComparator::Result res = func_comparator(function, reference);
ASSERT_TRUE(res.valid);
ASSERT_TRUE(res.valid) << res.message;
}
#endif
2 changes: 1 addition & 1 deletion ngraph/frontend/paddlepaddle/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ endif()
link_system_libraries(${TARGET_NAME} PRIVATE ${Protobuf_LITE_LIBRARIES})

target_link_libraries(${TARGET_NAME} PRIVATE ngraph::frontend_manager::static
PRIVATE ngraph::builder)
PRIVATE ngraph::builder inference_engine_transformations)

add_clang_format_target(${TARGET_NAME}_clang FOR_TARGETS ${TARGET_NAME}
EXCLUDE_PATTERNS ${PROTO_SRCS} ${PROTO_HDRS})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <frontend_manager/frontend_manager.hpp>
#include "exceptions.hpp"
#include "model.hpp"
#include "place.hpp"

namespace ngraph
{
Expand All @@ -22,6 +23,32 @@ namespace ngraph
/// \return fully converted nGraph function
std::shared_ptr<Function> convert(InputModel::Ptr model) const override;

/// \brief Completely convert the remaining, not converted part of a function.
/// \param partiallyConverted partially converted nGraph function
/// \return fully converted nGraph function
std::shared_ptr<ngraph::Function>
convert(std::shared_ptr<ngraph::Function> partiallyConverted) const override;

/// \brief Convert only those parts of the model that can be converted leaving others
/// as-is. Converted parts are not normalized by additional transformations; normalize
/// function or another form of convert function should be called to finalize the
/// conversion process.
/// \param model Input model
/// \return partially converted nGraph function
std::shared_ptr<ngraph::Function>
convert_partially(InputModel::Ptr model) const override;

/// \brief Convert operations with one-to-one mapping with decoding nodes.
/// Each decoding node is an nGraph node representing a single FW operation node with
/// all attributes represented in FW-independent way.
/// \param model Input model
/// \return nGraph function after decoding
std::shared_ptr<ngraph::Function> decode(InputModel::Ptr model) const override;

/// \brief Runs normalization passes on function that was loaded with partial conversion
/// \param function partially converted nGraph function
void normalize(std::shared_ptr<ngraph::Function> function) const override;

protected:
/// \brief Check if FrontEndPDPD can recognize model from given parts
/// \param params Can be path to folder which contains __model__ file or path to
Expand All @@ -40,7 +67,10 @@ namespace ngraph

private:
static std::shared_ptr<Function>
convert_model(const std::shared_ptr<InputModelPDPD>& model);
convert_each_node(const std::shared_ptr<InputModelPDPD>& model,
std::function<std::map<std::string, OutputVector>(
const std::map<std::string, Output<Node>>&,
const std::shared_ptr<OpPlacePDPD>&)> func);
};

} // namespace frontend
Expand Down
45 changes: 45 additions & 0 deletions ngraph/frontend/paddlepaddle/src/decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,16 @@ namespace ngraph
return output_names;
}

size_t DecoderPDPDProto::get_output_size() const
{
size_t res = 0;
for (const auto& output : op_place->get_desc().outputs())
{
res += output.arguments().size();
}
return res;
}

ngraph::element::Type
DecoderPDPDProto::get_out_port_type(const std::string& port_name) const
{
Expand Down Expand Up @@ -135,5 +145,40 @@ namespace ngraph
" Expected number: 0 or 1");
return attrs;
}

std::map<std::string, OutputVector> DecoderPDPDProto::map_for_each_input(
std::function<Output<Node>(const std::string&)> func) const
{
std::map<std::string, OutputVector> res;
for (const auto& port : op_place->get_desc().inputs())
{
std::vector<Output<Node>> v;
v.reserve(port.arguments_size());
for (const auto& inp : port.arguments())
{
v.push_back(func(inp));
}
res.emplace(std::make_pair(port.parameter(), v));
}
return res;
}

std::map<std::string, OutputVector> DecoderPDPDProto::map_for_each_output(
std::function<Output<Node>(const std::string&)> func) const
{
std::map<std::string, OutputVector> res;
for (const auto& port : op_place->get_desc().outputs())
{
std::vector<Output<Node>> v;
v.reserve(port.arguments_size());
for (const auto& out : port.arguments())
{
v.push_back(func(out));
}
res.emplace(std::make_pair(port.parameter(), v));
}
return res;
}

} // namespace frontend
} // namespace ngraph
8 changes: 8 additions & 0 deletions ngraph/frontend/paddlepaddle/src/decoder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,18 @@ namespace ngraph

std::vector<pdpd::OutPortName> get_output_names() const override;

size_t get_output_size() const override;

ngraph::element::Type get_out_port_type(const std::string& port_name) const override;

std::string get_op_type() const override;

std::map<std::string, OutputVector>
map_for_each_input(std::function<Output<Node>(const std::string&)> func) const;

std::map<std::string, OutputVector>
map_for_each_output(std::function<Output<Node>(const std::string&)> func) const;

private:
std::vector<paddle::framework::proto::OpDesc_Attr>
decode_attribute_helper(const std::string& name) const;
Expand Down
Loading

0 comments on commit be9fba2

Please sign in to comment.