diff --git a/plugins/flytekit-papermill/flytekitplugins/papermill/__init__.py b/plugins/flytekit-papermill/flytekitplugins/papermill/__init__.py index bce2ef2653..648ba1c6e0 100644 --- a/plugins/flytekit-papermill/flytekitplugins/papermill/__init__.py +++ b/plugins/flytekit-papermill/flytekitplugins/papermill/__init__.py @@ -11,4 +11,4 @@ record_outputs """ -from .task import NotebookTask, record_outputs +from .task import NotebookTask, load_flytedirectory, load_flytefile, load_structureddataset, record_outputs diff --git a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py index b70dfcc910..b1f472e99a 100644 --- a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py +++ b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py @@ -2,24 +2,28 @@ import logging import os import sys +import tempfile import typing from typing import Any import nbformat import papermill as pm +from flyteidl.core.literals_pb2 import Literal as _pb2_Literal from flyteidl.core.literals_pb2 import LiteralMap as _pb2_LiteralMap from google.protobuf import text_format as _text_format from nbconvert import HTMLExporter -from flytekit import FlyteContext, PythonInstanceTask +from flytekit import FlyteContext, PythonInstanceTask, StructuredDataset from flytekit.configuration import SerializationSettings +from flytekit.core import utils from flytekit.core.context_manager import ExecutionParameters from flytekit.deck.deck import Deck from flytekit.extend import Interface, TaskPlugins, TypeEngine from flytekit.loggers import logger from flytekit.models import task as task_models -from flytekit.models.literals import LiteralMap -from flytekit.types.file import HTMLPage, PythonNotebook +from flytekit.models.literals import Literal, LiteralMap +from flytekit.types.directory import FlyteDirectory +from flytekit.types.file import FlyteFile, HTMLPage, PythonNotebook T = typing.TypeVar("T") @@ -28,6 +32,8 @@ def _dummy_task_func(): return None +SAVE_AS_LITERAL = (FlyteFile, FlyteDirectory, StructuredDataset) + PAPERMILL_TASK_PREFIX = "pm.nb" @@ -255,6 +261,10 @@ def execute(self, **kwargs) -> Any: singleton """ logger.info(f"Hijacking the call for task-type {self.task_type}, to call notebook.") + for k, v in kwargs.items(): + if isinstance(v, SAVE_AS_LITERAL): + kwargs[k] = save_python_val_to_file(v) + # Execute Notebook via Papermill. pm.execute_notebook(self._notebook_path, self.output_notebook_path, parameters=kwargs, log_output=self._stream_logs) # type: ignore @@ -265,6 +275,7 @@ def execute(self, **kwargs) -> Any: if outputs: m = outputs.literals output_list = [] + for k, type_v in self.python_interface.outputs.items(): if k == self._IMPLICIT_OP_NOTEBOOK: output_list.append(self.output_notebook_path) @@ -274,7 +285,7 @@ def execute(self, **kwargs) -> Any: v = TypeEngine.to_python_value(ctx=FlyteContext.current_context(), lv=m[k], expected_python_type=type_v) output_list.append(v) else: - raise RuntimeError(f"Expected output {k} of type {v} not found in the notebook outputs") + raise TypeError(f"Expected output {k} of type {type_v} not found in the notebook outputs") return tuple(output_list) @@ -307,3 +318,80 @@ def record_outputs(**kwargs) -> str: lit = TypeEngine.to_literal(ctx, python_type=type(v), python_val=v, expected=expected) m[k] = lit return LiteralMap(literals=m).to_flyte_idl() + + +def save_python_val_to_file(input: Any) -> str: + """Save a python value to a local file as a Flyte literal. + + Args: + input (Any): the python value + + Returns: + str: the path to the file + """ + ctx = FlyteContext.current_context() + expected = TypeEngine.to_literal_type(type(input)) + lit = TypeEngine.to_literal(ctx, python_type=type(input), python_val=input, expected=expected) + + tmp_file = tempfile.mktemp(suffix="bin") + utils.write_proto_to_file(lit.to_flyte_idl(), tmp_file) + return tmp_file + + +def load_python_val_from_file(path: str, dtype: T) -> T: + """Loads a python value from a Flyte literal saved to a local file. + + If the path matches the type, it is returned as is. This enables + reusing the parameters cell for local development. + + Args: + path (str): path to the file + dtype (T): the type of the literal + + Returns: + T: the python value of the literal + """ + if isinstance(path, dtype): + return path + + proto = utils.load_proto_from_file(_pb2_Literal, path) + lit = Literal.from_flyte_idl(proto) + ctx = FlyteContext.current_context() + python_value = TypeEngine.to_python_value(ctx, lit, dtype) + return python_value + + +def load_flytefile(path: str) -> T: + """Loads a FlyteFile from a file. + + Args: + path (str): path to the file + + Returns: + T: the python value of the literal + """ + return load_python_val_from_file(path=path, dtype=FlyteFile) + + +def load_flytedirectory(path: str) -> T: + """Loads a FlyteDirectory from a file. + + Args: + path (str): path to the file + + Returns: + T: the python value of the literal + """ + return load_python_val_from_file(path=path, dtype=FlyteDirectory) + + +def load_structureddataset(path: str) -> T: + """Loads a StructuredDataset from a file. + + Args: + path (str): path to the file + + Returns: + T: the python value of the literal + """ + return load_python_val_from_file(path=path, dtype=StructuredDataset) diff --git a/plugins/flytekit-papermill/tests/test_task.py b/plugins/flytekit-papermill/tests/test_task.py index 1947d09445..0e54e7082e 100644 --- a/plugins/flytekit-papermill/tests/test_task.py +++ b/plugins/flytekit-papermill/tests/test_task.py @@ -1,14 +1,17 @@ import datetime import os +import tempfile +import pandas as pd from flytekitplugins.papermill import NotebookTask from flytekitplugins.pod import Pod from kubernetes.client import V1Container, V1PodSpec import flytekit -from flytekit import kwtypes +from flytekit import StructuredDataset, kwtypes, task from flytekit.configuration import Image, ImageConfig -from flytekit.types.file import PythonNotebook +from flytekit.types.directory import FlyteDirectory +from flytekit.types.file import FlyteFile, PythonNotebook from .testdata.datatype import X @@ -134,3 +137,38 @@ def test_notebook_pod_task(): nb.get_command(serialization_settings) == nb.get_k8s_pod(serialization_settings).pod_spec["containers"][0]["args"] ) + + +def test_flyte_types(): + @task + def create_file() -> FlyteFile: + tmp_file = tempfile.mktemp() + with open(tmp_file, "w") as f: + f.write("abc") + return FlyteFile(path=tmp_file) + + @task + def create_dir() -> FlyteDirectory: + tmp_dir = tempfile.mkdtemp() + with open(os.path.join(tmp_dir, "file.txt"), "w") as f: + f.write("abc") + return FlyteDirectory(path=tmp_dir) + + @task + def create_sd() -> StructuredDataset: + df = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) + return StructuredDataset(dataframe=df) + + ff = create_file() + fd = create_dir() + sd = create_sd() + + nb_name = "nb-types" + nb_types = NotebookTask( + name="test", + notebook_path=_get_nb_path(nb_name, abs=False), + inputs=kwtypes(ff=FlyteFile, fd=FlyteDirectory, sd=StructuredDataset), + outputs=kwtypes(success=bool), + ) + success, out, render = nb_types.execute(ff=ff, fd=fd, sd=sd) + assert success is True, "Notebook execution failed" diff --git a/plugins/flytekit-papermill/tests/testdata/nb-simple.ipynb b/plugins/flytekit-papermill/tests/testdata/nb-simple.ipynb index ebdf9a3c71..1ad7aaed4a 100644 --- a/plugins/flytekit-papermill/tests/testdata/nb-simple.ipynb +++ b/plugins/flytekit-papermill/tests/testdata/nb-simple.ipynb @@ -34,7 +34,6 @@ "outputs": [], "source": [ "from flytekitplugins.papermill import record_outputs\n", - "\n", "record_outputs(square=out)" ] }, @@ -49,7 +48,7 @@ "metadata": { "celltoolbar": "Tags", "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -63,9 +62,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.3" + "version": "3.10.10" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/plugins/flytekit-papermill/tests/testdata/nb-types.ipynb b/plugins/flytekit-papermill/tests/testdata/nb-types.ipynb new file mode 100644 index 0000000000..824b1d39ae --- /dev/null +++ b/plugins/flytekit-papermill/tests/testdata/nb-types.ipynb @@ -0,0 +1,100 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "tags": [ + "parameters" + ] + }, + "outputs": [], + "source": [ + "ff = None\n", + "fd = None\n", + "sd = None" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import pandas as pd\n", + "from flytekitplugins.papermill import (\n", + " load_flytefile, load_flytedirectory, load_structureddataset,\n", + " record_outputs\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ff = load_flytefile(ff)\n", + "fd = load_flytedirectory(fd)\n", + "sd = load_structureddataset(sd)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# read file\n", + "with open(ff.download(), 'r') as f:\n", + " text = f.read()\n", + " assert text == \"abc\", \"Text does not match\"\n", + "\n", + "# check file inside directory\n", + "with open(os.path.join(fd.download(),\"file.txt\"), 'r') as f:\n", + " text = f.read()\n", + " assert text == \"abc\", \"Text does not match\"\n", + "\n", + "# check dataset\n", + "df = sd.open(pd.DataFrame).all()\n", + "expected = pd.DataFrame({\"a\": [1, 2], \"b\": [3, 4]})\n", + "assert df.equals(expected), \"Dataframes do not match\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "outputs" + ] + }, + "outputs": [], + "source": [ + "record_outputs(success=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}