diff --git a/examples/connect/.gitignore b/examples/connect/.gitignore new file mode 100644 index 00000000..afe36a39 --- /dev/null +++ b/examples/connect/.gitignore @@ -0,0 +1 @@ +.posit diff --git a/examples/connect/databricks/README.md b/examples/connect/databricks/README.md new file mode 100644 index 00000000..fb1c223f --- /dev/null +++ b/examples/connect/databricks/README.md @@ -0,0 +1,19 @@ +```bash +# start streamlit locally +DATABRICKS_TOKEN= \ +streamlit run ./sample-content.py + +# deploy the app the first time +publisher deploy -a localhost:3939 -n databricks ./ + +# re-deploy the databricks app +publisher redeploy databricks +``` + +TODO: Test this content with databricks-connect + + +``` +# install the sdk from this branch +pip install git+https://github.com/posit-dev/posit-sdk-py.git@kegs/databricks-oauth-2 +``` diff --git a/examples/connect/databricks/requirements.txt b/examples/connect/databricks/requirements.txt new file mode 100644 index 00000000..99719b8d --- /dev/null +++ b/examples/connect/databricks/requirements.txt @@ -0,0 +1,56 @@ +altair==5.2.0 +attrs==23.2.0 +blinker==1.7.0 +cachetools==5.3.2 +certifi==2024.2.2 +charset-normalizer==3.3.2 +click==8.1.7 +databricks-sdk==0.20.0 +databricks-sql-connector==3.1.0 +et-xmlfile==1.1.0 +gitdb==4.0.11 +GitPython==3.1.42 +google-auth==2.28.0 +idna==3.6 +importlib-metadata==7.0.1 +Jinja2==3.1.3 +jsonschema==4.21.1 +jsonschema-specifications==2023.12.1 +lz4==4.3.3 +markdown-it-py==3.0.0 +MarkupSafe==2.1.5 +mdurl==0.1.2 +numpy==1.26.4 +oauthlib==3.2.2 +openpyxl==3.1.2 +packaging==23.2 +pandas==2.1.4 +pillow==10.2.0 +posit-sdk @ git+https://github.com/posit-dev/posit-sdk-py.git@24ad71458de56b01b8168e80441ea860236f1933 +protobuf==4.25.3 +pyarrow==14.0.2 +pyasn1==0.5.1 +pyasn1-modules==0.3.0 +pydeck==0.8.1b0 +Pygments==2.17.2 +python-dateutil==2.8.2 +pytz==2024.1 +referencing==0.33.0 +requests==2.31.0 +rich==13.7.0 +rpds-py==0.18.0 +rsa==4.9 +six==1.16.0 +smmap==5.0.1 +streamlit==1.31.1 +tenacity==8.2.3 +thrift==0.16.0 +toml==0.10.2 +toolz==0.12.1 +tornado==6.4 +typing_extensions==4.9.0 +tzdata==2024.1 +tzlocal==5.2 +urllib3==2.2.1 +validators==0.22.0 +zipp==3.17.0 diff --git a/examples/connect/databricks/sample-content.py b/examples/connect/databricks/sample-content.py new file mode 100644 index 00000000..3620bbc7 --- /dev/null +++ b/examples/connect/databricks/sample-content.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +# mypy: ignore-errors +import os + +from posit.connect.external.databricks import viewer_credentials_provider + +from databricks import sql +from databricks.sdk.service.iam import CurrentUserAPI +from databricks.sdk.core import ApiClient, Config + +import pandas as pd +import streamlit as st +from streamlit.web.server.websocket_headers import _get_websocket_headers + +DB_PAT=os.getenv("DATABRICKS_TOKEN") + +DB_HOST=os.getenv("DB_HOST") +DB_HOST_URL = f"https://{DB_HOST}" +SQL_HTTP_PATH=os.getenv("SQL_HTTP_PATH") + +USER_SESSION_TOKEN = None + +# Read the viewer's user session token from the streamlit ws header. +headers = _get_websocket_headers() +if headers: + USER_SESSION_TOKEN = headers.get('Posit-Connect-User-Session') + +credentials_provider = viewer_credentials_provider(user_session_token=USER_SESSION_TOKEN) +cfg = Config(host=DB_HOST_URL, credentials_provider=credentials_provider) +#cfg = Config(host=DB_HOST_URL, token=DB_PAT) + +databricks_user = CurrentUserAPI(ApiClient(cfg)).me() +st.write(f"Hello, {databricks_user.display_name}!") + +with sql.connect( + server_hostname=DB_HOST, + http_path=SQL_HTTP_PATH, + #access_token=DB_PAT) as connection: + auth_type='databricks-oauth', + credentials_provider=credentials_provider) as connection: + with connection.cursor() as cursor: + cursor.execute("SELECT * FROM data") + result = cursor.fetchall() + st.table(pd.DataFrame(result)) + diff --git a/src/posit/connect/client.py b/src/posit/connect/client.py index e22aa6d5..6d9a751e 100644 --- a/src/posit/connect/client.py +++ b/src/posit/connect/client.py @@ -7,6 +7,7 @@ from .auth import Auth from .config import Config +from .oauth import OAuthIntegration from .content import Content from .users import User, Users @@ -51,6 +52,10 @@ def me(self) -> User: response = self.session.get(url) return User(**response.json()) + @property + def oauth(self) -> OAuthIntegration: + return OAuthIntegration(config=self.config, session=self.session) + @property def users(self) -> Users: return Users(config=self.config, session=self.session) diff --git a/src/posit/connect/client_test.py b/src/posit/connect/client_test.py index 435f895e..bda8f5f8 100644 --- a/src/posit/connect/client_test.py +++ b/src/posit/connect/client_test.py @@ -3,7 +3,6 @@ from unittest.mock import MagicMock, patch - from .client import Client diff --git a/src/posit/connect/external/__init__.py b/src/posit/connect/external/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/posit/connect/external/databricks.py b/src/posit/connect/external/databricks.py new file mode 100644 index 00000000..cce1a81d --- /dev/null +++ b/src/posit/connect/external/databricks.py @@ -0,0 +1,66 @@ +import abc +import os +from typing import Callable, Dict, Optional + +from ..client import Client +from ..oauth import OAuthIntegration + +HeaderFactory = Callable[[], Dict[str, str]] + +# https://github.com/databricks/databricks-sdk-py/blob/v0.20.0/databricks/sdk/credentials_provider.py +# https://github.com/databricks/databricks-sql-python/blob/v3.1.0/src/databricks/sql/auth/authenticators.py +# In order to keep compatibility with the Databricks SDK +class CredentialsProvider(abc.ABC): + """CredentialsProvider is the protocol (call-side interface) + for authenticating requests to Databricks REST APIs""" + + @abc.abstractmethod + def auth_type(self) -> str: + raise NotImplementedError + + @abc.abstractmethod + def __call__(self, *args, **kwargs) -> HeaderFactory: + raise NotImplementedError + + +class PositOAuthIntegrationCredentialsProvider(CredentialsProvider): + def __init__(self, posit_oauth: OAuthIntegration, user_session_token: str): + self.posit_oauth = posit_oauth + self.user_session_token = user_session_token + + def auth_type(self) -> str: + return "posit-oauth-integration" + + def __call__(self, *args, **kwargs) -> HeaderFactory: + def inner() -> Dict[str, str]: + access_token = self.posit_oauth.get_credentials(self.user_session_token)['access_token'] + return {"Authorization": f"Bearer {access_token}"} + return inner + + +# Use this environment variable to determine if the +# client SDK was initialized from a piece of content running on a Connect server. +def is_local() -> bool: + return not os.getenv("RSTUDIO_PRODUCT") == "CONNECT" + + +def viewer_credentials_provider(client: Optional[Client] = None, user_session_token: Optional[str] = None) -> Optional[CredentialsProvider]: + + # If the content is not running on Connect then viewer auth should + # fall back to the locally configured credentials hierarchy + if is_local(): + return None + + if client is None: + client = Client() + + # If the user-session-token wasn't provided and we're running on Connect then we raise an exception. + # user_session_token is required to impersonate the viewer. + if user_session_token is None: + raise ValueError("The user-session-token is required for viewer authentication.") + + return PositOAuthIntegrationCredentialsProvider(client.oauth, user_session_token) + + +def service_account_credentials_provider(client: Optional[Client] = None): + raise NotImplementedError diff --git a/src/posit/connect/external/databricks_test.py b/src/posit/connect/external/databricks_test.py new file mode 100644 index 00000000..e69de29b diff --git a/src/posit/connect/oauth.py b/src/posit/connect/oauth.py new file mode 100644 index 00000000..9fecdb40 --- /dev/null +++ b/src/posit/connect/oauth.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from requests import Session +from typing import Optional, TypedDict + +from . import urls +from .config import Config + + +class Credentials(TypedDict, total=False): + access_token: str + issued_token_type: str + token_type: str + + +class OAuthIntegration: + + def __init__( + self, config: Config, session: Session + ) -> None: + self.url = urls.append_path(config.url, "v1/oauth/integrations/credentials") + self.config = config + self.session = session + + + def get_credentials(self, user_session_token: Optional[str]=None) -> Credentials: + + # craft a basic credential exchange request where the self.config.api_key owner + # is requesting their own credentials + data = dict() + data["grant_type"] = "urn:ietf:params:oauth:grant-type:token-exchange" + data["subject_token_type"] = "urn:posit:connect:api-key" + data["subject_token"] = self.config.api_key + + # if this content is running on Connect, then it is allowed to request + # the content viewer's credentials + if user_session_token: + data["subject_token_type"] = "urn:posit:connect:user-session-token" + data["subject_token"] = user_session_token + + response = self.session.post(self.url, data=data) + return Credentials(**response.json()) diff --git a/src/posit/connect/oauth_test.py b/src/posit/connect/oauth_test.py new file mode 100644 index 00000000..f81389b3 --- /dev/null +++ b/src/posit/connect/oauth_test.py @@ -0,0 +1,51 @@ +import responses + +from .client import Client + +class TestOAuthIntegrations: + + @responses.activate + def test_get_credentials(self): + responses.post( + "https://connect.example/__api__/v1/oauth/integrations/credentials", + match=[ + responses.matchers.urlencoded_params_matcher( + { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "subject_token_type": "urn:posit:connect:user-session-token", + "subject_token": "cit", + } + ) + ], + json={ + "access_token": "viewer-token", + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + "token_type": "Bearer", + }, + ) + responses.post( + "https://connect.example/__api__/v1/oauth/integrations/credentials", + match=[ + responses.matchers.urlencoded_params_matcher( + { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "subject_token_type": "urn:posit:connect:api-key", + "subject_token": "12345", + } + ) + ], + json={ + "access_token": "sdk-user-token", + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + "token_type": "Bearer", + }, + ) + con = Client(api_key="12345", url="https://connect.example/") + assert ( + con.oauth.get_credentials()["access_token"] + == "sdk-user-token" + ) + assert ( + con.oauth.get_credentials("cit")["access_token"] + == "viewer-token" + )