diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index 2686e7fc72c92..f239eb5089188 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -261,6 +261,28 @@ async def a_get_run_state(self, run_id: int) -> RunState: state = response['state'] return RunState(**state) + def get_run(self, run_id: int) -> Dict[str, Any]: + """ + Retrieve run information. + + :param run_id: id of the run + :return: state of the run + """ + json = {'run_id': run_id} + response = self._do_api_call(GET_RUN_ENDPOINT, json) + return response + + async def a_get_run(self, run_id: int) -> Dict[str, Any]: + """ + Async version of `get_run`. + + :param run_id: id of the run + :return: state of the run + """ + json = {'run_id': run_id} + response = await self._a_do_api_call(GET_RUN_ENDPOINT, json) + return response + def get_run_state_str(self, run_id: int) -> str: """ Return the string representation of RunState. diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 7658c81381226..8750565429e48 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -54,20 +54,39 @@ def _handle_databricks_operator_execution(operator, hook, log, context) -> None: if operator.wait_for_termination: while True: - run_state = hook.get_run_state(operator.run_id) + run_info = hook.get_run(operator.run_id) + run_state = RunState(**run_info['state']) if run_state.is_terminal: if run_state.is_successful: log.info('%s completed successfully.', operator.task_id) log.info('View run status, Spark UI, and logs at %s', run_page_url) return else: - run_output = hook.get_run_output(operator.run_id) - notebook_error = run_output['error'] - error_message = ( - f'{operator.task_id} failed with terminal state: {run_state} ' - f'and with the error {notebook_error}' - ) + if run_state.result_state == "FAILED": + task_run_id = None + if 'tasks' in run_info: + for task in run_info['tasks']: + if task.get("state", {}).get("result_state", "") == "FAILED": + task_run_id = task["run_id"] + if task_run_id is not None: + run_output = hook.get_run_output(task_run_id) + if 'error' in run_output: + notebook_error = run_output['error'] + else: + notebook_error = run_state.state_message + else: + notebook_error = run_state.state_message + error_message = ( + f'{operator.task_id} failed with terminal state: {run_state} ' + f'and with the error {notebook_error}' + ) + else: + error_message = ( + f'{operator.task_id} failed with terminal state: {run_state} ' + f'and with the error {run_state.state_message}' + ) raise AirflowException(error_message) + else: log.info('%s in run state: %s', operator.task_id, run_state) log.info('View run status, Spark UI, and logs at %s', run_page_url) diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index e94de7290163c..bea6c1210c577 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -19,6 +19,7 @@ import unittest from datetime import datetime from unittest import mock +from unittest.mock import MagicMock import pytest @@ -60,6 +61,28 @@ SPARK_SUBMIT_PARAMS = ["--class", "org.apache.spark.examples.SparkPi"] +def mock_dict(d: dict): + m = MagicMock() + m.return_value = d + return m + + +def make_run_with_state_mock( + lifecycle_state: str, result_state: str, state_message: str = "", run_id=1, job_id=JOB_ID +): + return mock_dict( + { + "job_id": job_id, + "run_id": run_id, + "state": { + "life_cycle_state": lifecycle_state, + "result_state": result_state, + "state_message": state_message, + }, + } + ) + + class TestDatabricksSubmitRunOperator(unittest.TestCase): def test_init_with_notebook_task_named_parameters(self): """ @@ -218,7 +241,7 @@ def test_exec_success(self, db_mock_class): op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run) db_mock = db_mock_class.return_value db_mock.submit_run.return_value = 1 - db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '') + db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") op.execute(None) @@ -235,7 +258,7 @@ def test_exec_success(self, db_mock_class): db_mock.submit_run.assert_called_once_with(expected) db_mock.get_run_page_url.assert_called_once_with(RUN_ID) - db_mock.get_run_state.assert_called_once_with(RUN_ID) + db_mock.get_run.assert_called_once_with(RUN_ID) assert RUN_ID == op.run_id @mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook') @@ -250,7 +273,7 @@ def test_exec_failure(self, db_mock_class): op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run) db_mock = db_mock_class.return_value db_mock.submit_run.return_value = 1 - db_mock.get_run_state.return_value = RunState('TERMINATED', 'FAILED', '') + db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED") with pytest.raises(AirflowException): op.execute(None) @@ -271,7 +294,7 @@ def test_exec_failure(self, db_mock_class): ) db_mock.submit_run.assert_called_once_with(expected) db_mock.get_run_page_url.assert_called_once_with(RUN_ID) - db_mock.get_run_state.assert_called_once_with(RUN_ID) + db_mock.get_run.assert_called_once_with(RUN_ID) assert RUN_ID == op.run_id @mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook') @@ -297,7 +320,7 @@ def test_wait_for_termination(self, db_mock_class): op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run) db_mock = db_mock_class.return_value db_mock.submit_run.return_value = 1 - db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '') + db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") assert op.wait_for_termination @@ -316,7 +339,7 @@ def test_wait_for_termination(self, db_mock_class): db_mock.submit_run.assert_called_once_with(expected) db_mock.get_run_page_url.assert_called_once_with(RUN_ID) - db_mock.get_run_state.assert_called_once_with(RUN_ID) + db_mock.get_run.assert_called_once_with(RUN_ID) @mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook') def test_no_wait_for_termination(self, db_mock_class): @@ -345,7 +368,7 @@ def test_no_wait_for_termination(self, db_mock_class): db_mock.submit_run.assert_called_once_with(expected) db_mock.get_run_page_url.assert_called_once_with(RUN_ID) - db_mock.get_run_state.assert_not_called() + db_mock.get_run.assert_not_called() class TestDatabricksSubmitRunDeferrableOperator(unittest.TestCase): @@ -361,7 +384,7 @@ def test_execute_task_deferred(self, db_mock_class): op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run) db_mock = db_mock_class.return_value db_mock.submit_run.return_value = 1 - db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '') + db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") with pytest.raises(TaskDeferred) as exc: op.execute(None) @@ -422,7 +445,7 @@ def test_execute_complete_failure(self, db_mock_class): db_mock = db_mock_class.return_value db_mock.submit_run.return_value = 1 - db_mock.get_run_state.return_value = run_state_failed + db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED") with pytest.raises(AirflowException, match=f'Job run failed with terminal state: {run_state_failed}'): op.execute_complete(context=None, event=event) @@ -535,7 +558,7 @@ def test_exec_success(self, db_mock_class): op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) db_mock = db_mock_class.return_value db_mock.run_now.return_value = 1 - db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '') + db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") op.execute(None) @@ -557,7 +580,7 @@ def test_exec_success(self, db_mock_class): ) db_mock.run_now.assert_called_once_with(expected) db_mock.get_run_page_url.assert_called_once_with(RUN_ID) - db_mock.get_run_state.assert_called_once_with(RUN_ID) + db_mock.get_run.assert_called_once_with(RUN_ID) assert RUN_ID == op.run_id @mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook') @@ -569,7 +592,7 @@ def test_exec_failure(self, db_mock_class): op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) db_mock = db_mock_class.return_value db_mock.run_now.return_value = 1 - db_mock.get_run_state.return_value = RunState('TERMINATED', 'FAILED', '') + db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED") with pytest.raises(AirflowException): op.execute(None) @@ -591,7 +614,64 @@ def test_exec_failure(self, db_mock_class): ) db_mock.run_now.assert_called_once_with(expected) db_mock.get_run_page_url.assert_called_once_with(RUN_ID) - db_mock.get_run_state.assert_called_once_with(RUN_ID) + db_mock.get_run.assert_called_once_with(RUN_ID) + assert RUN_ID == op.run_id + + @mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook') + def test_exec_failure_with_message(self, db_mock_class): + """ + Test the execute function in case where the run failed. + """ + run = {'notebook_params': NOTEBOOK_PARAMS, 'notebook_task': NOTEBOOK_TASK, 'jar_params': JAR_PARAMS} + op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) + db_mock = db_mock_class.return_value + db_mock.run_now.return_value = 1 + db_mock.get_run = mock_dict( + { + "job_id": JOB_ID, + "run_id": 1, + "state": { + "life_cycle_state": "TERMINATED", + "result_state": "FAILED", + "state_message": "failed", + }, + "tasks": [ + { + "run_id": 2, + "state": { + "life_cycle_state": "TERMINATED", + "result_state": "FAILED", + "state_message": "failed", + }, + } + ], + } + ) + db_mock.get_run_output = mock_dict({"error": "Exception: Something went wrong..."}) + + with pytest.raises(AirflowException) as exc_info: + op.execute(None) + + assert exc_info.value.args[0].endswith(" Exception: Something went wrong...") + + expected = utils.deep_string_coerce( + { + 'notebook_params': NOTEBOOK_PARAMS, + 'notebook_task': NOTEBOOK_TASK, + 'jar_params': JAR_PARAMS, + 'job_id': JOB_ID, + } + ) + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, + caller='DatabricksRunNowOperator', + ) + db_mock.run_now.assert_called_once_with(expected) + db_mock.get_run_page_url.assert_called_once_with(RUN_ID) + db_mock.get_run.assert_called_once_with(RUN_ID) assert RUN_ID == op.run_id @mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook') @@ -610,7 +690,7 @@ def test_wait_for_termination(self, db_mock_class): op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) db_mock = db_mock_class.return_value db_mock.run_now.return_value = 1 - db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '') + db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") assert op.wait_for_termination @@ -634,7 +714,7 @@ def test_wait_for_termination(self, db_mock_class): db_mock.run_now.assert_called_once_with(expected) db_mock.get_run_page_url.assert_called_once_with(RUN_ID) - db_mock.get_run_state.assert_called_once_with(RUN_ID) + db_mock.get_run.assert_called_once_with(RUN_ID) @mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook') def test_no_wait_for_termination(self, db_mock_class): @@ -665,7 +745,7 @@ def test_no_wait_for_termination(self, db_mock_class): db_mock.run_now.assert_called_once_with(expected) db_mock.get_run_page_url.assert_called_once_with(RUN_ID) - db_mock.get_run_state.assert_not_called() + db_mock.get_run.assert_not_called() def test_init_exception_with_job_name_and_job_id(self): exception_message = "Argument 'job_name' is not allowed with argument 'job_id'" @@ -688,7 +768,7 @@ def test_exec_with_job_name(self, db_mock_class): db_mock = db_mock_class.return_value db_mock.find_job_id_by_name.return_value = JOB_ID db_mock.run_now.return_value = 1 - db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '') + db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") op.execute(None) @@ -711,7 +791,7 @@ def test_exec_with_job_name(self, db_mock_class): db_mock.find_job_id_by_name.assert_called_once_with(JOB_NAME) db_mock.run_now.assert_called_once_with(expected) db_mock.get_run_page_url.assert_called_once_with(RUN_ID) - db_mock.get_run_state.assert_called_once_with(RUN_ID) + db_mock.get_run.assert_called_once_with(RUN_ID) assert RUN_ID == op.run_id @mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook') @@ -738,7 +818,7 @@ def test_execute_task_deferred(self, db_mock_class): op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) db_mock = db_mock_class.return_value db_mock.run_now.return_value = 1 - db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '') + db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") with pytest.raises(TaskDeferred) as exc: op.execute(None) @@ -799,7 +879,7 @@ def test_execute_complete_failure(self, db_mock_class): db_mock = db_mock_class.return_value db_mock.run_now.return_value = 1 - db_mock.get_run_state.return_value = run_state_failed + db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED") with pytest.raises(AirflowException, match=f'Job run failed with terminal state: {run_state_failed}'): op.execute_complete(context=None, event=event)