From fad70757e6a1670a7907892174fc8a5a576001d0 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 19 May 2023 23:44:57 -0700 Subject: [PATCH] Fix tests Signed-off-by: Kevin Su --- plugins/flytekit-papermill/tests/test_task.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/plugins/flytekit-papermill/tests/test_task.py b/plugins/flytekit-papermill/tests/test_task.py index a9613b5d0a..47db35793d 100644 --- a/plugins/flytekit-papermill/tests/test_task.py +++ b/plugins/flytekit-papermill/tests/test_task.py @@ -34,6 +34,14 @@ def _get_nb_path(name: str, suffix: str = "", abs: bool = True, ext: str = ".ipy outputs=kwtypes(square=float), ) +nb_sub_task = NotebookTask( + name="test", + notebook_path=_get_nb_path(nb_name, abs=False), + inputs=kwtypes(a=float), + outputs=kwtypes(square=float), + output_notebooks=False, +) + def test_notebook_task_simple(): serialization_settings = flytekit.configuration.SerializationSettings( @@ -176,16 +184,8 @@ def create_sd() -> StructuredDataset: def test_map_over_notebook_task(): - nb_task = NotebookTask( - name="test", - notebook_path=_get_nb_path(nb_name, abs=False), - inputs=kwtypes(a=float), - outputs=kwtypes(square=float), - output_notebooks=False, - ) - @workflow def wf(a: float) -> typing.List[float]: - return map_task(nb_task)(a=[a, a, a]) + return map_task(nb_sub_task)(a=[a, a]) - assert wf(a=3.14) == [9.8596, 9.8596, 9.8596] + assert wf(a=3.14) == [9.8596, 9.8596]