From 4f57ab9009e73aa06a538ec4eafff90443e35219 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 17 May 2023 17:48:54 -0700 Subject: [PATCH] pyflyte run supports pickle (#1646) Signed-off-by: Kevin Su --- flytekit/clis/sdk_in_container/run.py | 22 +++++++++++++++++++-- tests/flytekit/unit/cli/pyflyte/test_run.py | 2 ++ tests/flytekit/unit/cli/pyflyte/workflow.py | 6 ++++-- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index b56f67c605..336ffbdad6 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -9,6 +9,7 @@ from dataclasses import dataclass from typing import cast +import cloudpickle import rich_click as click import yaml from dataclasses_json import DataClassJsonMixin @@ -33,7 +34,7 @@ from flytekit.configuration.default_images import DefaultImages from flytekit.core import context_manager from flytekit.core.base_task import PythonTask -from flytekit.core.context_manager import FlyteContext +from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.data_persistence import FileAccessProvider from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase @@ -45,6 +46,7 @@ from flytekit.tools import module_loader, script_mode from flytekit.tools.script_mode import _find_project_root from flytekit.tools.translator import Options +from flytekit.types.pickle.pickle import FlytePickleTransformer REMOTE_FLAG_KEY = "remote" RUN_LEVEL_PARAMS_KEY = "run_level_params" @@ -103,6 +105,19 @@ def convert( raise click.BadParameter(f"parameter should be a valid file path, {value}") +class PickleParamType(click.ParamType): + name = "pickle" + + def convert( + self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] + ) -> typing.Any: + + uri = FlyteContextManager.current_context().file_access.get_random_local_path() + with open(uri, "w+b") as outfile: + cloudpickle.dump(value, outfile) + return FileParam(filepath=str(pathlib.Path(uri).resolve())) + + class DateTimeType(click.DateTime): _NOW_FMT = "now" @@ -227,7 +242,10 @@ def __init__( if self._literal_type.blob: if self._literal_type.blob.dimensionality == BlobType.BlobDimensionality.SINGLE: - self._click_type = FileParamType() + if self._literal_type.blob.format == FlytePickleTransformer.PYTHON_PICKLE_FORMAT: + self._click_type = PickleParamType() + else: + self._click_type = FileParamType() else: self._click_type = DirParamType() diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index 3cfef4aa0f..e33aeebe56 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -117,6 +117,8 @@ def test_pyflyte_run_cli(): json.dumps([{"x": parquet_file}]), "--o", json.dumps({"x": [parquet_file]}), + "--p", + "Any", ], catch_exceptions=False, ) diff --git a/tests/flytekit/unit/cli/pyflyte/workflow.py b/tests/flytekit/unit/cli/pyflyte/workflow.py index 01621a6a01..311f141a22 100644 --- a/tests/flytekit/unit/cli/pyflyte/workflow.py +++ b/tests/flytekit/unit/cli/pyflyte/workflow.py @@ -58,8 +58,9 @@ def print_all( m: dict, n: typing.List[typing.Dict[str, FlyteFile]], o: typing.Dict[str, typing.List[FlyteFile]], + p: typing.Any, ): - print(f"{a}, {b}, {c}, {d}, {e}, {f}, {g}, {h}, {i}, {j}, {k}, {l}, {m}, {n}, {o}") + print(f"{a}, {b}, {c}, {d}, {e}, {f}, {g}, {h}, {i}, {j}, {k}, {l}, {m}, {n}, {o} , {p}") @task @@ -88,6 +89,7 @@ def my_wf( l: dict, n: typing.List[typing.Dict[str, FlyteFile]], o: typing.Dict[str, typing.List[FlyteFile]], + p: typing.Any, remote: pd.DataFrame, image: StructuredDataset, m: dict = {"hello": "world"}, @@ -95,5 +97,5 @@ def my_wf( x = get_subset_df(df=remote) # noqa: shown for demonstration; users should use the same types between tasks show_sd(in_sd=x) show_sd(in_sd=image) - print_all(a=a, b=b, c=c, d=d, e=e, f=f, g=g, h=h, i=i, j=j, k=k, l=l, m=m, n=n, o=o) + print_all(a=a, b=b, c=c, d=d, e=e, f=f, g=g, h=h, i=i, j=j, k=k, l=l, m=m, n=n, o=o, p=p) return x