Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Abstract out attribute copy during clone_to #2288

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions ax/core/base_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod, abstractproperty
from copy import deepcopy
from datetime import datetime, timedelta
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
Expand Down Expand Up @@ -865,3 +866,20 @@ def _validate_can_attach_data(self) -> None:
f"Trial {self.index} has been marked {self.status.name}, so it "
"no longer expects data."
)

def _update_trial_attrs_on_clone(
self,
new_trial: BaseTrial,
) -> None:
"""Updates attributes of the trial that are not copied over when cloning
a trial.

Args:
new_trial: The cloned trial.
new_experiment: The experiment that the cloned trial belongs to.
new_status: The new status of the cloned trial.
"""
new_trial._run_metadata = deepcopy(self._run_metadata)
new_trial._stop_metadata = deepcopy(self._stop_metadata)
new_trial._num_arms_created = self._num_arms_created
new_trial.runner = self._runner.clone() if self._runner else None
6 changes: 1 addition & 5 deletions ax/core/batch_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import warnings

from collections import defaultdict, OrderedDict
from copy import deepcopy
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
Expand Down Expand Up @@ -603,10 +602,7 @@ def clone_to(
self._status_quo.clone(),
weight=sq_weight,
)
new_trial.runner = self._runner.clone() if self._runner else None
new_trial._run_metadata = deepcopy(self._run_metadata)
new_trial._stop_metadata = deepcopy(self._stop_metadata)
new_trial._num_arms_created = self._num_arms_created
self._update_trial_attrs_on_clone(new_trial=new_trial)
return new_trial

def attach_batch_trial_data(
Expand Down
8 changes: 1 addition & 7 deletions ax/core/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

from __future__ import annotations

from copy import deepcopy

from functools import partial

from logging import Logger
Expand Down Expand Up @@ -351,9 +349,5 @@ def clone_to(
)
if self.generator_run is not None:
new_trial.add_generator_run(self.generator_run.clone())
new_trial._run_metadata = deepcopy(self._run_metadata)
new_trial._stop_metadata = deepcopy(self._stop_metadata)
new_trial._num_arms_created = self._num_arms_created
new_trial.runner = self._runner.clone() if self._runner else None

self._update_trial_attrs_on_clone(new_trial=new_trial)
return new_trial
Loading