Skip to content

Commit

Permalink
Merge pull request #4 from nosovmik/mnosov/reader
Browse files Browse the repository at this point in the history
FrontEndManager - private implementation
  • Loading branch information
slyalin authored Apr 1, 2021
2 parents 3e3e9e3 + bf6324f commit 658ae1e
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -312,11 +312,19 @@ enum FrontEndCapabilities {
class NGRAPH_API FrontEndManager
{
public:
FrontEndManager () {}
FrontEnd::Ptr loadByFramework (const std::string& framework, FrontEndCapabilities fec = FEC_DEFAULT);
FrontEnd::Ptr loadByModel (const std::string& path, FrontEndCapabilities fec = FEC_DEFAULT);
std::vector<std::string> availableFrontEnds () const;
FrontEndManager();
~FrontEndManager();
FrontEnd::Ptr loadByFramework(const std::string& framework, FrontEndCapabilities fec = FEC_DEFAULT);
FrontEnd::Ptr loadByModel(const std::string& path, FrontEndCapabilities fec = FEC_DEFAULT);
std::vector<std::string> availableFrontEnds() const;

using FrontEndFactory = std::function<FrontEnd::Ptr(FrontEndCapabilities fec)>;
void registerFrontEnd(const std::string& name, FrontEndFactory creator);
private:
class Impl;
std::unique_ptr<Impl> m_impl;
};

} // namespace frontend

} // namespace ngraph
66 changes: 53 additions & 13 deletions ngraph/frontend/generic/src/frontend_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ namespace ngraph
{

#define FRONT_END_NOT_IMPLEMENTED(NAME) throw #NAME " is not implemented for this FrontEnd class";
#define FRONT_END_ASSERT(EXPRESSION) \
{ if (!(EXPRESSION)) throw "AssertionFailed"; }

std::vector<Place::Ptr> InputModel::getInputs () const
{
Expand Down Expand Up @@ -429,26 +431,64 @@ namespace ngraph
FRONT_END_NOT_IMPLEMENTED(normalize);
}

FrontEnd::Ptr FrontEndManager::loadByFramework (const std::string& framework, FrontEndCapabilities fec)
//////////////////////////////////////////////////////////////
class FrontEndManager::Impl
{
if (framework == "onnx")
return std::make_shared<FrontEndONNX>();
else if (framework == "pdpd")
return std::make_shared<FrontEndPDPD>();
else if (framework == "tf")
return std::make_shared<FrontEndTensorflow>();
else
throw "Framework " + framework + " is unknown for FrontEnd manager; cannot load it.";
std::map<std::string, FrontEndFactory> m_factories;

void registerDefault() {
registerFrontEnd("onnx", [](FrontEndCapabilities){return std::make_shared<FrontEndONNX>();});
registerFrontEnd("pdpd", [](FrontEndCapabilities){return std::make_shared<FrontEndPDPD>();});
registerFrontEnd("tf", [](FrontEndCapabilities){return std::make_shared<FrontEndTensorflow>();});
}
public:
Impl() {
registerDefault();
}
~Impl() = default;
FrontEnd::Ptr loadByFramework(const std::string& framework, FrontEndCapabilities fec) {
FRONT_END_ASSERT(m_factories.count(framework))
return m_factories[framework](fec);
}

std::vector<std::string> availableFrontEnds() const {
std::vector<std::string> keys;

std::transform(m_factories.begin(), m_factories.end(),
std::back_inserter(keys),
[](const std::pair<std::string, FrontEndFactory>& item) {
return item.first;
});
return keys;
}

FrontEnd::Ptr loadByModel (const std::string& path, FrontEndCapabilities fec)
{
FRONT_END_NOT_IMPLEMENTED(loadByModel);
}

void registerFrontEnd(const std::string& name, FrontEndFactory creator) {
m_factories.insert({name, creator});
}
};

FrontEndManager::FrontEndManager(): m_impl(new Impl()) {
}
FrontEndManager::~FrontEndManager() = default;

FrontEnd::Ptr FrontEndManager::loadByFramework(const std::string& framework, FrontEndCapabilities fec)
{
return m_impl->loadByFramework(framework, fec);
}

FrontEnd::Ptr FrontEndManager::loadByModel (const std::string& path, FrontEndCapabilities fec)
FrontEnd::Ptr FrontEndManager::loadByModel(const std::string& path, FrontEndCapabilities fec)
{
FRONT_END_NOT_IMPLEMENTED(loadByModel);
return m_impl->loadByModel(path, fec);
}

std::vector<std::string> FrontEndManager::availableFrontEnds () const
std::vector<std::string> FrontEndManager::availableFrontEnds() const
{
return {"onnx", "pdpd", "tf"};
return m_impl->availableFrontEnds();
}
} // namespace frontend

Expand Down
3 changes: 2 additions & 1 deletion ngraph/frontend/paddlepaddle/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


#include <algorithm>
#include <numeric>
#include <chrono>
#include <memory>
#include <map>
Expand Down Expand Up @@ -203,4 +204,4 @@ std::shared_ptr<ngraph::Function> ngraph::frontend::FrontEndPDPD::convert(InputM
}

}
}
}
1 change: 1 addition & 0 deletions ngraph/frontend/tensorflow/src/ngraph_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*******************************************************************************/

#include <numeric>
#include "graph.pb.h"
#include "tensor.pb.h"

Expand Down

0 comments on commit 658ae1e

Please sign in to comment.