From f14a6a2b5aea949831feeb47bdc18a02c28564df Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Tue, 21 Mar 2023 18:13:46 +0000 Subject: [PATCH] Enable `Workflow.transform` to be run with a DataFrame type (#1777) * Enable `Workflow.transform` to be run with a DataFrame type * Update typehint for transform arguments * Use DataFrameType in Workflow typehints matching existing code style * Use singledispatch to split up transform method by type * Update docstring of `Workflow.transform` to match new types --- nvtabular/workflow/workflow.py | 50 ++++++++++++++++++++++------ tests/unit/workflow/test_workflow.py | 9 +++++ 2 files changed, 49 insertions(+), 10 deletions(-) diff --git a/nvtabular/workflow/workflow.py b/nvtabular/workflow/workflow.py index ac4d701e0b8..78274181ed7 100755 --- a/nvtabular/workflow/workflow.py +++ b/nvtabular/workflow/workflow.py @@ -21,6 +21,7 @@ import time import types import warnings +from functools import singledispatchmethod from typing import TYPE_CHECKING, Optional import cloudpickle @@ -78,26 +79,55 @@ def __init__(self, output_node: WorkflowNode, client: Optional["distributed.Clie self.graph = Graph(output_node) self.executor = DaskExecutor(client) - def transform(self, dataset: Dataset) -> Dataset: - """Transforms the dataset by applying the graph of operators to it. Requires the ``fit`` - method to have already been called, or calculated statistics to be loaded from disk + @singledispatchmethod + def transform(self, data): + """Transforms the data by applying the graph of operators to it. - This method returns a Dataset object, with the transformations lazily loaded. None - of the actual computation will happen until the produced Dataset is consumed, or - written out to disk. + Requires the ``fit`` method to have already been called, or + using a Workflow that has already beeen fit and re-loaded from + disk (using the ``load`` method). + + This method returns data of the same type. + + In the case of a `Dataset`. The computation is lazy. It won't + happen until the produced Dataset is consumed, or written out + to disk. e.g. with a `dataset.compute()`. Parameters ----------- - dataset: Dataset - Input dataset to transform + data: Union[Dataset, DataFrameType] + Input Dataset or DataFrame to transform Returns ------- - Dataset - Transformed Dataset with the workflow graph applied to it + Dataset or DataFrame + Transformed Dataset or DataFrame with the workflow graph applied to it + + Raises + ------ + NotImplementedError + If passed an unsupoprted data type to transform. + """ + raise NotImplementedError( + f"Workflow.transform received an unsupported type: {type(data)} " + "Supported types are a `merlin.io.Dataset` or DataFrame (pandas or cudf)" + ) + + @transform.register + def _(self, dataset: Dataset) -> Dataset: return self._transform_impl(dataset) + @transform.register + def _(self, dataframe: pd.DataFrame) -> pd.DataFrame: + return self._transform_df(dataframe) + + if cudf: + + @transform.register + def _(self, dataframe: cudf.DataFrame) -> cudf.DataFrame: + return self._transform_df(dataframe) + def fit_schema(self, input_schema: Schema): """Computes input and output schemas for each node in the Workflow graph diff --git a/tests/unit/workflow/test_workflow.py b/tests/unit/workflow/test_workflow.py index ebb4cce590e..28df6142748 100755 --- a/tests/unit/workflow/test_workflow.py +++ b/tests/unit/workflow/test_workflow.py @@ -53,6 +53,15 @@ def test_workflow_double_fit(): workflow.transform(df_event).to_ddf().compute() +def test_workflow_transform_df(): + df = make_df({"user_session": ["1", "2", "4", "4", "5"]}) + ops = ["user_session"] >> nvt.ops.Categorify() + dataset = nvt.Dataset(df) + workflow = nvt.Workflow(ops) + workflow.fit(dataset) + assert isinstance(workflow.transform(df), type(df)) + + @pytest.mark.parametrize("engine", ["parquet"]) def test_workflow_fit_op_rename(tmpdir, dataset, engine): # NVT