From 47ac276dc4f992295c79532e1dce3c5020327c8b Mon Sep 17 00:00:00 2001 From: jiwaszki Date: Fri, 16 Feb 2024 12:40:09 +0100 Subject: [PATCH] Port changes of Remote API --- src/bindings/python/src/openvino/__init__.py | 7 + .../python/src/pyopenvino/CMakeLists.txt | 2 + .../python/src/pyopenvino/core/core.cpp | 31 ++++ .../src/pyopenvino/core/remote_context.cpp | 96 +++++++++++++ .../src/pyopenvino/core/remote_context.hpp | 40 ++++++ .../src/pyopenvino/core/remote_tensor.cpp | 86 ++++++++++++ .../src/pyopenvino/core/remote_tensor.hpp | 45 ++++++ .../python/src/pyopenvino/pyopenvino.cpp | 7 + .../python/src/pyopenvino/utils/utils.cpp | 6 + .../python/src/pyopenvino/utils/utils.hpp | 2 + .../tests/test_runtime/test_remote_api.py | 132 ++++++++++++++++++ tools/benchmark_tool/openvino/__init__.py | 7 + tools/mo/openvino/__init__.py | 7 + tools/openvino_dev/src/openvino/__init__.py | 7 + tools/ovc/openvino/__init__.py | 7 + 15 files changed, 482 insertions(+) create mode 100644 src/bindings/python/src/pyopenvino/core/remote_context.cpp create mode 100644 src/bindings/python/src/pyopenvino/core/remote_context.hpp create mode 100644 src/bindings/python/src/pyopenvino/core/remote_tensor.cpp create mode 100644 src/bindings/python/src/pyopenvino/core/remote_tensor.hpp create mode 100644 src/bindings/python/tests/test_runtime/test_remote_api.py diff --git a/src/bindings/python/src/openvino/__init__.py b/src/bindings/python/src/openvino/__init__.py index 9db717409cf327..e6700c56c61616 100644 --- a/src/bindings/python/src/openvino/__init__.py +++ b/src/bindings/python/src/openvino/__init__.py @@ -50,6 +50,13 @@ from openvino.runtime import save_model from openvino.runtime import layout_helpers +from openvino._pyopenvino import RemoteContext +from openvino._pyopenvino import RemoteTensor + +# libva related: +from openvino._pyopenvino import VAContext +from openvino._pyopenvino import VASurfaceTensor + # Set version for openvino package from openvino.runtime import get_version __version__ = get_version() diff --git a/src/bindings/python/src/pyopenvino/CMakeLists.txt b/src/bindings/python/src/pyopenvino/CMakeLists.txt index 6e11c915c7baf2..420ab28cf329c5 100644 --- a/src/bindings/python/src/pyopenvino/CMakeLists.txt +++ b/src/bindings/python/src/pyopenvino/CMakeLists.txt @@ -63,7 +63,9 @@ list(FILTER SOURCES EXCLUDE REGEX ".*(frontend/(onnx|tensorflow|paddle|pytorch)) pybind11_add_module(${PROJECT_NAME} MODULE NO_EXTRAS ${SOURCES}) target_include_directories(${PROJECT_NAME} PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/..") + target_link_libraries(${PROJECT_NAME} PRIVATE openvino::core::dev openvino::runtime openvino::offline_transformations) + set_target_properties(${PROJECT_NAME} PROPERTIES INTERPROCEDURAL_OPTIMIZATION_RELEASE ${ENABLE_LTO} OUTPUT_NAME "_pyopenvino") diff --git a/src/bindings/python/src/pyopenvino/core/core.cpp b/src/bindings/python/src/pyopenvino/core/core.cpp index 5252660efe4c98..78ddebb8ed6009 100644 --- a/src/bindings/python/src/pyopenvino/core/core.cpp +++ b/src/bindings/python/src/pyopenvino/core/core.cpp @@ -11,6 +11,7 @@ #include #include "common.hpp" +#include "pyopenvino/core/remote_context.hpp" #include "pyopenvino/utils/utils.hpp" namespace py = pybind11; @@ -231,6 +232,36 @@ void regclass_Core(py::module m) { :rtype: openvino.runtime.CompiledModel )"); + cls.def( + "compile_model", + [](ov::Core& self, + const std::shared_ptr& model, + const RemoteContextWrapper& context, + const std::map& properties) { + auto _properties = Common::utils::properties_to_any_map(properties); + py::gil_scoped_release release; + return self.compile_model(model, context.context, _properties); + }, + py::arg("model"), + py::arg("context"), + py::arg("properties")); + + cls.def( + "create_context", + [](ov::Core& self, const std::string& device_name, const std::map& properties) { + auto _properties = Common::utils::properties_to_any_map(properties); + return RemoteContextWrapper(self.create_context(device_name, _properties)); + }, + py::arg("device_name"), + py::arg("properties")); + + cls.def( + "get_default_context", + [](ov::Core& self, const std::string& device_name) { + return RemoteContextWrapper(self.get_default_context(device_name)); + }, + py::arg("device_name")); + cls.def("get_versions", &ov::Core::get_versions, py::arg("device_name"), diff --git a/src/bindings/python/src/pyopenvino/core/remote_context.cpp b/src/bindings/python/src/pyopenvino/core/remote_context.cpp new file mode 100644 index 00000000000000..e10bb3512900c9 --- /dev/null +++ b/src/bindings/python/src/pyopenvino/core/remote_context.cpp @@ -0,0 +1,96 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "pyopenvino/core/remote_context.hpp" + +#include + +#include + +#include "common.hpp" +#include "pyopenvino/utils/utils.hpp" + +namespace py = pybind11; + +void regclass_RemoteContext(py::module m) { + py::class_> cls(m, "RemoteContext"); + + cls.def("get_device_name", [](RemoteContextWrapper& self) { + return self.context.get_device_name(); + }); + + cls.def("get_params", [](RemoteContextWrapper& self) { + return self.context.get_params(); + }); + + cls.def( + "create_tensor", + [](RemoteContextWrapper& self, + const ov::element::Type& type, + const ov::Shape& shape, + const std::map& properties) { + auto _properties = Common::utils::properties_to_any_map(properties); + return RemoteTensorWrapper(self.context.create_tensor(type, shape, _properties)); + }, + py::arg("type"), + py::arg("shape"), + py::arg("properties")); + + cls.def( + "create_host_tensor", + [](RemoteContextWrapper& self, const ov::element::Type& type, const ov::Shape& shape) { + return self.context.create_host_tensor(type, shape); + }, + py::arg("type"), + py::arg("shape")); +} + +void regclass_VAContext(py::module m) { + py::class_> cls(m, "VAContext"); + + cls.def(py::init([](ov::Core& core, void* display, int target_tile_id) { + ov::AnyMap context_params = { + {ov::intel_gpu::context_type.name(), ov::intel_gpu::ContextType::VA_SHARED}, + {ov::intel_gpu::va_device.name(), display}, + {ov::intel_gpu::tile_id.name(), target_tile_id}}; + auto ctx = core.create_context("GPU", context_params); + return VAContextWrapper(ctx); + }), + py::arg("core"), + py::arg("display"), + py::arg("target_tile_id") = -1); + + cls.def( + "create_tensor_nv12", + [](VAContextWrapper& self, const size_t height, const size_t width, const uint32_t nv12_surface) { + ov::AnyMap tensor_params = { + {ov::intel_gpu::shared_mem_type.name(), ov::intel_gpu::SharedMemType::VA_SURFACE}, + {ov::intel_gpu::dev_object_handle.name(), nv12_surface}, + {ov::intel_gpu::va_plane.name(), uint32_t(0)}}; + auto y_tensor = self.context.create_tensor(ov::element::u8, {1, height, width, 1}, tensor_params); + tensor_params[ov::intel_gpu::va_plane.name()] = uint32_t(1); + auto uv_tensor = self.context.create_tensor(ov::element::u8, {1, height / 2, width / 2, 2}, tensor_params); + return py::make_tuple(VASurfaceTensorWrapper(y_tensor), VASurfaceTensorWrapper(uv_tensor)); + }, + py::arg("height"), + py::arg("width"), + py::arg("nv12_surface")); + + cls.def( + "create_tensor", + [](VAContextWrapper& self, + const ov::element::Type& type, + const ov::Shape shape, + const uint32_t surface, + const uint32_t plane) { + ov::AnyMap params = {{ov::intel_gpu::shared_mem_type.name(), ov::intel_gpu::SharedMemType::VA_SURFACE}, + {ov::intel_gpu::dev_object_handle.name(), surface}, + {ov::intel_gpu::va_plane.name(), plane}}; + return VASurfaceTensorWrapper(self.context.create_tensor(type, shape, params)); + }, + py::arg("type"), + py::arg("shape"), + py::arg("surface"), + py::arg("plane") = 0); +} diff --git a/src/bindings/python/src/pyopenvino/core/remote_context.hpp b/src/bindings/python/src/pyopenvino/core/remote_context.hpp new file mode 100644 index 00000000000000..d5695e591650d3 --- /dev/null +++ b/src/bindings/python/src/pyopenvino/core/remote_context.hpp @@ -0,0 +1,40 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +#include "openvino/core/except.hpp" +#include "pyopenvino/core/remote_tensor.hpp" + +namespace py = pybind11; + +class RemoteContextWrapper { +public: + RemoteContextWrapper() {} + + RemoteContextWrapper(ov::RemoteContext& _context): context{_context} {} + + RemoteContextWrapper(ov::RemoteContext&& _context): context{std::move(_context)} {} + + ov::RemoteContext context; +}; + +void regclass_RemoteContext(py::module m); + +class VAContextWrapper : public RemoteContextWrapper { +public: + VAContextWrapper(ov::RemoteContext& _context): RemoteContextWrapper{_context} {} + + VAContextWrapper(ov::RemoteContext&& _context): RemoteContextWrapper{std::move(_context)} {} +}; + +void regclass_VAContext(py::module m); diff --git a/src/bindings/python/src/pyopenvino/core/remote_tensor.cpp b/src/bindings/python/src/pyopenvino/core/remote_tensor.cpp new file mode 100644 index 00000000000000..dbe50135590976 --- /dev/null +++ b/src/bindings/python/src/pyopenvino/core/remote_tensor.cpp @@ -0,0 +1,86 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "pyopenvino/core/remote_tensor.hpp" + +#include + +#include "common.hpp" +#include "pyopenvino/utils/utils.hpp" + +namespace py = pybind11; + +void regclass_RemoteTensor(py::module m) { + py::class_> cls(m, + "RemoteTensor", + py::base()); + + cls.def("get_device_name", [](RemoteTensorWrapper& self) { + return self.tensor.get_device_name(); + }); + + cls.def("get_params", [](RemoteTensorWrapper& self) { + return self.tensor.get_params(); + }); + + cls.def("copy_to", [](RemoteTensorWrapper& self, py::object& dst) { + Common::utils::raise_not_implemented(); + }); + + cls.def_property_readonly("data", [](RemoteTensorWrapper& self) { + Common::utils::raise_not_implemented(); + }); + + cls.def_property( + "bytes_data", + [](RemoteTensorWrapper& self) { + Common::utils::raise_not_implemented(); + }, + [](RemoteTensorWrapper& self, py::object& other) { + Common::utils::raise_not_implemented(); + }); + + cls.def_property( + "str_data", + [](RemoteTensorWrapper& self) { + Common::utils::raise_not_implemented(); + }, + [](RemoteTensorWrapper& self, py::object& other) { + Common::utils::raise_not_implemented(); + }); + + cls.def("__repr__", [](const RemoteTensorWrapper& self) { + std::stringstream ss; + + ss << "shape" << self.tensor.get_shape() << " type: " << self.tensor.get_element_type(); + + return "<" + Common::get_class_name(self) + ": " + ss.str() + ">"; + }); +} + +void regclass_VASurfaceTensor(py::module m) { + py::class_> cls( + m, + "VASurfaceTensor"); + + cls.def_property_readonly("surface_id", [](VASurfaceTensorWrapper& self) { + return self.surface_id(); + }); + + cls.def_property_readonly("plane_id", [](VASurfaceTensorWrapper& self) { + return self.plane_id(); + }); + + cls.def_property_readonly("data", [](VASurfaceTensorWrapper& self) { + Common::utils::raise_not_implemented(); + }); + + cls.def("__repr__", [](const VASurfaceTensorWrapper& self) { + std::stringstream ss; + + ss << "shape" << self.tensor.get_shape() << " type: " << self.tensor.get_element_type(); + + return "<" + Common::get_class_name(self) + ": " + ss.str() + ">"; + }); +} diff --git a/src/bindings/python/src/pyopenvino/core/remote_tensor.hpp b/src/bindings/python/src/pyopenvino/core/remote_tensor.hpp new file mode 100644 index 00000000000000..4fa40aff49615d --- /dev/null +++ b/src/bindings/python/src/pyopenvino/core/remote_tensor.hpp @@ -0,0 +1,45 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +class RemoteTensorWrapper { +public: + RemoteTensorWrapper() {} + + RemoteTensorWrapper(ov::RemoteTensor& _tensor): tensor{_tensor} {} + + RemoteTensorWrapper(ov::RemoteTensor&& _tensor): tensor{std::move(_tensor)} {} + + ov::RemoteTensor tensor; +}; + +void regclass_RemoteTensor(py::module m); + +class VASurfaceTensorWrapper : public RemoteTensorWrapper { +public: + VASurfaceTensorWrapper(ov::RemoteTensor& _tensor): RemoteTensorWrapper{_tensor} {} + + VASurfaceTensorWrapper(ov::RemoteTensor&& _tensor): RemoteTensorWrapper{std::move(_tensor)} {} + + uint32_t surface_id() { + return tensor.get_params().at(ov::intel_gpu::dev_object_handle.name()).as(); + } + + uint32_t plane_id() { + return tensor.get_params().at(ov::intel_gpu::va_plane.name()).as(); + } +}; + +void regclass_VASurfaceTensor(py::module m); diff --git a/src/bindings/python/src/pyopenvino/pyopenvino.cpp b/src/bindings/python/src/pyopenvino/pyopenvino.cpp index bb6447a60357fa..8fdea8a8bde0ed 100644 --- a/src/bindings/python/src/pyopenvino/pyopenvino.cpp +++ b/src/bindings/python/src/pyopenvino/pyopenvino.cpp @@ -30,6 +30,8 @@ #include "pyopenvino/core/offline_transformations.hpp" #include "pyopenvino/core/profiling_info.hpp" #include "pyopenvino/core/properties/properties.hpp" +#include "pyopenvino/core/remote_context.hpp" +#include "pyopenvino/core/remote_tensor.hpp" #include "pyopenvino/core/tensor.hpp" #include "pyopenvino/core/variable_state.hpp" #include "pyopenvino/core/version.hpp" @@ -258,6 +260,11 @@ PYBIND11_MODULE(_pyopenvino, m) { regclass_ProfilingInfo(m); regclass_Extension(m); + regclass_RemoteContext(m); + regclass_RemoteTensor(m); + regclass_VAContext(m); + regclass_VASurfaceTensor(m); + // Properties and hints regmodule_properties(m); diff --git a/src/bindings/python/src/pyopenvino/utils/utils.cpp b/src/bindings/python/src/pyopenvino/utils/utils.cpp index 52f7099530523f..87f6c36576a1ca 100644 --- a/src/bindings/python/src/pyopenvino/utils/utils.cpp +++ b/src/bindings/python/src/pyopenvino/utils/utils.cpp @@ -264,6 +264,12 @@ void deprecation_warning(const std::string& function_name, PyErr_WarnEx(PyExc_DeprecationWarning, ss.str().data(), stacklevel); } +void raise_not_implemented() { + auto error_message = py::detail::c_str(std::string("This function is not implemented.")); + PyErr_SetString(PyExc_NotImplementedError, error_message); + throw py::error_already_set(); +} + bool py_object_is_any_map(const py::object& py_obj) { if (!py::isinstance(py_obj)) { return false; diff --git a/src/bindings/python/src/pyopenvino/utils/utils.hpp b/src/bindings/python/src/pyopenvino/utils/utils.hpp index 6d65bce29acac6..e2e5c6f5e886ff 100644 --- a/src/bindings/python/src/pyopenvino/utils/utils.hpp +++ b/src/bindings/python/src/pyopenvino/utils/utils.hpp @@ -43,6 +43,8 @@ namespace utils { void deprecation_warning(const std::string& function_name, const std::string& version = std::string(), const std::string& message = std::string(), int stacklevel=2); + void raise_not_implemented(); + bool py_object_is_any_map(const py::object& py_obj); ov::AnyMap py_object_to_any_map(const py::object& py_obj); diff --git a/src/bindings/python/tests/test_runtime/test_remote_api.py b/src/bindings/python/tests/test_runtime/test_remote_api.py new file mode 100644 index 00000000000000..2c6201c7b4674f --- /dev/null +++ b/src/bindings/python/tests/test_runtime/test_remote_api.py @@ -0,0 +1,132 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2018-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import os +import pytest + +import numpy as np + +import openvino as ov +import openvino.runtime.opset13 as ops + +from tests.utils.helpers import generate_image, get_relu_model, generate_model_with_memory + + +@pytest.mark.skipif( + "CPU" not in os.environ.get("TEST_DEVICE", ""), + reason="Test can be only performed on CPU device!", +) +def test_get_default_context_cpu(): + core = ov.Core() + with pytest.raises(RuntimeError) as cpu_error: + _ = core.get_default_context("CPU") + possible_errors = ["is not supported by CPU plugin!", "Not Implemented"] + assert any(error in str(cpu_error.value) for error in possible_errors) + + +@pytest.mark.skipif( + "GPU" not in os.environ.get("TEST_DEVICE", ""), + reason="Test can be only performed on GPU device!", +) +def test_get_default_context_gpu(): + core = ov.Core() + context = core.get_default_context("GPU") + assert isinstance(context, ov.RemoteContext) + assert "GPU" in context.get_device_name() + + context_params = context.get_params() + + assert isinstance(context_params, dict) + assert list(context_params.keys()) == ["CONTEXT_TYPE", "OCL_CONTEXT", "OCL_QUEUE"] + + +@pytest.mark.skipif( + "GPU" not in os.environ.get("TEST_DEVICE", ""), + reason="Test can be only performed on GPU device!", +) +def test_create_host_tensor_gpu(): + core = ov.Core() + context = core.get_default_context("GPU") + assert isinstance(context, ov.RemoteContext) + assert "GPU" in context.get_device_name() + + tensor = context.create_host_tensor(ov.Type.f32, ov.Shape([1, 2, 3])) + + assert isinstance(tensor, ov.Tensor) + assert not isinstance(tensor, ov.RemoteTensor) + + +@pytest.mark.skipif( + "GPU" not in os.environ.get("TEST_DEVICE", ""), + reason="Test can be only performed on GPU device!", +) +def test_create_device_tensor_gpu(): + core = ov.Core() + context = core.get_default_context("GPU") + assert isinstance(context, ov.RemoteContext) + assert "GPU" in context.get_device_name() + + tensor = context.create_tensor(ov.Type.f32, ov.Shape([1, 2, 3]), {}) + tensor_params = tensor.get_params() + + assert isinstance(tensor_params, dict) + assert list(tensor_params.keys()) == ["MEM_HANDLE", "OCL_CONTEXT", "SHARED_MEM_TYPE"] + + assert isinstance(tensor, ov.Tensor) + assert isinstance(tensor, ov.RemoteTensor) + assert "GPU" in tensor.get_device_name() + assert tensor.get_shape() == ov.Shape([1, 2, 3]) + assert tensor.get_element_type() == ov.Type.f32 + assert tensor.get_size() == 6 + assert tensor.get_byte_size() == 24 + assert list(tensor.get_strides()) == [24, 12, 4] + + tensor.set_shape([1, 1, 1]) + assert tensor.get_shape() + assert tensor.get_size() == 1 + assert tensor.get_byte_size() == 4 + assert list(tensor.get_strides()) == [4, 4, 4] + + with pytest.raises(TypeError) as constructor_error: + _ = ov.RemoteTensor(np.ones((1, 2, 3))) + assert "No constructor defined!" in str(constructor_error.value) + + with pytest.raises(RuntimeError) as copy_to_error: + _ = tensor.copy_to(None) + assert "This function is not implemented." in str(copy_to_error.value) + + with pytest.raises(RuntimeError) as data_error: + _ = tensor.data + assert "This function is not implemented." in str(data_error.value) + + with pytest.raises(RuntimeError) as bytes_data_error: + _ = tensor.bytes_data + assert "This function is not implemented." in str(bytes_data_error.value) + + with pytest.raises(RuntimeError) as str_data_error: + _ = tensor.str_data + assert "This function is not implemented." in str(str_data_error.value) + + +@pytest.mark.skipif( + "GPU" not in os.environ.get("TEST_DEVICE", ""), + reason="Test can be only performed on GPU device!", +) +def test_compile_with_context(): + core = ov.Core() + context = core.get_default_context("GPU") + model = get_relu_model() + compiled = core.compile_model(model, context) + assert isinstance(compiled, ov.CompiledModel) + + +@pytest.mark.skipif( + "GPU" not in os.environ.get("TEST_DEVICE", ""), + reason="Test can be only performed on GPU device!", +) +def test_va_context(): + core = ov.Core() + with pytest.raises(RuntimeError) as context_error: + _ = ov.VAContext(core, None) + assert "user handle is nullptr!" in str(context_error.value) diff --git a/tools/benchmark_tool/openvino/__init__.py b/tools/benchmark_tool/openvino/__init__.py index 9db717409cf327..e6700c56c61616 100644 --- a/tools/benchmark_tool/openvino/__init__.py +++ b/tools/benchmark_tool/openvino/__init__.py @@ -50,6 +50,13 @@ from openvino.runtime import save_model from openvino.runtime import layout_helpers +from openvino._pyopenvino import RemoteContext +from openvino._pyopenvino import RemoteTensor + +# libva related: +from openvino._pyopenvino import VAContext +from openvino._pyopenvino import VASurfaceTensor + # Set version for openvino package from openvino.runtime import get_version __version__ = get_version() diff --git a/tools/mo/openvino/__init__.py b/tools/mo/openvino/__init__.py index 9701fe0a9a45a6..3cbe7805f4c686 100644 --- a/tools/mo/openvino/__init__.py +++ b/tools/mo/openvino/__init__.py @@ -47,6 +47,13 @@ from openvino.runtime import save_model from openvino.runtime import layout_helpers + from openvino._pyopenvino import RemoteContext + from openvino._pyopenvino import RemoteTensor + + # libva related: + from openvino._pyopenvino import VAContext + from openvino._pyopenvino import VASurfaceTensor + # Set version for openvino package from openvino.runtime import get_version __version__ = get_version() diff --git a/tools/openvino_dev/src/openvino/__init__.py b/tools/openvino_dev/src/openvino/__init__.py index 9701fe0a9a45a6..3cbe7805f4c686 100644 --- a/tools/openvino_dev/src/openvino/__init__.py +++ b/tools/openvino_dev/src/openvino/__init__.py @@ -47,6 +47,13 @@ from openvino.runtime import save_model from openvino.runtime import layout_helpers + from openvino._pyopenvino import RemoteContext + from openvino._pyopenvino import RemoteTensor + + # libva related: + from openvino._pyopenvino import VAContext + from openvino._pyopenvino import VASurfaceTensor + # Set version for openvino package from openvino.runtime import get_version __version__ = get_version() diff --git a/tools/ovc/openvino/__init__.py b/tools/ovc/openvino/__init__.py index 9db717409cf327..e6700c56c61616 100644 --- a/tools/ovc/openvino/__init__.py +++ b/tools/ovc/openvino/__init__.py @@ -50,6 +50,13 @@ from openvino.runtime import save_model from openvino.runtime import layout_helpers +from openvino._pyopenvino import RemoteContext +from openvino._pyopenvino import RemoteTensor + +# libva related: +from openvino._pyopenvino import VAContext +from openvino._pyopenvino import VASurfaceTensor + # Set version for openvino package from openvino.runtime import get_version __version__ = get_version()