Skip to content

Commit

Permalink
Typing improvements to RangeParameter (#2327)
Browse files Browse the repository at this point in the history
Summary:

This commit improves the typing coverage for `RangeParameter` and its downstream applications, as well as getters and setters four its bounds that include bounds validation.

Reviewed By: bernardbeckerman

Differential Revision: D55805080
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Apr 12, 2024
1 parent c122890 commit 22d132e
Show file tree
Hide file tree
Showing 12 changed files with 89 additions and 90 deletions.
5 changes: 1 addition & 4 deletions ax/core/batch_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,10 +482,7 @@ def is_factorial(self) -> bool:
param_levels: DefaultDict[str, Dict[Union[str, float], int]] = defaultdict(dict)
for arm in self.arms:
for param_name, param_value in arm.parameters.items():
# Expected `Union[float, str]` for 2nd anonymous parameter to call
# `dict.__setitem__` but got `Optional[Union[bool, float, str]]`.
# pyre-fixme[6]: Expected `Union[float, str]` for 1st param but got `...
param_levels[param_name][param_value] = 1
param_levels[param_name][not_none(param_value)] = 1
param_cardinality = 1
for param_values in param_levels.values():
param_cardinality *= len(param_values)
Expand Down
108 changes: 56 additions & 52 deletions ax/core/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
from copy import deepcopy
from enum import Enum
from math import inf
from typing import Dict, List, Optional, Tuple, Type, Union
from typing import cast, Dict, List, Optional, Tuple, Type, Union
from warnings import warn

from ax.core.types import TParamValue, TParamValueList
from ax.core.types import TNumeric, TParamValue, TParamValueList
from ax.exceptions.core import AxWarning, UserInputError
from ax.utils.common.base import SortableBase
from ax.utils.common.typeutils import not_none
from pyre_extensions import assert_is_instance

# Tolerance for floating point comparisons. This is relatively permissive,
# and allows for serializing at rather low numerical precision.
Expand Down Expand Up @@ -81,7 +82,7 @@ def _get_parameter_type(python_type: Type) -> ParameterType:
class Parameter(SortableBase, metaclass=ABCMeta):
_is_fidelity: bool = False
_name: str
_target_value: Optional[TParamValue] = None
_target_value: TParamValue = None
_parameter_type: ParameterType

def cast(self, value: TParamValue) -> TParamValue:
Expand Down Expand Up @@ -125,7 +126,7 @@ def is_hierarchical(self) -> bool:
)

@property
def target_value(self) -> Optional[TParamValue]:
def target_value(self) -> TParamValue:
return self._target_value

@property
Expand Down Expand Up @@ -234,7 +235,7 @@ def __init__(
logit_scale: bool = False,
digits: Optional[int] = None,
is_fidelity: bool = False,
target_value: Optional[TParamValue] = None,
target_value: TParamValue = None,
) -> None:
"""Initialize RangeParameter
Expand All @@ -259,17 +260,16 @@ def __init__(
)

self._name = name
if parameter_type not in (ParameterType.INT, ParameterType.FLOAT):
raise UserInputError("RangeParameter type must be int or float.")
self._parameter_type = parameter_type
self._digits = digits
# pyre-fixme[4]: Attribute must be annotated.
self._lower = self.cast(lower)
# pyre-fixme[4]: Attribute must be annotated.
self._upper = self.cast(upper)
self._lower: TNumeric = not_none(self.cast(lower))
self._upper: TNumeric = not_none(self.cast(upper))
self._log_scale = log_scale
self._logit_scale = logit_scale
self._is_fidelity = is_fidelity
# pyre-fixme[4]: Attribute must be annotated.
self._target_value = self.cast(target_value)
self._target_value: Optional[TNumeric] = self.cast(target_value)

self._validate_range_param(
parameter_type=parameter_type,
Expand All @@ -279,16 +279,16 @@ def __init__(
logit_scale=logit_scale,
)

def cardinality(self) -> float:
def cardinality(self) -> TNumeric:
if self.parameter_type == ParameterType.FLOAT:
return inf

return self.upper - self.lower + 1
return int(self.upper) - int(self.lower) + 1

def _validate_range_param(
self,
lower: TParamValue,
upper: TParamValue,
lower: TNumeric,
upper: TNumeric,
log_scale: bool,
logit_scale: bool,
parameter_type: Optional[ParameterType] = None,
Expand All @@ -298,15 +298,13 @@ def _validate_range_param(
ParameterType.FLOAT,
):
raise UserInputError("RangeParameter type must be int or float.")
# pyre-fixme[58]: `>=` is not supported for operand types `Union[None, bool,
# float, int, str]` and `Union[None, bool, float, int, str]`.

upper = float(upper)
if lower >= upper:
raise UserInputError(
f"Upper bound of {self.name} must be strictly larger than lower."
f"Got: ({lower}, {upper})."
)
# pyre-fixme[58]: `-` is not supported for operand types `Union[None, bool,
# float, int, str]` and `Union[None, bool, float, int, str]`.
width: float = upper - lower
if width < 100 * EPS:
raise UserInputError(
Expand All @@ -316,12 +314,8 @@ def _validate_range_param(
)
if log_scale and logit_scale:
raise UserInputError("Can't use both log and logit.")
# pyre-fixme[58]: `<=` is not supported for operand types `Union[None, bool,
# float, int, str]` and `int`.
if log_scale and lower <= 0:
raise UserInputError("Cannot take log when min <= 0.")
# pyre-fixme[58]: `<=` is not supported for operand types `Union[None, bool,
# float, int, str]` and `int`.
if logit_scale and (lower <= 0 or upper >= 1):
raise UserInputError("Logit requires lower > 0 and upper < 1")
if not (self.is_valid_type(lower)) or not (self.is_valid_type(upper)):
Expand All @@ -330,23 +324,43 @@ def _validate_range_param(
)

@property
def upper(self) -> float:
def upper(self) -> TNumeric:
"""Upper bound of the parameter range.
Value is cast to parameter type upon set and also validated
to ensure the bound is strictly greater than lower bound.
"""
return self._upper

@upper.setter
def upper(self, value: TNumeric) -> None:
self._validate_range_param(
lower=self.lower,
upper=value,
log_scale=self.log_scale,
logit_scale=self.logit_scale,
)
self._upper = not_none(self.cast(value))

@property
def lower(self) -> float:
def lower(self) -> TNumeric:
"""Lower bound of the parameter range.
Value is cast to parameter type upon set and also validated
to ensure the bound is strictly less than upper bound.
"""
return self._lower

@lower.setter
def lower(self, value: TNumeric) -> None:
self._validate_range_param(
lower=value,
upper=self.upper,
log_scale=self.log_scale,
logit_scale=self.logit_scale,
)
self._lower = not_none(self.cast(value))

@property
def digits(self) -> Optional[int]:
"""Number of digits to round values to for float type.
Expand Down Expand Up @@ -381,8 +395,8 @@ def update_range(
if upper is None:
upper = self._upper

cast_lower = self.cast(lower)
cast_upper = self.cast(upper)
cast_lower = not_none(self.cast(lower))
cast_upper = not_none(self.cast(upper))
self._validate_range_param(
lower=cast_lower,
upper=cast_upper,
Expand All @@ -397,10 +411,8 @@ def set_digits(self, digits: int) -> RangeParameter:
self._digits = digits

# Re-scale min and max to new digits definition
cast_lower = self.cast(self._lower)
cast_upper = self.cast(self._upper)
# pyre-fixme[58]: `>=` is not supported for operand types `Union[None, bool,
# float, int, str]` and `Union[None, bool, float, int, str]`.
cast_lower = not_none(self.cast(self._lower))
cast_upper = not_none(self.cast(self._upper))
if cast_lower >= cast_upper:
raise UserInputError(
f"Lower bound {cast_lower} is >= upper bound {cast_upper}."
Expand Down Expand Up @@ -451,9 +463,7 @@ def is_valid_type(self, value: TParamValue) -> bool:

# This might have issues with ints > 2^24
if self.parameter_type is ParameterType.INT:
# pyre-fixme[6]: Expected `Union[_SupportsIndex, bytearray, bytes, str,
# typing.SupportsFloat]` for 1st param but got `Union[None, float, str]`.
return isinstance(value, int) or float(value).is_integer()
return isinstance(value, int) or float(not_none(value)).is_integer()
return True

def clone(self) -> RangeParameter:
Expand All @@ -469,13 +479,12 @@ def clone(self) -> RangeParameter:
target_value=self._target_value,
)

def cast(self, value: TParamValue) -> TParamValue:
def cast(self, value: TParamValue) -> Optional[TNumeric]:
if value is None:
return None
if self.parameter_type is ParameterType.FLOAT and self._digits is not None:
# pyre-fixme[6]: Expected `None` for 2nd param but got `Optional[int]`.
return round(float(value), self._digits)
return self.python_type(value)
return round(float(value), not_none(self._digits))
return assert_is_instance(self.python_type(value), TNumeric)

def __repr__(self) -> str:
ret_val = self._base_repr()
Expand Down Expand Up @@ -526,7 +535,7 @@ def __init__(
is_ordered: Optional[bool] = None,
is_task: bool = False,
is_fidelity: bool = False,
target_value: Optional[TParamValue] = None,
target_value: TParamValue = None,
sort_values: Optional[bool] = None,
dependents: Optional[Dict[TParamValue, List[str]]] = None,
) -> None:
Expand Down Expand Up @@ -561,9 +570,8 @@ def __init__(
stacklevel=2,
)
values = list(dict_values)
self._values: List[TParamValue] = self._cast_values(values)
# pyre-fixme[4]: Attribute must be annotated.
self._is_ordered = (

self._is_ordered: bool = (
is_ordered
if is_ordered is not None
else self._get_default_bool_and_warn(param_string="is_ordered")
Expand All @@ -575,11 +583,9 @@ def __init__(
else self._get_default_bool_and_warn(param_string="sort_values")
)
if self.sort_values:
# pyre-ignore[6]: values/self._values expects List[Union[None, bool, float,
# int, str]] but sorted() takes/returns
# List[Variable[_typeshed.SupportsLessThanT (bound to
# _typeshed.SupportsLessThan)]]
self._values = self._cast_values(sorted(values))
values = cast(List[TParamValue], sorted([not_none(v) for v in values]))
self._values: List[TParamValue] = self._cast_values(values)

if dependents:
for value in dependents:
if value not in self.values:
Expand Down Expand Up @@ -714,7 +720,7 @@ def __init__(
parameter_type: ParameterType,
value: TParamValue,
is_fidelity: bool = False,
target_value: Optional[TParamValue] = None,
target_value: TParamValue = None,
dependents: Optional[Dict[TParamValue, List[str]]] = None,
) -> None:
"""Initialize FixedParameter
Expand All @@ -737,11 +743,9 @@ def __init__(

self._name = name
self._parameter_type = parameter_type
# pyre-fixme[4]: Attribute must be annotated.
self._value = self.cast(value)
self._value: TParamValue = self.cast(value)
self._is_fidelity = is_fidelity
# pyre-fixme[4]: Attribute must be annotated.
self._target_value = self.cast(target_value)
self._target_value: TParamValue = self.cast(target_value)
# NOTE: We don't need to check that dependent parameters actually exist as
# that is done in `HierarchicalSearchSpace` constructor.
if dependents:
Expand Down
3 changes: 1 addition & 2 deletions ax/core/parameter_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@ class ParameterConstraint(SortableBase):
Constraints are expressed using a map from parameter name to weight
followed by a bound.
The constraint is satisfied if w * v <= b where:
The constraint is satisfied if sum_i(w_i * v_i) <= b where:
w is the vector of parameter weights.
v is a vector of parameter values.
b is the specified bound.
* is the dot product operator.
"""

def __init__(self, constraint_dict: Dict[str, float], bound: float) -> None:
Expand Down
6 changes: 3 additions & 3 deletions ax/core/search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,7 @@ def check_membership(

# parameter constraints only accept numeric parameters
numerical_param_dict = {
# pyre-fixme[6]: Expected `typing.Union[...oat]` but got `unknown`.
name: float(value)
name: float(not_none(value))
for name, value in parameterization.items()
if self.parameters[name].is_numeric
}
Expand Down Expand Up @@ -544,7 +543,8 @@ def flatten_observation_features(
# that behavior was requested via the opt-in flag.
warnings.warn(
f"Cannot flatten observation features {obs_feats} as full "
"parameterization is not recorded in metadata."
"parameterization is not recorded in metadata.",
stacklevel=2,
)
return obs_feats

Expand Down
4 changes: 3 additions & 1 deletion ax/core/tests/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

# pyre-strict

from typing import cast, List

from ax.core.parameter import (
_get_parameter_type,
ChoiceParameter,
Expand Down Expand Up @@ -299,7 +301,7 @@ def test_Properties(self) -> None:
)
self.assertTrue(int_param.is_ordered)
self.assertListEqual(
int_param.values, sorted(int_param.values) # pyre-fixme[6]
int_param.values, sorted(cast(List[int], int_param.values))
)
float_param = ChoiceParameter(
name="x", parameter_type=ParameterType.FLOAT, values=[1.5, 2.5, 3.5]
Expand Down
2 changes: 1 addition & 1 deletion ax/modelbridge/dispatch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def _suggest_gp_model(
if parameter.parameter_type == ParameterType.FLOAT:
all_range_parameters_are_discrete = False
else:
num_param_discrete_values = int(parameter.upper - parameter.lower) + 1
num_param_discrete_values = parameter.cardinality()
num_possible_points *= num_param_discrete_values

if should_enumerate_param:
Expand Down
13 changes: 8 additions & 5 deletions ax/modelbridge/transforms/int_range_to_choice.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

# pyre-strict

from typing import Dict, List, Optional, Set, TYPE_CHECKING
from numbers import Real
from typing import cast, Dict, List, Optional, Set, TYPE_CHECKING

from ax.core.observation import Observation
from ax.core.parameter import ChoiceParameter, Parameter, ParameterType, RangeParameter
Expand Down Expand Up @@ -35,14 +36,16 @@ def __init__(
) -> None:
assert search_space is not None, "IntRangeToChoice requires search space"
config = config or {}
self.max_choices: float = config.get("max_choices", float("inf")) # pyre-ignore
self.max_choices: float = float(
cast(Real, (config.get("max_choices", float("inf"))))
)
# Identify parameters that should be transformed
self.transform_parameters: Set[str] = {
p_name
for p_name, p in search_space.parameters.items()
if isinstance(p, RangeParameter)
and p.parameter_type == ParameterType.INT
and p.upper - p.lower + 1 <= self.max_choices
and p.cardinality() <= self.max_choices
}

def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
Expand All @@ -52,9 +55,9 @@ def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
p_name in self.transform_parameters
and isinstance(p, RangeParameter)
and p.parameter_type == ParameterType.INT
and p.upper - p.lower + 1 <= self.max_choices
and p.cardinality() <= self.max_choices
):
values = list(range(p.lower, p.upper + 1)) # pyre-ignore
values = list(range(int(p.lower), int(p.upper) + 1))
target_value = (
None
if p.target_value is None
Expand Down
2 changes: 1 addition & 1 deletion ax/modelbridge/transforms/int_to_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(
for p_name, p in self.search_space.parameters.items()
if isinstance(p, RangeParameter)
and p.parameter_type == ParameterType.INT
and (p.upper - p.lower + 1 >= self.min_choices or p.log_scale)
and ((p.cardinality() >= self.min_choices) or p.log_scale)
}
if contains_constrained_integer(self.search_space, self.transform_parameters):
self.rounding = "randomized"
Expand Down
Loading

0 comments on commit 22d132e

Please sign in to comment.