Skip to content

Commit

Permalink
Adds missing numpy type when looking for the ort correspondance (#10943)
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre authored and chilo-ms committed Apr 15, 2022
1 parent ec00d28 commit c48922e
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 35 deletions.
4 changes: 2 additions & 2 deletions onnxruntime/python/onnxruntime_pybind_iobinding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ void addIoBindingMethods(pybind11::module& m) {
Py_DECREF(dtype);

OrtMemoryInfo info(GetDeviceName(device), OrtDeviceAllocator, device, device.Id());
auto ml_type = NumpyTypeToOnnxRuntimeType(type_num);
auto ml_type = NumpyTypeToOnnxRuntimeTensorType(type_num);
OrtValue ml_value;
Tensor::InitOrtValue(ml_type, gsl::make_span(shape), reinterpret_cast<void*>(data_ptr), info, ml_value);

Expand Down Expand Up @@ -124,7 +124,7 @@ void addIoBindingMethods(pybind11::module& m) {
Py_DECREF(dtype);

OrtMemoryInfo info(GetDeviceName(device), OrtDeviceAllocator, device, device.Id());
auto ml_type = NumpyTypeToOnnxRuntimeType(type_num);
auto ml_type = NumpyTypeToOnnxRuntimeTensorType(type_num);
OrtValue ml_value;
Tensor::InitOrtValue(ml_type, gsl::make_span(shape), reinterpret_cast<void*>(data_ptr), info, ml_value);

Expand Down
29 changes: 3 additions & 26 deletions onnxruntime/python/onnxruntime_pybind_mlvalue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,30 +236,7 @@ int OnnxRuntimeTensorToNumpyType(const DataTypeImpl* tensor_type) {
}
}

MLDataType NumpyTypeToOnnxRuntimeType(int numpy_type) {
static std::map<int, MLDataType> type_map{
{NPY_BOOL, DataTypeImpl::GetType<bool>()},
{NPY_FLOAT, DataTypeImpl::GetType<float>()},
{NPY_FLOAT16, DataTypeImpl::GetType<MLFloat16>()},
{NPY_DOUBLE, DataTypeImpl::GetType<double>()},
{NPY_INT8, DataTypeImpl::GetType<int8_t>()},
{NPY_UINT8, DataTypeImpl::GetType<uint8_t>()},
{NPY_INT16, DataTypeImpl::GetType<int16_t>()},
{NPY_UINT16, DataTypeImpl::GetType<uint16_t>()},
{NPY_INT, DataTypeImpl::GetType<int32_t>()},
{NPY_UINT, DataTypeImpl::GetType<uint32_t>()},
{NPY_LONGLONG, DataTypeImpl::GetType<int64_t>()},
{NPY_ULONGLONG, DataTypeImpl::GetType<uint64_t>()},
{NPY_OBJECT, DataTypeImpl::GetType<std::string>()}};
const auto it = type_map.find(numpy_type);
if (it == type_map.end()) {
throw std::runtime_error("No corresponding Numpy type for Tensor Type.");
} else {
return it->second;
}
}

MLDataType NumpyToOnnxRuntimeTensorType(int numpy_type) {
MLDataType NumpyTypeToOnnxRuntimeTensorType(int numpy_type) {
static std::map<int, MLDataType> type_map{
{NPY_BOOL, DataTypeImpl::GetType<bool>()},
{NPY_FLOAT, DataTypeImpl::GetType<float>()},
Expand Down Expand Up @@ -452,7 +429,7 @@ static std::unique_ptr<Tensor> CreateTensor(const AllocatorPtr& alloc, const std

const int npy_type = PyArray_TYPE(darray);
TensorShape shape = GetArrayShape(darray);
auto element_type = NumpyToOnnxRuntimeTensorType(npy_type);
auto element_type = NumpyTypeToOnnxRuntimeTensorType(npy_type);
if (IsNumericNumpyType(npy_type) && use_numpy_data_memory) {
if (pyObject == darray) {
// Use the memory of numpy array directly. The ownership belongs to the calling
Expand Down Expand Up @@ -544,7 +521,7 @@ static void CreateTensorMLValue(const AllocatorPtr& alloc, const std::string& na
static void CreateTensorMLValueOwned(const OrtPybindSingleUseAllocatorPtr& pybind_alloc, const AllocatorPtr& alloc, OrtValue* p_mlvalue) {
auto npy_type = PyArray_TYPE(pybind_alloc->GetContiguous());
TensorShape shape = GetArrayShape(pybind_alloc->GetContiguous());
auto element_type = NumpyToOnnxRuntimeTensorType(npy_type);
auto element_type = NumpyTypeToOnnxRuntimeTensorType(npy_type);

std::unique_ptr<Tensor> p_tensor;

Expand Down
4 changes: 1 addition & 3 deletions onnxruntime/python/onnxruntime_pybind_mlvalue.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ TensorShape GetShape(const pybind11::array& arr);

int OnnxRuntimeTensorToNumpyType(const DataTypeImpl* tensor_type);

MLDataType NumpyTypeToOnnxRuntimeType(int numpy_type);

MLDataType NumpyToOnnxRuntimeTensorType(int numpy_type);
MLDataType NumpyTypeToOnnxRuntimeTensorType(int numpy_type);

using MemCpyFunc = void (*)(void*, const void*, size_t);

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/python/onnxruntime_pybind_ortvalue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ void addOrtValueMethods(pybind11::module& m) {
}

auto ml_value = std::make_unique<OrtValue>();
auto ml_type = NumpyTypeToOnnxRuntimeType(type_num);
auto ml_type = NumpyTypeToOnnxRuntimeTensorType(type_num);
Tensor::InitOrtValue(ml_type, gsl::make_span(shape), std::move(allocator), *ml_value);
return ml_value;
})
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/python/onnxruntime_pybind_sparse_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ void addSparseTensorMethods(pybind11::module& m) {

TensorShape dense_shape(py_dense_shape);
auto values_type = GetNumpyArrayType(py_values);
auto ml_type = NumpyToOnnxRuntimeTensorType(values_type);
auto ml_type = NumpyTypeToOnnxRuntimeTensorType(values_type);

std::unique_ptr<PySparseTensor> result;
if (IsNumericNumpyType(values_type)) {
Expand Down Expand Up @@ -199,7 +199,7 @@ void addSparseTensorMethods(pybind11::module& m) {

TensorShape dense_shape(py_dense_shape);
auto values_type = GetNumpyArrayType(py_values);
auto ml_type = NumpyToOnnxRuntimeTensorType(values_type);
auto ml_type = NumpyTypeToOnnxRuntimeTensorType(values_type);

std::unique_ptr<PySparseTensor> result;
if (IsNumericNumpyType(values_type)) {
Expand Down Expand Up @@ -262,7 +262,7 @@ void addSparseTensorMethods(pybind11::module& m) {
TensorShape values_shape = GetShape(py_values);
TensorShape index_shape = GetShape(py_indices);
auto values_type = GetNumpyArrayType(py_values);
auto ml_type = NumpyToOnnxRuntimeTensorType(values_type);
auto ml_type = NumpyTypeToOnnxRuntimeTensorType(values_type);

std::unique_ptr<PySparseTensor> result;
if (IsNumericNumpyType(values_type)) {
Expand Down
54 changes: 54 additions & 0 deletions onnxruntime/test/python/onnxruntime_test_python_iobinding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import numpy as np
from numpy.testing import assert_almost_equal
from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE
from onnx.defs import onnx_opset_version
from onnx import helper
import onnxruntime as onnxrt
from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611
OrtDevice as C_OrtDevice, OrtValue as C_OrtValue, SessionIOBinding)
import unittest

from helper import get_name
Expand Down Expand Up @@ -48,6 +54,54 @@ def test_bind_input_to_cpu_arr(self):
# Validate results
self.assertTrue(np.array_equal(self.create_expected_output(), ort_output))

def test_bind_input_types(self):

opset = onnx_opset_version()
devices = [(C_OrtDevice(C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0), ['CPUExecutionProvider'])]
if "CUDAExecutionProvider" in onnxrt.get_all_providers():
devices.append((C_OrtDevice(C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0), ['CUDAExecutionProvider']))

for device, provider in devices:
for dtype in [np.float32, np.float64, np.int32, np.uint32,
np.int64, np.uint64, np.int16, np.uint16,
np.int8, np.uint8, np.float16, np.bool_]:
with self.subTest(dtype=dtype, device=str(device)):

x = np.arange(8).reshape((-1, 2)).astype(dtype)
proto_dtype = NP_TYPE_TO_TENSOR_TYPE[x.dtype]

X = helper.make_tensor_value_info('X', proto_dtype, [None, x.shape[1]])
Y = helper.make_tensor_value_info('Y', proto_dtype, [None, x.shape[1]])

# inference
node_add = helper.make_node('Identity', ['X'], ['Y'])

# graph
graph_def = helper.make_graph([node_add], 'lr', [X], [Y], [])
model_def = helper.make_model(
graph_def, producer_name='dummy', ir_version=7,
producer_version="0",
opset_imports=[helper.make_operatorsetid('', opset)])

sess = onnxrt.InferenceSession(model_def.SerializeToString(), providers=provider)

bind = SessionIOBinding(sess._sess)
ort_value = C_OrtValue.ortvalue_from_numpy(x, device)
bind.bind_ortvalue_input('X', ort_value)
bind.bind_output('Y', device)
sess._sess.run_with_iobinding(bind, None)
ortvalue = bind.get_outputs()[0]
y = ortvalue.numpy()
assert_almost_equal(x, y)

bind = SessionIOBinding(sess._sess)
bind.bind_input('X', device, dtype, x.shape, ort_value.data_ptr())
bind.bind_output('Y', device)
sess._sess.run_with_iobinding(bind, None)
ortvalue = bind.get_outputs()[0]
y = ortvalue.numpy()
assert_almost_equal(x, y)

def test_bind_input_only(self):
input = self.create_ortvalue_input_on_gpu()

Expand Down

0 comments on commit c48922e

Please sign in to comment.