Skip to content

Commit

Permalink
adding docstrings, tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zackverham committed Dec 5, 2024
1 parent 54a47c9 commit 144e895
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 0 deletions.
80 changes: 80 additions & 0 deletions src/posit/connect/external/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,19 @@ def _get_auth_type(local_auth_type: str) -> str:
return POSIT_OAUTH_INTEGRATION_AUTH_TYPE

class PositLocalContentCredentialsProvider:
"""`CredentialsProvider` implementation which provides a fallback for local development using a client credentials flow.
There is an open issue against the Databricks CLI which prevents it from returning service principal access tokens.
https://github.com/databricks/cli/issues/1939
Until the CLI issue is resolved, this CredentialsProvider implements the approach described in the Databricks documentation
for manually generating a workspace-level access token using OAuth M2M authentication. Once it has acquired an access token,
it returns it as a Bearer authorization header like other `CredentialsProvider` implementations.
See Also
--------
* https://docs.databricks.com/en/dev-tools/auth/oauth-m2m.html#manually-generate-a-workspace-level-access-token
"""

def __init__(self, token_endpoint_url: str, client_id: str, client_secret: str):
self._token_endpoint_url = token_endpoint_url
Expand Down Expand Up @@ -142,6 +155,73 @@ def __call__(self) -> Dict[str, str]:
return _new_bearer_authorization_header(credentials)

class PositLocalContentCredentialsStrategy(CredentialsStrategy):
"""`CredentialsStrategy` implementation which supports local development using OAuth M2M authentication against databricks.
There is an open issue against the Databricks CLI which prevents it from returning service principal access tokens.
https://github.com/databricks/cli/issues/1939
Until the CLI issue is resolved, this CredentialsStrategy provides a drop-in replacement as a local_strategy that can be used
to develop applications which target Service Account OAuth integrations on Connect.
Examples
--------
In the example below, the PositContentCredentialsStrategy can be initialized anywhere that
the Python process can read environment variables.
CLIENT_ID and CLIENT_SECRET credentials associated with the Databricks Service Principal.
```python
import os
from posit.connect.external.databricks import PositContentCredentialsStrategy, PositLocalContentCredentialsStrategy
import pandas as pd
from databricks import sql
from databricks.sdk.core import ApiClient, Config
from databricks.sdk.service.iam import CurrentUserAPI
DATABRICKS_HOST = "<REDACTED>"
DATABRICKS_HOST_URL = f"https://{DATABRICKS_HOST}"
SQL_HTTP_PATH = "<REDACTED>"
TOKEN_ENDPOINT_URL = f"https://{DATABRICKS_HOST}/oidc/v1/token"
CLIENT_ID = "<REDACTED>"
CLIENT_SECRET = "<REDACTED>"
# Rather than relying on the Databricks CLI as a local strategy, we use
# PositLocalContentCredentialsStragtegy as a drop-in replacement.
# Can be replaced with the Databricks CLI implementation when
# https://github.com/databricks/cli/issues/1939 is resolved.
local_strategy = PositLocalContentCredentialsStrategy(
token_endpoint_url=TOKEN_ENDPOINT_URL,
client_id=CLIENT_ID,
client_secret=CLIENT_SECRET,
)
posit_strategy = PositContentCredentialsStrategy(local_strategy=local_strategy)
cfg = Config(host=DATABRICKS_HOST_URL, credentials_strategy=posit_strategy)
databricks_user_info = CurrentUserAPI(ApiClient(cfg)).me()
print(f"Hello, {databricks_user_info.display_name}!")
query = "SELECT * FROM samples.nyctaxi.trips LIMIT 10;"
with sql.connect(
server_hostname=DATABRICKS_HOST,
http_path=SQL_HTTP_PATH,
credentials_provider=posit_strategy.sql_credentials_provider(cfg),
) as connection:
with connection.cursor() as cursor:
cursor.execute(query)
rows = cursor.fetchall()
print(pd.DataFrame([row.asDict() for row in rows]))
```
See Also
--------
* https://docs.databricks.com/en/dev-tools/auth/oauth-m2m.html#manually-generate-a-workspace-level-access-token
"""

def __init__(self, token_endpoint_url: str, client_id: str, client_secret: str):
self._token_endpoint_url = token_endpoint_url
self._client_id = client_id
Expand Down
74 changes: 74 additions & 0 deletions tests/posit/connect/external/test_databricks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import base64

from typing import Dict
from unittest.mock import patch

Expand All @@ -13,6 +15,8 @@
PositContentCredentialsStrategy,
PositCredentialsProvider,
PositCredentialsStrategy,
PositLocalContentCredentialsProvider,
PositLocalContentCredentialsStrategy,
_get_auth_type,
_new_bearer_authorization_header,
)
Expand Down Expand Up @@ -92,6 +96,36 @@ def test_get_auth_type_local(self):
def test_get_auth_type_connect(self):
assert _get_auth_type("local-auth") == POSIT_OAUTH_INTEGRATION_AUTH_TYPE

@responses.activate
def test_local_content_credentials_provider(self):

token_url = "https://my-token/url"
client_id = "client_id"
client_secret = "client_secret_123"
basic_auth = f"{client_id}:{client_secret}"
b64_basic_auth = base64.b64encode(basic_auth.encode('utf-8')).decode('utf-8')

responses.post(
token_url,
match=[
responses.matchers.urlencoded_params_matcher(
{
"grant_type": "client_credentials",
"scope": "all-apis",
},
),
responses.matchers.header_matcher({"Authorization": f"Basic {b64_basic_auth}"})
],
json={
"access_token": "oauth2-m2m-access-token",
"token_type": "Bearer",
"expires_in": 3600,
},
)

cp = PositLocalContentCredentialsProvider(token_url, client_id, client_secret)
assert cp() == {"Authorization": "Bearer oauth2-m2m-access-token"}

@patch.dict("os.environ", {"CONNECT_CONTENT_SESSION_TOKEN": "cit"})
@responses.activate
def test_posit_content_credentials_provider(self):
Expand All @@ -111,6 +145,46 @@ def test_posit_credentials_provider(self):
cp = PositCredentialsProvider(client=client, user_session_token="cit")
assert cp() == {"Authorization": "Bearer dynamic-viewer-access-token"}


@responses.activate
def test_local_content_credentials_strategy(self):

token_url = "https://my-token/url"
client_id = "client_id"
client_secret = "client_secret_123"
basic_auth = f"{client_id}:{client_secret}"
b64_basic_auth = base64.b64encode(basic_auth.encode('utf-8')).decode('utf-8')


responses.post(
token_url,
match=[
responses.matchers.urlencoded_params_matcher(
{
"grant_type": "client_credentials",
"scope": "all-apis",
},
),
responses.matchers.header_matcher({"Authorization": f"Basic {b64_basic_auth}"})
],
json={
"access_token": "oauth2-m2m-access-token",
"token_type": "Bearer",
"expires_in": 3600,
},
)

cs = PositLocalContentCredentialsStrategy(
token_url,
client_id,
client_secret,
)
cp = cs()
assert cs.auth_type() == "posit-local-client-credentials"
assert cp() == {"Authorization": "Bearer oauth2-m2m-access-token"}



@patch.dict("os.environ", {"CONNECT_CONTENT_SESSION_TOKEN": "cit"})
@responses.activate
@patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"})
Expand Down

0 comments on commit 144e895

Please sign in to comment.