diff --git a/ote_sdk/ote_sdk/entities/model_template.py b/ote_sdk/ote_sdk/entities/model_template.py index d548db57a21..8adfa2f2786 100644 --- a/ote_sdk/ote_sdk/entities/model_template.py +++ b/ote_sdk/ote_sdk/entities/model_template.py @@ -424,7 +424,10 @@ def is_task_global(self) -> bool: """ Returns ``True`` if the task is global task i.e. if task produces global labels """ - return self.task_type in [TaskType.CLASSIFICATION] + return self.task_type in ( + TaskType.CLASSIFICATION, + TaskType.ANOMALY_CLASSIFICATION, + ) def supports_auto_hpo(self) -> bool: """ diff --git a/ote_sdk/ote_sdk/tests/entities/test_model_template.py b/ote_sdk/ote_sdk/tests/entities/test_model_template.py index d01e6045869..a9a2f12f1f9 100644 --- a/ote_sdk/ote_sdk/tests/entities/test_model_template.py +++ b/ote_sdk/ote_sdk/tests/entities/test_model_template.py @@ -965,27 +965,36 @@ def test_model_template_is_task_global(self): Test passes if is_task_global method of ModelTemplate object returns expected bool values related to task_type attribute Steps - 1. Check is_task_global method returns True if task_type equal to CLASSIFICATION - 2. Check is_task_global method returns False if task_type not equal to CLASSIFICATION + 1. Check is_task_global method returns True if task_type equal to CLASSIFICATION or ANOMALY_CLASSIFICATION + 2. Check is_task_global method returns False if task_type not equal to CLASSIFICATION or ANOMALY_CLASSIFICATION """ - # Check is_task_global method returns True - default_parameters = self.default_model_parameters() - task_global_parameters = dict(default_parameters) - task_global_parameters["task_type"] = TaskType.CLASSIFICATION - task_global_model_template = ModelTemplate(**task_global_parameters) - assert task_global_model_template.is_task_global() - # Check is_task_global method returns False + # Check is_task_global method returns True for CLASSIFICATION and ANOMALY_CLASSIFICATION + for global_task_type in ( + TaskType.CLASSIFICATION, + TaskType.ANOMALY_CLASSIFICATION, + ): + default_parameters = self.default_model_parameters() + task_global_parameters = dict(default_parameters) + task_global_parameters["task_type"] = global_task_type + task_global_model_template = ModelTemplate(**task_global_parameters) + assert ( + task_global_model_template.is_task_global() + ), f"Expected True value returned by is_task_global for {global_task_type}" + # Check is_task_global method returns False for the other tasks non_global_task_parameters = dict(default_parameters) non_global_tasks_list = [] for task_type in TaskType: - if task_type != TaskType.CLASSIFICATION: + if task_type not in ( + TaskType.CLASSIFICATION, + TaskType.ANOMALY_CLASSIFICATION, + ): non_global_tasks_list.append(task_type) for non_global_task in non_global_tasks_list: non_global_task_parameters["task_type"] = non_global_task non_global_task_template = ModelTemplate(**non_global_task_parameters) assert not non_global_task_template.is_task_global(), ( - f"Expected False value returned by is_task_global method for {non_global_task}, only CLASSIFICATION " - f"task type is global" + f"Expected False value returned by is_task_global method for {non_global_task}, " + f"only CLASSIFICATION and ANOMALY_CLASSIFICATION task types are global" )