diff --git a/plugins/flytekit-papermill/tests/test_task.py b/plugins/flytekit-papermill/tests/test_task.py index a9613b5d0a..47db35793d 100644 --- a/plugins/flytekit-papermill/tests/test_task.py +++ b/plugins/flytekit-papermill/tests/test_task.py @@ -34,6 +34,14 @@ def _get_nb_path(name: str, suffix: str = "", abs: bool = True, ext: str = ".ipy outputs=kwtypes(square=float), ) +nb_sub_task = NotebookTask( + name="test", + notebook_path=_get_nb_path(nb_name, abs=False), + inputs=kwtypes(a=float), + outputs=kwtypes(square=float), + output_notebooks=False, +) + def test_notebook_task_simple(): serialization_settings = flytekit.configuration.SerializationSettings( @@ -176,16 +184,8 @@ def create_sd() -> StructuredDataset: def test_map_over_notebook_task(): - nb_task = NotebookTask( - name="test", - notebook_path=_get_nb_path(nb_name, abs=False), - inputs=kwtypes(a=float), - outputs=kwtypes(square=float), - output_notebooks=False, - ) - @workflow def wf(a: float) -> typing.List[float]: - return map_task(nb_task)(a=[a, a, a]) + return map_task(nb_sub_task)(a=[a, a]) - assert wf(a=3.14) == [9.8596, 9.8596, 9.8596] + assert wf(a=3.14) == [9.8596, 9.8596]