Skip to content

Commit

Permalink
Add OpenVINO backend support (#148)
Browse files Browse the repository at this point in the history
* Add OpenVINO backend support

* fix pybind

* fix python library path
  • Loading branch information
jiangjiajun authored Aug 24, 2022
1 parent a1260d7 commit cf4afa4
Show file tree
Hide file tree
Showing 20 changed files with 479 additions and 38 deletions.
11 changes: 10 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ option(WITH_GPU "Whether WITH_GPU=ON, will enable onnxruntime-gpu/paddle-infernc
option(ENABLE_ORT_BACKEND "Whether to enable onnxruntime backend." OFF)
option(ENABLE_TRT_BACKEND "Whether to enable tensorrt backend." OFF)
option(ENABLE_PADDLE_BACKEND "Whether to enable paddle backend." OFF)
option(ENABLE_OPENVINO_BACKEND "Whether to enable paddle backend." OFF)
option(CUDA_DIRECTORY "If build tensorrt backend, need to define path of cuda library.")
option(TRT_DIRECTORY "If build tensorrt backend, need to define path of tensorrt library.")
option(ENABLE_VISION "Whether to enable vision models usage." OFF)
Expand Down Expand Up @@ -117,10 +118,11 @@ file(GLOB_RECURSE FDTENSOR_FUNC_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fas
file(GLOB_RECURSE DEPLOY_ORT_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/ort/*.cc)
file(GLOB_RECURSE DEPLOY_PADDLE_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/paddle/*.cc)
file(GLOB_RECURSE DEPLOY_TRT_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/tensorrt/*.cc ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/tensorrt/*.cpp)
file(GLOB_RECURSE DEPLOY_OPENVINO_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/openvino/*.cc)
file(GLOB_RECURSE DEPLOY_VISION_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/vision/*.cc)
file(GLOB_RECURSE DEPLOY_TEXT_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/text/*.cc)
file(GLOB_RECURSE DEPLOY_PYBIND_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/pybind/*.cc ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/*_pybind.cc)
list(REMOVE_ITEM ALL_DEPLOY_SRCS ${DEPLOY_ORT_SRCS} ${DEPLOY_PADDLE_SRCS} ${DEPLOY_TRT_SRCS} ${DEPLOY_VISION_SRCS} ${DEPLOY_TEXT_SRCS} ${FDTENSOR_FUNC_SRCS})
list(REMOVE_ITEM ALL_DEPLOY_SRCS ${DEPLOY_ORT_SRCS} ${DEPLOY_PADDLE_SRCS} ${DEPLOY_TRT_SRCS} ${DEPLOY_OPENVINO_SRCS} ${DEPLOY_VISION_SRCS} ${DEPLOY_TEXT_SRCS} ${FDTENSOR_FUNC_SRCS})

set(DEPEND_LIBS "")

Expand Down Expand Up @@ -157,6 +159,13 @@ if(ENABLE_PADDLE_BACKEND)
endif()
endif()

if(ENABLE_OPENVINO_BACKEND)
add_definitions(-DENABLE_OPENVINO_BACKEND)
list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_OPENVINO_SRCS})
include(external/openvino.cmake)
list(APPEND DEPEND_LIBS external_openvino)
endif()

if(WITH_GPU)
if(APPLE)
message(FATAL_ERROR "Cannot enable GPU while compling in Mac OSX.")
Expand Down
11 changes: 11 additions & 0 deletions FastDeploy.cmake.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ CMAKE_MINIMUM_REQUIRED (VERSION 3.12)
set(WITH_GPU @WITH_GPU@)
set(ENABLE_ORT_BACKEND @ENABLE_ORT_BACKEND@)
set(ENABLE_PADDLE_BACKEND @ENABLE_PADDLE_BACKEND@)
set(ENABLE_OPENVINO_BACKEND @ENABLE_OPENVINO_BACKEND@)
set(PADDLEINFERENCE_VERSION @PADDLEINFERENCE_VERSION@)
set(OPENVINO_VERSION @OPENVINO_VERSION@)
set(ENABLE_TRT_BACKEND @ENABLE_TRT_BACKEND@)
set(ENABLE_PADDLE_FRONTEND @ENABLE_PADDLE_FRONTEND@)
set(ENABLE_VISION @ENABLE_VISION@)
Expand Down Expand Up @@ -45,6 +47,11 @@ if(ENABLE_PADDLE_BACKEND)
endif()
endif()

if(ENABLE_OPENVINO_BACKEND)
find_library(OPENVINO_LIB openvino ${CMAKE_CURRENT_LIST_DIR}/third_libs/install/openvino/lib/ NO_DEFAULT_PATH)
list(APPEND FASTDEPLOY_LIBS ${OPENVINO_LIB})
endif()

if(WITH_GPU)
if (NOT CUDA_DIRECTORY)
set(CUDA_DIRECTORY "/usr/local/cuda")
Expand Down Expand Up @@ -124,6 +131,10 @@ message(STATUS " ENABLE_PADDLE_BACKEND : ${ENABLE_PADDLE_BACKEND}")
if(ENABLE_PADDLE_BACKEND)
message(STATUS " Paddle Inference version : ${PADDLEINFERENCE_VERSION}")
endif()
message(STATUS " ENABLE_OPENVINO_BACKEND : ${ENABLE_OPENVINO_BACKEND}")
if(ENABLE_OPENVINO_BACKEND)
message(STATUS " OpenVINO version : ${OPENVINO_VERSION}")
endif()
message(STATUS " ENABLE_TRT_BACKEND : ${ENABLE_TRT_BACKEND}")
message(STATUS " ENABLE_VISION : ${ENABLE_VISION}")
message(STATUS " ENABLE_TEXT : ${ENABLE_TEXT}")
Expand Down
2 changes: 1 addition & 1 deletion VERSION_NUMBER
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.2.0
0.2.1
18 changes: 16 additions & 2 deletions csrc/fastdeploy/backends/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,33 @@

#pragma once

#include "fastdeploy/backends/common/multiclass_nms.h"
#include "fastdeploy/core/fd_tensor.h"
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include "fastdeploy/backends/common/multiclass_nms.h"
#include "fastdeploy/core/fd_tensor.h"

namespace fastdeploy {

struct TensorInfo {
std::string name;
std::vector<int> shape;
FDDataType dtype;

friend std::ostream& operator<<(std::ostream& output,
const TensorInfo& info) {
output << "TensorInfo(name: " << info.name << ", shape: [";
for (size_t i = 0; i < info.shape.size(); ++i) {
if (i == info.shape.size() - 1) {
output << info.shape[i];
} else {
output << info.shape[i] << ", ";
}
}
output << "], dtype: " << Str(info.dtype) << ")";
return output;
}
};

class BaseBackend {
Expand Down
199 changes: 199 additions & 0 deletions csrc/fastdeploy/backends/openvino/ov_backend.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "fastdeploy/backends/openvino/ov_backend.h"

namespace fastdeploy {

std::vector<int64_t> PartialShapeToVec(const ov::PartialShape& shape) {
std::vector<int64_t> res;
for (int i = 0; i < shape.size(); ++i) {
auto dim = shape[i];
if (dim.is_dynamic()) {
res.push_back(-1);
} else {
res.push_back(dim.get_length());
}
}
return res;
}

FDDataType OpenVINODataTypeToFD(const ov::element::Type& type) {
if (type == ov::element::f32) {
return FDDataType::FP32;
} else if (type == ov::element::f64) {
return FDDataType::FP64;
} else if (type == ov::element::i8) {
return FDDataType::INT8;
} else if (type == ov::element::i32) {
return FDDataType::INT32;
} else if (type == ov::element::i64) {
return FDDataType::INT64;
} else {
FDASSERT(false, "Only support float/double/int8/int32/int64 now.");
}
return FDDataType::FP32;
}

ov::element::Type FDDataTypeToOV(const FDDataType& type) {
if (type == FDDataType::FP32) {
return ov::element::f32;
} else if (type == FDDataType::FP64) {
return ov::element::f64;
} else if (type == FDDataType::INT8) {
return ov::element::i8;
} else if (type == FDDataType::INT32) {
return ov::element::i32;
} else if (type == FDDataType::INT64) {
return ov::element::i64;
}
FDASSERT(false, "Only support float/double/int8/int32/int64 now.");
return ov::element::f32;
}

bool OpenVINOBackend::InitFromPaddle(const std::string& model_file,
const std::string& params_file,
const OpenVINOBackendOption& option) {
if (initialized_) {
FDERROR << "OpenVINOBackend is already initlized, cannot initialize again."
<< std::endl;
return false;
}
option_ = option;
ov::AnyMap properties;
if (option_.cpu_thread_num > 0) {
properties["INFERENCE_NUM_THREADS"] = option_.cpu_thread_num;
}

std::shared_ptr<ov::Model> model = core_.read_model(model_file, params_file);

// Get inputs/outputs information from loaded model
const std::vector<ov::Output<ov::Node>> inputs = model->inputs();
for (size_t i = 0; i < inputs.size(); ++i) {
TensorInfo info;
auto partial_shape = PartialShapeToVec(inputs[i].get_partial_shape());
info.shape.assign(partial_shape.begin(), partial_shape.end());
info.name = inputs[i].get_any_name();
info.dtype = OpenVINODataTypeToFD(inputs[i].get_element_type());
input_infos_.emplace_back(info);
}
const std::vector<ov::Output<ov::Node>> outputs = model->outputs();
for (size_t i = 0; i < outputs.size(); ++i) {
TensorInfo info;
auto partial_shape = PartialShapeToVec(outputs[i].get_partial_shape());
info.shape.assign(partial_shape.begin(), partial_shape.end());
info.name = outputs[i].get_any_name();
info.dtype = OpenVINODataTypeToFD(outputs[i].get_element_type());
output_infos_.emplace_back(info);
}

compiled_model_ = core_.compile_model(model, "CPU", properties);
request_ = compiled_model_.create_infer_request();
initialized_ = true;
return true;
}

TensorInfo OpenVINOBackend::GetInputInfo(int index) {
FDASSERT(index < NumInputs(),
"The index: %d should less than the number of outputs: %d.", index,
NumOutputs());
return input_infos_[index];
}

TensorInfo OpenVINOBackend::GetOutputInfo(int index) {
FDASSERT(index < NumOutputs(),
"The index: %d should less than the number of outputs: %d.", index,
NumOutputs());
return output_infos_[index];
}

bool OpenVINOBackend::InitFromOnnx(const std::string& model_file,
const OpenVINOBackendOption& option) {
if (initialized_) {
FDERROR << "OpenVINOBackend is already initlized, cannot initialize again."
<< std::endl;
return false;
}
option_ = option;
ov::AnyMap properties;
if (option_.cpu_thread_num > 0) {
properties["INFERENCE_NUM_THREADS"] = option_.cpu_thread_num;
}

std::shared_ptr<ov::Model> model = core_.read_model(model_file);

// Get inputs/outputs information from loaded model
const std::vector<ov::Output<ov::Node>> inputs = model->inputs();
for (size_t i = 0; i < inputs.size(); ++i) {
TensorInfo info;
auto partial_shape = PartialShapeToVec(inputs[i].get_partial_shape());
info.shape.assign(partial_shape.begin(), partial_shape.end());
info.name = inputs[i].get_any_name();
info.dtype = OpenVINODataTypeToFD(inputs[i].get_element_type());
input_infos_.emplace_back(info);
}
const std::vector<ov::Output<ov::Node>> outputs = model->outputs();
for (size_t i = 0; i < outputs.size(); ++i) {
TensorInfo info;
auto partial_shape = PartialShapeToVec(outputs[i].get_partial_shape());
info.shape.assign(partial_shape.begin(), partial_shape.end());
info.name = outputs[i].get_any_name();
info.dtype = OpenVINODataTypeToFD(outputs[i].get_element_type());
output_infos_.emplace_back(info);
}

compiled_model_ = core_.compile_model(model, "CPU", properties);
request_ = compiled_model_.create_infer_request();
initialized_ = true;
return true;
}

int OpenVINOBackend::NumInputs() const { return input_infos_.size(); }

int OpenVINOBackend::NumOutputs() const { return output_infos_.size(); }

bool OpenVINOBackend::Infer(std::vector<FDTensor>& inputs,
std::vector<FDTensor>* outputs) {
if (inputs.size() != input_infos_.size()) {
FDERROR << "[OpenVINOBackend] Size of the inputs(" << inputs.size()
<< ") should keep same with the inputs of this model("
<< input_infos_.size() << ")." << std::endl;
return false;
}

for (size_t i = 0; i < inputs.size(); ++i) {
ov::Shape shape(inputs[i].shape.begin(), inputs[i].shape.end());
ov::Tensor ov_tensor(FDDataTypeToOV(inputs[i].dtype), shape,
inputs[i].Data());
request_.set_tensor(inputs[i].name, ov_tensor);
}

request_.infer();

outputs->resize(output_infos_.size());
for (size_t i = 0; i < output_infos_.size(); ++i) {
auto out_tensor = request_.get_output_tensor(i);
auto out_tensor_shape = out_tensor.get_shape();
std::vector<int64_t> shape(out_tensor_shape.begin(),
out_tensor_shape.end());
(*outputs)[i].Allocate(shape,
OpenVINODataTypeToFD(out_tensor.get_element_type()),
output_infos_[i].name);
memcpy((*outputs)[i].MutableData(), out_tensor.data(),
(*outputs)[i].Nbytes());
}
return true;
}

} // namespace fastdeploy
62 changes: 62 additions & 0 deletions csrc/fastdeploy/backends/openvino/ov_backend.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <iostream>
#include <memory>
#include <string>
#include <vector>

#include "fastdeploy/backends/backend.h"
#include "openvino/openvino.hpp"

namespace fastdeploy {

struct OpenVINOBackendOption {
int cpu_thread_num = 8;
std::map<std::string, std::vector<int64_t>> shape_infos;
};

class OpenVINOBackend : public BaseBackend {
public:
OpenVINOBackend() {}
virtual ~OpenVINOBackend() = default;

bool
InitFromPaddle(const std::string& model_file, const std::string& params_file,
const OpenVINOBackendOption& option = OpenVINOBackendOption());

bool
InitFromOnnx(const std::string& model_file,
const OpenVINOBackendOption& option = OpenVINOBackendOption());

bool Infer(std::vector<FDTensor>& inputs, std::vector<FDTensor>* outputs);

int NumInputs() const;

int NumOutputs() const;

TensorInfo GetInputInfo(int index);
TensorInfo GetOutputInfo(int index);

private:
ov::Core core_;
ov::CompiledModel compiled_model_;
ov::InferRequest request_;
OpenVINOBackendOption option_;
std::vector<TensorInfo> input_infos_;
std::vector<TensorInfo> output_infos_;
};
} // namespace fastdeploy
4 changes: 2 additions & 2 deletions csrc/fastdeploy/backends/ort/ort_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ void OrtBackend::BuildOption(const OrtBackendOption& option) {
session_options_.SetGraphOptimizationLevel(
GraphOptimizationLevel(option.graph_optimization_level));
}
if (option.intra_op_num_threads >= 0) {
if (option.intra_op_num_threads > 0) {
session_options_.SetIntraOpNumThreads(option.intra_op_num_threads);
}
if (option.inter_op_num_threads >= 0) {
if (option.inter_op_num_threads > 0) {
session_options_.SetInterOpNumThreads(option.inter_op_num_threads);
}
if (option.execution_mode >= 0) {
Expand Down
6 changes: 5 additions & 1 deletion csrc/fastdeploy/backends/paddle/paddle_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ void PaddleBackend::BuildOption(const PaddleBackendOption& option) {
if (!option.enable_log_info) {
config_.DisableGlogInfo();
}
config_.SetCpuMathLibraryNumThreads(option.cpu_thread_num);
if (option.cpu_thread_num <= 0) {
config_.SetCpuMathLibraryNumThreads(8);
} else {
config_.SetCpuMathLibraryNumThreads(option.cpu_thread_num);
}
}

bool PaddleBackend::InitFromPaddle(const std::string& model_file,
Expand Down
Loading

0 comments on commit cf4afa4

Please sign in to comment.