Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
younesStrittmatter committed Jun 4, 2024
1 parent 7509b4d commit f4f5e46
Showing 1 changed file with 11 additions and 13 deletions.
24 changes: 11 additions & 13 deletions src/autora/experimentalist/prediction_filter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
"""
Example Experimentalist
"""
from typing import Callable

import numpy as np
import pandas as pd

from sklearn.base import BaseEstimator

from typing import Union, List, Callable


def filter(
conditions: pd.DataFrame,
model: BaseEstimator,
filter_function: Callable,
reset_index: bool = True
conditions: pd.DataFrame,
model: BaseEstimator,
filter_function: Callable,
reset_index: bool = True,
) -> pd.DataFrame:
"""
Filter conditions based on the expected outcome io the mdeol
Expand Down Expand Up @@ -47,20 +46,19 @@ def filter(

def __filter(x):
y = model.predict(np.array(x))
if hasattr(y, 'shape') and y.shape == np.array([1]):
if hasattr(y, "shape") and y.shape == np.array([1]):
y = y[0]
_bool = filter_function(y)
return _bool

new_conditions['__prediction'] = \
new_conditions.apply(__filter, axis=1)
new_conditions["__prediction"] = new_conditions.apply(__filter, axis=1)

_c = new_conditions[new_conditions['__prediction']]
_c = _c.drop(columns=['__prediction'])
_c = new_conditions[new_conditions["__prediction"]]
_c = _c.drop(columns=["__prediction"])
if reset_index:
_c.reset_index(drop=True, inplace=True)

return _c


prediction_filter = filter
prediction_filter = filter

0 comments on commit f4f5e46

Please sign in to comment.