From 9d2d34d6989ddf79fba348dab8d5a233cc1bf30b Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Wed, 18 Oct 2023 11:51:57 +0200 Subject: [PATCH 1/3] Update MAPI rt infor for detection --- src/otx/algorithms/detection/utils/utils.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/otx/algorithms/detection/utils/utils.py b/src/otx/algorithms/detection/utils/utils.py index 500ec1ad7cf..648f2195835 100644 --- a/src/otx/algorithms/detection/utils/utils.py +++ b/src/otx/algorithms/detection/utils/utils.py @@ -110,16 +110,22 @@ def get_det_model_api_configuration( """Get ModelAPI config.""" omz_config = {} all_labels = "" + all_label_ids = "" if task_type == TaskType.DETECTION: omz_config[("model_info", "model_type")] = "ssd" + omz_config[("model_info", "task_type")] = "detection" if task_type == TaskType.INSTANCE_SEGMENTATION: omz_config[("model_info", "model_type")] = "MaskRCNN" + omz_config[("model_info", "task_type")] = "instance_segmentation" all_labels = "otx_empty_lbl " + all_label_ids = "None " if tiling_parameters.enable_tiling: omz_config[("model_info", "resize_type")] = "fit_to_window_letterbox" if task_type == TaskType.ROTATED_DETECTION: - omz_config[("model_info", "model_type")] = "rotated_detection" + omz_config[("model_info", "model_type")] = "MaskRCNN" + omz_config[("model_info", "task_type")] = "rotated_detection" all_labels = "otx_empty_lbl " + all_label_ids = "None " if tiling_parameters.enable_tiling: omz_config[("model_info", "resize_type")] = "fit_to_window_letterbox" @@ -137,9 +143,10 @@ def get_det_model_api_configuration( for lbl in label_schema.get_labels(include_empty=False): all_labels += lbl.name.replace(" ", "_") + " " - all_labels = all_labels.strip() + all_label_ids += f"{lbl.id} " - omz_config[("model_info", "labels")] = all_labels + omz_config[("model_info", "labels")] = all_labels.strip() + omz_config[("model_info", "label_ids")] = all_label_ids.strip() return omz_config From f29a7da609db3f1f79511f0106700d19c7c5f550 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Wed, 18 Oct 2023 13:36:03 +0200 Subject: [PATCH 2/3] Upadte export info for cls, det and seg --- src/otx/algorithms/classification/utils/cls_utils.py | 8 ++++++-- src/otx/algorithms/detection/utils/utils.py | 2 +- src/otx/algorithms/segmentation/utils/metadata.py | 6 ++++-- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/otx/algorithms/classification/utils/cls_utils.py b/src/otx/algorithms/classification/utils/cls_utils.py index b1506ccc8e7..968586a7d5a 100644 --- a/src/otx/algorithms/classification/utils/cls_utils.py +++ b/src/otx/algorithms/classification/utils/cls_utils.py @@ -98,16 +98,20 @@ def get_cls_model_api_configuration(label_schema: LabelSchemaEntity, inference_c """Get ModelAPI config.""" mapi_config = {} mapi_config[("model_info", "model_type")] = "Classification" + mapi_config[("model_info", "task_type")] = "classification" mapi_config[("model_info", "confidence_threshold")] = str(inference_config["confidence_threshold"]) mapi_config[("model_info", "multilabel")] = str(inference_config["multilabel"]) mapi_config[("model_info", "hierarchical")] = str(inference_config["hierarchical"]) mapi_config[("model_info", "output_raw_scores")] = str(True) all_labels = "" + all_label_ids = "" for lbl in label_schema.get_labels(include_empty=False): all_labels += lbl.name.replace(" ", "_") + " " - all_labels = all_labels.strip() - mapi_config[("model_info", "labels")] = all_labels + all_label_ids += f"{lbl.id_} " + + mapi_config[("model_info", "labels")] = all_labels.strip() + mapi_config[("model_info", "label_ids")] = all_label_ids.strip() hierarchical_config = {} hierarchical_config["cls_heads_info"] = get_multihead_class_info(label_schema) diff --git a/src/otx/algorithms/detection/utils/utils.py b/src/otx/algorithms/detection/utils/utils.py index 648f2195835..90c9a7e4476 100644 --- a/src/otx/algorithms/detection/utils/utils.py +++ b/src/otx/algorithms/detection/utils/utils.py @@ -143,7 +143,7 @@ def get_det_model_api_configuration( for lbl in label_schema.get_labels(include_empty=False): all_labels += lbl.name.replace(" ", "_") + " " - all_label_ids += f"{lbl.id} " + all_label_ids += f"{lbl.id_} " omz_config[("model_info", "labels")] = all_labels.strip() omz_config[("model_info", "label_ids")] = all_label_ids.strip() diff --git a/src/otx/algorithms/segmentation/utils/metadata.py b/src/otx/algorithms/segmentation/utils/metadata.py index 9ecc4e320c2..38ff4e31810 100644 --- a/src/otx/algorithms/segmentation/utils/metadata.py +++ b/src/otx/algorithms/segmentation/utils/metadata.py @@ -12,13 +12,15 @@ def get_seg_model_api_configuration(label_schema: LabelSchemaEntity, hyperparams: ConfigDict): """Get ModelAPI config.""" all_labels = "" + all_label_ids = "" for lbl in label_schema.get_labels(include_empty=False): all_labels += lbl.name.replace(" ", "_") + " " - all_labels = all_labels.strip() + all_label_ids += f"{lbl.id_} " return { ("model_info", "model_type"): "Segmentation", ("model_info", "soft_threshold"): str(hyperparams.postprocessing.soft_threshold), ("model_info", "blur_strength"): str(hyperparams.postprocessing.blur_strength), - ("model_info", "labels"): all_labels, + ("model_info", "labels"): all_labels.strip(), + ("model_info", "label_ids"): all_label_ids.strip(), } From 70fae68c9edc4d55dbc6e8d0311f7419887c64b8 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Wed, 18 Oct 2023 13:55:04 +0200 Subject: [PATCH 3/3] Update unit tests --- src/otx/algorithms/segmentation/utils/metadata.py | 1 + tests/unit/algorithms/classification/utils/test_utils.py | 4 ++++ tests/unit/algorithms/detection/utils/test_detection_utils.py | 4 ++++ 3 files changed, 9 insertions(+) diff --git a/src/otx/algorithms/segmentation/utils/metadata.py b/src/otx/algorithms/segmentation/utils/metadata.py index 38ff4e31810..4a4012a024e 100644 --- a/src/otx/algorithms/segmentation/utils/metadata.py +++ b/src/otx/algorithms/segmentation/utils/metadata.py @@ -23,4 +23,5 @@ def get_seg_model_api_configuration(label_schema: LabelSchemaEntity, hyperparams ("model_info", "blur_strength"): str(hyperparams.postprocessing.blur_strength), ("model_info", "labels"): all_labels.strip(), ("model_info", "label_ids"): all_label_ids.strip(), + ("model_info", "task_type"): "segmentation", } diff --git a/tests/unit/algorithms/classification/utils/test_utils.py b/tests/unit/algorithms/classification/utils/test_utils.py index 009005f3cea..95dbf4883db 100644 --- a/tests/unit/algorithms/classification/utils/test_utils.py +++ b/tests/unit/algorithms/classification/utils/test_utils.py @@ -93,3 +93,7 @@ def test_get_cls_model_api_configuration(default_hierarchical_data): assert len(model_api_cfg) > 0 assert model_api_cfg[("model_info", "confidence_threshold")] == str(config["confidence_threshold"]) assert ("model_info", "hierarchical_config") in model_api_cfg + assert ("model_info", "labels") in model_api_cfg + assert ("model_info", "label_ids") in model_api_cfg + assert len(label_schema.get_labels(include_empty=False)) == len(model_api_cfg[("model_info", "labels")].split()) + assert len(label_schema.get_labels(include_empty=False)) == len(model_api_cfg[("model_info", "label_ids")].split()) diff --git a/tests/unit/algorithms/detection/utils/test_detection_utils.py b/tests/unit/algorithms/detection/utils/test_detection_utils.py index 77c46a8c855..0a3a645e29e 100644 --- a/tests/unit/algorithms/detection/utils/test_detection_utils.py +++ b/tests/unit/algorithms/detection/utils/test_detection_utils.py @@ -34,3 +34,7 @@ def test_get_det_model_api_configuration(): tiling_parameters.tile_overlap / tiling_parameters.tile_ir_scale_factor ) assert model_api_cfg[("model_info", "max_pred_number")] == str(tiling_parameters.tile_max_number) + assert ("model_info", "labels") in model_api_cfg + assert ("model_info", "label_ids") in model_api_cfg + assert len(label_schema.get_labels(include_empty=False)) == len(model_api_cfg[("model_info", "labels")].split()) + assert len(label_schema.get_labels(include_empty=False)) == len(model_api_cfg[("model_info", "label_ids")].split())