Skip to content

Commit

Permalink
Adds an OAuthIntegration (#52)
Browse files Browse the repository at this point in the history
* Add Databricks OAuthIntegration credentials provider

* Return Credentials response. Dont cache oauthIntegration resource in client

* Add oauth integration tests

* Test databricks content on Connect

* Move is_local; check CONNECT_SERVER and CONNECT_CONTENT_GUID

* Use post body temporarily for oauth/integration/credentials endpoints

* use form xml post in credentials request

* Add sample databricks content

* Fix linter

* update sample content requirements

* Revert env var change

* Use user-session-token instead of content-identity-token
  • Loading branch information
dbkegley authored Feb 27, 2024
1 parent 4e929e4 commit 19b6cab
Show file tree
Hide file tree
Showing 11 changed files with 285 additions and 1 deletion.
1 change: 1 addition & 0 deletions examples/connect/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.posit
19 changes: 19 additions & 0 deletions examples/connect/databricks/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
```bash
# start streamlit locally
DATABRICKS_TOKEN=<DB_PAT> \
streamlit run ./sample-content.py

# deploy the app the first time
publisher deploy -a localhost:3939 -n databricks ./

# re-deploy the databricks app
publisher redeploy databricks
```

TODO: Test this content with databricks-connect
<https://docs.databricks.com/en/dev-tools/databricks-connect/python/index.html>

```
# install the sdk from this branch
pip install git+https://github.com/posit-dev/posit-sdk-py.git@kegs/databricks-oauth-2
```
56 changes: 56 additions & 0 deletions examples/connect/databricks/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
altair==5.2.0
attrs==23.2.0
blinker==1.7.0
cachetools==5.3.2
certifi==2024.2.2
charset-normalizer==3.3.2
click==8.1.7
databricks-sdk==0.20.0
databricks-sql-connector==3.1.0
et-xmlfile==1.1.0
gitdb==4.0.11
GitPython==3.1.42
google-auth==2.28.0
idna==3.6
importlib-metadata==7.0.1
Jinja2==3.1.3
jsonschema==4.21.1
jsonschema-specifications==2023.12.1
lz4==4.3.3
markdown-it-py==3.0.0
MarkupSafe==2.1.5
mdurl==0.1.2
numpy==1.26.4
oauthlib==3.2.2
openpyxl==3.1.2
packaging==23.2
pandas==2.1.4
pillow==10.2.0
posit-sdk @ git+https://github.com/posit-dev/posit-sdk-py.git@24ad71458de56b01b8168e80441ea860236f1933
protobuf==4.25.3
pyarrow==14.0.2
pyasn1==0.5.1
pyasn1-modules==0.3.0
pydeck==0.8.1b0
Pygments==2.17.2
python-dateutil==2.8.2
pytz==2024.1
referencing==0.33.0
requests==2.31.0
rich==13.7.0
rpds-py==0.18.0
rsa==4.9
six==1.16.0
smmap==5.0.1
streamlit==1.31.1
tenacity==8.2.3
thrift==0.16.0
toml==0.10.2
toolz==0.12.1
tornado==6.4
typing_extensions==4.9.0
tzdata==2024.1
tzlocal==5.2
urllib3==2.2.1
validators==0.22.0
zipp==3.17.0
45 changes: 45 additions & 0 deletions examples/connect/databricks/sample-content.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# -*- coding: utf-8 -*-
# mypy: ignore-errors
import os

from posit.connect.external.databricks import viewer_credentials_provider

from databricks import sql
from databricks.sdk.service.iam import CurrentUserAPI
from databricks.sdk.core import ApiClient, Config

import pandas as pd
import streamlit as st
from streamlit.web.server.websocket_headers import _get_websocket_headers

DB_PAT=os.getenv("DATABRICKS_TOKEN")

DB_HOST=os.getenv("DB_HOST")
DB_HOST_URL = f"https://{DB_HOST}"
SQL_HTTP_PATH=os.getenv("SQL_HTTP_PATH")

USER_SESSION_TOKEN = None

# Read the viewer's user session token from the streamlit ws header.
headers = _get_websocket_headers()
if headers:
USER_SESSION_TOKEN = headers.get('Posit-Connect-User-Session')

credentials_provider = viewer_credentials_provider(user_session_token=USER_SESSION_TOKEN)
cfg = Config(host=DB_HOST_URL, credentials_provider=credentials_provider)
#cfg = Config(host=DB_HOST_URL, token=DB_PAT)

databricks_user = CurrentUserAPI(ApiClient(cfg)).me()
st.write(f"Hello, {databricks_user.display_name}!")

with sql.connect(
server_hostname=DB_HOST,
http_path=SQL_HTTP_PATH,
#access_token=DB_PAT) as connection:
auth_type='databricks-oauth',
credentials_provider=credentials_provider) as connection:
with connection.cursor() as cursor:
cursor.execute("SELECT * FROM data")
result = cursor.fetchall()
st.table(pd.DataFrame(result))

5 changes: 5 additions & 0 deletions src/posit/connect/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from .auth import Auth
from .config import Config
from .oauth import OAuthIntegration
from .content import Content
from .users import User, Users

Expand Down Expand Up @@ -51,6 +52,10 @@ def me(self) -> User:
response = self.session.get(url)
return User(**response.json())

@property
def oauth(self) -> OAuthIntegration:
return OAuthIntegration(config=self.config, session=self.session)

@property
def users(self) -> Users:
return Users(config=self.config, session=self.session)
Expand Down
1 change: 0 additions & 1 deletion src/posit/connect/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from unittest.mock import MagicMock, patch


from .client import Client


Expand Down
Empty file.
66 changes: 66 additions & 0 deletions src/posit/connect/external/databricks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import abc
import os
from typing import Callable, Dict, Optional

from ..client import Client
from ..oauth import OAuthIntegration

HeaderFactory = Callable[[], Dict[str, str]]

# https://github.com/databricks/databricks-sdk-py/blob/v0.20.0/databricks/sdk/credentials_provider.py
# https://github.com/databricks/databricks-sql-python/blob/v3.1.0/src/databricks/sql/auth/authenticators.py
# In order to keep compatibility with the Databricks SDK
class CredentialsProvider(abc.ABC):
"""CredentialsProvider is the protocol (call-side interface)
for authenticating requests to Databricks REST APIs"""

@abc.abstractmethod
def auth_type(self) -> str:
raise NotImplementedError

@abc.abstractmethod
def __call__(self, *args, **kwargs) -> HeaderFactory:
raise NotImplementedError


class PositOAuthIntegrationCredentialsProvider(CredentialsProvider):
def __init__(self, posit_oauth: OAuthIntegration, user_session_token: str):
self.posit_oauth = posit_oauth
self.user_session_token = user_session_token

def auth_type(self) -> str:
return "posit-oauth-integration"

def __call__(self, *args, **kwargs) -> HeaderFactory:
def inner() -> Dict[str, str]:
access_token = self.posit_oauth.get_credentials(self.user_session_token)['access_token']
return {"Authorization": f"Bearer {access_token}"}
return inner


# Use this environment variable to determine if the
# client SDK was initialized from a piece of content running on a Connect server.
def is_local() -> bool:
return not os.getenv("RSTUDIO_PRODUCT") == "CONNECT"


def viewer_credentials_provider(client: Optional[Client] = None, user_session_token: Optional[str] = None) -> Optional[CredentialsProvider]:

# If the content is not running on Connect then viewer auth should
# fall back to the locally configured credentials hierarchy
if is_local():
return None

if client is None:
client = Client()

# If the user-session-token wasn't provided and we're running on Connect then we raise an exception.
# user_session_token is required to impersonate the viewer.
if user_session_token is None:
raise ValueError("The user-session-token is required for viewer authentication.")

return PositOAuthIntegrationCredentialsProvider(client.oauth, user_session_token)


def service_account_credentials_provider(client: Optional[Client] = None):
raise NotImplementedError
Empty file.
42 changes: 42 additions & 0 deletions src/posit/connect/oauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from __future__ import annotations

from requests import Session
from typing import Optional, TypedDict

from . import urls
from .config import Config


class Credentials(TypedDict, total=False):
access_token: str
issued_token_type: str
token_type: str


class OAuthIntegration:

def __init__(
self, config: Config, session: Session
) -> None:
self.url = urls.append_path(config.url, "v1/oauth/integrations/credentials")
self.config = config
self.session = session


def get_credentials(self, user_session_token: Optional[str]=None) -> Credentials:

# craft a basic credential exchange request where the self.config.api_key owner
# is requesting their own credentials
data = dict()
data["grant_type"] = "urn:ietf:params:oauth:grant-type:token-exchange"
data["subject_token_type"] = "urn:posit:connect:api-key"
data["subject_token"] = self.config.api_key

# if this content is running on Connect, then it is allowed to request
# the content viewer's credentials
if user_session_token:
data["subject_token_type"] = "urn:posit:connect:user-session-token"
data["subject_token"] = user_session_token

response = self.session.post(self.url, data=data)
return Credentials(**response.json())
51 changes: 51 additions & 0 deletions src/posit/connect/oauth_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import responses

from .client import Client

class TestOAuthIntegrations:

@responses.activate
def test_get_credentials(self):
responses.post(
"https://connect.example/__api__/v1/oauth/integrations/credentials",
match=[
responses.matchers.urlencoded_params_matcher(
{
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
"subject_token_type": "urn:posit:connect:user-session-token",
"subject_token": "cit",
}
)
],
json={
"access_token": "viewer-token",
"issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
"token_type": "Bearer",
},
)
responses.post(
"https://connect.example/__api__/v1/oauth/integrations/credentials",
match=[
responses.matchers.urlencoded_params_matcher(
{
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
"subject_token_type": "urn:posit:connect:api-key",
"subject_token": "12345",
}
)
],
json={
"access_token": "sdk-user-token",
"issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
"token_type": "Bearer",
},
)
con = Client(api_key="12345", url="https://connect.example/")
assert (
con.oauth.get_credentials()["access_token"]
== "sdk-user-token"
)
assert (
con.oauth.get_credentials("cit")["access_token"]
== "viewer-token"
)

0 comments on commit 19b6cab

Please sign in to comment.