Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using sidecar handler to run Papermill task #1143

Merged
merged 4 commits into from
Aug 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions plugins/flytekit-papermill/dev-requirements.in
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
flyteidl>=1.0.0
git+https://github.com/flyteorg/flytekit@sd-data-persistence#egg=flytekitplugins-spark&subdirectory=plugins/flytekit-spark
# vcs+protocol://repo_url/#egg=pkg&subdirectory=flyte
flytekitplugins-pod==v1.2.0b0
git+https://github.com/flyteorg/flytekit@master#egg=flytekitplugins-spark&subdirectory=plugins/flytekit-spark
49 changes: 43 additions & 6 deletions plugins/flytekit-papermill/dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# This file is autogenerated by pip-compile with python 3.8
# This file is autogenerated by pip-compile with python 3.9
# To update, run:
#
# pip-compile dev-requirements.in
Expand All @@ -8,8 +8,12 @@ arrow==1.2.1
# via jinja2-time
binaryornot==0.4.4
# via cookiecutter
cachetools==5.2.0
# via google-auth
certifi==2021.10.8
# via requests
# via
# kubernetes
# requests
chardet==4.0.0
# via binaryornot
charset-normalizer==2.0.10
Expand Down Expand Up @@ -43,9 +47,15 @@ flyteidl==1.0.0.post1
# -r dev-requirements.in
# flytekit
flytekit==1.1.0b0
# via flytekitplugins-spark
flytekitplugins-spark @ git+https://github.com/flyteorg/flytekit@sd-data-persistence#subdirectory=plugins/flytekit-spark
# via
# flytekitplugins-pod
# flytekitplugins-spark
flytekitplugins-pod==v1.2.0b0
# via -r dev-requirements.in
flytekitplugins-spark @ git+https://github.com/flyteorg/flytekit@master#subdirectory=plugins/flytekit-spark
# via -r dev-requirements.in
google-auth==2.11.0
# via kubernetes
googleapis-common-protos==1.55.0
# via
# flyteidl
Expand All @@ -68,6 +78,8 @@ jinja2-time==0.2.0
# via cookiecutter
keyring==23.5.0
# via flytekit
kubernetes==24.2.0
# via flytekitplugins-pod
markupsafe==2.0.1
# via jinja2
marshmallow==3.14.1
Expand All @@ -87,6 +99,8 @@ numpy==1.22.1
# via
# pandas
# pyarrow
oauthlib==3.2.0
# via requests-oauthlib
pandas==1.3.5
# via flytekit
poyo==0.5.0
Expand All @@ -106,13 +120,20 @@ py4j==0.10.9.3
# via pyspark
pyarrow==6.0.1
# via flytekit
pyasn1==0.4.8
# via
# pyasn1-modules
# rsa
pyasn1-modules==0.2.8
# via google-auth
pyspark==3.2.1
# via flytekitplugins-spark
python-dateutil==2.8.1
# via
# arrow
# croniter
# flytekit
# kubernetes
# pandas
python-json-logger==2.0.2
# via flytekit
Expand All @@ -125,23 +146,33 @@ pytz==2021.3
# flytekit
# pandas
pyyaml==6.0
# via flytekit
# via
# flytekit
# kubernetes
regex==2021.11.10
# via docker-image-py
requests==2.27.1
# via
# cookiecutter
# docker
# flytekit
# kubernetes
# requests-oauthlib
# responses
requests-oauthlib==1.3.1
# via kubernetes
responses==0.17.0
# via flytekit
retry==0.9.2
# via flytekit
rsa==4.9
# via google-auth
six==1.16.0
# via
# cookiecutter
# google-auth
# grpcio
# kubernetes
# python-dateutil
# responses
sortedcontainers==2.4.0
Expand All @@ -159,10 +190,13 @@ typing-inspect==0.7.1
urllib3==1.26.8
# via
# flytekit
# kubernetes
# requests
# responses
websocket-client==1.3.2
# via docker
# via
# docker
# kubernetes
wheel==0.37.1
# via flytekit
wrapt==1.13.3
Expand All @@ -171,3 +205,6 @@ wrapt==1.13.3
# flytekit
zipp==3.7.0
# via importlib-metadata

# The following packages are considered to be unsafe in a requirements file:
# setuptools
29 changes: 26 additions & 3 deletions plugins/flytekit-papermill/flytekitplugins/papermill/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
from nbconvert import HTMLExporter

from flytekit import FlyteContext, PythonInstanceTask
from flytekit.configuration import SerializationSettings
from flytekit.core.context_manager import ExecutionParameters
from flytekit.deck.deck import Deck
from flytekit.extend import Interface, TaskPlugins, TypeEngine
from flytekit.loggers import logger
from flytekit.models import task as task_models
from flytekit.models.literals import LiteralMap
from flytekit.types.file import HTMLPage, PythonNotebook

Expand Down Expand Up @@ -123,12 +125,13 @@ def __init__(
# errors.
# This seem like a hack. We should use a plugin_class that doesn't require a fake-function to make work.
plugin_class = TaskPlugins.find_pythontask_plugin(type(task_config))
self._config_task_instance = plugin_class(task_config=task_config, task_function=_dummy_task_func)
self._config_task_instance = plugin_class(task_config=task_config, task_function=_dummy_task_func, **kwargs)
# Rename the internal task so that there are no conflicts at serialization time. Technically these internal
# tasks should not be serialized at all, but we don't currently have a mechanism for skipping Flyte entities
# at serialization time.
self._config_task_instance._name = f"{PAPERMILL_TASK_PREFIX}.{name}"
task_type = f"nb-{self._config_task_instance.task_type}"
task_type = f"{self._config_task_instance.task_type}"
task_type_version = self._config_task_instance.task_type_version
self._notebook_path = os.path.abspath(notebook_path)

self._render_deck = render_deck
Expand All @@ -144,7 +147,12 @@ def __init__(
}
)
super().__init__(
name, task_config, task_type=task_type, interface=Interface(inputs=inputs, outputs=outputs), **kwargs
name,
task_config,
task_type=task_type,
task_type_version=task_type_version,
interface=Interface(inputs=inputs, outputs=outputs),
**kwargs,
)

@property
Expand All @@ -159,6 +167,21 @@ def output_notebook_path(self) -> str:
def rendered_output_path(self) -> str:
return self._notebook_path.split(".ipynb")[0] + "-out.html"

def get_container(self, settings: SerializationSettings) -> task_models.Container:
return self._config_task_instance.get_container(settings)

def get_k8s_pod(self, settings: SerializationSettings) -> task_models.K8sPod:
# The task name in original command is incorrect because we use _dummy_task_func to construct the _config_task_instance.
# Therefore, Here we replace primary container's command with NotebookTask's command.
def fn(settings: SerializationSettings) -> typing.List[str]:
return self.get_command(settings)

self._config_task_instance.set_command_fn(fn)
return self._config_task_instance.get_k8s_pod(settings)

def get_config(self, settings: SerializationSettings) -> typing.Dict[str, str]:
return self._config_task_instance.get_config(settings)

def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
return self._config_task_instance.pre_execute(user_params)

Expand Down
37 changes: 37 additions & 0 deletions plugins/flytekit-papermill/tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@
import os

from flytekitplugins.papermill import NotebookTask
from flytekitplugins.pod import Pod
from kubernetes.client import V1Container, V1PodSpec

import flytekit
from flytekit import kwtypes
from flytekit.configuration import Image, ImageConfig
from flytekit.types.file import PythonNotebook

from .testdata.datatype import X
Expand Down Expand Up @@ -83,3 +87,36 @@ def test_notebook_deck_local_execution_doesnt_fail():
sqr, out, render = nb.execute(pi=4)
# This is largely a no assert test to ensure render_deck never inhibits local execution.
assert nb._render_deck, "Passing render deck to init should result in private attribute being set"


def generate_por_spec_for_task():
primary_container = V1Container(name="primary")
pod_spec = V1PodSpec(containers=[primary_container])

return pod_spec


nb = NotebookTask(
name="test",
task_config=Pod(pod_spec=generate_por_spec_for_task(), primary_container_name="primary"),
notebook_path=_get_nb_path("nb-simple", abs=False),
inputs=kwtypes(h=str, n=int, w=str),
outputs=kwtypes(h=str, w=PythonNotebook, x=X),
)


def test_notebook_pod_task():
serialization_settings = flytekit.configuration.SerializationSettings(
project="project",
domain="domain",
version="version",
env=None,
image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
)

assert nb.get_container(serialization_settings) is None
assert nb.get_config(serialization_settings)["primary_container_name"] == "primary"
assert (
nb.get_command(serialization_settings)
== nb.get_k8s_pod(serialization_settings).pod_spec["containers"][0]["args"]
)