From 11e47436ac88de9447a4c3e9c2626118b7c19401 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Fri, 17 May 2024 15:53:36 -0700 Subject: [PATCH] Improve typing in SyntheticFunctions (#2470) Summary: A simple Pyre failure was blocking D57473132. I ended up spending way too much time trying out various options for improving typing in this file and settled on this option in the end. Updating the decorator to change the type from `Optional[T] -> T` is something I considered, but it requires the property to be typed as `Optional[T]`, which made me decide to remove the decorator altogether. Differential Revision: D57508193 --- ax/utils/measurement/synthetic_functions.py | 82 ++++++++------------- 1 file changed, 30 insertions(+), 52 deletions(-) diff --git a/ax/utils/measurement/synthetic_functions.py b/ax/utils/measurement/synthetic_functions.py index 893654ee88f..d7dcfb514cf 100644 --- a/ax/utils/measurement/synthetic_functions.py +++ b/ax/utils/measurement/synthetic_functions.py @@ -7,7 +7,7 @@ # pyre-strict from abc import ABC, abstractmethod -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, TypeVar, Union import numpy as np import torch @@ -16,40 +16,26 @@ from botorch.test_functions import synthetic as botorch_synthetic from pyre_extensions import override - -# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. -def informative_failure_on_none(func: Callable) -> Callable: - # pyre-fixme[3]: Return annotation cannot be `Any`. - # pyre-fixme[2]: Parameter must be annotated. - def function_wrapper(*args, **kwargs) -> Any: - res = func(*args, **kwargs) - if res is None: - raise NotImplementedError( - f"{args[0].name} does not specify property " f'"{func.__name__}".' - ) - return not_none(res) - - return function_wrapper +T = TypeVar("T") class SyntheticFunction(ABC): - _required_dimensionality: Optional[int] = None - # pyre-fixme[4]: Attribute must be annotated. - _domain = None - # pyre-fixme[4]: Attribute must be annotated. - _minimums = None - # pyre-fixme[4]: Attribute must be annotated. - _maximums = None - # pyre-fixme[4]: Attribute must be annotated. - _fmin = None - # pyre-fixme[4]: Attribute must be annotated. - _fmax = None + _required_dimensionality: int + _domain: List[Tuple[float, float]] + _minimums: Optional[List[Tuple[float, ...]]] = None + _maximums: Optional[List[Tuple[float, ...]]] = None + _fmin: Optional[float] = None + _fmax: Optional[float] = None + + def informative_failure_on_none(self, attr: Optional[T]) -> T: + if attr is None: + raise NotImplementedError(f"{self.name} does not specify property.") + return not_none(attr) @property - @informative_failure_on_none def name(self) -> str: - return f"{self.__class__.__name__}" + return self.__class__.__name__ def __call__( self, @@ -116,14 +102,12 @@ def f(self, X: np.ndarray) -> Union[float, np.ndarray]: return np.array([self._f(X=x) for x in X]) @property - @informative_failure_on_none - def required_dimensionality(self) -> Optional[int]: + def required_dimensionality(self) -> int: """Required dimensionality of input to this function.""" return self._required_dimensionality @property - @informative_failure_on_none - def domain(self) -> List[Tuple[float, ...]]: + def domain(self) -> List[Tuple[float, float]]: """Domain on which function is evaluated. The list is of the same length as the dimensionality of the inputs, @@ -133,36 +117,32 @@ def domain(self) -> List[Tuple[float, ...]]: return self._domain @property - @informative_failure_on_none def minimums(self) -> List[Tuple[float, ...]]: """List of global minimums. Each element of the list is a d-tuple, where d is the dimensionality of the inputs. There may be more than one global minimums. """ - return self._minimums + return self.informative_failure_on_none(self._minimums) @property - @informative_failure_on_none def maximums(self) -> List[Tuple[float, ...]]: """List of global minimums. Each element of the list is a d-tuple, where d is the dimensionality of the inputs. There may be more than one global minimums. """ - return self._maximums + return self.informative_failure_on_none(self._maximums) @property - @informative_failure_on_none def fmin(self) -> float: """Value at global minimum(s).""" - return self._fmin + return self.informative_failure_on_none(self._fmin) @property - @informative_failure_on_none def fmax(self) -> float: """Value at global minimum(s).""" - return self._fmax + return self.informative_failure_on_none(self._fmax) @abstractmethod def _f(self, X: np.ndarray) -> float: @@ -184,10 +164,8 @@ def __init__( ) -> None: self._botorch_function = botorch_synthetic_function self._required_dimensionality: int = self._botorch_function.dim - self._domain: Optional[List[Tuple[float, float]]] = ( - self._botorch_function._bounds - ) - self._fmin: float = self._botorch_function._optimal_value + self._domain: List[Tuple[float, float]] = self._botorch_function._bounds + self._fmin: Optional[float] = self._botorch_function._optimal_value @override @property @@ -211,7 +189,7 @@ class Hartmann6(SyntheticFunction): """Hartmann6 function (6-dimensional with 1 global minimum).""" _required_dimensionality = 6 - _domain: List[Tuple[int, int]] = [(0, 1) for i in range(6)] + _domain: List[Tuple[float, float]] = [(0.0, 1.0) for i in range(6)] _minimums = [(0.20169, 0.150011, 0.476874, 0.275332, 0.311652, 0.6573)] _fmin: float = -3.32237 _fmax = 0.0 @@ -249,7 +227,7 @@ class Aug_Hartmann6(Hartmann6): """Augmented Hartmann6 function (7-dimensional with 1 global minimum).""" _required_dimensionality = 7 - _domain: List[Tuple[int, int]] = [(0, 1) for i in range(7)] + _domain: List[Tuple[float, float]] = [(0.0, 1.0) for i in range(7)] # pyre-fixme[15]: `_minimums` overrides attribute defined in `Hartmann6` # inconsistently. _minimums = [(0.20169, 0.150011, 0.476874, 0.275332, 0.311652, 0.6573, 1.0)] @@ -276,7 +254,7 @@ class Branin(SyntheticFunction): """Branin function (2-dimensional with 3 global minima).""" _required_dimensionality = 2 - _domain = [(-5, 10), (0, 15)] + _domain: List[Tuple[float, float]] = [(-5.0, 10.0), (0.0, 15.0)] _minimums: List[Tuple[float, float]] = [ (-np.pi, 12.275), (np.pi, 2.275), @@ -301,11 +279,11 @@ class Aug_Branin(SyntheticFunction): """Augmented Branin function (3-dimensional with infinitely many global minima).""" _required_dimensionality = 3 - _domain = [(-5, 10), (0, 15), (0, 1)] - _minimums: List[Tuple[float, float, int]] = [ - (-np.pi, 12.275, 1), - (np.pi, 2.275, 1), - (9.42478, 2.475, 1), + _domain: List[Tuple[float, float]] = [(-5.0, 10.0), (0.0, 15.0), (0.0, 1.0)] + _minimums: List[Tuple[float, float, float]] = [ + (-np.pi, 12.275, 1.0), + (np.pi, 2.275, 1.0), + (9.42478, 2.475, 1.0), ] _fmin = 0.397887 _fmax = 308.129