Skip to content

Commit

Permalink
Fix labels names in hierarchical config (#3879)
Browse files Browse the repository at this point in the history
* Fix hierarchical config

* Add exceptions handling

* Add exceptions checks to other tasks

* Fix black
  • Loading branch information
sovrasov authored Aug 23, 2024
1 parent 4bb9e1c commit 7bdf708
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 4 deletions.
4 changes: 4 additions & 0 deletions src/otx/algorithms/classification/adapters/openvino/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def _async_callback(self, request: Any, callback_args: tuple) -> None:
result_handler(id, annotation, aux_data)

except Exception as e:
logger.exception(e)
self.callback_exceptions.append(e)

def predict(self, image: np.ndarray) -> Tuple[ClassificationResult, AnnotationSceneEntity]:
Expand Down Expand Up @@ -280,6 +281,9 @@ def add_prediction(id: int, predicted_scene: AnnotationSceneEntity, aux_data: tu

self.inferencer.await_all()

if self.inferencer.callback_exceptions:
raise RuntimeError("Inference failed, check the exceptions log.")

self._avg_time_per_image = total_time / len(dataset)
logger.info(f"Avg time per image: {self._avg_time_per_image} secs")
logger.info(f"Total time: {total_time} secs")
Expand Down
11 changes: 7 additions & 4 deletions src/otx/algorithms/classification/utils/cls_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@
from otx.api.entities.label import LabelEntity
from otx.api.entities.label_schema import LabelSchemaEntity
from otx.api.serialization.label_mapper import LabelSchemaMapper
from otx.api.utils.labels_utils import get_normalized_label_name


def get_multihead_class_info(label_schema: LabelSchemaEntity): # pylint: disable=too-many-locals
"""Get multihead info by label schema."""
all_groups = label_schema.get_groups(include_empty=False)
all_groups_str = []
for g in all_groups:
group_labels_str = [lbl.name for lbl in g.labels]
group_labels_str = [get_normalized_label_name(lbl) for lbl in g.labels]
all_groups_str.append(group_labels_str)

single_label_groups = [g for g in all_groups_str if len(g) == 1]
Expand Down Expand Up @@ -112,7 +113,7 @@ def get_cls_model_api_configuration(label_schema: LabelSchemaEntity, inference_c
all_labels = ""
all_label_ids = ""
for lbl in label_entities:
all_labels += lbl.name.replace(" ", "_") + " "
all_labels += get_normalized_label_name(lbl) + " "
all_label_ids += f"{lbl.id_} "

mapi_config[("model_info", "labels")] = all_labels.strip()
Expand All @@ -122,7 +123,9 @@ def get_cls_model_api_configuration(label_schema: LabelSchemaEntity, inference_c
hierarchical_config["cls_heads_info"] = get_multihead_class_info(label_schema)
hierarchical_config["label_tree_edges"] = []
for edge in label_schema.label_tree.edges: # (child, parent)
hierarchical_config["label_tree_edges"].append((edge[0].name, edge[1].name))
hierarchical_config["label_tree_edges"].append(
(get_normalized_label_name(edge[0]), get_normalized_label_name(edge[1]))
)

mapi_config[("model_info", "hierarchical_config")] = json.dumps(hierarchical_config)
return mapi_config
Expand All @@ -137,7 +140,7 @@ def get_hierarchical_label_list(hierarchical_cls_heads_info: Dict, labels: List)
hierarchical_labels = []
for label_str, _ in label_to_idx.items():
for label_entity in labels:
if label_entity.name == label_str:
if get_normalized_label_name(label_entity) == label_str:
hierarchical_labels.append(label_entity)
break
return hierarchical_labels
4 changes: 4 additions & 0 deletions src/otx/algorithms/detection/adapters/openvino/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def _async_callback(self, request: Any, callback_args: tuple) -> None:
result_handler(id, processed_prediciton, features)

except Exception as e:
logger.exception(e)
self.callback_exceptions.append(e)

def enqueue_prediction(self, image: np.ndarray, id: int, result_handler: Any) -> None:
Expand Down Expand Up @@ -557,6 +558,9 @@ def add_prediction(id: int, predicted_scene: AnnotationSceneEntity, aux_data: tu

self.inferencer.await_all()

if self.inferencer.callback_exceptions:
raise RuntimeError("Inference failed, check the exceptions log.")

self._avg_time_per_image = total_time / len(dataset)
logger.info(f"Avg time per image: {self._avg_time_per_image} secs")
logger.info(f"Total time: {total_time} secs")
Expand Down
4 changes: 4 additions & 0 deletions src/otx/algorithms/segmentation/adapters/openvino/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def _async_callback(self, request: Any, callback_args: tuple) -> None:
result_handler(id, annotation, processed_prediciton.feature_vector, processed_prediciton.saliency_map)

except Exception as e:
logger.exception(e)
self.callback_exceptions.append(e)


Expand Down Expand Up @@ -254,6 +255,9 @@ def add_prediction(

self.inferencer.await_all()

if self.inferencer.callback_exceptions:
raise RuntimeError("Inference failed, check the exceptions log.")

self._avg_time_per_image = total_time / len(dataset)
logger.info(f"Avg time per image: {self._avg_time_per_image} secs")
logger.info(f"Total time: {total_time} secs")
Expand Down
5 changes: 5 additions & 0 deletions src/otx/api/utils/labels_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,8 @@ def get_empty_label(label_schema: LabelSchemaEntity) -> Optional[LabelEntity]:
if empty_candidates:
return empty_candidates[0]
return None


def get_normalized_label_name(label: LabelEntity) -> str:
"""Gets a nomalized label name"""
return label.name.replace(" ", "_")

0 comments on commit 7bdf708

Please sign in to comment.