diff --git a/psycop/common/model_training_v2/trainer/preprocessing/pipeline.py b/psycop/common/model_training_v2/trainer/preprocessing/pipeline.py index aa2e983a5..9456f5484 100644 --- a/psycop/common/model_training_v2/trainer/preprocessing/pipeline.py +++ b/psycop/common/model_training_v2/trainer/preprocessing/pipeline.py @@ -21,7 +21,7 @@ def apply(self, data: PolarsFrame) -> pd.DataFrame: @BaselineRegistry.preprocessing.register("baseline_preprocessing_pipeline") -class BaselinePreprocessingPipeline(PreprocessingPipeline): +class BaselinePreprocessingPipeline(PreprocessingPipeline): def __init__(self, *args: PresplitStep) -> None: self.steps = list(args) diff --git a/psycop/common/model_training_v2/trainer/preprocessing/steps/row_filters.py b/psycop/common/model_training_v2/trainer/preprocessing/steps/row_filters.py index 3d6e7fcad..beb7f22fc 100644 --- a/psycop/common/model_training_v2/trainer/preprocessing/steps/row_filters.py +++ b/psycop/common/model_training_v2/trainer/preprocessing/steps/row_filters.py @@ -9,7 +9,12 @@ @BaselineRegistry.preprocessing.register("age_filter") class AgeFilter(PresplitStep): - def __init__(self, min_age: int, max_age: int = 999, age_col_name: str = "pred_age"): + def __init__( + self, + min_age: int, + max_age: int = 999, + age_col_name: str = "pred_age", + ): self.min_age = min_age self.max_age = max_age self.age = pl.col(age_col_name)