Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine HPO #2175

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix a bug that HPO in multi GPU doesn't work rightly
eunwoosh committed May 23, 2023

Unverified

This user has not yet uploaded their public signing key.
commit 620bcf6789967ee23e567e29a9657990e26e3543
6 changes: 6 additions & 0 deletions otx/cli/utils/hpo.py
Original file line number Diff line number Diff line change
@@ -679,6 +679,12 @@ def run(self):
need_to_save_initial_weight = False
resume_weight_path = self._get_resume_weight_path()
if resume_weight_path is not None:
ret = re.search(r"(\d+)\.pth", resume_weight_path)
if ret is not None:
resume_epoch = int(ret.group(1))
if self._epoch <= resume_epoch: # given epoch is already done
self._report_func(0, 0, done=True)
return
environment.resume_model_weight(resume_weight_path, dataset)
else:
initial_weight = self._load_fixed_initial_weight()
18 changes: 8 additions & 10 deletions otx/hpo/hpo_base.py
Original file line number Diff line number Diff line change
@@ -183,6 +183,7 @@ def __init__(self, trial_id: Any, configuration: Dict, train_environment: Option
self._train_environment = train_environment
self._iteration = None
self.status: TrialStatus = TrialStatus.READY
self._done = False

@property
def id(self):
@@ -204,6 +205,8 @@ def iteration(self, val):
"""Setter for iteration."""
check_positive(val, "iteration")
self._iteration = val
if self.get_progress() < val:
self._done = False

@property
def train_environment(self):
@@ -279,21 +282,16 @@ def save_results(self, save_path: str):
json.dump(results, f)

def finalize(self):
"""Let the trial know that training is done.

If the trial isn't trained until given resource, then make it pretend to train until resouce.
"""
if self.get_progress() < self.iteration:
best_score = self.get_best_score()
if best_score is None:
raise RuntimeError(f"Although {self.id} trial doesn't report any score but it's done")
self.register_score(best_score, self.iteration)
"""Set done as True."""
if not self.score:
raise RuntimeError(f"Trial{self.id} didn't report any score but tries to be done.")
self._done = True

def is_done(self):
"""Check the trial is done."""
if self.iteration is None:
raise ValueError("iteration isn't set yet.")
return self.get_progress() >= self.iteration
return self._done or self.get_progress() >= self.iteration


class TrialStatus(IntEnum):
5 changes: 3 additions & 2 deletions otx/hpo/hpo_runner.py
Original file line number Diff line number Diff line change
@@ -163,9 +163,10 @@ def _get_uid(self) -> int:

def _terminate_all_running_processes(self):
for trial in self._running_trials.values():
trial.queue.close()
process = trial.process
if process.is_alive():
logger.warning(f"Kill child process {process.pid}")
logger.info(f"Kill child process {process.pid}")
process.kill()

def _terminate_signal_handler(self, signum, _frame):
@@ -206,7 +207,7 @@ def _report_score(
try:
trial_status = recv_queue.get(timeout=3)
except queue.Empty:
pass
return TrialStatus.RUNNING

while not recv_queue.empty():
trial_status = recv_queue.get_nowait()
4 changes: 3 additions & 1 deletion otx/hpo/hyperband.py
Original file line number Diff line number Diff line change
@@ -53,6 +53,7 @@ def __init__(self, trial_id: Any, configuration: Dict, train_environment: Option
super().__init__(trial_id, configuration, train_environment)
self._rung: Optional[int] = None
self._bracket: Optional[int] = None
self.estimating_max_resource: bool = False

@property
def rung(self):
@@ -708,6 +709,7 @@ def _make_trial_to_estimate_resource(self) -> AshaTrial:
if len(self._trials) == 1: # first trial to estimate
trial.bracket = 0
trial.iteration = self.num_full_iterations
trial.estimating_max_resource = True
elif self._minimum_resource is not None:
trial.iteration = self._minimum_resource
else:
@@ -917,7 +919,7 @@ def report_score(
"""
trial = self._trials[trial_id]
if done:
if self.maximum_resource is None:
if self.maximum_resource is None and trial.estimating_max_resource:
self.maximum_resource = trial.get_progress()
self.num_full_iterations = self.maximum_resource
if not self._need_to_find_resource_value():
121 changes: 82 additions & 39 deletions tests/unit/cli/utils/test_hpo.py
Original file line number Diff line number Diff line change
@@ -476,6 +476,16 @@ def test_run_hpo_w_dataset_smaller_than_batch(self, mocker, cls_task_env):


class TestTrainer:
@pytest.fixture(autouse=True)
def setup(self, tmp_dir):
self.weight_format = "epoch_{}.pth"
self.hpo_workdir = Path(tmp_dir) / "hpo_dir"

@pytest.fixture
def tmp_dir(self):
with TemporaryDirectory() as tmp_dir:
yield tmp_dir

@e2e_pytest_unit
def test_init(self, mocker, cls_template_path):
Trainer(
@@ -489,49 +499,82 @@ def test_init(self, mocker, cls_template_path):
metric="fake",
)

@pytest.fixture
def mock_task(self, mocker, tmp_dir):
fake_project_path = Path(tmp_dir) / "fake_proejct"
fake_project_path.mkdir(parents=True)
for i in range(1, 5):
(fake_project_path / self.weight_format.format(i)).write_text("fake")

mock_get_train_task = mocker.patch.object(TaskEnvironmentManager, "get_train_task")
mock_task = mocker.MagicMock()
mock_task.project_path = str(fake_project_path)
mock_get_train_task.return_value = mock_task

return mock_task

@e2e_pytest_unit
def test_run(self, mocker, cls_template_path):
with TemporaryDirectory() as tmp_dir:
# prepare
trial_id = "1"
weight_format = "epoch_{}.pth"
hpo_workdir = Path(tmp_dir) / "hpo_dir"
fake_project_path = Path(tmp_dir) / "fake_proejct"
fake_project_path.mkdir(parents=True)
for i in range(1, 5):
(fake_project_path / weight_format.format(i)).write_text("fake")

mock_get_train_task = mocker.patch.object(TaskEnvironmentManager, "get_train_task")
mock_task = mocker.MagicMock()
mock_task.project_path = str(fake_project_path)
mock_get_train_task.return_value = mock_task

mock_report_func = mocker.MagicMock()

mocker.patch("otx.cli.utils.hpo.get_dataset_adapter")
mocker.patch("otx.cli.utils.hpo.HpoDataset")

# run
trainer = Trainer(
hp_config={"configuration": {"iterations": 10}, "id": trial_id},
report_func=mock_report_func,
model_template=find_and_parse_model_template(cls_template_path),
data_roots=mocker.MagicMock(),
task_type=TaskType.CLASSIFICATION,
hpo_workdir=hpo_workdir,
initial_weight_name="fake",
metric="fake",
)
trainer.run()

# check
mock_report_func.assert_called_once_with(0, 0, done=True) # finilize report
assert hpo_workdir.exists() # make a directory to copy weight
for i in range(1, 5): # check model weights are copied
assert (hpo_workdir / "weight" / trial_id / weight_format.format(i)).exists()
def test_run(self, mocker, cls_template_path, mock_task, tmp_dir):
# prepare
trial_id = "1"
mock_report_func = mocker.MagicMock()

mocker.patch("otx.cli.utils.hpo.get_dataset_adapter")
mocker.patch("otx.cli.utils.hpo.HpoDataset")

# run
trainer = Trainer(
hp_config={"configuration": {"iterations": 10}, "id": trial_id},
report_func=mock_report_func,
model_template=find_and_parse_model_template(cls_template_path),
data_roots=mocker.MagicMock(),
task_type=TaskType.CLASSIFICATION,
hpo_workdir=self.hpo_workdir,
initial_weight_name="fake",
metric="fake",
)
trainer.run()

# check
mock_report_func.assert_called_once_with(0, 0, done=True) # finilize report
assert self.hpo_workdir.exists() # make a directory to copy weight
for i in range(1, 5): # check model weights are copied
assert (self.hpo_workdir / "weight" / trial_id / self.weight_format.format(i)).exists()

mock_task.train.assert_called() # check task.train() is called

@e2e_pytest_unit
def test_run_trial_already_done(self, mocker, cls_template_path, mock_task, tmp_dir):
"""Test a case where trial to run already training given epoch."""
# prepare
trial_id = "1"
epoch_to_run = 10
weight_dir = self.hpo_workdir / "weight" / trial_id
# prepare a weight trained more than given epoch
weight_dir.mkdir(parents=True)
(weight_dir / self.weight_format.format(epoch_to_run+1)).touch()
mock_report_func = mocker.MagicMock()

mocker.patch("otx.cli.utils.hpo.get_dataset_adapter")
mocker.patch("otx.cli.utils.hpo.HpoDataset")

# run
trainer = Trainer(
hp_config={"configuration": {"iterations": epoch_to_run}, "id": trial_id},
report_func=mock_report_func,
model_template=find_and_parse_model_template(cls_template_path),
data_roots=mocker.MagicMock(),
task_type=TaskType.CLASSIFICATION,
hpo_workdir=self.hpo_workdir,
initial_weight_name="fake",
metric="fake",
)
trainer.run()

# check
mock_report_func.assert_called_once_with(0, 0, done=True) # finilize report
mock_task.train.assert_not_called() # check task.train() is called


class TestHpoCallback:
@e2e_pytest_unit
2 changes: 1 addition & 1 deletion tests/unit/hpo/test_hpo_base.py
Original file line number Diff line number Diff line change
@@ -139,7 +139,7 @@ def test_finalize(self, trial):
trial.iteration = 10
trial.register_score(10, 5)
trial.finalize()
assert trial.iteration == trial.get_progress()
assert trial.is_done()

@e2e_pytest_component
def test_finalize_without_registered_score(self, trial):
3 changes: 1 addition & 2 deletions tests/unit/hpo/test_hyperband.py
Original file line number Diff line number Diff line change
@@ -715,7 +715,7 @@ def test_report_score_trial_done(self, hyper_band):
trial = hyper_band.get_next_sample()
hyper_band.report_score(100, 0.1, trial.id)
hyper_band.report_score(0, 0, trial.id, done=True)
assert trial.get_progress() == trial.iteration
assert trial.is_done()

@e2e_pytest_component
def test_get_best_config(self, hyper_band):
@@ -1051,7 +1051,6 @@ def test_without_minimum_maximum_resource(self, good_hyperband_args, num_trial_t
if hyper_band.report_score(score=1, resource=iter, trial_id=trial.id) == TrialStatus.STOP:
break

first_trial = trials_to_estimate[0]
hyper_band.report_score(score=1, resource=max_validation, trial_id=first_trial.id)
hyper_band.report_score(score=0, resource=0, trial_id=first_trial.id, done=True)