Skip to content

Commit

Permalink
Map over notebook task (#1650)
Browse files Browse the repository at this point in the history
* map over notebook

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* tests

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* add a flag

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* fix tests

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* lint

Signed-off-by: Kevin Su <[email protected]>

* Fix tests

Signed-off-by: Kevin Su <[email protected]>

* lint

Signed-off-by: Kevin Su <[email protected]>

---------

Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored and eapolinario committed Jul 10, 2023
1 parent a7247e4 commit 4a272e2
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 7 deletions.
17 changes: 12 additions & 5 deletions flytekit/core/map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from flytekit.core.constants import SdkTaskType
from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager
from flytekit.core.interface import transform_interface_to_list_interface
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask
from flytekit.core.tracker import TrackedInstance
from flytekit.core.utils import timeit
from flytekit.exceptions import scopes as exception_scopes
Expand All @@ -34,7 +34,7 @@ class MapPythonTask(PythonTask):

def __init__(
self,
python_function_task: typing.Union[PythonFunctionTask, functools.partial],
python_function_task: typing.Union[PythonFunctionTask, PythonInstanceTask, functools.partial],
concurrency: Optional[int] = None,
min_success_ratio: Optional[float] = None,
bound_inputs: Optional[Set[str]] = None,
Expand Down Expand Up @@ -65,7 +65,10 @@ def __init__(
actual_task = python_function_task

if not isinstance(actual_task, PythonFunctionTask):
raise ValueError("Map tasks can only compose of Python Functon Tasks currently")
if isinstance(actual_task, PythonInstanceTask):
pass
else:
raise ValueError("Map tasks can only compose of PythonFuncton and PythonInstanceTasks currently")

if len(actual_task.python_interface.outputs.keys()) > 1:
raise ValueError("Map tasks only accept python function tasks with 0 or 1 outputs")
Expand All @@ -76,7 +79,11 @@ def __init__(

collection_interface = transform_interface_to_list_interface(actual_task.python_interface, self._bound_inputs)
self._run_task: PythonFunctionTask = actual_task
_, mod, f, _ = tracker.extract_task_module(actual_task.task_function)
if isinstance(actual_task, PythonInstanceTask):
mod = actual_task.task_type
f = actual_task.lhs
else:
_, mod, f, _ = tracker.extract_task_module(actual_task.task_function)
h = hashlib.md5(collection_interface.__str__().encode("utf-8")).hexdigest()
name = f"{mod}.map_{f}_{h}"

Expand Down Expand Up @@ -271,7 +278,7 @@ def _raw_execute(self, **kwargs) -> Any:


def map_task(
task_function: typing.Union[PythonFunctionTask, functools.partial],
task_function: typing.Union[PythonFunctionTask, PythonInstanceTask, functools.partial],
concurrency: int = 0,
min_success_ratio: float = 1.0,
**kwargs,
Expand Down
8 changes: 7 additions & 1 deletion plugins/flytekit-papermill/flytekitplugins/papermill/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(
task_config: T = None,
inputs: typing.Optional[typing.Dict[str, typing.Type]] = None,
outputs: typing.Optional[typing.Dict[str, typing.Type]] = None,
output_notebooks: typing.Optional[bool] = True,
**kwargs,
):
# Each instance of NotebookTask instantiates an underlying task with a dummy function that will only be used
Expand Down Expand Up @@ -165,13 +166,16 @@ def __init__(
if not os.path.exists(self._notebook_path):
raise ValueError(f"Illegal notebook path passed in {self._notebook_path}")

if outputs:
if output_notebooks:
if outputs is None:
outputs = {}
outputs.update(
{
self._IMPLICIT_OP_NOTEBOOK: self._IMPLICIT_OP_NOTEBOOK_TYPE,
self._IMPLICIT_RENDERED_NOTEBOOK: self._IMPLICIT_RENDERED_NOTEBOOK_TYPE,
}
)

super().__init__(
name,
task_config,
Expand Down Expand Up @@ -287,6 +291,8 @@ def execute(self, **kwargs) -> Any:
else:
raise TypeError(f"Expected output {k} of type {type_v} not found in the notebook outputs")

if len(output_list) == 1:
return output_list[0]
return tuple(output_list)

def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any:
Expand Down
19 changes: 18 additions & 1 deletion plugins/flytekit-papermill/tests/test_task.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import datetime
import os
import tempfile
import typing

import pandas as pd
from flytekitplugins.papermill import NotebookTask
from flytekitplugins.pod import Pod
from kubernetes.client import V1Container, V1PodSpec

import flytekit
from flytekit import StructuredDataset, kwtypes, task
from flytekit import StructuredDataset, kwtypes, map_task, task, workflow
from flytekit.configuration import Image, ImageConfig
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile, PythonNotebook
Expand All @@ -33,6 +34,14 @@ def _get_nb_path(name: str, suffix: str = "", abs: bool = True, ext: str = ".ipy
outputs=kwtypes(square=float),
)

nb_sub_task = NotebookTask(
name="test",
notebook_path=_get_nb_path(nb_name, abs=False),
inputs=kwtypes(a=float),
outputs=kwtypes(square=float),
output_notebooks=False,
)


def test_notebook_task_simple():
serialization_settings = flytekit.configuration.SerializationSettings(
Expand Down Expand Up @@ -172,3 +181,11 @@ def create_sd() -> StructuredDataset:
)
success, out, render = nb_types.execute(ff=ff, fd=fd, sd=sd)
assert success is True, "Notebook execution failed"


def test_map_over_notebook_task():
@workflow
def wf(a: float) -> typing.List[float]:
return map_task(nb_sub_task)(a=[a, a])

assert wf(a=3.14) == [9.8596, 9.8596]

0 comments on commit 4a272e2

Please sign in to comment.