Skip to content

Commit

Permalink
Add init_position argument to UniformGenerator (facebook#2686)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebook#2686

`SobolGenerator` uses `init_position` to ensure that when the model is reconstructed, it resumes candidate generation from where it was left (rather than starting from the beginning of the sequence). Without this, when `deduplicate=False`, the model would generate the same points it has already generated, which would lead to different candidate generation behaviors depending on how often the model was reconstructed. This is undesirable as we want the model to resume generation rather than repeating from scratch.

Prior to this diff, `UniformGenerator` did not have `init_position`, so had this exact issue. We fix it here.

Reviewed By: Balandat

Differential Revision: D61622058
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Aug 21, 2024
1 parent 84b307d commit 353054b
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 73 deletions.
14 changes: 13 additions & 1 deletion ax/models/random/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ class RandomModel(Model):
of the model will not return the same point twice. This flag
is used in rejection sampling.
seed: An optional seed value for scrambling.
init_position: The initial state of the generator. This is the number
of samples to fast-forward before generating new samples.
Used to ensure that the re-loaded generator will continue generating
from the same sequence rather than starting from scratch.
generated_points: A set of previously generated points to use
for deduplication. These should be provided in the raw transformed
space the model operates in.
Expand All @@ -59,6 +63,7 @@ def __init__(
self,
deduplicate: bool = True,
seed: Optional[int] = None,
init_position: int = 0,
generated_points: Optional[np.ndarray] = None,
fallback_to_sample_polytope: bool = False,
) -> None:
Expand All @@ -69,6 +74,7 @@ def __init__(
if seed is not None
else checked_cast(int, torch.randint(high=100_000, size=(1,)).item())
)
self.init_position = init_position
# Used for deduplication.
self.generated_points = generated_points
self.fallback_to_sample_polytope = fallback_to_sample_polytope
Expand Down Expand Up @@ -180,7 +186,13 @@ def gen(
@copy_doc(Model._get_state)
def _get_state(self) -> dict[str, Any]:
state = super()._get_state()
state.update({"seed": self.seed, "generated_points": self.generated_points})
state.update(
{
"seed": self.seed,
"init_position": self.init_position,
"generated_points": self.generated_points,
}
)
return state

def _gen_unconstrained(
Expand Down
16 changes: 3 additions & 13 deletions ax/models/random/sobol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@

# pyre-strict

from typing import Any, Callable, Optional
from typing import Callable, Optional

import numpy as np
import torch
from ax.models.base import Model
from ax.models.model_utils import tunable_feature_indices
from ax.models.random.base import RandomModel
from ax.models.types import TConfig
from ax.utils.common.docutils import copy_doc
from ax.utils.common.typeutils import not_none
from torch.quasirandom import SobolEngine

Expand All @@ -26,17 +24,15 @@ class SobolGenerator(RandomModel):
the fit or predict methods.
Attributes:
init_position: The initial state of the Sobol generator.
Starts at 0 by default.
scramble: If True, permutes the parameter values among
the elements of the Sobol sequence. Default is True.
See base `RandomModel` for a description of remaining attributes.
"""

def __init__(
self,
seed: Optional[int] = None,
deduplicate: bool = True,
seed: Optional[int] = None,
init_position: int = 0,
scramble: bool = True,
generated_points: Optional[np.ndarray] = None,
Expand All @@ -45,10 +41,10 @@ def __init__(
super().__init__(
deduplicate=deduplicate,
seed=seed,
init_position=init_position,
generated_points=generated_points,
fallback_to_sample_polytope=fallback_to_sample_polytope,
)
self.init_position = init_position
self.scramble = scramble
# Initialize engine on gen.
self._engine: Optional[SobolEngine] = None
Expand Down Expand Up @@ -121,12 +117,6 @@ def gen(
self.init_position = not_none(self.engine).num_generated
return (points, weights)

@copy_doc(Model._get_state)
def _get_state(self) -> dict[str, Any]:
state = super()._get_state()
state.update({"init_position": self.init_position})
return state

def _gen_samples(self, n: int, tunable_d: int) -> np.ndarray:
"""Generate n samples.
Expand Down
9 changes: 7 additions & 2 deletions ax/models/random/uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import numpy as np
from ax.models.random.base import RandomModel
from scipy.stats import uniform


class UniformGenerator(RandomModel):
Expand All @@ -26,16 +25,21 @@ def __init__(
self,
deduplicate: bool = True,
seed: Optional[int] = None,
init_position: int = 0,
generated_points: Optional[np.ndarray] = None,
fallback_to_sample_polytope: bool = False,
) -> None:
super().__init__(
deduplicate=deduplicate,
seed=seed,
init_position=init_position,
generated_points=generated_points,
fallback_to_sample_polytope=fallback_to_sample_polytope,
)
self._rs = np.random.RandomState(seed=self.seed)
if self.init_position > 0:
# Fast-forward the random state by generating & discarding samples.
self._rs.uniform(size=(self.init_position))

def _gen_samples(self, n: int, tunable_d: int) -> np.ndarray:
"""Generate samples from the scipy uniform distribution.
Expand All @@ -48,4 +52,5 @@ def _gen_samples(self, n: int, tunable_d: int) -> np.ndarray:
samples: An (n x d) array of random points.
"""
return uniform.rvs(size=(n, tunable_d), random_state=self._rs) # pyre-ignore
self.init_position += n * tunable_d
return self._rs.uniform(size=(n, tunable_d))
105 changes: 48 additions & 57 deletions ax/models/tests/test_uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,34 @@
class UniformGeneratorTest(TestCase):
def setUp(self) -> None:
super().setUp()
self.tunable_param_bounds = (0, 1)
self.fixed_param_bounds = (1, 100)
self.tunable_param_bounds = (0.0, 1.0)
self.fixed_param_bounds = (1.0, 100.0)
self.seed = 0
self.expected_points = np.array(
[
[0.5488135, 0.71518937, 0.60276338],
[0.54488318, 0.4236548, 0.64589411],
[0.43758721, 0.891773, 0.96366276],
]
)

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def _create_bounds(self, n_tunable, n_fixed):
def _create_bounds(self, n_tunable: int, n_fixed: int) -> list[tuple[float, float]]:
tunable_bounds = [self.tunable_param_bounds] * n_tunable
fixed_bounds = [self.fixed_param_bounds] * n_fixed
return tunable_bounds + fixed_bounds

def test_UniformGeneratorAllTunable(self) -> None:
generator = UniformGenerator(seed=0)
def test_with_all_tunable(self) -> None:
generator = UniformGenerator(seed=self.seed)
bounds = self._create_bounds(n_tunable=3, n_fixed=0)
generated_points, weights = generator.gen(
n=3, bounds=bounds, rounding_func=lambda x: x
)

expected_points = np.array(
[
[0.5488135, 0.71518937, 0.60276338],
[0.54488318, 0.4236548, 0.64589411],
[0.43758721, 0.891773, 0.96366276],
]
)
self.assertTrue(np.shape(expected_points) == np.shape(generated_points))
self.assertTrue(np.allclose(expected_points, generated_points))
self.assertTrue(np.shape(self.expected_points) == np.shape(generated_points))
self.assertTrue(np.allclose(self.expected_points, generated_points))
self.assertTrue(np.all(weights == 1.0))

def test_UniformGeneratorFixedSpace(self) -> None:
generator = UniformGenerator(seed=0)
def test_with_fixed_space(self) -> None:
generator = UniformGenerator(seed=self.seed)
bounds = self._create_bounds(n_tunable=0, n_fixed=2)
n = 3
with self.assertRaises(SearchSpaceExhausted):
Expand All @@ -55,7 +53,7 @@ def test_UniformGeneratorFixedSpace(self) -> None:
fixed_features={0: 1, 1: 2},
rounding_func=lambda x: x,
)
generator = UniformGenerator(seed=0, deduplicate=False)
generator = UniformGenerator(seed=self.seed, deduplicate=False)
generated_points, _ = generator.gen(
n=3,
bounds=bounds,
Expand All @@ -66,57 +64,50 @@ def test_UniformGeneratorFixedSpace(self) -> None:
self.assertTrue(np.shape(expected_points) == np.shape(generated_points))
self.assertTrue(np.allclose(expected_points, generated_points))

def test_UniformGeneratorOnline(self) -> None:
def test_generating_one_by_one(self, init_position: int = 0) -> None:
# Verify that the generator will return the expected arms if called
# one at a time.
generator = UniformGenerator(seed=0)
generator = UniformGenerator(seed=self.seed, init_position=init_position)
n_tunable = fixed_param_index = 3
bounds = self._create_bounds(n_tunable=n_tunable, n_fixed=1)

n = 3
expected_points = np.array(
[
[0.5488135, 0.71518937, 0.60276338, 1],
[0.54488318, 0.4236548, 0.64589411, 1],
[0.43758721, 0.891773, 0.96366276, 1],
]
)
for i in range(n):
for i in range(init_position, 3):
generated_points, weights = generator.gen(
n=1,
bounds=bounds,
fixed_features={fixed_param_index: 1},
rounding_func=lambda x: x,
)
self.assertEqual(weights, [1])
self.assertTrue(np.allclose(generated_points, expected_points[i, :]))
self.assertTrue(
np.allclose(generated_points[..., :-1], self.expected_points[i, :])
)
self.assertEqual(generated_points[..., -1], 1)
self.assertEqual(generator.init_position, (i + 1) * n_tunable)

def test_UniformGeneratorReseed(self) -> None:
# Verify that the generator will return the expected arms if called
# one at a time.
generator = UniformGenerator(seed=0)
n_tunable = fixed_param_index = 3
bounds = self._create_bounds(n_tunable=n_tunable, n_fixed=1)
def test_with_init_position(self) -> None:
# These are multiples of 3 since there are 3 tunable parameters.
self.test_generating_one_by_one(init_position=3)
self.test_generating_one_by_one(init_position=6)

n = 3
expected_points = np.array(
[
[0.5488135, 0.71518937, 0.60276338, 1],
[0.54488318, 0.4236548, 0.64589411, 1],
[0.43758721, 0.891773, 0.96366276, 1],
]
def test_with_reloaded_state(self) -> None:
# Check that a reloaded generator will produce the same samples.
org_generator = UniformGenerator()
bounds = self._create_bounds(n_tunable=3, n_fixed=0)
# Generate some to advance the state.
org_generator.gen(n=3, bounds=bounds, rounding_func=lambda x: x)
# Construct a new generator with the state.
new_generator = UniformGenerator(**org_generator._get_state())
# Compare the generated samples.
org_samples, _ = org_generator.gen(
n=3, bounds=bounds, rounding_func=lambda x: x
)
for i in range(n):
generated_points, weights = generator.gen(
n=1,
bounds=bounds,
fixed_features={fixed_param_index: 1},
rounding_func=lambda x: x,
)
self.assertEqual(weights, [1])
self.assertTrue(np.allclose(generated_points, expected_points[i, :]))
new_samples, _ = new_generator.gen(
n=3, bounds=bounds, rounding_func=lambda x: x
)
self.assertTrue(np.allclose(org_samples, new_samples))

def test_UniformGeneratorWithOrderConstraints(self) -> None:
def test_with_order_constraints(self) -> None:
# Enforce dim_0 <= dim_1 <= dim_2 <= dim_3.
# Enforce both fixed and tunable constraints.
generator = UniformGenerator(seed=0)
Expand All @@ -143,7 +134,7 @@ def test_UniformGeneratorWithOrderConstraints(self) -> None:
self.assertTrue(np.shape(expected_points) == np.shape(generated_points))
self.assertTrue(np.allclose(expected_points, generated_points))

def test_UniformGeneratorWithLinearConstraints(self) -> None:
def test_with_linear_constraints(self) -> None:
# Enforce dim_0 <= dim_1 <= dim_2 <= dim_3.
# Enforce both fixed and tunable constraints.
generator = UniformGenerator(seed=0)
Expand All @@ -169,7 +160,7 @@ def test_UniformGeneratorWithLinearConstraints(self) -> None:
self.assertTrue(np.shape(expected_points) == np.shape(generated_points))
self.assertTrue(np.allclose(expected_points, generated_points))

def test_UniformGeneratorBadBounds(self) -> None:
def test_with_bad_bounds(self) -> None:
generator = UniformGenerator()
with self.assertRaises(ValueError):
generated_points, weights = generator.gen(
Expand Down

0 comments on commit 353054b

Please sign in to comment.