From c99aebd84ee7a22e4dfe5fcc088f715b5b5e29eb Mon Sep 17 00:00:00 2001 From: saltykox Date: Tue, 5 Apr 2022 13:46:33 +0300 Subject: [PATCH] added tests to cover get_data_cfg function and StopLossNanTrainingHook after_train_iter method --- .../apis/detection/config_utils.py | 1 + ...test_ote_config_utils_params_validation.py | 31 +++++++++++++++++++ .../test_ote_hooks_params_validation.py | 20 ++++++++++++ .../test_ote_train_task_params_validation.py | 2 +- 4 files changed, 53 insertions(+), 1 deletion(-) diff --git a/external/mmdetection/detection_tasks/apis/detection/config_utils.py b/external/mmdetection/detection_tasks/apis/detection/config_utils.py index 0151b68a856..3c1d2a04a91 100644 --- a/external/mmdetection/detection_tasks/apis/detection/config_utils.py +++ b/external/mmdetection/detection_tasks/apis/detection/config_utils.py @@ -353,6 +353,7 @@ def cluster_anchors(config: Config, dataset: DatasetEntity, model: BaseDetector) return config, model +@check_input_parameters_type() def get_data_cfg(config: Config, subset: str = 'train') -> Config: data_cfg = config.data[subset] while 'dataset' in data_cfg: diff --git a/external/mmdetection/tests/ote_params_validation/test_ote_config_utils_params_validation.py b/external/mmdetection/tests/ote_params_validation/test_ote_config_utils_params_validation.py index 0a42ec15c89..5bca36c90b7 100644 --- a/external/mmdetection/tests/ote_params_validation/test_ote_config_utils_params_validation.py +++ b/external/mmdetection/tests/ote_params_validation/test_ote_config_utils_params_validation.py @@ -9,6 +9,7 @@ cluster_anchors, config_from_string, config_to_string, + get_data_cfg, is_epoch_based_runner, patch_adaptive_repeat_dataset, patch_config, @@ -416,3 +417,33 @@ def test_cluster_anchors_input_params_validation(self): unexpected_values=unexpected_values, class_or_function=cluster_anchors, ) + + @e2e_pytest_unit + def test_get_data_cfg_input_params_validation(self): + """ + Description: + Check "get_data_cfg" function input parameters validation + + Input data: + "get_data_cfg" function input parameters with unexpected type + + Expected results: + Test passes if ValueError exception is raised when unexpected type object is specified as input parameter for + "get_data_cfg" function + """ + correct_values_dict = { + "config": Config(), + } + unexpected_int = 1 + unexpected_values = [ + # Unexpected integer is specified as "config" parameter + ("config", unexpected_int), + # Unexpected integer is specified as "subset" parameter + ("subset", unexpected_int), + ] + + check_value_error_exception_raised( + correct_parameters=correct_values_dict, + unexpected_values=unexpected_values, + class_or_function=get_data_cfg, + ) diff --git a/external/mmdetection/tests/ote_params_validation/test_ote_hooks_params_validation.py b/external/mmdetection/tests/ote_params_validation/test_ote_hooks_params_validation.py index 219ea6b2226..e10ddb5b90e 100644 --- a/external/mmdetection/tests/ote_params_validation/test_ote_hooks_params_validation.py +++ b/external/mmdetection/tests/ote_params_validation/test_ote_hooks_params_validation.py @@ -12,6 +12,7 @@ FixedMomentumUpdaterHook, OTELoggerHook, OTEProgressHook, + StopLossNanTrainingHook, ReduceLROnPlateauLrUpdaterHook, ) from mmcv.runner import EpochBasedRunner @@ -532,3 +533,22 @@ def test_reduce_lr_hook_before_run_params_validation(self): hook = self.hook() with pytest.raises(ValueError): hook.before_run(runner="unexpected string") # type: ignore + + +class TestStopLossNanTrainingHook: + @e2e_pytest_unit + def test_stop_loss_nan_train_hook_after_train_iter_params_validation(self): + """ + Description: + Check StopLossNanTrainingHook object "after_train_iter" method input parameters validation + + Input data: + StopLossNanTrainingHook object, "runner" non-BaseRunner type object + + Expected results: + Test passes if ValueError exception is raised when unexpected type object is specified as + input parameter for "after_train_iter" method + """ + hook = StopLossNanTrainingHook() + with pytest.raises(ValueError): + hook.after_train_iter(runner="unexpected string") # type: ignore diff --git a/external/mmdetection/tests/ote_params_validation/test_ote_train_task_params_validation.py b/external/mmdetection/tests/ote_params_validation/test_ote_train_task_params_validation.py index 14d7a750fe6..14d29bafd90 100644 --- a/external/mmdetection/tests/ote_params_validation/test_ote_train_task_params_validation.py +++ b/external/mmdetection/tests/ote_params_validation/test_ote_train_task_params_validation.py @@ -34,7 +34,7 @@ def model(): ) @e2e_pytest_unit - def test_train_task_input_params_validation(self): + def test_train_task_train_input_params_validation(self): """ Description: Check OTEDetectionTrainingTask object "train" method input parameters validation