From 4a272e26fbb5054f9545be741c2f04abf215a65e Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 23 May 2023 21:48:30 -0700 Subject: [PATCH] Map over notebook task (#1650) * map over notebook Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * tests Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * add a flag Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * fix tests Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * lint Signed-off-by: Kevin Su * Fix tests Signed-off-by: Kevin Su * lint Signed-off-by: Kevin Su --------- Signed-off-by: Kevin Su --- flytekit/core/map_task.py | 17 ++++++++++++----- .../flytekitplugins/papermill/task.py | 8 +++++++- plugins/flytekit-papermill/tests/test_task.py | 19 ++++++++++++++++++- 3 files changed, 37 insertions(+), 7 deletions(-) diff --git a/flytekit/core/map_task.py b/flytekit/core/map_task.py index 44f18de2a3..6e47ce716c 100644 --- a/flytekit/core/map_task.py +++ b/flytekit/core/map_task.py @@ -16,7 +16,7 @@ from flytekit.core.constants import SdkTaskType from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.interface import transform_interface_to_list_interface -from flytekit.core.python_function_task import PythonFunctionTask +from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask from flytekit.core.tracker import TrackedInstance from flytekit.core.utils import timeit from flytekit.exceptions import scopes as exception_scopes @@ -34,7 +34,7 @@ class MapPythonTask(PythonTask): def __init__( self, - python_function_task: typing.Union[PythonFunctionTask, functools.partial], + python_function_task: typing.Union[PythonFunctionTask, PythonInstanceTask, functools.partial], concurrency: Optional[int] = None, min_success_ratio: Optional[float] = None, bound_inputs: Optional[Set[str]] = None, @@ -65,7 +65,10 @@ def __init__( actual_task = python_function_task if not isinstance(actual_task, PythonFunctionTask): - raise ValueError("Map tasks can only compose of Python Functon Tasks currently") + if isinstance(actual_task, PythonInstanceTask): + pass + else: + raise ValueError("Map tasks can only compose of PythonFuncton and PythonInstanceTasks currently") if len(actual_task.python_interface.outputs.keys()) > 1: raise ValueError("Map tasks only accept python function tasks with 0 or 1 outputs") @@ -76,7 +79,11 @@ def __init__( collection_interface = transform_interface_to_list_interface(actual_task.python_interface, self._bound_inputs) self._run_task: PythonFunctionTask = actual_task - _, mod, f, _ = tracker.extract_task_module(actual_task.task_function) + if isinstance(actual_task, PythonInstanceTask): + mod = actual_task.task_type + f = actual_task.lhs + else: + _, mod, f, _ = tracker.extract_task_module(actual_task.task_function) h = hashlib.md5(collection_interface.__str__().encode("utf-8")).hexdigest() name = f"{mod}.map_{f}_{h}" @@ -271,7 +278,7 @@ def _raw_execute(self, **kwargs) -> Any: def map_task( - task_function: typing.Union[PythonFunctionTask, functools.partial], + task_function: typing.Union[PythonFunctionTask, PythonInstanceTask, functools.partial], concurrency: int = 0, min_success_ratio: float = 1.0, **kwargs, diff --git a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py index b1f472e99a..6f4ed6886c 100644 --- a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py +++ b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py @@ -133,6 +133,7 @@ def __init__( task_config: T = None, inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, outputs: typing.Optional[typing.Dict[str, typing.Type]] = None, + output_notebooks: typing.Optional[bool] = True, **kwargs, ): # Each instance of NotebookTask instantiates an underlying task with a dummy function that will only be used @@ -165,13 +166,16 @@ def __init__( if not os.path.exists(self._notebook_path): raise ValueError(f"Illegal notebook path passed in {self._notebook_path}") - if outputs: + if output_notebooks: + if outputs is None: + outputs = {} outputs.update( { self._IMPLICIT_OP_NOTEBOOK: self._IMPLICIT_OP_NOTEBOOK_TYPE, self._IMPLICIT_RENDERED_NOTEBOOK: self._IMPLICIT_RENDERED_NOTEBOOK_TYPE, } ) + super().__init__( name, task_config, @@ -287,6 +291,8 @@ def execute(self, **kwargs) -> Any: else: raise TypeError(f"Expected output {k} of type {type_v} not found in the notebook outputs") + if len(output_list) == 1: + return output_list[0] return tuple(output_list) def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any: diff --git a/plugins/flytekit-papermill/tests/test_task.py b/plugins/flytekit-papermill/tests/test_task.py index 0e54e7082e..47db35793d 100644 --- a/plugins/flytekit-papermill/tests/test_task.py +++ b/plugins/flytekit-papermill/tests/test_task.py @@ -1,6 +1,7 @@ import datetime import os import tempfile +import typing import pandas as pd from flytekitplugins.papermill import NotebookTask @@ -8,7 +9,7 @@ from kubernetes.client import V1Container, V1PodSpec import flytekit -from flytekit import StructuredDataset, kwtypes, task +from flytekit import StructuredDataset, kwtypes, map_task, task, workflow from flytekit.configuration import Image, ImageConfig from flytekit.types.directory import FlyteDirectory from flytekit.types.file import FlyteFile, PythonNotebook @@ -33,6 +34,14 @@ def _get_nb_path(name: str, suffix: str = "", abs: bool = True, ext: str = ".ipy outputs=kwtypes(square=float), ) +nb_sub_task = NotebookTask( + name="test", + notebook_path=_get_nb_path(nb_name, abs=False), + inputs=kwtypes(a=float), + outputs=kwtypes(square=float), + output_notebooks=False, +) + def test_notebook_task_simple(): serialization_settings = flytekit.configuration.SerializationSettings( @@ -172,3 +181,11 @@ def create_sd() -> StructuredDataset: ) success, out, render = nb_types.execute(ff=ff, fd=fd, sd=sd) assert success is True, "Notebook execution failed" + + +def test_map_over_notebook_task(): + @workflow + def wf(a: float) -> typing.List[float]: + return map_task(nb_sub_task)(a=[a, a]) + + assert wf(a=3.14) == [9.8596, 9.8596]