Skip to content

Commit

Permalink
Merge pull request #980 from openvinotoolkit/leo/modeltemplate-isglobal
Browse files Browse the repository at this point in the history
[OTE_SDK] expand ModelTemplate.is_global
  • Loading branch information
druzhkov-paul authored Mar 18, 2022
2 parents 4fc20a2 + 5060c88 commit abafaa3
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
5 changes: 4 additions & 1 deletion ote_sdk/ote_sdk/entities/model_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,10 @@ def is_task_global(self) -> bool:
"""
Returns ``True`` if the task is global task i.e. if task produces global labels
"""
return self.task_type in [TaskType.CLASSIFICATION]
return self.task_type in (
TaskType.CLASSIFICATION,
TaskType.ANOMALY_CLASSIFICATION,
)

def supports_auto_hpo(self) -> bool:
"""
Expand Down
33 changes: 21 additions & 12 deletions ote_sdk/ote_sdk/tests/entities/test_model_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,27 +965,36 @@ def test_model_template_is_task_global(self):
Test passes if is_task_global method of ModelTemplate object returns expected bool values related to
task_type attribute
<b>Steps</b>
1. Check is_task_global method returns True if task_type equal to CLASSIFICATION
2. Check is_task_global method returns False if task_type not equal to CLASSIFICATION
1. Check is_task_global method returns True if task_type equal to CLASSIFICATION or ANOMALY_CLASSIFICATION
2. Check is_task_global method returns False if task_type not equal to CLASSIFICATION or ANOMALY_CLASSIFICATION
"""
# Check is_task_global method returns True
default_parameters = self.default_model_parameters()
task_global_parameters = dict(default_parameters)
task_global_parameters["task_type"] = TaskType.CLASSIFICATION
task_global_model_template = ModelTemplate(**task_global_parameters)
assert task_global_model_template.is_task_global()
# Check is_task_global method returns False
# Check is_task_global method returns True for CLASSIFICATION and ANOMALY_CLASSIFICATION
for global_task_type in (
TaskType.CLASSIFICATION,
TaskType.ANOMALY_CLASSIFICATION,
):
default_parameters = self.default_model_parameters()
task_global_parameters = dict(default_parameters)
task_global_parameters["task_type"] = global_task_type
task_global_model_template = ModelTemplate(**task_global_parameters)
assert (
task_global_model_template.is_task_global()
), f"Expected True value returned by is_task_global for {global_task_type}"
# Check is_task_global method returns False for the other tasks
non_global_task_parameters = dict(default_parameters)
non_global_tasks_list = []
for task_type in TaskType:
if task_type != TaskType.CLASSIFICATION:
if task_type not in (
TaskType.CLASSIFICATION,
TaskType.ANOMALY_CLASSIFICATION,
):
non_global_tasks_list.append(task_type)
for non_global_task in non_global_tasks_list:
non_global_task_parameters["task_type"] = non_global_task
non_global_task_template = ModelTemplate(**non_global_task_parameters)
assert not non_global_task_template.is_task_global(), (
f"Expected False value returned by is_task_global method for {non_global_task}, only CLASSIFICATION "
f"task type is global"
f"Expected False value returned by is_task_global method for {non_global_task}, "
f"only CLASSIFICATION and ANOMALY_CLASSIFICATION task types are global"
)


Expand Down

0 comments on commit abafaa3

Please sign in to comment.