Skip to content

Commit

Permalink
Merge pull request #1120 from openvinotoolkit/da/use-is-anomalous
Browse files Browse the repository at this point in the history
[ANOMALY] Use is_anomalous attribute instead of string matching
  • Loading branch information
goodsong81 authored May 31, 2022
2 parents c240c4b + 16f2138 commit 2e18117
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,10 @@ def test_create_converter(self):
name="Normal", domain=Domain.ANOMALY_CLASSIFICATION, id=ID("1")
),
LabelEntity(
name="Anomalous", domain=Domain.ANOMALY_CLASSIFICATION, id=ID("2")
name="Anomalous",
domain=Domain.ANOMALY_CLASSIFICATION,
id=ID("2"),
is_anomalous=True,
),
]
label_group = LabelGroup(
Expand All @@ -310,7 +313,12 @@ def test_create_converter(self):
# "ANOMALY_DETECTION" is specified as "converter_type"
labels = [
LabelEntity(name="Normal", domain=Domain.ANOMALY_DETECTION, id=ID("1")),
LabelEntity(name="Anomalous", domain=Domain.ANOMALY_DETECTION, id=ID("2")),
LabelEntity(
name="Anomalous",
domain=Domain.ANOMALY_DETECTION,
id=ID("2"),
is_anomalous=True,
),
]
label_group = LabelGroup(name="Anomaly detection labels group", labels=labels)
label_schema = LabelSchemaEntity(label_groups=[label_group])
Expand All @@ -325,7 +333,10 @@ def test_create_converter(self):
labels = [
LabelEntity(name="Normal", domain=Domain.ANOMALY_SEGMENTATION, id=ID("1")),
LabelEntity(
name="Anomalous", domain=Domain.ANOMALY_SEGMENTATION, id=ID("2")
name="Anomalous",
domain=Domain.ANOMALY_SEGMENTATION,
id=ID("2"),
is_anomalous=True,
),
]
label_group = LabelGroup(name="Anomaly detection labels group", labels=labels)
Expand Down Expand Up @@ -947,8 +958,18 @@ def test_anomaly_classification_to_annotation_init(
non_empty_labels = [
LabelEntity(name="Normal", domain=Domain.CLASSIFICATION, id=ID("1")),
LabelEntity(name="Normal", domain=Domain.CLASSIFICATION, id=ID("2")),
LabelEntity(name="Anomalous", domain=Domain.CLASSIFICATION, id=ID("1")),
LabelEntity(name="Anomalous", domain=Domain.CLASSIFICATION, id=ID("2")),
LabelEntity(
name="Anomalous",
domain=Domain.CLASSIFICATION,
id=ID("1"),
is_anomalous=True,
),
LabelEntity(
name="Anomalous",
domain=Domain.CLASSIFICATION,
id=ID("2"),
is_anomalous=True,
),
]
label_group = LabelGroup(
name="Classification labels group", labels=non_empty_labels
Expand Down Expand Up @@ -1030,7 +1051,12 @@ def check_annotation(actual_annotation: Annotation, expected_labels: list):

non_empty_labels = [
LabelEntity(name="Normal", domain=Domain.CLASSIFICATION, id=ID("1")),
LabelEntity(name="Anomalous", domain=Domain.CLASSIFICATION, id=ID("2")),
LabelEntity(
name="Anomalous",
domain=Domain.CLASSIFICATION,
id=ID("2"),
is_anomalous=True,
),
]
label_group = LabelGroup(
name="Anomaly classification labels group", labels=non_empty_labels
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,8 @@ class AnomalyClassificationToAnnotationConverter(IPredictionToAnnotationConverte

def __init__(self, label_schema: LabelSchemaEntity):
labels = label_schema.get_labels(include_empty=False)
self.normal_label = [label for label in labels if label.name == "Normal"][0]
self.anomalous_label = [label for label in labels if label.name == "Anomalous"][
0
]
self.normal_label = [label for label in labels if not label.is_anomalous][0]
self.anomalous_label = [label for label in labels if label.is_anomalous][0]

def convert_to_annotation(
self, predictions: np.ndarray, metadata: Dict[str, Any]
Expand Down Expand Up @@ -290,10 +288,8 @@ class AnomalySegmentationToAnnotationConverter(IPredictionToAnnotationConverter)

def __init__(self, label_schema: LabelSchemaEntity):
labels = label_schema.get_labels(include_empty=False)
self.normal_label = [label for label in labels if label.name == "Normal"][0]
self.anomalous_label = [label for label in labels if label.name == "Anomalous"][
0
]
self.normal_label = [label for label in labels if not label.is_anomalous][0]
self.anomalous_label = [label for label in labels if label.is_anomalous][0]
self.label_map = {0: self.normal_label, 1: self.anomalous_label}

def convert_to_annotation(
Expand Down Expand Up @@ -327,10 +323,8 @@ def __init__(self, label_schema: LabelSchemaEntity):
:param label_schema: Label Schema containing the label info of the task
"""
labels = label_schema.get_labels(include_empty=False)
self.normal_label = [label for label in labels if label.name == "Normal"][0]
self.anomalous_label = [label for label in labels if label.name == "Anomalous"][
0
]
self.normal_label = [label for label in labels if not label.is_anomalous][0]
self.anomalous_label = [label for label in labels if label.is_anomalous][0]
self.label_map = {0: self.normal_label, 1: self.anomalous_label}

def convert_to_annotation(
Expand Down

0 comments on commit 2e18117

Please sign in to comment.