diff --git a/trieste/ask_tell_optimization.py b/trieste/ask_tell_optimization.py index aa809f3abd..e4ed1bc0b8 100644 --- a/trieste/ask_tell_optimization.py +++ b/trieste/ask_tell_optimization.py @@ -192,7 +192,10 @@ def __init__( if not isinstance(models, Mapping): models = {OBJECTIVE: models} + self._filtered_datasets = datasets + # reassure the type checker that everything is tagged + datasets = cast(Dict[Tag, Dataset], datasets) models = cast(Dict[Tag, TrainableProbabilisticModelType], models) # Get set of dataset and model keys, ignoring any local tag index. That is, only the @@ -206,7 +209,6 @@ def __init__( ) self._datasets = datasets - self._filtered_datasets = datasets self._models = models self._query_plot_dfs: dict[int, pd.DataFrame] = {}