Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Implement snowflake auth helpers #268

Merged
merged 4 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions examples/connect/snowflake/streamlit/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Streamlit Example

## Start the app locally

```bash
SNOWFLAKE_ACCOUNT = "<snowflake-account-identifier>"
SNOWFLAKE_WAREHOUSE = "<snowflake-warehouse-name>"

# USER is only required when running the example locally with external browser auth
SNOWFLAKE_USER="<snowflake-username>" streamlit run app.py
```

## Deploy to Posit Connect

Validate that `rsconnect-python` is installed:

```bash
rsconnect version
```

Or install it as documented in the [installation](https://docs.posit.co/rsconnect-python/#installation) section of the documentation.

To publish, make sure `CONNECT_SERVER`, `CONNECT_API_KEY`, `SNOWFLAKE_ACCOUNT`, `SNOWFLAKE_WAREHOUSE` have valid values. Then, on a terminal session, enter the following command:

```bash
rsconnect deploy streamlit . \
--server "${CONNECT_SERVER}" \
--api-key "${CONNECT_API_KEY}" \
--environment SNOWFLAKE_ACCOUNT \
--environment SNOWFLAKE_WAREHOUSE
```

Note that the Snowflake environment variables do not need to be resolved by the shell, so they do not include the `$` prefix.

The Snowflake environment variables only need to be set once, unless a change needs to be made. If the values have not changed, you don’t need to provide them again when you publish updates to the document.

```bash
rsconnect deploy streamlit . \
--server "${CONNECT_SERVER}" \
--api-key "${CONNECT_API_KEY}"
```
43 changes: 43 additions & 0 deletions examples/connect/snowflake/streamlit/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# -*- coding: utf-8 -*-
# mypy: ignore-errors
import os

import pandas as pd
import streamlit as st
import snowflake.connector

from posit.connect.external.snowflake import PositAuthenticator

ACCOUNT = os.getenv("SNOWFLAKE_ACCOUNT")
WAREHOUSE = os.getenv("SNOWFLAKE_WAREHOUSE")

# USER is only required when running the example locally with external browser auth
USER = os.getenv("SNOWFLAKE_USER")

# https://docs.snowflake.com/en/user-guide/sample-data-using
DATABASE = os.getenv("SNOWFLAKE_DATABASE", "snowflake_sample_data")
SCHEMA = os.getenv("SNOWFLAKE_SCHEMA", "tpch_sf1")
TABLE = os.getenv("SNOWFLAKE_TABLE", "lineitem")

session_token = st.context.headers.get("Posit-Connect-User-Session-Token")
auth = PositAuthenticator(
local_authenticator="EXTERNALBROWSER",
user_session_token=session_token)

con = snowflake.connector.connect(
user=USER,
account=ACCOUNT,
warehouse=WAREHOUSE,
database=DATABASE,
schema=SCHEMA,
authenticator=auth.authenticator(),
token=auth.token(),
)

snowflake_user = con.cursor().execute("SELECT CURRENT_USER()").fetchone()
st.write(f"Hello, {snowflake_user[0]}!")

with st.spinner("Loading data from Snowflake..."):
df = pd.read_sql_query(f"SELECT * FROM {TABLE} LIMIT 10", con)

st.dataframe(df)
3 changes: 3 additions & 0 deletions examples/connect/snowflake/streamlit/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
snowflake-connector-python==3.12.1
streamlit==1.37.0
posit-sdk>=0.4.1
14 changes: 14 additions & 0 deletions src/posit/connect/external/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import os

"""
NOTE: The APIs in this module are provided as a convenience and are subject to breaking changes.
"""

def _is_local() -> bool:
"""Returns true if called from a piece of content running on a Connect server.

The connect server will always set the environment variable `RSTUDIO_PRODUCT=CONNECT`.
We can use this environment variable to determine if the content is running locally
or on a Connect server.
"""
return not os.getenv("RSTUDIO_PRODUCT") == "CONNECT"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't found a clear answer to best practices, but I've recently started reducing init.py to only include named exports. Then placing any shared module code in a file of the same name; external/external.py here.

This also eliminates the need to mark the method as private since it won't be declared in the init.py file.

31 changes: 15 additions & 16 deletions src/posit/connect/external/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Callable, Dict, Optional

from ..client import Client
from ..oauth import OAuthIntegration

"""
NOTE: These APIs are provided as a convenience and are subject to breaking changes:
Expand Down Expand Up @@ -41,13 +40,13 @@ def _is_local() -> bool:


class PositCredentialsProvider:
def __init__(self, posit_oauth: OAuthIntegration, user_session_token: str):
self.posit_oauth = posit_oauth
self.user_session_token = user_session_token
def __init__(self, client: Client, user_session_token: str):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

breaking: PositCredentialsProvider now accepts a Client instead of an OAuthIntegration resource. This fits better with the API changes that @zackverham is about to make.

self._client = client
self._user_session_token = user_session_token
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

breaking: These fields are now _internal

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! If you can capture the breaking changes in the PR description, it will be easy to pull into the release notes.


def __call__(self) -> Dict[str, str]:
access_token = self.posit_oauth.get_credentials(
self.user_session_token
access_token = self._client.oauth.get_credentials(
self._user_session_token
)["access_token"]
return {"Authorization": f"Bearer {access_token}"}

Expand All @@ -56,12 +55,12 @@ class PositCredentialsStrategy(CredentialsStrategy):
def __init__(
self,
local_strategy: CredentialsStrategy,
user_session_token: Optional[str] = None,
client: Optional[Client] = None,
user_session_token: Optional[str] = None,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

potentially breaking? These are named arguments but I did change the order. Callers who are using args instead of kwargs may break

):
self.user_session_token = user_session_token
self.local_strategy = local_strategy
self.client = client
self._local_strategy = local_strategy
self._client = client
self._user_session_token = user_session_token
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

breaking: These fields are now _internal


def sql_credentials_provider(self, *args, **kwargs):
"""The sql connector attempts to call the credentials provider w/o any args.
Expand Down Expand Up @@ -90,25 +89,25 @@ def auth_type(self) -> str:
https://github.com/databricks/databricks-sql-python/blob/v3.3.0/src/databricks/sql/client.py#L214-L219
"""
if _is_local():
return self.local_strategy.auth_type()
return self._local_strategy.auth_type()
else:
return "posit-oauth-integration"

def __call__(self, *args, **kwargs) -> CredentialsProvider:
# If the content is not running on Connect then fall back to local_strategy
if _is_local():
return self.local_strategy(*args, **kwargs)
return self._local_strategy(*args, **kwargs)

# If the user-session-token wasn't provided and we're running on Connect then we raise an exception.
# user_session_token is required to impersonate the viewer.
if self.user_session_token is None:
if self._user_session_token is None:
raise ValueError(
"The user-session-token is required for viewer authentication."
)

if self.client is None:
self.client = Client()
if self._client is None:
self._client = Client()

return PositCredentialsProvider(
self.client.oauth, self.user_session_token
self._client, self._user_session_token
)
44 changes: 44 additions & 0 deletions src/posit/connect/external/snowflake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Optional

from . import _is_local
from ..client import Client

"""
NOTE: The APIs in this module are provided as a convenience and are subject to breaking changes.
"""

class PositAuthenticator:
tdstein marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
local_authenticator: Optional[str] = None,
client: Optional[Client] = None,
user_session_token: Optional[str] = None,
):
self._local_authenticator = local_authenticator
self._client = client
self._user_session_token = user_session_token

def authenticator(self) -> Optional[str]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a @property?

if _is_local():
return self._local_authenticator
return "oauth"

def token(self) -> Optional[str]:
if _is_local():
return None

# If the user-session-token wasn't provided and we're running on Connect then we raise an exception.
# user_session_token is required to impersonate the viewer.
if self._user_session_token is None:
raise ValueError(
"The user-session-token is required for viewer authentication."
)

if self._client is None:
self._client = Client()

access_token = self._client.oauth.get_credentials(
self._user_session_token
)["access_token"]
return access_token

2 changes: 1 addition & 1 deletion tests/posit/connect/external/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_posit_credentials_provider(self):

client = Client(api_key="12345", url="https://connect.example/")
cp = PositCredentialsProvider(
posit_oauth=client.oauth, user_session_token="cit"
client=client, user_session_token="cit"
)
assert cp() == {"Authorization": f"Bearer dynamic-viewer-access-token"}

Expand Down
54 changes: 54 additions & 0 deletions tests/posit/connect/external/test_snowflake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Dict
from unittest.mock import patch

import responses

from posit.connect import Client
from posit.connect.external.snowflake import PositAuthenticator


def register_mocks():
responses.post(
"https://connect.example/__api__/v1/oauth/integrations/credentials",
match=[
responses.matchers.urlencoded_params_matcher(
{
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
"subject_token_type": "urn:posit:connect:user-session-token",
"subject_token": "cit",
}
)
],
json={
"access_token": "dynamic-viewer-access-token",
"issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
"token_type": "Bearer",
},
)


class TestPositAuthenticator:
@responses.activate
@patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"})
def test_posit_authenticator(self):
register_mocks()

client = Client(api_key="12345", url="https://connect.example/")
auth = PositAuthenticator(
local_authenticator="SNOWFLAKE",
user_session_token="cit",
client=client,
)
assert auth.authenticator() == "oauth"
assert auth.token() == "dynamic-viewer-access-token"

def test_posit_authenticator_fallback(self):
# local_authenticator is used when the content is running locally
client = Client(api_key="12345", url="https://connect.example/")
auth = PositAuthenticator(
local_authenticator="SNOWFLAKE",
user_session_token="cit",
client=client,
)
assert auth.authenticator() == "SNOWFLAKE"
assert auth.token() == None
Loading