diff --git a/external/mmdetection/detection_tasks/apis/detection/config_utils.py b/external/mmdetection/detection_tasks/apis/detection/config_utils.py
index 0151b68a856..3c1d2a04a91 100644
--- a/external/mmdetection/detection_tasks/apis/detection/config_utils.py
+++ b/external/mmdetection/detection_tasks/apis/detection/config_utils.py
@@ -353,6 +353,7 @@ def cluster_anchors(config: Config, dataset: DatasetEntity, model: BaseDetector)
return config, model
+@check_input_parameters_type()
def get_data_cfg(config: Config, subset: str = 'train') -> Config:
data_cfg = config.data[subset]
while 'dataset' in data_cfg:
diff --git a/external/mmdetection/tests/ote_params_validation/test_ote_config_utils_params_validation.py b/external/mmdetection/tests/ote_params_validation/test_ote_config_utils_params_validation.py
index 0a42ec15c89..5bca36c90b7 100644
--- a/external/mmdetection/tests/ote_params_validation/test_ote_config_utils_params_validation.py
+++ b/external/mmdetection/tests/ote_params_validation/test_ote_config_utils_params_validation.py
@@ -9,6 +9,7 @@
cluster_anchors,
config_from_string,
config_to_string,
+ get_data_cfg,
is_epoch_based_runner,
patch_adaptive_repeat_dataset,
patch_config,
@@ -416,3 +417,33 @@ def test_cluster_anchors_input_params_validation(self):
unexpected_values=unexpected_values,
class_or_function=cluster_anchors,
)
+
+ @e2e_pytest_unit
+ def test_get_data_cfg_input_params_validation(self):
+ """
+ Description:
+ Check "get_data_cfg" function input parameters validation
+
+ Input data:
+ "get_data_cfg" function input parameters with unexpected type
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as input parameter for
+ "get_data_cfg" function
+ """
+ correct_values_dict = {
+ "config": Config(),
+ }
+ unexpected_int = 1
+ unexpected_values = [
+ # Unexpected integer is specified as "config" parameter
+ ("config", unexpected_int),
+ # Unexpected integer is specified as "subset" parameter
+ ("subset", unexpected_int),
+ ]
+
+ check_value_error_exception_raised(
+ correct_parameters=correct_values_dict,
+ unexpected_values=unexpected_values,
+ class_or_function=get_data_cfg,
+ )
diff --git a/external/mmdetection/tests/ote_params_validation/test_ote_hooks_params_validation.py b/external/mmdetection/tests/ote_params_validation/test_ote_hooks_params_validation.py
index 219ea6b2226..e10ddb5b90e 100644
--- a/external/mmdetection/tests/ote_params_validation/test_ote_hooks_params_validation.py
+++ b/external/mmdetection/tests/ote_params_validation/test_ote_hooks_params_validation.py
@@ -12,6 +12,7 @@
FixedMomentumUpdaterHook,
OTELoggerHook,
OTEProgressHook,
+ StopLossNanTrainingHook,
ReduceLROnPlateauLrUpdaterHook,
)
from mmcv.runner import EpochBasedRunner
@@ -532,3 +533,22 @@ def test_reduce_lr_hook_before_run_params_validation(self):
hook = self.hook()
with pytest.raises(ValueError):
hook.before_run(runner="unexpected string") # type: ignore
+
+
+class TestStopLossNanTrainingHook:
+ @e2e_pytest_unit
+ def test_stop_loss_nan_train_hook_after_train_iter_params_validation(self):
+ """
+ Description:
+ Check StopLossNanTrainingHook object "after_train_iter" method input parameters validation
+
+ Input data:
+ StopLossNanTrainingHook object, "runner" non-BaseRunner type object
+
+ Expected results:
+ Test passes if ValueError exception is raised when unexpected type object is specified as
+ input parameter for "after_train_iter" method
+ """
+ hook = StopLossNanTrainingHook()
+ with pytest.raises(ValueError):
+ hook.after_train_iter(runner="unexpected string") # type: ignore
diff --git a/external/mmdetection/tests/ote_params_validation/test_ote_train_task_params_validation.py b/external/mmdetection/tests/ote_params_validation/test_ote_train_task_params_validation.py
index 14d7a750fe6..14d29bafd90 100644
--- a/external/mmdetection/tests/ote_params_validation/test_ote_train_task_params_validation.py
+++ b/external/mmdetection/tests/ote_params_validation/test_ote_train_task_params_validation.py
@@ -34,7 +34,7 @@ def model():
)
@e2e_pytest_unit
- def test_train_task_input_params_validation(self):
+ def test_train_task_train_input_params_validation(self):
"""
Description:
Check OTEDetectionTrainingTask object "train" method input parameters validation