From 99aa3e7232e10632a6091dcb15dd64d27c2d7a78 Mon Sep 17 00:00:00 2001 From: Russell Brooks Date: Mon, 13 May 2024 00:08:57 -0500 Subject: [PATCH 1/2] passthrough ONNX session options --- unstructured_inference/models/detectron2onnx.py | 7 +++++++ unstructured_inference/models/yolox.py | 15 +++++++++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/unstructured_inference/models/detectron2onnx.py b/unstructured_inference/models/detectron2onnx.py index 79cd0a1a..b9df790f 100644 --- a/unstructured_inference/models/detectron2onnx.py +++ b/unstructured_inference/models/detectron2onnx.py @@ -99,6 +99,7 @@ def initialize( model_path: str, label_map: Dict[int, str], confidence_threshold: Optional[float] = None, + session_options_dict: Optional[Dict[str, Union[int, bool, str]]] = None, ): """Loads the detectron2 model using the specified parameters""" if not os.path.exists(model_path) and "detectron2_quantized" in model_path: @@ -115,8 +116,14 @@ def initialize( ] providers = [provider for provider in ordered_providers if provider in available_providers] + session_options = onnxruntime.SessionOptions() + if session_options_dict: + for option_name, option_value in session_options_dict.items(): + setattr(session_options, option_name, option_value) + self.model = onnxruntime.InferenceSession( model_path, + sess_options=session_options, providers=providers, ) self.model_path = model_path diff --git a/unstructured_inference/models/yolox.py b/unstructured_inference/models/yolox.py index 0acd93f3..14e535e0 100644 --- a/unstructured_inference/models/yolox.py +++ b/unstructured_inference/models/yolox.py @@ -3,7 +3,7 @@ # https://github.com/Megvii-BaseDetection/YOLOX/blob/237e943ac64aa32eb32f875faa93ebb18512d41d/yolox/data/data_augment.py # https://github.com/Megvii-BaseDetection/YOLOX/blob/ac379df3c97d1835ebd319afad0c031c36d03f36/yolox/utils/demo_utils.py -from typing import List, cast +from typing import Dict, List, Optional, Union, cast import cv2 import numpy as np @@ -68,7 +68,12 @@ def predict(self, x: PILImage.Image): super().predict(x) return self.image_processing(x) - def initialize(self, model_path: str, label_map: dict): + def initialize( + self, + model_path: str, + label_map: dict, + session_options_dict: Optional[Dict[str, Union[int, bool, str]]] = None, + ): """Start inference session for YoloX model.""" self.model_path = model_path @@ -80,8 +85,14 @@ def initialize(self, model_path: str, label_map: dict): ] providers = [provider for provider in ordered_providers if provider in available_providers] + session_options = onnxruntime.SessionOptions() + if session_options_dict: + for option_name, option_value in session_options_dict.items(): + setattr(session_options, option_name, option_value) + self.model = onnxruntime.InferenceSession( model_path, + sess_options=session_options, providers=providers, ) From 0505f29f6102809be9aeb1a95dcb782193650706 Mon Sep 17 00:00:00 2001 From: Russell Brooks Date: Mon, 13 May 2024 00:09:18 -0500 Subject: [PATCH 2/2] tidy up test --- .../models/test_detectron2onnx.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/test_unstructured_inference/models/test_detectron2onnx.py b/test_unstructured_inference/models/test_detectron2onnx.py index 3be5916f..464b3823 100644 --- a/test_unstructured_inference/models/test_detectron2onnx.py +++ b/test_unstructured_inference/models/test_detectron2onnx.py @@ -37,11 +37,18 @@ def test_load_default_model(monkeypatch): @pytest.mark.parametrize(("model_path", "label_map"), [("asdf", "diufs"), ("dfaw", "hfhfhfh")]) def test_load_model(model_path, label_map): - with patch.object(detectron2.onnxruntime, "InferenceSession", return_value=True): + session_options_dict = {"intra_op_num_threads": 1, "inter_op_num_threads": 1} + with patch.object(detectron2.onnxruntime, "InferenceSession") as mock_session: model = detectron2.UnstructuredDetectronONNXModel() - model.initialize(model_path=model_path, label_map=label_map) - args, _ = detectron2.onnxruntime.InferenceSession.call_args - assert args == (model_path,) + model.initialize( + model_path=model_path, + label_map=label_map, + session_options_dict=session_options_dict + ) + args, kwargs = mock_session.call_args + assert args[0] == model_path + assert kwargs['sess_options'].intra_op_num_threads == 1 + assert kwargs['sess_options'].inter_op_num_threads == 1 assert label_map == model.label_map