Skip to content

Commit

Permalink
update unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
eunwoosh committed May 7, 2024
1 parent 11a0870 commit d6554c2
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 46 deletions.
20 changes: 16 additions & 4 deletions src/otx/cli/utils/hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,9 +657,17 @@ def run_hpo(
logger.debug(f"{best_hpo_weight} will be loaded as best HPO weight")
env_manager.load_model_weight(best_hpo_weight, dataset)

_remove_unused_model_weights(hpo_save_path, best_hpo_weight)
return env_manager.environment


def _remove_unused_model_weights(hpo_save_path: Path, best_hpo_weight: Optional[str] = None):
for weight in hpo_save_path.rglob("*.pth"):
if best_hpo_weight is not None and str(weight) == best_hpo_weight:
continue
weight.unlink()


def get_best_hpo_weight(hpo_dir: Union[str, Path], trial_id: Union[str, Path]) -> Optional[str]:
"""Get best model weight path of the HPO trial.
Expand Down Expand Up @@ -848,11 +856,15 @@ def _finalize_trial(self, task):
weight_dir_path = self._get_weight_dir_path()
weight_dir_path.mkdir(parents=True, exist_ok=True)
self._task.copy_weight(task.project_path, weight_dir_path)
latest_model_weight = self._task.get_latest_weight(weight_dir_path)
best_model_weight = get_best_hpo_weight(self._hpo_workdir, self._hp_config["id"])
necessary_weights = [
self._task.get_latest_weight(weight_dir_path),
get_best_hpo_weight(self._hpo_workdir, self._hp_config["id"]),
]
while None in necessary_weights:
necessary_weights.remove(None)
for each_model_weight in weight_dir_path.iterdir():
for neccesary_weight in [latest_model_weight, best_model_weight]:
if each_model_weight.samefile(neccesary_weight):
for necessary_weight in necessary_weights:
if each_model_weight.samefile(necessary_weight):
break
else:
each_model_weight.unlink()
Expand Down
59 changes: 17 additions & 42 deletions tests/unit/cli/utils/test_hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,13 @@ def test_init_fix_batch_size(self, cls_task_env, diff_from_min_bs):
hpo_runner = HpoRunner(cls_task_env, train_dataset_size, 10, "fake_path")
assert batch_size_name in hpo_runner._fixed_hp

@pytest.fixture
def mock_thread(self, mocker) -> MagicMock:
mock_thread = mocker.patch.object(hpo, "Thread")
return mock_thread

@e2e_pytest_unit
def test_run_hpo(self, mocker, cls_task_env):
def test_run_hpo(self, mocker, cls_task_env, mock_thread):
cls_task_env.model = None
hpo_runner = HpoRunner(cls_task_env, 100, 10, "fake_path")
mock_run_hpo_loop = mocker.patch("otx.cli.utils.hpo.run_hpo_loop")
Expand All @@ -477,7 +482,7 @@ def test_run_hpo(self, mocker, cls_task_env):
mock_hb.assert_called() # make hyperband

@e2e_pytest_unit
def test_run_hpo_w_dataset_smaller_than_batch(self, mocker, cls_task_env):
def test_run_hpo_w_dataset_smaller_than_batch(self, mocker, cls_task_env, mock_thread):
cls_task_env.model = None
hpo_runner = HpoRunner(cls_task_env, 2, 10, "fake_path")
mock_run_hpo_loop = mocker.patch("otx.cli.utils.hpo.run_hpo_loop")
Expand All @@ -494,6 +499,8 @@ class TestTrainer:
def setup(self, tmp_dir):
self.weight_format = "epoch_{}.pth"
self.hpo_workdir = Path(tmp_dir) / "hpo_dir"
self.hpo_workdir.mkdir()
self.trial_id = "1"

@pytest.fixture
def tmp_dir(self):
Expand All @@ -519,6 +526,8 @@ def mock_task(self, mocker, tmp_dir):
fake_project_path.mkdir(parents=True)
for i in range(1, 5):
(fake_project_path / self.weight_format.format(i)).write_text("fake")
with (self.hpo_workdir / f"{self.trial_id}.json").open("w") as f:
json.dump({"id": self.trial_id, "score": {"1": 1, "2": 2, "3": 5, "4": 4}}, f)

mock_get_train_task = mocker.patch.object(TaskEnvironmentManager, "get_train_task")
mock_task = mocker.MagicMock()
Expand Down Expand Up @@ -552,8 +561,12 @@ def test_run(self, mocker, cls_template_path, mock_task, tmp_dir):
# check
mock_report_func.assert_called_once_with(0, 0, done=True) # finilize report
assert self.hpo_workdir.exists() # make a directory to copy weight
for i in range(1, 5): # check model weights are copied
assert (self.hpo_workdir / "weight" / trial_id / self.weight_format.format(i)).exists()
assert (
self.hpo_workdir / "weight" / trial_id / self.weight_format.format(3)
).exists() # check best weight exists
assert (
self.hpo_workdir / "weight" / trial_id / self.weight_format.format(4)
).exists() # check last weight exists

mock_task.train.assert_called() # check task.train() is called

Expand Down Expand Up @@ -589,44 +602,6 @@ def test_run_trial_already_done(self, mocker, cls_template_path, mock_task, tmp_
mock_report_func.assert_called_once_with(0, 0, done=True) # finilize report
mock_task.train.assert_not_called() # check task.train() is called

@e2e_pytest_unit
def test_delete_unused_model_weight(self, mocker, cls_template_path):
# prepare
trial0_weight_dir = self.hpo_workdir / "weight" / "0"
mocker.patch(
"otx.cli.utils.hpo.TaskManager.get_latest_weight", return_value=str(trial0_weight_dir / "latest.pth")
)
mocker.patch("otx.cli.utils.hpo.get_best_hpo_weight", return_value=str(trial0_weight_dir / "best.pth"))

self.hpo_workdir.mkdir()
(self.hpo_workdir / "0.json").touch()
for i in range(2):
weight_dir = self.hpo_workdir / "weight" / str(i)
weight_dir.mkdir(parents=True)
(weight_dir / "latest.pth").touch()
(weight_dir / "best.pth").touch()
(weight_dir / "unused.pth").touch()

# run
trainer = Trainer(
hp_config={"configuration": {"iterations": 10}, "id": "1"},
report_func=mocker.MagicMock(),
model_template=find_and_parse_model_template(cls_template_path),
data_roots=mocker.MagicMock(),
task_type=TaskType.CLASSIFICATION,
hpo_workdir=self.hpo_workdir,
initial_weight_name="fake",
metric="fake",
)
trainer._delete_unused_model_weight()

assert sorted([f.name for f in (self.hpo_workdir / "weight" / "0").iterdir()]) == sorted(
["latest.pth", "best.pth"]
)
assert sorted([f.name for f in (self.hpo_workdir / "weight" / "1").iterdir()]) == sorted(
["latest.pth", "best.pth", "unused.pth"]
)


class TestHpoCallback:
@e2e_pytest_unit
Expand Down

0 comments on commit d6554c2

Please sign in to comment.