Skip to content

Commit

Permalink
Update HSS dummy value logic & expose it in Cast
Browse files Browse the repository at this point in the history
Summary: Updates `HSS.flatten_observation_features` to add dummy values even when the full parameterization is recorded in metadata, as long as `inject_dummy_values_to_complete_flat_parameterization=True` & there are missing parameters. Also exposes this setting in `Cast`, to make it usable in experiments.

Differential Revision: D53029754
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Apr 13, 2024
1 parent 8efc837 commit 1537a61
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 47 deletions.
67 changes: 36 additions & 31 deletions ax/core/search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
)
from ax.core.parameter_distribution import ParameterDistribution
from ax.core.types import TParameterization
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.exceptions.core import AxWarning, UnsupportedError, UserInputError
from ax.utils.common.base import Base
from ax.utils.common.constants import Keys
from ax.utils.common.logger import get_logger
Expand Down Expand Up @@ -508,44 +508,49 @@ def flatten_observation_features(
observation_features: Observation features corresponding to one point
to flatten.
inject_dummy_values_to_complete_flat_parameterization: Whether to inject
values for parameters that are not in the parameterization if they
are not recorded in the observation features' metadata (this can
happen if e.g. the point wasn't generated by Ax but attached manually).
values for parameters that are not in the parameterization.
This will be used to complete the parameterization after re-injecting
the parameters that are recorded in the metadata (for parameters
that were generated by Ax).
"""
obs_feats = observation_features
if obs_feats.metadata and Keys.FULL_PARAMETERIZATION in obs_feats.metadata:
# NOTE: We could just use the full parameterization as stored;
# opting for a safer option of only injecting parameters that were
# removed, but not altering those that are present if they have different
# values in full parameterization as stored in metadata.
has_full_parameterization = Keys.FULL_PARAMETERIZATION in (
obs_feats.metadata or {}
)

if obs_feats.parameters == {} and not has_full_parameterization:
# Return as is if the observation feature does not have any parameters.
return obs_feats

if has_full_parameterization:
# If full parameterization is recorded, use it to fill in missing values.
full_parameterization = not_none(obs_feats.metadata)[
Keys.FULL_PARAMETERIZATION
]
obs_feats.parameters = {**full_parameterization, **obs_feats.parameters}
return obs_feats

if obs_feats.parameters == {}:
# Return as is if the observation feature does not have any parameters.
return obs_feats

if inject_dummy_values_to_complete_flat_parameterization:
# To cast a parameterization to flattened search space, inject dummy values
# for parameters that were not present in it.
dummy_values_to_inject = (
self._gen_dummy_values_to_complete_flat_parameterization(
observation_features=obs_feats
if len(obs_feats.parameters) < len(self.parameters):
if inject_dummy_values_to_complete_flat_parameterization:
# Inject dummy values for parameters missing from the parameterization.
dummy_values_to_inject = (
self._gen_dummy_values_to_complete_flat_parameterization(
observation_features=obs_feats
)
)
obs_feats.parameters = {
**dummy_values_to_inject,
**obs_feats.parameters,
}
else:
# The parameterization is still incomplete.
warnings.warn(
f"Cannot flatten observation features {obs_feats} as full "
"parameterization is not recorded in metadata and "
"`inject_dummy_values_to_complete_flat_parameterization` is "
"set to False.",
AxWarning,
stacklevel=2,
)
)
obs_feats.parameters = {**dummy_values_to_inject, **obs_feats.parameters}
return obs_feats

# We did not have the full parameterization stored, so we either return the
# observation features as given without change, or we inject dummy values if
# that behavior was requested via the opt-in flag.
warnings.warn(
f"Cannot flatten observation features {obs_feats} as full "
"parameterization is not recorded in metadata."
)
return obs_feats

def check_membership(
Expand Down
68 changes: 56 additions & 12 deletions ax/core/tests/test_search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from unittest import mock

import pandas as pd

from ax.core.arm import Arm
from ax.core.observation import ObservationFeatures
from ax.core.parameter import (
Expand Down Expand Up @@ -838,25 +837,40 @@ def test_flatten_observation_features(self) -> None:
hss_1_obs_feats_1_cast = self.hss_1.cast_observation_features(
observation_features=hss_1_obs_feats_1
)
hss_1_obs_feats_1_flattened = self.hss_1.flatten_observation_features(
observation_features=hss_1_obs_feats_1_cast
)
self.assertEqual( # Cast-flatten roundtrip.
hss_1_obs_feats_1.parameters,
hss_1_obs_feats_1_flattened.parameters,
)
self.assertEqual( # Check that both cast and flattened have full params.
hss_1_obs_feats_1_cast.metadata.get(Keys.FULL_PARAMETERIZATION),
hss_1_obs_feats_1_flattened.metadata.get(Keys.FULL_PARAMETERIZATION),
)
for inject_dummy in (True, False):
hss_1_obs_feats_1_flattened = self.hss_1.flatten_observation_features(
observation_features=hss_1_obs_feats_1_cast,
inject_dummy_values_to_complete_flat_parameterization=inject_dummy,
)
self.assertEqual( # Cast-flatten roundtrip.
hss_1_obs_feats_1.parameters,
hss_1_obs_feats_1_flattened.parameters,
)
self.assertEqual( # Check that both cast and flattened have full params.
hss_1_obs_feats_1_cast.metadata.get(Keys.FULL_PARAMETERIZATION),
hss_1_obs_feats_1_flattened.metadata.get(Keys.FULL_PARAMETERIZATION),
)
# Check that flattening observation features without metadata does nothing.
# Does not warn here since it already has all parameters.
with warnings.catch_warnings(record=True) as ws:
self.assertEqual(
self.hss_1.flatten_observation_features(
observation_features=hss_1_obs_feats_1
),
hss_1_obs_feats_1,
)
self.assertFalse(
any("Cannot flatten observation features" in str(w.message) for w in ws)
)
# This one warns since it is missing some parameters.
obs_ft_missing = ObservationFeatures.from_arm(arm=self.hss_1_arm_missing_param)
with warnings.catch_warnings(record=True) as ws:
self.assertEqual(
self.hss_1.flatten_observation_features(
observation_features=obs_ft_missing
),
obs_ft_missing,
)
self.assertTrue(
any("Cannot flatten observation features" in str(w.message) for w in ws)
)
Expand Down Expand Up @@ -953,6 +967,36 @@ def test_flatten_observation_features_inject_dummy_parameter_values(
set(self.hss_with_fixed.parameters.keys()),
)

def test_flatten_observation_features_full_and_dummy(self) -> None:
# Test flattening when both full features & inject dummy values
# are specified. This is relevant if the full parameterization
# is from some subset of the search space.
# Get an obs feat with hss_1 parameterization in the metadata.
hss_1_obs_feats_1 = ObservationFeatures.from_arm(arm=self.hss_1_arm_1_flat)
hss_1_obs_feats_1_cast = self.hss_1.cast_observation_features(
observation_features=hss_1_obs_feats_1
)
hss_1_flat_params = hss_1_obs_feats_1.parameters
# Flatten it using hss_2, which requires an additional parameter.
# This will work but miss a parameter.
self.assertEqual(
self.hss_2.flatten_observation_features(
observation_features=hss_1_obs_feats_1_cast,
inject_dummy_values_to_complete_flat_parameterization=False,
).parameters,
hss_1_flat_params,
)
# Now try with inject dummy. This will add the mising param.
hss_2_flat = self.hss_2.flatten_observation_features(
observation_features=hss_1_obs_feats_1_cast,
inject_dummy_values_to_complete_flat_parameterization=True,
).parameters
self.assertNotEqual(hss_2_flat, hss_1_flat_params)
self.assertEqual(
{k: hss_2_flat[k] for k in hss_1_flat_params}, hss_1_flat_params
)
self.assertEqual(set(hss_2_flat.keys()), set(self.hss_2.parameters.keys()))


class TestRobustSearchSpace(TestCase):
def setUp(self) -> None:
Expand Down
26 changes: 22 additions & 4 deletions ax/modelbridge/transforms/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from ax.core.observation import Observation, ObservationFeatures
from ax.core.search_space import HierarchicalSearchSpace, SearchSpace
from ax.exceptions.core import UserInputError
from ax.modelbridge.transforms.base import Transform
from ax.models.types import TConfig
from ax.utils.common.typeutils import checked_cast, not_none
Expand Down Expand Up @@ -47,9 +48,21 @@ def __init__(
config: Optional[TConfig] = None,
) -> None:
self.search_space: SearchSpace = not_none(search_space).clone()
self.flatten_hss: bool = (
config is None or checked_cast(bool, config.get("flatten_hss", True))
) and isinstance(search_space, HierarchicalSearchSpace)
config = (config or {}).copy()
self.flatten_hss: bool = checked_cast(
bool,
config.pop(
"flatten_hss", isinstance(search_space, HierarchicalSearchSpace)
),
)
self.inject_dummy_values_to_complete_flat_parameterization: bool = checked_cast(
bool,
config.pop("inject_dummy_values_to_complete_flat_parameterization", True),
)
if config:
raise UserInputError(
f"Unexpected config parameters for `Cast` transform: {config}."
)

def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
"""Flattens the hierarchical search space and returns the flat
Expand Down Expand Up @@ -90,7 +103,12 @@ def transform_observation_features(
return [
checked_cast(
HierarchicalSearchSpace, self.search_space
).flatten_observation_features(observation_features=obs_feats)
).flatten_observation_features(
observation_features=obs_feats,
inject_dummy_values_to_complete_flat_parameterization=(
self.inject_dummy_values_to_complete_flat_parameterization
),
)
for obs_feats in observation_features
]

Expand Down
25 changes: 25 additions & 0 deletions ax/modelbridge/transforms/tests/test_cast_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RangeParameter,
)
from ax.core.search_space import HierarchicalSearchSpace, SearchSpace
from ax.exceptions.core import UserInputError
from ax.modelbridge.transforms.cast import Cast
from ax.utils.common.constants import Keys
from ax.utils.common.testutils import TestCase
Expand Down Expand Up @@ -68,6 +69,10 @@ def setUp(self) -> None:
metadata=None,
)

def test_invalid_config(self) -> None:
with self.assertRaisesRegex(UserInputError, "Unexpected config"):
Cast(search_space=self.search_space, config={"flatten_hs": "foo"})

def test_transform_observation_features(self) -> None:
# Verify running the transform on already-casted features does nothing
observation_features = [
Expand Down Expand Up @@ -158,6 +163,26 @@ def test_transform_observation_features_HSS(self) -> None:
self.obs_feats_hss.parameters,
)

def test_transform_observation_features_HSS_dummy_values_setting(self) -> None:
t = Cast(
search_space=self.hss,
config={"inject_dummy_values_to_complete_flat_parameterization": True},
observations=[],
)
self.assertTrue(t.inject_dummy_values_to_complete_flat_parameterization)
with patch.object(
t.search_space,
"flatten_observation_features",
wraps=t.search_space.flatten_observation_features, # pyre-ignore
) as mock_flatten_obsf:
t.transform_observation_features(observation_features=[self.obs_feats_hss])
mock_flatten_obsf.assert_called_once()
self.assertTrue(
mock_flatten_obsf.call_args.kwargs[
"inject_dummy_values_to_complete_flat_parameterization"
]
)

def test_untransform_observation_features_HSS(self) -> None:
# Test transformation in one subtree of HSS.
with patch.object(
Expand Down

0 comments on commit 1537a61

Please sign in to comment.