Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix label list order for h-label classification #2440

Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Fix label list for h-label cls
GalyaZalesskaya committed Aug 18, 2023
commit dd6130cbdd3133c6e5f2c6a3e4823ffa0e30f411
18 changes: 15 additions & 3 deletions src/otx/algorithms/classification/adapters/openvino/task.py
Original file line number Diff line number Diff line change
@@ -76,7 +76,7 @@
IOptimizationTask,
OptimizationType,
)
from otx.api.utils.dataset_utils import add_saliency_maps_to_dataset_item
from otx.api.utils.dataset_utils import add_saliency_maps_to_dataset_item, get_hierarchical_label_list

logger = logging.getLogger(__name__)

@@ -228,12 +228,18 @@ def add_prediction(id: int, predicted_scene: AnnotationSceneEntity, aux_data: tu
if saliency_map is not None and repr_vector is not None:
feature_vec_media = TensorEntity(name="representation_vector", numpy=repr_vector.reshape(-1))
dataset_item.append_metadata_item(feature_vec_media, model=self.model)
label_list = self.task_environment.get_labels()
# Fix the order for hierarchical labels to adjust classes with model outputs
if self.inferencer.model.hierarchical:
label_list = get_hierarchical_label_list(
self.inferencer.model.hierarchical_info["cls_heads_info"], label_list
)

add_saliency_maps_to_dataset_item(
dataset_item=dataset_item,
saliency_map=saliency_map,
model=self.model,
labels=self.task_environment.get_labels(),
labels=label_list,
predicted_scored_labels=item_labels,
explain_predicted_classes=explain_predicted_classes,
process_saliency_maps=process_saliency_maps,
@@ -284,6 +290,12 @@ def explain(
explain_predicted_classes = explain_parameters.explain_predicted_classes

dataset_size = len(dataset)
label_list = self.task_environment.get_labels()
# Fix the order for hierarchical labels to adjust classes with model outputs
if self.inferencer.model.hierarchical:
label_list = get_hierarchical_label_list(
self.inferencer.model.hierarchical_info["cls_heads_info"], label_list
)
for i, dataset_item in enumerate(dataset, 1):
predicted_scene, _, saliency_map, _, _ = self.inferencer.predict(dataset_item.numpy)
if saliency_map is None:
@@ -298,7 +310,7 @@ def explain(
dataset_item=dataset_item,
saliency_map=saliency_map,
model=self.model,
labels=self.task_environment.get_labels(),
labels=label_list,
predicted_scored_labels=item_labels,
explain_predicted_classes=explain_predicted_classes,
process_saliency_maps=process_saliency_maps,
14 changes: 11 additions & 3 deletions src/otx/algorithms/classification/task.py
Original file line number Diff line number Diff line change
@@ -75,7 +75,7 @@
from otx.api.serialization.label_mapper import label_schema_to_bytes
from otx.api.usecases.evaluation.metrics_helper import MetricsHelper
from otx.api.usecases.tasks.interfaces.export_interface import ExportType
from otx.api.utils.dataset_utils import add_saliency_maps_to_dataset_item
from otx.api.utils.dataset_utils import add_saliency_maps_to_dataset_item, get_hierarchical_label_list
from otx.api.utils.labels_utils import get_empty_label
from otx.cli.utils.multi_gpu import is_multigpu_child_process

@@ -345,6 +345,10 @@ def _add_predictions_to_dataset(

dataset_size = len(dataset)
pos_thr = 0.5
label_list = self._labels
# Fix the order for hierarchical labels to adjust classes with model outputs
if self._hierarchical:
label_list = get_hierarchical_label_list(self._hierarchical_info, label_list)
for i, (dataset_item, prediction_items) in enumerate(zip(dataset, prediction_results)):
prediction_item, feature_vector, saliency_map = prediction_items
if any(np.isnan(prediction_item)):
@@ -373,7 +377,7 @@ def _add_predictions_to_dataset(
dataset_item=dataset_item,
saliency_map=saliency_map,
model=self._task_environment.model,
labels=self._labels,
labels=label_list,
predicted_scored_labels=item_labels,
explain_predicted_classes=explain_predicted_classes,
process_saliency_maps=process_saliency_maps,
@@ -436,13 +440,17 @@ def _add_explanations_to_dataset(
):
"""Loop over dataset again and assign saliency maps."""
dataset_size = len(dataset)
label_list = self._labels
# Fix the order for hierarchical labels to adjust classes with model outputs
if self._hierarchical:
label_list = get_hierarchical_label_list(self._hierarchical_info, label_list)
for i, (dataset_item, prediction_item, saliency_map) in enumerate(zip(dataset, predictions, saliency_maps)):
item_labels = self._get_item_labels(prediction_item, pos_thr=0.5)
add_saliency_maps_to_dataset_item(
dataset_item=dataset_item,
saliency_map=saliency_map,
model=self._task_environment.model,
labels=self._labels,
labels=label_list,
predicted_scored_labels=item_labels,
explain_predicted_classes=explain_predicted_classes,
process_saliency_maps=process_saliency_maps,
12 changes: 12 additions & 0 deletions src/otx/api/utils/dataset_utils.py
Original file line number Diff line number Diff line change
@@ -272,3 +272,15 @@ def non_linear_normalization(saliency_map: np.ndarray) -> np.ndarray:
saliency_map = 255.0 / (max_soft_score + 1e-12) * saliency_map

return np.uint8(np.floor(saliency_map))


def get_hierarchical_label_list(hierarchical_info, labels):
sovrasov marked this conversation as resolved.
Show resolved Hide resolved
"""Return hierarchical labels list which is adjusted to model outputs classes."""
hierarchical_labels = []
for head_idx in range(hierarchical_info["num_multiclass_heads"]):
logits_begin, logits_end = hierarchical_info["head_idx_to_logits_range"][str(head_idx)]
for logit in range(0, logits_end - logits_begin):
label_str = hierarchical_info["all_groups"][head_idx][logit]
otx_label = next(x for x in labels if x.name == label_str)
sovrasov marked this conversation as resolved.
Show resolved Hide resolved
hierarchical_labels.append(otx_label)
sovrasov marked this conversation as resolved.
Show resolved Hide resolved
return hierarchical_labels