From ec20e10dbc6a8a4836fd67cac544ef92ebb6266f Mon Sep 17 00:00:00 2001 From: eunwoosh Date: Fri, 5 Jan 2024 12:34:11 +0900 Subject: [PATCH] deal with HPO edge case --- src/otx/cli/utils/hpo.py | 13 ++++++++++--- tests/unit/cli/utils/test_hpo.py | 14 ++++++++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/otx/cli/utils/hpo.py b/src/otx/cli/utils/hpo.py index 5a0a82d50af..2e0773bf57e 100644 --- a/src/otx/cli/utils/hpo.py +++ b/src/otx/cli/utils/hpo.py @@ -426,12 +426,16 @@ def _align_batch_size_search_space_to_dataset_size(self): if "range" in self._hpo_config["hp_space"][batch_size_name]: max_val = self._hpo_config["hp_space"][batch_size_name]["range"][1] min_val = self._hpo_config["hp_space"][batch_size_name]["range"][0] + step = 1 + if self._hpo_config["hp_space"][batch_size_name]["param_type"] in ["quniform", "qloguniform"]: + step = self._hpo_config["hp_space"][batch_size_name]["range"][2] if max_val > self._train_dataset_size: max_val = self._train_dataset_size self._hpo_config["hp_space"][batch_size_name]["range"][1] = max_val else: max_val = self._hpo_config["hp_space"][batch_size_name]["max"] min_val = self._hpo_config["hp_space"][batch_size_name]["min"] + step = self._hpo_config["hp_space"][batch_size_name].get("step", 1) if max_val > self._train_dataset_size: max_val = self._train_dataset_size @@ -439,10 +443,13 @@ def _align_batch_size_search_space_to_dataset_size(self): # If trainset size is lower than min batch size range, # fix batch size to trainset size + reason_to_fix_bs = "" if min_val >= max_val: - logger.info( - "Train set size is equal or lower than batch size range. Batch size is fixed to train set size." - ) + reason_to_fix_bs = "Train set size is equal or lower than batch size range." + elif max_val - min_val < step: + reason_to_fix_bs = "Difference between min and train set size is lesser than step." + if reason_to_fix_bs: + logger.info(f"{reason_to_fix_bs} Batch size is fixed to train set size.") del self._hpo_config["hp_space"][batch_size_name] self._fixed_hp[batch_size_name] = self._train_dataset_size self._environment.set_hyper_parameter_using_str_key(self._fixed_hp) diff --git a/tests/unit/cli/utils/test_hpo.py b/tests/unit/cli/utils/test_hpo.py index d0b9467d1e2..f01a048a195 100644 --- a/tests/unit/cli/utils/test_hpo.py +++ b/tests/unit/cli/utils/test_hpo.py @@ -1,4 +1,5 @@ import json +import yaml from copy import deepcopy from pathlib import Path from tempfile import TemporaryDirectory @@ -450,6 +451,19 @@ def test_init_wrong_hpo_time_ratio(self, cls_task_env, hpo_time_ratio): with pytest.raises(ValueError): HpoRunner(cls_task_env, 100, 10, "fake_path", hpo_time_ratio) + @e2e_pytest_unit + @pytest.mark.parametrize("diff_from_min_bs", [0, 1]) + def test_init_fix_batch_size(self, cls_task_env, diff_from_min_bs): + task_env = TaskEnvironmentManager(cls_task_env) + with (Path(task_env.get_model_template_path()).parent / "hpo_config.yaml").open() as f: + hpo_config = yaml.safe_load(f) + batch_size_name = task_env.get_batch_size_name() + min_bs = hpo_config["hp_space"][batch_size_name]["range"][0] + train_dataset_size = min_bs + diff_from_min_bs + + hpo_runner = HpoRunner(cls_task_env, train_dataset_size, 10, "fake_path") + assert batch_size_name in hpo_runner._fixed_hp + @e2e_pytest_unit def test_run_hpo(self, mocker, cls_task_env): cls_task_env.model = None