diff --git a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py index a58b01d482..304932a828 100644 --- a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py +++ b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py @@ -11,6 +11,7 @@ from flytekit import FlyteContext, PythonInstanceTask from flytekit.core.context_manager import ExecutionParameters +from flytekit.deck.deck import Deck from flytekit.extend import Interface, TaskPlugins, TypeEngine from flytekit.loggers import logger from flytekit.models.literals import LiteralMap @@ -63,6 +64,7 @@ class NotebookTask(PythonInstanceTask[T]): name="modulename.my_notebook_task", # the name should be unique within all your tasks, usually it is a good # idea to use the modulename notebook_path="../path/to/my_notebook", + render_deck=True, inputs=kwtypes(v=int), outputs=kwtypes(x=int, y=str), metadata=TaskMetadata(retries=3, cache=True, cache_version="1.0"), @@ -76,7 +78,7 @@ class NotebookTask(PythonInstanceTask[T]): #. It captures the executed notebook in its entirety and is available from Flyte with the name ``out_nb``. #. It also converts the captured notebook into an ``html`` page, which the FlyteConsole will render called - - ``out_rendered_nb`` + ``out_rendered_nb``. If ``render_deck=True`` is passed, this html content will be inserted into a deck. .. note: @@ -109,6 +111,7 @@ def __init__( self, name: str, notebook_path: str, + render_deck: bool = False, task_config: T = None, inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, outputs: typing.Optional[typing.Dict[str, typing.Type]] = None, @@ -128,6 +131,8 @@ def __init__( task_type = f"nb-{self._config_task_instance.task_type}" self._notebook_path = os.path.abspath(notebook_path) + self._render_deck = render_deck + if not os.path.exists(self._notebook_path): raise ValueError(f"Illegal notebook path passed in {self._notebook_path}") @@ -225,6 +230,15 @@ def execute(self, **kwargs) -> Any: return tuple(output_list) def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any: + if self._render_deck: + nb_deck = Deck(self._IMPLICIT_RENDERED_NOTEBOOK) + with open(self.rendered_output_path, "r") as f: + notebook_html = f.read() + nb_deck.append(notebook_html) + # Since user_params is passed by reference, this modifies the object in the outside scope + # which then causes the deck to be rendered later during the dispatch_execute function. + user_params.decks.append(nb_deck) + return self._config_task_instance.post_execute(user_params, rval) diff --git a/plugins/flytekit-papermill/tests/test_task.py b/plugins/flytekit-papermill/tests/test_task.py index ca25eea028..d60e68cdb0 100644 --- a/plugins/flytekit-papermill/tests/test_task.py +++ b/plugins/flytekit-papermill/tests/test_task.py @@ -69,3 +69,17 @@ def test_notebook_task_complex(): assert nb.python_interface.outputs.keys() == {"h", "w", "x", "out_nb", "out_rendered_nb"} assert nb.output_notebook_path == out == _get_nb_path(nb_name, suffix="-out") assert nb.rendered_output_path == render == _get_nb_path(nb_name, suffix="-out", ext=".html") + + +def test_notebook_deck_local_execution_doesnt_fail(): + nb_name = "nb-simple" + nb = NotebookTask( + name="test", + notebook_path=_get_nb_path(nb_name, abs=False), + render_deck=True, + inputs=kwtypes(pi=float), + outputs=kwtypes(square=float), + ) + sqr, out, render = nb.execute(pi=4) + # This is largely a no assert test to ensure render_deck never inhibits local execution. + assert nb._render_deck, "Passing render deck to init should result in private attribute being set"