Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MES-1053 - add support for databricks m2m oauth #154

Merged
merged 5 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from typing import Dict, Optional
from typing import Dict, Optional, Callable

from databricks import sql
from databricks.sdk.core import oauth_service_principal, Config

from apollo.integrations.db.base_db_proxy_client import BaseDbProxyClient

_ATTR_CONNECT_ARGS = "connect_args"
_ATTR_CREDENTIALS_PROVIDER = "credentials_provider"

CLIENT_ID_KEY = "databricks_client_id"
CLIENT_SECRET_KEY = "databricks_client_secret"


class DatabricksSqlWarehouseProxyClient(BaseDbProxyClient):
Expand All @@ -20,8 +25,28 @@ def __init__(self, credentials: Optional[Dict], **kwargs: Dict):
raise ValueError(
f"Databricks agent client requires {_ATTR_CONNECT_ARGS} in credentials"
)

if self._credentials_use_oauth(credentials[_ATTR_CONNECT_ARGS]):
credentials[_ATTR_CONNECT_ARGS][_ATTR_CREDENTIALS_PROVIDER] = (
self._oauth_credentials_provider(credentials[_ATTR_CONNECT_ARGS])
)

self._connection = sql.connect(**credentials[_ATTR_CONNECT_ARGS])

def _credentials_use_oauth(self, connect_args: Dict) -> bool:
return CLIENT_ID_KEY in connect_args and CLIENT_SECRET_KEY in connect_args

def _oauth_credentials_provider(self, connect_args: Dict) -> Callable:
# create the auth callable here because it can't be serialized
config = Config(
host=connect_args.get("server_hostname"),
# Service Principal UUID
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: super minor and just for consistency I'd define a constant for server_hostname

client_id=connect_args.get(CLIENT_ID_KEY),
# Service Principal Secret
client_secret=connect_args.get(CLIENT_SECRET_KEY),
)
return lambda: oauth_service_principal(config)

@property
def wrapped_client(self):
return self._connection
3 changes: 3 additions & 0 deletions apollo/integrations/http/http_proxy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def do_request(
params: Optional[Dict] = None,
verify_ssl: Optional[bool] = None,
retry_status_code_ranges: Optional[List[Tuple]] = None,
data: Optional[str] = None,
) -> Dict:
"""
Executes a single request with no retry, intended to be used for JSON request/response endpoints.
Expand All @@ -83,6 +84,8 @@ def do_request(
request_args = {}
if payload:
request_args["json"] = payload
if data:
request_args["data"] = data
if timeout:
request_args["timeout"] = timeout
if params:
Expand Down
1 change: 1 addition & 0 deletions requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ azure-storage-blob==12.23.0
boto3==1.34.151
cryptography>=43.0.1
databricks-sql-connector==3.5.0
databricks-sdk==0.11.0
dataclasses-json==0.6.0
duckdb==0.10.3
flask==2.3.3
Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ cryptography==43.0.1
# pyjwt
# pyopenssl
# snowflake-connector-python
databricks-sdk==0.11.0
# via -r requirements.in
databricks-sql-connector==3.5.0
# via -r requirements.in
dataclasses-json==0.6.0
Expand Down Expand Up @@ -247,6 +249,7 @@ requests==2.32.3
# via
# -r requirements.in
# azure-core
# databricks-sdk
# databricks-sql-connector
# google-api-core
# google-cloud-storage
Expand Down
Loading