Skip to content

Commit

Permalink
checks all cells executed correctly
Browse files Browse the repository at this point in the history
Signed-off-by: esad <[email protected]>
  • Loading branch information
peridotml committed May 8, 2023
1 parent ffb17fb commit e632803
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
10 changes: 4 additions & 6 deletions plugins/flytekit-papermill/flytekitplugins/papermill/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,14 +259,11 @@ def execute(self, **kwargs) -> Any:
singleton
"""
logger.info(f"Hijacking the call for task-type {self.task_type}, to call notebook.")
# Execute Notebook via Papermill.

for k, v in kwargs.items():
if isinstance(v, (FlyteFile, FlyteDirectory)):
kwargs[k] = save_literal_to_file(v)
elif isinstance(v, StructuredDataset):
if isinstance(v, (FlyteFile, FlyteDirectory, StructuredDataset)):
kwargs[k] = save_literal_to_file(v)

# Execute Notebook via Papermill.
pm.execute_notebook(self._notebook_path, self.output_notebook_path, parameters=kwargs, log_output=self._stream_logs) # type: ignore

outputs = self.extract_outputs(self.output_notebook_path)
Expand All @@ -276,6 +273,7 @@ def execute(self, **kwargs) -> Any:
if outputs:
m = outputs.literals
output_list = []

for k, type_v in self.python_interface.outputs.items():
if k == self._IMPLICIT_OP_NOTEBOOK:
output_list.append(self.output_notebook_path)
Expand All @@ -285,7 +283,7 @@ def execute(self, **kwargs) -> Any:
v = TypeEngine.to_python_value(ctx=FlyteContext.current_context(), lv=m[k], expected_python_type=type_v)
output_list.append(v)
else:
raise RuntimeError(f"Expected output {k} of type {v} not found in the notebook outputs")
raise RuntimeError(f"Expected output {k} of type {type_v} not found in the notebook outputs")

return tuple(output_list)

Expand Down
4 changes: 3 additions & 1 deletion plugins/flytekit-papermill/tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,5 +168,7 @@ def create_sd() -> StructuredDataset:
name="test",
notebook_path=_get_nb_path(nb_name, abs=False),
inputs=kwtypes(ff=FlyteFile, fd=FlyteDirectory, sd=StructuredDataset),
outputs=kwtypes(success=bool),
)
nb_types.execute(ff=ff, fd=fd, sd=sd)
success, out, render = nb_types.execute(ff=ff, fd=fd, sd=sd)
assert success == True, "Notebook execution failed"

0 comments on commit e632803

Please sign in to comment.