diff --git a/otx/algorithms/common/configs/training_base.py b/otx/algorithms/common/configs/training_base.py index 1b7032bb9b9..c644c014106 100644 --- a/otx/algorithms/common/configs/training_base.py +++ b/otx/algorithms/common/configs/training_base.py @@ -372,4 +372,20 @@ class BaseTilingParameters(ParameterGroup): affects_outcome_of=ModelLifecycle.NONE, ) + tile_ir_scale_factor = configurable_float( + header="OpenVINO IR Scale Factor", + description="The purpose of the scale parameter is to optimize the performance and " + "efficiency of tiling in OpenVINO IR during inference. By controlling the increase in tile size and " + "input size, the scale parameter allows for more efficient parallelization of the workload and " + "improve the overall performance and efficiency of the inference process on OpenVINO.", + warning="Setting the scale factor value too high may cause the application " + "to crash or result in out-of-memory errors. It is recommended to " + "adjust the scale factor value carefully based on the available " + "hardware resources and the needs of the application.", + default_value=2.0, + min_value=1.0, + max_value=4.0, + affects_outcome_of=ModelLifecycle.NONE, + ) + tiling_parameters = add_parameter_group(BaseTilingParameters) diff --git a/otx/algorithms/detection/adapters/mmdet/task.py b/otx/algorithms/detection/adapters/mmdet/task.py index c7bf809e422..d5ec4ccd28a 100644 --- a/otx/algorithms/detection/adapters/mmdet/task.py +++ b/otx/algorithms/detection/adapters/mmdet/task.py @@ -41,7 +41,6 @@ from otx.algorithms.common.adapters.mmcv.utils import ( adapt_batch_size, build_data_parallel, - get_configs_by_pairs, patch_data_pipeline, patch_from_hyperparams, ) @@ -63,7 +62,12 @@ from otx.algorithms.detection.adapters.mmdet.hooks.det_class_probability_map_hook import ( DetClassProbabilityMapHook, ) -from otx.algorithms.detection.adapters.mmdet.utils import patch_tiling +from otx.algorithms.detection.adapters.mmdet.utils import ( + patch_input_preprocessing, + patch_input_shape, + patch_ir_scale_factor, + patch_tiling, +) from otx.algorithms.detection.adapters.mmdet.utils.builder import build_detector from otx.algorithms.detection.adapters.mmdet.utils.config_utils import ( should_cluster_anchors, @@ -621,67 +625,10 @@ def _init_deploy_cfg(self, cfg) -> Union[Config, None]: if os.path.exists(deploy_cfg_path): deploy_cfg = MPAConfig.fromfile(deploy_cfg_path) - def patch_input_preprocessing(deploy_cfg): - normalize_cfg = get_configs_by_pairs( - cfg.data.test.pipeline, - dict(type="Normalize"), - ) - assert len(normalize_cfg) == 1 - normalize_cfg = normalize_cfg[0] - - options = dict(flags=[], args={}) - # NOTE: OTX loads image in RGB format - # so that `to_rgb=True` means a format change to BGR instead. - # Conventionally, OpenVINO IR expects a image in BGR format - # but OpenVINO IR under OTX assumes a image in RGB format. - # - # `to_rgb=True` -> a model was trained with images in BGR format - # and a OpenVINO IR needs to reverse input format from RGB to BGR - # `to_rgb=False` -> a model was trained with images in RGB format - # and a OpenVINO IR does not need to do a reverse - if normalize_cfg.get("to_rgb", False): - options["flags"] += ["--reverse_input_channels"] - # value must be a list not a tuple - if normalize_cfg.get("mean", None) is not None: - options["args"]["--mean_values"] = list(normalize_cfg.get("mean")) - if normalize_cfg.get("std", None) is not None: - options["args"]["--scale_values"] = list(normalize_cfg.get("std")) - - # fill default - backend_config = deploy_cfg.backend_config - if backend_config.get("mo_options") is None: - backend_config.mo_options = ConfigDict() - mo_options = backend_config.mo_options - if mo_options.get("args") is None: - mo_options.args = ConfigDict() - if mo_options.get("flags") is None: - mo_options.flags = [] - - # already defiend options have higher priority - options["args"].update(mo_options.args) - mo_options.args = ConfigDict(options["args"]) - # make sure no duplicates - mo_options.flags.extend(options["flags"]) - mo_options.flags = list(set(mo_options.flags)) - - def patch_input_shape(deploy_cfg): - resize_cfg = get_configs_by_pairs( - cfg.data.test.pipeline, - dict(type="Resize"), - ) - assert len(resize_cfg) == 1 - resize_cfg = resize_cfg[0] - size = resize_cfg.size - if isinstance(size, int): - size = (size, size) - assert all(isinstance(i, int) and i > 0 for i in size) - # default is static shape to prevent an unexpected error - # when converting to OpenVINO IR - deploy_cfg.backend_config.model_inputs = [ConfigDict(opt_shapes=ConfigDict(input=[1, 3, *size]))] - - patch_input_preprocessing(deploy_cfg) + patch_input_preprocessing(cfg, deploy_cfg) if not deploy_cfg.backend_config.get("model_inputs", []): - patch_input_shape(deploy_cfg) + patch_input_shape(cfg, deploy_cfg) + patch_ir_scale_factor(deploy_cfg, self._hyperparams) return deploy_cfg diff --git a/otx/algorithms/detection/adapters/mmdet/utils/__init__.py b/otx/algorithms/detection/adapters/mmdet/utils/__init__.py index 18238cfbb1b..77b125b0ca5 100644 --- a/otx/algorithms/detection/adapters/mmdet/utils/__init__.py +++ b/otx/algorithms/detection/adapters/mmdet/utils/__init__.py @@ -9,6 +9,9 @@ patch_config, patch_datasets, patch_evaluation, + patch_input_preprocessing, + patch_input_shape, + patch_ir_scale_factor, patch_tiling, prepare_for_training, set_hyperparams, @@ -23,4 +26,7 @@ "set_hyperparams", "build_detector", "patch_tiling", + "patch_input_preprocessing", + "patch_input_shape", + "patch_ir_scale_factor", ] diff --git a/otx/algorithms/detection/adapters/mmdet/utils/config_utils.py b/otx/algorithms/detection/adapters/mmdet/utils/config_utils.py index 16e8ff8f518..fe3e77c3592 100644 --- a/otx/algorithms/detection/adapters/mmdet/utils/config_utils.py +++ b/otx/algorithms/detection/adapters/mmdet/utils/config_utils.py @@ -373,3 +373,98 @@ def patch_tiling(config, hparams, dataset=None): config.update(dict(evaluation=dict(iou_thr=[0.5]))) return config + + +def patch_input_preprocessing(cfg: ConfigDict, deploy_cfg: ConfigDict): + """Update backend configuration with input preprocessing options. + + - If `"to_rgb"` in Normalize config is truthy, it adds `"--reverse_input_channels"` as a flag. + + The function then sets default values for the backend configuration in `deploy_cfg`. + + Args: + cfg (mmcv.ConfigDict): Config object containing test pipeline and other configurations. + deploy_cfg (mmcv.ConfigDict): DeployConfig object containing backend configuration. + + Returns: + None: This function updates the input `deploy_cfg` object directly. + """ + normalize_cfgs = get_configs_by_pairs(cfg.data.test.pipeline, dict(type="Normalize")) + assert len(normalize_cfgs) == 1 + normalize_cfg: dict = normalize_cfgs[0] + + # Set options based on Normalize config + options = { + "flags": ["--reverse_input_channels"] if normalize_cfg.get("to_rgb", False) else [], + "args": { + "--mean_values": list(normalize_cfg.get("mean", [])), + "--scale_values": list(normalize_cfg.get("std", [])), + }, + } + + # Set default backend configuration + mo_options = deploy_cfg.backend_config.get("mo_options", ConfigDict()) + mo_options = ConfigDict() if mo_options is None else mo_options + mo_options.args = mo_options.get("args", ConfigDict()) + mo_options.flags = mo_options.get("flags", []) + + # Override backend configuration with options from Normalize config + mo_options.args.update(options["args"]) + mo_options.flags = list(set(mo_options.flags + options["flags"])) + + deploy_cfg.backend_config.mo_options = mo_options + + +def patch_input_shape(cfg: ConfigDict, deploy_cfg: ConfigDict): + """Update backend configuration with input shape information. + + This function retrieves the Resize config from `cfg.data.test.pipeline`, checks + that only one Resize then sets the input shape for the backend model in `deploy_cfg` + + ``` + { + "opt_shapes": { + "input": [1, 3, *size] + } + } + ``` + + Args: + cfg (Config): Config object containing test pipeline and other configurations. + deploy_cfg (DeployConfig): DeployConfig object containing backend configuration. + + Returns: + None: This function updates the input `deploy_cfg` object directly. + """ + resize_cfgs = get_configs_by_pairs( + cfg.data.test.pipeline, + dict(type="Resize"), + ) + assert len(resize_cfgs) == 1 + resize_cfg: ConfigDict = resize_cfgs[0] + size = resize_cfg.size + if isinstance(size, int): + size = (size, size) + assert all(isinstance(i, int) and i > 0 for i in size) + # default is static shape to prevent an unexpected error + # when converting to OpenVINO IR + deploy_cfg.backend_config.model_inputs = [ConfigDict(opt_shapes=ConfigDict(input=[1, 3, *size]))] + + +def patch_ir_scale_factor(deploy_cfg: ConfigDict, hyper_parameters: DetectionConfig): + """Patch IR scale factor inplace from hyper parameters to deploy config. + + Args: + deploy_cfg (ConfigDict): mmcv deploy config + hyper_parameters (DetectionConfig): OTX detection hyper parameters + """ + + if hyper_parameters.tiling_parameters.enable_tiling: + scale_ir_input = deploy_cfg.get("scale_ir_input", False) + if scale_ir_input: + tile_ir_scale_factor = hyper_parameters.tiling_parameters.tile_ir_scale_factor + logger.info(f"Apply OpenVINO IR scale factor: {tile_ir_scale_factor}") + ir_input_shape = deploy_cfg.backend_config.model_inputs[0].opt_shapes.input + ir_input_shape[2] = int(ir_input_shape[2] * tile_ir_scale_factor) # height + ir_input_shape[3] = int(ir_input_shape[3] * tile_ir_scale_factor) # width + deploy_cfg.ir_config.input_shape = (ir_input_shape[3], ir_input_shape[2]) # width, height diff --git a/otx/algorithms/detection/adapters/openvino/task.py b/otx/algorithms/detection/adapters/openvino/task.py index 5ca6ef8e7cc..659112dd1d2 100644 --- a/otx/algorithms/detection/adapters/openvino/task.py +++ b/otx/algorithms/detection/adapters/openvino/task.py @@ -261,6 +261,7 @@ class OpenVINOTileClassifierWrapper(BaseInferencerWithConverter): tile_size (int): tile size overlap (float): overlap ratio between tiles max_number (int): maximum number of objects per image + tile_ir_scale_factor (float, optional): scale factor for tile size tile_classifier_model_file (Union[str, bytes, None], optional): tile classifier xml. Defaults to None. tile_classifier_weight_file (Union[str, bytes, None], optional): til classifier weight bin. Defaults to None. device (str, optional): device to run inference on, such as CPU, GPU or MYRIAD. Defaults to "CPU". @@ -274,6 +275,7 @@ def __init__( tile_size: int = 400, overlap: float = 0.5, max_number: int = 100, + tile_ir_scale_factor: float = 1.0, tile_classifier_model_file: Union[str, bytes, None] = None, tile_classifier_weight_file: Union[str, bytes, None] = None, device: str = "CPU", @@ -293,7 +295,7 @@ def __init__( classifier = Model(model_adapter=adapter, preload=True) self.tiler = Tiler( - tile_size=tile_size, + tile_size=int(tile_size * tile_ir_scale_factor), overlap=overlap, max_number=max_number, detector=inferencer.model, @@ -372,6 +374,10 @@ def load_config(self) -> ADDict: if self.model is not None and self.model.get_data("config.json"): json_dict = json.loads(self.model.get_data("config.json")) flatten_config_values(json_dict) + # NOTE: for backward compatibility + json_dict["tiling_parameters"]["tile_ir_scale_factor"] = json_dict["tiling_parameters"].get( + "tile_ir_scale_factor", 1.0 + ) config = merge_a_into_b(json_dict, config) except Exception as e: # pylint: disable=broad-except logger.warning(f"Failed to load config.json: {e}") @@ -418,6 +424,7 @@ def load_inferencer( self.config.tiling_parameters.tile_size, self.config.tiling_parameters.tile_overlap, self.config.tiling_parameters.tile_max_number, + self.config.tiling_parameters.tile_ir_scale_factor, tile_classifier_model_file, tile_classifier_weight_file, ) diff --git a/otx/algorithms/detection/configs/instance_segmentation/configuration.yaml b/otx/algorithms/detection/configs/instance_segmentation/configuration.yaml index 16eb6cb0571..088fb58b33a 100644 --- a/otx/algorithms/detection/configs/instance_segmentation/configuration.yaml +++ b/otx/algorithms/detection/configs/instance_segmentation/configuration.yaml @@ -553,5 +553,23 @@ tiling_parameters: visible_in_ui: true warning: null + tile_ir_scale_factor: + header: OpenVINO IR Scale Factor + description: The purpose of the scale parameter is to optimize the performance and efficiency of tiling in OpenVINO IR during inference. By controlling the increase in tile size and input size, the scale parameter allows for more efficient parallelization of the workload and improve the overall performance and efficiency of the inference process on OpenVINO. + affects_outcome_of: TRAINING + default_value: 2.0 + min_value: 1.0 + max_value: 4.0 + type: FLOAT + editable: true + ui_rules: + action: DISABLE_EDITING + operator: AND + rules: [] + type: UI_RULES + value: 2.0 + visible_in_ui: true + warning: null + type: PARAMETER_GROUP visible_in_ui: true diff --git a/otx/algorithms/detection/configs/instance_segmentation/efficientnetb2b_maskrcnn/deployment.py b/otx/algorithms/detection/configs/instance_segmentation/efficientnetb2b_maskrcnn/deployment.py index fa981b687fc..f9701f38ac0 100644 --- a/otx/algorithms/detection/configs/instance_segmentation/efficientnetb2b_maskrcnn/deployment.py +++ b/otx/algorithms/detection/configs/instance_segmentation/efficientnetb2b_maskrcnn/deployment.py @@ -2,8 +2,11 @@ _base_ = ["../../base/deployments/base_instance_segmentation_dynamic.py"] +scale_ir_input = True + ir_config = dict( output_names=["boxes", "labels", "masks"], + input_shape=(1024, 1024), ) backend_config = dict( diff --git a/otx/algorithms/detection/configs/instance_segmentation/resnet50_maskrcnn/deployment.py b/otx/algorithms/detection/configs/instance_segmentation/resnet50_maskrcnn/deployment.py index 715e77ef995..8ef82f1ca34 100644 --- a/otx/algorithms/detection/configs/instance_segmentation/resnet50_maskrcnn/deployment.py +++ b/otx/algorithms/detection/configs/instance_segmentation/resnet50_maskrcnn/deployment.py @@ -2,8 +2,11 @@ _base_ = ["../../base/deployments/base_instance_segmentation_dynamic.py"] +scale_ir_input = True + ir_config = dict( output_names=["boxes", "labels", "masks"], + input_shape=(1344, 800), ) backend_config = dict( diff --git a/otx/api/utils/tiler.py b/otx/api/utils/tiler.py index cbbabae68ae..8f5bba63507 100644 --- a/otx/api/utils/tiler.py +++ b/otx/api/utils/tiler.py @@ -70,7 +70,7 @@ def tile(self, image: np.ndarray) -> List[List[int]]: return coords def filter_tiles_by_objectness( - self, image: np.ndarray, tile_coords: List[List[int]], confidence_threshold: float = 0.45 + self, image: np.ndarray, tile_coords: List[List[int]], confidence_threshold: float = 0.35 ): """Filter tiles by objectness score by running tile classifier. diff --git a/tests/regression/detection/test_tiling_detection.py b/tests/regression/detection/test_tiling_detection.py index 04bb571aa8b..d33845c1b58 100644 --- a/tests/regression/detection/test_tiling_detection.py +++ b/tests/regression/detection/test_tiling_detection.py @@ -178,8 +178,8 @@ def test_otx_deploy_eval_deployment(self, template, tmp_dir_path): assert test_result["passed"] is True, test_result["log"] @e2e_pytest_component - @pytest.mark.parametrize("template", templates, ids=templates_ids) @pytest.mark.skip(reason="CVS-98026") + @pytest.mark.parametrize("template", templates, ids=templates_ids) def test_nncf_optimize_eval(self, template, tmp_dir_path): self.performance[template.name] = {} diff --git a/tests/regression/instance_segmentation/test_tiling_instnace_segmentation.py b/tests/regression/instance_segmentation/test_tiling_instnace_segmentation.py index bc1a9bf5c7d..61d5d020137 100644 --- a/tests/regression/instance_segmentation/test_tiling_instnace_segmentation.py +++ b/tests/regression/instance_segmentation/test_tiling_instnace_segmentation.py @@ -178,8 +178,8 @@ def test_otx_deploy_eval_deployment(self, template, tmp_dir_path): assert test_result["passed"] is True, test_result["log"] @e2e_pytest_component - @pytest.mark.parametrize("template", templates, ids=templates_ids) @pytest.mark.skip(reason="CVS-98026") + @pytest.mark.parametrize("template", templates, ids=templates_ids) def test_nncf_optimize_eval(self, template, tmp_dir_path): self.performance[template.name] = {} diff --git a/tests/unit/algorithms/detection/tiling/test_tiling_detection.py b/tests/unit/algorithms/detection/tiling/test_tiling_detection.py index 70f7a8db998..c1462e8e05b 100644 --- a/tests/unit/algorithms/detection/tiling/test_tiling_detection.py +++ b/tests/unit/algorithms/detection/tiling/test_tiling_detection.py @@ -8,10 +8,14 @@ import numpy as np import pytest import torch -from mmcv import ConfigDict +from mmcv import Config, ConfigDict from mmdet.datasets import build_dataloader, build_dataset +from mmdet.models import DETECTORS +from openvino.model_zoo.model_api.adapters import OpenvinoAdapter, create_core +from torch import nn from otx.algorithms.common.adapters.mmcv.utils.config_utils import MPAConfig +from otx.algorithms.common.adapters.mmdeploy.apis import MMdeployExporter from otx.algorithms.detection.adapters.mmdet.task import MMDetectionTask from otx.algorithms.detection.adapters.mmdet.utils import build_detector, patch_tiling from otx.api.configuration.helper import create @@ -32,8 +36,26 @@ ) +@DETECTORS.register_module(force=True) +class MockDetModel(nn.Module): + def __init__(self, backbone, train_cfg=None, test_cfg=None, init_cfg=None): + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + self.box_dummy = torch.nn.AdaptiveAvgPool2d((1, 5)) + self.label_dummy = torch.nn.AdaptiveAvgPool2d((1)) + self.mask_dummy = torch.nn.AdaptiveAvgPool2d((28, 28)) + + def forward(self, *args, **kwargs): + img = args[0] + x = self.conv(img) + boxes = self.box_dummy(x).mean(1) + labels = self.label_dummy(x).mean(1) + masks = self.mask_dummy(x).mean(1) + return boxes, labels, masks + + def create_otx_dataset(height: int, width: int, labels: List[str]): - """Create a random OTX dataset + """Create a random OTX dataset. Args: height (int): The height of the image @@ -54,11 +76,11 @@ def create_otx_dataset(height: int, width: int, labels: List[str]): class TestTilingDetection: - """Test the tiling functionality""" + """Test the tiling detection algorithm.""" @pytest.fixture(autouse=True) def setUp(self) -> None: - """Setup the test case""" + """Setup the test case.""" self.height = 1024 self.width = 1024 self.label_names = ["rectangle", "ellipse", "triangle"] @@ -132,7 +154,7 @@ def setUp(self) -> None: @e2e_pytest_unit def test_tiling_train_dataloader(self): - """Test that the training dataloader is built correctly for tiling""" + """Test that the training dataloader is built correctly for tiling.""" dataset = build_dataset(self.train_data_cfg) train_dataloader = build_dataloader(dataset, **self.dataloader_cfg) @@ -143,7 +165,7 @@ def test_tiling_train_dataloader(self): @e2e_pytest_unit def test_tiling_test_dataloader(self): - """Test that the testing dataloader is built correctly for tiling""" + """Test that the testing dataloader is built correctly for tiling.""" dataset = build_dataset(self.test_data_cfg) stride = int((1 - self.tile_cfg["overlap_ratio"]) * self.tile_cfg["tile_size"]) @@ -160,7 +182,7 @@ def test_tiling_test_dataloader(self): @e2e_pytest_unit def test_inference_merge(self): - """Test that the inference merge works correctly""" + """Test that the inference merge works correctly.""" dataset = build_dataset(self.test_data_cfg) # create simulated inference results @@ -222,7 +244,7 @@ def test_load_tiling_parameters(self, tmp_dir_path): @e2e_pytest_unit def test_patch_tiling_func(self): - """Test that patch_tiling function works correctly""" + """Test that patch_tiling function works correctly.""" cfg = MPAConfig.fromfile(os.path.join(DEFAULT_ISEG_TEMPLATE_DIR, "model.py")) model_template = parse_model_template(os.path.join(DEFAULT_ISEG_TEMPLATE_DIR, "template.yaml")) hyper_parameters = create(model_template.hyper_parameters.data) @@ -235,6 +257,60 @@ def test_patch_tiling_func(self): patch_tiling(cfg, hyper_parameters, self.otx_dataset) @e2e_pytest_unit - def test_openvino(self): - # TODO[EUGENE]: implement unittest for tiling prediction with openvino - pass + @pytest.mark.parametrize("scale_factor", [1, 1.5, 2, 3, 4]) + def test_tile_ir_scale_deploy(self, tmp_dir_path, scale_factor): + """Test that the IR scale factor is correctly applied during inference.""" + model_template = parse_model_template(os.path.join(DEFAULT_ISEG_TEMPLATE_DIR, "template.yaml")) + hyper_parameters = create(model_template.hyper_parameters.data) + hyper_parameters.tiling_parameters.enable_tiling = True + hyper_parameters.tiling_parameters.tile_ir_scale_factor = scale_factor + task_env = init_environment(hyper_parameters, model_template) + img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + task = MMDetectionTask(task_env) + pipeline = [ + dict(type="LoadImageFromFile"), + dict( + type="MultiScaleFlipAug", + img_scale=(800, 800), + flip=False, + transforms=[ + dict(type="Resize", keep_ratio=False), + dict(type="RandomFlip"), + dict(type="Normalize", **img_norm_cfg), + dict(type="Pad", size_divisor=32), + dict(type="DefaultFormatBundle"), + dict(type="Collect", keys=["img"]), + ], + ), + ] + config = Config( + dict(model=dict(type="MockDetModel", backbone=dict(init_cfg=None)), data=dict(test=dict(pipeline=pipeline))) + ) + + deploy_cfg = task._init_deploy_cfg(config) + onnx_path = MMdeployExporter.torch2onnx( + tmp_dir_path, + np.zeros((50, 50, 3), dtype=np.float32), + config, + deploy_cfg, + ) + assert isinstance(onnx_path, str) + assert os.path.exists(onnx_path) + + openvino_paths = MMdeployExporter.onnx2openvino( + tmp_dir_path, + onnx_path, + deploy_cfg, + ) + for openvino_path in openvino_paths: + assert os.path.exists(openvino_path) + + task._init_task() + original_width, original_height = task._recipe_cfg.data.test.pipeline[0].img_scale # w, h + + model_adapter = OpenvinoAdapter(create_core(), openvino_paths[0], openvino_paths[1]) + + ir_input_shape = model_adapter.get_input_layers()["image"].shape + _, _, ir_height, ir_width = ir_input_shape + assert ir_height == original_height * scale_factor + assert ir_width == original_width * scale_factor