Skip to content

Commit

Permalink
Add List[str] to TConfig definition (#2360)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2360

This resolves a few Pyre errors.

Reviewed By: saitcakmak

Differential Revision: D56084211

fbshipit-source-id: f845eeb2162d3f421c31330d9766becf3b8c01a7
  • Loading branch information
esantorella authored and facebook-github-bot committed Apr 13, 2024
1 parent cefe7bf commit 4579469
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 20 deletions.
2 changes: 1 addition & 1 deletion ax/modelbridge/transforms/logit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


class Logit(Transform):
"""Apply logit transfor to a float RangeParameter domain.
"""Apply logit transform to a float RangeParameter domain.
Transform is done in-place.
"""
Expand Down
6 changes: 3 additions & 3 deletions ax/modelbridge/transforms/tests/test_log_y_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_TransformObservations(self) -> None:
tf = LogY(
search_space=None,
observations=[],
config={"metrics": ["m3"]}, # pyre-ignore
config={"metrics": ["m3"]},
)
obsd1 = deepcopy(self.obsd1)
obsd1_ = tf._transform_observation_data([obsd1])
Expand Down Expand Up @@ -140,7 +140,7 @@ def test_TransformOptimizationConfig(self) -> None:
tf = LogY(
search_space=None,
observations=self.observations,
config={"metrics": ["m1"]}, # pyre-ignore
config={"metrics": ["m1"]},
)
oc_tf = tf.transform_optimization_config(deepcopy(oc), None, None)
self.assertEqual(oc_tf, oc)
Expand Down Expand Up @@ -229,7 +229,7 @@ def test_TransformOptimizationConfigMOO(self) -> None:
tf = LogY(
search_space=None,
observations=self.observations,
config={"metrics": ["m1"]}, # pyre-ignore
config={"metrics": ["m1"]},
)
oc_tf = tf.transform_optimization_config(deepcopy(oc), None, None)
oc.objective_thresholds[0].bound = math.log(1.234)
Expand Down
19 changes: 4 additions & 15 deletions ax/modelbridge/transforms/tests/test_power_y_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,6 @@ def test_Init(self) -> None:
PowerTransformY(**shared_init_args, config={})
# Test default init
for m in ["m1", "m2"]:
# pyre-fixme[6]: For 1st param expected `List[ObservationData]` but got
# `Optional[List[ObservationData]]`.
# pyre-fixme[6]: For 1st param expected `List[ObservationFeatures]` but
# got `Optional[List[ObservationData]]`.
# pyre-fixme[6]: For 1st param expected `Optional[ModelBridge]` but got
# `Optional[List[ObservationData]]`.
# pyre-fixme[6]: For 1st param expected `SearchSpace` but got
# `Optional[List[ObservationData]]`.
# pyre-fixme[6]: For 2nd param expected `Optional[Dict[str, Union[None,
# Dict[str, typing.Any], OptimizationConfig, AcquisitionFunction, float,
# int, str]]]` but got `Dict[str, List[str]]`.
tf = PowerTransformY(**shared_init_args, config={"metrics": [m]})
# tf.power_transforms should only exist for m and be a PowerTransformer
self.assertIsInstance(tf.power_transforms, dict)
Expand Down Expand Up @@ -187,7 +176,7 @@ def test_TransformAndUntransformOneMetric(self) -> None:
pt = PowerTransformY(
search_space=None,
observations=deepcopy(self.observations[:2]),
config={"metrics": ["m1"]}, # pyre-ignore
config={"metrics": ["m1"]},
)

# Transform the data and make sure we don't touch m1
Expand Down Expand Up @@ -217,7 +206,7 @@ def test_TransformAndUntransformAllMetrics(self) -> None:
pt = PowerTransformY(
search_space=None,
observations=deepcopy(self.observations[:2]),
config={"metrics": ["m1", "m2"]}, # pyre-ignore
config={"metrics": ["m1", "m2"]},
)

observation_data_tf = pt._transform_observation_data(
Expand Down Expand Up @@ -253,7 +242,7 @@ def test_CompareToSklearn(self) -> None:
pt = PowerTransformY(
search_space=None,
observations=deepcopy(self.observations[:3]),
config={"metrics": ["m1"]}, # pyre-ignore
config={"metrics": ["m1"]},
)
observation_data_tf = pt._transform_observation_data(observation_data)
y2 = [data.means[0] for data in observation_data_tf]
Expand All @@ -268,7 +257,7 @@ def test_TransformOptimizationConfig(self) -> None:
tf = PowerTransformY(
search_space=None,
observations=self.observations[:2],
config={"metrics": ["m1"]}, # pyre-ignore
config={"metrics": ["m1"]},
)
oc_tf = tf.transform_optimization_config(deepcopy(oc), None, None)
self.assertEqual(oc_tf, oc)
Expand Down
3 changes: 2 additions & 1 deletion ax/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# pyre-strict

from typing import Any, Dict, Union
from typing import Any, Dict, List, Union

from ax.core.optimization_config import OptimizationConfig
from ax.models.winsorization_config import WinsorizationConfig
Expand All @@ -20,6 +20,7 @@
float,
str,
AcquisitionFunction,
List[str],
Dict[int, Any],
Dict[str, Any],
OptimizationConfig,
Expand Down

0 comments on commit 4579469

Please sign in to comment.