Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tiling Spatial Concatenation for OpenVINO IR #2052

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions otx/algorithms/common/configs/training_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,4 +372,20 @@ class BaseTilingParameters(ParameterGroup):
affects_outcome_of=ModelLifecycle.NONE,
)

tile_ir_scale_factor = configurable_float(
header="OpenVINO IR Scale Factor",
description="The purpose of the scale parameter is to optimize the performance and "
"efficiency of tiling in OpenVINO IR during inference. By controlling the increase in tile size and "
"input size, the scale parameter allows for more efficient parallelization of the workload and "
"improve the overall performance and efficiency of the inference process on OpenVINO.",
warning="Setting the scale factor value too high may cause the application "
"to crash or result in out-of-memory errors. It is recommended to "
"adjust the scale factor value carefully based on the available "
"hardware resources and the needs of the application.",
default_value=2.0,
min_value=1.0,
max_value=4.0,
affects_outcome_of=ModelLifecycle.NONE,
)

tiling_parameters = add_parameter_group(BaseTilingParameters)
71 changes: 9 additions & 62 deletions otx/algorithms/detection/adapters/mmdet/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from otx.algorithms.common.adapters.mmcv.utils import (
adapt_batch_size,
build_data_parallel,
get_configs_by_pairs,
patch_data_pipeline,
patch_from_hyperparams,
)
Expand All @@ -63,7 +62,12 @@
from otx.algorithms.detection.adapters.mmdet.hooks.det_class_probability_map_hook import (
DetClassProbabilityMapHook,
)
from otx.algorithms.detection.adapters.mmdet.utils import patch_tiling
from otx.algorithms.detection.adapters.mmdet.utils import (
patch_input_preprocessing,
patch_input_shape,
patch_ir_scale_factor,
patch_tiling,
)
from otx.algorithms.detection.adapters.mmdet.utils.builder import build_detector
from otx.algorithms.detection.adapters.mmdet.utils.config_utils import (
should_cluster_anchors,
Expand Down Expand Up @@ -621,67 +625,10 @@ def _init_deploy_cfg(self, cfg) -> Union[Config, None]:
if os.path.exists(deploy_cfg_path):
deploy_cfg = MPAConfig.fromfile(deploy_cfg_path)

def patch_input_preprocessing(deploy_cfg):
normalize_cfg = get_configs_by_pairs(
cfg.data.test.pipeline,
dict(type="Normalize"),
)
assert len(normalize_cfg) == 1
normalize_cfg = normalize_cfg[0]

options = dict(flags=[], args={})
# NOTE: OTX loads image in RGB format
# so that `to_rgb=True` means a format change to BGR instead.
# Conventionally, OpenVINO IR expects a image in BGR format
# but OpenVINO IR under OTX assumes a image in RGB format.
#
# `to_rgb=True` -> a model was trained with images in BGR format
# and a OpenVINO IR needs to reverse input format from RGB to BGR
# `to_rgb=False` -> a model was trained with images in RGB format
# and a OpenVINO IR does not need to do a reverse
if normalize_cfg.get("to_rgb", False):
options["flags"] += ["--reverse_input_channels"]
# value must be a list not a tuple
if normalize_cfg.get("mean", None) is not None:
options["args"]["--mean_values"] = list(normalize_cfg.get("mean"))
if normalize_cfg.get("std", None) is not None:
options["args"]["--scale_values"] = list(normalize_cfg.get("std"))

# fill default
backend_config = deploy_cfg.backend_config
if backend_config.get("mo_options") is None:
backend_config.mo_options = ConfigDict()
mo_options = backend_config.mo_options
if mo_options.get("args") is None:
mo_options.args = ConfigDict()
if mo_options.get("flags") is None:
mo_options.flags = []

# already defiend options have higher priority
options["args"].update(mo_options.args)
mo_options.args = ConfigDict(options["args"])
# make sure no duplicates
mo_options.flags.extend(options["flags"])
mo_options.flags = list(set(mo_options.flags))

def patch_input_shape(deploy_cfg):
resize_cfg = get_configs_by_pairs(
cfg.data.test.pipeline,
dict(type="Resize"),
)
assert len(resize_cfg) == 1
resize_cfg = resize_cfg[0]
size = resize_cfg.size
if isinstance(size, int):
size = (size, size)
assert all(isinstance(i, int) and i > 0 for i in size)
# default is static shape to prevent an unexpected error
# when converting to OpenVINO IR
deploy_cfg.backend_config.model_inputs = [ConfigDict(opt_shapes=ConfigDict(input=[1, 3, *size]))]

patch_input_preprocessing(deploy_cfg)
patch_input_preprocessing(cfg, deploy_cfg)
if not deploy_cfg.backend_config.get("model_inputs", []):
patch_input_shape(deploy_cfg)
patch_input_shape(cfg, deploy_cfg)
patch_ir_scale_factor(deploy_cfg, self._hyperparams)

return deploy_cfg

Expand Down
6 changes: 6 additions & 0 deletions otx/algorithms/detection/adapters/mmdet/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
patch_config,
patch_datasets,
patch_evaluation,
patch_input_preprocessing,
patch_input_shape,
patch_ir_scale_factor,
patch_tiling,
prepare_for_training,
set_hyperparams,
Expand All @@ -23,4 +26,7 @@
"set_hyperparams",
"build_detector",
"patch_tiling",
"patch_input_preprocessing",
"patch_input_shape",
"patch_ir_scale_factor",
]
95 changes: 95 additions & 0 deletions otx/algorithms/detection/adapters/mmdet/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,3 +373,98 @@ def patch_tiling(config, hparams, dataset=None):
config.update(dict(evaluation=dict(iou_thr=[0.5])))

return config


def patch_input_preprocessing(cfg: ConfigDict, deploy_cfg: ConfigDict):
"""Update backend configuration with input preprocessing options.

- If `"to_rgb"` in Normalize config is truthy, it adds `"--reverse_input_channels"` as a flag.

The function then sets default values for the backend configuration in `deploy_cfg`.

Args:
cfg (mmcv.ConfigDict): Config object containing test pipeline and other configurations.
deploy_cfg (mmcv.ConfigDict): DeployConfig object containing backend configuration.

Returns:
None: This function updates the input `deploy_cfg` object directly.
"""
normalize_cfgs = get_configs_by_pairs(cfg.data.test.pipeline, dict(type="Normalize"))
assert len(normalize_cfgs) == 1
normalize_cfg: dict = normalize_cfgs[0]

# Set options based on Normalize config
options = {
"flags": ["--reverse_input_channels"] if normalize_cfg.get("to_rgb", False) else [],
"args": {
"--mean_values": list(normalize_cfg.get("mean", [])),
"--scale_values": list(normalize_cfg.get("std", [])),
},
}

# Set default backend configuration
mo_options = deploy_cfg.backend_config.get("mo_options", ConfigDict())
mo_options = ConfigDict() if mo_options is None else mo_options
mo_options.args = mo_options.get("args", ConfigDict())
mo_options.flags = mo_options.get("flags", [])

# Override backend configuration with options from Normalize config
mo_options.args.update(options["args"])
mo_options.flags = list(set(mo_options.flags + options["flags"]))

deploy_cfg.backend_config.mo_options = mo_options


def patch_input_shape(cfg: ConfigDict, deploy_cfg: ConfigDict):
"""Update backend configuration with input shape information.

This function retrieves the Resize config from `cfg.data.test.pipeline`, checks
that only one Resize then sets the input shape for the backend model in `deploy_cfg`

```
{
"opt_shapes": {
"input": [1, 3, *size]
}
}
```

Args:
cfg (Config): Config object containing test pipeline and other configurations.
deploy_cfg (DeployConfig): DeployConfig object containing backend configuration.

Returns:
None: This function updates the input `deploy_cfg` object directly.
"""
resize_cfgs = get_configs_by_pairs(
cfg.data.test.pipeline,
dict(type="Resize"),
)
assert len(resize_cfgs) == 1
resize_cfg: ConfigDict = resize_cfgs[0]
size = resize_cfg.size
if isinstance(size, int):
size = (size, size)
assert all(isinstance(i, int) and i > 0 for i in size)
# default is static shape to prevent an unexpected error
# when converting to OpenVINO IR
deploy_cfg.backend_config.model_inputs = [ConfigDict(opt_shapes=ConfigDict(input=[1, 3, *size]))]


def patch_ir_scale_factor(deploy_cfg: ConfigDict, hyper_parameters: DetectionConfig):
"""Patch IR scale factor inplace from hyper parameters to deploy config.

Args:
deploy_cfg (ConfigDict): mmcv deploy config
hyper_parameters (DetectionConfig): OTX detection hyper parameters
"""

if hyper_parameters.tiling_parameters.enable_tiling:
scale_ir_input = deploy_cfg.get("scale_ir_input", False)
if scale_ir_input:
tile_ir_scale_factor = hyper_parameters.tiling_parameters.tile_ir_scale_factor
logger.info(f"Apply OpenVINO IR scale factor: {tile_ir_scale_factor}")
ir_input_shape = deploy_cfg.backend_config.model_inputs[0].opt_shapes.input
ir_input_shape[2] = int(ir_input_shape[2] * tile_ir_scale_factor) # height
ir_input_shape[3] = int(ir_input_shape[3] * tile_ir_scale_factor) # width
deploy_cfg.ir_config.input_shape = (ir_input_shape[3], ir_input_shape[2]) # width, height
9 changes: 8 additions & 1 deletion otx/algorithms/detection/adapters/openvino/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ class OpenVINOTileClassifierWrapper(BaseInferencerWithConverter):
tile_size (int): tile size
overlap (float): overlap ratio between tiles
max_number (int): maximum number of objects per image
tile_ir_scale_factor (float, optional): scale factor for tile size
tile_classifier_model_file (Union[str, bytes, None], optional): tile classifier xml. Defaults to None.
tile_classifier_weight_file (Union[str, bytes, None], optional): til classifier weight bin. Defaults to None.
device (str, optional): device to run inference on, such as CPU, GPU or MYRIAD. Defaults to "CPU".
Expand All @@ -274,6 +275,7 @@ def __init__(
tile_size: int = 400,
overlap: float = 0.5,
max_number: int = 100,
tile_ir_scale_factor: float = 1.0,
tile_classifier_model_file: Union[str, bytes, None] = None,
tile_classifier_weight_file: Union[str, bytes, None] = None,
device: str = "CPU",
Expand All @@ -293,7 +295,7 @@ def __init__(
classifier = Model(model_adapter=adapter, preload=True)

self.tiler = Tiler(
tile_size=tile_size,
tile_size=int(tile_size * tile_ir_scale_factor),
overlap=overlap,
max_number=max_number,
detector=inferencer.model,
Expand Down Expand Up @@ -372,6 +374,10 @@ def load_config(self) -> ADDict:
if self.model is not None and self.model.get_data("config.json"):
json_dict = json.loads(self.model.get_data("config.json"))
flatten_config_values(json_dict)
# NOTE: for backward compatibility
json_dict["tiling_parameters"]["tile_ir_scale_factor"] = json_dict["tiling_parameters"].get(
"tile_ir_scale_factor", 1.0
)
config = merge_a_into_b(json_dict, config)
except Exception as e: # pylint: disable=broad-except
logger.warning(f"Failed to load config.json: {e}")
Expand Down Expand Up @@ -418,6 +424,7 @@ def load_inferencer(
self.config.tiling_parameters.tile_size,
self.config.tiling_parameters.tile_overlap,
self.config.tiling_parameters.tile_max_number,
self.config.tiling_parameters.tile_ir_scale_factor,
tile_classifier_model_file,
tile_classifier_weight_file,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -553,5 +553,23 @@ tiling_parameters:
visible_in_ui: true
warning: null

tile_ir_scale_factor:
header: OpenVINO IR Scale Factor
description: The purpose of the scale parameter is to optimize the performance and efficiency of tiling in OpenVINO IR during inference. By controlling the increase in tile size and input size, the scale parameter allows for more efficient parallelization of the workload and improve the overall performance and efficiency of the inference process on OpenVINO.
affects_outcome_of: TRAINING
default_value: 2.0
min_value: 1.0
max_value: 4.0
type: FLOAT
editable: true
ui_rules:
action: DISABLE_EDITING
operator: AND
rules: []
type: UI_RULES
value: 2.0
visible_in_ui: true
warning: null

type: PARAMETER_GROUP
visible_in_ui: true
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

_base_ = ["../../base/deployments/base_instance_segmentation_dynamic.py"]

scale_ir_input = True

ir_config = dict(
output_names=["boxes", "labels", "masks"],
input_shape=(1024, 1024),
)

backend_config = dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

_base_ = ["../../base/deployments/base_instance_segmentation_dynamic.py"]

scale_ir_input = True

ir_config = dict(
output_names=["boxes", "labels", "masks"],
input_shape=(1344, 800),
)

backend_config = dict(
Expand Down
2 changes: 1 addition & 1 deletion otx/api/utils/tiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def tile(self, image: np.ndarray) -> List[List[int]]:
return coords

def filter_tiles_by_objectness(
self, image: np.ndarray, tile_coords: List[List[int]], confidence_threshold: float = 0.45
self, image: np.ndarray, tile_coords: List[List[int]], confidence_threshold: float = 0.35
):
"""Filter tiles by objectness score by running tile classifier.

Expand Down
2 changes: 1 addition & 1 deletion tests/regression/detection/test_tiling_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ def test_otx_deploy_eval_deployment(self, template, tmp_dir_path):
assert test_result["passed"] is True, test_result["log"]

@e2e_pytest_component
@pytest.mark.parametrize("template", templates, ids=templates_ids)
@pytest.mark.skip(reason="CVS-98026")
@pytest.mark.parametrize("template", templates, ids=templates_ids)
def test_nncf_optimize_eval(self, template, tmp_dir_path):
self.performance[template.name] = {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ def test_otx_deploy_eval_deployment(self, template, tmp_dir_path):
assert test_result["passed"] is True, test_result["log"]

@e2e_pytest_component
@pytest.mark.parametrize("template", templates, ids=templates_ids)
@pytest.mark.skip(reason="CVS-98026")
@pytest.mark.parametrize("template", templates, ids=templates_ids)
def test_nncf_optimize_eval(self, template, tmp_dir_path):
self.performance[template.name] = {}

Expand Down
Loading