From 68273fe33e90f0ccbc5166d8638c5932df6268ce Mon Sep 17 00:00:00 2001 From: Cesar Cardoso Date: Wed, 17 Apr 2024 13:59:19 -0700 Subject: [PATCH] Support fixed features in Service API (#2372) Summary: Add the possibility of specifying some `FixedFeatures` as `fixed_features` in `AxClient.get_next_trial` and `AxClient.get_next_trials` which is currently only possible with the developer API. Reviewed By: saitcakmak Differential Revision: D56068035 --- ax/service/ax_client.py | 47 ++++++++++++++++++------------ ax/service/tests/test_ax_client.py | 19 ++++++++---- 2 files changed, 43 insertions(+), 23 deletions(-) diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index 7d1a9524a34..f79a64b8ec4 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -495,7 +495,10 @@ def set_search_space( wrap_error_message_in=CHOLESKY_ERROR_ANNOTATION, ) def get_next_trial( - self, ttl_seconds: Optional[int] = None, force: bool = False + self, + ttl_seconds: Optional[int] = None, + force: bool = False, + fixed_features: Optional[FixedFeatures] = None, ) -> Tuple[TParameterization, int]: """ Generate trial with the next set of parameters to try in the iteration process. @@ -508,6 +511,9 @@ def get_next_trial( failed properly. force: If set to True, this function will bypass the global stopping strategy's decision and generate a new trial anyway. + fixed_features: A FixedFeatures object containing any + features that should be fixed at specified values during + generation. Returns: Tuple of trial parameterization, trial index @@ -530,7 +536,10 @@ def get_next_trial( try: trial = self.experiment.new_trial( - generator_run=self._gen_new_generator_run(), ttl_seconds=ttl_seconds + generator_run=self._gen_new_generator_run( + fixed_features=fixed_features + ), + ttl_seconds=ttl_seconds, ) except MaxParallelismReachedException as e: if self._early_stopping_strategy is not None: @@ -580,7 +589,10 @@ def get_current_trial_generation_limit(self) -> Tuple[int, bool]: return self.generation_strategy.current_generator_run_limit() def get_next_trials( - self, max_trials: int, ttl_seconds: Optional[int] = None + self, + max_trials: int, + ttl_seconds: Optional[int] = None, + fixed_features: Optional[FixedFeatures] = None, ) -> Tuple[Dict[int, TParameterization], bool]: """Generate as many trials as currently possible. @@ -597,6 +609,9 @@ def get_next_trials( ttl_seconds: If specified, will consider the trial failed after this many seconds. Used to detect dead trials that were not marked failed properly. + fixed_features: A FixedFeatures object containing any + features that should be fixed at specified values during + generation. Returns: two-item tuple of: - mapping from trial indices to parameterizations in those trials, @@ -616,7 +631,9 @@ def get_next_trials( trials_dict = {} for _ in range(max_trials): try: - params, trial_index = self.get_next_trial(ttl_seconds=ttl_seconds) + params, trial_index = self.get_next_trial( + ttl_seconds=ttl_seconds, fixed_features=fixed_features + ) trials_dict[trial_index] = params except OptimizationComplete as err: logger.info( @@ -1744,20 +1761,16 @@ def _save_generation_strategy_to_db_if_possible( suppress_all_errors=suppress_all_errors, ) - def _get_last_completed_trial_index(self) -> int: - # infer last completed trial as the trial_index to use - # TODO: use Experiment.completed_trials once D46484953 lands. - completed_indices = [ - t.index for t in self.experiment.trials_by_status[TrialStatus.COMPLETED] - ] - completed_indices.append(0) # handle case of no completed trials - return max(completed_indices) - - def _gen_new_generator_run(self, n: int = 1) -> GeneratorRun: + def _gen_new_generator_run( + self, n: int = 1, fixed_features: Optional[FixedFeatures] = None + ) -> GeneratorRun: """Generate new generator run for this experiment. Args: n: Number of arms to generate. + fixed_features: A FixedFeatures object containing any + features that should be fixed at specified values during + generation. """ # If random seed is not set for this optimization, context manager does # nothing; otherwise, it sets the random seed for torch, but only for the @@ -1767,10 +1780,8 @@ def _gen_new_generator_run(self, n: int = 1) -> GeneratorRun: # stochasticity. fixed_feats = InstantiationBase.make_fixed_observation_features( - fixed_features=FixedFeatures( - parameters={}, trial_index=self._get_last_completed_trial_index() - ) - ) + fixed_features=fixed_features + ) if fixed_features else None with manual_seed(seed=self._random_seed): return not_none(self.generation_strategy).gen( experiment=self.experiment, diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index d232efafed8..dc301c1d5db 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -65,6 +65,7 @@ observed_pareto, predicted_pareto, ) +from ax.service.utils.instantiation import FixedFeatures from ax.storage.sqa_store.db import init_test_engine_and_session_factory from ax.storage.sqa_store.decoder import Decoder from ax.storage.sqa_store.encoder import Encoder @@ -2847,11 +2848,19 @@ def test_gen_fixed_features(self) -> None: with mock.patch.object( GenerationStrategy, "gen", wraps=ax_client.generation_strategy.gen ) as mock_gen: - params, idx = ax_client.get_next_trial() - call_kwargs = mock_gen.call_args_list[0][1] - ff = call_kwargs["fixed_features"] - self.assertEqual(ff.parameters, {}) - self.assertEqual(ff.trial_index, 0) + with self.subTest("fixed_features is None"): + params, idx = ax_client.get_next_trial() + call_kwargs = mock_gen.call_args_list[0][1] + ff = call_kwargs["fixed_features"] + self.assertEqual(ff.parameters, {}) + self.assertEqual(ff.trial_index, 0) + with self.subTest("fixed_features is set"): + fixed_features = FixedFeatures(parameters={"x": 0.0, "y": 5.0}) + params, idx = ax_client.get_next_trial(fixed_features=fixed_features) + call_kwargs = mock_gen.call_args_list[1][1] + ff = call_kwargs["fixed_features"] + self.assertEqual(ff.parameters, fixed_features.parameters) + self.assertEqual(ff.trial_index, 0) def test_get_optimization_trace_discard_infeasible_trials(self) -> None: ax_client = AxClient()