diff --git a/external/mmdetection/detection_tasks/apis/detection/config_utils.py b/external/mmdetection/detection_tasks/apis/detection/config_utils.py
index 57eabc94bf2..551ba8d85a9 100644
--- a/external/mmdetection/detection_tasks/apis/detection/config_utils.py
+++ b/external/mmdetection/detection_tasks/apis/detection/config_utils.py
@@ -18,13 +18,18 @@
import os
import tempfile
from collections import defaultdict
-from typing import List, Optional
+from typing import List, Optional, Union
import torch
from mmcv import Config, ConfigDict
from ote_sdk.entities.datasets import DatasetEntity
from ote_sdk.entities.label import LabelEntity, Domain
from ote_sdk.usecases.reporting.time_monitor_callback import TimeMonitorCallback
+from ote_sdk.utils.argument_checks import (
+ DatasetParamTypeCheck,
+ DirectoryPathCheck,
+ check_input_parameters_type
+)
from detection_tasks.extension.datasets.data_utils import get_anchor_boxes, \
get_sizes_from_dataset_entity, format_list_to_str
@@ -43,14 +48,16 @@
logger = get_root_logger()
+@check_input_parameters_type()
def is_epoch_based_runner(runner_config: ConfigDict):
return 'Epoch' in runner_config.type
+@check_input_parameters_type({"work_dir": DirectoryPathCheck})
def patch_config(config: Config, work_dir: str, labels: List[LabelEntity], domain: Domain, random_seed: Optional[int] = None):
# Set runner if not defined.
if 'runner' not in config:
- config.runner = {'type': 'EpochBasedRunner'}
+ config.runner = ConfigDict({'type': 'EpochBasedRunner'})
# Check that there is no conflict in specification of number of training epochs.
# Move global definition of epochs inside runner config.
@@ -112,6 +119,7 @@ def patch_config(config: Config, work_dir: str, labels: List[LabelEntity], domai
config.seed = random_seed
+@check_input_parameters_type()
def set_hyperparams(config: Config, hyperparams: OTEDetectionConfig):
config.optimizer.lr = float(hyperparams.learning_parameters.learning_rate)
config.lr_config.warmup_iters = int(hyperparams.learning_parameters.learning_rate_warmup_iters)
@@ -126,7 +134,8 @@ def set_hyperparams(config: Config, hyperparams: OTEDetectionConfig):
config.runner.max_iters = total_iterations
-def patch_adaptive_repeat_dataset(config: Config, num_samples: int,
+@check_input_parameters_type()
+def patch_adaptive_repeat_dataset(config: Union[Config, ConfigDict], num_samples: int,
decay: float = -0.002, factor: float = 30):
""" Patch the repeat times and training epochs adatively
@@ -155,14 +164,17 @@ def patch_adaptive_repeat_dataset(config: Config, num_samples: int,
data_train.times = new_repeat
-def prepare_for_testing(config: Config, dataset: DatasetEntity) -> Config:
+@check_input_parameters_type({"dataset": DatasetParamTypeCheck})
+def prepare_for_testing(config: Union[Config, ConfigDict], dataset: DatasetEntity) -> Config:
config = copy.deepcopy(config)
# FIXME. Should working directories be modified here?
config.data.test.ote_dataset = dataset
return config
-def prepare_for_training(config: Config, train_dataset: DatasetEntity, val_dataset: DatasetEntity,
+@check_input_parameters_type({"train_dataset": DatasetParamTypeCheck,
+ "val_dataset": DatasetParamTypeCheck})
+def prepare_for_training(config: Union[Config, ConfigDict], train_dataset: DatasetEntity, val_dataset: DatasetEntity,
time_monitor: TimeMonitorCallback, learning_curves: defaultdict) -> Config:
config = copy.deepcopy(config)
prepare_work_dir(config)
@@ -175,7 +187,8 @@ def prepare_for_training(config: Config, train_dataset: DatasetEntity, val_datas
return config
-def config_to_string(config: Config) -> str:
+@check_input_parameters_type()
+def config_to_string(config: Union[Config, ConfigDict]) -> str:
"""
Convert a full mmdetection config to a string.
@@ -194,6 +207,7 @@ def config_to_string(config: Config) -> str:
return Config(config_copy).pretty_text
+@check_input_parameters_type()
def config_from_string(config_string: str) -> Config:
"""
Generate an mmdetection config dict object from a string.
@@ -207,6 +221,7 @@ def config_from_string(config_string: str) -> Config:
return Config.fromfile(temp_file.name)
+@check_input_parameters_type()
def save_config_to_file(config: Config):
""" Dump the full config to a file. Filename is 'config.py', it is saved in the current work_dir. """
filepath = os.path.join(config.work_dir, 'config.py')
@@ -215,7 +230,8 @@ def save_config_to_file(config: Config):
f.write(config_string)
-def prepare_work_dir(config: Config) -> str:
+@check_input_parameters_type()
+def prepare_work_dir(config: Union[Config, ConfigDict]) -> str:
base_work_dir = config.work_dir
checkpoint_dirs = glob.glob(os.path.join(base_work_dir, "checkpoints_round_*"))
train_round_checkpoint_dir = os.path.join(base_work_dir, f"checkpoints_round_{len(checkpoint_dirs)}")
@@ -230,6 +246,7 @@ def prepare_work_dir(config: Config) -> str:
return train_round_checkpoint_dir
+@check_input_parameters_type()
def set_data_classes(config: Config, labels: List[LabelEntity]):
# Save labels in data configs.
for subset in ('train', 'val', 'test'):
@@ -256,7 +273,8 @@ def set_data_classes(config: Config, labels: List[LabelEntity]):
# self.config.model.CLASSES = label_names
-def patch_datasets(config: Config, domain):
+@check_input_parameters_type()
+def patch_datasets(config: Config, domain: Domain):
def patch_color_conversion(pipeline):
# Default data format for OTE is RGB, while mmdet uses BGR, so negate the color conversion flag.
@@ -289,7 +307,8 @@ def patch_color_conversion(pipeline):
patch_color_conversion(cfg.pipeline)
-def remove_from_config(config, key: str):
+@check_input_parameters_type()
+def remove_from_config(config: Union[Config, ConfigDict], key: str):
if key in config:
if isinstance(config, Config):
del config._cfg_dict[key]
@@ -298,6 +317,8 @@ def remove_from_config(config, key: str):
else:
raise ValueError(f'Unknown config type {type(config)}')
+
+@check_input_parameters_type({"dataset": DatasetParamTypeCheck})
def cluster_anchors(config: Config, dataset: DatasetEntity, model: BaseDetector):
if not kmeans_import:
raise ImportError('Sklearn package is not installed. To enable anchor boxes clustering, please install '
@@ -308,7 +329,7 @@ def cluster_anchors(config: Config, dataset: DatasetEntity, model: BaseDetector)
if transforms.type == 'MultiScaleFlipAug']
prev_generator = config.model.bbox_head.anchor_generator
group_as = [len(width) for width in prev_generator.widths]
- wh_stats = get_sizes_from_dataset_entity(dataset, target_wh)
+ wh_stats = get_sizes_from_dataset_entity(dataset, list(target_wh))
if len(wh_stats) < sum(group_as):
logger.warning(f'There are not enough objects to cluster: {len(wh_stats)} were detected, while it should be '
@@ -332,7 +353,8 @@ def cluster_anchors(config: Config, dataset: DatasetEntity, model: BaseDetector)
return config, model
-def get_data_cfg(config: Config, subset: str = 'train') -> Config:
+@check_input_parameters_type()
+def get_data_cfg(config: Union[Config, ConfigDict], subset: str = 'train') -> Config:
data_cfg = config.data[subset]
while 'dataset' in data_cfg:
data_cfg = data_cfg.dataset
diff --git a/external/mmdetection/detection_tasks/apis/detection/inference_task.py b/external/mmdetection/detection_tasks/apis/detection/inference_task.py
index 073a92d4548..7988eb7395f 100644
--- a/external/mmdetection/detection_tasks/apis/detection/inference_task.py
+++ b/external/mmdetection/detection_tasks/apis/detection/inference_task.py
@@ -45,6 +45,10 @@
from ote_sdk.usecases.tasks.interfaces.inference_interface import IInferenceTask
from ote_sdk.usecases.tasks.interfaces.unload_interface import IUnload
from ote_sdk.serialization.label_mapper import label_schema_to_bytes
+from ote_sdk.utils.argument_checks import (
+ DatasetParamTypeCheck,
+ check_input_parameters_type,
+)
from mmdet.apis import export_model
from detection_tasks.apis.detection.config_utils import patch_config, prepare_for_testing, set_hyperparams
@@ -63,6 +67,7 @@ class OTEDetectionInferenceTask(IInferenceTask, IExportTask, IEvaluationTask, IU
_task_environment: TaskEnvironment
+ @check_input_parameters_type()
def __init__(self, task_environment: TaskEnvironment):
""""
Task for inference object detection models using OTEDetection.
@@ -239,6 +244,7 @@ def _add_predictions_to_dataset(self, prediction_results, dataset, confidence_th
dataset_item.append_metadata_item(active_score, model=self._task_environment.model)
+ @check_input_parameters_type({"dataset": DatasetParamTypeCheck})
def infer(self, dataset: DatasetEntity, inference_parameters: Optional[InferenceParameters] = None) -> DatasetEntity:
""" Analyzes a dataset using the latest inference model. """
@@ -330,7 +336,7 @@ def dummy_dump_features_hook(mod, inp, out):
eval_predictions = zip(eval_predictions, feature_vectors)
return eval_predictions, metric
-
+ @check_input_parameters_type()
def evaluate(self,
output_result_set: ResultSetEntity,
evaluation_metric: Optional[str] = None):
@@ -375,6 +381,7 @@ def unload(self):
logger.warning(f"Done unloading. "
f"Torch is still occupying {torch.cuda.memory_allocated()} bytes of GPU memory")
+ @check_input_parameters_type()
def export(self,
export_type: ExportType,
output_model: ModelEntity):
diff --git a/external/mmdetection/detection_tasks/apis/detection/nncf_task.py b/external/mmdetection/detection_tasks/apis/detection/nncf_task.py
index fd71cbda6eb..e42b4f5118f 100644
--- a/external/mmdetection/detection_tasks/apis/detection/nncf_task.py
+++ b/external/mmdetection/detection_tasks/apis/detection/nncf_task.py
@@ -36,6 +36,10 @@
from ote_sdk.usecases.tasks.interfaces.export_interface import ExportType
from ote_sdk.usecases.tasks.interfaces.optimization_interface import IOptimizationTask
from ote_sdk.usecases.tasks.interfaces.optimization_interface import OptimizationType
+from ote_sdk.utils.argument_checks import (
+ DatasetParamTypeCheck,
+ check_input_parameters_type,
+)
from mmdet.apis import train_detector
from mmdet.apis.fake_input import get_fake_input
@@ -59,6 +63,7 @@
class OTEDetectionNNCFTask(OTEDetectionInferenceTask, IOptimizationTask):
+ @check_input_parameters_type()
def __init__(self, task_environment: TaskEnvironment):
""""
Task for compressing object detection models using NNCF.
@@ -177,12 +182,13 @@ def _create_compressed_model(self, dataset, config):
get_fake_input_func=get_fake_input,
is_accuracy_aware=is_acc_aware_training_set)
+ @check_input_parameters_type({"dataset": DatasetParamTypeCheck})
def optimize(
self,
optimization_type: OptimizationType,
dataset: DatasetEntity,
output_model: ModelEntity,
- optimization_parameters: Optional[OptimizationParameters],
+ optimization_parameters: Optional[OptimizationParameters] = None,
):
if optimization_type is not OptimizationType.NNCF:
raise RuntimeError("NNCF is the only supported optimization")
@@ -247,6 +253,7 @@ def optimize(
self._is_training = False
+ @check_input_parameters_type()
def export(self, export_type: ExportType, output_model: ModelEntity):
if self._compression_ctrl is None:
super().export(export_type, output_model)
@@ -256,6 +263,7 @@ def export(self, export_type: ExportType, output_model: ModelEntity):
super().export(export_type, output_model)
self._model.enable_dynamic_graph_building()
+ @check_input_parameters_type()
def save_model(self, output_model: ModelEntity):
buffer = io.BytesIO()
hyperparams = self._task_environment.get_hyper_parameters(OTEDetectionConfig)
diff --git a/external/mmdetection/detection_tasks/apis/detection/openvino_task.py b/external/mmdetection/detection_tasks/apis/detection/openvino_task.py
index 9c84aa40fdf..5352f6afa63 100644
--- a/external/mmdetection/detection_tasks/apis/detection/openvino_task.py
+++ b/external/mmdetection/detection_tasks/apis/detection/openvino_task.py
@@ -48,6 +48,7 @@
from ote_sdk.usecases.exportable_code.inference import BaseInferencer
from ote_sdk.usecases.exportable_code.prediction_to_annotation_converter import (
DetectionBoxToAnnotationConverter,
+ IPredictionToAnnotationConverter,
MaskToAnnotationConverter,
RotatedRectToAnnotationConverter,
)
@@ -55,6 +56,11 @@
from ote_sdk.usecases.tasks.interfaces.evaluate_interface import IEvaluationTask
from ote_sdk.usecases.tasks.interfaces.inference_interface import IInferenceTask
from ote_sdk.usecases.tasks.interfaces.optimization_interface import IOptimizationTask, OptimizationType
+from ote_sdk.utils.argument_checks import (
+ DatasetParamTypeCheck,
+ check_input_parameters_type,
+)
+from shutil import copyfile, copytree
from typing import Any, Dict, List, Optional, Tuple, Union
from zipfile import ZipFile
@@ -66,24 +72,29 @@
class BaseInferencerWithConverter(BaseInferencer):
- def __init__(self, configuration, model, converter) -> None:
+ @check_input_parameters_type()
+ def __init__(self, configuration: dict, model: Model, converter: IPredictionToAnnotationConverter) -> None:
self.configuration = configuration
self.model = model
self.converter = converter
+ @check_input_parameters_type()
def pre_process(self, image: np.ndarray) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
return self.model.preprocess(image)
+ @check_input_parameters_type()
def post_process(self, prediction: Dict[str, np.ndarray], metadata: Dict[str, Any]) -> AnnotationSceneEntity:
detections = self.model.postprocess(prediction, metadata)
return self.converter.convert_to_annotation(detections, metadata)
+ @check_input_parameters_type()
def forward(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
return self.model.infer_sync(inputs)
class OpenVINODetectionInferencer(BaseInferencerWithConverter):
+ @check_input_parameters_type()
def __init__(
self,
hparams: OTEDetectionConfig,
@@ -115,6 +126,7 @@ def __init__(
class OpenVINOMaskInferencer(BaseInferencerWithConverter):
+ @check_input_parameters_type()
def __init__(
self,
hparams: OTEDetectionConfig,
@@ -149,6 +161,7 @@ def __init__(
class OpenVINORotatedRectInferencer(BaseInferencerWithConverter):
+ @check_input_parameters_type()
def __init__(
self,
hparams: OTEDetectionConfig,
@@ -183,11 +196,13 @@ def __init__(
class OTEOpenVinoDataLoader(DataLoader):
+ @check_input_parameters_type({"dataset": DatasetParamTypeCheck})
def __init__(self, dataset: DatasetEntity, inferencer: BaseInferencer):
self.dataset = dataset
self.inferencer = inferencer
- def __getitem__(self, index):
+ @check_input_parameters_type()
+ def __getitem__(self, index: int):
image = self.dataset[index].numpy
annotation = self.dataset[index].annotation_scene
inputs, metadata = self.inferencer.pre_process(image)
@@ -199,6 +214,7 @@ def __len__(self):
class OpenVINODetectionTask(IDeploymentTask, IInferenceTask, IEvaluationTask, IOptimizationTask):
+ @check_input_parameters_type()
def __init__(self, task_environment: TaskEnvironment):
logger.info('Loading OpenVINO OTEDetectionTask')
self.task_environment = task_environment
@@ -230,6 +246,7 @@ def load_inferencer(self) -> Union[OpenVINODetectionInferencer, OpenVINOMaskInfe
return OpenVINORotatedRectInferencer(*args)
raise RuntimeError(f"Unknown OpenVINO Inferencer TaskType: {self.task_type}")
+ @check_input_parameters_type({"dataset": DatasetParamTypeCheck})
def infer(self, dataset: DatasetEntity, inference_parameters: Optional[InferenceParameters] = None) -> DatasetEntity:
logger.info('Start OpenVINO inference')
update_progress_callback = default_progress_callback
@@ -243,6 +260,7 @@ def infer(self, dataset: DatasetEntity, inference_parameters: Optional[Inference
logger.info('OpenVINO inference completed')
return dataset
+ @check_input_parameters_type()
def evaluate(self,
output_result_set: ResultSetEntity,
evaluation_metric: Optional[str] = None):
@@ -252,6 +270,7 @@ def evaluate(self,
output_result_set.performance = MetricsHelper.compute_f_measure(output_result_set).get_performance()
logger.info('OpenVINO metric evaluation completed')
+ @check_input_parameters_type()
def deploy(self,
output_model: ModelEntity) -> None:
logger.info('Deploying the model')
@@ -279,11 +298,12 @@ def deploy(self,
output_model.exportable_code = zip_buffer.getvalue()
logger.info('Deploying completed')
+ @check_input_parameters_type({"dataset": DatasetParamTypeCheck})
def optimize(self,
optimization_type: OptimizationType,
dataset: DatasetEntity,
output_model: ModelEntity,
- optimization_parameters: Optional[OptimizationParameters]):
+ optimization_parameters: Optional[OptimizationParameters] = None):
logger.info('Start POT optimization')
if optimization_type is not OptimizationType.POT:
diff --git a/external/mmdetection/detection_tasks/apis/detection/ote_utils.py b/external/mmdetection/detection_tasks/apis/detection/ote_utils.py
index 991efaaa52b..1fa456e1da6 100644
--- a/external/mmdetection/detection_tasks/apis/detection/ote_utils.py
+++ b/external/mmdetection/detection_tasks/apis/detection/ote_utils.py
@@ -16,7 +16,7 @@
import colorsys
import importlib
import random
-from typing import Callable, Union
+from typing import Callable, Optional, Sequence, Union
import numpy as np
import yaml
@@ -26,10 +26,15 @@
from ote_sdk.entities.label_schema import LabelGroup, LabelGroupType, LabelSchemaEntity
from ote_sdk.entities.train_parameters import UpdateProgressCallback
from ote_sdk.usecases.reporting.time_monitor_callback import TimeMonitorCallback
+from ote_sdk.utils.argument_checks import (
+ YamlFilePathCheck,
+ check_input_parameters_type,
+)
class ColorPalette:
- def __init__(self, n, rng=None):
+ @check_input_parameters_type()
+ def __init__(self, n: int, rng: Optional[random.Random] = None):
assert n > 0
if rng is None:
@@ -40,36 +45,38 @@ def __init__(self, n, rng=None):
for _ in range(1, n):
colors_candidates = [(rng.random(), rng.uniform(0.8, 1.0), rng.uniform(0.5, 1.0))
for _ in range(candidates_num)]
- min_distances = [self.min_distance(hsv_colors, c) for c in colors_candidates]
+ min_distances = [self._min_distance(hsv_colors, c) for c in colors_candidates]
arg_max = np.argmax(min_distances)
hsv_colors.append(colors_candidates[arg_max])
- self.palette = [Color(*self.hsv2rgb(*hsv)) for hsv in hsv_colors]
+ self.palette = [Color(*self._hsv2rgb(*hsv)) for hsv in hsv_colors]
@staticmethod
- def dist(c1, c2):
+ def _dist(c1, c2):
dh = min(abs(c1[0] - c2[0]), 1 - abs(c1[0] - c2[0])) * 2
ds = abs(c1[1] - c2[1])
dv = abs(c1[2] - c2[2])
return dh * dh + ds * ds + dv * dv
@classmethod
- def min_distance(cls, colors_set, color_candidate):
- distances = [cls.dist(o, color_candidate) for o in colors_set]
+ def _min_distance(cls, colors_set, color_candidate):
+ distances = [cls._dist(o, color_candidate) for o in colors_set]
return np.min(distances)
@staticmethod
- def hsv2rgb(h, s, v):
+ def _hsv2rgb(h, s, v):
return tuple(round(c * 255) for c in colorsys.hsv_to_rgb(h, s, v))
- def __getitem__(self, n):
+ @check_input_parameters_type()
+ def __getitem__(self, n: int):
return self.palette[n % len(self.palette)]
def __len__(self):
return len(self.palette)
-def generate_label_schema(label_names, label_domain=Domain.DETECTION):
+@check_input_parameters_type()
+def generate_label_schema(label_names: Sequence[str], label_domain: Domain = Domain.DETECTION):
colors = ColorPalette(len(label_names)) if len(label_names) > 0 else []
not_empty_labels = [LabelEntity(name=name, color=colors[i], domain=label_domain, id=ID(f"{i:08}")) for i, name in
enumerate(label_names)]
@@ -84,13 +91,15 @@ def generate_label_schema(label_names, label_domain=Domain.DETECTION):
return label_schema
+@check_input_parameters_type({"path": YamlFilePathCheck})
def load_template(path):
with open(path) as f:
template = yaml.safe_load(f)
return template
-def get_task_class(path):
+@check_input_parameters_type()
+def get_task_class(path: str):
module_name, class_name = path.rsplit('.', 1)
module = importlib.import_module(module_name)
return getattr(module, class_name)
diff --git a/external/mmdetection/detection_tasks/apis/detection/train_task.py b/external/mmdetection/detection_tasks/apis/detection/train_task.py
index 7ba4403100d..a7a3b359453 100644
--- a/external/mmdetection/detection_tasks/apis/detection/train_task.py
+++ b/external/mmdetection/detection_tasks/apis/detection/train_task.py
@@ -33,6 +33,10 @@
from ote_sdk.serialization.label_mapper import label_schema_to_bytes
from ote_sdk.usecases.evaluation.metrics_helper import MetricsHelper
from ote_sdk.usecases.tasks.interfaces.training_interface import ITrainingTask
+from ote_sdk.utils.argument_checks import (
+ DatasetParamTypeCheck,
+ check_input_parameters_type,
+)
from mmdet.apis import train_detector
from detection_tasks.apis.detection.config_utils import cluster_anchors, prepare_for_training, set_hyperparams
@@ -81,6 +85,7 @@ def _generate_training_metrics(self, learning_curves, map) -> Optional[List[Metr
return output
+ @check_input_parameters_type({"dataset": DatasetParamTypeCheck})
def train(self, dataset: DatasetEntity, output_model: ModelEntity, train_parameters: Optional[TrainParameters] = None):
""" Trains a model on a dataset """
@@ -191,6 +196,7 @@ def train(self, dataset: DatasetEntity, output_model: ModelEntity, train_paramet
logger.info('Training the model [done]')
+ @check_input_parameters_type()
def save_model(self, output_model: ModelEntity):
buffer = io.BytesIO()
hyperparams_str = ids_to_strings(cfg_helper.convert(self._hyperparams, dict, enum_to_str=True))
diff --git a/external/mmdetection/detection_tasks/extension/datasets/data_utils.py b/external/mmdetection/detection_tasks/extension/datasets/data_utils.py
index ff38a3243af..be1ee03b8e9 100644
--- a/external/mmdetection/detection_tasks/extension/datasets/data_utils.py
+++ b/external/mmdetection/detection_tasks/extension/datasets/data_utils.py
@@ -3,7 +3,7 @@
#
import json
import os.path as osp
-from typing import List, Optional
+from typing import Any, Dict, List, Optional, Sequence
import numpy as np
from ote_sdk.entities.annotation import Annotation, AnnotationSceneEntity, AnnotationSceneKind
@@ -16,11 +16,19 @@
from ote_sdk.entities.shapes.polygon import Polygon, Point
from ote_sdk.entities.shapes.rectangle import Rectangle
from ote_sdk.entities.subset import Subset
+from ote_sdk.utils.argument_checks import (
+ DatasetParamTypeCheck,
+ DirectoryPathCheck,
+ OptionalDirectoryPathCheck,
+ JsonFilePathCheck,
+ check_input_parameters_type,
+)
from ote_sdk.utils.shape_factory import ShapeFactory
from pycocotools.coco import COCO
from mmdet.core import BitmapMasks, PolygonMasks
+@check_input_parameters_type({"path": JsonFilePathCheck})
def get_classes_from_annotation(path):
with open(path) as read_file:
content = json.load(read_file)
@@ -31,7 +39,8 @@ def get_classes_from_annotation(path):
class LoadAnnotations:
- def __init__(self, with_bbox=True, with_label=True, with_mask=False):
+ @check_input_parameters_type()
+ def __init__(self, with_bbox: bool = True, with_label: bool = True, with_mask: bool = False):
self.with_bbox = with_bbox
self.with_label = with_label
self.with_mask = with_mask
@@ -57,7 +66,8 @@ def _load_masks(self, results):
results['mask_fields'].append('gt_masks')
return results
- def __call__(self, results):
+ @check_input_parameters_type()
+ def __call__(self, results: Dict[str, Any]):
if self.with_bbox:
results = self._load_bboxes(results)
if results is None:
@@ -77,16 +87,18 @@ def __repr__(self):
class CocoDataset:
+ @check_input_parameters_type({"ann_file": JsonFilePathCheck,
+ "data_root": OptionalDirectoryPathCheck})
def __init__(
self,
- ann_file,
- classes=None,
- data_root=None,
- img_prefix="",
- test_mode=False,
- filter_empty_gt=True,
- min_size=None,
- with_mask=False,
+ ann_file: str,
+ classes: Optional[Sequence[str]] = None,
+ data_root: Optional[str] = None,
+ img_prefix: str = "",
+ test_mode: bool = False,
+ filter_empty_gt: bool = True,
+ min_size: Optional[int] = None,
+ with_mask: bool = False,
):
self.ann_file = ann_file
self.data_root = data_root
@@ -112,7 +124,8 @@ def __init__(
def __len__(self):
return len(self.data_infos)
- def pre_pipeline(self, results):
+ @check_input_parameters_type()
+ def pre_pipeline(self, results: Dict[str, Any]):
results["img_prefix"] = self.img_prefix
results["bbox_fields"] = []
results["mask_fields"] = []
@@ -122,21 +135,24 @@ def _rand_another(self, idx):
pool = np.where(self.flag == self.flag[idx])[0]
return np.random.choice(pool)
- def __getitem__(self, idx):
+ @check_input_parameters_type()
+ def __getitem__(self, idx: int):
return self.prepare_img(idx)
def __iter__(self):
for i in range(len(self)):
yield self[i]
- def prepare_img(self, idx):
+ @check_input_parameters_type()
+ def prepare_img(self, idx: int):
img_info = self.data_infos[idx]
ann_info = self.get_ann_info(idx)
results = dict(img_info=img_info, ann_info=ann_info)
self.pre_pipeline(results)
return LoadAnnotations(with_mask=self.with_mask)(results)
- def get_classes(self, classes=None):
+ @check_input_parameters_type()
+ def get_classes(self, classes: Optional[Sequence[str]] = None):
if classes is None:
return get_classes_from_annotation(self.ann_file)
@@ -145,6 +161,7 @@ def get_classes(self, classes=None):
raise ValueError(f"Unsupported type {type(classes)} of classes.")
+ @check_input_parameters_type({"ann_file": JsonFilePathCheck})
def load_annotations(self, ann_file):
self.coco = COCO(ann_file)
self.cat_ids = self.coco.get_cat_ids(cat_names=self.classes)
@@ -157,13 +174,15 @@ def load_annotations(self, ann_file):
data_infos.append(info)
return data_infos
- def get_ann_info(self, idx):
+ @check_input_parameters_type()
+ def get_ann_info(self, idx: int):
img_id = self.data_infos[idx]["id"]
ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
ann_info = self.coco.load_anns(ann_ids)
return self._parse_ann_info(self.data_infos[idx], ann_info)
- def get_cat_ids(self, idx):
+ @check_input_parameters_type()
+ def get_cat_ids(self, idx: int):
img_id = self.data_infos[idx]["id"]
ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
ann_info = self.coco.load_anns(ann_ids)
@@ -246,7 +265,8 @@ def _parse_ann_info(self, img_info, ann_info):
return ann
-def find_label_by_name(labels, name, domain):
+@check_input_parameters_type()
+def find_label_by_name(labels: Sequence[LabelEntity], name: str, domain: Domain):
matching_labels = [label for label in labels if label.name == name]
if len(matching_labels) == 1:
return matching_labels[0]
@@ -258,6 +278,8 @@ def find_label_by_name(labels, name, domain):
raise ValueError("Found multiple matching labels")
+@check_input_parameters_type({"ann_file_path": JsonFilePathCheck,
+ "data_root_dir": DirectoryPathCheck})
def load_dataset_items_coco_format(
ann_file_path: str,
data_root_dir: str,
@@ -346,7 +368,8 @@ def create_gt_polygon(polygon_group, label_name):
return dataset_items
-def get_sizes_from_dataset_entity(dataset: DatasetEntity, target_wh: list):
+@check_input_parameters_type({"dataset": DatasetParamTypeCheck})
+def get_sizes_from_dataset_entity(dataset: DatasetEntity, target_wh: List[int]):
"""
Function to get sizes of instances in DatasetEntity and to resize it to the target size.
@@ -366,7 +389,8 @@ def get_sizes_from_dataset_entity(dataset: DatasetEntity, target_wh: list):
return wh_stats
-def get_anchor_boxes(wh_stats, group_as):
+@check_input_parameters_type()
+def get_anchor_boxes(wh_stats: List[tuple], group_as: List[int]):
from sklearn.cluster import KMeans
kmeans = KMeans(init='k-means++', n_clusters=sum(group_as), random_state=0).fit(wh_stats)
centers = kmeans.cluster_centers_
@@ -382,7 +406,8 @@ def get_anchor_boxes(wh_stats, group_as):
return widths, heights
-def format_list_to_str(value_lists):
+@check_input_parameters_type()
+def format_list_to_str(value_lists: list):
""" Decrease floating point digits in logs """
str_value = ''
for value_list in value_lists:
diff --git a/external/mmdetection/detection_tasks/extension/datasets/mmdataset.py b/external/mmdetection/detection_tasks/extension/datasets/mmdataset.py
index ba7ef08c136..65f1702221b 100644
--- a/external/mmdetection/detection_tasks/extension/datasets/mmdataset.py
+++ b/external/mmdetection/detection_tasks/extension/datasets/mmdataset.py
@@ -13,12 +13,16 @@
# and limitations under the License.
from copy import deepcopy
-from typing import List
+from typing import Any, Dict, List, Sequence
import numpy as np
from ote_sdk.entities.dataset_item import DatasetItemEntity
from ote_sdk.entities.datasets import DatasetEntity
from ote_sdk.entities.label import Domain, LabelEntity
+from ote_sdk.utils.argument_checks import (
+ DatasetParamTypeCheck,
+ check_input_parameters_type,
+)
from ote_sdk.utils.shape_factory import ShapeFactory
from mmdet.core import PolygonMasks
@@ -27,6 +31,7 @@
from mmdet.datasets.pipelines import Compose
+@check_input_parameters_type()
def get_annotation_mmdet_format(
dataset_item: DatasetItemEntity,
labels: List[LabelEntity],
@@ -130,7 +135,15 @@ def __getitem__(self, index):
return data_info
- def __init__(self, ote_dataset: DatasetEntity, labels: List[LabelEntity], pipeline, domain, test_mode: bool = False):
+ @check_input_parameters_type({"ote_dataset": DatasetParamTypeCheck})
+ def __init__(
+ self,
+ ote_dataset: DatasetEntity,
+ labels: List[LabelEntity],
+ pipeline: Sequence[dict],
+ domain: Domain,
+ test_mode: bool = False,
+ ):
self.ote_dataset = ote_dataset
self.labels = labels
self.CLASSES = list(label.name for label in labels)
@@ -171,6 +184,7 @@ def _rand_another(self, idx):
def _filter_imgs(self, min_size=32):
raise NotImplementedError
+ @check_input_parameters_type()
def prepare_train_img(self, idx: int) -> dict:
"""Get training data and annotations after pipeline.
@@ -181,6 +195,7 @@ def prepare_train_img(self, idx: int) -> dict:
self.pre_pipeline(item)
return self.pipeline(item)
+ @check_input_parameters_type()
def prepare_test_img(self, idx: int) -> dict:
"""Get testing data after pipeline.
@@ -194,13 +209,15 @@ def prepare_test_img(self, idx: int) -> dict:
return self.pipeline(item)
@staticmethod
- def pre_pipeline(results: dict):
+ @check_input_parameters_type()
+ def pre_pipeline(results: Dict[str, Any]):
"""Prepare results dict for pipeline. Add expected keys to the dict. """
results['bbox_fields'] = []
results['mask_fields'] = []
results['seg_fields'] = []
- def get_ann_info(self, idx):
+ @check_input_parameters_type()
+ def get_ann_info(self, idx: int):
"""
This method is used for evaluation of predictions. The CustomDataset class implements a method
CustomDataset.evaluate, which uses the class method get_ann_info to retrieve annotations.
diff --git a/external/mmdetection/detection_tasks/extension/utils/hooks.py b/external/mmdetection/detection_tasks/extension/utils/hooks.py
index f631eee7db6..e26eb2bbf11 100644
--- a/external/mmdetection/detection_tasks/extension/utils/hooks.py
+++ b/external/mmdetection/detection_tasks/extension/utils/hooks.py
@@ -17,12 +17,15 @@
import os
from math import inf, isnan
from collections import defaultdict
+from typing import Any, Dict, Optional
from mmcv.runner.hooks import HOOKS, Hook, LoggerHook, LrUpdaterHook
from mmcv.runner import BaseRunner, EpochBasedRunner
from mmcv.runner.dist_utils import master_only
from mmcv.utils import print_log
+from ote_sdk.usecases.reporting.time_monitor_callback import TimeMonitorCallback
+from ote_sdk.utils.argument_checks import check_input_parameters_type
from mmdet.utils.logger import get_root_logger
@@ -31,6 +34,7 @@
@HOOKS.register_module()
class CancelTrainingHook(Hook):
+ @check_input_parameters_type()
def __init__(self, interval: int = 5):
"""
Periodically check whether whether a stop signal is sent to the runner during model training.
@@ -53,6 +57,7 @@ def _check_for_stop_signal(runner: BaseRunner):
runner.should_stop = True # Set this flag to true to stop the current training epoch
os.remove(stop_filepath)
+ @check_input_parameters_type()
def after_train_iter(self, runner: BaseRunner):
if not self.every_n_iters(runner, self.interval):
return
@@ -68,7 +73,8 @@ def __init__(self):
"""
pass
- def before_run(self, runner):
+ @check_input_parameters_type()
+ def before_run(self, runner: BaseRunner):
pass
@@ -81,7 +87,8 @@ def __init__(self):
"""
pass
- def after_run(self, runner):
+ @check_input_parameters_type()
+ def after_run(self, runner: BaseRunner):
runner.call_hook('after_train_epoch')
@@ -99,17 +106,19 @@ def __repr__(self):
points.append(f'({x},{y})')
return 'curve[' + ','.join(points) + ']'
+ @check_input_parameters_type()
def __init__(self,
- curves=None,
- interval=10,
- ignore_last=True,
- reset_flag=True,
- by_epoch=True):
+ curves: Optional[Dict[Any, Curve]] = None,
+ interval: int = 10,
+ ignore_last: bool = True,
+ reset_flag: bool = True,
+ by_epoch: bool = True):
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.curves = curves if curves is not None else defaultdict(self.Curve)
@master_only
- def log(self, runner):
+ @check_input_parameters_type()
+ def log(self, runner: BaseRunner):
tags = self.get_loggable_tags(runner, allow_text=False)
if runner.max_epochs is not None:
normalized_iter = self.get_iter(runner) / runner.max_iters * runner.max_epochs
@@ -124,7 +133,8 @@ def log(self, runner):
curve.x.append(normalized_iter)
curve.y.append(value)
- def after_train_epoch(self, runner):
+ @check_input_parameters_type()
+ def after_train_epoch(self, runner: BaseRunner):
# Iteration counter is increased right after the last iteration in the epoch,
# temporarily decrease it back.
runner._iter -= 1
@@ -134,13 +144,15 @@ def after_train_epoch(self, runner):
@HOOKS.register_module()
class OTEProgressHook(Hook):
- def __init__(self, time_monitor, verbose=False):
+ @check_input_parameters_type()
+ def __init__(self, time_monitor: TimeMonitorCallback, verbose: bool = False):
super().__init__()
self.time_monitor = time_monitor
self.verbose = verbose
self.print_threshold = 1
- def before_run(self, runner):
+ @check_input_parameters_type()
+ def before_run(self, runner: BaseRunner):
total_epochs = runner.max_epochs if runner.max_epochs is not None else 1
self.time_monitor.total_epochs = total_epochs
self.time_monitor.train_steps = runner.max_iters // total_epochs if total_epochs else 1
@@ -150,16 +162,20 @@ def before_run(self, runner):
self.time_monitor.current_epoch = 0
self.time_monitor.on_train_begin()
- def before_epoch(self, runner):
+ @check_input_parameters_type()
+ def before_epoch(self, runner: BaseRunner):
self.time_monitor.on_epoch_begin(runner.epoch)
- def after_epoch(self, runner):
+ @check_input_parameters_type()
+ def after_epoch(self, runner: BaseRunner):
self.time_monitor.on_epoch_end(runner.epoch, runner.log_buffer.output)
- def before_iter(self, runner):
+ @check_input_parameters_type()
+ def before_iter(self, runner: BaseRunner):
self.time_monitor.on_train_batch_begin(1)
- def after_iter(self, runner):
+ @check_input_parameters_type()
+ def after_iter(self, runner: BaseRunner):
self.time_monitor.on_train_batch_end(1)
if self.verbose:
progress = self.progress
@@ -167,13 +183,16 @@ def after_iter(self, runner):
logger.warning(f'training progress {progress:.0f}%')
self.print_threshold = (progress + 10) // 10 * 10
- def before_val_iter(self, runner):
+ @check_input_parameters_type()
+ def before_val_iter(self, runner: BaseRunner):
self.time_monitor.on_test_batch_begin(1)
- def after_val_iter(self, runner):
+ @check_input_parameters_type()
+ def after_val_iter(self, runner: BaseRunner):
self.time_monitor.on_test_batch_end(1)
- def after_run(self, runner):
+ @check_input_parameters_type()
+ def after_run(self, runner: BaseRunner):
self.time_monitor.on_train_end(1)
self.time_monitor.update_progress_callback(int(self.time_monitor.get_progress()))
@@ -217,10 +236,11 @@ class EarlyStoppingHook(Hook):
]
less_keys = ['loss']
+ @check_input_parameters_type()
def __init__(self,
interval: int,
metric: str = 'bbox_mAP',
- rule: str = None,
+ rule: Optional[str] = None,
patience: int = 5,
iteration_patience: int = 500,
min_delta: float = 0.0):
@@ -274,19 +294,22 @@ def _init_rule(self, rule, key_indicator):
self.key_indicator = key_indicator
self.compare_func = self.rule_map[self.rule]
- def before_run(self, runner):
+ @check_input_parameters_type()
+ def before_run(self, runner: BaseRunner):
self.by_epoch = False if runner.max_epochs is None else True
for hook in runner.hooks:
if isinstance(hook, LrUpdaterHook):
self.warmup_iters = hook.warmup_iters
break
- def after_train_iter(self, runner):
+ @check_input_parameters_type()
+ def after_train_iter(self, runner: BaseRunner):
"""Called after every training iter to evaluate the results."""
if not self.by_epoch:
self._do_check_stopping(runner)
- def after_train_epoch(self, runner):
+ @check_input_parameters_type()
+ def after_train_epoch(self, runner: BaseRunner):
"""Called after every training epoch to evaluate the results."""
if self.by_epoch:
self._do_check_stopping(runner)
@@ -371,14 +394,15 @@ class ReduceLROnPlateauLrUpdaterHook(LrUpdaterHook):
]
less_keys = ['loss']
+ @check_input_parameters_type()
def __init__(self,
- min_lr,
- interval,
- metric='bbox_mAP',
- rule=None,
- factor=0.1,
- patience=3,
- iteration_patience=300,
+ min_lr: float,
+ interval: int,
+ metric: str = 'bbox_mAP',
+ rule: Optional[str] = None,
+ factor: float = 0.1,
+ patience: int = 3,
+ iteration_patience: int = 300,
**kwargs):
super().__init__(**kwargs)
self.interval = interval
@@ -438,7 +462,8 @@ def _should_check_stopping(self, runner):
return False
return True
- def get_lr(self, runner, base_lr):
+ @check_input_parameters_type()
+ def get_lr(self, runner: BaseRunner, base_lr: float):
if not self._should_check_stopping(
runner) or self.warmup_iters > runner.iter:
return base_lr
@@ -480,7 +505,8 @@ def get_lr(self, runner, base_lr):
self.current_lr = max(self.current_lr * self.factor, self.min_lr)
return self.current_lr
- def before_run(self, runner):
+ @check_input_parameters_type()
+ def before_run(self, runner: BaseRunner):
# TODO: remove overloaded method after fixing the issue
# https://github.com/open-mmlab/mmdetection/issues/6572
for group in runner.optimizer.param_groups:
@@ -497,7 +523,8 @@ def before_run(self, runner):
@HOOKS.register_module()
class StopLossNanTrainingHook(Hook):
- def after_train_iter(self, runner):
+ @check_input_parameters_type()
+ def after_train_iter(self, runner: BaseRunner):
if isnan(runner.outputs['loss'].item()):
logger.warning(f"Early Stopping since loss is NaN")
runner.should_stop = True
diff --git a/external/mmdetection/detection_tasks/extension/utils/pipelines.py b/external/mmdetection/detection_tasks/extension/utils/pipelines.py
index 82629e49973..dd18ee395cb 100644
--- a/external/mmdetection/detection_tasks/extension/utils/pipelines.py
+++ b/external/mmdetection/detection_tasks/extension/utils/pipelines.py
@@ -14,8 +14,12 @@
import copy
+from typing import Dict, Any, Optional
import numpy as np
+from ote_sdk.entities.label import Domain
+from ote_sdk.utils.argument_checks import check_input_parameters_type
+
from mmdet.datasets.builder import PIPELINES
from ..datasets import get_annotation_mmdet_format
@@ -34,10 +38,12 @@ class LoadImageFromOTEDataset:
:param to_float32: optional bool, True to convert images to fp32. defaults to False
"""
+ @check_input_parameters_type()
def __init__(self, to_float32: bool = False):
self.to_float32 = to_float32
- def __call__(self, results):
+ @check_input_parameters_type()
+ def __call__(self, results: Dict[str, Any]):
dataset_item = results['dataset_item']
img = dataset_item.numpy
shape = img.shape
@@ -77,8 +83,9 @@ class LoadAnnotationFromOTEDataset:
"""
+ @check_input_parameters_type()
def __init__(self, min_size : int, with_bbox: bool = True, with_label: bool = True, with_mask: bool = False, with_seg: bool = False,
- poly2mask: bool = True, with_text: bool = False, domain=None):
+ poly2mask: bool = True, with_text: bool = False, domain: Optional[Domain] = None):
self.with_bbox = with_bbox
self.with_label = with_label
self.with_mask = with_mask
@@ -105,7 +112,8 @@ def _load_masks(results, ann_info):
results['gt_masks'] = copy.deepcopy(ann_info['masks'])
return results
- def __call__(self, results):
+ @check_input_parameters_type()
+ def __call__(self, results: Dict[str, Any]):
dataset_item = results['dataset_item']
label_list = results['ann_info']['label_list']
ann_info = get_annotation_mmdet_format(dataset_item, label_list, self.domain, self.min_size)
diff --git a/external/mmdetection/detection_tasks/extension/utils/runner.py b/external/mmdetection/detection_tasks/extension/utils/runner.py
index 94889b2d1d2..a245d9c3779 100644
--- a/external/mmdetection/detection_tasks/extension/utils/runner.py
+++ b/external/mmdetection/detection_tasks/extension/utils/runner.py
@@ -11,11 +11,15 @@
import time
import warnings
+from typing import List, Sequence, Optional
import mmcv
import torch.distributed as dist
from mmcv.runner.utils import get_host_info
from mmcv.runner import RUNNERS, EpochBasedRunner, IterBasedRunner, IterLoader, get_dist_info
+from torch.utils.data.dataloader import DataLoader
+
+from ote_sdk.utils.argument_checks import check_input_parameters_type
@RUNNERS.register_module()
@@ -46,7 +50,8 @@ def stop(self) -> bool:
self._max_epochs = self.epoch
return broadcast_obj[0]
- def train(self, data_loader, **kwargs):
+ @check_input_parameters_type()
+ def train(self, data_loader: DataLoader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
@@ -78,7 +83,8 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.should_stop = False
- def main_loop(self, workflow, iter_loaders, **kwargs):
+ @check_input_parameters_type()
+ def main_loop(self, workflow: List[tuple], iter_loaders: Sequence[IterLoader], **kwargs):
while self.iter < self._max_iters:
for i, flow in enumerate(workflow):
self._inner_iter = 0
@@ -95,7 +101,8 @@ def main_loop(self, workflow, iter_loaders, **kwargs):
if self.should_stop:
return
- def run(self, data_loaders, workflow, max_iters=None, **kwargs):
+ @check_input_parameters_type()
+ def run(self, data_loaders: Sequence[DataLoader], workflow: List[tuple], max_iters: Optional[int] = None, **kwargs):
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
diff --git a/external/mmdetection/tests/ote_params_validation/test_ote_config_utils_params_validation.py b/external/mmdetection/tests/ote_params_validation/test_ote_config_utils_params_validation.py
new file mode 100644
index 00000000000..5bca36c90b7
--- /dev/null
+++ b/external/mmdetection/tests/ote_params_validation/test_ote_config_utils_params_validation.py
@@ -0,0 +1,449 @@
+# Copyright (C) 2021-2022 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+#
+
+from collections import defaultdict
+
+import pytest
+from detection_tasks.apis.detection.config_utils import (
+ cluster_anchors,
+ config_from_string,
+ config_to_string,
+ get_data_cfg,
+ is_epoch_based_runner,
+ patch_adaptive_repeat_dataset,
+ patch_config,
+ patch_datasets,
+ prepare_for_testing,
+ prepare_for_training,
+ prepare_work_dir,
+ remove_from_config,
+ save_config_to_file,
+ set_data_classes,
+ set_hyperparams,
+)
+from detection_tasks.apis.detection.configuration import OTEDetectionConfig
+from mmcv import Config
+from ote_sdk.entities.datasets import DatasetEntity
+from ote_sdk.entities.label import Domain, LabelEntity
+from ote_sdk.test_suite.e2e_test_system import e2e_pytest_unit
+from ote_sdk.tests.parameters_validation.validation_helper import (
+ check_value_error_exception_raised,
+)
+from ote_sdk.usecases.reporting.time_monitor_callback import TimeMonitorCallback
+
+
+class TestConfigUtilsInputParamsValidation:
+ @e2e_pytest_unit
+ def test_is_epoch_based_runner_input_params_validation(self):
+ """
+ Description:
+ Check "is_epoch_based_runner" function input parameters validation
+
+ Input data:
+ "runner_config" non-ConfigDict object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as input parameter for
+ "runner_config" function
+ """
+ with pytest.raises(ValueError):
+ is_epoch_based_runner(runner_config="unexpected_str") # type: ignore
+
+ @e2e_pytest_unit
+ def test_patch_config_input_params_validation(self):
+ """
+ Description:
+ Check "patch_config" function input parameters validation
+
+ Input data:
+ "patch_config" function input parameters with unexpected type
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as input parameter for
+ "patch_config" function
+ """
+ label = LabelEntity(name="test label", domain=Domain.SEGMENTATION)
+ correct_values_dict = {
+ "config": Config(),
+ "work_dir": "./work_dir",
+ "labels": [label],
+ "domain": Domain.DETECTION
+ }
+ unexpected_float = 1.1
+ unexpected_values = [
+ # Unexpected float is specified as "config" parameter
+ ("config", unexpected_float),
+ # Unexpected float is specified as "work_dir" parameter
+ ("work_dir", unexpected_float),
+ # Empty string is specified as "work_dir" parameter
+ ("work_dir", ""),
+ # String with null-character is specified as "work_dir" parameter
+ ("work_dir", "null\0character/path"),
+ # String with non-printable character is specified as "work_dir" parameter
+ ("work_dir", "\non_printable_character/path"),
+ # Unexpected float is specified as "labels" parameter
+ ("labels", unexpected_float),
+ # Unexpected float is specified as nested "label"
+ ("labels", [label, unexpected_float]),
+ # Unexpected float is specified as "domain" parameter
+ ("domain", unexpected_float),
+ # Unexpected float is specified as "random_seed" parameter
+ ("random_seed", unexpected_float),
+ ]
+
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=patch_config,
+ )
+
+ @e2e_pytest_unit
+ def test_set_hyperparams_input_params_validation(self):
+ """
+ Description:
+ Check "set_hyperparams" function input parameters validation
+
+ Input data:
+ "set_hyperparams" function input parameters with unexpected type
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as input parameter for
+ "set_hyperparams" function
+ """
+ correct_values_dict = {
+ "config": Config(),
+ "hyperparams": OTEDetectionConfig(header="config header"),
+ }
+ unexpected_str = "unexpected string"
+ unexpected_values = [
+ # Unexpected string is specified as "config" parameter
+ ("config", unexpected_str),
+ # Unexpected string is specified as "hyperparams" parameter
+ ("hyperparams", unexpected_str),
+ ]
+
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=set_hyperparams,
+ )
+
+ @e2e_pytest_unit
+ def test_patch_adaptive_repeat_dataset_input_params_validation(self):
+ """
+ Description:
+ Check "patch_adaptive_repeat_dataset" function input parameters validation
+
+ Input data:
+ "patch_adaptive_repeat_dataset" function input parameters with unexpected type
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as input parameter for
+ "patch_adaptive_repeat_dataset" function
+ """
+ correct_values_dict = {"config": Config(), "num_samples": 10}
+ unexpected_str = "unexpected string"
+ unexpected_values = [
+ # Unexpected string is specified as "config" parameter
+ ("config", unexpected_str),
+ # Unexpected string is specified as "num_samples" parameter
+ ("num_samples", unexpected_str),
+ # Unexpected string is specified as "decay" parameter
+ ("decay", unexpected_str),
+ # Unexpected string is specified as "factor" parameter
+ ("factor", unexpected_str),
+ ]
+
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=patch_adaptive_repeat_dataset,
+ )
+
+ @e2e_pytest_unit
+ def test_prepare_for_testing_input_params_validation(self):
+ """
+ Description:
+ Check "prepare_for_testing" function input parameters validation
+
+ Input data:
+ "prepare_for_testing" function input parameters with unexpected type
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as input parameter for
+ "prepare_for_testing" function
+ """
+ correct_values_dict = {"config": Config(), "dataset": DatasetEntity()}
+ unexpected_str = "unexpected string"
+ unexpected_values = [
+ # Unexpected string is specified as "config" parameter
+ ("config", unexpected_str),
+ # Unexpected string is specified as "dataset" parameter
+ ("dataset", unexpected_str),
+ ]
+
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=prepare_for_testing,
+ )
+
+ @e2e_pytest_unit
+ def test_prepare_for_training_input_params_validation(self):
+ """
+ Description:
+ Check "prepare_for_training" function input parameters validation
+
+ Input data:
+ "prepare_for_training" function input parameters with unexpected type
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as input parameter for
+ "prepare_for_training" function
+ """
+ dataset = DatasetEntity()
+ time_monitor = TimeMonitorCallback(
+ num_epoch=5, num_train_steps=2, num_val_steps=1, num_test_steps=1
+ )
+ correct_values_dict = {
+ "config": Config(),
+ "train_dataset": dataset,
+ "val_dataset": dataset,
+ "time_monitor": time_monitor,
+ "learning_curves": defaultdict(),
+ }
+ unexpected_str = "unexpected string"
+ unexpected_values = [
+ # Unexpected string is specified as "config" parameter
+ ("config", unexpected_str),
+ # Unexpected string is specified as "train_dataset" parameter
+ ("train_dataset", unexpected_str),
+ # Unexpected string is specified as "val_dataset" parameter
+ ("val_dataset", unexpected_str),
+ # Unexpected string is specified as "time_monitor" parameter
+ ("time_monitor", unexpected_str),
+ # Unexpected string is specified as "learning_curves" parameter
+ ("learning_curves", unexpected_str),
+ ]
+
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=prepare_for_training,
+ )
+
+ @e2e_pytest_unit
+ def test_config_to_string_input_params_validation(self):
+ """
+ Description:
+ Check "config_to_string" function input parameters validation
+
+ Input data:
+ "config" non-Config type object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as input parameter for
+ "config_to_string" function
+ """
+ with pytest.raises(ValueError):
+ config_to_string(config=1) # type: ignore
+
+ @e2e_pytest_unit
+ def test_config_from_string_input_params_validation(self):
+ """
+ Description:
+ Check "config_from_string" function input parameters validation
+
+ Input data:
+ "config_string" non-string type object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as input parameter for
+ "config_from_string" function
+ """
+ with pytest.raises(ValueError):
+ config_from_string(config_string=1) # type: ignore
+
+ @e2e_pytest_unit
+ def test_save_config_to_file_input_params_validation(self):
+ """
+ Description:
+ Check "save_config_to_file" function input parameters validation
+
+ Input data:
+ "config" non-Config type object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as input parameter for
+ "save_config_to_file" function
+ """
+ with pytest.raises(ValueError):
+ save_config_to_file(config=1) # type: ignore
+
+ @e2e_pytest_unit
+ def test_prepare_work_dir_input_params_validation(self):
+ """
+ Description:
+ Check "prepare_work_dir" function input parameters validation
+
+ Input data:
+ "config" non-Config type object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as input parameter for
+ "prepare_work_dir" function
+ """
+ with pytest.raises(ValueError):
+ prepare_work_dir(config=1) # type: ignore
+
+ @e2e_pytest_unit
+ def test_set_data_classes_input_params_validation(self):
+ """
+ Description:
+ Check "set_data_classes" function input parameters validation
+
+ Input data:
+ "set_data_classes" function input parameters with unexpected type
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as input parameter for
+ "set_data_classes" function
+ """
+ label = LabelEntity(name="test label", domain=Domain.SEGMENTATION)
+ correct_values_dict = {"config": Config(), "labels": [label]}
+ unexpected_str = "unexpected string"
+ unexpected_values = [
+ # Unexpected string is specified as "config" parameter
+ ("config", unexpected_str),
+ # Unexpected string is specified as "dataset" parameter
+ ("labels", unexpected_str),
+ # Unexpected string is specified as nested "label"
+ ("labels", [label, unexpected_str]),
+ ]
+
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=set_data_classes,
+ )
+
+ @e2e_pytest_unit
+ def test_patch_datasets_input_params_validation(self):
+ """
+ Description:
+ Check "patch_datasets" function input parameters validation
+
+ Input data:
+ "config" non-Config type object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as input parameter for
+ "patch_datasets" function
+ """
+ correct_values_dict = {"config": Config(), "domain": Domain.DETECTION}
+ unexpected_str = "unexpected string"
+ unexpected_values = [
+ # Unexpected string is specified as "config" parameter
+ ("config", unexpected_str),
+ # Unexpected string is specified as "domain" parameter
+ ("domain", unexpected_str),
+ ]
+
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=patch_datasets,
+ )
+
+ @e2e_pytest_unit
+ def test_remove_from_config_input_params_validation(self):
+ """
+ Description:
+ Check "remove_from_config" function input parameters validation
+
+ Input data:
+ "remove_from_config" function input parameters with unexpected type
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as input parameter for
+ "remove_from_config" function
+ """
+ correct_values_dict = {"config": Config(), "key": "key_1"}
+ unexpected_int = 1
+ unexpected_values = [
+ # Unexpected integer is specified as "config" parameter
+ ("config", unexpected_int),
+ # Unexpected integer is specified as "key" parameter
+ ("key", unexpected_int),
+ ]
+
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=remove_from_config,
+ )
+
+ @e2e_pytest_unit
+ def test_cluster_anchors_input_params_validation(self):
+ """
+ Description:
+ Check "cluster_anchors" function input parameters validation
+
+ Input data:
+ "cluster_anchors" function input parameters with unexpected type
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as input parameter for
+ "cluster_anchors" function
+ """
+ correct_values_dict = {
+ "config": Config(),
+ "dataset": DatasetEntity(),
+ "model": None,
+ }
+ unexpected_str = "unexpected string"
+ unexpected_values = [
+ # Unexpected string is specified as "config" parameter
+ ("config", unexpected_str),
+ # Unexpected string is specified as "dataset" parameter
+ ("dataset", unexpected_str),
+ # Unexpected string is specified as nested "label"
+ ("model", unexpected_str),
+ ]
+
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=cluster_anchors,
+ )
+
+ @e2e_pytest_unit
+ def test_get_data_cfg_input_params_validation(self):
+ """
+ Description:
+ Check "get_data_cfg" function input parameters validation
+
+ Input data:
+ "get_data_cfg" function input parameters with unexpected type
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as input parameter for
+ "get_data_cfg" function
+ """
+ correct_values_dict = {
+ "config": Config(),
+ }
+ unexpected_int = 1
+ unexpected_values = [
+ # Unexpected integer is specified as "config" parameter
+ ("config", unexpected_int),
+ # Unexpected integer is specified as "subset" parameter
+ ("subset", unexpected_int),
+ ]
+
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=get_data_cfg,
+ )
diff --git a/external/mmdetection/tests/ote_params_validation/test_ote_data_utils_params_validation.py b/external/mmdetection/tests/ote_params_validation/test_ote_data_utils_params_validation.py
new file mode 100644
index 00000000000..f72fea86261
--- /dev/null
+++ b/external/mmdetection/tests/ote_params_validation/test_ote_data_utils_params_validation.py
@@ -0,0 +1,549 @@
+# Copyright (C) 2021-2022 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+#
+import os.path as osp
+import tempfile
+
+import mmcv
+import pytest
+from detection_tasks.extension.datasets.data_utils import (
+ CocoDataset,
+ LoadAnnotations,
+ find_label_by_name,
+ format_list_to_str,
+ get_anchor_boxes,
+ get_classes_from_annotation,
+ get_sizes_from_dataset_entity,
+ load_dataset_items_coco_format,
+)
+from ote_sdk.entities.datasets import DatasetEntity
+from ote_sdk.entities.label import Domain, LabelEntity
+from ote_sdk.test_suite.e2e_test_system import e2e_pytest_unit
+from ote_sdk.tests.parameters_validation.validation_helper import (
+ check_value_error_exception_raised,
+)
+
+
+def _create_dummy_coco_json(json_name):
+ image = {
+ "id": 0,
+ "width": 640,
+ "height": 640,
+ "file_name": "fake_name.jpg",
+ }
+
+ annotation_1 = {
+ "id": 1,
+ "image_id": 0,
+ "category_id": 0,
+ "area": 400,
+ "bbox": [50, 60, 20, 20],
+ "iscrowd": 0,
+ }
+
+ annotation_2 = {
+ "id": 2,
+ "image_id": 0,
+ "category_id": 0,
+ "area": 900,
+ "bbox": [100, 120, 30, 30],
+ "iscrowd": 0,
+ }
+
+ categories = [
+ {
+ "id": 0,
+ "name": "car",
+ "supercategory": "car",
+ }
+ ]
+
+ fake_json = {
+ "images": [image],
+ "annotations": [annotation_1, annotation_2],
+ "categories": categories,
+ }
+
+ mmcv.dump(fake_json, json_name)
+
+
+class TestDataUtilsFunctionsInputParamsValidation:
+ @e2e_pytest_unit
+ def test_get_classes_from_annotation_input_params_validation(self):
+ """
+ Description:
+ Check "get_classes_from_annotation" function input parameters validation
+
+ Input data:
+ "path" unexpected object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as input parameter for
+ "get_classes_from_annotation" function
+ """
+ for unexpected_value in [
+ # non string object is specified as "path" parameter
+ 1,
+ # Empty string is specified as "path" parameter
+ "",
+ # Path to file with unexpected extension is specified as "path" parameter
+ "./unexpected_extension.yaml",
+ # Path to non-existing file is specified as "path" parameter
+ "./non_existing.json",
+ # Path with null character is specified as "path" parameter
+ "./null\0char.json",
+ # Path with non-printable character is specified as "path" parameter
+ "./\non_printable_char.json",
+ ]:
+ with pytest.raises(ValueError):
+ get_classes_from_annotation(path=unexpected_value)
+
+ @e2e_pytest_unit
+ def test_find_label_by_name_params_validation(self):
+ """
+ Description:
+ Check "find_label_by_name" function input parameters validation
+
+ Input data:
+ "find_label_by_name" function unexpected-type input parameters
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "find_label_by_name" function
+ """
+ label = LabelEntity(name="test label", domain=Domain.DETECTION)
+ correct_values_dict = {
+ "labels": [label],
+ "name": "test label",
+ "domain": Domain.DETECTION,
+ }
+ unexpected_int = 1
+ unexpected_values = [
+ # Unexpected integer is specified as "labels" parameter
+ ("labels", unexpected_int),
+ # Unexpected integer is specified as nested label
+ ("labels", [label, unexpected_int]),
+ # Unexpected integer is specified as "name" parameter
+ ("name", unexpected_int),
+ # Unexpected integer is specified as "domain" parameter
+ ("domain", unexpected_int),
+ ]
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=find_label_by_name,
+ )
+
+ @e2e_pytest_unit
+ def test_load_dataset_items_coco_format_params_validation(self):
+ """
+ Description:
+ Check "load_dataset_items_coco_format" function input parameters validation
+
+ Input data:
+ "load_dataset_items_coco_format" function unexpected-type input parameters
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "load_dataset_items_coco_format" function
+ """
+ tmp_dir = tempfile.TemporaryDirectory()
+ fake_json_file = osp.join(tmp_dir.name, "fake_data.json")
+ _create_dummy_coco_json(fake_json_file)
+
+ label = LabelEntity(name="test label", domain=Domain.DETECTION)
+ correct_values_dict = {
+ "ann_file_path": fake_json_file,
+ "data_root_dir": tmp_dir.name,
+ "domain": Domain.DETECTION,
+ }
+ unexpected_int = 1
+ unexpected_values = [
+ # Unexpected integer is specified as "ann_file_path" parameter
+ ("ann_file_path", unexpected_int),
+ # Empty string is specified as "ann_file_path" parameter
+ ("ann_file_path", ""),
+ # Path to non-json file is specified as "ann_file_path" parameter
+ ("ann_file_path", osp.join(tmp_dir.name, "non_json.jpg")),
+ # Path with null character is specified as "ann_file_path" parameter
+ ("ann_file_path", osp.join(tmp_dir.name, "\0fake_data.json")),
+ # Path with non-printable character is specified as "ann_file_path" parameter
+ ("ann_file_path", osp.join(tmp_dir.name, "\nfake_data.json")),
+ # Path to non-existing file is specified as "ann_file_path" parameter
+ ("ann_file_path", osp.join(tmp_dir.name, "non_existing.json")),
+ # Unexpected integer is specified as "data_root_dir" parameter
+ ("data_root_dir", unexpected_int),
+ # Empty string is specified as "data_root_dir" parameter
+ ("data_root_dir", ""),
+ # Path with null character is specified as "data_root_dir" parameter
+ ("data_root_dir", "./\0null_char"),
+ # Path with non-printable character is specified as "data_root_dir" parameter
+ ("data_root_dir", "./\non_printable_char"),
+ # Unexpected integer is specified as "domain" parameter
+ ("domain", unexpected_int),
+ # Unexpected integer is specified as "subset" parameter
+ ("subset", unexpected_int),
+ # Unexpected integer is specified as "labels_list" parameter
+ ("labels_list", unexpected_int),
+ # Unexpected integer is specified as nested label
+ ("labels_list", [label, unexpected_int]),
+ # Unexpected string is specified as "with_mask" parameter
+ ("with_mask", "unexpected string"),
+ ]
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=load_dataset_items_coco_format,
+ )
+
+ @e2e_pytest_unit
+ def test_get_sizes_from_dataset_entity_params_validation(self):
+ """
+ Description:
+ Check "get_sizes_from_dataset_entity" function input parameters validation
+
+ Input data:
+ "get_sizes_from_dataset_entity" function unexpected-type input parameters
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "get_sizes_from_dataset_entity" function
+ """
+ correct_values_dict = {
+ "dataset": DatasetEntity(),
+ "target_wh": [(0.1, 0.1)],
+ }
+ unexpected_int = 1
+ unexpected_values = [
+ # Unexpected integer is specified as "dataset" parameter
+ ("dataset", unexpected_int),
+ # Unexpected integer is specified as "target_wh" parameter
+ ("target_wh", unexpected_int),
+ # Unexpected integer is specified as nested target_wh
+ ("target_wh", [(0.1, 0.1), unexpected_int]),
+ ]
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=get_sizes_from_dataset_entity,
+ )
+
+ @e2e_pytest_unit
+ def test_format_list_to_str_params_validation(self):
+ """
+ Description:
+ Check "format_list_to_str" function input parameters validation
+
+ Input data:
+ "value_lists" unexpected type object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "format_list_to_str" function
+ """
+ with pytest.raises(ValueError):
+ format_list_to_str(value_lists="unexpected string") # type: ignore
+
+ @e2e_pytest_unit
+ def test_get_anchor_boxes_params_validation(self):
+ """
+ Description:
+ Check "get_anchor_boxes" function input parameters validation
+
+ Input data:
+ "get_anchor_boxes" function unexpected-type input parameters
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "get_anchor_boxes" function
+ """
+ correct_values_dict = {
+ "wh_stats": [("wh_stat_1", 1), ("wh_stat_2", 2)],
+ "group_as": [0, 1, 2],
+ }
+ unexpected_str = "unexpected string"
+ unexpected_values = [
+ # Unexpected string is specified as "wh_stats" parameter
+ ("wh_stats", unexpected_str),
+ # Unexpected string is specified as nested "wh_stat"
+ ("wh_stats", [("wh_stat_1", 1), unexpected_str]),
+ # Unexpected string is specified as "group_as" parameter
+ ("group_as", unexpected_str),
+ # Unexpected string is specified as nested "group_as"
+ ("group_as", [0, 1, 2, unexpected_str]),
+ ]
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=get_anchor_boxes,
+ )
+
+
+class TestLoadAnnotationsInputParamsValidation:
+ @e2e_pytest_unit
+ def test_load_annotations_init_params_validation(self):
+ """
+ Description:
+ Check LoadAnnotations object initialization parameters validation
+
+ Input data:
+ LoadAnnotations object initialization parameters with unexpected type
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ LoadAnnotations initialization parameter
+ """
+ for parameter in ["with_bbox", "with_label", "with_mask"]:
+ with pytest.raises(ValueError):
+ LoadAnnotations(**{parameter: "unexpected string"})
+
+ @e2e_pytest_unit
+ def test_load_annotations_call_params_validation(self):
+ """
+ Description:
+ Check LoadAnnotations object "__call__" method input parameters validation
+
+ Input data:
+ "results" parameter with unexpected type
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "__call__" method
+ """
+ load_annotations = LoadAnnotations()
+ unexpected_int = 1
+ for unexpected_value in [
+ # Unexpected integer is specified as "results" parameter
+ unexpected_int,
+ # Unexpected integer is specified as "results" dictionary key
+ {"result_1": "some results", unexpected_int: "unexpected results"},
+ ]:
+ with pytest.raises(ValueError):
+ load_annotations(results=unexpected_value)
+
+
+class TestCocoDatasetInputParamsValidation:
+ @staticmethod
+ def create_fake_json_file():
+ tmp_dir = tempfile.TemporaryDirectory()
+ fake_json_file = osp.join(tmp_dir.name, "fake_data.json")
+ _create_dummy_coco_json(fake_json_file)
+ return fake_json_file
+
+ @staticmethod
+ def dataset():
+ tmp_dir = tempfile.TemporaryDirectory()
+ fake_json_file = osp.join(tmp_dir.name, "fake_data.json")
+ _create_dummy_coco_json(fake_json_file)
+ return CocoDataset(fake_json_file)
+
+ @e2e_pytest_unit
+ def test_coco_dataset_init_params_validation(self):
+ """
+ Description:
+ Check CocoDataset object initialization parameters validation
+
+ Input data:
+ CocoDataset object initialization parameters with unexpected type
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ CocoDataset object initialization parameter
+ """
+ tmp_dir = tempfile.TemporaryDirectory()
+ fake_json_file = osp.join(tmp_dir.name, "fake_data.json")
+ _create_dummy_coco_json(fake_json_file)
+
+ correct_values_dict = {
+ "ann_file": fake_json_file,
+ }
+ unexpected_str = "unexpected string"
+ unexpected_int = 1
+ unexpected_values = [
+ # Unexpected integer is specified as "ann_file" parameter
+ ("ann_file", unexpected_int),
+ # Empty string is specified as "ann_file" parameter
+ ("ann_file", ""),
+ # Path to non-json file is specified as "ann_file" parameter
+ ("ann_file", osp.join(tmp_dir.name, "non_json.jpg")),
+ # Path with null character is specified as "ann_file" parameter
+ ("ann_file", osp.join(tmp_dir.name, "\0fake_data.json")),
+ # Path with non-printable character is specified as "ann_file" parameter
+ ("ann_file", osp.join(tmp_dir.name, "\nfake_data.json")),
+ # Path to non-existing file is specified as "ann_file" parameter
+ ("ann_file", osp.join(tmp_dir.name, "non_existing.json")),
+ # Unexpected integer is specified as "classes" parameter
+ ("classes", unexpected_int),
+ # Unexpected integer is specified nested class
+ ("classes", ["class_1", unexpected_int]),
+ # Unexpected integer is specified as "data_root" parameter
+ ("data_root", unexpected_int),
+ # Empty string is specified as "data_root" parameter
+ ("data_root", ""),
+ # Path with null character is specified as "data_root" parameter
+ ("data_root", "./\0null_char"),
+ # Path with non-printable character is specified as "data_root" parameter
+ ("data_root", "./\non_printable_char"),
+ # Unexpected integer is specified as "img_prefix" parameter
+ ("img_prefix", unexpected_int),
+ # Unexpected string is specified as "test_mode" parameter
+ ("test_mode", unexpected_str),
+ # Unexpected string is specified as "filter_empty_gt" parameter
+ ("filter_empty_gt", unexpected_str),
+ # Unexpected string is specified as "min_size" parameter
+ ("min_size", unexpected_str),
+ # Unexpected string is specified as "with_mask" parameter
+ ("with_mask", unexpected_str),
+ ]
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=CocoDataset,
+ )
+
+ @e2e_pytest_unit
+ def test_coco_dataset_pre_pipeline_params_validation(self):
+ """
+ Description:
+ Check CocoDataset object "pre_pipeline" method input parameters validation
+
+ Input data:
+ CocoDataset object, "results" parameter with unexpected type
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "pre_pipeline" method
+ """
+ dataset = self.dataset()
+ unexpected_int = 1
+ for unexpected_value in [
+ # Unexpected integer is specified as "results" parameter
+ unexpected_int,
+ # Unexpected integer is specified as "results" dictionary key
+ {"result_1": "some results", unexpected_int: "unexpected results"},
+ ]:
+ with pytest.raises(ValueError):
+ dataset.pre_pipeline(results=unexpected_value)
+
+ @e2e_pytest_unit
+ def test_coco_dataset_get_item_params_validation(self):
+ """
+ Description:
+ Check CocoDataset object "__getitem__" method input parameters validation
+
+ Input data:
+ CocoDataset object, "idx" non-integer type parameter
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "__getitem__" method
+ """
+ dataset = self.dataset()
+ with pytest.raises(ValueError):
+ dataset.__getitem__(idx="unexpected string") # type: ignore
+
+ @e2e_pytest_unit
+ def test_coco_dataset_prepare_img_params_validation(self):
+ """
+ Description:
+ Check CocoDataset object "prepare_img" method input parameters validation
+
+ Input data:
+ CocoDataset object, "idx" non-integer type parameter
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "prepare_img" method
+ """
+ dataset = self.dataset()
+ with pytest.raises(ValueError):
+ dataset.prepare_img(idx="unexpected string") # type: ignore
+
+ @e2e_pytest_unit
+ def test_coco_dataset_get_classes_params_validation(self):
+ """
+ Description:
+ Check CocoDataset object "get_classes" method input parameters validation
+
+ Input data:
+ CocoDataset object, "classes" parameter with unexpected type
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "get_classes" method
+ """
+ dataset = self.dataset()
+ unexpected_int = 1
+ for unexpected_value in [
+ # Unexpected integer is specified as "classes" parameter
+ unexpected_int,
+ # Unexpected integer is specified as nested "classes" element
+ ["class_1", unexpected_int],
+ ]:
+ with pytest.raises(ValueError):
+ dataset.get_classes(classes=unexpected_value) # type: ignore
+
+ @e2e_pytest_unit
+ def test_coco_dataset_load_annotations_params_validation(self):
+ """
+ Description:
+ Check CocoDataset object "load_annotations" method input parameters validation
+
+ Input data:
+ CocoDataset object, "ann_file" unexpected object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "load_annotations" method
+ """
+ dataset = self.dataset()
+ for unexpected_value in [
+ # Unexpected integer is specified as "ann_file" parameter
+ 1,
+ # Empty string is specified as "ann_file" parameter
+ "",
+ # Path to non-existing file is specified as "ann_file" parameter
+ "./non_existing.json",
+ # Path to non-json file is specified as "ann_file" parameter
+ "./unexpected_type.jpg",
+ # Path Null character is specified in "ann_file" parameter
+ "./null\0char.json",
+ # Path with non-printable character is specified as "input_config" parameter
+ "./null\nchar.json",
+ ]:
+ with pytest.raises(ValueError):
+ dataset.load_annotations(ann_file=unexpected_value)
+
+ @e2e_pytest_unit
+ def test_coco_dataset_get_ann_info_params_validation(self):
+ """
+ Description:
+ Check CocoDataset object "get_ann_info" method input parameters validation
+
+ Input data:
+ CocoDataset object, "idx" non-integer type parameter
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "get_ann_info" method
+ """
+ dataset = self.dataset()
+ with pytest.raises(ValueError):
+ dataset.get_ann_info(idx="unexpected string") # type: ignore
+
+ @e2e_pytest_unit
+ def test_coco_dataset_get_cat_ids_params_validation(self):
+ """
+ Description:
+ Check CocoDataset object "get_cat_ids" method input parameters validation
+
+ Input data:
+ CocoDataset object, "idx" non-integer type parameter
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "get_cat_ids" method
+ """
+ dataset = self.dataset()
+ with pytest.raises(ValueError):
+ dataset.get_cat_ids(idx="unexpected string") # type: ignore
diff --git a/external/mmdetection/tests/ote_params_validation/test_ote_hooks_params_validation.py b/external/mmdetection/tests/ote_params_validation/test_ote_hooks_params_validation.py
new file mode 100644
index 00000000000..e10ddb5b90e
--- /dev/null
+++ b/external/mmdetection/tests/ote_params_validation/test_ote_hooks_params_validation.py
@@ -0,0 +1,554 @@
+# Copyright (C) 2021-2022 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+#
+from logging import Logger
+
+import pytest
+import torch.nn as nn
+from detection_tasks.extension.utils.hooks import (
+ CancelTrainingHook,
+ EarlyStoppingHook,
+ EnsureCorrectBestCheckpointHook,
+ FixedMomentumUpdaterHook,
+ OTELoggerHook,
+ OTEProgressHook,
+ StopLossNanTrainingHook,
+ ReduceLROnPlateauLrUpdaterHook,
+)
+from mmcv.runner import EpochBasedRunner
+from ote_sdk.test_suite.e2e_test_system import e2e_pytest_unit
+from ote_sdk.tests.parameters_validation.validation_helper import (
+ check_value_error_exception_raised,
+)
+from ote_sdk.usecases.reporting.time_monitor_callback import TimeMonitorCallback
+
+
+class TestCancelTrainingHook:
+ @e2e_pytest_unit
+ def test_cancel_training_hook_initialization_params_validation(self):
+ """
+ Description:
+ Check CancelTrainingHook object initialization parameters validation
+
+ Input data:
+ "interval" non-int type object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ CancelTrainingHook object initialization parameter
+ """
+ with pytest.raises(ValueError):
+ CancelTrainingHook(interval="unexpected string") # type: ignore
+
+ @e2e_pytest_unit
+ def test_cancel_training_hook_after_train_iter_params_validation(self):
+ """
+ Description:
+ Check CancelTrainingHook object "after_train_iter" method input parameters validation
+
+ Input data:
+ CancelTrainingHook object, "runner" non-BaseRunner type object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "after_train_iter" method
+ """
+ hook = CancelTrainingHook()
+ with pytest.raises(ValueError):
+ hook.after_train_iter(runner="unexpected string") # type: ignore
+
+
+class TestFixedMomentumUpdaterHook:
+ @e2e_pytest_unit
+ def test_fixed_momentum_updater_hook_before_run_params_validation(self):
+ """
+ Description:
+ Check FixedMomentumUpdaterHook object "before_run" method input parameters validation
+
+ Input data:
+ FixedMomentumUpdaterHook object, "runner" non-BaseRunner type object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "before_run" method
+ """
+ hook = FixedMomentumUpdaterHook()
+ with pytest.raises(ValueError):
+ hook.before_run(runner="unexpected string") # type: ignore
+
+
+class TestEnsureCorrectBestCheckpointHook:
+ @e2e_pytest_unit
+ def test_ensure_correct_best_checkpoint_hook_after_run_params_validation(self):
+ """
+ Description:
+ Check EnsureCorrectBestCheckpointHook object "after_run" method input parameters validation
+
+ Input data:
+ EnsureCorrectBestCheckpointHook object, "runner" non-BaseRunner type object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "after_run" method
+ """
+ hook = EnsureCorrectBestCheckpointHook()
+ with pytest.raises(ValueError):
+ hook.after_run(runner="unexpected string") # type: ignore
+
+
+class TestOTELoggerHook:
+ @e2e_pytest_unit
+ def test_ote_logger_hook_initialization_parameters_validation(self):
+ """
+ Description:
+ Check OTELoggerHook object initialization parameters validation
+
+ Input data:
+ OTELoggerHook object initialization parameters with unexpected type
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ OTELoggerHook object initialization parameter
+ """
+ correct_values_dict = {}
+ unexpected_str = "unexpected string"
+ unexpected_values = [
+ # Unexpected string is specified as "curves" parameter
+ ("curves", unexpected_str),
+ # Unexpected string is specified as nested curve
+ (
+ "curves",
+ {
+ "expected": OTELoggerHook.Curve(),
+ "unexpected": unexpected_str,
+ },
+ ),
+ # Unexpected string is specified as "interval" parameter
+ ("interval", unexpected_str),
+ # Unexpected string is specified as "ignore_last" parameter
+ ("ignore_last", unexpected_str),
+ # Unexpected string is specified as "reset_flag" parameter
+ ("reset_flag", unexpected_str),
+ # Unexpected string is specified as "by_epoch" parameter
+ ("by_epoch", unexpected_str),
+ ]
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=OTELoggerHook,
+ )
+
+ @e2e_pytest_unit
+ def test_ote_logger_hook_log_params_validation(self):
+ """
+ Description:
+ Check OTELoggerHook object "log" method input parameters validation
+
+ Input data:
+ OTELoggerHook object, "runner" non-BaseRunner type object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "log" method
+ """
+ hook = OTELoggerHook()
+ with pytest.raises(ValueError):
+ hook.log(runner="unexpected string") # type: ignore
+
+ @e2e_pytest_unit
+ def test_ote_logger_hook_after_train_epoch_params_validation(self):
+ """
+ Description:
+ Check OTELoggerHook object "after_train_epoch" method input parameters validation
+
+ Input data:
+ OTELoggerHook object, "runner" non-BaseRunner type object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "after_train_epoch" method
+ """
+ hook = OTELoggerHook()
+ with pytest.raises(ValueError):
+ hook.after_train_epoch(runner="unexpected string") # type: ignore
+
+
+class TestOTEProgressHook:
+ @staticmethod
+ def time_monitor():
+ return TimeMonitorCallback(
+ num_epoch=10, num_train_steps=5, num_test_steps=5, num_val_steps=4
+ )
+
+ def hook(self):
+ return OTEProgressHook(time_monitor=self.time_monitor())
+
+ @e2e_pytest_unit
+ def test_ote_progress_hook_initialization_parameters_validation(self):
+ """
+ Description:
+ Check OTEProgressHook object initialization parameters validation
+
+ Input data:
+ OTEProgressHook object initialization parameters with unexpected type
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ OTEProgressHook object initialization parameter
+ """
+ correct_values_dict = {"time_monitor": self.time_monitor()}
+ unexpected_str = "unexpected string"
+ unexpected_values = [
+ # Unexpected string is specified as "time_monitor" parameter
+ ("time_monitor", unexpected_str),
+ # Unexpected string is specified as "verbose" parameter
+ ("verbose", unexpected_str),
+ ]
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=OTEProgressHook,
+ )
+
+ @e2e_pytest_unit
+ def test_ote_progress_hook_before_run_params_validation(self):
+ """
+ Description:
+ Check OTEProgressHook object "before_run" method input parameters validation
+
+ Input data:
+ OTEProgressHook object, "runner" non-BaseRunner type object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "before_run" method
+ """
+ hook = self.hook()
+ with pytest.raises(ValueError):
+ hook.before_run(runner="unexpected string") # type: ignore
+
+ @e2e_pytest_unit
+ def test_ote_progress_hook_before_epoch_params_validation(self):
+ """
+ Description:
+ Check OTEProgressHook object "before_epoch" method input parameters validation
+
+ Input data:
+ OTEProgressHook object, "runner" non-BaseRunner type object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "before_epoch" method
+ """
+ hook = self.hook()
+ with pytest.raises(ValueError):
+ hook.before_epoch(runner="unexpected string") # type: ignore
+
+ @e2e_pytest_unit
+ def test_ote_progress_hook_after_epoch_params_validation(self):
+ """
+ Description:
+ Check OTEProgressHook object "after_epoch" method input parameters validation
+
+ Input data:
+ OTEProgressHook object, "runner" non-BaseRunner type object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "after_epoch" method
+ """
+ hook = self.hook()
+ with pytest.raises(ValueError):
+ hook.after_epoch(runner="unexpected string") # type: ignore
+
+ @e2e_pytest_unit
+ def test_ote_progress_hook_before_iter_params_validation(self):
+ """
+ Description:
+ Check OTEProgressHook object "before_iter" method input parameters validation
+
+ Input data:
+ OTEProgressHook object, "runner" non-BaseRunner type object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "before_iter" method
+ """
+ hook = self.hook()
+ with pytest.raises(ValueError):
+ hook.before_iter(runner="unexpected string") # type: ignore
+
+ @e2e_pytest_unit
+ def test_ote_progress_hook_after_iter_params_validation(self):
+ """
+ Description:
+ Check OTEProgressHook object "after_iter" method input parameters validation
+
+ Input data:
+ OTEProgressHook object, "runner" non-BaseRunner type object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "after_iter" method
+ """
+ hook = self.hook()
+ with pytest.raises(ValueError):
+ hook.after_iter(runner="unexpected string") # type: ignore
+
+ @e2e_pytest_unit
+ def test_ote_progress_hook_before_val_iter_params_validation(self):
+ """
+ Description:
+ Check OTEProgressHook object "before_val_iter" method input parameters validation
+
+ Input data:
+ OTEProgressHook object, "runner" non-BaseRunner type object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "before_val_iter" method
+ """
+ hook = self.hook()
+ with pytest.raises(ValueError):
+ hook.before_val_iter(runner="unexpected string") # type: ignore
+
+ @e2e_pytest_unit
+ def test_ote_progress_hook_after_val_iter_params_validation(self):
+ """
+ Description:
+ Check OTEProgressHook object "after_val_iter" method input parameters validation
+
+ Input data:
+ OTEProgressHook object, "runner" non-BaseRunner type object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "after_val_iter" method
+ """
+ hook = self.hook()
+ with pytest.raises(ValueError):
+ hook.after_val_iter(runner="unexpected string") # type: ignore
+
+ @e2e_pytest_unit
+ def test_ote_progress_hook_after_run_params_validation(self):
+ """
+ Description:
+ Check OTEProgressHook object "after_run" method input parameters validation
+
+ Input data:
+ OTEProgressHook object, "runner" non-BaseRunner type object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "after_run" method
+ """
+ hook = self.hook()
+ with pytest.raises(ValueError):
+ hook.after_run(runner="unexpected string") # type: ignore
+
+
+class TestEarlyStoppingHook:
+ @staticmethod
+ def hook():
+ return EarlyStoppingHook(interval=5)
+
+ @e2e_pytest_unit
+ def test_early_stopping_hook_initialization_parameters_validation(self):
+ """
+ Description:
+ Check EarlyStoppingHook object initialization parameters validation
+
+ Input data:
+ EarlyStoppingHook object initialization parameters with unexpected type
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ EarlyStoppingHook object initialization parameter
+ """
+ correct_values_dict = {"interval": 5}
+ unexpected_dict = {"unexpected": "dictionary"}
+ unexpected_values = [
+ # Unexpected dictionary is specified as "interval" parameter
+ ("interval", unexpected_dict),
+ # Unexpected dictionary is specified as "metric" parameter
+ ("metric", unexpected_dict),
+ # Unexpected dictionary is specified as "rule" parameter
+ ("rule", unexpected_dict),
+ # Unexpected dictionary is specified as "patience" parameter
+ ("patience", unexpected_dict),
+ # Unexpected dictionary is specified as "iteration_patience" parameter
+ ("iteration_patience", unexpected_dict),
+ # Unexpected dictionary is specified as "min_delta" parameter
+ ("min_delta", unexpected_dict),
+ ]
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=EarlyStoppingHook,
+ )
+
+ @e2e_pytest_unit
+ def test_early_stopping_hook_before_run_params_validation(self):
+ """
+ Description:
+ Check EarlyStoppingHook object "before_run" method input parameters validation
+
+ Input data:
+ EarlyStoppingHook object, "runner" non-BaseRunner type object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "before_run" method
+ """
+ hook = self.hook()
+ with pytest.raises(ValueError):
+ hook.before_run(runner="unexpected string") # type: ignore
+
+ @e2e_pytest_unit
+ def test_early_stopping_hook_after_train_iter_params_validation(self):
+ """
+ Description:
+ Check EarlyStoppingHook object "after_train_iter" method input parameters validation
+
+ Input data:
+ EarlyStoppingHook object, "runner" non-BaseRunner type object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "after_train_iter" method
+ """
+ hook = self.hook()
+ with pytest.raises(ValueError):
+ hook.after_train_iter(runner="unexpected string") # type: ignore
+
+ @e2e_pytest_unit
+ def test_early_stopping_hook_after_train_epoch_params_validation(self):
+ """
+ Description:
+ Check EarlyStoppingHook object "after_train_epoch" method input parameters validation
+
+ Input data:
+ EarlyStoppingHook object, "runner" non-BaseRunner type object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "after_train_epoch" method
+ """
+ hook = self.hook()
+ with pytest.raises(ValueError):
+ hook.after_train_epoch(runner="unexpected string") # type: ignore
+
+
+class TestReduceLROnPlateauLrUpdaterHook:
+ @staticmethod
+ def hook():
+ return ReduceLROnPlateauLrUpdaterHook(min_lr=0.1, interval=5)
+
+ class MockModel(nn.Module):
+ @staticmethod
+ def train_step():
+ pass
+
+ @e2e_pytest_unit
+ def test_reduce_lr_hook_initialization_parameters_validation(self):
+ """
+ Description:
+ Check ReduceLROnPlateauLrUpdaterHook object initialization parameters validation
+
+ Input data:
+ ReduceLROnPlateauLrUpdaterHook object initialization parameters with unexpected type
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ ReduceLROnPlateauLrUpdaterHook object initialization parameter
+ """
+ correct_values_dict = {"min_lr": 0.1, "interval": 5}
+ unexpected_dict = {"unexpected": "dictionary"}
+ unexpected_values = [
+ # Unexpected dictionary is specified as "min_lr" parameter
+ ("min_lr", unexpected_dict),
+ # Unexpected dictionary is specified as "interval" parameter
+ ("interval", unexpected_dict),
+ # Unexpected dictionary is specified as "metric" parameter
+ ("metric", unexpected_dict),
+ # Unexpected dictionary is specified as "rule" parameter
+ ("rule", unexpected_dict),
+ # Unexpected dictionary is specified as "factor" parameter
+ ("factor", unexpected_dict),
+ # Unexpected dictionary is specified as "patience" parameter
+ ("patience", unexpected_dict),
+ # Unexpected dictionary is specified as "iteration_patience" parameter
+ ("iteration_patience", unexpected_dict),
+ ]
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=ReduceLROnPlateauLrUpdaterHook,
+ )
+
+ @e2e_pytest_unit
+ def test_reduce_lr_hook_get_lr_params_validation(self):
+ """
+ Description:
+ Check ReduceLROnPlateauLrUpdaterHook object "get_lr" method input parameters validation
+
+ Input data:
+ ReduceLROnPlateauLrUpdaterHook object, "get_lr" method unexpected-type input parameters
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "get_lr" method
+ """
+ hook = self.hook()
+ runner = EpochBasedRunner(
+ model=self.MockModel(), logger=Logger(name="test logger")
+ )
+ correct_values_dict = {"runner": runner, "base_lr": 0.2}
+ unexpected_str = "unexpected string"
+ unexpected_values = [
+ # Unexpected string is specified as "runner" parameter
+ ("runner", unexpected_str),
+ # Unexpected string is specified as "base_lr" parameter
+ ("base_lr", unexpected_str),
+ ]
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=hook.get_lr,
+ )
+
+ @e2e_pytest_unit
+ def test_reduce_lr_hook_before_run_params_validation(self):
+ """
+ Description:
+ Check ReduceLROnPlateauLrUpdaterHook object "before_run" method input parameters validation
+
+ Input data:
+ ReduceLROnPlateauLrUpdaterHook object, "runner" non-BaseRunner type object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "before_run" method
+ """
+ hook = self.hook()
+ with pytest.raises(ValueError):
+ hook.before_run(runner="unexpected string") # type: ignore
+
+
+class TestStopLossNanTrainingHook:
+ @e2e_pytest_unit
+ def test_stop_loss_nan_train_hook_after_train_iter_params_validation(self):
+ """
+ Description:
+ Check StopLossNanTrainingHook object "after_train_iter" method input parameters validation
+
+ Input data:
+ StopLossNanTrainingHook object, "runner" non-BaseRunner type object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "after_train_iter" method
+ """
+ hook = StopLossNanTrainingHook()
+ with pytest.raises(ValueError):
+ hook.after_train_iter(runner="unexpected string") # type: ignore
diff --git a/external/mmdetection/tests/ote_params_validation/test_ote_inference_task_params_validation.py b/external/mmdetection/tests/ote_params_validation/test_ote_inference_task_params_validation.py
new file mode 100644
index 00000000000..cbd76244a31
--- /dev/null
+++ b/external/mmdetection/tests/ote_params_validation/test_ote_inference_task_params_validation.py
@@ -0,0 +1,156 @@
+# Copyright (C) 2021-2022 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+#
+
+import pytest
+
+from detection_tasks.apis.detection.inference_task import OTEDetectionInferenceTask
+from ote_sdk.configuration.configurable_parameters import ConfigurableParameters
+from ote_sdk.entities.datasets import DatasetEntity
+from ote_sdk.entities.inference_parameters import InferenceParameters
+from ote_sdk.entities.label_schema import LabelSchemaEntity
+from ote_sdk.entities.model import ModelConfiguration, ModelEntity
+from ote_sdk.entities.resultset import ResultSetEntity
+from ote_sdk.test_suite.e2e_test_system import e2e_pytest_unit
+from ote_sdk.tests.parameters_validation.validation_helper import (
+ check_value_error_exception_raised,
+)
+from ote_sdk.usecases.tasks.interfaces.export_interface import ExportType
+
+
+class MockDetectionInferenceTask(OTEDetectionInferenceTask):
+ def __init__(self):
+ pass
+
+
+class TestInferenceTaskInputParamsValidation:
+ @staticmethod
+ def model():
+ model_configuration = ModelConfiguration(
+ configurable_parameters=ConfigurableParameters(
+ header="header", description="description"
+ ),
+ label_schema=LabelSchemaEntity(),
+ )
+ return ModelEntity(
+ train_dataset=DatasetEntity(), configuration=model_configuration
+ )
+
+ @e2e_pytest_unit
+ def test_ote_detection_inference_task_init_params_validation(self):
+ """
+ Description:
+ Check OTEDetectionInferenceTask object initialization parameters validation
+
+ Input data:
+ "task_environment" non-TaskEnvironment object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ OTEDetectionInferenceTask object initialization parameter
+ """
+ with pytest.raises(ValueError):
+ OTEDetectionInferenceTask(task_environment="unexpected string") # type: ignore
+
+ @e2e_pytest_unit
+ def test_ote_detection_inference_task_infer_params_validation(self):
+ """
+ Description:
+ Check OTEDetectionInferenceTask object "infer" method input parameters validation
+
+ Input data:
+ OTEDetectionInferenceTask object. "infer" method unexpected-type input parameters
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "infer" method
+ """
+ task = MockDetectionInferenceTask()
+ correct_values_dict = {
+ "dataset": DatasetEntity(),
+ "inference_parameters": InferenceParameters(),
+ }
+ unexpected_str = "unexpected string"
+ unexpected_values = [
+ # Unexpected string is specified as "dataset" parameter
+ ("dataset", unexpected_str),
+ # Unexpected string is specified as "inference_parameters" parameter
+ ("inference_parameters", unexpected_str),
+ ]
+
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=task.infer,
+ )
+
+ @e2e_pytest_unit
+ def test_ote_detection_inference_task_evaluate_params_validation(self):
+ """
+ Description:
+ Check OTEDetectionInferenceTask object "evaluate" method input parameters validation
+
+ Input data:
+ OTEDetectionInferenceTask object. "evaluate" method unexpected-type input parameters
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "evaluate" method
+ """
+ task = MockDetectionInferenceTask()
+ model = self.model()
+ result_set = ResultSetEntity(
+ model=model,
+ ground_truth_dataset=DatasetEntity(),
+ prediction_dataset=DatasetEntity(),
+ )
+ correct_values_dict = {
+ "output_result_set": result_set,
+ "evaluation_metric": "metric",
+ }
+ unexpected_int = 1
+ unexpected_values = [
+ # Unexpected integer is specified as "output_result_set" parameter
+ ("output_result_set", unexpected_int),
+ # Unexpected integer is specified as "evaluation_metric" parameter
+ ("evaluation_metric", unexpected_int),
+ ]
+
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=task.evaluate,
+ )
+
+ @e2e_pytest_unit
+ def test_ote_detection_inference_task_export_params_validation(self):
+ """
+ Description:
+ Check OTEDetectionInferenceTask object "export" method input parameters validation
+
+ Input data:
+ OTEDetectionInferenceTask object. "export" method unexpected-type input parameters
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "export" method
+ """
+ task = MockDetectionInferenceTask()
+ model = self.model()
+ correct_values_dict = {
+ "export_type": ExportType.OPENVINO,
+ "output_model": model,
+ }
+ unexpected_str = "unexpected string"
+ unexpected_values = [
+ # Unexpected string is specified as "export_type" parameter
+ ("export_type", unexpected_str),
+ # Unexpected string is specified as "output_model" parameter
+ ("output_model", unexpected_str),
+ ]
+
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=task.export,
+ )
diff --git a/external/mmdetection/tests/ote_params_validation/test_ote_mmdataset_params_validation.py b/external/mmdetection/tests/ote_params_validation/test_ote_mmdataset_params_validation.py
new file mode 100644
index 00000000000..b35d2ab3e17
--- /dev/null
+++ b/external/mmdetection/tests/ote_params_validation/test_ote_mmdataset_params_validation.py
@@ -0,0 +1,213 @@
+# Copyright (C) 2021-2022 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+#
+
+import numpy as np
+import pytest
+from detection_tasks.extension.datasets.mmdataset import (
+ OTEDataset,
+ get_annotation_mmdet_format,
+)
+
+from ote_sdk.entities.annotation import (
+ Annotation,
+ AnnotationSceneEntity,
+ AnnotationSceneKind,
+)
+from ote_sdk.entities.dataset_item import DatasetItemEntity
+from ote_sdk.entities.datasets import DatasetEntity
+from ote_sdk.entities.image import Image
+from ote_sdk.entities.label import Domain, LabelEntity
+from ote_sdk.entities.scored_label import ScoredLabel
+from ote_sdk.entities.shapes.rectangle import Rectangle
+from ote_sdk.test_suite.e2e_test_system import e2e_pytest_unit
+from ote_sdk.tests.parameters_validation.validation_helper import (
+ check_value_error_exception_raised,
+)
+
+
+def label_entity():
+ return LabelEntity(name="test label", domain=Domain.DETECTION)
+
+
+def dataset_item():
+ image = Image(data=np.random.randint(low=0, high=255, size=(10, 16, 3)))
+ annotation = Annotation(
+ shape=Rectangle.generate_full_box(), labels=[ScoredLabel(label_entity())]
+ )
+ annotation_scene = AnnotationSceneEntity(
+ annotations=[annotation], kind=AnnotationSceneKind.ANNOTATION
+ )
+ return DatasetItemEntity(media=image, annotation_scene=annotation_scene)
+
+
+class TestMMDatasetFunctionsInputParamsValidation:
+ @e2e_pytest_unit
+ def test_get_annotation_mmdet_format_input_params_validation(self):
+ """
+ Description:
+ Check "get_annotation_mmdet_format" function input parameters validation
+
+ Input data:
+ "get_annotation_mmdet_format" function unexpected-type input parameters
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "get_annotation_mmdet_format" function
+ """
+ label = label_entity()
+ correct_values_dict = {
+ "dataset_item": dataset_item(),
+ "labels": [label],
+ "domain": Domain.DETECTION,
+ }
+ unexpected_int = 1
+ unexpected_values = [
+ # Unexpected integer is specified as "dataset_item" parameter
+ ("dataset_item", unexpected_int),
+ # Unexpected integer is specified as "labels" parameter
+ ("labels", unexpected_int),
+ # Unexpected integer is specified as nested label
+ ("labels", [label, unexpected_int]),
+ # Unexpected integer is specified as "domain" parameter
+ ("domain", unexpected_int),
+ # Unexpected string is specified as "min_size" parameter
+ ("min_size", "unexpected string"),
+ ]
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=get_annotation_mmdet_format,
+ )
+
+
+class TestOTEDatasetInputParamsValidation:
+ @staticmethod
+ def dataset():
+ pipeline = [{"type": "LoadImageFromFile", "to_float32": True}]
+ return OTEDataset(
+ ote_dataset=DatasetEntity(),
+ labels=[label_entity()],
+ pipeline=pipeline,
+ test_mode=True,
+ domain=Domain.DETECTION,
+ )
+
+ @e2e_pytest_unit
+ def test_ote_dataset_init_params_validation(self):
+ """
+ Description:
+ Check OTEDataset object initialization parameters validation
+
+ Input data:
+ OTEDataset object initialization parameters with unexpected type
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ OTEDataset object initialization parameter
+ """
+ label = label_entity()
+
+ correct_values_dict = {
+ "ote_dataset": DatasetEntity(),
+ "labels": [label],
+ "pipeline": [{"type": "LoadImageFromFile", "to_float32": True}],
+ "domain": Domain.DETECTION,
+ }
+ unexpected_str = "unexpected string"
+ unexpected_values = [
+ # Unexpected string is specified as "ote_dataset" parameter
+ ("ote_dataset", unexpected_str),
+ # Unexpected string is specified as "labels" parameter
+ ("labels", unexpected_str),
+ # Unexpected string is specified as nested label
+ ("labels", [label, unexpected_str]),
+ # Unexpected integer is specified as "pipeline" parameter
+ ("pipeline", 1),
+ # Unexpected string is specified as nested pipeline
+ ("pipeline", [{"config": 1}, unexpected_str]),
+ # Unexpected string is specified as "domain" parameter
+ ("domain", unexpected_str),
+ # Unexpected string is specified as "test_mode" parameter
+ ("test_mode", unexpected_str),
+ ]
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=OTEDataset,
+ )
+
+ @e2e_pytest_unit
+ def test_ote_dataset_prepare_train_img_params_validation(self):
+ """
+ Description:
+ Check OTEDataset object "prepare_train_img" method input parameters validation
+
+ Input data:
+ OTEDataset object, "idx" non-integer type parameter
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "prepare_train_img" method
+ """
+ dataset = self.dataset()
+ with pytest.raises(ValueError):
+ dataset.prepare_train_img(idx="unexpected string") # type: ignore
+
+ @e2e_pytest_unit
+ def test_ote_dataset_prepare_test_img_params_validation(self):
+ """
+ Description:
+ Check OTEDataset object "prepare_test_img" method input parameters validation
+
+ Input data:
+ OTEDataset object, "idx" non-integer type parameter
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "prepare_test_img" method
+ """
+ dataset = self.dataset()
+ with pytest.raises(ValueError):
+ dataset.prepare_test_img(idx="unexpected string") # type: ignore
+
+ @e2e_pytest_unit
+ def test_ote_dataset_pre_pipeline_params_validation(self):
+ """
+ Description:
+ Check OTEDataset object "pre_pipeline" method input parameters validation
+
+ Input data:
+ OTEDataset object, "results" unexpected type object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "pre_pipeline" method
+ """
+ dataset = self.dataset()
+ unexpected_int = 1
+ for unexpected_value in [
+ # Unexpected integer is specified as "results" parameter
+ unexpected_int,
+ # Unexpected integer is specified as "results" dictionary key
+ {"result_1": "some results", unexpected_int: "unexpected results"},
+ ]:
+ with pytest.raises(ValueError):
+ dataset.pre_pipeline(results=unexpected_value)
+
+ @e2e_pytest_unit
+ def test_ote_dataset_get_ann_info_params_validation(self):
+ """
+ Description:
+ Check OTEDataset object "get_ann_info" method input parameters validation
+
+ Input data:
+ OTEDataset object, "idx" non-integer type parameter
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "get_ann_info" method
+ """
+ dataset = self.dataset()
+ with pytest.raises(ValueError):
+ dataset.get_ann_info(idx="unexpected string") # type: ignore
diff --git a/external/mmdetection/tests/ote_params_validation/test_ote_nncf_task_params_validation.py b/external/mmdetection/tests/ote_params_validation/test_ote_nncf_task_params_validation.py
new file mode 100644
index 00000000000..13d48024ab3
--- /dev/null
+++ b/external/mmdetection/tests/ote_params_validation/test_ote_nncf_task_params_validation.py
@@ -0,0 +1,136 @@
+# Copyright (C) 2021-2022 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+#
+
+import pytest
+
+from detection_tasks.apis.detection.nncf_task import OTEDetectionNNCFTask
+from ote_sdk.configuration.configurable_parameters import ConfigurableParameters
+from ote_sdk.entities.datasets import DatasetEntity
+from ote_sdk.entities.label_schema import LabelSchemaEntity
+from ote_sdk.entities.model import ModelConfiguration, ModelEntity
+from ote_sdk.test_suite.e2e_test_system import e2e_pytest_unit
+from ote_sdk.tests.parameters_validation.validation_helper import (
+ check_value_error_exception_raised,
+)
+from ote_sdk.usecases.tasks.interfaces.export_interface import ExportType
+from ote_sdk.usecases.tasks.interfaces.optimization_interface import OptimizationType
+
+
+class MockNNCFTask(OTEDetectionNNCFTask):
+ def __init__(self):
+ pass
+
+
+class TestNNCFTaskInputParamsValidation:
+ @staticmethod
+ def model():
+ model_configuration = ModelConfiguration(
+ configurable_parameters=ConfigurableParameters(
+ header="header", description="description"
+ ),
+ label_schema=LabelSchemaEntity(),
+ )
+ return ModelEntity(
+ train_dataset=DatasetEntity(), configuration=model_configuration
+ )
+
+ @e2e_pytest_unit
+ def test_nncf_detection_task_init_params_validation(self):
+ """
+ Description:
+ Check OTEDetectionNNCFTask object initialization parameters validation
+
+ Input data:
+ OTEDetectionNNCFTask object initialization parameters with unexpected type
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ OTEDetectionNNCFTask object initialization parameter
+ """
+ with pytest.raises(ValueError):
+ OTEDetectionNNCFTask(task_environment="unexpected string") # type: ignore
+
+ @e2e_pytest_unit
+ def test_nncf_detection_task_optimize_params_validation(self):
+ """
+ Description:
+ Check OTEDetectionNNCFTask object "optimize" method input parameters validation
+
+ Input data:
+ OTEDetectionNNCFTask object. "optimize" method unexpected-type input parameters
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "optimize" method
+ """
+ task = MockNNCFTask()
+ correct_values_dict = {
+ "optimization_type": OptimizationType.NNCF,
+ "dataset": DatasetEntity(),
+ "output_model": self.model(),
+ }
+ unexpected_str = "unexpected string"
+ unexpected_values = [
+ # Unexpected string is specified as "optimization_type" parameter
+ ("optimization_type", unexpected_str),
+ # Unexpected string is specified as "dataset" parameter
+ ("dataset", unexpected_str),
+ # Unexpected string is specified as "output_model" parameter
+ ("output_model", unexpected_str),
+ # Unexpected string is specified as "optimization_parameters" parameter
+ ("optimization_parameters", unexpected_str),
+ ]
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=task.optimize,
+ )
+
+ @e2e_pytest_unit
+ def test_nncf_detection_task_export_params_validation(self):
+ """
+ Description:
+ Check OTEDetectionNNCFTask object "export" method input parameters validation
+
+ Input data:
+ OTEDetectionNNCFTask object. "export" method unexpected-type input parameters
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "export" method
+ """
+ task = MockNNCFTask()
+ correct_values_dict = {
+ "export_type": ExportType.OPENVINO,
+ "output_model": self.model(),
+ }
+ unexpected_str = "unexpected string"
+ unexpected_values = [
+ # Unexpected string is specified as "export_type" parameter
+ ("export_type", unexpected_str),
+ # Unexpected string is specified as "output_model" parameter
+ ("output_model", unexpected_str),
+ ]
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=task.export,
+ )
+
+ @e2e_pytest_unit
+ def test_nncf_detection_task_save_model_params_validation(self):
+ """
+ Description:
+ Check OTEDetectionNNCFTask object "save_model" method input parameters validation
+
+ Input data:
+ OTEDetectionNNCFTask object, "output_model" non-ModelEntity object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "save_model" method
+ """
+ task = MockNNCFTask()
+ with pytest.raises(ValueError):
+ task.save_model(output_model="unexpected string") # type: ignore
diff --git a/external/mmdetection/tests/ote_params_validation/test_ote_openvino_task_params_validation.py b/external/mmdetection/tests/ote_params_validation/test_ote_openvino_task_params_validation.py
new file mode 100644
index 00000000000..b1b1f8f433c
--- /dev/null
+++ b/external/mmdetection/tests/ote_params_validation/test_ote_openvino_task_params_validation.py
@@ -0,0 +1,503 @@
+# Copyright (C) 2021-2022 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+#
+
+import numpy as np
+import pytest
+from detection_tasks.apis.detection.configuration import OTEDetectionConfig
+from detection_tasks.apis.detection.openvino_task import (
+ BaseInferencerWithConverter,
+ OpenVINODetectionInferencer,
+ OpenVINODetectionTask,
+ OpenVINOMaskInferencer,
+ OpenVINORotatedRectInferencer,
+ OTEOpenVinoDataLoader,
+)
+from openvino.model_zoo.model_api.models import Model
+from ote_sdk.configuration.configurable_parameters import ConfigurableParameters
+from ote_sdk.entities.datasets import DatasetEntity
+from ote_sdk.entities.label import Domain, LabelEntity
+from ote_sdk.entities.label_schema import LabelSchemaEntity
+from ote_sdk.entities.model import ModelConfiguration, ModelEntity
+from ote_sdk.entities.resultset import ResultSetEntity
+from ote_sdk.test_suite.e2e_test_system import e2e_pytest_unit
+from ote_sdk.tests.parameters_validation.validation_helper import (
+ check_value_error_exception_raised,
+)
+from ote_sdk.usecases.exportable_code.prediction_to_annotation_converter import (
+ DetectionToAnnotationConverter,
+)
+from ote_sdk.usecases.tasks.interfaces.optimization_interface import OptimizationType
+
+
+class MockOpenVinoTask(OpenVINODetectionTask):
+ def __init__(self):
+ pass
+
+
+class MockBaseInferencer(BaseInferencerWithConverter):
+ def __init__(self):
+ pass
+
+
+class MockDetectionInferencer(OpenVINODetectionInferencer):
+ def __init__(self):
+ pass
+
+
+class MockModel(Model):
+ def __init__(self):
+ pass
+
+ def preprocess(self):
+ pass
+
+ def postprocess(self):
+ pass
+
+
+class TestBaseInferencerWithConverterInputParamsValidation:
+ @e2e_pytest_unit
+ def test_base_inferencer_with_converter_init_params_validation(self):
+ """
+ Description:
+ Check BaseInferencerWithConverter object initialization parameters validation
+
+ Input data:
+ BaseInferencerWithConverter object initialization parameters
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ BaseInferencerWithConverter object initialization parameter
+ """
+ model = MockModel()
+ label = LabelEntity(name="test label", domain=Domain.DETECTION)
+ converter = DetectionToAnnotationConverter([label])
+ correct_values_dict = {
+ "configuration": {"inferencer": "configuration"},
+ "model": model,
+ "converter": converter,
+ }
+ unexpected_str = "unexpected string"
+ unexpected_values = [
+ # Unexpected string is specified as "configuration" parameter
+ ("configuration", unexpected_str),
+ # Unexpected string is specified as "model" parameter
+ ("model", unexpected_str),
+ # Unexpected string is specified as "converter" parameter
+ ("converter", unexpected_str),
+ ]
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=BaseInferencerWithConverter,
+ )
+
+ @e2e_pytest_unit
+ def test_base_inferencer_with_converter_pre_process_params_validation(self):
+ """
+ Description:
+ Check BaseInferencerWithConverter object "pre_process" method input parameters validation
+
+ Input data:
+ BaseInferencerWithConverter object, "image" non-ndarray object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "pre_process" method
+ """
+ inferencer = MockBaseInferencer()
+ with pytest.raises(ValueError):
+ inferencer.pre_process(image="unexpected string") # type: ignore
+
+ @e2e_pytest_unit
+ def test_base_inferencer_with_converter_post_process_params_validation(self):
+ """
+ Description:
+ Check BaseInferencerWithConverter object "post_process" method input parameters validation
+
+ Input data:
+ BaseInferencerWithConverter object, "post_process" method unexpected-type input parameters
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "post_process" method
+ """
+ inferencer = MockBaseInferencer()
+ correct_values_dict = {
+ "prediction": {"prediction_1": np.random.rand(2, 2)},
+ "metadata": {"metadata_1": "some_data"},
+ }
+ unexpected_int = 1
+ unexpected_values = [
+ # Unexpected integer is specified as "prediction" parameter
+ ("prediction", unexpected_int),
+ # Unexpected integer is specified as "prediction" dictionary key
+ ("prediction", {unexpected_int: np.random.rand(2, 2)}),
+ # Unexpected integer is specified as "prediction" dictionary value
+ ("prediction", {"prediction_1": unexpected_int}),
+ # Unexpected integer is specified as "metadata" parameter
+ ("metadata", unexpected_int),
+ # Unexpected integer is specified as "metadata" dictionary key
+ ("metadata", {unexpected_int: "some_data"}),
+ ]
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=inferencer.post_process,
+ )
+
+ @e2e_pytest_unit
+ def test_base_inferencer_with_converter_forward_params_validation(self):
+ """
+ Description:
+ Check BaseInferencerWithConverter object "forward" method input parameters validation
+
+ Input data:
+ BaseInferencerWithConverter object, "inputs" unexpected type object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "forward" method
+ """
+ inferencer = MockBaseInferencer()
+ unexpected_int = 1
+ for unexpected_value in [
+ # Unexpected integer is specified as "inputs" parameter
+ unexpected_int,
+ # Unexpected integer is specified as "inputs" dictionary key
+ {unexpected_int: np.random.rand(2, 2)},
+ # Unexpected integer is specified as "inputs" dictionary value
+ {"input_1": unexpected_int},
+ ]:
+ with pytest.raises(ValueError):
+ inferencer.forward(inputs=unexpected_value) # type: ignore
+
+
+class TestOpenVINODetectionInferencerInputParamsValidation:
+ @e2e_pytest_unit
+ def test_openvino_detection_inferencer_init_params_validation(self):
+ """
+ Description:
+ Check OpenVINODetectionInferencer object initialization parameters validation
+
+ Input data:
+ OpenVINODetectionInferencer object initialization parameters with unexpected type
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ OpenVINODetectionInferencer object initialization parameter
+ """
+ correct_values_dict = {
+ "hparams": OTEDetectionConfig("test header"),
+ "label_schema": LabelSchemaEntity(),
+ "model_file": "model data",
+ }
+ unexpected_str = "unexpected string"
+ unexpected_int = 1
+
+ unexpected_values = [
+ # Unexpected string is specified as "hparams" parameter
+ ("hparams", unexpected_str),
+ # Unexpected string is specified as "label_schema" parameter
+ ("label_schema", unexpected_str),
+ # Unexpected integer is specified as "model_file" parameter
+ ("model_file", unexpected_int),
+ # Unexpected integer is specified as "weight_file" parameter
+ ("weight_file", unexpected_int),
+ # Unexpected integer is specified as "device" parameter
+ ("device", unexpected_int),
+ # Unexpected string is specified as "num_requests" parameter
+ ("num_requests", unexpected_str),
+ ]
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=OpenVINODetectionInferencer,
+ )
+
+
+class TestOpenVINOMaskInferencerInputParamsValidation:
+ @e2e_pytest_unit
+ def test_openvino_mask_inferencer_init_params_validation(self):
+ """
+ Description:
+ Check OpenVINOMaskInferencer object initialization parameters validation
+
+ Input data:
+ OpenVINOMaskInferencer object initialization parameters with unexpected type
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ OpenVINOMaskInferencer object initialization parameter
+ """
+ correct_values_dict = {
+ "hparams": OTEDetectionConfig("test header"),
+ "label_schema": LabelSchemaEntity(),
+ "model_file": "model data",
+ }
+ unexpected_str = "unexpected string"
+ unexpected_int = 1
+
+ unexpected_values = [
+ # Unexpected string is specified as "hparams" parameter
+ ("hparams", unexpected_str),
+ # Unexpected string is specified as "label_schema" parameter
+ ("label_schema", unexpected_str),
+ # Unexpected integer is specified as "model_file" parameter
+ ("model_file", unexpected_int),
+ # Unexpected integer is specified as "weight_file" parameter
+ ("weight_file", unexpected_int),
+ # Unexpected integer is specified as "device" parameter
+ ("device", unexpected_int),
+ # Unexpected string is specified as "num_requests" parameter
+ ("num_requests", unexpected_str),
+ ]
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=OpenVINOMaskInferencer,
+ )
+
+
+class TestOpenVINORotatedRectInferencerInputParamsValidation:
+ @e2e_pytest_unit
+ def test_openvino_rotated_rect_inferencer_init_params_validation(self):
+ """
+ Description:
+ Check OpenVINORotatedRectInferencer object initialization parameters validation
+
+ Input data:
+ OpenVINORotatedRectInferencer object initialization parameters with unexpected type
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ OpenVINORotatedRectInferencer object initialization parameter
+ """
+ correct_values_dict = {
+ "hparams": OTEDetectionConfig("test header"),
+ "label_schema": LabelSchemaEntity(),
+ "model_file": "model data",
+ }
+ unexpected_str = "unexpected string"
+ unexpected_int = 1
+ unexpected_values = [
+ # Unexpected string is specified as "hparams" parameter
+ ("hparams", unexpected_str),
+ # Unexpected string is specified as "label_schema" parameter
+ ("label_schema", unexpected_str),
+ # Unexpected integer is specified as "model_file" parameter
+ ("model_file", unexpected_int),
+ # Unexpected integer is specified as "weight_file" parameter
+ ("weight_file", unexpected_int),
+ # Unexpected integer is specified as "device" parameter
+ ("device", unexpected_int),
+ # Unexpected string is specified as "num_requests" parameter
+ ("num_requests", unexpected_str),
+ ]
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=OpenVINORotatedRectInferencer,
+ )
+
+
+class TestOTEOpenVinoDataLoaderInputParamsValidation:
+ @staticmethod
+ def detection_inferencer(openvino_task):
+ return openvino_task.load_inferencer()
+
+ @e2e_pytest_unit
+ def test_openvino_data_loader_init_params_validation(self):
+ """
+ Description:
+ Check OTEOpenVinoDataLoader object initialization parameters validation
+
+ Input data:
+ OTEOpenVinoDataLoader object initialization parameters with unexpected type
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ OTEOpenVinoDataLoader object initialization parameter
+ """
+ correct_values_dict = {
+ "dataset": DatasetEntity(),
+ "inferencer": MockDetectionInferencer(),
+ }
+ unexpected_str = "unexpected string"
+ unexpected_values = [
+ # Unexpected string is specified as "dataset" parameter
+ ("dataset", unexpected_str),
+ # Unexpected string is specified as "inferencer" parameter
+ ("inferencer", unexpected_str),
+ ]
+
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=OTEOpenVinoDataLoader,
+ )
+
+ @e2e_pytest_unit
+ def test_openvino_data_loader_getitem_input_params_validation(self):
+ """
+ Description:
+ Check OTEOpenVinoDataLoader object "__getitem__" method input parameters validation
+
+ Input data:
+ OTEOpenVinoDataLoader object. "__getitem__" method unexpected-type input parameters
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "__getitem__" method
+ """
+ data_loader = OTEOpenVinoDataLoader(
+ dataset=DatasetEntity(), inferencer=MockDetectionInferencer()
+ )
+ with pytest.raises(ValueError):
+ data_loader.__getitem__("unexpected string") # type: ignore
+
+
+class TestOpenVINODetectionTaskInputParamsValidation:
+ @staticmethod
+ def model():
+ model_configuration = ModelConfiguration(
+ configurable_parameters=ConfigurableParameters(
+ header="header", description="description"
+ ),
+ label_schema=LabelSchemaEntity(),
+ )
+ return ModelEntity(
+ train_dataset=DatasetEntity(), configuration=model_configuration
+ )
+
+ @e2e_pytest_unit
+ def test_openvino_task_init_params_validation(self):
+ """
+ Description:
+ Check OpenVINODetectionTask object initialization parameters validation
+
+ Input data:
+ "task_environment" non-TaskEnvironment object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ OpenVINODetectionTask object initialization parameter
+ """
+ with pytest.raises(ValueError):
+ OpenVINODetectionTask(task_environment="unexpected string") # type: ignore
+
+ @e2e_pytest_unit
+ def test_openvino_task_infer_params_validation(self):
+ """
+ Description:
+ Check OpenVINODetectionTask object "infer" method input parameters validation
+
+ Input data:
+ OpenVINODetectionTask object. "infer" method unexpected-type input parameters
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "infer" method
+ """
+ task = MockOpenVinoTask()
+ correct_values_dict = {"dataset": DatasetEntity()}
+ unexpected_str = "unexpected string"
+ unexpected_values = [
+ # Unexpected string is specified as "dataset" parameter
+ ("dataset", unexpected_str),
+ # Unexpected string is specified as "inference_parameters" parameter
+ ("inference_parameters", unexpected_str),
+ ]
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=task.infer,
+ )
+
+ @e2e_pytest_unit
+ def test_openvino_task_evaluate_params_validation(self):
+ """
+ Description:
+ Check OpenVINODetectionTask object "evaluate" method input parameters validation
+
+ Input data:
+ OpenVINODetectionTask object. "evaluate" method unexpected-type input parameters
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "evaluate" method
+ """
+ result_set = ResultSetEntity(
+ model=self.model(),
+ ground_truth_dataset=DatasetEntity(),
+ prediction_dataset=DatasetEntity(),
+ )
+ task = MockOpenVinoTask()
+ correct_values_dict = {"output_result_set": result_set}
+ unexpected_int = 1
+ unexpected_values = [
+ # Unexpected integer is specified as "output_result_set" parameter
+ ("output_result_set", unexpected_int),
+ # Unexpected integer is specified as "evaluation_metric" parameter
+ ("evaluation_metric", unexpected_int),
+ ]
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=task.evaluate,
+ )
+
+ @e2e_pytest_unit
+ def test_openvino_task_deploy_params_validation(self):
+ """
+ Description:
+ Check OpenVINODetectionTask object "deploy" method input parameters validation
+
+ Input data:
+ OpenVINODetectionTask object. "output_model" non-ModelEntity object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "deploy" method
+ """
+ task = MockOpenVinoTask()
+ with pytest.raises(ValueError):
+ task.deploy("unexpected string") # type: ignore
+
+ @e2e_pytest_unit
+ def test_openvino_task_optimize_params_validation(self):
+ """
+ Description:
+ Check OpenVINODetectionTask object "optimize" method input parameters validation
+
+ Input data:
+ OpenVINODetectionTask object. "optimize" method unexpected-type input parameters
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "optimize" method
+ """
+ task = MockOpenVinoTask()
+ correct_values_dict = {
+ "optimization_type": OptimizationType.NNCF,
+ "dataset": DatasetEntity(),
+ "output_model": self.model(),
+ "optimization_parameters": None,
+ }
+ unexpected_str = "unexpected string"
+ unexpected_values = [
+ # Unexpected string is specified as "optimization_type" parameter
+ ("optimization_type", unexpected_str),
+ # Unexpected string is specified as "dataset" parameter
+ ("dataset", unexpected_str),
+ # Unexpected string is specified as "output_model" parameter
+ ("output_model", unexpected_str),
+ # Unexpected string is specified as "optimization_parameters" parameter
+ ("optimization_parameters", unexpected_str),
+ ]
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=task.optimize,
+ )
diff --git a/external/mmdetection/tests/ote_params_validation/test_ote_ote_utils_params_validation.py b/external/mmdetection/tests/ote_params_validation/test_ote_ote_utils_params_validation.py
new file mode 100644
index 00000000000..f0c73a72e85
--- /dev/null
+++ b/external/mmdetection/tests/ote_params_validation/test_ote_ote_utils_params_validation.py
@@ -0,0 +1,148 @@
+# Copyright (C) 2021-2022 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+#
+
+import pytest
+from detection_tasks.apis.detection.ote_utils import (
+ ColorPalette,
+ generate_label_schema,
+ get_task_class,
+ load_template,
+)
+
+from ote_sdk.test_suite.e2e_test_system import e2e_pytest_unit
+from ote_sdk.tests.parameters_validation.validation_helper import (
+ check_value_error_exception_raised,
+)
+
+
+class TestColorPaletteInputParamsValidation:
+ @staticmethod
+ def color_palette():
+ return ColorPalette(1)
+
+ @e2e_pytest_unit
+ def test_color_palette_init_params_validation(self):
+ """
+ Description:
+ Check ColorPalette object initialization parameters validation
+
+ Input data:
+ ColorPalette object initialization parameters with unexpected type
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ ColorPalette object initialization parameter
+ """
+ correct_values_dict = {
+ "n": 1,
+ }
+ unexpected_str = "unexpected string"
+ unexpected_values = [
+ # Unexpected string is specified as "n" parameter
+ ("n", unexpected_str),
+ # Unexpected string is specified as "rng" parameter
+ ("rng", unexpected_str),
+ ]
+
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=ColorPalette,
+ )
+
+ @e2e_pytest_unit
+ def test_color_palette_get_item_params_validation(self):
+ """
+ Description:
+ Check ColorPalette object "__getitem__" method input parameters validation
+
+ Input data:
+ ColorPalette object, "n" non-integer object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "__getitem__" method
+ """
+ color_palette = self.color_palette()
+ with pytest.raises(ValueError):
+ color_palette.__getitem__("unexpected string") # type: ignore
+
+
+class TestOTEUtilsFunctionsInputParamsValidation:
+ @e2e_pytest_unit
+ def test_generate_label_schema_input_params_validation(self):
+ """
+ Description:
+ Check "generate_label_schema" function input parameters validation
+
+ Input data:
+ "generate_label_schema" function unexpected-type input parameters
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "generate_label_schema" function
+ """
+ correct_values_dict = {
+ "label_names": ["label_1", "label_2"],
+ }
+ unexpected_int = 1
+ unexpected_values = [
+ # Unexpected integer is specified as "label_names" parameter
+ ("label_names", unexpected_int),
+ # Unexpected integer is specified as nested label name
+ ("label_names", ["label_1", unexpected_int]),
+ # Unexpected integer is specified as "label_domain" parameter
+ ("label_domain", unexpected_int),
+ ]
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=generate_label_schema,
+ )
+
+ @e2e_pytest_unit
+ def test_load_template_params_validation(self):
+ """
+ Description:
+ Check "load_template" function input parameters validation
+
+ Input data:
+ "path" unexpected string with yaml file object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "load_template" function
+ """
+ for incorrect_parameter in [
+ # Unexpected integer is specified as "path" parameter
+ 1,
+ # Empty string is specified as "path" parameter
+ "",
+ # Path to non-existing file is specified as "path" parameter
+ "./non_existing.yaml",
+ # Path to non-yaml file is specified as "path" parameter
+ "./unexpected_type.jpg",
+ # Path Null character is specified in "path" parameter
+ "./null\0char.yaml",
+ # Path with non-printable character is specified as "path" parameter
+ "./non\nprintable.yaml",
+ ]:
+ with pytest.raises(ValueError):
+ load_template(incorrect_parameter)
+
+ @e2e_pytest_unit
+ def test_get_task_class_input_params_validation(self):
+ """
+ Description:
+ Check "get_task_class" function input parameters validation
+
+ Input data:
+ "path" non string-type object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "get_task_class" function
+ """
+ with pytest.raises(ValueError):
+ get_task_class(path=1) # type: ignore
diff --git a/external/mmdetection/tests/ote_params_validation/test_ote_pipelines_params_validation.py b/external/mmdetection/tests/ote_params_validation/test_ote_pipelines_params_validation.py
new file mode 100644
index 00000000000..8ec743008c4
--- /dev/null
+++ b/external/mmdetection/tests/ote_params_validation/test_ote_pipelines_params_validation.py
@@ -0,0 +1,123 @@
+# Copyright (C) 2021-2022 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+#
+
+import pytest
+from detection_tasks.extension.utils.pipelines import (
+ LoadAnnotationFromOTEDataset,
+ LoadImageFromOTEDataset,
+)
+
+from ote_sdk.test_suite.e2e_test_system import e2e_pytest_unit
+from ote_sdk.tests.parameters_validation.validation_helper import (
+ check_value_error_exception_raised,
+)
+
+
+class TestLoadImageFromOTEDatasetInputParamsValidation:
+ @e2e_pytest_unit
+ def test_load_image_from_ote_dataset_init_params_validation(self):
+ """
+ Description:
+ Check LoadImageFromOTEDataset object initialization parameters validation
+
+ Input data:
+ "to_float32" non-bool parameter
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ LoadImageFromOTEDataset object initialization parameter
+ """
+ with pytest.raises(ValueError):
+ LoadImageFromOTEDataset("unexpected string") # type: ignore
+
+ @e2e_pytest_unit
+ def test_load_image_from_ote_dataset_call_params_validation(self):
+ """
+ Description:
+ Check LoadImageFromOTEDataset object "__call__" method input parameters validation
+
+ Input data:
+ LoadImageFromOTEDataset object, "results" unexpected type object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "__call__" method
+ """
+ load_image_from_ote_dataset = LoadImageFromOTEDataset()
+ unexpected_int = 1
+ for unexpected_value in [
+ # Unexpected integer is specified as "results" parameter
+ unexpected_int,
+ # Unexpected integer is specified as "results" dictionary key
+ {"result_1": "some results", unexpected_int: "unexpected results"},
+ ]:
+ with pytest.raises(ValueError):
+ load_image_from_ote_dataset.__call__(results=unexpected_value)
+
+
+class TestLoadAnnotationFromOTEDatasetInputParamsValidation:
+ @e2e_pytest_unit
+ def test_load_annotation_from_ote_dataset_init_params_validation(self):
+ """
+ Description:
+ Check LoadAnnotationFromOTEDataset object initialization parameters validation
+
+ Input data:
+ LoadAnnotationFromOTEDataset object initialization parameters with unexpected type
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ LoadAnnotationFromOTEDataset object initialization parameter
+ """
+ correct_values_dict = {
+ "min_size": 1,
+ }
+ unexpected_str = "unexpected string"
+ unexpected_values = [
+ # Unexpected string is specified as "min_size" parameter
+ ("min_size", unexpected_str),
+ # Unexpected string is specified as "with_bbox" parameter
+ ("with_bbox", unexpected_str),
+ # Unexpected string is specified as "with_label" parameter
+ ("with_label", unexpected_str),
+ # Unexpected string is specified as "with_mask" parameter
+ ("with_mask", unexpected_str),
+ # Unexpected string is specified as "with_seg" parameter
+ ("with_seg", unexpected_str),
+ # Unexpected string is specified as "poly2mask" parameter
+ ("poly2mask", unexpected_str),
+ # Unexpected string is specified as "with_text" parameter
+ ("with_text", unexpected_str),
+ # Unexpected string is specified as "domain" parameter
+ ("domain", unexpected_str),
+ ]
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=LoadAnnotationFromOTEDataset,
+ )
+
+ @e2e_pytest_unit
+ def test_load_annotation_from_ote_dataset_call_params_validation(self):
+ """
+ Description:
+ Check LoadAnnotationFromOTEDataset object "__call__" method input parameters validation
+
+ Input data:
+ LoadAnnotationFromOTEDataset object, "results" unexpected type object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "__call__" method
+ """
+ load_annotation_from_ote_dataset = LoadAnnotationFromOTEDataset(min_size=1)
+ unexpected_int = 1
+ for unexpected_value in [
+ # Unexpected integer is specified as "results" parameter
+ unexpected_int,
+ # Unexpected integer is specified as "results" dictionary key
+ {"result_1": "some results", unexpected_int: "unexpected results"},
+ ]:
+ with pytest.raises(ValueError):
+ load_annotation_from_ote_dataset(results=unexpected_value)
diff --git a/external/mmdetection/tests/ote_params_validation/test_ote_runner_params_validation.py b/external/mmdetection/tests/ote_params_validation/test_ote_runner_params_validation.py
new file mode 100644
index 00000000000..79699cdeefa
--- /dev/null
+++ b/external/mmdetection/tests/ote_params_validation/test_ote_runner_params_validation.py
@@ -0,0 +1,130 @@
+# Copyright (C) 2021-2022 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+#
+
+from logging import Logger
+
+import pytest
+import torch.nn as nn
+from detection_tasks.extension.utils.runner import (
+ EpochRunnerWithCancel,
+ IterBasedRunnerWithCancel,
+ IterLoader,
+)
+from ote_sdk.test_suite.e2e_test_system import e2e_pytest_unit
+from ote_sdk.tests.parameters_validation.validation_helper import (
+ check_value_error_exception_raised,
+ load_test_dataset
+)
+from torch.utils.data.dataloader import DataLoader
+
+
+class TestRunnersInputParamsValidation:
+ def iter_based_runner(self):
+ return IterBasedRunnerWithCancel(
+ model=self.MockModel(), logger=Logger(name="test logger")
+ )
+
+ @staticmethod
+ def data_loader():
+ dataset = load_test_dataset()[0]
+ return DataLoader(dataset)
+
+ class MockModel(nn.Module):
+ @staticmethod
+ def train_step():
+ pass
+
+ @e2e_pytest_unit
+ def test_epoch_runner_with_cancel_train_params_validation(self):
+ """
+ Description:
+ Check EpochRunnerWithCancel object "train" method input parameters validation
+
+ Input data:
+ EpochRunnerWithCancel object. "data_loader" non DataLoader object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "train" method
+ """
+ runner = EpochRunnerWithCancel(
+ model=self.MockModel(), logger=Logger(name="test logger")
+ )
+ with pytest.raises(ValueError):
+ runner.train(data_loader="unexpected string") # type: ignore
+
+ @e2e_pytest_unit
+ def test_iter_based_runner_with_cancel_main_loop_params_validation(self):
+ """
+ Description:
+ Check IterBasedRunnerWithCancel object "main_loop" method input parameters validation
+
+ Input data:
+ IterBasedRunnerWithCancel object. "main_loop" method unexpected-type input parameters
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "main_loop" method
+ """
+ data_loader = self.data_loader()
+ iter_loader = IterLoader(data_loader)
+ runner = self.iter_based_runner()
+ correct_values_dict = {
+ "workflow": [("train", 1)],
+ "iter_loaders": [iter_loader],
+ }
+ unexpected_int = 1
+ unexpected_values = [
+ # Unexpected integer is specified as "workflow" parameter
+ ("workflow", unexpected_int),
+ # Unexpected integer is specified as nested workflow
+ ("workflow", [("train", 1), unexpected_int]),
+ # Unexpected integer is specified as "iter_loaders" parameter
+ ("iter_loaders", unexpected_int),
+ # Unexpected integer is specified as nested iter_loader
+ ("iter_loaders", [iter_loader, unexpected_int]),
+ ]
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=runner.main_loop,
+ )
+
+ @e2e_pytest_unit
+ def test_iter_based_runner_with_cancel_run_params_validation(self):
+ """
+ Description:
+ Check IterBasedRunnerWithCancel object "run" method input parameters validation
+
+ Input data:
+ IterBasedRunnerWithCancel object. "run" method unexpected-type input parameters
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "run" method
+ """
+ data_loader = self.data_loader()
+ runner = self.iter_based_runner()
+ correct_values_dict = {
+ "data_loaders": [data_loader],
+ "workflow": [("train", 1)],
+ }
+ unexpected_int = 1
+ unexpected_values = [
+ # Unexpected integer is specified as "data_loaders" parameter
+ ("data_loaders", unexpected_int),
+ # Unexpected integer is specified as nested data_loader
+ ("data_loaders", [data_loader, unexpected_int]),
+ # Unexpected integer is specified as "workflow" parameter
+ ("workflow", unexpected_int),
+ # Unexpected integer is specified as nested workflow
+ ("workflow", [("train", 1), unexpected_int]),
+ # Unexpected string is specified as "max_iters" parameter
+ ("max_iters", "unexpected string"),
+ ]
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=runner.run,
+ )
diff --git a/external/mmdetection/tests/ote_params_validation/test_ote_train_task_params_validation.py b/external/mmdetection/tests/ote_params_validation/test_ote_train_task_params_validation.py
new file mode 100644
index 00000000000..14d29bafd90
--- /dev/null
+++ b/external/mmdetection/tests/ote_params_validation/test_ote_train_task_params_validation.py
@@ -0,0 +1,85 @@
+# Copyright (C) 2021-2022 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+#
+
+import pytest
+
+from detection_tasks.apis.detection.train_task import OTEDetectionTrainingTask
+from ote_sdk.configuration.configurable_parameters import ConfigurableParameters
+from ote_sdk.entities.datasets import DatasetEntity
+from ote_sdk.entities.label_schema import LabelSchemaEntity
+from ote_sdk.entities.model import ModelConfiguration, ModelEntity
+from ote_sdk.test_suite.e2e_test_system import e2e_pytest_unit
+from ote_sdk.tests.parameters_validation.validation_helper import (
+ check_value_error_exception_raised,
+)
+
+
+class MockDetectionTrainingTask(OTEDetectionTrainingTask):
+ def __init__(self):
+ pass
+
+
+class TestOTEDetectionTrainingTaskInputParamsValidation:
+ @staticmethod
+ def model():
+ model_configuration = ModelConfiguration(
+ configurable_parameters=ConfigurableParameters(
+ header="header", description="description"
+ ),
+ label_schema=LabelSchemaEntity(),
+ )
+ return ModelEntity(
+ train_dataset=DatasetEntity(), configuration=model_configuration
+ )
+
+ @e2e_pytest_unit
+ def test_train_task_train_input_params_validation(self):
+ """
+ Description:
+ Check OTEDetectionTrainingTask object "train" method input parameters validation
+
+ Input data:
+ OTEDetectionTrainingTask object, "train" method unexpected-type input parameters
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "train" method
+ """
+ task = MockDetectionTrainingTask()
+ correct_values_dict = {
+ "dataset": DatasetEntity(),
+ "output_model": self.model(),
+ }
+ unexpected_str = "unexpected string"
+ unexpected_values = [
+ # Unexpected string is specified as "dataset" parameter
+ ("dataset", unexpected_str),
+ # Unexpected string is specified as "output_model" parameter
+ ("output_model", unexpected_str),
+ # Unexpected string is specified as "train_parameters" parameter
+ ("train_parameters", unexpected_str),
+ ]
+
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=task.train,
+ )
+
+ @e2e_pytest_unit
+ def test_train_task_save_model_input_params_validation(self):
+ """
+ Description:
+ Check OTEDetectionTrainingTask object "save_model" method input parameters validation
+
+ Input data:
+ OTEDetectionTrainingTask object, "model" non-ModelEntity object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "save_model" method
+ """
+ task = MockDetectionTrainingTask()
+ with pytest.raises(ValueError):
+ task.save_model("unexpected string") # type: ignore
diff --git a/external/mmdetection/tests/test_ote_api.py b/external/mmdetection/tests/test_ote_api.py
index cedb2d6b6ef..8baf38f0e14 100644
--- a/external/mmdetection/tests/test_ote_api.py
+++ b/external/mmdetection/tests/test_ote_api.py
@@ -214,7 +214,7 @@ def test_cancel_training_detection(self):
def progress_callback(progress: float, score: Optional[float] = None):
training_progress_curve.append(progress)
- train_parameters = TrainParameters
+ train_parameters = TrainParameters()
train_parameters.update_progress = progress_callback
# Test stopping after some time
@@ -254,7 +254,7 @@ def test_training_progress_tracking(self):
def progress_callback(progress: float, score: Optional[float] = None):
training_progress_curve.append(progress)
- train_parameters = TrainParameters
+ train_parameters = TrainParameters()
train_parameters.update_progress = progress_callback
output_model = ModelEntity(
dataset,
@@ -282,7 +282,7 @@ def test_nncf_optimize_progress_tracking(self):
dataset,
detection_environment.get_model_configuration(),
)
- task.train(dataset, original_model, TrainParameters)
+ task.train(dataset, original_model, TrainParameters())
# Create NNCFTask
detection_environment.model = original_model
@@ -301,7 +301,7 @@ def test_nncf_optimize_progress_tracking(self):
def progress_callback(progress: float, score: Optional[float] = None):
training_progress_curve.append(progress)
- optimization_parameters = OptimizationParameters
+ optimization_parameters = OptimizationParameters()
optimization_parameters.update_progress = progress_callback
nncf_model = ModelEntity(
dataset,
@@ -329,7 +329,7 @@ def progress_callback(progress: int):
assert isinstance(progress, int)
inference_progress_curve.append(progress)
- inference_parameters = InferenceParameters
+ inference_parameters = InferenceParameters()
inference_parameters.update_progress = progress_callback
task.infer(dataset.with_empty_annotations(), inference_parameters)
@@ -352,7 +352,7 @@ def test_inference_task(self):
dataset,
detection_environment.get_model_configuration(),
)
- train_task.train(dataset, trained_model, TrainParameters)
+ train_task.train(dataset, trained_model, TrainParameters())
performance_after_train = self.eval(train_task, trained_model, val_dataset)
# Create InferenceTask
diff --git a/ote_sdk/ote_sdk/utils/argument_checks.py b/ote_sdk/ote_sdk/utils/argument_checks.py
index 144f78f618c..54f98261c32 100644
--- a/ote_sdk/ote_sdk/utils/argument_checks.py
+++ b/ote_sdk/ote_sdk/utils/argument_checks.py
@@ -268,9 +268,7 @@ def check_file_extension(
def check_that_null_character_absents_in_string(parameter: str, parameter_name: str):
"""Function raises ValueError exception if null character: '\0' is specified in path to file"""
if "\0" in parameter:
- raise ValueError(
- rf"null char \\0 is specified in {parameter_name}: {parameter}"
- )
+ raise ValueError(f"null char \\0 is specified in {parameter_name}: {parameter}")
def check_that_file_exists(file_path: str, file_path_name: str):
@@ -475,7 +473,7 @@ def __init__(self, parameter, parameter_name):
class JsonFilePathCheck(FilePathCheck):
- """Class to check json file path parameters"""
+ """Class to check optional yaml file path parameters"""
def __init__(self, parameter, parameter_name):
super().__init__(