Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support fixed features in Service API #2372

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 32 additions & 17 deletions ax/service/ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,10 @@ def set_search_space(
wrap_error_message_in=CHOLESKY_ERROR_ANNOTATION,
)
def get_next_trial(
self, ttl_seconds: Optional[int] = None, force: bool = False
self,
ttl_seconds: Optional[int] = None,
force: bool = False,
fixed_features: Optional[FixedFeatures] = None,
) -> Tuple[TParameterization, int]:
"""
Generate trial with the next set of parameters to try in the iteration process.
Expand All @@ -508,6 +511,9 @@ def get_next_trial(
failed properly.
force: If set to True, this function will bypass the global stopping
strategy's decision and generate a new trial anyway.
fixed_features: A FixedFeatures object containing any
features that should be fixed at specified values during
generation.

Returns:
Tuple of trial parameterization, trial index
Expand All @@ -530,7 +536,10 @@ def get_next_trial(

try:
trial = self.experiment.new_trial(
generator_run=self._gen_new_generator_run(), ttl_seconds=ttl_seconds
generator_run=self._gen_new_generator_run(
fixed_features=fixed_features
),
ttl_seconds=ttl_seconds,
)
except MaxParallelismReachedException as e:
if self._early_stopping_strategy is not None:
Expand Down Expand Up @@ -580,7 +589,10 @@ def get_current_trial_generation_limit(self) -> Tuple[int, bool]:
return self.generation_strategy.current_generator_run_limit()

def get_next_trials(
self, max_trials: int, ttl_seconds: Optional[int] = None
self,
max_trials: int,
ttl_seconds: Optional[int] = None,
fixed_features: Optional[FixedFeatures] = None,
) -> Tuple[Dict[int, TParameterization], bool]:
"""Generate as many trials as currently possible.

Expand All @@ -597,6 +609,9 @@ def get_next_trials(
ttl_seconds: If specified, will consider the trial failed after this
many seconds. Used to detect dead trials that were not marked
failed properly.
fixed_features: A FixedFeatures object containing any
features that should be fixed at specified values during
generation.

Returns: two-item tuple of:
- mapping from trial indices to parameterizations in those trials,
Expand All @@ -616,7 +631,9 @@ def get_next_trials(
trials_dict = {}
for _ in range(max_trials):
try:
params, trial_index = self.get_next_trial(ttl_seconds=ttl_seconds)
params, trial_index = self.get_next_trial(
ttl_seconds=ttl_seconds, fixed_features=fixed_features
)
trials_dict[trial_index] = params
except OptimizationComplete as err:
logger.info(
Expand Down Expand Up @@ -1744,20 +1761,16 @@ def _save_generation_strategy_to_db_if_possible(
suppress_all_errors=suppress_all_errors,
)

def _get_last_completed_trial_index(self) -> int:
# infer last completed trial as the trial_index to use
# TODO: use Experiment.completed_trials once D46484953 lands.
completed_indices = [
t.index for t in self.experiment.trials_by_status[TrialStatus.COMPLETED]
]
completed_indices.append(0) # handle case of no completed trials
return max(completed_indices)

def _gen_new_generator_run(self, n: int = 1) -> GeneratorRun:
def _gen_new_generator_run(
self, n: int = 1, fixed_features: Optional[FixedFeatures] = None
) -> GeneratorRun:
"""Generate new generator run for this experiment.

Args:
n: Number of arms to generate.
fixed_features: A FixedFeatures object containing any
features that should be fixed at specified values during
generation.
"""
# If random seed is not set for this optimization, context manager does
# nothing; otherwise, it sets the random seed for torch, but only for the
Expand All @@ -1766,10 +1779,12 @@ def _gen_new_generator_run(self, n: int = 1) -> GeneratorRun:
# serious negative impact on the performance of the models that employ
# stochasticity.

fixed_feats = InstantiationBase.make_fixed_observation_features(
fixed_features=FixedFeatures(
parameters={}, trial_index=self._get_last_completed_trial_index()
fixed_feats = (
InstantiationBase.make_fixed_observation_features(
fixed_features=fixed_features
)
if fixed_features
else None
)
with manual_seed(seed=self._random_seed):
return not_none(self.generation_strategy).gen(
Expand Down
20 changes: 15 additions & 5 deletions ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
observed_pareto,
predicted_pareto,
)
from ax.service.utils.instantiation import FixedFeatures
from ax.storage.sqa_store.db import init_test_engine_and_session_factory
from ax.storage.sqa_store.decoder import Decoder
from ax.storage.sqa_store.encoder import Encoder
Expand Down Expand Up @@ -2847,11 +2848,20 @@ def test_gen_fixed_features(self) -> None:
with mock.patch.object(
GenerationStrategy, "gen", wraps=ax_client.generation_strategy.gen
) as mock_gen:
params, idx = ax_client.get_next_trial()
call_kwargs = mock_gen.call_args_list[0][1]
ff = call_kwargs["fixed_features"]
self.assertEqual(ff.parameters, {})
self.assertEqual(ff.trial_index, 0)
with self.subTest("fixed_features is None"):
params, idx = ax_client.get_next_trial()
call_kwargs = mock_gen.call_args_list[0][1]
ff = call_kwargs["fixed_features"]
self.assertIsNone(ff)
with self.subTest("fixed_features is set"):
fixed_features = FixedFeatures(
parameters={"x": 0.0, "y": 5.0}, trial_index=0
)
params, idx = ax_client.get_next_trial(fixed_features=fixed_features)
call_kwargs = mock_gen.call_args_list[1][1]
ff = call_kwargs["fixed_features"]
self.assertEqual(ff.parameters, fixed_features.parameters)
self.assertEqual(ff.trial_index, 0)

def test_get_optimization_trace_discard_infeasible_trials(self) -> None:
ax_client = AxClient()
Expand Down