Skip to content

Commit

Permalink
Add Databricks OAuthIntegration credentials provider
Browse files Browse the repository at this point in the history
  • Loading branch information
dbkegley committed Feb 23, 2024
1 parent 26c2730 commit 93bf580
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 1 deletion.
16 changes: 16 additions & 0 deletions src/posit/connect/client.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
from __future__ import annotations

import os
from requests import Response, Session
from typing import Optional

from . import hooks, urls

from .auth import Auth
from .config import Config
from .oauth import OAuthIntegration
from .users import Users


# Connect sets the value of the environment variable RSTUDIO_PRODUCT = CONNECT
# when content is running on a Connect server. Use this var to determine if the
# client SDK was initialized from a piece of content running on a Connect server.
def is_local() -> bool:
return not os.getenv("RSTUDIO_PRODUCT") == "CONNECT"


class Client:
def __init__(
self,
Expand Down Expand Up @@ -37,13 +46,20 @@ def __init__(

# Place to cache the server settings
self.server_settings = None
self._oauth = None

@property
def connect_version(self):
if self.server_settings is None:
self.server_settings = self.get("server_settings").json()
return self.server_settings["version"]

@property
def oauth(self) -> OAuthIntegration:
if self._oauth is None:
self._oauth = OAuthIntegration(config=self.config, session=self.session)
return self._oauth

@property
def users(self) -> Users:
return Users(config=self.config, session=self.session)
Expand Down
13 changes: 12 additions & 1 deletion src/posit/connect/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from unittest.mock import MagicMock, patch


from .client import Client


Expand All @@ -25,6 +24,18 @@ def MockSession():
yield mock


@pytest.fixture
def MockOAuthIntegration():
with patch("posit.connect.client.OAuthIntegration") as mock:
yield mock


@pytest.fixture
def MockUsers():
with patch("posit.connect.client.Users") as mock:
yield mock


class TestClient:
def test_init(
self,
Expand Down
Empty file.
60 changes: 60 additions & 0 deletions src/posit/connect/external/databricks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import abc
from typing import Callable, Dict, Optional

from ..client import Client, is_local
from ..oauth import OAuthIntegration

HeaderFactory = Callable[[], Dict[str, str]]

# https://github.com/databricks/databricks-sdk-py/blob/v0.20.0/databricks/sdk/credentials_provider.py
# https://github.com/databricks/databricks-sql-python/blob/v3.1.0/src/databricks/sql/auth/authenticators.py
# In order to keep compatibility with the Databricks SDK
class CredentialsProvider(abc.ABC):
"""CredentialsProvider is the protocol (call-side interface)
for authenticating requests to Databricks REST APIs"""

@abc.abstractmethod
def auth_type(self) -> str:
...

@abc.abstractmethod
def __call__(self, *args, **kwargs) -> HeaderFactory:
...


class PositOAuthIntegrationCredentialsProvider(CredentialsProvider):
def __init__(self, posit_oauth: OAuthIntegration, user_identity: str):
self.posit_oauth = posit_oauth
self.user_identity = user_identity

def auth_type(self) -> str:
return "posit-oauth-integration"

def __call__(self, *args, **kwargs) -> HeaderFactory:
def inner() -> Dict[str, str]:
access_token = self.posit_oauth.get_credentials(self.user_identity).json()['access_token']
return {"Authorization": f"Bearer {access_token}"}
return inner


def viewer_credentials_provider(client: Optional[Client], user_identity: Optional[str]) -> Optional[CredentialsProvider]:

# If the content is not running on Connect then viewer auth should
# fall back to the locally configured credentials hierarchy
if is_local():
return None

if client is None:
client = Client()


# If the user-identity-token wasn't provided and we're running on Connect then we raise an exception.
# user_identity is required to impersonate the viewer.
if user_identity is None:
raise Exception("The user-identity-token is required for viewer authentication.")

return PositOAuthIntegrationCredentialsProvider(client.oauth, user_identity)


def service_account_credentials_provider(client: Optional[Client]):
raise NotImplemented
Empty file.
35 changes: 35 additions & 0 deletions src/posit/connect/oauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations

from requests import Response, Session
from typing import Optional

from . import urls
from .config import Config


class OAuthIntegration:

def __init__(
self, config: Config, session: Session
) -> None:
self.url = urls.append_path(config.url, "v1/oauth/integrations/credentials")
self.config = config
self.session = session


def get_credentials(self, user_identity: Optional[str]) -> Response:

# craft a basic credential exchange request where the self.config.api_key owner
# is requesting their own credentials
data = dict()
data["grant_type"] = "urn:ietf:params:oauth:grant-type:token-exchange"
data["subject_token_type"] = "urn:posit:connect:api-key"
data["subject_token"] = self.config.api_key

# if this content is running on Connect, then it is allowed to request
# the content viewer's credentials
if user_identity:
data["subject_token_type"] = "urn:posit:connect:user-identity-token"
data["subject_token"] = user_identity

return self.session.post(self.url, json=data)
Empty file added src/posit/connect/oauth_test.py
Empty file.

0 comments on commit 93bf580

Please sign in to comment.