diff --git a/src/posit/connect/external/databricks.py b/src/posit/connect/external/databricks.py index 9eb64d7f..3cea970d 100644 --- a/src/posit/connect/external/databricks.py +++ b/src/posit/connect/external/databricks.py @@ -82,6 +82,19 @@ def _get_auth_type(local_auth_type: str) -> str: return POSIT_OAUTH_INTEGRATION_AUTH_TYPE class PositLocalContentCredentialsProvider: + """`CredentialsProvider` implementation which provides a fallback for local development using a client credentials flow. + + There is an open issue against the Databricks CLI which prevents it from returning service principal access tokens. + https://github.com/databricks/cli/issues/1939 + + Until the CLI issue is resolved, this CredentialsProvider implements the approach described in the Databricks documentation + for manually generating a workspace-level access token using OAuth M2M authentication. Once it has acquired an access token, + it returns it as a Bearer authorization header like other `CredentialsProvider` implementations. + + See Also + -------- + * https://docs.databricks.com/en/dev-tools/auth/oauth-m2m.html#manually-generate-a-workspace-level-access-token + """ def __init__(self, token_endpoint_url: str, client_id: str, client_secret: str): self._token_endpoint_url = token_endpoint_url @@ -142,6 +155,73 @@ def __call__(self) -> Dict[str, str]: return _new_bearer_authorization_header(credentials) class PositLocalContentCredentialsStrategy(CredentialsStrategy): + """`CredentialsStrategy` implementation which supports local development using OAuth M2M authentication against databricks. + + There is an open issue against the Databricks CLI which prevents it from returning service principal access tokens. + https://github.com/databricks/cli/issues/1939 + + Until the CLI issue is resolved, this CredentialsStrategy provides a drop-in replacement as a local_strategy that can be used + to develop applications which target Service Account OAuth integrations on Connect. + + Examples + -------- + In the example below, the PositContentCredentialsStrategy can be initialized anywhere that + the Python process can read environment variables. + + CLIENT_ID and CLIENT_SECRET credentials associated with the Databricks Service Principal. + + ```python + import os + + from posit.connect.external.databricks import PositContentCredentialsStrategy, PositLocalContentCredentialsStrategy + + import pandas as pd + from databricks import sql + from databricks.sdk.core import ApiClient, Config + from databricks.sdk.service.iam import CurrentUserAPI + + DATABRICKS_HOST = "" + DATABRICKS_HOST_URL = f"https://{DATABRICKS_HOST}" + SQL_HTTP_PATH = "" + TOKEN_ENDPOINT_URL = f"https://{DATABRICKS_HOST}/oidc/v1/token" + + CLIENT_ID = "" + CLIENT_SECRET = "" + + # Rather than relying on the Databricks CLI as a local strategy, we use + # PositLocalContentCredentialsStragtegy as a drop-in replacement. + # Can be replaced with the Databricks CLI implementation when + # https://github.com/databricks/cli/issues/1939 is resolved. + local_strategy = PositLocalContentCredentialsStrategy( + token_endpoint_url=TOKEN_ENDPOINT_URL, + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + ) + + posit_strategy = PositContentCredentialsStrategy(local_strategy=local_strategy) + + cfg = Config(host=DATABRICKS_HOST_URL, credentials_strategy=posit_strategy) + + databricks_user_info = CurrentUserAPI(ApiClient(cfg)).me() + print(f"Hello, {databricks_user_info.display_name}!") + + query = "SELECT * FROM samples.nyctaxi.trips LIMIT 10;" + with sql.connect( + server_hostname=DATABRICKS_HOST, + http_path=SQL_HTTP_PATH, + credentials_provider=posit_strategy.sql_credentials_provider(cfg), + ) as connection: + with connection.cursor() as cursor: + cursor.execute(query) + rows = cursor.fetchall() + print(pd.DataFrame([row.asDict() for row in rows])) + ``` + + See Also + -------- + * https://docs.databricks.com/en/dev-tools/auth/oauth-m2m.html#manually-generate-a-workspace-level-access-token + """ + def __init__(self, token_endpoint_url: str, client_id: str, client_secret: str): self._token_endpoint_url = token_endpoint_url self._client_id = client_id diff --git a/tests/posit/connect/external/test_databricks.py b/tests/posit/connect/external/test_databricks.py index 1a53c3cd..daf5e1ab 100644 --- a/tests/posit/connect/external/test_databricks.py +++ b/tests/posit/connect/external/test_databricks.py @@ -1,3 +1,5 @@ +import base64 + from typing import Dict from unittest.mock import patch @@ -13,6 +15,8 @@ PositContentCredentialsStrategy, PositCredentialsProvider, PositCredentialsStrategy, + PositLocalContentCredentialsProvider, + PositLocalContentCredentialsStrategy, _get_auth_type, _new_bearer_authorization_header, ) @@ -92,6 +96,36 @@ def test_get_auth_type_local(self): def test_get_auth_type_connect(self): assert _get_auth_type("local-auth") == POSIT_OAUTH_INTEGRATION_AUTH_TYPE + @responses.activate + def test_local_content_credentials_provider(self): + + token_url = "https://my-token/url" + client_id = "client_id" + client_secret = "client_secret_123" + basic_auth = f"{client_id}:{client_secret}" + b64_basic_auth = base64.b64encode(basic_auth.encode('utf-8')).decode('utf-8') + + responses.post( + token_url, + match=[ + responses.matchers.urlencoded_params_matcher( + { + "grant_type": "client_credentials", + "scope": "all-apis", + }, + ), + responses.matchers.header_matcher({"Authorization": f"Basic {b64_basic_auth}"}) + ], + json={ + "access_token": "oauth2-m2m-access-token", + "token_type": "Bearer", + "expires_in": 3600, + }, + ) + + cp = PositLocalContentCredentialsProvider(token_url, client_id, client_secret) + assert cp() == {"Authorization": "Bearer oauth2-m2m-access-token"} + @patch.dict("os.environ", {"CONNECT_CONTENT_SESSION_TOKEN": "cit"}) @responses.activate def test_posit_content_credentials_provider(self): @@ -111,6 +145,46 @@ def test_posit_credentials_provider(self): cp = PositCredentialsProvider(client=client, user_session_token="cit") assert cp() == {"Authorization": "Bearer dynamic-viewer-access-token"} + + @responses.activate + def test_local_content_credentials_strategy(self): + + token_url = "https://my-token/url" + client_id = "client_id" + client_secret = "client_secret_123" + basic_auth = f"{client_id}:{client_secret}" + b64_basic_auth = base64.b64encode(basic_auth.encode('utf-8')).decode('utf-8') + + + responses.post( + token_url, + match=[ + responses.matchers.urlencoded_params_matcher( + { + "grant_type": "client_credentials", + "scope": "all-apis", + }, + ), + responses.matchers.header_matcher({"Authorization": f"Basic {b64_basic_auth}"}) + ], + json={ + "access_token": "oauth2-m2m-access-token", + "token_type": "Bearer", + "expires_in": 3600, + }, + ) + + cs = PositLocalContentCredentialsStrategy( + token_url, + client_id, + client_secret, + ) + cp = cs() + assert cs.auth_type() == "posit-local-client-credentials" + assert cp() == {"Authorization": "Bearer oauth2-m2m-access-token"} + + + @patch.dict("os.environ", {"CONNECT_CONTENT_SESSION_TOKEN": "cit"}) @responses.activate @patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"})