Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Constant attr to visitor #3540

Merged
merged 8 commits into from
Dec 23, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//

#include <array>
#include <cstdint>
#include <fstream>
#include <unordered_map>
#include <unordered_set>
Expand All @@ -18,53 +19,88 @@ using namespace ngraph;
NGRAPH_RTTI_DEFINITION(ngraph::pass::Serialize, "Serialize", 0);

namespace { // helpers
template <typename T, typename A>
std::string joinVec(const std::vector<T, A>& vec,
const std::string& glue = std::string(",")) {
if (vec.empty()) return "";
template <typename Container>
std::string join(const Container& c, const char* glue = ", ") {
pelszkow marked this conversation as resolved.
Show resolved Hide resolved
std::stringstream oss;
oss << vec[0];
for (size_t i = 1; i < vec.size(); i++) oss << glue << vec[i];
const char* s = "";
for (const auto& v : c) {
oss << s << v;
s = glue;
}
return oss.str();
}
} // namespace

namespace { // implementation details
struct Edge {
int from_layer = 0;
int from_port = 0;
int to_layer = 0;
int to_port = 0;
};

struct ConstantAtributes {
int size = 0;
int offset = 0;
};
// Here operation type names are translated from ngraph convention to IR
// convention. Most of them are the same, but there are exceptions, e.g
// Constant (ngraph name) and Const (IR name). If there will be more
// discrepancies discoverd, translations needs to be added here.
const std::unordered_map<std::string, std::string> translate_type_name_translator = {
{"Constant", "Const"},
{"Relu", "ReLU"},
{"Softmax", "SoftMax"}};

std::string translate_type_name(const std::string& name) {
auto found = translate_type_name_translator.find(name);
if (found != end(translate_type_name_translator)) {
return found->second;
}
return name;
}

class XmlVisitor : public ngraph::AttributeVisitor {
pugi::xml_node& m_data;
class XmlSerializer : public ngraph::AttributeVisitor {
pugi::xml_node& m_xml_node;
std::ostream& m_bin_data;
std::string& m_node_type_name;

template <typename T>
std::string create_atribute_list(
ngraph::ValueAccessor<std::vector<T>>& adapter) {
return joinVec(adapter.get(), std::string(","));
return join(adapter.get());
}

public:
XmlVisitor(pugi::xml_node& data, std::string& node_type_name)
: m_data(data), m_node_type_name(node_type_name) {}
XmlSerializer(pugi::xml_node& data,
std::ostream& bin_data,
std::string& node_type_name)
: m_xml_node(data)
, m_bin_data(bin_data)
, m_node_type_name(node_type_name) {
}

void on_adapter(const std::string& name,
ngraph::ValueAccessor<void>& adapter) override {
#if 0 // TODO: remove when Constant will support VisitorAPI
m_data.append_attribute(name.c_str());
#endif
(void)name;
(void)adapter;
}

void on_adapter(const std::string& name,
ngraph::ValueAccessor<void*>& adapter) override {
if (name == "value" && translate_type_name(m_node_type_name) == "Const") {
using AlignedBufferAdapter =
ngraph::AttributeAdapter<std::shared_ptr<runtime::AlignedBuffer>>;
if (auto a = ngraph::as_type<AlignedBufferAdapter>(&adapter)) {
const int64_t size = a->size();
const int64_t offset = m_bin_data.tellp();

m_xml_node.append_attribute("offset").set_value(offset);
m_xml_node.append_attribute("size").set_value(size);

auto data = static_cast<const char*>(a->get_ptr());
m_bin_data.write(data, size);
}
}
}

void on_adapter(const std::string& name,
ngraph::ValueAccessor<bool>& adapter) override {
m_data.append_attribute(name.c_str()).set_value(adapter.get());
m_xml_node.append_attribute(name.c_str()).set_value(adapter.get());
}
void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::string>& adapter) override {
Expand All @@ -75,40 +111,40 @@ class XmlVisitor : public ngraph::AttributeVisitor {
// it is a WA to not introduce dependency on plugin_api library
m_node_type_name = adapter.get();
} else {
m_data.append_attribute(name.c_str())
m_xml_node.append_attribute(name.c_str())
.set_value(adapter.get().c_str());
}
}
void on_adapter(const std::string& name,
ngraph::ValueAccessor<int64_t>& adapter) override {
m_data.append_attribute(name.c_str()).set_value(adapter.get());
m_xml_node.append_attribute(name.c_str()).set_value(adapter.get());
}
void on_adapter(const std::string& name,
ngraph::ValueAccessor<double>& adapter) override {
m_data.append_attribute(name.c_str()).set_value(adapter.get());
m_xml_node.append_attribute(name.c_str()).set_value(adapter.get());
}
void on_adapter(
const std::string& name,
ngraph::ValueAccessor<std::vector<int64_t>>& adapter) override {
m_data.append_attribute(name.c_str())
m_xml_node.append_attribute(name.c_str())
.set_value(create_atribute_list(adapter).c_str());
}
void on_adapter(
const std::string& name,
ngraph::ValueAccessor<std::vector<uint64_t>>& adapter) override {
m_data.append_attribute(name.c_str())
m_xml_node.append_attribute(name.c_str())
.set_value(create_atribute_list(adapter).c_str());
}
void on_adapter(
const std::string& name,
ngraph::ValueAccessor<std::vector<float>>& adapter) override {
m_data.append_attribute(name.c_str())
m_xml_node.append_attribute(name.c_str())
.set_value(create_atribute_list(adapter).c_str());
}
void on_adapter(
const std::string& name,
ngraph::ValueAccessor<std::vector<std::string>>& adapter) override {
m_data.append_attribute(name.c_str())
m_xml_node.append_attribute(name.c_str())
.set_value(create_atribute_list(adapter).c_str());
}
};
Expand Down Expand Up @@ -175,19 +211,7 @@ const std::vector<Edge> create_edge_mapping(
return edges;
}

// TODO: refactor to Vistor API when Constant will be supporting it
ConstantAtributes dump_constant_data(std::vector<uint8_t>& bin,
const ngraph::op::Constant& c) {
NGRAPH_CHECK(c.get_output_partial_shape(0.).is_static(),
"Unsupported dynamic output shape in ", c);

ConstantAtributes attr;
const uint8_t* p = reinterpret_cast<const uint8_t*>(c.get_data_ptr());
attr.size = ngraph::shape_size(c.get_shape()) * c.get_element_type().size();
attr.offset = bin.size();
bin.insert(end(bin), p, p + attr.size);
return attr;
}


std::string get_opset_name(
const ngraph::Node* n,
Expand All @@ -214,20 +238,6 @@ std::string get_opset_name(
return "experimental";
}

// Here operation type names are translated from ngraph convention to IR
// convention. Most of them are the same, but there are exceptions, e.g
// Constant (ngraph name) and Const (IR name). If there will be more
// discrepancies discoverd, translations needs to be added here.
std::string translate_type_name(std::string name) {
const std::unordered_map<std::string, std::string> translator = {
{"Constant", "Const"},
{"Relu", "ReLU"},
{"Softmax", "SoftMax"}};
if (translator.count(name) > 0) {
name = translator.at(name);
}
return name;
}

std::string get_output_precision_name(ngraph::Output<Node>& o) {
auto elem_type = o.get_element_type();
Expand Down Expand Up @@ -364,10 +374,10 @@ bool resolve_dynamic_shapes(const ngraph::Function& f) {
return true;
}

void ngfunction_2_irv10(
pugi::xml_document& doc, std::vector<uint8_t>& bin,
ngraph::Function& f,
const std::map<std::string, ngraph::OpSet>& custom_opsets) {
void ngfunction_2_irv10(pugi::xml_document& doc,
std::ostream& bin_file,
ngraph::Function& f,
pelszkow marked this conversation as resolved.
Show resolved Hide resolved
const std::map<std::string, ngraph::OpSet>& custom_opsets) {
const bool exec_graph = is_exec_graph(f);

pugi::xml_node netXml = doc.append_child("net");
Expand Down Expand Up @@ -403,7 +413,7 @@ void ngfunction_2_irv10(
if (exec_graph) {
visit_exec_graph_node(data, node_type_name, node);
} else {
XmlVisitor visitor(data, node_type_name);
XmlSerializer visitor(data, bin_file, node_type_name);
NGRAPH_CHECK(node->visit_attributes(visitor),
"Visitor API is not supported in ", node);
}
Expand All @@ -416,13 +426,6 @@ void ngfunction_2_irv10(
layer.remove_child(data);
}

// <layers/data> constant atributes (special case)
if (auto constant = dynamic_cast<ngraph::op::Constant*>(node)) {
ConstantAtributes attr = dump_constant_data(bin, *constant);
data.append_attribute("offset").set_value(attr.offset);
data.append_attribute("size").set_value(attr.size);
}

int port_id = 0;
// <layers/input>
if (node->get_input_size() > 0) {
Expand Down Expand Up @@ -482,10 +485,11 @@ void ngfunction_2_irv10(
bool pass::Serialize::run_on_function(std::shared_ptr<ngraph::Function> f) {
// prepare data
pugi::xml_document xml_doc;
std::vector<uint8_t> constants;
std::ofstream bin_file(m_binPath, std::ios::out | std::ios::binary);
NGRAPH_CHECK(bin_file, "Can't open bin file");
switch (m_version) {
case Version::IR_V10:
ngfunction_2_irv10(xml_doc, constants, *f, m_custom_opsets);
ngfunction_2_irv10(xml_doc, bin_file, *f, m_custom_opsets);
break;
default:
NGRAPH_UNREACHABLE("Unsupported version");
Expand All @@ -494,12 +498,10 @@ bool pass::Serialize::run_on_function(std::shared_ptr<ngraph::Function> f) {

// create xml file
std::ofstream xml_file(m_xmlPath, std::ios::out);
NGRAPH_CHECK(xml_file, "Can't open xml file");
xml_doc.save(xml_file);

// create bin file
std::ofstream bin_file(m_binPath, std::ios::out | std::ios::binary);
bin_file.write(reinterpret_cast<const char*>(constants.data()),
constants.size() * sizeof(constants[0]));
xml_file.flush();
bin_file.flush();

// Return false because we didn't change nGraph Function
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,16 @@ typedef std::tuple<std::string> SerializationParams;
class SerializationTest: public CommonTestUtils::TestsCommon,
public testing::WithParamInterface<SerializationParams> {
public:
std::string m_model_path;
std::string m_out_xml_path;
std::string m_out_bin_path;

void SetUp() override {
const auto & model_path = IR_SERIALIZATION_MODELS_PATH + std::get<0>(GetParam());
m_model_path = IR_SERIALIZATION_MODELS_PATH + std::get<0>(GetParam());

const std::string test_name = "test"; // ::testing::UnitTest::GetInstance()->current_test_info()->name();
const std::string test_name = GetTestName() + "_" + GetTimestamp();
m_out_xml_path = test_name + ".xml";
m_out_bin_path = test_name + ".bin";

InferenceEngine::Core ie;
auto expected = ie.ReadNetwork(model_path);
expected.serialize(m_out_xml_path, m_out_bin_path);
auto result = ie.ReadNetwork(m_out_xml_path, m_out_bin_path);

bool success;
std::string message;
std::tie(success, message) = compare_functions(result.getFunction(), expected.getFunction());
ASSERT_TRUE(success) << message;
}

void TearDown() override {
Expand All @@ -45,6 +36,15 @@ class SerializationTest: public CommonTestUtils::TestsCommon,
};

TEST_P(SerializationTest, CompareFunctions) {
InferenceEngine::Core ie;
auto expected = ie.ReadNetwork(m_model_path);
expected.serialize(m_out_xml_path, m_out_bin_path);
auto result = ie.ReadNetwork(m_out_xml_path, m_out_bin_path);

bool success;
std::string message;
std::tie(success, message) = compare_functions(result.getFunction(), expected.getFunction(), true);
ASSERT_TRUE(success) << message;
}

INSTANTIATE_TEST_CASE_P(IRSerialization, SerializationTest,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,6 @@ class LayerTestsCommon : public CommonTestUtils::TestsCommon {

private:
RefMode refMode = RefMode::INTERPRETER;
static std::string GetTimestamp();
const std::string GetTestName();
};

} // namespace LayerTestsUtils
Original file line number Diff line number Diff line change
Expand Up @@ -473,18 +473,4 @@ std::map<std::string, std::string> &LayerTestsCommon::GetConfiguration() {
return configuration;
}

std::string LayerTestsCommon::GetTimestamp() {
auto now = std::chrono::system_clock::now();
auto epoch = now.time_since_epoch();
auto ns = std::chrono::duration_cast<std::chrono::nanoseconds>(epoch);
return std::to_string(ns.count());
}

const std::string LayerTestsCommon::GetTestName() {
std::string test_name =
::testing::UnitTest::GetInstance()->current_test_info()->name();
std::replace_if(test_name.begin(), test_name.end(),
[](char c) { return !std::isalnum(c); }, '_');
return test_name;
}
} // namespace LayerTestsUtils
Loading