Skip to content

Commit

Permalink
Correctly handle output of the failed tasks (#25427)
Browse files Browse the repository at this point in the history
In the Jobs API 2.1, we can't call `get_run_output` on the top-level Run ID because it's
not supported by API - we need to call this function on specific sub-run of the job, even
if it consists of the single task

closes: #25286
  • Loading branch information
alexott authored Aug 3, 2022
1 parent 87a0bd9 commit 679a853
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 27 deletions.
22 changes: 22 additions & 0 deletions airflow/providers/databricks/hooks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,28 @@ async def a_get_run_state(self, run_id: int) -> RunState:
state = response['state']
return RunState(**state)

def get_run(self, run_id: int) -> Dict[str, Any]:
"""
Retrieve run information.
:param run_id: id of the run
:return: state of the run
"""
json = {'run_id': run_id}
response = self._do_api_call(GET_RUN_ENDPOINT, json)
return response

async def a_get_run(self, run_id: int) -> Dict[str, Any]:
"""
Async version of `get_run`.
:param run_id: id of the run
:return: state of the run
"""
json = {'run_id': run_id}
response = await self._a_do_api_call(GET_RUN_ENDPOINT, json)
return response

def get_run_state_str(self, run_id: int) -> str:
"""
Return the string representation of RunState.
Expand Down
33 changes: 26 additions & 7 deletions airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,39 @@ def _handle_databricks_operator_execution(operator, hook, log, context) -> None:

if operator.wait_for_termination:
while True:
run_state = hook.get_run_state(operator.run_id)
run_info = hook.get_run(operator.run_id)
run_state = RunState(**run_info['state'])
if run_state.is_terminal:
if run_state.is_successful:
log.info('%s completed successfully.', operator.task_id)
log.info('View run status, Spark UI, and logs at %s', run_page_url)
return
else:
run_output = hook.get_run_output(operator.run_id)
notebook_error = run_output['error']
error_message = (
f'{operator.task_id} failed with terminal state: {run_state} '
f'and with the error {notebook_error}'
)
if run_state.result_state == "FAILED":
task_run_id = None
if 'tasks' in run_info:
for task in run_info['tasks']:
if task.get("state", {}).get("result_state", "") == "FAILED":
task_run_id = task["run_id"]
if task_run_id is not None:
run_output = hook.get_run_output(task_run_id)
if 'error' in run_output:
notebook_error = run_output['error']
else:
notebook_error = run_state.state_message
else:
notebook_error = run_state.state_message
error_message = (
f'{operator.task_id} failed with terminal state: {run_state} '
f'and with the error {notebook_error}'
)
else:
error_message = (
f'{operator.task_id} failed with terminal state: {run_state} '
f'and with the error {run_state.state_message}'
)
raise AirflowException(error_message)

else:
log.info('%s in run state: %s', operator.task_id, run_state)
log.info('View run status, Spark UI, and logs at %s', run_page_url)
Expand Down
120 changes: 100 additions & 20 deletions tests/providers/databricks/operators/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import unittest
from datetime import datetime
from unittest import mock
from unittest.mock import MagicMock

import pytest

Expand Down Expand Up @@ -60,6 +61,28 @@
SPARK_SUBMIT_PARAMS = ["--class", "org.apache.spark.examples.SparkPi"]


def mock_dict(d: dict):
m = MagicMock()
m.return_value = d
return m


def make_run_with_state_mock(
lifecycle_state: str, result_state: str, state_message: str = "", run_id=1, job_id=JOB_ID
):
return mock_dict(
{
"job_id": job_id,
"run_id": run_id,
"state": {
"life_cycle_state": lifecycle_state,
"result_state": result_state,
"state_message": state_message,
},
}
)


class TestDatabricksSubmitRunOperator(unittest.TestCase):
def test_init_with_notebook_task_named_parameters(self):
"""
Expand Down Expand Up @@ -218,7 +241,7 @@ def test_exec_success(self, db_mock_class):
op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = 1
db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '')
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

op.execute(None)

Expand All @@ -235,7 +258,7 @@ def test_exec_success(self, db_mock_class):

db_mock.submit_run.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run_state.assert_called_once_with(RUN_ID)
db_mock.get_run.assert_called_once_with(RUN_ID)
assert RUN_ID == op.run_id

@mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
Expand All @@ -250,7 +273,7 @@ def test_exec_failure(self, db_mock_class):
op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = 1
db_mock.get_run_state.return_value = RunState('TERMINATED', 'FAILED', '')
db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")

with pytest.raises(AirflowException):
op.execute(None)
Expand All @@ -271,7 +294,7 @@ def test_exec_failure(self, db_mock_class):
)
db_mock.submit_run.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run_state.assert_called_once_with(RUN_ID)
db_mock.get_run.assert_called_once_with(RUN_ID)
assert RUN_ID == op.run_id

@mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
Expand All @@ -297,7 +320,7 @@ def test_wait_for_termination(self, db_mock_class):
op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = 1
db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '')
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

assert op.wait_for_termination

Expand All @@ -316,7 +339,7 @@ def test_wait_for_termination(self, db_mock_class):

db_mock.submit_run.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run_state.assert_called_once_with(RUN_ID)
db_mock.get_run.assert_called_once_with(RUN_ID)

@mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
def test_no_wait_for_termination(self, db_mock_class):
Expand Down Expand Up @@ -345,7 +368,7 @@ def test_no_wait_for_termination(self, db_mock_class):

db_mock.submit_run.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run_state.assert_not_called()
db_mock.get_run.assert_not_called()


class TestDatabricksSubmitRunDeferrableOperator(unittest.TestCase):
Expand All @@ -361,7 +384,7 @@ def test_execute_task_deferred(self, db_mock_class):
op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = 1
db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '')
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

with pytest.raises(TaskDeferred) as exc:
op.execute(None)
Expand Down Expand Up @@ -422,7 +445,7 @@ def test_execute_complete_failure(self, db_mock_class):

db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = 1
db_mock.get_run_state.return_value = run_state_failed
db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")

with pytest.raises(AirflowException, match=f'Job run failed with terminal state: {run_state_failed}'):
op.execute_complete(context=None, event=event)
Expand Down Expand Up @@ -535,7 +558,7 @@ def test_exec_success(self, db_mock_class):
op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '')
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

op.execute(None)

Expand All @@ -557,7 +580,7 @@ def test_exec_success(self, db_mock_class):
)
db_mock.run_now.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run_state.assert_called_once_with(RUN_ID)
db_mock.get_run.assert_called_once_with(RUN_ID)
assert RUN_ID == op.run_id

@mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
Expand All @@ -569,7 +592,7 @@ def test_exec_failure(self, db_mock_class):
op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.get_run_state.return_value = RunState('TERMINATED', 'FAILED', '')
db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")

with pytest.raises(AirflowException):
op.execute(None)
Expand All @@ -591,7 +614,64 @@ def test_exec_failure(self, db_mock_class):
)
db_mock.run_now.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run_state.assert_called_once_with(RUN_ID)
db_mock.get_run.assert_called_once_with(RUN_ID)
assert RUN_ID == op.run_id

@mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
def test_exec_failure_with_message(self, db_mock_class):
"""
Test the execute function in case where the run failed.
"""
run = {'notebook_params': NOTEBOOK_PARAMS, 'notebook_task': NOTEBOOK_TASK, 'jar_params': JAR_PARAMS}
op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.get_run = mock_dict(
{
"job_id": JOB_ID,
"run_id": 1,
"state": {
"life_cycle_state": "TERMINATED",
"result_state": "FAILED",
"state_message": "failed",
},
"tasks": [
{
"run_id": 2,
"state": {
"life_cycle_state": "TERMINATED",
"result_state": "FAILED",
"state_message": "failed",
},
}
],
}
)
db_mock.get_run_output = mock_dict({"error": "Exception: Something went wrong..."})

with pytest.raises(AirflowException) as exc_info:
op.execute(None)

assert exc_info.value.args[0].endswith(" Exception: Something went wrong...")

expected = utils.deep_string_coerce(
{
'notebook_params': NOTEBOOK_PARAMS,
'notebook_task': NOTEBOOK_TASK,
'jar_params': JAR_PARAMS,
'job_id': JOB_ID,
}
)
db_mock_class.assert_called_once_with(
DEFAULT_CONN_ID,
retry_limit=op.databricks_retry_limit,
retry_delay=op.databricks_retry_delay,
retry_args=None,
caller='DatabricksRunNowOperator',
)
db_mock.run_now.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run.assert_called_once_with(RUN_ID)
assert RUN_ID == op.run_id

@mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
Expand All @@ -610,7 +690,7 @@ def test_wait_for_termination(self, db_mock_class):
op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '')
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

assert op.wait_for_termination

Expand All @@ -634,7 +714,7 @@ def test_wait_for_termination(self, db_mock_class):

db_mock.run_now.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run_state.assert_called_once_with(RUN_ID)
db_mock.get_run.assert_called_once_with(RUN_ID)

@mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
def test_no_wait_for_termination(self, db_mock_class):
Expand Down Expand Up @@ -665,7 +745,7 @@ def test_no_wait_for_termination(self, db_mock_class):

db_mock.run_now.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run_state.assert_not_called()
db_mock.get_run.assert_not_called()

def test_init_exception_with_job_name_and_job_id(self):
exception_message = "Argument 'job_name' is not allowed with argument 'job_id'"
Expand All @@ -688,7 +768,7 @@ def test_exec_with_job_name(self, db_mock_class):
db_mock = db_mock_class.return_value
db_mock.find_job_id_by_name.return_value = JOB_ID
db_mock.run_now.return_value = 1
db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '')
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

op.execute(None)

Expand All @@ -711,7 +791,7 @@ def test_exec_with_job_name(self, db_mock_class):
db_mock.find_job_id_by_name.assert_called_once_with(JOB_NAME)
db_mock.run_now.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run_state.assert_called_once_with(RUN_ID)
db_mock.get_run.assert_called_once_with(RUN_ID)
assert RUN_ID == op.run_id

@mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
Expand All @@ -738,7 +818,7 @@ def test_execute_task_deferred(self, db_mock_class):
op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '')
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

with pytest.raises(TaskDeferred) as exc:
op.execute(None)
Expand Down Expand Up @@ -799,7 +879,7 @@ def test_execute_complete_failure(self, db_mock_class):

db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.get_run_state.return_value = run_state_failed
db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")

with pytest.raises(AirflowException, match=f'Job run failed with terminal state: {run_state_failed}'):
op.execute_complete(context=None, event=event)
Expand Down

0 comments on commit 679a853

Please sign in to comment.