-
Notifications
You must be signed in to change notification settings - Fork 1
/
ov_model_splitter.cpp
148 lines (133 loc) · 5.51 KB
/
ov_model_splitter.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <iostream>
#include <string>
#include <unordered_map>
#include <memory>
#include <fstream>
#include <chrono>
#include <openvino/opsets/opset8.hpp>
#include <openvino/pass/serialize.hpp>
#include <openvino/core/preprocess/pre_post_process.hpp>
#include "openvino/openvino.hpp"
#include <openvino/op/util/variable.hpp>
#include <openvino/op/sink.hpp>
using namespace ov;
void run_model(std::shared_ptr<ov::Model> model) {
ov::Core core;
auto exeNetwork = core.compile_model(model, "CPU");
auto infer_request = exeNetwork.create_infer_request();
ov::Tensor input_tensor = infer_request.get_input_tensor();
model->get_parameter_index(0);
std::vector<float > fake_input(10, 1);
std::memcpy(input_tensor.data<float>(), fake_input.data(), fake_input.size());
infer_request.infer();
}
int main(int args, char *argv[]) {
if (args < 4)
return -1;
std::string inputs(argv[2]);
std::string outputs(argv[3]);
std::string assigns = "";
if (args >= 5)
assigns = std::string(argv[4]);
bool hasInputConfig = false;
bool hasOutputConfig = false;
if (inputs.find(".config") != std::string::npos) {
hasInputConfig = true;
}
if (outputs.find(".config") != std::string::npos) {
hasOutputConfig = true;
}
ov::Core core;
auto model = core.read_model(argv[1]);
auto ordered_ops = model->get_ordered_ops();
std::unordered_map<std::string, std::shared_ptr<Node>> name2op = {};
//collect op mapps
for(auto& op : ordered_ops) {
name2op.emplace(op->get_friendly_name(), op);
}
auto readFile = [](std::string& name) {
std::cout << "Read Config " << name << std::endl;
std::fstream newfile;
std::vector<std::string> names;
newfile.open(name, std::ios::in);
if (newfile.is_open()){
std::string tp;
while(getline(newfile, tp)){
names.push_back(tp);
}
newfile.close(); //close the file object.
}
return names;
};
std::vector<std::string> target_input = hasInputConfig ? readFile(inputs) : std::vector<std::string>{inputs}; // "offset", "att_cache", "cnn_cache"};
std::vector<std::string> target_output = hasOutputConfig ? readFile(outputs) : std::vector<std::string>{outputs};
std::vector<std::string> target_assign = !assigns.empty() ? readFile(assigns) : std::vector<std::string>();
std::cout << "Start " << target_input[0] << std::endl;
std::cout << "End " << target_output[0] << std::endl;
std::vector<std::shared_ptr<opset8::Parameter> > subgraph_parameters = {};
std::vector<std::shared_ptr<opset8::Result> > subgraph_results = {};
for(auto& input_name : target_input) {
std::cout << "process input " << input_name << std::endl;
auto input_op = name2op.at(input_name);
if (auto node = ov::as_type_ptr<opset8::Parameter>(input_op)) {
std::cout << "keep original parameter " << input_name << std::endl;
subgraph_parameters.push_back(node);
continue;
}
size_t num_const = 0;
std::vector<int> index2non_const = {};
for(size_t i = 0; i < input_op->get_input_size(); i++) {
auto parent = input_op->get_input_node_shared_ptr(i);
if(ov::as_type_ptr<ov::opset8::Constant>(parent)) {
num_const++;
} else {
index2non_const.push_back(i);
}
}
for(auto& index : index2non_const) {
auto new_param = std::make_shared<opset8::Parameter>(input_op->get_input_element_type(index),
input_op->get_input_partial_shape(index));
new_param->output(0).set_names({new_param->get_friendly_name()});
input_op->input(index).replace_source_output(new_param->output(0));
subgraph_parameters.push_back(new_param);
}
}
for(auto& output_name : target_output) {
std::cout << "process output " << output_name << std::endl;
auto output_op = name2op.at(output_name);
if (auto node = ov::as_type_ptr<opset8::Result>(output_op)) {
std::cout << "keep original result " << output_name << std::endl;
subgraph_results.push_back(node);
continue;
}
if (output_op->get_output_size() !=1) {
throw std::runtime_error("input must has 1 child");
}
auto node_copy = output_op->clone_with_new_inputs(output_op->input_values());
auto new_result = std::make_shared<opset8::Result>(node_copy);
ov::replace_node(output_op, new_result);
subgraph_results.push_back(new_result);
}
auto subgraph = std::make_shared<ov::Model>(subgraph_results, subgraph_parameters);
SinkVector sinks;
// std::vector<std::string> target_thinks = {"Assign_196986", "Assign_196988"};
for (auto& assign_name : target_assign) {
std::cout << "process sink" << assign_name << std::endl;
auto sink_op = name2op.at(assign_name);
auto node = ov::as_type_ptr<opset8::Assign>(sink_op);
if (node) {
sinks.push_back(node);
}
}
if (!sinks.empty()){
std::cout << "Add sinks size " << sinks.size() << std::endl;
subgraph->add_sinks(sinks);
}
using namespace std::chrono;
auto ms = duration_cast< milliseconds >(
system_clock::now().time_since_epoch()).count();
ov::serialize(subgraph, "simple_model_" + std::to_string(ms) + ".xml");
}