Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly handle output of the failed tasks #25427

Merged
merged 1 commit into from
Aug 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions airflow/providers/databricks/hooks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,28 @@ async def a_get_run_state(self, run_id: int) -> RunState:
state = response['state']
return RunState(**state)

def get_run(self, run_id: int) -> Dict[str, Any]:
"""
Retrieve run information.

:param run_id: id of the run
:return: state of the run
"""
json = {'run_id': run_id}
response = self._do_api_call(GET_RUN_ENDPOINT, json)
return response

async def a_get_run(self, run_id: int) -> Dict[str, Any]:
"""
Async version of `get_run`.

:param run_id: id of the run
:return: state of the run
"""
json = {'run_id': run_id}
response = await self._a_do_api_call(GET_RUN_ENDPOINT, json)
return response

def get_run_state_str(self, run_id: int) -> str:
"""
Return the string representation of RunState.
Expand Down
33 changes: 26 additions & 7 deletions airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,39 @@ def _handle_databricks_operator_execution(operator, hook, log, context) -> None:

if operator.wait_for_termination:
while True:
run_state = hook.get_run_state(operator.run_id)
run_info = hook.get_run(operator.run_id)
run_state = RunState(**run_info['state'])
if run_state.is_terminal:
if run_state.is_successful:
log.info('%s completed successfully.', operator.task_id)
log.info('View run status, Spark UI, and logs at %s', run_page_url)
return
else:
run_output = hook.get_run_output(operator.run_id)
notebook_error = run_output['error']
error_message = (
f'{operator.task_id} failed with terminal state: {run_state} '
f'and with the error {notebook_error}'
)
if run_state.result_state == "FAILED":
task_run_id = None
if 'tasks' in run_info:
for task in run_info['tasks']:
if task.get("state", {}).get("result_state", "") == "FAILED":
task_run_id = task["run_id"]
if task_run_id is not None:
run_output = hook.get_run_output(task_run_id)
if 'error' in run_output:
notebook_error = run_output['error']
else:
notebook_error = run_state.state_message
else:
notebook_error = run_state.state_message
error_message = (
f'{operator.task_id} failed with terminal state: {run_state} '
f'and with the error {notebook_error}'
)
else:
error_message = (
f'{operator.task_id} failed with terminal state: {run_state} '
f'and with the error {run_state.state_message}'
)
raise AirflowException(error_message)

else:
log.info('%s in run state: %s', operator.task_id, run_state)
log.info('View run status, Spark UI, and logs at %s', run_page_url)
Expand Down
120 changes: 100 additions & 20 deletions tests/providers/databricks/operators/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import unittest
from datetime import datetime
from unittest import mock
from unittest.mock import MagicMock

import pytest

Expand Down Expand Up @@ -60,6 +61,28 @@
SPARK_SUBMIT_PARAMS = ["--class", "org.apache.spark.examples.SparkPi"]


def mock_dict(d: dict):
m = MagicMock()
m.return_value = d
return m


def make_run_with_state_mock(
lifecycle_state: str, result_state: str, state_message: str = "", run_id=1, job_id=JOB_ID
):
return mock_dict(
{
"job_id": job_id,
"run_id": run_id,
"state": {
"life_cycle_state": lifecycle_state,
"result_state": result_state,
"state_message": state_message,
},
}
)


class TestDatabricksSubmitRunOperator(unittest.TestCase):
def test_init_with_notebook_task_named_parameters(self):
"""
Expand Down Expand Up @@ -218,7 +241,7 @@ def test_exec_success(self, db_mock_class):
op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = 1
db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '')
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

op.execute(None)

Expand All @@ -235,7 +258,7 @@ def test_exec_success(self, db_mock_class):

db_mock.submit_run.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run_state.assert_called_once_with(RUN_ID)
db_mock.get_run.assert_called_once_with(RUN_ID)
assert RUN_ID == op.run_id

@mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
Expand All @@ -250,7 +273,7 @@ def test_exec_failure(self, db_mock_class):
op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = 1
db_mock.get_run_state.return_value = RunState('TERMINATED', 'FAILED', '')
db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")

with pytest.raises(AirflowException):
op.execute(None)
Expand All @@ -271,7 +294,7 @@ def test_exec_failure(self, db_mock_class):
)
db_mock.submit_run.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run_state.assert_called_once_with(RUN_ID)
db_mock.get_run.assert_called_once_with(RUN_ID)
assert RUN_ID == op.run_id

@mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
Expand All @@ -297,7 +320,7 @@ def test_wait_for_termination(self, db_mock_class):
op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = 1
db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '')
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

assert op.wait_for_termination

Expand All @@ -316,7 +339,7 @@ def test_wait_for_termination(self, db_mock_class):

db_mock.submit_run.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run_state.assert_called_once_with(RUN_ID)
db_mock.get_run.assert_called_once_with(RUN_ID)

@mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
def test_no_wait_for_termination(self, db_mock_class):
Expand Down Expand Up @@ -345,7 +368,7 @@ def test_no_wait_for_termination(self, db_mock_class):

db_mock.submit_run.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run_state.assert_not_called()
db_mock.get_run.assert_not_called()


class TestDatabricksSubmitRunDeferrableOperator(unittest.TestCase):
Expand All @@ -361,7 +384,7 @@ def test_execute_task_deferred(self, db_mock_class):
op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = 1
db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '')
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

with pytest.raises(TaskDeferred) as exc:
op.execute(None)
Expand Down Expand Up @@ -422,7 +445,7 @@ def test_execute_complete_failure(self, db_mock_class):

db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = 1
db_mock.get_run_state.return_value = run_state_failed
db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")

with pytest.raises(AirflowException, match=f'Job run failed with terminal state: {run_state_failed}'):
op.execute_complete(context=None, event=event)
Expand Down Expand Up @@ -535,7 +558,7 @@ def test_exec_success(self, db_mock_class):
op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '')
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

op.execute(None)

Expand All @@ -557,7 +580,7 @@ def test_exec_success(self, db_mock_class):
)
db_mock.run_now.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run_state.assert_called_once_with(RUN_ID)
db_mock.get_run.assert_called_once_with(RUN_ID)
assert RUN_ID == op.run_id

@mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
Expand All @@ -569,7 +592,7 @@ def test_exec_failure(self, db_mock_class):
op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.get_run_state.return_value = RunState('TERMINATED', 'FAILED', '')
db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")

with pytest.raises(AirflowException):
op.execute(None)
Expand All @@ -591,7 +614,64 @@ def test_exec_failure(self, db_mock_class):
)
db_mock.run_now.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run_state.assert_called_once_with(RUN_ID)
db_mock.get_run.assert_called_once_with(RUN_ID)
assert RUN_ID == op.run_id

@mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
def test_exec_failure_with_message(self, db_mock_class):
"""
Test the execute function in case where the run failed.
"""
run = {'notebook_params': NOTEBOOK_PARAMS, 'notebook_task': NOTEBOOK_TASK, 'jar_params': JAR_PARAMS}
op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.get_run = mock_dict(
{
"job_id": JOB_ID,
"run_id": 1,
"state": {
"life_cycle_state": "TERMINATED",
"result_state": "FAILED",
"state_message": "failed",
},
"tasks": [
{
"run_id": 2,
"state": {
"life_cycle_state": "TERMINATED",
"result_state": "FAILED",
"state_message": "failed",
},
}
],
}
)
db_mock.get_run_output = mock_dict({"error": "Exception: Something went wrong..."})

with pytest.raises(AirflowException) as exc_info:
op.execute(None)

assert exc_info.value.args[0].endswith(" Exception: Something went wrong...")

expected = utils.deep_string_coerce(
{
'notebook_params': NOTEBOOK_PARAMS,
'notebook_task': NOTEBOOK_TASK,
'jar_params': JAR_PARAMS,
'job_id': JOB_ID,
}
)
db_mock_class.assert_called_once_with(
DEFAULT_CONN_ID,
retry_limit=op.databricks_retry_limit,
retry_delay=op.databricks_retry_delay,
retry_args=None,
caller='DatabricksRunNowOperator',
)
db_mock.run_now.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run.assert_called_once_with(RUN_ID)
assert RUN_ID == op.run_id

@mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
Expand All @@ -610,7 +690,7 @@ def test_wait_for_termination(self, db_mock_class):
op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '')
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

assert op.wait_for_termination

Expand All @@ -634,7 +714,7 @@ def test_wait_for_termination(self, db_mock_class):

db_mock.run_now.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run_state.assert_called_once_with(RUN_ID)
db_mock.get_run.assert_called_once_with(RUN_ID)

@mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
def test_no_wait_for_termination(self, db_mock_class):
Expand Down Expand Up @@ -665,7 +745,7 @@ def test_no_wait_for_termination(self, db_mock_class):

db_mock.run_now.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run_state.assert_not_called()
db_mock.get_run.assert_not_called()

def test_init_exception_with_job_name_and_job_id(self):
exception_message = "Argument 'job_name' is not allowed with argument 'job_id'"
Expand All @@ -688,7 +768,7 @@ def test_exec_with_job_name(self, db_mock_class):
db_mock = db_mock_class.return_value
db_mock.find_job_id_by_name.return_value = JOB_ID
db_mock.run_now.return_value = 1
db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '')
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

op.execute(None)

Expand All @@ -711,7 +791,7 @@ def test_exec_with_job_name(self, db_mock_class):
db_mock.find_job_id_by_name.assert_called_once_with(JOB_NAME)
db_mock.run_now.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run_state.assert_called_once_with(RUN_ID)
db_mock.get_run.assert_called_once_with(RUN_ID)
assert RUN_ID == op.run_id

@mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
Expand All @@ -738,7 +818,7 @@ def test_execute_task_deferred(self, db_mock_class):
op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '')
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

with pytest.raises(TaskDeferred) as exc:
op.execute(None)
Expand Down Expand Up @@ -799,7 +879,7 @@ def test_execute_complete_failure(self, db_mock_class):

db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.get_run_state.return_value = run_state_failed
db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")

with pytest.raises(AirflowException, match=f'Job run failed with terminal state: {run_state_failed}'):
op.execute_complete(context=None, event=event)
Expand Down