Skip to content

Commit

Permalink
Fix a bug that error is raised when train set size is greater than mi…
Browse files Browse the repository at this point in the history
…nimumof batch size in HPO by exactly 1 (#2760)

deal with HPO edge case
  • Loading branch information
eunwoosh authored Jan 8, 2024
1 parent 4d92e6c commit 2a19fda
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
13 changes: 10 additions & 3 deletions src/otx/cli/utils/hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,23 +426,30 @@ def _align_batch_size_search_space_to_dataset_size(self):
if "range" in self._hpo_config["hp_space"][batch_size_name]:
max_val = self._hpo_config["hp_space"][batch_size_name]["range"][1]
min_val = self._hpo_config["hp_space"][batch_size_name]["range"][0]
step = 1
if self._hpo_config["hp_space"][batch_size_name]["param_type"] in ["quniform", "qloguniform"]:
step = self._hpo_config["hp_space"][batch_size_name]["range"][2]
if max_val > self._train_dataset_size:
max_val = self._train_dataset_size
self._hpo_config["hp_space"][batch_size_name]["range"][1] = max_val
else:
max_val = self._hpo_config["hp_space"][batch_size_name]["max"]
min_val = self._hpo_config["hp_space"][batch_size_name]["min"]
step = self._hpo_config["hp_space"][batch_size_name].get("step", 1)

if max_val > self._train_dataset_size:
max_val = self._train_dataset_size
self._hpo_config["hp_space"][batch_size_name]["max"] = max_val

# If trainset size is lower than min batch size range,
# fix batch size to trainset size
reason_to_fix_bs = ""
if min_val >= max_val:
logger.info(
"Train set size is equal or lower than batch size range. Batch size is fixed to train set size."
)
reason_to_fix_bs = "Train set size is equal or lower than batch size range."
elif max_val - min_val < step:
reason_to_fix_bs = "Difference between min and train set size is lesser than step."
if reason_to_fix_bs:
logger.info(f"{reason_to_fix_bs} Batch size is fixed to train set size.")
del self._hpo_config["hp_space"][batch_size_name]
self._fixed_hp[batch_size_name] = self._train_dataset_size
self._environment.set_hyper_parameter_using_str_key(self._fixed_hp)
Expand Down
14 changes: 14 additions & 0 deletions tests/unit/cli/utils/test_hpo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import yaml
from copy import deepcopy
from pathlib import Path
from tempfile import TemporaryDirectory
Expand Down Expand Up @@ -450,6 +451,19 @@ def test_init_wrong_hpo_time_ratio(self, cls_task_env, hpo_time_ratio):
with pytest.raises(ValueError):
HpoRunner(cls_task_env, 100, 10, "fake_path", hpo_time_ratio)

@e2e_pytest_unit
@pytest.mark.parametrize("diff_from_min_bs", [0, 1])
def test_init_fix_batch_size(self, cls_task_env, diff_from_min_bs):
task_env = TaskEnvironmentManager(cls_task_env)
with (Path(task_env.get_model_template_path()).parent / "hpo_config.yaml").open() as f:
hpo_config = yaml.safe_load(f)
batch_size_name = task_env.get_batch_size_name()
min_bs = hpo_config["hp_space"][batch_size_name]["range"][0]
train_dataset_size = min_bs + diff_from_min_bs

hpo_runner = HpoRunner(cls_task_env, train_dataset_size, 10, "fake_path")
assert batch_size_name in hpo_runner._fixed_hp

@e2e_pytest_unit
def test_run_hpo(self, mocker, cls_task_env):
cls_task_env.model = None
Expand Down

0 comments on commit 2a19fda

Please sign in to comment.