diff --git a/examples/sample_pipeline/pipeline.py b/examples/sample_pipeline/pipeline.py index f984c129..25648e2d 100644 --- a/examples/sample_pipeline/pipeline.py +++ b/examples/sample_pipeline/pipeline.py @@ -50,8 +50,11 @@ ], ) class CalculateChunkLength(PandasTransformComponent): + def __init__(self, feature_name: str, **kwargs): + self.feature_name = feature_name + def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: - dataframe["chunk_length"] = dataframe["text"].apply(len) + dataframe[self.feature_name] = dataframe["text"].apply(len) return dataframe @@ -59,4 +62,5 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: ref=CalculateChunkLength, consumes={"text": pa.string()}, produces={"chunk_length": pa.int32()}, + arguments={"feature_name": "chunk_length"}, )