diff --git a/apollo/integrations/databricks/databricks_sql_warehouse_proxy_client.py b/apollo/integrations/databricks/databricks_sql_warehouse_proxy_client.py index f636846a..97edbb10 100644 --- a/apollo/integrations/databricks/databricks_sql_warehouse_proxy_client.py +++ b/apollo/integrations/databricks/databricks_sql_warehouse_proxy_client.py @@ -1,10 +1,16 @@ -from typing import Dict, Optional +from typing import Dict, Optional, Callable from databricks import sql +from databricks.sdk.core import oauth_service_principal, Config from apollo.integrations.db.base_db_proxy_client import BaseDbProxyClient _ATTR_CONNECT_ARGS = "connect_args" +_ATTR_CREDENTIALS_PROVIDER = "credentials_provider" + +SERVER_HOSTNAME = "server_hostname" +CLIENT_ID_KEY = "databricks_client_id" +CLIENT_SECRET_KEY = "databricks_client_secret" class DatabricksSqlWarehouseProxyClient(BaseDbProxyClient): @@ -20,8 +26,28 @@ def __init__(self, credentials: Optional[Dict], **kwargs: Dict): raise ValueError( f"Databricks agent client requires {_ATTR_CONNECT_ARGS} in credentials" ) + + if self._credentials_use_oauth(credentials[_ATTR_CONNECT_ARGS]): + credentials[_ATTR_CONNECT_ARGS][_ATTR_CREDENTIALS_PROVIDER] = ( + self._oauth_credentials_provider(credentials[_ATTR_CONNECT_ARGS]) + ) + self._connection = sql.connect(**credentials[_ATTR_CONNECT_ARGS]) + def _credentials_use_oauth(self, connect_args: Dict) -> bool: + return CLIENT_ID_KEY in connect_args and CLIENT_SECRET_KEY in connect_args + + def _oauth_credentials_provider(self, connect_args: Dict) -> Callable: + # create the auth callable here because it can't be serialized + config = Config( + host=connect_args.get(SERVER_HOSTNAME), + # Service Principal UUID + client_id=connect_args.get(CLIENT_ID_KEY), + # Service Principal Secret + client_secret=connect_args.get(CLIENT_SECRET_KEY), + ) + return lambda: oauth_service_principal(config) + @property def wrapped_client(self): return self._connection diff --git a/apollo/integrations/http/http_proxy_client.py b/apollo/integrations/http/http_proxy_client.py index 996e8128..069103f6 100644 --- a/apollo/integrations/http/http_proxy_client.py +++ b/apollo/integrations/http/http_proxy_client.py @@ -60,6 +60,7 @@ def do_request( params: Optional[Dict] = None, verify_ssl: Optional[bool] = None, retry_status_code_ranges: Optional[List[Tuple]] = None, + data: Optional[str] = None, ) -> Dict: """ Executes a single request with no retry, intended to be used for JSON request/response endpoints. @@ -83,6 +84,8 @@ def do_request( request_args = {} if payload: request_args["json"] = payload + if data: + request_args["data"] = data if timeout: request_args["timeout"] = timeout if params: diff --git a/requirements.in b/requirements.in index cb1d7513..4911d065 100644 --- a/requirements.in +++ b/requirements.in @@ -4,6 +4,7 @@ azure-storage-blob==12.23.0 boto3==1.34.151 cryptography>=43.0.1 databricks-sql-connector==3.5.0 +databricks-sdk==0.11.0 dataclasses-json==0.6.0 duckdb==0.10.3 flask==2.3.3 diff --git a/requirements.txt b/requirements.txt index 0c7f7081..28805eb7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -69,6 +69,8 @@ cryptography==43.0.1 # pyjwt # pyopenssl # snowflake-connector-python +databricks-sdk==0.11.0 + # via -r requirements.in databricks-sql-connector==3.5.0 # via -r requirements.in dataclasses-json==0.6.0 @@ -247,6 +249,7 @@ requests==2.32.3 # via # -r requirements.in # azure-core + # databricks-sdk # databricks-sql-connector # google-api-core # google-cloud-storage