Skip to content

Commit

Permalink
Enable Workflow.transform to be run with a DataFrame type (#1777)
Browse files Browse the repository at this point in the history
* Enable `Workflow.transform` to be run with a DataFrame type

* Update typehint for transform arguments

* Use DataFrameType in Workflow typehints matching existing code style

* Use singledispatch to split up transform method by type

* Update docstring of `Workflow.transform` to match new types
  • Loading branch information
oliverholworthy authored Mar 21, 2023
1 parent f726194 commit f14a6a2
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 10 deletions.
50 changes: 40 additions & 10 deletions nvtabular/workflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import time
import types
import warnings
from functools import singledispatchmethod
from typing import TYPE_CHECKING, Optional

import cloudpickle
Expand Down Expand Up @@ -78,26 +79,55 @@ def __init__(self, output_node: WorkflowNode, client: Optional["distributed.Clie
self.graph = Graph(output_node)
self.executor = DaskExecutor(client)

def transform(self, dataset: Dataset) -> Dataset:
"""Transforms the dataset by applying the graph of operators to it. Requires the ``fit``
method to have already been called, or calculated statistics to be loaded from disk
@singledispatchmethod
def transform(self, data):
"""Transforms the data by applying the graph of operators to it.
This method returns a Dataset object, with the transformations lazily loaded. None
of the actual computation will happen until the produced Dataset is consumed, or
written out to disk.
Requires the ``fit`` method to have already been called, or
using a Workflow that has already beeen fit and re-loaded from
disk (using the ``load`` method).
This method returns data of the same type.
In the case of a `Dataset`. The computation is lazy. It won't
happen until the produced Dataset is consumed, or written out
to disk. e.g. with a `dataset.compute()`.
Parameters
-----------
dataset: Dataset
Input dataset to transform
data: Union[Dataset, DataFrameType]
Input Dataset or DataFrame to transform
Returns
-------
Dataset
Transformed Dataset with the workflow graph applied to it
Dataset or DataFrame
Transformed Dataset or DataFrame with the workflow graph applied to it
Raises
------
NotImplementedError
If passed an unsupoprted data type to transform.
"""
raise NotImplementedError(
f"Workflow.transform received an unsupported type: {type(data)} "
"Supported types are a `merlin.io.Dataset` or DataFrame (pandas or cudf)"
)

@transform.register
def _(self, dataset: Dataset) -> Dataset:
return self._transform_impl(dataset)

@transform.register
def _(self, dataframe: pd.DataFrame) -> pd.DataFrame:
return self._transform_df(dataframe)

if cudf:

@transform.register
def _(self, dataframe: cudf.DataFrame) -> cudf.DataFrame:
return self._transform_df(dataframe)

def fit_schema(self, input_schema: Schema):
"""Computes input and output schemas for each node in the Workflow graph
Expand Down
9 changes: 9 additions & 0 deletions tests/unit/workflow/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ def test_workflow_double_fit():
workflow.transform(df_event).to_ddf().compute()


def test_workflow_transform_df():
df = make_df({"user_session": ["1", "2", "4", "4", "5"]})
ops = ["user_session"] >> nvt.ops.Categorify()
dataset = nvt.Dataset(df)
workflow = nvt.Workflow(ops)
workflow.fit(dataset)
assert isinstance(workflow.transform(df), type(df))


@pytest.mark.parametrize("engine", ["parquet"])
def test_workflow_fit_op_rename(tmpdir, dataset, engine):
# NVT
Expand Down

0 comments on commit f14a6a2

Please sign in to comment.