Skip to content

Commit

Permalink
Improve workaround and also apply it for surrogate model renaming
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrianSosic committed Sep 6, 2024
1 parent e465e86 commit 97c800c
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 20 deletions.
25 changes: 23 additions & 2 deletions baybe/recommenders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import cattrs
import pandas as pd
from cattrs import override

from baybe.objectives.base import Objective
from baybe.searchspace import SearchSpace
Expand Down Expand Up @@ -51,15 +52,35 @@ def recommend(
...


# TODO: The workarounds below are currently required since the hooks created through
# `unstructure_base` and `get_base_structure_hook` do not reuse the hooks of the
# actual class, hence we cannot control things there. Fix is already planned and also
# needed for other reasons.

# Register (un-)structure hooks
converter.register_unstructure_hook(
RecommenderProtocol,
lambda x: unstructure_base(
x,
# TODO: Remove once deprecation got expired:
overrides=dict(acquisition_function_cls=cattrs.override(omit=True)),
overrides=dict(
acquisition_function_cls=cattrs.override(omit=True),
# Temporary workaround (see TODO note above)
_surrogate_model=override(rename="surrogate_model"),
_current_recommender=override(omit=False),
_current_recommender_was_used=override(omit=False),
),
),
)
converter.register_structure_hook(
RecommenderProtocol, get_base_structure_hook(RecommenderProtocol)
RecommenderProtocol,
get_base_structure_hook(
RecommenderProtocol,
# Temporary workaround (see TODO note above)
overrides=dict(
_surrogate_model=override(rename="surrogate_model"),
_current_recommender=override(omit=False),
_current_recommender_was_used=override(omit=False),
),
),
)
23 changes: 5 additions & 18 deletions baybe/recommenders/meta/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,11 @@
class MetaRecommender(SerialMixin, RecommenderProtocol, ABC):
"""Abstract base class for all meta recommenders."""

# TODO: The attributes should be `init=False` but this currently prevents them from
# being serialized. The reason is that setting `_cattrs_include_init_false=True`
# for this class has currently no effect when serializing it as
# a `RecommenderProtocol`, since the hook of the latter does not reuse the
# hook of the actual class. Fix is already planned and also needed for other
# reasons. Until that, as a workaround, we expose the attributes as "private"
# attributes.

_current_recommender: PureRecommender | None = field(
alias="_current_recommender", default=None, kw_only=True
)
"""The current recommender. (For internal use only!)"""

_current_recommender_was_used: bool = field(
alias="_current_recommender_was_used", default=False, kw_only=True
)
"""Flag indicating if the current recommender has already been used.
(For internal use only!)"""
_current_recommender: PureRecommender | None = field(default=None, init=False)
"""The current recommender."""

_current_recommender_was_used: bool = field(default=False, init=False)
"""Flag indicating if the current recommender has already been used."""

@abstractmethod
def select_recommender(
Expand Down

0 comments on commit 97c800c

Please sign in to comment.