Skip to content

Commit

Permalink
Add Deferrable Databricks operators
Browse files Browse the repository at this point in the history
  • Loading branch information
eskarimov committed May 3, 2022
1 parent af54630 commit 9b40c02
Show file tree
Hide file tree
Showing 16 changed files with 1,313 additions and 93 deletions.
29 changes: 29 additions & 0 deletions airflow/providers/databricks/hooks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
or the ``api/2.1/jobs/runs/submit``
`endpoint <https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunsSubmit>`_.
"""
import json
from typing import Any, Dict, List, Optional

from requests import exceptions as requests_exceptions
Expand Down Expand Up @@ -92,6 +93,13 @@ def __eq__(self, other: object) -> bool:
def __repr__(self) -> str:
return str(self.__dict__)

def to_json(self) -> str:
return json.dumps(self.__dict__)

@classmethod
def from_json(cls, data: str) -> 'RunState':
return RunState(**json.loads(data))


class DatabricksHook(BaseDatabricksHook):
"""
Expand Down Expand Up @@ -198,6 +206,16 @@ def get_run_page_url(self, run_id: int) -> str:
response = self._do_api_call(GET_RUN_ENDPOINT, json)
return response['run_page_url']

async def a_get_run_page_url(self, run_id: int) -> str:
"""
Async version of `get_run_page_url()`.
:param run_id: id of the run
:return: URL of the run page
"""
json = {'run_id': run_id}
response = await self._a_do_api_call(GET_RUN_ENDPOINT, json)
return response['run_page_url']

def get_job_id(self, run_id: int) -> int:
"""
Retrieves job_id from run_id.
Expand Down Expand Up @@ -229,6 +247,17 @@ def get_run_state(self, run_id: int) -> RunState:
state = response['state']
return RunState(**state)

async def a_get_run_state(self, run_id: int) -> RunState:
"""
Async version of `get_run_state()`.
:param run_id: id of the run
:return: state of the run
"""
json = {'run_id': run_id}
response = await self._a_do_api_call(GET_RUN_ENDPOINT, json)
state = response['state']
return RunState(**state)

def get_run_state_str(self, run_id: int) -> str:
"""
Return the string representation of RunState.
Expand Down
237 changes: 224 additions & 13 deletions airflow/providers/databricks/hooks/databricks_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,19 @@
from typing import Any, Dict, Optional, Tuple
from urllib.parse import urlparse

import aiohttp
import requests
from requests import PreparedRequest, exceptions as requests_exceptions
from requests.auth import AuthBase, HTTPBasicAuth
from requests.exceptions import JSONDecodeError
from tenacity import RetryError, Retrying, retry_if_exception, stop_after_attempt, wait_exponential
from tenacity import (
AsyncRetrying,
RetryError,
Retrying,
retry_if_exception,
stop_after_attempt,
wait_exponential,
)

from airflow import __version__
from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -135,6 +143,14 @@ def host(self) -> str:

return host

async def __aenter__(self):
self._session = aiohttp.ClientSession()
return self

async def __aexit__(self, *err):
await self._session.close()
self._session = None

@staticmethod
def _parse_host(host: str) -> str:
"""
Expand Down Expand Up @@ -169,6 +185,13 @@ def _get_retry_object(self) -> Retrying:
"""
return Retrying(**self.retry_args)

def _a_get_retry_object(self) -> AsyncRetrying:
"""
Instantiates an async retry object
:return: instance of AsyncRetrying class
"""
return AsyncRetrying(**self.retry_args)

def _get_aad_token(self, resource: str) -> str:
"""
Function to get AAD token for given resource. Supports managed identity or service principal auth
Expand Down Expand Up @@ -234,6 +257,72 @@ def _get_aad_token(self, resource: str) -> str:

return token

async def _a_get_aad_token(self, resource: str) -> str:
"""
Async version of `_get_aad_token()`.
:param resource: resource to issue token to
:return: AAD token, or raise an exception
"""
aad_token = self.aad_tokens.get(resource)
if aad_token and self._is_aad_token_valid(aad_token):
return aad_token['token']

self.log.info('Existing AAD token is expired, or going to expire soon. Refreshing...')
try:
async for attempt in self._a_get_retry_object():
with attempt:
if self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False):
params = {
"api-version": "2018-02-01",
"resource": resource,
}
async with self._session.get(
url=AZURE_METADATA_SERVICE_TOKEN_URL,
params=params,
headers={**USER_AGENT_HEADER, "Metadata": "true"},
timeout=self.aad_timeout_seconds,
) as resp:
resp.raise_for_status()
jsn = await resp.json()
else:
tenant_id = self.databricks_conn.extra_dejson['azure_tenant_id']
data = {
"grant_type": "client_credentials",
"client_id": self.databricks_conn.login,
"resource": resource,
"client_secret": self.databricks_conn.password,
}
azure_ad_endpoint = self.databricks_conn.extra_dejson.get(
"azure_ad_endpoint", AZURE_DEFAULT_AD_ENDPOINT
)
async with self._session.post(
url=AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint, tenant_id),
data=data,
headers={
**USER_AGENT_HEADER,
'Content-Type': 'application/x-www-form-urlencoded',
},
timeout=self.aad_timeout_seconds,
) as resp:
resp.raise_for_status()
jsn = await resp.json()
if (
'access_token' not in jsn
or jsn.get('token_type') != 'Bearer'
or 'expires_on' not in jsn
):
raise AirflowException(f"Can't get necessary data from AAD token: {jsn}")

token = jsn['access_token']
self.aad_tokens[resource] = {'token': token, 'expires_on': int(jsn["expires_on"])}
break
except RetryError:
raise AirflowException(f'API requests to Azure failed {self.retry_limit} times. Giving up.')
except aiohttp.ClientResponseError as err:
raise AirflowException(f'Response: {err.message}, Status Code: {err.status}')

return token

def _get_aad_headers(self) -> dict:
"""
Fills AAD headers if necessary (SPN is outside of the workspace)
Expand All @@ -248,6 +337,20 @@ def _get_aad_headers(self) -> dict:
headers['X-Databricks-Azure-SP-Management-Token'] = mgmt_token
return headers

async def _a_get_aad_headers(self) -> dict:
"""
Async version of `_get_aad_headers()`.
:return: dictionary with filled AAD headers
"""
headers = {}
if 'azure_resource_id' in self.databricks_conn.extra_dejson:
mgmt_token = await self._a_get_aad_token(AZURE_MANAGEMENT_ENDPOINT)
headers['X-Databricks-Azure-Workspace-Resource-Id'] = self.databricks_conn.extra_dejson[
'azure_resource_id'
]
headers['X-Databricks-Azure-SP-Management-Token'] = mgmt_token
return headers

@staticmethod
def _is_aad_token_valid(aad_token: dict) -> bool:
"""
Expand Down Expand Up @@ -281,6 +384,23 @@ def _check_azure_metadata_service() -> None:
except (requests_exceptions.RequestException, ValueError) as e:
raise AirflowException(f"Can't reach Azure Metadata Service: {e}")

async def _a_check_azure_metadata_service(self):
"""Async version of `_check_azure_metadata_service()`."""
try:
async with self._session.get(
url=AZURE_METADATA_SERVICE_INSTANCE_URL,
params={"api-version": "2021-02-01"},
headers={"Metadata": "true"},
timeout=2,
) as resp:
jsn = await resp.json()
if 'compute' not in jsn or 'azEnvironment' not in jsn['compute']:
raise AirflowException(
f"Was able to fetch some metadata, but it doesn't look like Azure Metadata: {jsn}"
)
except (requests_exceptions.RequestException, ValueError) as e:
raise AirflowException(f"Can't reach Azure Metadata Service: {e}")

def _get_token(self, raise_error: bool = False) -> Optional[str]:
if 'token' in self.databricks_conn.extra_dejson:
self.log.info(
Expand All @@ -304,6 +424,29 @@ def _get_token(self, raise_error: bool = False) -> Optional[str]:

return None

async def _a_get_token(self, raise_error: bool = False) -> Optional[str]:
if 'token' in self.databricks_conn.extra_dejson:
self.log.info(
'Using token auth. For security reasons, please set token in Password field instead of extra'
)
return self.databricks_conn.extra_dejson["token"]
elif not self.databricks_conn.login and self.databricks_conn.password:
self.log.info('Using token auth.')
return self.databricks_conn.password
elif 'azure_tenant_id' in self.databricks_conn.extra_dejson:
if self.databricks_conn.login == "" or self.databricks_conn.password == "":
raise AirflowException("Azure SPN credentials aren't provided")
self.log.info('Using AAD Token for SPN.')
return await self._a_get_aad_token(DEFAULT_DATABRICKS_SCOPE)
elif self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False):
self.log.info('Using AAD Token for managed identity.')
await self._a_check_azure_metadata_service()
return await self._a_get_aad_token(DEFAULT_DATABRICKS_SCOPE)
elif raise_error:
raise AirflowException('Token authentication isn\'t configured')

return None

def _log_request_error(self, attempt_num: int, error: str) -> None:
self.log.error('Attempt %s API Request to Databricks failed with reason: %s', attempt_num, error)

Expand Down Expand Up @@ -374,6 +517,55 @@ def _do_api_call(
else:
raise e

async def _a_do_api_call(self, endpoint_info: Tuple[str, str], json: Optional[Dict[str, Any]] = None):
"""
Async version of `_do_api_call()`.
:param endpoint_info: Tuple of method and endpoint
:param json: Parameters for this API call.
:return: If the api call returns a OK status code,
this function returns the response in JSON. Otherwise, throw an AirflowException.
"""
method, endpoint = endpoint_info

url = f'https://{self.host}/{endpoint}'

aad_headers = await self._a_get_aad_headers()
headers = {**USER_AGENT_HEADER.copy(), **aad_headers}

auth: aiohttp.BasicAuth
token = await self._a_get_token()
if token:
auth = BearerAuth(token)
else:
self.log.info('Using basic auth.')
auth = aiohttp.BasicAuth(self.databricks_conn.login, self.databricks_conn.password)

request_func: Any
if method == 'GET':
request_func = self._session.get
elif method == 'POST':
request_func = self._session.post
elif method == 'PATCH':
request_func = self._session.patch
else:
raise AirflowException('Unexpected HTTP Method: ' + method)
try:
async for attempt in self._a_get_retry_object():
with attempt:
async with request_func(
url,
json=json,
auth=auth,
headers={**headers, **USER_AGENT_HEADER},
timeout=self.timeout_seconds,
) as response:
response.raise_for_status()
return await response.json()
except RetryError:
raise AirflowException(f'API requests to Databricks failed {self.retry_limit} times. Giving up.')
except aiohttp.ClientResponseError as err:
raise AirflowException(f'Response: {err.message}, Status Code: {err.status}')

@staticmethod
def _get_error_code(exception: BaseException) -> str:
if isinstance(exception, requests_exceptions.HTTPError):
Expand All @@ -387,19 +579,25 @@ def _get_error_code(exception: BaseException) -> str:

@staticmethod
def _retryable_error(exception: BaseException) -> bool:
if not isinstance(exception, requests_exceptions.RequestException):
return False
return isinstance(exception, (requests_exceptions.ConnectionError, requests_exceptions.Timeout)) or (
exception.response is not None
and (
exception.response.status_code >= 500
or exception.response.status_code == 429
or (
exception.response.status_code == 400
and BaseDatabricksHook._get_error_code(exception) == 'COULD_NOT_ACQUIRE_LOCK'
if isinstance(exception, requests_exceptions.RequestException):
if isinstance(exception, (requests_exceptions.ConnectionError, requests_exceptions.Timeout)) or (
exception.response is not None
and (
exception.response.status_code >= 500
or exception.response.status_code == 429
or (
exception.response.status_code == 400
and BaseDatabricksHook._get_error_code(exception) == 'COULD_NOT_ACQUIRE_LOCK'
)
)
)
)
):
return True

if isinstance(exception, aiohttp.ClientResponseError):
if exception.status >= 500 or exception.status == 429:
return True

return False


class _TokenAuth(AuthBase):
Expand All @@ -414,3 +612,16 @@ def __init__(self, token: str) -> None:
def __call__(self, r: PreparedRequest) -> PreparedRequest:
r.headers['Authorization'] = 'Bearer ' + self.token
return r


class BearerAuth(aiohttp.BasicAuth):
"""aiohttp only ships BasicAuth, for Bearer auth we need a subclass of BasicAuth."""

def __new__(cls, token: str) -> 'BearerAuth':
return super().__new__(cls, token) # type: ignore

def __init__(self, token: str) -> None:
self.token = token

def encode(self) -> str:
return f'Bearer {self.token}'
Loading

0 comments on commit 9b40c02

Please sign in to comment.