Skip to content

Commit

Permalink
Return error code on fail (#1408)
Browse files Browse the repository at this point in the history
* AWS batch return error code once it fails

Signed-off-by: Kevin Su <[email protected]>

* AWS batch return error code once it fails

Signed-off-by: Kevin Su <[email protected]>

* update tests

Signed-off-by: Kevin Su <[email protected]>

* Update tests

Signed-off-by: Kevin Su <[email protected]>

Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: Samhita Alla <[email protected]>
  • Loading branch information
pingsutw authored and samhita-alla committed Feb 2, 2023
1 parent 8988a92 commit 963c120
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 0 deletions.
6 changes: 6 additions & 0 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,12 @@ def _dispatch_execute(
logger.info(f"Engine folder written successfully to the output prefix {output_prefix}")
logger.debug("Finished _dispatch_execute")

if os.environ.get("FLYTE_FAIL_ON_ERROR", "").lower() == "true" and _constants.ERROR_FILE_NAME in output_file_dict:
# This env is set by the flytepropeller
# AWS batch job get the status from the exit code, so once we catch the error,
# we should return the error code here
exit(1)


def get_one_of(*args) -> str:
"""
Expand Down
32 changes: 32 additions & 0 deletions tests/flytekit/unit/bin/test_python_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import OrderedDict

import mock
import pytest
from flyteidl.core.errors_pb2 import ErrorDocument

from flytekit.bin.entrypoint import _dispatch_execute, normalize_inputs, setup_execution
Expand Down Expand Up @@ -110,6 +111,37 @@ def verify_output(*args, **kwargs):
assert mock_write_to_file.call_count == 1


@mock.patch.dict(os.environ, {"FLYTE_FAIL_ON_ERROR": "True"})
@mock.patch("flytekit.core.utils.load_proto_from_file")
@mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data")
@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data")
@mock.patch("flytekit.core.utils.write_proto_to_file")
def test_dispatch_execute_return_error_code(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto):
mock_get_data.return_value = True
mock_upload_dir.return_value = True

ctx = context_manager.FlyteContext.current_context()
with context_manager.FlyteContextManager.with_context(
ctx.with_execution_state(
ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION)
)
) as ctx:
python_task = mock.MagicMock()
python_task.dispatch_execute.side_effect = Exception("random")

empty_literal_map = _literal_models.LiteralMap({}).to_flyte_idl()
mock_load_proto.return_value = empty_literal_map

def verify_output(*args, **kwargs):
assert isinstance(args[0], ErrorDocument)

mock_write_to_file.side_effect = verify_output

with pytest.raises(SystemExit) as cm:
_dispatch_execute(ctx, python_task, "inputs path", "outputs prefix")
pytest.assertEqual(cm.value.code, 1)


# This function collects outputs instead of writing them to a file.
# See flytekit.core.utils.write_proto_to_file for the original
def get_output_collector(results: OrderedDict):
Expand Down

0 comments on commit 963c120

Please sign in to comment.