diff --git a/src/autora/experimentalist/prediction_filter/__init__.py b/src/autora/experimentalist/prediction_filter/__init__.py index 07cf484..1c85da4 100644 --- a/src/autora/experimentalist/prediction_filter/__init__.py +++ b/src/autora/experimentalist/prediction_filter/__init__.py @@ -1,19 +1,18 @@ """ Example Experimentalist """ +from typing import Callable + import numpy as np import pandas as pd - from sklearn.base import BaseEstimator -from typing import Union, List, Callable - def filter( - conditions: pd.DataFrame, - model: BaseEstimator, - filter_function: Callable, - reset_index: bool = True + conditions: pd.DataFrame, + model: BaseEstimator, + filter_function: Callable, + reset_index: bool = True, ) -> pd.DataFrame: """ Filter conditions based on the expected outcome io the mdeol @@ -47,20 +46,19 @@ def filter( def __filter(x): y = model.predict(np.array(x)) - if hasattr(y, 'shape') and y.shape == np.array([1]): + if hasattr(y, "shape") and y.shape == np.array([1]): y = y[0] _bool = filter_function(y) return _bool - new_conditions['__prediction'] = \ - new_conditions.apply(__filter, axis=1) + new_conditions["__prediction"] = new_conditions.apply(__filter, axis=1) - _c = new_conditions[new_conditions['__prediction']] - _c = _c.drop(columns=['__prediction']) + _c = new_conditions[new_conditions["__prediction"]] + _c = _c.drop(columns=["__prediction"]) if reset_index: _c.reset_index(drop=True, inplace=True) return _c -prediction_filter = filter \ No newline at end of file +prediction_filter = filter