Skip to content

Commit

Permalink
initial
Browse files Browse the repository at this point in the history
  • Loading branch information
evkotov authored and mryzhov committed Feb 21, 2023
1 parent 86b5004 commit d7bff1f
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/plugins/intel_gna/src/gna_data_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ using ConcatConnection = std::unordered_map<std::string, GNAConcatLayer>;
using SplitConnection = std::unordered_map<std::string, GNASplitLayer>;
using CropConnection = std::unordered_map<std::string, GNACropLayer>;
using ConstConnections = std::unordered_map<std::string, void*>;
using SubgraphCPUMap = std::unordered_map<std::string, std::shared_ptr<ov::Model>>;

} // namespace intel_gna
} // namespace ov
45 changes: 44 additions & 1 deletion src/plugins/intel_gna/src/gna_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,29 @@ uint32_t GNAPlugin::QueueInference(const InferenceEngine::BlobMap& inputs, Infer
importedElements,
importedElements);

{
auto subgraph_it = subgraph_cpu_map.find(input.first);
if (subgraph_it != subgraph_cpu_map.end()) {
std::shared_ptr<ov::Model> model = subgraph_it->second;
ov::TensorVector model_result(1);

const ngraph::Shape & model_input_shape = model->get_parameters()[0]->get_output_shape(0);
const ngraph::element::Type & model_input_type = model->get_parameters()[0]->get_element_type();

void * input_ptr = inputs_ptr_->at(input.first).ptrs[index];
ov::Tensor input_tensor(model_input_type, model_input_shape, input_ptr);

if (model->evaluate(model_result, ov::TensorVector{input_tensor})) {
const size_t model_result_size = model_result[0].get_byte_size();
ie_memcpy(input_ptr, model_result_size, model_result[0].data(), model_result_size);
} else {
THROW_GNA_EXCEPTION << "Error evalutate ngraph::Function for output " << input.first;
}
} else {
std::cout << "input " << input.first << " model NOT found " << std::endl;
}
}

auto transpose_info = transpose_inputs_info.find(input.first);
if (transpose_info != std::end(transpose_inputs_info)) {
size_t batchSize = (dims.size() > 1) ? dims[0] : 1;
Expand Down Expand Up @@ -1252,7 +1275,27 @@ RequestStatus GNAPlugin::WaitFor(uint32_t request_idx, int64_t millisTimeout) {
auto elementsPerBatch =
isScalar ? 1
: (is1D ? dims.front() : InferenceEngine::details::product(++std::begin(dims), std::end(dims)));

{
auto subgraph_it = subgraph_cpu_map.find(outputBlobIt.first);
if (subgraph_it != subgraph_cpu_map.end()) {
std::shared_ptr<ov::Model> model = subgraph_it->second;
ov::TensorVector model_result(1);

void * output_ptr = outputDesc.ptrs[request_idx];
const ngraph::Shape & model_input_shape = model->get_parameters()[0]->get_output_shape(0);
const ngraph::element::Type & model_input_type = model->get_parameters()[0]->get_element_type();
ov::Tensor input_tensor(model_input_type, model_input_shape, output_ptr);

if (model->evaluate(model_result, ov::TensorVector{input_tensor})) {
const size_t model_result_size = model_result[0].get_byte_size();
ie_memcpy(output_ptr, model_result_size, model_result[0].data(), model_result_size);
} else {
THROW_GNA_EXCEPTION << "Error evalutate ngraph::Function for output " << outputBlobIt.first;
}
} else {
std::cout << "output " << outputBlobIt.first << " model NOT found " << std::endl;
}
}
auto transpose_output_info = transpose_outputs_info.find(outputBlobIt.first);
if (transpose_output_info != std::end(transpose_outputs_info) &&
FoundPartToTranspose(transpose_output_info->second)) {
Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_gna/src/gna_plugin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class GNAPlugin : public InferenceEngine::IInferencePlugin {
TranspositionInfoMap transpose_inputs_info;
TranspositionInfoMap transpose_outputs_info;

SubgraphCPUMap subgraph_cpu_map;

uint32_t dnn_dump_write_index = 0;
intel_dnn_number_type_t output_type = kDnnInt;

Expand Down

0 comments on commit d7bff1f

Please sign in to comment.