forked from facebook/Ax
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
TimeAsFeature transform (facebook#2438)
Summary: This implements a transform for adding `start_time` and `duration` as features for modeling. Currently, this adds them as `RangeParameter`s (to unblock time-sensitive applications), but in the future it would be good to revise this with a better treatment of non-tunable contextual information. `duration` appears to lead to better model fits on the synthetic example (notebook) than using `end_time`. This also works better than using the midpoint between start time and end time. Reviewed By: bernardbeckerman, Balandat Differential Revision: D57082939
- Loading branch information
1 parent
cea8dc3
commit d92410c
Showing
4 changed files
with
272 additions
and
0 deletions.
There are no files selected for viewing
118 changes: 118 additions & 0 deletions
118
ax/modelbridge/transforms/tests/test_time_as_feature_transform.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
from copy import deepcopy | ||
from unittest import mock | ||
|
||
import numpy as np | ||
from ax.core.observation import Observation, ObservationData, ObservationFeatures | ||
from ax.core.parameter import ParameterType, RangeParameter | ||
from ax.core.search_space import SearchSpace | ||
from ax.exceptions.core import UnsupportedError | ||
from ax.modelbridge.transforms.time_as_feature import TimeAsFeature | ||
from ax.utils.common.testutils import TestCase | ||
from ax.utils.common.timeutils import unixtime_to_pandas_ts | ||
from ax.utils.testing.core_stubs import get_robust_search_space | ||
|
||
|
||
class TimeAsFeatureTransformTest(TestCase): | ||
def setUp(self) -> None: | ||
super().setUp() | ||
self.search_space = SearchSpace( | ||
parameters=[ | ||
RangeParameter( | ||
"x", lower=1, upper=4, parameter_type=ParameterType.FLOAT | ||
) | ||
] | ||
) | ||
self.training_feats = [ | ||
ObservationFeatures( | ||
{"x": i + 1}, | ||
trial_index=i, | ||
start_time=unixtime_to_pandas_ts(float(i)), | ||
end_time=unixtime_to_pandas_ts(float(i + 1 + i)), | ||
) | ||
for i in range(4) | ||
] | ||
self.training_obs = [ | ||
Observation( | ||
data=ObservationData( | ||
metric_names=[], means=np.array([]), covariance=np.empty((0, 0)) | ||
), | ||
features=obsf, | ||
) | ||
for obsf in self.training_feats | ||
] | ||
with mock.patch( | ||
"ax.modelbridge.transforms.time_as_feature.time", return_value=5.0 | ||
): | ||
self.t = TimeAsFeature( | ||
search_space=self.search_space, | ||
observations=self.training_obs, | ||
) | ||
|
||
def test_init(self) -> None: | ||
self.assertEqual(self.t.current_time, 5.0) | ||
self.assertEqual(self.t.min_duration, 1.0) | ||
self.assertEqual(self.t.max_duration, 4.0) | ||
self.assertEqual(self.t.duration_range, 3.0) | ||
self.assertEqual(self.t.min_start_time, 0.0) | ||
self.assertEqual(self.t.max_start_time, 3.0) | ||
|
||
# Test validation | ||
obsf = ObservationFeatures({"x": 2}) | ||
obs = Observation( | ||
data=ObservationData([], np.array([]), np.empty((0, 0))), features=obsf | ||
) | ||
msg = ( | ||
"Unable to use TimeAsFeature since not all observations have " | ||
"start time specified." | ||
) | ||
with self.assertRaisesRegex(ValueError, msg): | ||
TimeAsFeature( | ||
search_space=self.search_space, | ||
observations=self.training_obs + [obs], | ||
) | ||
|
||
t2 = TimeAsFeature( | ||
search_space=self.search_space, | ||
observations=self.training_obs[:1], | ||
) | ||
self.assertEqual(t2.duration_range, 1.0) | ||
|
||
def test_TransformObservationFeatures(self) -> None: | ||
obs_ft1 = deepcopy(self.training_feats) | ||
obs_ft_trans1 = deepcopy(self.training_feats) | ||
for i, obs in enumerate(obs_ft_trans1): | ||
obs.parameters.update({"start_time": float(i), "duration": 1 / 3 * i}) | ||
obs_ft1 = self.t.transform_observation_features(obs_ft1) | ||
self.assertEqual(obs_ft1, obs_ft_trans1) | ||
obs_ft1 = self.t.untransform_observation_features(obs_ft1) | ||
self.assertEqual(obs_ft1, self.training_feats) | ||
|
||
def test_TransformSearchSpace(self) -> None: | ||
ss2 = deepcopy(self.search_space) | ||
ss2 = self.t.transform_search_space(ss2) | ||
self.assertEqual(set(ss2.parameters.keys()), {"x", "start_time", "duration"}) | ||
p = ss2.parameters["start_time"] | ||
self.assertEqual(p.parameter_type, ParameterType.FLOAT) | ||
self.assertEqual(p.lower, 0.0) | ||
self.assertEqual(p.upper, 3.0) | ||
p = ss2.parameters["duration"] | ||
self.assertEqual(p.parameter_type, ParameterType.FLOAT) | ||
self.assertEqual(p.lower, 0.0) | ||
self.assertEqual(p.upper, 1.0) | ||
|
||
def test_w_robust_search_space(self) -> None: | ||
rss = get_robust_search_space() | ||
# Raises an error in __init__. | ||
with self.assertRaisesRegex(UnsupportedError, "transform is not supported"): | ||
TimeAsFeature( | ||
search_space=rss, | ||
observations=[], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
from logging import Logger | ||
from time import time | ||
from typing import List, Optional, TYPE_CHECKING | ||
|
||
import pandas as pd | ||
|
||
from ax.core.observation import Observation, ObservationFeatures | ||
from ax.core.parameter import ParameterType, RangeParameter | ||
from ax.core.search_space import RobustSearchSpace, SearchSpace | ||
from ax.exceptions.core import UnsupportedError | ||
from ax.modelbridge.transforms.base import Transform | ||
from ax.models.types import TConfig | ||
from ax.utils.common.logger import get_logger | ||
from ax.utils.common.timeutils import unixtime_to_pandas_ts | ||
from ax.utils.common.typeutils import checked_cast, not_none | ||
|
||
if TYPE_CHECKING: | ||
# import as module to make sphinx-autodoc-typehints happy | ||
from ax import modelbridge as modelbridge_module # noqa F401 | ||
|
||
|
||
logger: Logger = get_logger(__name__) | ||
|
||
|
||
class TimeAsFeature(Transform): | ||
"""Convert start time and duration into features that can be used for modeling. | ||
If no end_time is available, the current time is used. | ||
Duration is normalized to the unit cube. | ||
Transform is done in-place. | ||
TODO: revise this when better support for non-tunable features is added. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
search_space: Optional[SearchSpace] = None, | ||
observations: Optional[List[Observation]] = None, | ||
modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, | ||
config: Optional[TConfig] = None, | ||
) -> None: | ||
assert observations is not None, "TimeAsFeature requires observations" | ||
if isinstance(search_space, RobustSearchSpace): | ||
raise UnsupportedError( | ||
"TimeAsFeature transform is not supported for RobustSearchSpace." | ||
) | ||
self.min_start_time: float = float("inf") | ||
self.max_start_time: float = float("-inf") | ||
self.min_duration: float = float("inf") | ||
self.max_duration: float = float("-inf") | ||
self.current_time: float = time() | ||
for obs in observations: | ||
obsf = obs.features | ||
if obsf.start_time is None: | ||
raise ValueError( | ||
"Unable to use TimeAsFeature since not all observations have " | ||
"start time specified." | ||
) | ||
start_time = not_none(obsf.start_time).timestamp() | ||
self.min_start_time = min(self.min_start_time, start_time) | ||
self.max_start_time = max(self.max_start_time, start_time) | ||
duration = self._get_duration(start_time=start_time, end_time=obsf.end_time) | ||
self.min_duration = min(self.min_duration, duration) | ||
self.max_duration = max(self.max_duration, duration) | ||
self.duration_range: float = self.max_duration - self.min_duration | ||
if self.duration_range == 0: | ||
# no need to case-distinguish during normalization | ||
self.duration_range = 1.0 | ||
|
||
def _get_duration( | ||
self, start_time: float, end_time: Optional[pd.Timestamp] | ||
) -> float: | ||
return ( | ||
self.current_time if end_time is None else end_time.timestamp() | ||
) - start_time | ||
|
||
def transform_observation_features( | ||
self, observation_features: List[ObservationFeatures] | ||
) -> List[ObservationFeatures]: | ||
for obsf in observation_features: | ||
if obsf.start_time is not None: | ||
start_time = obsf.start_time.timestamp() | ||
obsf.parameters["start_time"] = start_time | ||
duration = self._get_duration( | ||
start_time=start_time, end_time=obsf.end_time | ||
) | ||
# normalize duration to the unit cube | ||
obsf.parameters["duration"] = ( | ||
duration - self.min_duration | ||
) / self.duration_range | ||
return observation_features | ||
|
||
def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace: | ||
for p_name in ("start_time", "duration"): | ||
if p_name in search_space.parameters: | ||
raise ValueError( | ||
f"Parameter name {p_name} is reserved when using " | ||
"TimeAsFeature transform, but is part of the provided " | ||
"search space. Please choose a different name for " | ||
"this parameter." | ||
) | ||
param = RangeParameter( | ||
name="start_time", | ||
parameter_type=ParameterType.FLOAT, | ||
lower=self.min_start_time, | ||
upper=self.max_start_time, | ||
) | ||
search_space.add_parameter(param) | ||
param = RangeParameter( | ||
name="duration", | ||
parameter_type=ParameterType.FLOAT, | ||
# duration is normalized to [0,1] | ||
lower=0.0, | ||
upper=1.0, | ||
) | ||
search_space.add_parameter(param) | ||
return search_space | ||
|
||
def untransform_observation_features( | ||
self, observation_features: List[ObservationFeatures] | ||
) -> List[ObservationFeatures]: | ||
for obsf in observation_features: | ||
start_time = checked_cast(float, obsf.parameters.pop("start_time")) | ||
obsf.start_time = unixtime_to_pandas_ts(start_time) | ||
obsf.end_time = unixtime_to_pandas_ts( | ||
checked_cast(float, obsf.parameters.pop("duration")) | ||
* self.duration_range | ||
+ self.min_duration | ||
+ start_time | ||
) | ||
return observation_features |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters