Skip to content

Commit

Permalink
Make sure random seed persists beyond storage (#2671)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2671

When the random seed is not specified in `model_kwargs`, it is set to a fixed value in `__init__`. If we do not store this generated `seed` and reload the experiment / GS, we end up continuing the random generation using a different seed. Storing `seed` in `model_state` will ensure it is stored (in last GR) and reused when the GS is reloaded.

Model state is extracted from last GR in `GS._fit_current_model`: https://www.internalfb.com/code/fbsource/[4d9fa225216d]/fbcode/ax/modelbridge/generation_strategy.py?lines=856
This gets passed down to `ModelSpec.fit` as `**model_kwargs`, which takes precedence over `ModelSpec.model_kwargs`: https://www.internalfb.com/code/fbsource/[4d9fa225216d]/fbcode/ax/modelbridge/model_spec.py?lines=131, which will ensure any `"seed": None` kwarg will get overwritten by the generated seed from last GR.

Also removed the `if not self.deduplicate` block from `_get_state`. Whether we save the state should not depend on `deduplicate`, as it is taken into account during sampling. Not saving the seed would just lead to additional unpredictability in the generator behavior.

Reviewed By: bernardbeckerman

Differential Revision: D61479553

fbshipit-source-id: 456e3da987b872895b50bb73cfb60b7011e38c3a
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Aug 20, 2024
1 parent e157b1f commit 40c8417
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 24 deletions.
45 changes: 27 additions & 18 deletions ax/modelbridge/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from unittest import mock
from unittest.mock import MagicMock, patch

import numpy as np
from ax.core.arm import Arm
from ax.core.base_trial import TrialStatus
from ax.core.experiment import Experiment
Expand Down Expand Up @@ -448,6 +449,7 @@ def test_do_not_enforce_min_observations(self) -> None:
def test_sobol_GPEI_strategy(self) -> None:
exp = get_branin_experiment()
self.assertEqual(self.sobol_GPEI_GS.name, "Sobol+GPEI")
expected_seed = None
for i in range(7):
g = self.sobol_GPEI_GS.gen(exp)
exp.new_trial(generator_run=g).run()
Expand All @@ -470,7 +472,7 @@ def test_sobol_GPEI_strategy(self) -> None:
self.assertEqual(
mkw,
{
"seed": None,
"seed": expected_seed,
"deduplicate": True,
"init_position": i,
"scramble": True,
Expand All @@ -491,14 +493,17 @@ def test_sobol_GPEI_strategy(self) -> None:
"fit_on_init": True,
},
)
ms = g._model_state_after_gen
self.assertIsNotNone(ms)
# Generated points are randomized, so just checking that they are there.
self.assertIn("generated_points", ms)
# Remove the randomized generated points to compare the rest.
ms = ms.copy()
del ms["generated_points"]
self.assertEqual(ms, {"init_position": i + 1})
ms = not_none(g._model_state_after_gen).copy()
# Compare the model state to Sobol state.
sobol_model = not_none(self.sobol_GPEI_GS.model).model
self.assertTrue(
np.array_equal(
ms.pop("generated_points"), sobol_model.generated_points
)
)
# Replace expected seed with the one generated in __init__.
expected_seed = sobol_model.seed
self.assertEqual(ms, {"init_position": i + 1, "seed": expected_seed})
# Check completeness error message when GS should be done.
with self.assertRaises(GenerationStrategyCompleted):
g = self.sobol_GPEI_GS.gen(exp)
Expand Down Expand Up @@ -1212,6 +1217,7 @@ def test_gs_with_generation_nodes(self) -> None:
"Simple test of a SOBOL + GPEI GenerationStrategy composed of GenerationNodes"
exp = get_branin_experiment()
self.assertEqual(self.sobol_GPEI_GS_nodes.name, "Sobol+GPEI_Nodes")
expected_seed = None

for i in range(7):
g = self.sobol_GPEI_GS_nodes.gen(exp)
Expand All @@ -1235,7 +1241,7 @@ def test_gs_with_generation_nodes(self) -> None:
self.assertEqual(
mkw,
{
"seed": None,
"seed": expected_seed,
"deduplicate": True,
"init_position": i,
"scramble": True,
Expand All @@ -1256,14 +1262,17 @@ def test_gs_with_generation_nodes(self) -> None:
"fit_on_init": True,
},
)
ms = g._model_state_after_gen
self.assertIsNotNone(ms)
# Generated points are randomized, so just checking that they are there.
self.assertIn("generated_points", ms)
# Remove the randomized generated points to compare the rest.
ms = ms.copy()
del ms["generated_points"]
self.assertEqual(ms, {"init_position": i + 1})
ms = not_none(g._model_state_after_gen).copy()
# Compare the model state to Sobol state.
sobol_model = not_none(self.sobol_GPEI_GS_nodes.model).model
self.assertTrue(
np.array_equal(
ms.pop("generated_points"), sobol_model.generated_points
)
)
# Replace expected seed with the one generated in __init__.
expected_seed = sobol_model.seed
self.assertEqual(ms, {"init_position": i + 1, "seed": expected_seed})

def test_clone_reset_nodes(self) -> None:
"""Test that node-based generation strategy is appropriately reset
Expand Down
4 changes: 1 addition & 3 deletions ax/models/random/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,7 @@ def gen(
@copy_doc(Model._get_state)
def _get_state(self) -> dict[str, Any]:
state = super()._get_state()
if not self.deduplicate:
return state
state.update({"generated_points": self.generated_points})
state.update({"seed": self.seed, "generated_points": self.generated_points})
return state

def _gen_unconstrained(
Expand Down
6 changes: 6 additions & 0 deletions ax/models/tests/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ def test_seed(self) -> None:
# With no seed.
self.assertIsInstance(self.random_model.seed, int)

def test_state(self) -> None:
for model in (self.random_model, RandomModel(seed=5)):
state = model._get_state()
self.assertEqual(state["seed"], model.seed)
self.assertEqual(state["generated_points"], model.generated_points)

def test_RandomModelGenSamples(self) -> None:
with self.assertRaises(NotImplementedError):
self.random_model._gen_samples(n=1, tunable_d=1)
Expand Down
17 changes: 15 additions & 2 deletions ax/models/tests/test_sobol.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,15 @@ def test_SobolGeneratorAllTunable(self) -> None:
self.assertTrue(np.all(generated_points >= np_bounds[:, 0]))
self.assertTrue(np.all(generated_points <= np_bounds[:, 1]))
self.assertTrue(np.all(weights == 1.0))
self.assertEqual(generator._get_state().get("init_position"), 3)
state = generator._get_state()
self.assertEqual(state.get("init_position"), 3)
self.assertEqual(state.get("seed"), generator.seed)
self.assertTrue(
np.array_equal(
state.get("generated_points"),
generator.generated_points,
)
)

def test_SobolGeneratorFixedSpace(self) -> None:
generator = SobolGenerator(seed=0, deduplicate=False)
Expand Down Expand Up @@ -308,4 +316,9 @@ def test_SobolGeneratorDedupe(self) -> None:
rounding_func=lambda x: x,
)
self.assertEqual(len(generated_points), 1)
self.assertIsNotNone(generator._get_state().get("generated_points"))
self.assertTrue(
np.array_equal(
generator._get_state().get("generated_points"),
generator.generated_points,
)
)
3 changes: 2 additions & 1 deletion ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,10 @@ def test_ExperimentSaveAndLoadReducedState(
bkw = gr._bridge_kwargs
self.assertIsNotNone(bkw)
self.assertEqual(len(bkw), 9)
# This has seed, generated points and init position.
ms = gr._model_state_after_gen
self.assertIsNotNone(ms)
self.assertEqual(len(ms), 2)
self.assertEqual(len(ms), 3)
gm = gr._gen_metadata
self.assertIsNotNone(gm)
self.assertEqual(len(gm), 0)
Expand Down

0 comments on commit 40c8417

Please sign in to comment.