-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add Databricks OAuthIntegration credentials provider * Return Credentials response. Dont cache oauthIntegration resource in client * Add oauth integration tests * Test databricks content on Connect * Move is_local; check CONNECT_SERVER and CONNECT_CONTENT_GUID * Use post body temporarily for oauth/integration/credentials endpoints * use form xml post in credentials request * Add sample databricks content * Fix linter * update sample content requirements * Revert env var change * Use user-session-token instead of content-identity-token
- Loading branch information
Showing
11 changed files
with
285 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
.posit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
```bash | ||
# start streamlit locally | ||
DATABRICKS_TOKEN=<DB_PAT> \ | ||
streamlit run ./sample-content.py | ||
|
||
# deploy the app the first time | ||
publisher deploy -a localhost:3939 -n databricks ./ | ||
|
||
# re-deploy the databricks app | ||
publisher redeploy databricks | ||
``` | ||
|
||
TODO: Test this content with databricks-connect | ||
<https://docs.databricks.com/en/dev-tools/databricks-connect/python/index.html> | ||
|
||
``` | ||
# install the sdk from this branch | ||
pip install git+https://github.com/posit-dev/posit-sdk-py.git@kegs/databricks-oauth-2 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
altair==5.2.0 | ||
attrs==23.2.0 | ||
blinker==1.7.0 | ||
cachetools==5.3.2 | ||
certifi==2024.2.2 | ||
charset-normalizer==3.3.2 | ||
click==8.1.7 | ||
databricks-sdk==0.20.0 | ||
databricks-sql-connector==3.1.0 | ||
et-xmlfile==1.1.0 | ||
gitdb==4.0.11 | ||
GitPython==3.1.42 | ||
google-auth==2.28.0 | ||
idna==3.6 | ||
importlib-metadata==7.0.1 | ||
Jinja2==3.1.3 | ||
jsonschema==4.21.1 | ||
jsonschema-specifications==2023.12.1 | ||
lz4==4.3.3 | ||
markdown-it-py==3.0.0 | ||
MarkupSafe==2.1.5 | ||
mdurl==0.1.2 | ||
numpy==1.26.4 | ||
oauthlib==3.2.2 | ||
openpyxl==3.1.2 | ||
packaging==23.2 | ||
pandas==2.1.4 | ||
pillow==10.2.0 | ||
posit-sdk @ git+https://github.com/posit-dev/posit-sdk-py.git@24ad71458de56b01b8168e80441ea860236f1933 | ||
protobuf==4.25.3 | ||
pyarrow==14.0.2 | ||
pyasn1==0.5.1 | ||
pyasn1-modules==0.3.0 | ||
pydeck==0.8.1b0 | ||
Pygments==2.17.2 | ||
python-dateutil==2.8.2 | ||
pytz==2024.1 | ||
referencing==0.33.0 | ||
requests==2.31.0 | ||
rich==13.7.0 | ||
rpds-py==0.18.0 | ||
rsa==4.9 | ||
six==1.16.0 | ||
smmap==5.0.1 | ||
streamlit==1.31.1 | ||
tenacity==8.2.3 | ||
thrift==0.16.0 | ||
toml==0.10.2 | ||
toolz==0.12.1 | ||
tornado==6.4 | ||
typing_extensions==4.9.0 | ||
tzdata==2024.1 | ||
tzlocal==5.2 | ||
urllib3==2.2.1 | ||
validators==0.22.0 | ||
zipp==3.17.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# -*- coding: utf-8 -*- | ||
# mypy: ignore-errors | ||
import os | ||
|
||
from posit.connect.external.databricks import viewer_credentials_provider | ||
|
||
from databricks import sql | ||
from databricks.sdk.service.iam import CurrentUserAPI | ||
from databricks.sdk.core import ApiClient, Config | ||
|
||
import pandas as pd | ||
import streamlit as st | ||
from streamlit.web.server.websocket_headers import _get_websocket_headers | ||
|
||
DB_PAT=os.getenv("DATABRICKS_TOKEN") | ||
|
||
DB_HOST=os.getenv("DB_HOST") | ||
DB_HOST_URL = f"https://{DB_HOST}" | ||
SQL_HTTP_PATH=os.getenv("SQL_HTTP_PATH") | ||
|
||
USER_SESSION_TOKEN = None | ||
|
||
# Read the viewer's user session token from the streamlit ws header. | ||
headers = _get_websocket_headers() | ||
if headers: | ||
USER_SESSION_TOKEN = headers.get('Posit-Connect-User-Session') | ||
|
||
credentials_provider = viewer_credentials_provider(user_session_token=USER_SESSION_TOKEN) | ||
cfg = Config(host=DB_HOST_URL, credentials_provider=credentials_provider) | ||
#cfg = Config(host=DB_HOST_URL, token=DB_PAT) | ||
|
||
databricks_user = CurrentUserAPI(ApiClient(cfg)).me() | ||
st.write(f"Hello, {databricks_user.display_name}!") | ||
|
||
with sql.connect( | ||
server_hostname=DB_HOST, | ||
http_path=SQL_HTTP_PATH, | ||
#access_token=DB_PAT) as connection: | ||
auth_type='databricks-oauth', | ||
credentials_provider=credentials_provider) as connection: | ||
with connection.cursor() as cursor: | ||
cursor.execute("SELECT * FROM data") | ||
result = cursor.fetchall() | ||
st.table(pd.DataFrame(result)) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,6 @@ | |
|
||
from unittest.mock import MagicMock, patch | ||
|
||
|
||
from .client import Client | ||
|
||
|
||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import abc | ||
import os | ||
from typing import Callable, Dict, Optional | ||
|
||
from ..client import Client | ||
from ..oauth import OAuthIntegration | ||
|
||
HeaderFactory = Callable[[], Dict[str, str]] | ||
|
||
# https://github.com/databricks/databricks-sdk-py/blob/v0.20.0/databricks/sdk/credentials_provider.py | ||
# https://github.com/databricks/databricks-sql-python/blob/v3.1.0/src/databricks/sql/auth/authenticators.py | ||
# In order to keep compatibility with the Databricks SDK | ||
class CredentialsProvider(abc.ABC): | ||
"""CredentialsProvider is the protocol (call-side interface) | ||
for authenticating requests to Databricks REST APIs""" | ||
|
||
@abc.abstractmethod | ||
def auth_type(self) -> str: | ||
raise NotImplementedError | ||
|
||
@abc.abstractmethod | ||
def __call__(self, *args, **kwargs) -> HeaderFactory: | ||
raise NotImplementedError | ||
|
||
|
||
class PositOAuthIntegrationCredentialsProvider(CredentialsProvider): | ||
def __init__(self, posit_oauth: OAuthIntegration, user_session_token: str): | ||
self.posit_oauth = posit_oauth | ||
self.user_session_token = user_session_token | ||
|
||
def auth_type(self) -> str: | ||
return "posit-oauth-integration" | ||
|
||
def __call__(self, *args, **kwargs) -> HeaderFactory: | ||
def inner() -> Dict[str, str]: | ||
access_token = self.posit_oauth.get_credentials(self.user_session_token)['access_token'] | ||
return {"Authorization": f"Bearer {access_token}"} | ||
return inner | ||
|
||
|
||
# Use this environment variable to determine if the | ||
# client SDK was initialized from a piece of content running on a Connect server. | ||
def is_local() -> bool: | ||
return not os.getenv("RSTUDIO_PRODUCT") == "CONNECT" | ||
|
||
|
||
def viewer_credentials_provider(client: Optional[Client] = None, user_session_token: Optional[str] = None) -> Optional[CredentialsProvider]: | ||
|
||
# If the content is not running on Connect then viewer auth should | ||
# fall back to the locally configured credentials hierarchy | ||
if is_local(): | ||
return None | ||
|
||
if client is None: | ||
client = Client() | ||
|
||
# If the user-session-token wasn't provided and we're running on Connect then we raise an exception. | ||
# user_session_token is required to impersonate the viewer. | ||
if user_session_token is None: | ||
raise ValueError("The user-session-token is required for viewer authentication.") | ||
|
||
return PositOAuthIntegrationCredentialsProvider(client.oauth, user_session_token) | ||
|
||
|
||
def service_account_credentials_provider(client: Optional[Client] = None): | ||
raise NotImplementedError |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from __future__ import annotations | ||
|
||
from requests import Session | ||
from typing import Optional, TypedDict | ||
|
||
from . import urls | ||
from .config import Config | ||
|
||
|
||
class Credentials(TypedDict, total=False): | ||
access_token: str | ||
issued_token_type: str | ||
token_type: str | ||
|
||
|
||
class OAuthIntegration: | ||
|
||
def __init__( | ||
self, config: Config, session: Session | ||
) -> None: | ||
self.url = urls.append_path(config.url, "v1/oauth/integrations/credentials") | ||
self.config = config | ||
self.session = session | ||
|
||
|
||
def get_credentials(self, user_session_token: Optional[str]=None) -> Credentials: | ||
|
||
# craft a basic credential exchange request where the self.config.api_key owner | ||
# is requesting their own credentials | ||
data = dict() | ||
data["grant_type"] = "urn:ietf:params:oauth:grant-type:token-exchange" | ||
data["subject_token_type"] = "urn:posit:connect:api-key" | ||
data["subject_token"] = self.config.api_key | ||
|
||
# if this content is running on Connect, then it is allowed to request | ||
# the content viewer's credentials | ||
if user_session_token: | ||
data["subject_token_type"] = "urn:posit:connect:user-session-token" | ||
data["subject_token"] = user_session_token | ||
|
||
response = self.session.post(self.url, data=data) | ||
return Credentials(**response.json()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import responses | ||
|
||
from .client import Client | ||
|
||
class TestOAuthIntegrations: | ||
|
||
@responses.activate | ||
def test_get_credentials(self): | ||
responses.post( | ||
"https://connect.example/__api__/v1/oauth/integrations/credentials", | ||
match=[ | ||
responses.matchers.urlencoded_params_matcher( | ||
{ | ||
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", | ||
"subject_token_type": "urn:posit:connect:user-session-token", | ||
"subject_token": "cit", | ||
} | ||
) | ||
], | ||
json={ | ||
"access_token": "viewer-token", | ||
"issued_token_type": "urn:ietf:params:oauth:token-type:access_token", | ||
"token_type": "Bearer", | ||
}, | ||
) | ||
responses.post( | ||
"https://connect.example/__api__/v1/oauth/integrations/credentials", | ||
match=[ | ||
responses.matchers.urlencoded_params_matcher( | ||
{ | ||
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", | ||
"subject_token_type": "urn:posit:connect:api-key", | ||
"subject_token": "12345", | ||
} | ||
) | ||
], | ||
json={ | ||
"access_token": "sdk-user-token", | ||
"issued_token_type": "urn:ietf:params:oauth:token-type:access_token", | ||
"token_type": "Bearer", | ||
}, | ||
) | ||
con = Client(api_key="12345", url="https://connect.example/") | ||
assert ( | ||
con.oauth.get_credentials()["access_token"] | ||
== "sdk-user-token" | ||
) | ||
assert ( | ||
con.oauth.get_credentials("cit")["access_token"] | ||
== "viewer-token" | ||
) |