Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve typing in SyntheticFunctions #2470

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 30 additions & 52 deletions ax/utils/measurement/synthetic_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# pyre-strict

from abc import ABC, abstractmethod
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import List, Optional, Tuple, TypeVar, Union

import numpy as np
import torch
Expand All @@ -16,40 +16,26 @@
from botorch.test_functions import synthetic as botorch_synthetic
from pyre_extensions import override


# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
def informative_failure_on_none(func: Callable) -> Callable:
# pyre-fixme[3]: Return annotation cannot be `Any`.
# pyre-fixme[2]: Parameter must be annotated.
def function_wrapper(*args, **kwargs) -> Any:
res = func(*args, **kwargs)
if res is None:
raise NotImplementedError(
f"{args[0].name} does not specify property " f'"{func.__name__}".'
)
return not_none(res)

return function_wrapper
T = TypeVar("T")


class SyntheticFunction(ABC):

_required_dimensionality: Optional[int] = None
# pyre-fixme[4]: Attribute must be annotated.
_domain = None
# pyre-fixme[4]: Attribute must be annotated.
_minimums = None
# pyre-fixme[4]: Attribute must be annotated.
_maximums = None
# pyre-fixme[4]: Attribute must be annotated.
_fmin = None
# pyre-fixme[4]: Attribute must be annotated.
_fmax = None
_required_dimensionality: int
_domain: List[Tuple[float, float]]
_minimums: Optional[List[Tuple[float, ...]]] = None
_maximums: Optional[List[Tuple[float, ...]]] = None
_fmin: Optional[float] = None
_fmax: Optional[float] = None

def informative_failure_on_none(self, attr: Optional[T]) -> T:
if attr is None:
raise NotImplementedError(f"{self.name} does not specify property.")
return not_none(attr)

@property
@informative_failure_on_none
def name(self) -> str:
return f"{self.__class__.__name__}"
return self.__class__.__name__

def __call__(
self,
Expand Down Expand Up @@ -116,14 +102,12 @@ def f(self, X: np.ndarray) -> Union[float, np.ndarray]:
return np.array([self._f(X=x) for x in X])

@property
@informative_failure_on_none
def required_dimensionality(self) -> Optional[int]:
def required_dimensionality(self) -> int:
"""Required dimensionality of input to this function."""
return self._required_dimensionality

@property
@informative_failure_on_none
def domain(self) -> List[Tuple[float, ...]]:
def domain(self) -> List[Tuple[float, float]]:
"""Domain on which function is evaluated.

The list is of the same length as the dimensionality of the inputs,
Expand All @@ -133,36 +117,32 @@ def domain(self) -> List[Tuple[float, ...]]:
return self._domain

@property
@informative_failure_on_none
def minimums(self) -> List[Tuple[float, ...]]:
"""List of global minimums.

Each element of the list is a d-tuple, where d is the dimensionality
of the inputs. There may be more than one global minimums.
"""
return self._minimums
return self.informative_failure_on_none(self._minimums)

@property
@informative_failure_on_none
def maximums(self) -> List[Tuple[float, ...]]:
"""List of global minimums.

Each element of the list is a d-tuple, where d is the dimensionality
of the inputs. There may be more than one global minimums.
"""
return self._maximums
return self.informative_failure_on_none(self._maximums)

@property
@informative_failure_on_none
def fmin(self) -> float:
"""Value at global minimum(s)."""
return self._fmin
return self.informative_failure_on_none(self._fmin)

@property
@informative_failure_on_none
def fmax(self) -> float:
"""Value at global minimum(s)."""
return self._fmax
return self.informative_failure_on_none(self._fmax)

@abstractmethod
def _f(self, X: np.ndarray) -> float:
Expand All @@ -184,10 +164,8 @@ def __init__(
) -> None:
self._botorch_function = botorch_synthetic_function
self._required_dimensionality: int = self._botorch_function.dim
self._domain: Optional[List[Tuple[float, float]]] = (
self._botorch_function._bounds
)
self._fmin: float = self._botorch_function._optimal_value
self._domain: List[Tuple[float, float]] = self._botorch_function._bounds
self._fmin: Optional[float] = self._botorch_function._optimal_value

@override
@property
Expand All @@ -211,7 +189,7 @@ class Hartmann6(SyntheticFunction):
"""Hartmann6 function (6-dimensional with 1 global minimum)."""

_required_dimensionality = 6
_domain: List[Tuple[int, int]] = [(0, 1) for i in range(6)]
_domain: List[Tuple[float, float]] = [(0.0, 1.0) for i in range(6)]
_minimums = [(0.20169, 0.150011, 0.476874, 0.275332, 0.311652, 0.6573)]
_fmin: float = -3.32237
_fmax = 0.0
Expand Down Expand Up @@ -249,7 +227,7 @@ class Aug_Hartmann6(Hartmann6):
"""Augmented Hartmann6 function (7-dimensional with 1 global minimum)."""

_required_dimensionality = 7
_domain: List[Tuple[int, int]] = [(0, 1) for i in range(7)]
_domain: List[Tuple[float, float]] = [(0.0, 1.0) for i in range(7)]
# pyre-fixme[15]: `_minimums` overrides attribute defined in `Hartmann6`
# inconsistently.
_minimums = [(0.20169, 0.150011, 0.476874, 0.275332, 0.311652, 0.6573, 1.0)]
Expand All @@ -276,7 +254,7 @@ class Branin(SyntheticFunction):
"""Branin function (2-dimensional with 3 global minima)."""

_required_dimensionality = 2
_domain = [(-5, 10), (0, 15)]
_domain: List[Tuple[float, float]] = [(-5.0, 10.0), (0.0, 15.0)]
_minimums: List[Tuple[float, float]] = [
(-np.pi, 12.275),
(np.pi, 2.275),
Expand All @@ -301,11 +279,11 @@ class Aug_Branin(SyntheticFunction):
"""Augmented Branin function (3-dimensional with infinitely many global minima)."""

_required_dimensionality = 3
_domain = [(-5, 10), (0, 15), (0, 1)]
_minimums: List[Tuple[float, float, int]] = [
(-np.pi, 12.275, 1),
(np.pi, 2.275, 1),
(9.42478, 2.475, 1),
_domain: List[Tuple[float, float]] = [(-5.0, 10.0), (0.0, 15.0), (0.0, 1.0)]
_minimums: List[Tuple[float, float, float]] = [
(-np.pi, 12.275, 1.0),
(np.pi, 2.275, 1.0),
(9.42478, 2.475, 1.0),
]
_fmin = 0.397887
_fmax = 308.129
Expand Down