Skip to content

Commit

Permalink
Fix mypy issues
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrianSosic committed Aug 30, 2024
1 parent dc44dcc commit 6d2d2a8
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 10 deletions.
20 changes: 18 additions & 2 deletions baybe/campaign.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from attrs.converters import optional
from attrs.validators import instance_of

from baybe.exceptions import IncompatibilityError
from baybe.objectives.base import Objective, to_objective
from baybe.parameters.base import Parameter
from baybe.recommenders.base import RecommenderProtocol
Expand All @@ -24,7 +25,7 @@
validate_searchspace_from_config,
)
from baybe.serialization import SerialMixin, converter
from baybe.surrogates.base import Surrogate
from baybe.surrogates.base import SurrogateProtocol
from baybe.targets.base import Target
from baybe.telemetry import (
TELEM_LABELS,
Expand Down Expand Up @@ -278,14 +279,23 @@ def posterior(self, candidates: pd.DataFrame) -> Posterior:
candidates: The candidate points in experimental recommendations.
For details, see :meth:`baybe.surrogates.base.Surrogate.posterior`.
Raises:
IncompatibilityError: If the underlying surrogate model exposes no
method for computing the posterior distribution.
Returns:
Posterior: The corresponding posterior object.
For details, see :meth:`baybe.surrogates.base.Surrogate.posterior`.
"""
surrogate = self.get_surrogate()
if not hasattr(surrogate, method_name := "posterior"):
raise IncompatibilityError(
f"The used surrogate type '{surrogate.__class__.__name__}' does not "
f"provide a '{method_name}' method."
)
return surrogate.posterior(candidates)

def get_surrogate(self) -> Surrogate:
def get_surrogate(self) -> SurrogateProtocol:
"""Get the current surrogate model.
Raises:
Expand All @@ -294,6 +304,12 @@ def get_surrogate(self) -> Surrogate:
Returns:
Surrogate: The surrogate of the current recommender.
"""
if self.objective is None:
raise IncompatibilityError(
f"No surrogate is available since no '{Objective.__name__}' is defined."
)

pure_recommender: RecommenderProtocol
if isinstance(self.recommender, MetaRecommender):
pure_recommender = self.recommender.get_current_recommender()
else:
Expand Down
20 changes: 13 additions & 7 deletions baybe/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@ class UnusedObjectWarning(UserWarning):


##### Exceptions #####


class IncompatibilityError(Exception):
"""Incompatible components are used together."""


class IncompatibleSearchSpaceError(IncompatibilityError):
"""
A recommender is used with a search space that contains incompatible parts,
e.g. a discrete recommender is used with a hybrid or continuous search space.
"""


class NotEnoughPointsLeftError(Exception):
"""
More recommendations are requested than there are viable parameter configurations
Expand All @@ -24,13 +37,6 @@ class NoMCAcquisitionFunctionError(Exception):
"""


class IncompatibleSearchSpaceError(Exception):
"""
A recommender is used with a search space that contains incompatible parts,
e.g. a discrete recommender is used with a hybrid or continuous search space.
"""


class EmptySearchSpaceError(Exception):
"""The created search space contains no parameters."""

Expand Down
3 changes: 3 additions & 0 deletions baybe/surrogates/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ class _NoTransform(Enum):
class SurrogateProtocol(Protocol):
"""Type protocol specifying the interface surrogate models need to implement."""

# TODO: Final layout still to be optimized. For example, shall we require a
# `posterior` method?

def fit(
self,
searchspace: SearchSpace,
Expand Down
2 changes: 1 addition & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ ignore_missing_imports = True
[mypy-gpytorch.*]
ignore_missing_imports = True

[mypy-joblib]
[mypy-joblib.*]
ignore_missing_imports = True

[mypy-mordred]
Expand Down

0 comments on commit 6d2d2a8

Please sign in to comment.