From 9b40c028f716948321d0f41e84afd8446e982724 Mon Sep 17 00:00:00 2001 From: Eugene Karimov Date: Wed, 2 Mar 2022 22:01:47 +0100 Subject: [PATCH] Add Deferrable Databricks operators --- .../providers/databricks/hooks/databricks.py | 29 ++ .../databricks/hooks/databricks_base.py | 237 +++++++++++- .../databricks/operators/databricks.py | 106 ++++-- .../providers/databricks/triggers/__init__.py | 17 + .../databricks/triggers/databricks.py | 77 ++++ .../providers/databricks/utils/__init__.py | 16 + .../providers/databricks/utils/databricks.py | 69 ++++ .../operators/run_now.rst | 7 + .../operators/submit_run.rst | 7 + setup.py | 2 + .../databricks/hooks/test_databricks.py | 352 +++++++++++++++++- .../databricks/operators/test_databricks.py | 239 +++++++++--- .../providers/databricks/triggers/__init__.py | 17 + .../databricks/triggers/test_databricks.py | 153 ++++++++ tests/providers/databricks/utils/__init__.py | 16 + .../providers/databricks/utils/databricks.py | 62 +++ 16 files changed, 1313 insertions(+), 93 deletions(-) create mode 100644 airflow/providers/databricks/triggers/__init__.py create mode 100644 airflow/providers/databricks/triggers/databricks.py create mode 100644 airflow/providers/databricks/utils/__init__.py create mode 100644 airflow/providers/databricks/utils/databricks.py create mode 100644 tests/providers/databricks/triggers/__init__.py create mode 100644 tests/providers/databricks/triggers/test_databricks.py create mode 100644 tests/providers/databricks/utils/__init__.py create mode 100644 tests/providers/databricks/utils/databricks.py diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index 79116604124ba..400bbe895588a 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -25,6 +25,7 @@ or the ``api/2.1/jobs/runs/submit`` `endpoint `_. """ +import json from typing import Any, Dict, List, Optional from requests import exceptions as requests_exceptions @@ -92,6 +93,13 @@ def __eq__(self, other: object) -> bool: def __repr__(self) -> str: return str(self.__dict__) + def to_json(self) -> str: + return json.dumps(self.__dict__) + + @classmethod + def from_json(cls, data: str) -> 'RunState': + return RunState(**json.loads(data)) + class DatabricksHook(BaseDatabricksHook): """ @@ -198,6 +206,16 @@ def get_run_page_url(self, run_id: int) -> str: response = self._do_api_call(GET_RUN_ENDPOINT, json) return response['run_page_url'] + async def a_get_run_page_url(self, run_id: int) -> str: + """ + Async version of `get_run_page_url()`. + :param run_id: id of the run + :return: URL of the run page + """ + json = {'run_id': run_id} + response = await self._a_do_api_call(GET_RUN_ENDPOINT, json) + return response['run_page_url'] + def get_job_id(self, run_id: int) -> int: """ Retrieves job_id from run_id. @@ -229,6 +247,17 @@ def get_run_state(self, run_id: int) -> RunState: state = response['state'] return RunState(**state) + async def a_get_run_state(self, run_id: int) -> RunState: + """ + Async version of `get_run_state()`. + :param run_id: id of the run + :return: state of the run + """ + json = {'run_id': run_id} + response = await self._a_do_api_call(GET_RUN_ENDPOINT, json) + state = response['state'] + return RunState(**state) + def get_run_state_str(self, run_id: int) -> str: """ Return the string representation of RunState. diff --git a/airflow/providers/databricks/hooks/databricks_base.py b/airflow/providers/databricks/hooks/databricks_base.py index 6e0f1b44d8ae7..5b18dad9303ef 100644 --- a/airflow/providers/databricks/hooks/databricks_base.py +++ b/airflow/providers/databricks/hooks/databricks_base.py @@ -28,11 +28,19 @@ from typing import Any, Dict, Optional, Tuple from urllib.parse import urlparse +import aiohttp import requests from requests import PreparedRequest, exceptions as requests_exceptions from requests.auth import AuthBase, HTTPBasicAuth from requests.exceptions import JSONDecodeError -from tenacity import RetryError, Retrying, retry_if_exception, stop_after_attempt, wait_exponential +from tenacity import ( + AsyncRetrying, + RetryError, + Retrying, + retry_if_exception, + stop_after_attempt, + wait_exponential, +) from airflow import __version__ from airflow.exceptions import AirflowException @@ -135,6 +143,14 @@ def host(self) -> str: return host + async def __aenter__(self): + self._session = aiohttp.ClientSession() + return self + + async def __aexit__(self, *err): + await self._session.close() + self._session = None + @staticmethod def _parse_host(host: str) -> str: """ @@ -169,6 +185,13 @@ def _get_retry_object(self) -> Retrying: """ return Retrying(**self.retry_args) + def _a_get_retry_object(self) -> AsyncRetrying: + """ + Instantiates an async retry object + :return: instance of AsyncRetrying class + """ + return AsyncRetrying(**self.retry_args) + def _get_aad_token(self, resource: str) -> str: """ Function to get AAD token for given resource. Supports managed identity or service principal auth @@ -234,6 +257,72 @@ def _get_aad_token(self, resource: str) -> str: return token + async def _a_get_aad_token(self, resource: str) -> str: + """ + Async version of `_get_aad_token()`. + :param resource: resource to issue token to + :return: AAD token, or raise an exception + """ + aad_token = self.aad_tokens.get(resource) + if aad_token and self._is_aad_token_valid(aad_token): + return aad_token['token'] + + self.log.info('Existing AAD token is expired, or going to expire soon. Refreshing...') + try: + async for attempt in self._a_get_retry_object(): + with attempt: + if self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False): + params = { + "api-version": "2018-02-01", + "resource": resource, + } + async with self._session.get( + url=AZURE_METADATA_SERVICE_TOKEN_URL, + params=params, + headers={**USER_AGENT_HEADER, "Metadata": "true"}, + timeout=self.aad_timeout_seconds, + ) as resp: + resp.raise_for_status() + jsn = await resp.json() + else: + tenant_id = self.databricks_conn.extra_dejson['azure_tenant_id'] + data = { + "grant_type": "client_credentials", + "client_id": self.databricks_conn.login, + "resource": resource, + "client_secret": self.databricks_conn.password, + } + azure_ad_endpoint = self.databricks_conn.extra_dejson.get( + "azure_ad_endpoint", AZURE_DEFAULT_AD_ENDPOINT + ) + async with self._session.post( + url=AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint, tenant_id), + data=data, + headers={ + **USER_AGENT_HEADER, + 'Content-Type': 'application/x-www-form-urlencoded', + }, + timeout=self.aad_timeout_seconds, + ) as resp: + resp.raise_for_status() + jsn = await resp.json() + if ( + 'access_token' not in jsn + or jsn.get('token_type') != 'Bearer' + or 'expires_on' not in jsn + ): + raise AirflowException(f"Can't get necessary data from AAD token: {jsn}") + + token = jsn['access_token'] + self.aad_tokens[resource] = {'token': token, 'expires_on': int(jsn["expires_on"])} + break + except RetryError: + raise AirflowException(f'API requests to Azure failed {self.retry_limit} times. Giving up.') + except aiohttp.ClientResponseError as err: + raise AirflowException(f'Response: {err.message}, Status Code: {err.status}') + + return token + def _get_aad_headers(self) -> dict: """ Fills AAD headers if necessary (SPN is outside of the workspace) @@ -248,6 +337,20 @@ def _get_aad_headers(self) -> dict: headers['X-Databricks-Azure-SP-Management-Token'] = mgmt_token return headers + async def _a_get_aad_headers(self) -> dict: + """ + Async version of `_get_aad_headers()`. + :return: dictionary with filled AAD headers + """ + headers = {} + if 'azure_resource_id' in self.databricks_conn.extra_dejson: + mgmt_token = await self._a_get_aad_token(AZURE_MANAGEMENT_ENDPOINT) + headers['X-Databricks-Azure-Workspace-Resource-Id'] = self.databricks_conn.extra_dejson[ + 'azure_resource_id' + ] + headers['X-Databricks-Azure-SP-Management-Token'] = mgmt_token + return headers + @staticmethod def _is_aad_token_valid(aad_token: dict) -> bool: """ @@ -281,6 +384,23 @@ def _check_azure_metadata_service() -> None: except (requests_exceptions.RequestException, ValueError) as e: raise AirflowException(f"Can't reach Azure Metadata Service: {e}") + async def _a_check_azure_metadata_service(self): + """Async version of `_check_azure_metadata_service()`.""" + try: + async with self._session.get( + url=AZURE_METADATA_SERVICE_INSTANCE_URL, + params={"api-version": "2021-02-01"}, + headers={"Metadata": "true"}, + timeout=2, + ) as resp: + jsn = await resp.json() + if 'compute' not in jsn or 'azEnvironment' not in jsn['compute']: + raise AirflowException( + f"Was able to fetch some metadata, but it doesn't look like Azure Metadata: {jsn}" + ) + except (requests_exceptions.RequestException, ValueError) as e: + raise AirflowException(f"Can't reach Azure Metadata Service: {e}") + def _get_token(self, raise_error: bool = False) -> Optional[str]: if 'token' in self.databricks_conn.extra_dejson: self.log.info( @@ -304,6 +424,29 @@ def _get_token(self, raise_error: bool = False) -> Optional[str]: return None + async def _a_get_token(self, raise_error: bool = False) -> Optional[str]: + if 'token' in self.databricks_conn.extra_dejson: + self.log.info( + 'Using token auth. For security reasons, please set token in Password field instead of extra' + ) + return self.databricks_conn.extra_dejson["token"] + elif not self.databricks_conn.login and self.databricks_conn.password: + self.log.info('Using token auth.') + return self.databricks_conn.password + elif 'azure_tenant_id' in self.databricks_conn.extra_dejson: + if self.databricks_conn.login == "" or self.databricks_conn.password == "": + raise AirflowException("Azure SPN credentials aren't provided") + self.log.info('Using AAD Token for SPN.') + return await self._a_get_aad_token(DEFAULT_DATABRICKS_SCOPE) + elif self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False): + self.log.info('Using AAD Token for managed identity.') + await self._a_check_azure_metadata_service() + return await self._a_get_aad_token(DEFAULT_DATABRICKS_SCOPE) + elif raise_error: + raise AirflowException('Token authentication isn\'t configured') + + return None + def _log_request_error(self, attempt_num: int, error: str) -> None: self.log.error('Attempt %s API Request to Databricks failed with reason: %s', attempt_num, error) @@ -374,6 +517,55 @@ def _do_api_call( else: raise e + async def _a_do_api_call(self, endpoint_info: Tuple[str, str], json: Optional[Dict[str, Any]] = None): + """ + Async version of `_do_api_call()`. + :param endpoint_info: Tuple of method and endpoint + :param json: Parameters for this API call. + :return: If the api call returns a OK status code, + this function returns the response in JSON. Otherwise, throw an AirflowException. + """ + method, endpoint = endpoint_info + + url = f'https://{self.host}/{endpoint}' + + aad_headers = await self._a_get_aad_headers() + headers = {**USER_AGENT_HEADER.copy(), **aad_headers} + + auth: aiohttp.BasicAuth + token = await self._a_get_token() + if token: + auth = BearerAuth(token) + else: + self.log.info('Using basic auth.') + auth = aiohttp.BasicAuth(self.databricks_conn.login, self.databricks_conn.password) + + request_func: Any + if method == 'GET': + request_func = self._session.get + elif method == 'POST': + request_func = self._session.post + elif method == 'PATCH': + request_func = self._session.patch + else: + raise AirflowException('Unexpected HTTP Method: ' + method) + try: + async for attempt in self._a_get_retry_object(): + with attempt: + async with request_func( + url, + json=json, + auth=auth, + headers={**headers, **USER_AGENT_HEADER}, + timeout=self.timeout_seconds, + ) as response: + response.raise_for_status() + return await response.json() + except RetryError: + raise AirflowException(f'API requests to Databricks failed {self.retry_limit} times. Giving up.') + except aiohttp.ClientResponseError as err: + raise AirflowException(f'Response: {err.message}, Status Code: {err.status}') + @staticmethod def _get_error_code(exception: BaseException) -> str: if isinstance(exception, requests_exceptions.HTTPError): @@ -387,19 +579,25 @@ def _get_error_code(exception: BaseException) -> str: @staticmethod def _retryable_error(exception: BaseException) -> bool: - if not isinstance(exception, requests_exceptions.RequestException): - return False - return isinstance(exception, (requests_exceptions.ConnectionError, requests_exceptions.Timeout)) or ( - exception.response is not None - and ( - exception.response.status_code >= 500 - or exception.response.status_code == 429 - or ( - exception.response.status_code == 400 - and BaseDatabricksHook._get_error_code(exception) == 'COULD_NOT_ACQUIRE_LOCK' + if isinstance(exception, requests_exceptions.RequestException): + if isinstance(exception, (requests_exceptions.ConnectionError, requests_exceptions.Timeout)) or ( + exception.response is not None + and ( + exception.response.status_code >= 500 + or exception.response.status_code == 429 + or ( + exception.response.status_code == 400 + and BaseDatabricksHook._get_error_code(exception) == 'COULD_NOT_ACQUIRE_LOCK' + ) ) - ) - ) + ): + return True + + if isinstance(exception, aiohttp.ClientResponseError): + if exception.status >= 500 or exception.status == 429: + return True + + return False class _TokenAuth(AuthBase): @@ -414,3 +612,16 @@ def __init__(self, token: str) -> None: def __call__(self, r: PreparedRequest) -> PreparedRequest: r.headers['Authorization'] = 'Bearer ' + self.token return r + + +class BearerAuth(aiohttp.BasicAuth): + """aiohttp only ships BasicAuth, for Bearer auth we need a subclass of BasicAuth.""" + + def __new__(cls, token: str) -> 'BearerAuth': + return super().__new__(cls, token) # type: ignore + + def __init__(self, token: str) -> None: + self.token = token + + def encode(self) -> str: + return f'Bearer {self.token}' diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index f9e4f247642d3..e1e4e28cc7502 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -19,51 +19,24 @@ """This module contains Databricks operators.""" import time +from logging import Logger from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union from airflow.exceptions import AirflowException from airflow.models import BaseOperator, BaseOperatorLink, XCom -from airflow.providers.databricks.hooks.databricks import DatabricksHook +from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunState +from airflow.providers.databricks.triggers.databricks import DatabricksExecutionTrigger +from airflow.providers.databricks.utils.databricks import deep_string_coerce, validate_trigger_event if TYPE_CHECKING: from airflow.models.taskinstance import TaskInstanceKey from airflow.utils.context import Context +DEFER_METHOD_NAME = 'execute_complete' XCOM_RUN_ID_KEY = 'run_id' XCOM_RUN_PAGE_URL_KEY = 'run_page_url' -def _deep_string_coerce(content, json_path: str = 'json') -> Union[str, list, dict]: - """ - Coerces content or all values of content if it is a dict to a string. The - function will throw if content contains non-string or non-numeric types. - - The reason why we have this function is because the ``self.json`` field must be a - dict with only string values. This is because ``render_template`` will fail - for numerical values. - """ - coerce = _deep_string_coerce - if isinstance(content, str): - return content - elif isinstance( - content, - ( - int, - float, - ), - ): - # Databricks can tolerate either numeric or string types in the API backend. - return str(content) - elif isinstance(content, (list, tuple)): - return [coerce(e, f'{json_path}[{i}]') for i, e in enumerate(content)] - elif isinstance(content, dict): - return {k: coerce(v, f'{json_path}[{k}]') for k, v in list(content.items())} - else: - param_type = type(content) - msg = f'Type {param_type} used for parameter {json_path} is not a number or a string' - raise AirflowException(msg) - - def _handle_databricks_operator_execution(operator, hook, log, context) -> None: """ Handles the Airflow + Databricks lifecycle logic for a Databricks operator @@ -103,6 +76,47 @@ def _handle_databricks_operator_execution(operator, hook, log, context) -> None: log.info('View run status, Spark UI, and logs at %s', run_page_url) +def _handle_deferrable_databricks_operator_execution(operator, hook, log, context) -> None: + """ + Handles the Airflow + Databricks lifecycle logic for deferrable Databricks operators + + :param operator: Databricks async operator being handled + :param context: Airflow context + """ + if operator.do_xcom_push and context is not None: + context['ti'].xcom_push(key=XCOM_RUN_ID_KEY, value=operator.run_id) + log.info(f'Run submitted with run_id: {operator.run_id}') + + run_page_url = hook.get_run_page_url(operator.run_id) + if operator.do_xcom_push and context is not None: + context['ti'].xcom_push(key=XCOM_RUN_PAGE_URL_KEY, value=run_page_url) + log.info(f'View run status, Spark UI, and logs at {run_page_url}') + + if operator.wait_for_termination: + operator.defer( + trigger=DatabricksExecutionTrigger( + run_id=operator.run_id, + databricks_conn_id=operator.databricks_conn_id, + polling_period_seconds=operator.polling_period_seconds, + ), + method_name=DEFER_METHOD_NAME, + ) + + +def _handle_deferrable_databricks_operator_completion(event: dict, log: Logger) -> None: + validate_trigger_event(event) + run_state = RunState.from_json(event['run_state']) + run_page_url = event['run_page_url'] + log.info(f'View run status, Spark UI, and logs at {run_page_url}') + + if run_state.is_successful: + log.info('Job run completed successfully.') + return + else: + error_message = f'Job run failed with terminal state: {run_state}' + raise AirflowException(error_message) + + class DatabricksJobRunLink(BaseOperatorLink): """Constructs a link to monitor a Databricks Job Run.""" @@ -356,7 +370,7 @@ def __init__( if access_control_list is not None: self.json['access_control_list'] = access_control_list - self.json = _deep_string_coerce(self.json) + self.json = deep_string_coerce(self.json) # This variable will be used in case our task gets killed. self.run_id: Optional[int] = None self.do_xcom_push = do_xcom_push @@ -385,6 +399,18 @@ def on_kill(self): self.log.error('Error: Task: %s with invalid run_id was requested to be cancelled.', self.task_id) +class DatabricksSubmitRunDeferrableOperator(DatabricksSubmitRunOperator): + """Deferrable version of ``DatabricksSubmitRunOperator``""" + + def execute(self, context): + hook = self._get_hook() + self.run_id = hook.submit_run(self.json) + _handle_deferrable_databricks_operator_execution(self, hook, self.log, context) + + def execute_complete(self, context: Optional[dict], event: dict): + _handle_deferrable_databricks_operator_completion(event, self.log) + + class DatabricksRunNowOperator(BaseOperator): """ Runs an existing Spark job run to Databricks using the @@ -596,7 +622,7 @@ def __init__( if idempotency_token is not None: self.json['idempotency_token'] = idempotency_token - self.json = _deep_string_coerce(self.json) + self.json = deep_string_coerce(self.json) # This variable will be used in case our task gets killed. self.run_id: Optional[int] = None self.do_xcom_push = do_xcom_push @@ -629,3 +655,15 @@ def on_kill(self): ) else: self.log.error('Error: Task: %s with invalid run_id was requested to be cancelled.', self.task_id) + + +class DatabricksRunNowDeferrableOperator(DatabricksRunNowOperator): + """Deferrable version of ``DatabricksRunNowOperator``""" + + def execute(self, context): + hook = self._get_hook() + self.run_id = hook.run_now(self.json) + _handle_deferrable_databricks_operator_execution(self, hook, self.log, context) + + def execute_complete(self, context: Optional[dict], event: dict): + _handle_deferrable_databricks_operator_completion(event, self.log) diff --git a/airflow/providers/databricks/triggers/__init__.py b/airflow/providers/databricks/triggers/__init__.py new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/airflow/providers/databricks/triggers/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/databricks/triggers/databricks.py b/airflow/providers/databricks/triggers/databricks.py new file mode 100644 index 0000000000000..5f50f5aff29ee --- /dev/null +++ b/airflow/providers/databricks/triggers/databricks.py @@ -0,0 +1,77 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import asyncio +import logging +from typing import Any, Dict, Tuple + +from airflow.providers.databricks.hooks.databricks import DatabricksHook + +try: + from airflow.triggers.base import BaseTrigger, TriggerEvent +except ImportError: + logging.getLogger(__name__).warning( + 'Deferrable Operators only work starting Airflow 2.2', + exc_info=True, + ) + BaseTrigger = object # type: ignore + TriggerEvent = None # type: ignore + + +class DatabricksExecutionTrigger(BaseTrigger): + """ + The trigger handles the logic of async communication with DataBricks API. + + :param run_id: id of the run + :param databricks_conn_id: Reference to the :ref:`Databricks connection `. + :param polling_period_seconds: Controls the rate of the poll for the result of this run. + By default, the trigger will poll every 30 seconds. + """ + + def __init__(self, run_id: int, databricks_conn_id: str, polling_period_seconds: int = 30) -> None: + super().__init__() + self.run_id = run_id + self.databricks_conn_id = databricks_conn_id + self.polling_period_seconds = polling_period_seconds + self.hook = DatabricksHook(databricks_conn_id) + + def serialize(self) -> Tuple[str, Dict[str, Any]]: + return ( + 'airflow.providers.databricks.triggers.databricks.DatabricksExecutionTrigger', + { + 'run_id': self.run_id, + 'databricks_conn_id': self.databricks_conn_id, + 'polling_period_seconds': self.polling_period_seconds, + }, + ) + + async def run(self): + async with self.hook: + run_page_url = await self.hook.a_get_run_page_url(self.run_id) + while True: + run_state = await self.hook.a_get_run_state(self.run_id) + if run_state.is_terminal: + yield TriggerEvent( + { + 'run_id': self.run_id, + 'run_state': run_state.to_json(), + 'run_page_url': run_page_url, + } + ) + break + else: + await asyncio.sleep(self.polling_period_seconds) diff --git a/airflow/providers/databricks/utils/__init__.py b/airflow/providers/databricks/utils/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/databricks/utils/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/databricks/utils/databricks.py b/airflow/providers/databricks/utils/databricks.py new file mode 100644 index 0000000000000..96935d806344e --- /dev/null +++ b/airflow/providers/databricks/utils/databricks.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from typing import Union + +from airflow.exceptions import AirflowException +from airflow.providers.databricks.hooks.databricks import RunState + + +def deep_string_coerce(content, json_path: str = 'json') -> Union[str, list, dict]: + """ + Coerces content or all values of content if it is a dict to a string. The + function will throw if content contains non-string or non-numeric types. + The reason why we have this function is because the ``self.json`` field must be a + dict with only string values. This is because ``render_template`` will fail + for numerical values. + """ + coerce = deep_string_coerce + if isinstance(content, str): + return content + elif isinstance( + content, + ( + int, + float, + ), + ): + # Databricks can tolerate either numeric or string types in the API backend. + return str(content) + elif isinstance(content, (list, tuple)): + return [coerce(e, f'{json_path}[{i}]') for i, e in enumerate(content)] + elif isinstance(content, dict): + return {k: coerce(v, f'{json_path}[{k}]') for k, v in list(content.items())} + else: + param_type = type(content) + msg = f'Type {param_type} used for parameter {json_path} is not a number or a string' + raise AirflowException(msg) + + +def validate_trigger_event(event: dict): + """ + Validates correctness of the event + received from :class:`~airflow.providers.databricks.triggers.databricks.DatabricksExecutionTrigger` + """ + keys_to_check = ['run_id', 'run_page_url', 'run_state'] + for key in keys_to_check: + if key not in event: + raise AirflowException(f'Could not find `{key}` in the event: {event}') + + try: + RunState.from_json(event['run_state']) + except Exception: + raise AirflowException(f'Run state returned by the Trigger is incorrect: {event["run_state"]}') diff --git a/docs/apache-airflow-providers-databricks/operators/run_now.rst b/docs/apache-airflow-providers-databricks/operators/run_now.rst index 8b2e6010ee1fb..a4b00d9005c81 100644 --- a/docs/apache-airflow-providers-databricks/operators/run_now.rst +++ b/docs/apache-airflow-providers-databricks/operators/run_now.rst @@ -45,3 +45,10 @@ All other parameters are optional and described in documentation for ``Databrick * ``python_named_parameters`` * ``jar_params`` * ``spark_submit_params`` + +DatabricksRunNowDeferrableOperator +================================== + +Deferrable version of the :class:`~airflow.providers.databricks.operators.DatabricksRunNowOperator` operator. + +It allows to utilize Airflow workers more effectively using `new functionality introduced in Airflow 2.2.0 `_ diff --git a/docs/apache-airflow-providers-databricks/operators/submit_run.rst b/docs/apache-airflow-providers-databricks/operators/submit_run.rst index da71194da78fd..81f9dfd32f382 100644 --- a/docs/apache-airflow-providers-databricks/operators/submit_run.rst +++ b/docs/apache-airflow-providers-databricks/operators/submit_run.rst @@ -75,3 +75,10 @@ You can also use named parameters to initialize the operator and run the job. :language: python :start-after: [START howto_operator_databricks_named] :end-before: [END howto_operator_databricks_named] + +DatabricksSubmitRunDeferrableOperator +===================================== + +Deferrable version of the :class:`~airflow.providers.databricks.operators.DatabricksSubmitRunOperator` operator. + +It allows to utilize Airflow workers more effectively using `new functionality introduced in Airflow 2.2.0 `_ diff --git a/setup.py b/setup.py index f01ebb21b59e2..9ccd6cf6c1aa4 100644 --- a/setup.py +++ b/setup.py @@ -265,6 +265,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version databricks = [ 'requests>=2.26.0, <3', 'databricks-sql-connector>=2.0.0, <3.0.0', + 'aiohttp>=3.6.3, <4', ] datadog = [ 'datadog>=0.14.0', @@ -602,6 +603,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version # Dependencies needed for development only devel_only = [ + 'asynctest~=0.13', 'aws_xray_sdk', 'beautifulsoup4>=4.7.1', 'black', diff --git a/tests/providers/databricks/hooks/test_databricks.py b/tests/providers/databricks/hooks/test_databricks.py index 5a93ed7b4221f..2997984f7b90b 100644 --- a/tests/providers/databricks/hooks/test_databricks.py +++ b/tests/providers/databricks/hooks/test_databricks.py @@ -19,10 +19,11 @@ import itertools import json +import sys import time import unittest -from unittest import mock +import aiohttp import pytest import tenacity from requests import exceptions as requests_exceptions @@ -31,7 +32,12 @@ from airflow import __version__ from airflow.exceptions import AirflowException from airflow.models import Connection -from airflow.providers.databricks.hooks.databricks import SUBMIT_RUN_ENDPOINT, DatabricksHook, RunState +from airflow.providers.databricks.hooks.databricks import ( + GET_RUN_ENDPOINT, + SUBMIT_RUN_ENDPOINT, + DatabricksHook, + RunState, +) from airflow.providers.databricks.hooks.databricks_base import ( AZURE_DEFAULT_AD_ENDPOINT, AZURE_MANAGEMENT_ENDPOINT, @@ -39,9 +45,17 @@ AZURE_TOKEN_SERVICE_URL, DEFAULT_DATABRICKS_SCOPE, TOKEN_REFRESH_LEAD_TIME, + BearerAuth, ) from airflow.utils.session import provide_session +if sys.version_info < (3, 8): + from asynctest import mock + from asynctest.mock import CoroutineMock as AsyncMock +else: + from unittest import mock + from unittest.mock import AsyncMock + TASK_ID = 'databricks-operator' DEFAULT_CONN_ID = 'databricks_default' NOTEBOOK_TASK = {'notebook_path': '/test'} @@ -172,6 +186,7 @@ def list_jobs_endpoint(host): def create_valid_response_mock(content): response = mock.MagicMock() response.json.return_value = content + response.__aenter__.return_value.json = AsyncMock(return_value=content) return response @@ -785,6 +800,18 @@ def test_is_successful(self): run_state = RunState('TERMINATED', 'SUCCESS', '') assert run_state.is_successful + def test_to_json(self): + run_state = RunState('TERMINATED', 'SUCCESS', '') + expected = json.dumps( + {'life_cycle_state': 'TERMINATED', 'result_state': 'SUCCESS', 'state_message': ''} + ) + assert expected == run_state.to_json() + + def test_from_json(self): + state = {'life_cycle_state': 'TERMINATED', 'result_state': 'SUCCESS', 'state_message': ''} + expected = RunState('TERMINATED', 'SUCCESS', '') + assert expected == RunState.from_json(json.dumps(state)) + def create_aad_token_for_resource(resource: str) -> dict: return { @@ -976,3 +1003,324 @@ def test_submit_run(self, mock_requests): args = mock_requests.post.call_args kwargs = args[1] assert kwargs['auth'].token == TOKEN + + +class TestDatabricksHookAsyncMethods: + """ + Tests for async functionality of DatabricksHook. + """ + + @provide_session + def setup_method(self, method, session=None): + conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).first() + conn.host = HOST + conn.login = LOGIN + conn.password = PASSWORD + conn.extra = None + session.commit() + + self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS) + + @pytest.mark.asyncio + async def test_init_async_session(self): + async with self.hook: + assert isinstance(self.hook._session, aiohttp.ClientSession) + assert self.hook._session is None + + @pytest.mark.asyncio + @mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get') + async def test_do_api_call_retries_with_retryable_error(self, mock_get): + mock_get.side_effect = aiohttp.ClientResponseError(None, None, status=500) + with mock.patch.object(self.hook.log, 'error') as mock_errors: + async with self.hook: + with pytest.raises(AirflowException): + await self.hook._a_do_api_call(GET_RUN_ENDPOINT, {}) + assert mock_errors.call_count == DEFAULT_RETRY_NUMBER + + @pytest.mark.asyncio + @mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get') + async def test_do_api_call_does_not_retry_with_non_retryable_error(self, mock_get): + mock_get.side_effect = aiohttp.ClientResponseError(None, None, status=400) + with mock.patch.object(self.hook.log, 'error') as mock_errors: + async with self.hook: + with pytest.raises(AirflowException): + await self.hook._a_do_api_call(GET_RUN_ENDPOINT, {}) + mock_errors.assert_not_called() + + @pytest.mark.asyncio + @mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get') + async def test_do_api_call_succeeds_after_retrying(self, mock_get): + mock_get.side_effect = [ + aiohttp.ClientResponseError(None, None, status=500), + create_valid_response_mock({'run_id': '1'}), + ] + with mock.patch.object(self.hook.log, 'error') as mock_errors: + async with self.hook: + response = await self.hook._a_do_api_call(GET_RUN_ENDPOINT, {}) + assert mock_errors.call_count == 1 + assert response == {'run_id': '1'} + + @pytest.mark.asyncio + @mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get') + async def test_do_api_call_waits_between_retries(self, mock_get): + self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS) + + mock_get.side_effect = aiohttp.ClientResponseError(None, None, status=500) + with mock.patch.object(self.hook.log, 'error') as mock_errors: + async with self.hook: + with pytest.raises(AirflowException): + await self.hook._a_do_api_call(GET_RUN_ENDPOINT, {}) + assert mock_errors.call_count == DEFAULT_RETRY_NUMBER + + @pytest.mark.asyncio + @mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.patch') + async def test_do_api_call_patch(self, mock_patch): + mock_patch.return_value.__aenter__.return_value.json = AsyncMock( + return_value={'cluster_name': 'new_name'} + ) + data = {'cluster_name': 'new_name'} + async with self.hook: + patched_cluster_name = await self.hook._a_do_api_call(('PATCH', 'api/2.1/jobs/runs/submit'), data) + + assert patched_cluster_name['cluster_name'] == 'new_name' + mock_patch.assert_called_once_with( + submit_run_endpoint(HOST), + json={'cluster_name': 'new_name'}, + auth=aiohttp.BasicAuth(LOGIN, PASSWORD), + headers=USER_AGENT_HEADER, + timeout=self.hook.timeout_seconds, + ) + + @pytest.mark.asyncio + @mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get') + async def test_get_run_page_url(self, mock_get): + mock_get.return_value.__aenter__.return_value.json = AsyncMock(return_value=GET_RUN_RESPONSE) + async with self.hook: + run_page_url = await self.hook.a_get_run_page_url(RUN_ID) + + assert run_page_url == RUN_PAGE_URL + mock_get.assert_called_once_with( + get_run_endpoint(HOST), + json={'run_id': RUN_ID}, + auth=aiohttp.BasicAuth(LOGIN, PASSWORD), + headers=USER_AGENT_HEADER, + timeout=self.hook.timeout_seconds, + ) + + @pytest.mark.asyncio + @mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get') + async def test_get_run_state(self, mock_get): + mock_get.return_value.__aenter__.return_value.json = AsyncMock(return_value=GET_RUN_RESPONSE) + + async with self.hook: + run_state = await self.hook.a_get_run_state(RUN_ID) + + assert run_state == RunState(LIFE_CYCLE_STATE, RESULT_STATE, STATE_MESSAGE) + mock_get.assert_called_once_with( + get_run_endpoint(HOST), + json={'run_id': RUN_ID}, + auth=aiohttp.BasicAuth(LOGIN, PASSWORD), + headers=USER_AGENT_HEADER, + timeout=self.hook.timeout_seconds, + ) + + +class TestDatabricksHookAsyncAadToken: + """ + Tests for DatabricksHook using async methods when + auth is done with AAD token for SP as user inside workspace. + """ + + @provide_session + def setup_method(self, method, session=None): + conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).first() + conn.login = '9ff815a6-4404-4ab8-85cb-cd0e6f879c1d' + conn.password = 'secret' + conn.extra = json.dumps( + { + 'host': HOST, + 'azure_tenant_id': '3ff810a6-5504-4ab8-85cb-cd0e6f879c1d', + } + ) + session.commit() + self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS) + + @pytest.mark.asyncio + @mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get') + @mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.post') + async def test_get_run_state(self, mock_post, mock_get): + mock_post.return_value.__aenter__.return_value.json = AsyncMock( + return_value=create_aad_token_for_resource(DEFAULT_DATABRICKS_SCOPE) + ) + mock_get.return_value.__aenter__.return_value.json = AsyncMock(return_value=GET_RUN_RESPONSE) + + async with self.hook: + run_state = await self.hook.a_get_run_state(RUN_ID) + + assert run_state == RunState(LIFE_CYCLE_STATE, RESULT_STATE, STATE_MESSAGE) + mock_get.assert_called_once_with( + get_run_endpoint(HOST), + json={'run_id': RUN_ID}, + auth=BearerAuth(TOKEN), + headers=USER_AGENT_HEADER, + timeout=self.hook.timeout_seconds, + ) + + +class TestDatabricksHookAsyncAadTokenOtherClouds: + """ + Tests for DatabricksHook using async methodswhen auth is done with AAD token + for SP as user inside workspace and using non-global Azure cloud (China, GovCloud, Germany) + """ + + @provide_session + def setup_method(self, method, session=None): + self.tenant_id = '3ff810a6-5504-4ab8-85cb-cd0e6f879c1d' + self.ad_endpoint = 'https://login.microsoftonline.de' + self.client_id = '9ff815a6-4404-4ab8-85cb-cd0e6f879c1d' + conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).first() + conn.login = self.client_id + conn.password = 'secret' + conn.extra = json.dumps( + { + 'host': HOST, + 'azure_tenant_id': self.tenant_id, + 'azure_ad_endpoint': self.ad_endpoint, + } + ) + session.commit() + self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS) + + @pytest.mark.asyncio + @mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get') + @mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.post') + async def test_get_run_state(self, mock_post, mock_get): + mock_post.return_value.__aenter__.return_value.json = AsyncMock( + return_value=create_aad_token_for_resource(DEFAULT_DATABRICKS_SCOPE) + ) + mock_get.return_value.__aenter__.return_value.json = AsyncMock(return_value=GET_RUN_RESPONSE) + + async with self.hook: + run_state = await self.hook.a_get_run_state(RUN_ID) + + assert run_state == RunState(LIFE_CYCLE_STATE, RESULT_STATE, STATE_MESSAGE) + + ad_call_args = mock_post.call_args_list[0] + assert ad_call_args[1]['url'] == AZURE_TOKEN_SERVICE_URL.format(self.ad_endpoint, self.tenant_id) + assert ad_call_args[1]['data']['client_id'] == self.client_id + assert ad_call_args[1]['data']['resource'] == DEFAULT_DATABRICKS_SCOPE + + mock_get.assert_called_once_with( + get_run_endpoint(HOST), + json={'run_id': RUN_ID}, + auth=BearerAuth(TOKEN), + headers=USER_AGENT_HEADER, + timeout=self.hook.timeout_seconds, + ) + + +class TestDatabricksHookAsyncAadTokenSpOutside: + """ + Tests for DatabricksHook using async methods when auth is done with AAD token for SP outside of workspace. + """ + + @provide_session + def setup_method(self, method, session=None): + self.tenant_id = '3ff810a6-5504-4ab8-85cb-cd0e6f879c1d' + self.client_id = '9ff815a6-4404-4ab8-85cb-cd0e6f879c1d' + conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).first() + conn.login = self.client_id + conn.password = 'secret' + conn.host = HOST + conn.extra = json.dumps( + { + 'azure_resource_id': '/Some/resource', + 'azure_tenant_id': self.tenant_id, + } + ) + session.commit() + self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS) + + @pytest.mark.asyncio + @mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get') + @mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.post') + async def test_get_run_state(self, mock_post, mock_get): + mock_post.return_value.__aenter__.return_value.json.side_effect = AsyncMock( + side_effect=[ + create_aad_token_for_resource(AZURE_MANAGEMENT_ENDPOINT), + create_aad_token_for_resource(DEFAULT_DATABRICKS_SCOPE), + ] + ) + mock_get.return_value.__aenter__.return_value.json = AsyncMock(return_value=GET_RUN_RESPONSE) + + async with self.hook: + run_state = await self.hook.a_get_run_state(RUN_ID) + + assert run_state == RunState(LIFE_CYCLE_STATE, RESULT_STATE, STATE_MESSAGE) + + ad_call_args = mock_post.call_args_list[0] + assert ad_call_args[1]['url'] == AZURE_TOKEN_SERVICE_URL.format( + AZURE_DEFAULT_AD_ENDPOINT, self.tenant_id + ) + assert ad_call_args[1]['data']['client_id'] == self.client_id + assert ad_call_args[1]['data']['resource'] == AZURE_MANAGEMENT_ENDPOINT + + ad_call_args = mock_post.call_args_list[1] + assert ad_call_args[1]['url'] == AZURE_TOKEN_SERVICE_URL.format( + AZURE_DEFAULT_AD_ENDPOINT, self.tenant_id + ) + assert ad_call_args[1]['data']['client_id'] == self.client_id + assert ad_call_args[1]['data']['resource'] == DEFAULT_DATABRICKS_SCOPE + + mock_get.assert_called_once_with( + get_run_endpoint(HOST), + json={'run_id': RUN_ID}, + auth=BearerAuth(TOKEN), + headers={ + **USER_AGENT_HEADER, + 'X-Databricks-Azure-Workspace-Resource-Id': '/Some/resource', + 'X-Databricks-Azure-SP-Management-Token': TOKEN, + }, + timeout=self.hook.timeout_seconds, + ) + + +class TestDatabricksHookAsyncAadTokenManagedIdentity: + """ + Tests for DatabricksHook using async methods when + auth is done with AAD leveraging Managed Identity authentication + """ + + @provide_session + def setup_method(self, method, session=None): + conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).first() + conn.host = HOST + conn.extra = json.dumps( + { + 'use_azure_managed_identity': True, + } + ) + session.commit() + session.commit() + self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS) + + @pytest.mark.asyncio + @mock.patch('airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get') + async def test_get_run_state(self, mock_get): + mock_get.return_value.__aenter__.return_value.json.side_effect = AsyncMock( + side_effect=[ + {'compute': {'azEnvironment': 'AZUREPUBLICCLOUD'}}, + create_aad_token_for_resource(DEFAULT_DATABRICKS_SCOPE), + GET_RUN_RESPONSE, + ] + ) + + async with self.hook: + run_state = await self.hook.a_get_run_state(RUN_ID) + + assert run_state == RunState(LIFE_CYCLE_STATE, RESULT_STATE, STATE_MESSAGE) + + ad_call_args = mock_get.call_args_list[0] + assert ad_call_args[1]['url'] == AZURE_METADATA_SERVICE_INSTANCE_URL + assert ad_call_args[1]['params']['api-version'] > '2018-02-01' + assert ad_call_args[1]['headers']['Metadata'] == 'true' diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index c467234e09bc2..34e41a673236c 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -22,14 +22,17 @@ import pytest -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import DAG from airflow.providers.databricks.hooks.databricks import RunState -from airflow.providers.databricks.operators import databricks as databricks_operator from airflow.providers.databricks.operators.databricks import ( + DatabricksRunNowDeferrableOperator, DatabricksRunNowOperator, + DatabricksSubmitRunDeferrableOperator, DatabricksSubmitRunOperator, ) +from airflow.providers.databricks.triggers.databricks import DatabricksExecutionTrigger +from airflow.providers.databricks.utils import databricks as utils DATE = '2017-04-20' TASK_ID = 'databricks-operator' @@ -46,6 +49,7 @@ EXISTING_CLUSTER_ID = 'existing-cluster-id' RUN_NAME = 'run-name' RUN_ID = 1 +RUN_PAGE_URL = 'run-page-url' JOB_ID = "42" JOB_NAME = "job-name" NOTEBOOK_PARAMS = {"dry-run": "true", "oldest-time-to-consider": "1457570074236"} @@ -56,26 +60,6 @@ SPARK_SUBMIT_PARAMS = ["--class", "org.apache.spark.examples.SparkPi"] -class TestDatabricksOperatorSharedFunctions(unittest.TestCase): - def test_deep_string_coerce(self): - test_json = { - 'test_int': 1, - 'test_float': 1.0, - 'test_dict': {'key': 'value'}, - 'test_list': [1, 1.0, 'a', 'b'], - 'test_tuple': (1, 1.0, 'a', 'b'), - } - - expected = { - 'test_int': '1', - 'test_float': '1.0', - 'test_dict': {'key': 'value'}, - 'test_list': ['1', '1.0', 'a', 'b'], - 'test_tuple': ['1', '1.0', 'a', 'b'], - } - assert databricks_operator._deep_string_coerce(test_json) == expected - - class TestDatabricksSubmitRunOperator(unittest.TestCase): def test_init_with_notebook_task_named_parameters(self): """ @@ -84,7 +68,7 @@ def test_init_with_notebook_task_named_parameters(self): op = DatabricksSubmitRunOperator( task_id=TASK_ID, new_cluster=NEW_CLUSTER, notebook_task=NOTEBOOK_TASK ) - expected = databricks_operator._deep_string_coerce( + expected = utils.deep_string_coerce( {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 'run_name': TASK_ID} ) @@ -97,7 +81,7 @@ def test_init_with_spark_python_task_named_parameters(self): op = DatabricksSubmitRunOperator( task_id=TASK_ID, new_cluster=NEW_CLUSTER, spark_python_task=SPARK_PYTHON_TASK ) - expected = databricks_operator._deep_string_coerce( + expected = utils.deep_string_coerce( {'new_cluster': NEW_CLUSTER, 'spark_python_task': SPARK_PYTHON_TASK, 'run_name': TASK_ID} ) @@ -110,7 +94,7 @@ def test_init_with_spark_submit_task_named_parameters(self): op = DatabricksSubmitRunOperator( task_id=TASK_ID, new_cluster=NEW_CLUSTER, spark_submit_task=SPARK_SUBMIT_TASK ) - expected = databricks_operator._deep_string_coerce( + expected = utils.deep_string_coerce( {'new_cluster': NEW_CLUSTER, 'spark_submit_task': SPARK_SUBMIT_TASK, 'run_name': TASK_ID} ) @@ -122,7 +106,7 @@ def test_init_with_json(self): """ json = {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK} op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) - expected = databricks_operator._deep_string_coerce( + expected = utils.deep_string_coerce( {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 'run_name': TASK_ID} ) assert expected == op.json @@ -130,7 +114,7 @@ def test_init_with_json(self): def test_init_with_tasks(self): tasks = [{"task_key": 1, "new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK}] op = DatabricksSubmitRunOperator(task_id=TASK_ID, tasks=tasks) - expected = databricks_operator._deep_string_coerce({'run_name': TASK_ID, "tasks": tasks}) + expected = utils.deep_string_coerce({'run_name': TASK_ID, "tasks": tasks}) assert expected == op.json def test_init_with_specified_run_name(self): @@ -139,7 +123,7 @@ def test_init_with_specified_run_name(self): """ json = {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 'run_name': RUN_NAME} op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) - expected = databricks_operator._deep_string_coerce( + expected = utils.deep_string_coerce( {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 'run_name': RUN_NAME} ) assert expected == op.json @@ -151,7 +135,7 @@ def test_pipeline_task(self): pipeline_task = {"pipeline_id": "test-dlt"} json = {'new_cluster': NEW_CLUSTER, 'run_name': RUN_NAME, "pipeline_task": pipeline_task} op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) - expected = databricks_operator._deep_string_coerce( + expected = utils.deep_string_coerce( {'new_cluster': NEW_CLUSTER, "pipeline_task": pipeline_task, 'run_name': RUN_NAME} ) assert expected == op.json @@ -168,7 +152,7 @@ def test_init_with_merging(self): 'notebook_task': NOTEBOOK_TASK, } op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json, new_cluster=override_new_cluster) - expected = databricks_operator._deep_string_coerce( + expected = utils.deep_string_coerce( { 'new_cluster': override_new_cluster, 'notebook_task': NOTEBOOK_TASK, @@ -185,7 +169,7 @@ def test_init_with_templating(self): dag = DAG('test', start_date=datetime.now()) op = DatabricksSubmitRunOperator(dag=dag, task_id=TASK_ID, json=json) op.render_template_fields(context={'ds': DATE}) - expected = databricks_operator._deep_string_coerce( + expected = utils.deep_string_coerce( { 'new_cluster': NEW_CLUSTER, 'notebook_task': RENDERED_TEMPLATED_NOTEBOOK_TASK, @@ -220,7 +204,7 @@ def test_exec_success(self, db_mock_class): op.execute(None) - expected = databricks_operator._deep_string_coerce( + expected = utils.deep_string_coerce( {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 'run_name': TASK_ID} ) db_mock_class.assert_called_once_with( @@ -252,7 +236,7 @@ def test_exec_failure(self, db_mock_class): with pytest.raises(AirflowException): op.execute(None) - expected = databricks_operator._deep_string_coerce( + expected = utils.deep_string_coerce( { 'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, @@ -299,7 +283,7 @@ def test_wait_for_termination(self, db_mock_class): op.execute(None) - expected = databricks_operator._deep_string_coerce( + expected = utils.deep_string_coerce( {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 'run_name': TASK_ID} ) db_mock_class.assert_called_once_with( @@ -327,7 +311,7 @@ def test_no_wait_for_termination(self, db_mock_class): op.execute(None) - expected = databricks_operator._deep_string_coerce( + expected = utils.deep_string_coerce( {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 'run_name': TASK_ID} ) db_mock_class.assert_called_once_with( @@ -342,13 +326,98 @@ def test_no_wait_for_termination(self, db_mock_class): db_mock.get_run_state.assert_not_called() +class TestDatabricksSubmitRunDeferrableOperator(unittest.TestCase): + @mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook') + def test_execute_task_deferred(self, db_mock_class): + """ + Test the execute function in case where the run is successful. + """ + run = { + 'new_cluster': NEW_CLUSTER, + 'notebook_task': NOTEBOOK_TASK, + } + op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run) + db_mock = db_mock_class.return_value + db_mock.submit_run.return_value = 1 + db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '') + + with pytest.raises(TaskDeferred) as exc: + op.execute(None) + self.assertTrue(isinstance(exc.value.trigger, DatabricksExecutionTrigger)) + self.assertEqual(exc.value.method_name, 'execute_complete') + + expected = utils.deep_string_coerce( + {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 'run_name': TASK_ID} + ) + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, + ) + + db_mock.submit_run.assert_called_once_with(expected) + db_mock.get_run_page_url.assert_called_once_with(RUN_ID) + self.assertEqual(RUN_ID, op.run_id) + + def test_execute_complete_success(self): + """ + Test `execute_complete` function in case the Trigger has returned a successful completion event. + """ + run = { + 'new_cluster': NEW_CLUSTER, + 'notebook_task': NOTEBOOK_TASK, + } + event = { + 'run_id': RUN_ID, + 'run_page_url': RUN_PAGE_URL, + 'run_state': RunState('TERMINATED', 'SUCCESS', '').to_json(), + } + + op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run) + self.assertIsNone(op.execute_complete(context=None, event=event)) + + @mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook') + def test_execute_complete_failure(self, db_mock_class): + """ + Test `execute_complete` function in case the Trigger has returned a failure completion event. + """ + run_state_failed = RunState('TERMINATED', 'FAILED', '') + run = { + 'new_cluster': NEW_CLUSTER, + 'notebook_task': NOTEBOOK_TASK, + } + event = { + 'run_id': RUN_ID, + 'run_page_url': RUN_PAGE_URL, + 'run_state': run_state_failed.to_json(), + } + + op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run) + with pytest.raises(AirflowException): + op.execute_complete(context=None, event=event) + + db_mock = db_mock_class.return_value + db_mock.submit_run.return_value = 1 + db_mock.get_run_state.return_value = run_state_failed + + with pytest.raises(AirflowException, match=f'Job run failed with terminal state: {run_state_failed}'): + op.execute_complete(context=None, event=event) + + def test_execute_complete_incorrect_event_validation_failure(self): + event = {} + op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID) + with pytest.raises(AirflowException): + op.execute_complete(context=None, event=event) + + class TestDatabricksRunNowOperator(unittest.TestCase): def test_init_with_named_parameters(self): """ Test the initializer with the named parameters. """ op = DatabricksRunNowOperator(job_id=JOB_ID, task_id=TASK_ID) - expected = databricks_operator._deep_string_coerce({'job_id': 42}) + expected = utils.deep_string_coerce({'job_id': 42}) assert expected == op.json @@ -365,7 +434,7 @@ def test_init_with_json(self): } op = DatabricksRunNowOperator(task_id=TASK_ID, json=json) - expected = databricks_operator._deep_string_coerce( + expected = utils.deep_string_coerce( { 'notebook_params': NOTEBOOK_PARAMS, 'jar_params': JAR_PARAMS, @@ -397,7 +466,7 @@ def test_init_with_merging(self): spark_submit_params=SPARK_SUBMIT_PARAMS, ) - expected = databricks_operator._deep_string_coerce( + expected = utils.deep_string_coerce( { 'notebook_params': override_notebook_params, 'jar_params': override_jar_params, @@ -415,7 +484,7 @@ def test_init_with_templating(self): dag = DAG('test', start_date=datetime.now()) op = DatabricksRunNowOperator(dag=dag, task_id=TASK_ID, job_id=JOB_ID, json=json) op.render_template_fields(context={'ds': DATE}) - expected = databricks_operator._deep_string_coerce( + expected = utils.deep_string_coerce( { 'notebook_params': NOTEBOOK_PARAMS, 'jar_params': RENDERED_TEMPLATED_JAR_PARAMS, @@ -447,7 +516,7 @@ def test_exec_success(self, db_mock_class): op.execute(None) - expected = databricks_operator._deep_string_coerce( + expected = utils.deep_string_coerce( { 'notebook_params': NOTEBOOK_PARAMS, 'notebook_task': NOTEBOOK_TASK, @@ -481,7 +550,7 @@ def test_exec_failure(self, db_mock_class): with pytest.raises(AirflowException): op.execute(None) - expected = databricks_operator._deep_string_coerce( + expected = utils.deep_string_coerce( { 'notebook_params': NOTEBOOK_PARAMS, 'notebook_task': NOTEBOOK_TASK, @@ -522,7 +591,7 @@ def test_wait_for_termination(self, db_mock_class): op.execute(None) - expected = databricks_operator._deep_string_coerce( + expected = utils.deep_string_coerce( { 'notebook_params': NOTEBOOK_PARAMS, 'notebook_task': NOTEBOOK_TASK, @@ -552,7 +621,7 @@ def test_no_wait_for_termination(self, db_mock_class): op.execute(None) - expected = databricks_operator._deep_string_coerce( + expected = utils.deep_string_coerce( { 'notebook_params': NOTEBOOK_PARAMS, 'notebook_task': NOTEBOOK_TASK, @@ -596,7 +665,7 @@ def test_exec_with_job_name(self, db_mock_class): op.execute(None) - expected = databricks_operator._deep_string_coerce( + expected = utils.deep_string_coerce( { 'notebook_params': NOTEBOOK_PARAMS, 'notebook_task': NOTEBOOK_TASK, @@ -629,3 +698,85 @@ def test_exec_failure_if_job_id_not_found(self, db_mock_class): op.execute(None) db_mock.find_job_id_by_name.assert_called_once_with(JOB_NAME) + + +class TestDatabricksRunNowDeferrableOperator(unittest.TestCase): + @mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook') + def test_execute_task_deferred(self, db_mock_class): + """ + Test the execute function in case where the run is successful. + """ + run = {'notebook_params': NOTEBOOK_PARAMS, 'notebook_task': NOTEBOOK_TASK, 'jar_params': JAR_PARAMS} + op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) + db_mock = db_mock_class.return_value + db_mock.run_now.return_value = 1 + db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '') + + with pytest.raises(TaskDeferred) as exc: + op.execute(None) + self.assertTrue(isinstance(exc.value.trigger, DatabricksExecutionTrigger)) + self.assertEqual(exc.value.method_name, 'execute_complete') + + expected = utils.deep_string_coerce( + { + 'notebook_params': NOTEBOOK_PARAMS, + 'notebook_task': NOTEBOOK_TASK, + 'jar_params': JAR_PARAMS, + 'job_id': JOB_ID, + } + ) + + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, + ) + + db_mock.run_now.assert_called_once_with(expected) + db_mock.get_run_page_url.assert_called_once_with(RUN_ID) + self.assertEqual(RUN_ID, op.run_id) + + def test_execute_complete_success(self): + """ + Test `execute_complete` function in case the Trigger has returned a successful completion event. + """ + run = {'notebook_params': NOTEBOOK_PARAMS, 'notebook_task': NOTEBOOK_TASK, 'jar_params': JAR_PARAMS} + event = { + 'run_id': RUN_ID, + 'run_page_url': RUN_PAGE_URL, + 'run_state': RunState('TERMINATED', 'SUCCESS', '').to_json(), + } + + op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) + self.assertIsNone(op.execute_complete(context=None, event=event)) + + @mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook') + def test_execute_complete_failure(self, db_mock_class): + """ + Test `execute_complete` function in case the Trigger has returned a failure completion event. + """ + run_state_failed = RunState('TERMINATED', 'FAILED', '') + run = {'notebook_params': NOTEBOOK_PARAMS, 'notebook_task': NOTEBOOK_TASK, 'jar_params': JAR_PARAMS} + event = { + 'run_id': RUN_ID, + 'run_page_url': RUN_PAGE_URL, + 'run_state': run_state_failed.to_json(), + } + + op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) + with pytest.raises(AirflowException): + op.execute_complete(context=None, event=event) + + db_mock = db_mock_class.return_value + db_mock.run_now.return_value = 1 + db_mock.get_run_state.return_value = run_state_failed + + with pytest.raises(AirflowException, match=f'Job run failed with terminal state: {run_state_failed}'): + op.execute_complete(context=None, event=event) + + def test_execute_complete_incorrect_event_validation_failure(self): + event = {} + op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID) + with pytest.raises(AirflowException): + op.execute_complete(context=None, event=event) diff --git a/tests/providers/databricks/triggers/__init__.py b/tests/providers/databricks/triggers/__init__.py new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/tests/providers/databricks/triggers/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/databricks/triggers/test_databricks.py b/tests/providers/databricks/triggers/test_databricks.py new file mode 100644 index 0000000000000..cecbed11388ca --- /dev/null +++ b/tests/providers/databricks/triggers/test_databricks.py @@ -0,0 +1,153 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import sys + +import pytest + +from airflow.models import Connection +from airflow.providers.databricks.hooks.databricks import RunState +from airflow.providers.databricks.triggers.databricks import DatabricksExecutionTrigger +from airflow.triggers.base import TriggerEvent +from airflow.utils.session import provide_session + +if sys.version_info < (3, 8): + from asynctest import mock +else: + from unittest import mock + +DEFAULT_CONN_ID = 'databricks_default' +HOST = 'xx.cloud.databricks.com' +LOGIN = 'login' +PASSWORD = 'password' +POLLING_INTERVAL_SECONDS = 30 +RETRY_DELAY = 10 +RETRY_LIMIT = 3 +RUN_ID = 1 +JOB_ID = 42 +RUN_PAGE_URL = 'https://XX.cloud.databricks.com/#jobs/1/runs/1' + +RUN_LIFE_CYCLE_STATES = ['PENDING', 'RUNNING', 'TERMINATING', 'TERMINATED', 'SKIPPED', 'INTERNAL_ERROR'] + +LIFE_CYCLE_STATE_PENDING = 'PENDING' +LIFE_CYCLE_STATE_TERMINATED = 'TERMINATED' + +STATE_MESSAGE = 'Waiting for cluster' + +GET_RUN_RESPONSE_PENDING = { + 'job_id': JOB_ID, + 'run_page_url': RUN_PAGE_URL, + 'state': { + 'life_cycle_state': LIFE_CYCLE_STATE_PENDING, + 'state_message': STATE_MESSAGE, + 'result_state': None, + }, +} +GET_RUN_RESPONSE_TERMINATED = { + 'job_id': JOB_ID, + 'run_page_url': RUN_PAGE_URL, + 'state': { + 'life_cycle_state': LIFE_CYCLE_STATE_TERMINATED, + 'state_message': None, + 'result_state': 'SUCCESS', + }, +} + + +class TestDatabricksExecutionTrigger: + @provide_session + def setup_method(self, method, session=None): + conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).first() + conn.host = HOST + conn.login = LOGIN + conn.password = PASSWORD + conn.extra = None + session.commit() + + self.trigger = DatabricksExecutionTrigger( + run_id=RUN_ID, + databricks_conn_id=DEFAULT_CONN_ID, + polling_period_seconds=POLLING_INTERVAL_SECONDS, + ) + + def test_serialize(self): + assert self.trigger.serialize() == ( + 'airflow.providers.databricks.triggers.databricks.DatabricksExecutionTrigger', + { + 'run_id': RUN_ID, + 'databricks_conn_id': DEFAULT_CONN_ID, + 'polling_period_seconds': POLLING_INTERVAL_SECONDS, + }, + ) + + @pytest.mark.asyncio + @mock.patch('airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_page_url') + @mock.patch('airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state') + async def test_run_return_success(self, mock_get_run_state, mock_get_run_page_url): + mock_get_run_page_url.return_value = RUN_PAGE_URL + mock_get_run_state.return_value = RunState( + life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, + state_message='', + result_state='SUCCESS', + ) + + trigger_event = self.trigger.run() + async for event in trigger_event: + assert event == TriggerEvent( + { + 'run_id': RUN_ID, + 'run_state': RunState( + life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, state_message='', result_state='SUCCESS' + ).to_json(), + 'run_page_url': RUN_PAGE_URL, + } + ) + + @pytest.mark.asyncio + @mock.patch('airflow.providers.databricks.triggers.databricks.asyncio.sleep') + @mock.patch('airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_page_url') + @mock.patch('airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state') + async def test_sleep_between_retries(self, mock_get_run_state, mock_get_run_page_url, mock_sleep): + mock_get_run_page_url.return_value = RUN_PAGE_URL + mock_get_run_state.side_effect = [ + RunState( + life_cycle_state=LIFE_CYCLE_STATE_PENDING, + state_message='', + result_state='', + ), + RunState( + life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, + state_message='', + result_state='SUCCESS', + ), + ] + + trigger_event = self.trigger.run() + async for event in trigger_event: + assert event == TriggerEvent( + { + 'run_id': RUN_ID, + 'run_state': RunState( + life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, state_message='', result_state='SUCCESS' + ).to_json(), + 'run_page_url': RUN_PAGE_URL, + } + ) + mock_sleep.assert_called_once() + mock_sleep.assert_called_with(POLLING_INTERVAL_SECONDS) diff --git a/tests/providers/databricks/utils/__init__.py b/tests/providers/databricks/utils/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/databricks/utils/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/databricks/utils/databricks.py b/tests/providers/databricks/utils/databricks.py new file mode 100644 index 0000000000000..d450a19ceb6c0 --- /dev/null +++ b/tests/providers/databricks/utils/databricks.py @@ -0,0 +1,62 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import unittest + +import pytest + +from airflow.exceptions import AirflowException +from airflow.providers.databricks.hooks.databricks import RunState +from airflow.providers.databricks.utils.databricks import deep_string_coerce, validate_trigger_event + +RUN_ID = 1 +RUN_PAGE_URL = 'run-page-url' + + +class TestDatabricksOperatorSharedFunctions(unittest.TestCase): + def test_deep_string_coerce(self): + test_json = { + 'test_int': 1, + 'test_float': 1.0, + 'test_dict': {'key': 'value'}, + 'test_list': [1, 1.0, 'a', 'b'], + 'test_tuple': (1, 1.0, 'a', 'b'), + } + + expected = { + 'test_int': '1', + 'test_float': '1.0', + 'test_dict': {'key': 'value'}, + 'test_list': ['1', '1.0', 'a', 'b'], + 'test_tuple': ['1', '1.0', 'a', 'b'], + } + assert deep_string_coerce(test_json) == expected + + def test_validate_trigger_event_success(self): + event = { + 'run_id': RUN_ID, + 'run_page_url': RUN_PAGE_URL, + 'run_state': RunState('TERMINATED', 'SUCCESS', '').to_json(), + } + self.assertIsNone(validate_trigger_event(event)) + + def test_validate_trigger_event_failure(self): + event = {} + with pytest.raises(AirflowException): + validate_trigger_event(event)