Skip to content

Commit

Permalink
Fix label list order for h-label classification (#2440)
Browse files Browse the repository at this point in the history
* Fix label list for h-label cls
* Fix unit tests
  • Loading branch information
GalyaZalesskaya authored and yunchu committed Aug 30, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 6c93b24 commit 441ba1b
Showing 5 changed files with 50 additions and 6 deletions.
19 changes: 15 additions & 4 deletions src/otx/algorithms/classification/adapters/openvino/task.py
Original file line number Diff line number Diff line change
@@ -36,6 +36,7 @@
from otx.algorithms.classification.configs import ClassificationConfig
from otx.algorithms.classification.utils import (
get_cls_deploy_config,
get_hierarchical_label_list,
)
from otx.algorithms.common.utils import OTXOpenVinoDataLoader
from otx.algorithms.common.utils.ir import check_if_quantized
@@ -216,14 +217,18 @@ def add_prediction(id: int, predicted_scene: AnnotationSceneEntity, aux_data: tu
if saliency_map.ndim > 1 and repr_vector.ndim > 0:
feature_vec_media = TensorEntity(name="representation_vector", numpy=repr_vector.reshape(-1))
dataset_item.append_metadata_item(feature_vec_media, model=self.model)
if saliency_map.ndim == 4 and saliency_map.shape[0] == 1:
saliency_map = saliency_map.squeeze()
label_list = self.task_environment.get_labels()
# Fix the order for hierarchical labels to adjust classes with model outputs
if self.inferencer.model.hierarchical:
label_list = get_hierarchical_label_list(
self.inferencer.model.hierarchical_info["cls_heads_info"], label_list
)

add_saliency_maps_to_dataset_item(
dataset_item=dataset_item,
saliency_map=saliency_map,
model=self.model,
labels=self.task_environment.get_labels(),
labels=label_list,
predicted_scored_labels=item_labels,
explain_predicted_classes=explain_predicted_classes,
process_saliency_maps=process_saliency_maps,
@@ -272,6 +277,12 @@ def explain(
explain_predicted_classes = explain_parameters.explain_predicted_classes

dataset_size = len(dataset)
label_list = self.task_environment.get_labels()
# Fix the order for hierarchical labels to adjust classes with model outputs
if self.inferencer.model.hierarchical:
label_list = get_hierarchical_label_list(
self.inferencer.model.hierarchical_info["cls_heads_info"], label_list
)
for i, dataset_item in enumerate(dataset, 1):
cls_result, predicted_scene = self.inferencer.predict(dataset_item.numpy)

@@ -292,7 +303,7 @@ def explain(
dataset_item=dataset_item,
saliency_map=saliency_map,
model=self.model,
labels=self.task_environment.get_labels(),
labels=label_list,
predicted_scored_labels=item_labels,
explain_predicted_classes=explain_predicted_classes,
process_saliency_maps=process_saliency_maps,
13 changes: 11 additions & 2 deletions src/otx/algorithms/classification/task.py
Original file line number Diff line number Diff line change
@@ -28,6 +28,7 @@
get_cls_deploy_config,
get_cls_inferencer_configuration,
get_cls_model_api_configuration,
get_hierarchical_label_list,
)
from otx.algorithms.classification.utils import (
get_multihead_class_info as get_hierarchical_info,
@@ -350,6 +351,10 @@ def _add_predictions_to_dataset(

dataset_size = len(dataset)
pos_thr = 0.5
label_list = self._labels
# Fix the order for hierarchical labels to adjust classes with model outputs
if self._hierarchical:
label_list = get_hierarchical_label_list(self._hierarchical_info, label_list)
for i, (dataset_item, prediction_items) in enumerate(zip(dataset, prediction_results)):
prediction_item, feature_vector, saliency_map = prediction_items
if any(np.isnan(prediction_item)):
@@ -378,7 +383,7 @@ def _add_predictions_to_dataset(
dataset_item=dataset_item,
saliency_map=saliency_map,
model=self._task_environment.model,
labels=self._labels,
labels=label_list,
predicted_scored_labels=item_labels,
explain_predicted_classes=explain_predicted_classes,
process_saliency_maps=process_saliency_maps,
@@ -440,13 +445,17 @@ def _add_explanations_to_dataset(
):
"""Loop over dataset again and assign saliency maps."""
dataset_size = len(dataset)
label_list = self._labels
# Fix the order for hierarchical labels to adjust classes with model outputs
if self._hierarchical:
label_list = get_hierarchical_label_list(self._hierarchical_info, label_list)
for i, (dataset_item, prediction_item, saliency_map) in enumerate(zip(dataset, predictions, saliency_maps)):
item_labels = self._get_item_labels(prediction_item, pos_thr=0.5)
add_saliency_maps_to_dataset_item(
dataset_item=dataset_item,
saliency_map=saliency_map,
model=self._task_environment.model,
labels=self._labels,
labels=label_list,
predicted_scored_labels=item_labels,
explain_predicted_classes=explain_predicted_classes,
process_saliency_maps=process_saliency_maps,
2 changes: 2 additions & 0 deletions src/otx/algorithms/classification/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -8,10 +8,12 @@
get_cls_deploy_config,
get_cls_inferencer_configuration,
get_cls_model_api_configuration,
get_hierarchical_label_list,
get_multihead_class_info,
)

__all__ = [
"get_hierarchical_label_list",
"get_multihead_class_info",
"get_cls_inferencer_configuration",
"get_cls_deploy_config",
21 changes: 21 additions & 0 deletions src/otx/algorithms/classification/utils/cls_utils.py
Original file line number Diff line number Diff line change
@@ -117,3 +117,24 @@ def get_cls_model_api_configuration(label_schema: LabelSchemaEntity, inference_c

mapi_config[("model_info", "hierarchical_config")] = json.dumps(hierarchical_config)
return mapi_config


def get_hierarchical_label_list(hierarchical_info, labels):
"""Return hierarchical labels list which is adjusted to model outputs classes."""
hierarchical_labels = []
for head_idx in range(hierarchical_info["num_multiclass_heads"]):
logits_begin, logits_end = hierarchical_info["head_idx_to_logits_range"][str(head_idx)]
for logit in range(0, logits_end - logits_begin):
label_str = hierarchical_info["all_groups"][head_idx][logit]
label_idx = hierarchical_info["label_to_idx"][label_str]
hierarchical_labels.append(labels[label_idx])

if hierarchical_info["num_multilabel_classes"]:
logits_begin = hierarchical_info["num_single_label_classes"]
logits_end = len(labels)
for logit_idx, logit in enumerate(range(0, logits_end - logits_begin)):
label_str_idx = hierarchical_info["num_multiclass_heads"] + logit_idx
label_str = hierarchical_info["all_groups"][label_str_idx][0]
label_idx = hierarchical_info["label_to_idx"][label_str]
hierarchical_labels.append(labels[label_idx])
return hierarchical_labels
Original file line number Diff line number Diff line change
@@ -143,6 +143,7 @@ def test_explain(self, mocker):
self.fake_ann_scene,
),
)
self.cls_ov_task.inferencer.model.hierarchical = False
updpated_dataset = self.cls_ov_task.explain(self.dataset)

assert updpated_dataset is not None

0 comments on commit 441ba1b

Please sign in to comment.