Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update HSS dummy value logic & expose it in Cast #2362

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 47 additions & 37 deletions ax/core/search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
)
from ax.core.parameter_distribution import ParameterDistribution
from ax.core.types import TParameterization
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.exceptions.core import AxWarning, UnsupportedError, UserInputError
from ax.utils.common.base import Base
from ax.utils.common.constants import Keys
from ax.utils.common.logger import get_logger
Expand Down Expand Up @@ -508,44 +508,49 @@ def flatten_observation_features(
observation_features: Observation features corresponding to one point
to flatten.
inject_dummy_values_to_complete_flat_parameterization: Whether to inject
values for parameters that are not in the parameterization if they
are not recorded in the observation features' metadata (this can
happen if e.g. the point wasn't generated by Ax but attached manually).
values for parameters that are not in the parameterization.
This will be used to complete the parameterization after re-injecting
the parameters that are recorded in the metadata (for parameters
that were generated by Ax).
"""
obs_feats = observation_features
if obs_feats.metadata and Keys.FULL_PARAMETERIZATION in obs_feats.metadata:
# NOTE: We could just use the full parameterization as stored;
# opting for a safer option of only injecting parameters that were
# removed, but not altering those that are present if they have different
# values in full parameterization as stored in metadata.
has_full_parameterization = Keys.FULL_PARAMETERIZATION in (
obs_feats.metadata or {}
)

if obs_feats.parameters == {} and not has_full_parameterization:
# Return as is if the observation feature does not have any parameters.
return obs_feats

if has_full_parameterization:
# If full parameterization is recorded, use it to fill in missing values.
full_parameterization = not_none(obs_feats.metadata)[
Keys.FULL_PARAMETERIZATION
]
obs_feats.parameters = {**full_parameterization, **obs_feats.parameters}
return obs_feats

if obs_feats.parameters == {}:
# Return as is if the observation feature does not have any parameters.
return obs_feats

if inject_dummy_values_to_complete_flat_parameterization:
# To cast a parameterization to flattened search space, inject dummy values
# for parameters that were not present in it.
dummy_values_to_inject = (
self._gen_dummy_values_to_complete_flat_parameterization(
observation_features=obs_feats
if len(obs_feats.parameters) < len(self.parameters):
if inject_dummy_values_to_complete_flat_parameterization:
# Inject dummy values for parameters missing from the parameterization.
dummy_values_to_inject = (
self._gen_dummy_values_to_complete_flat_parameterization(
observation_features=obs_feats
)
)
obs_feats.parameters = {
**dummy_values_to_inject,
**obs_feats.parameters,
}
else:
# The parameterization is still incomplete.
warnings.warn(
f"Cannot flatten observation features {obs_feats} as full "
"parameterization is not recorded in metadata and "
"`inject_dummy_values_to_complete_flat_parameterization` is "
"set to False.",
AxWarning,
stacklevel=2,
)
)
obs_feats.parameters = {**dummy_values_to_inject, **obs_feats.parameters}
return obs_feats

# We did not have the full parameterization stored, so we either return the
# observation features as given without change, or we inject dummy values if
# that behavior was requested via the opt-in flag.
warnings.warn(
f"Cannot flatten observation features {obs_feats} as full "
"parameterization is not recorded in metadata."
)
return obs_feats

def check_membership(
Expand Down Expand Up @@ -664,10 +669,11 @@ def _cast_parameterization(

Args:
parameters: Parameterization to cast to hierarchical structure.
check_all_parameters_present: Whether to raise an error if a paramete
that is expected to be present (according to values of other
parameters and the hierarchical structure of the search space)
is not specified.
check_all_parameters_present: Whether to raise an error if a parameter
that is expected to be present (according to values of other
parameters and the hierarchical structure of the search space)
is not specified. When this is False, if a parameter is missing,
its dependents will not be included in the returned parameterization.
"""
error_msg_prefix: str = (
f"Parameterization {parameters} violates the hierarchical structure "
Expand All @@ -682,11 +688,15 @@ def _find_applicable_parameters(root: Parameter) -> Set[str]:
+ f"Parameter '{root.name}' not in parameterization to cast."
)

if not root.is_hierarchical:
# Return if the root parameter is not hierarchical or if it is not
# in the parameterization to cast.
if not root.is_hierarchical or root.name not in parameters:
return applicable

# Find the dependents of the current root parameter.
root_val = parameters[root.name]
for val, deps in root.dependents.items():
if parameters[root.name] == val:
if root_val == val:
for dep in deps:
applicable.update(_find_applicable_parameters(root=self[dep]))

Expand Down
113 changes: 94 additions & 19 deletions ax/core/tests/test_search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from unittest import mock

import pandas as pd

from ax.core.arm import Arm
from ax.core.observation import ObservationFeatures
from ax.core.parameter import (
Expand All @@ -36,6 +35,7 @@
SearchSpace,
SearchSpaceDigest,
)
from ax.core.types import TParameterization
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.utils.common.constants import Keys
from ax.utils.common.testutils import TestCase
Expand Down Expand Up @@ -621,13 +621,12 @@ def setUp(self) -> None:
"num_boost_rounds": 12,
}
)
self.hss_1_arm_missing_param = Arm(
parameters={
"model": "Linear",
"l2_reg_weight": 0.0001,
"num_boost_rounds": 12,
}
)
self.hss_1_missing_params: TParameterization = {
"model": "Linear",
"l2_reg_weight": 0.0001,
"num_boost_rounds": 12,
}
self.hss_1_arm_missing_param = Arm(parameters=self.hss_1_missing_params)
self.hss_1_arm_1_cast = Arm(
parameters={
"model": "Linear",
Expand Down Expand Up @@ -759,6 +758,7 @@ def test_flatten(self) -> None:
self.assertTrue(str(flattened_hss_with_constraints).startswith("SearchSpace"))

def test_cast_arm(self) -> None:
# This uses _cast_parameterization with check_all_parameters_present=True.
self.assertEqual( # Check one subtree.
self.hss_1._cast_arm(arm=self.hss_1_arm_1_flat),
self.hss_1_arm_1_cast,
Expand All @@ -775,6 +775,7 @@ def test_cast_arm(self) -> None:
self.hss_1._cast_arm(arm=self.hss_1_arm_missing_param)

def test_cast_observation_features(self) -> None:
# This uses _cast_parameterization with check_all_parameters_present=False.
# Ensure that during casting, full parameterization is saved
# in metadata and actual parameterization is cast to HSS.
hss_1_obs_feats_1 = ObservationFeatures.from_arm(arm=self.hss_1_arm_1_flat)
Expand All @@ -798,6 +799,35 @@ def test_cast_observation_features(self) -> None:
ObservationFeatures.from_arm(arm=self.hss_1_arm_1_cast),
)

def test_cast_parameterization(self) -> None:
# NOTE: This is also tested in test_cast_arm & test_cast_observation_features.
with self.assertRaisesRegex(RuntimeError, "not in parameterization to cast"):
self.hss_1._cast_parameterization(
parameters=self.hss_1_missing_params,
check_all_parameters_present=True,
)
# An active leaf param is missing, it'll get ignored. There's an inactive
# leaf param, that'll just get filtered out.
self.assertEqual(
self.hss_1._cast_parameterization(
parameters=self.hss_1_missing_params,
check_all_parameters_present=False,
),
{"l2_reg_weight": 0.0001, "model": "Linear"},
)
# A hierarchical param is missing, all its dependents will be ignored.
# In this case, it is the root param, so we'll have empty parameterization.
self.assertEqual(
self.hss_1._cast_parameterization(
parameters={
"l2_reg_weight": 0.0001,
"num_boost_rounds": 12,
},
check_all_parameters_present=False,
),
{},
)

def test_flatten_observation_features(self) -> None:
# Ensure that during casting, full parameterization is saved
# in metadata and actual parameterization is cast to HSS; during
Expand All @@ -807,25 +837,40 @@ def test_flatten_observation_features(self) -> None:
hss_1_obs_feats_1_cast = self.hss_1.cast_observation_features(
observation_features=hss_1_obs_feats_1
)
hss_1_obs_feats_1_flattened = self.hss_1.flatten_observation_features(
observation_features=hss_1_obs_feats_1_cast
)
self.assertEqual( # Cast-flatten roundtrip.
hss_1_obs_feats_1.parameters,
hss_1_obs_feats_1_flattened.parameters,
)
self.assertEqual( # Check that both cast and flattened have full params.
hss_1_obs_feats_1_cast.metadata.get(Keys.FULL_PARAMETERIZATION),
hss_1_obs_feats_1_flattened.metadata.get(Keys.FULL_PARAMETERIZATION),
)
for inject_dummy in (True, False):
hss_1_obs_feats_1_flattened = self.hss_1.flatten_observation_features(
observation_features=hss_1_obs_feats_1_cast,
inject_dummy_values_to_complete_flat_parameterization=inject_dummy,
)
self.assertEqual( # Cast-flatten roundtrip.
hss_1_obs_feats_1.parameters,
hss_1_obs_feats_1_flattened.parameters,
)
self.assertEqual( # Check that both cast and flattened have full params.
hss_1_obs_feats_1_cast.metadata.get(Keys.FULL_PARAMETERIZATION),
hss_1_obs_feats_1_flattened.metadata.get(Keys.FULL_PARAMETERIZATION),
)
# Check that flattening observation features without metadata does nothing.
# Does not warn here since it already has all parameters.
with warnings.catch_warnings(record=True) as ws:
self.assertEqual(
self.hss_1.flatten_observation_features(
observation_features=hss_1_obs_feats_1
),
hss_1_obs_feats_1,
)
self.assertFalse(
any("Cannot flatten observation features" in str(w.message) for w in ws)
)
# This one warns since it is missing some parameters.
obs_ft_missing = ObservationFeatures.from_arm(arm=self.hss_1_arm_missing_param)
with warnings.catch_warnings(record=True) as ws:
self.assertEqual(
self.hss_1.flatten_observation_features(
observation_features=obs_ft_missing
),
obs_ft_missing,
)
self.assertTrue(
any("Cannot flatten observation features" in str(w.message) for w in ws)
)
Expand Down Expand Up @@ -922,6 +967,36 @@ def test_flatten_observation_features_inject_dummy_parameter_values(
set(self.hss_with_fixed.parameters.keys()),
)

def test_flatten_observation_features_full_and_dummy(self) -> None:
# Test flattening when both full features & inject dummy values
# are specified. This is relevant if the full parameterization
# is from some subset of the search space.
# Get an obs feat with hss_1 parameterization in the metadata.
hss_1_obs_feats_1 = ObservationFeatures.from_arm(arm=self.hss_1_arm_1_flat)
hss_1_obs_feats_1_cast = self.hss_1.cast_observation_features(
observation_features=hss_1_obs_feats_1
)
hss_1_flat_params = hss_1_obs_feats_1.parameters
# Flatten it using hss_2, which requires an additional parameter.
# This will work but miss a parameter.
self.assertEqual(
self.hss_2.flatten_observation_features(
observation_features=hss_1_obs_feats_1_cast,
inject_dummy_values_to_complete_flat_parameterization=False,
).parameters,
hss_1_flat_params,
)
# Now try with inject dummy. This will add the mising param.
hss_2_flat = self.hss_2.flatten_observation_features(
observation_features=hss_1_obs_feats_1_cast,
inject_dummy_values_to_complete_flat_parameterization=True,
).parameters
self.assertNotEqual(hss_2_flat, hss_1_flat_params)
self.assertEqual(
{k: hss_2_flat[k] for k in hss_1_flat_params}, hss_1_flat_params
)
self.assertEqual(set(hss_2_flat.keys()), set(self.hss_2.parameters.keys()))


class TestRobustSearchSpace(TestCase):
def setUp(self) -> None:
Expand Down
26 changes: 22 additions & 4 deletions ax/modelbridge/transforms/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from ax.core.observation import Observation, ObservationFeatures
from ax.core.search_space import HierarchicalSearchSpace, SearchSpace
from ax.exceptions.core import UserInputError
from ax.modelbridge.transforms.base import Transform
from ax.models.types import TConfig
from ax.utils.common.typeutils import checked_cast, not_none
Expand Down Expand Up @@ -47,9 +48,21 @@ def __init__(
config: Optional[TConfig] = None,
) -> None:
self.search_space: SearchSpace = not_none(search_space).clone()
self.flatten_hss: bool = (
config is None or checked_cast(bool, config.get("flatten_hss", True))
) and isinstance(search_space, HierarchicalSearchSpace)
config = (config or {}).copy()
self.flatten_hss: bool = checked_cast(
bool,
config.pop(
"flatten_hss", isinstance(search_space, HierarchicalSearchSpace)
),
)
self.inject_dummy_values_to_complete_flat_parameterization: bool = checked_cast(
bool,
config.pop("inject_dummy_values_to_complete_flat_parameterization", True),
)
if config:
raise UserInputError(
f"Unexpected config parameters for `Cast` transform: {config}."
)

def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
"""Flattens the hierarchical search space and returns the flat
Expand Down Expand Up @@ -90,7 +103,12 @@ def transform_observation_features(
return [
checked_cast(
HierarchicalSearchSpace, self.search_space
).flatten_observation_features(observation_features=obs_feats)
).flatten_observation_features(
observation_features=obs_feats,
inject_dummy_values_to_complete_flat_parameterization=(
self.inject_dummy_values_to_complete_flat_parameterization
),
)
for obs_feats in observation_features
]

Expand Down
25 changes: 25 additions & 0 deletions ax/modelbridge/transforms/tests/test_cast_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RangeParameter,
)
from ax.core.search_space import HierarchicalSearchSpace, SearchSpace
from ax.exceptions.core import UserInputError
from ax.modelbridge.transforms.cast import Cast
from ax.utils.common.constants import Keys
from ax.utils.common.testutils import TestCase
Expand Down Expand Up @@ -68,6 +69,10 @@ def setUp(self) -> None:
metadata=None,
)

def test_invalid_config(self) -> None:
with self.assertRaisesRegex(UserInputError, "Unexpected config"):
Cast(search_space=self.search_space, config={"flatten_hs": "foo"})

def test_transform_observation_features(self) -> None:
# Verify running the transform on already-casted features does nothing
observation_features = [
Expand Down Expand Up @@ -158,6 +163,26 @@ def test_transform_observation_features_HSS(self) -> None:
self.obs_feats_hss.parameters,
)

def test_transform_observation_features_HSS_dummy_values_setting(self) -> None:
t = Cast(
search_space=self.hss,
config={"inject_dummy_values_to_complete_flat_parameterization": True},
observations=[],
)
self.assertTrue(t.inject_dummy_values_to_complete_flat_parameterization)
with patch.object(
t.search_space,
"flatten_observation_features",
wraps=t.search_space.flatten_observation_features, # pyre-ignore
) as mock_flatten_obsf:
t.transform_observation_features(observation_features=[self.obs_feats_hss])
mock_flatten_obsf.assert_called_once()
self.assertTrue(
mock_flatten_obsf.call_args.kwargs[
"inject_dummy_values_to_complete_flat_parameterization"
]
)

def test_untransform_observation_features_HSS(self) -> None:
# Test transformation in one subtree of HSS.
with patch.object(
Expand Down
Loading