-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat!: Implement snowflake auth helpers #268
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Streamlit Example | ||
|
||
## Start the app locally | ||
|
||
```bash | ||
SNOWFLAKE_ACCOUNT = "<snowflake-account-identifier>" | ||
SNOWFLAKE_WAREHOUSE = "<snowflake-warehouse-name>" | ||
|
||
# USER is only required when running the example locally with external browser auth | ||
SNOWFLAKE_USER="<snowflake-username>" streamlit run app.py | ||
``` | ||
|
||
## Deploy to Posit Connect | ||
|
||
Validate that `rsconnect-python` is installed: | ||
|
||
```bash | ||
rsconnect version | ||
``` | ||
|
||
Or install it as documented in the [installation](https://docs.posit.co/rsconnect-python/#installation) section of the documentation. | ||
|
||
To publish, make sure `CONNECT_SERVER`, `CONNECT_API_KEY`, `SNOWFLAKE_ACCOUNT`, `SNOWFLAKE_WAREHOUSE` have valid values. Then, on a terminal session, enter the following command: | ||
|
||
```bash | ||
rsconnect deploy streamlit . \ | ||
--server "${CONNECT_SERVER}" \ | ||
--api-key "${CONNECT_API_KEY}" \ | ||
--environment SNOWFLAKE_ACCOUNT \ | ||
--environment SNOWFLAKE_WAREHOUSE | ||
``` | ||
|
||
Note that the Snowflake environment variables do not need to be resolved by the shell, so they do not include the `$` prefix. | ||
|
||
The Snowflake environment variables only need to be set once, unless a change needs to be made. If the values have not changed, you don’t need to provide them again when you publish updates to the document. | ||
|
||
```bash | ||
rsconnect deploy streamlit . \ | ||
--server "${CONNECT_SERVER}" \ | ||
--api-key "${CONNECT_API_KEY}" | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# -*- coding: utf-8 -*- | ||
# mypy: ignore-errors | ||
import os | ||
|
||
import pandas as pd | ||
import streamlit as st | ||
import snowflake.connector | ||
|
||
from posit.connect.external.snowflake import PositAuthenticator | ||
|
||
ACCOUNT = os.getenv("SNOWFLAKE_ACCOUNT") | ||
WAREHOUSE = os.getenv("SNOWFLAKE_WAREHOUSE") | ||
|
||
# USER is only required when running the example locally with external browser auth | ||
USER = os.getenv("SNOWFLAKE_USER") | ||
|
||
# https://docs.snowflake.com/en/user-guide/sample-data-using | ||
DATABASE = os.getenv("SNOWFLAKE_DATABASE", "snowflake_sample_data") | ||
SCHEMA = os.getenv("SNOWFLAKE_SCHEMA", "tpch_sf1") | ||
TABLE = os.getenv("SNOWFLAKE_TABLE", "lineitem") | ||
|
||
session_token = st.context.headers.get("Posit-Connect-User-Session-Token") | ||
auth = PositAuthenticator( | ||
local_authenticator="EXTERNALBROWSER", | ||
user_session_token=session_token) | ||
|
||
con = snowflake.connector.connect( | ||
user=USER, | ||
account=ACCOUNT, | ||
warehouse=WAREHOUSE, | ||
database=DATABASE, | ||
schema=SCHEMA, | ||
authenticator=auth.authenticator(), | ||
token=auth.token(), | ||
) | ||
|
||
snowflake_user = con.cursor().execute("SELECT CURRENT_USER()").fetchone() | ||
st.write(f"Hello, {snowflake_user[0]}!") | ||
|
||
with st.spinner("Loading data from Snowflake..."): | ||
df = pd.read_sql_query(f"SELECT * FROM {TABLE} LIMIT 10", con) | ||
|
||
st.dataframe(df) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
snowflake-connector-python==3.12.1 | ||
streamlit==1.37.0 | ||
posit-sdk>=0.4.1 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import os | ||
|
||
""" | ||
NOTE: The APIs in this module are provided as a convenience and are subject to breaking changes. | ||
""" | ||
|
||
def _is_local() -> bool: | ||
"""Returns true if called from a piece of content running on a Connect server. | ||
|
||
The connect server will always set the environment variable `RSTUDIO_PRODUCT=CONNECT`. | ||
We can use this environment variable to determine if the content is running locally | ||
or on a Connect server. | ||
""" | ||
return not os.getenv("RSTUDIO_PRODUCT") == "CONNECT" | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,6 @@ | |
from typing import Callable, Dict, Optional | ||
|
||
from ..client import Client | ||
from ..oauth import OAuthIntegration | ||
|
||
""" | ||
NOTE: These APIs are provided as a convenience and are subject to breaking changes: | ||
|
@@ -41,13 +40,13 @@ def _is_local() -> bool: | |
|
||
|
||
class PositCredentialsProvider: | ||
def __init__(self, posit_oauth: OAuthIntegration, user_session_token: str): | ||
self.posit_oauth = posit_oauth | ||
self.user_session_token = user_session_token | ||
def __init__(self, client: Client, user_session_token: str): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. breaking: PositCredentialsProvider now accepts a |
||
self._client = client | ||
self._user_session_token = user_session_token | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. breaking: These fields are now There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! If you can capture the breaking changes in the PR description, it will be easy to pull into the release notes. |
||
|
||
def __call__(self) -> Dict[str, str]: | ||
access_token = self.posit_oauth.get_credentials( | ||
self.user_session_token | ||
access_token = self._client.oauth.get_credentials( | ||
self._user_session_token | ||
)["access_token"] | ||
return {"Authorization": f"Bearer {access_token}"} | ||
|
||
|
@@ -56,12 +55,12 @@ class PositCredentialsStrategy(CredentialsStrategy): | |
def __init__( | ||
self, | ||
local_strategy: CredentialsStrategy, | ||
user_session_token: Optional[str] = None, | ||
client: Optional[Client] = None, | ||
user_session_token: Optional[str] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. potentially breaking? These are named arguments but I did change the order. Callers who are using args instead of kwargs may break |
||
): | ||
self.user_session_token = user_session_token | ||
self.local_strategy = local_strategy | ||
self.client = client | ||
self._local_strategy = local_strategy | ||
self._client = client | ||
self._user_session_token = user_session_token | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. breaking: These fields are now |
||
|
||
def sql_credentials_provider(self, *args, **kwargs): | ||
"""The sql connector attempts to call the credentials provider w/o any args. | ||
|
@@ -90,25 +89,25 @@ def auth_type(self) -> str: | |
https://github.com/databricks/databricks-sql-python/blob/v3.3.0/src/databricks/sql/client.py#L214-L219 | ||
""" | ||
if _is_local(): | ||
return self.local_strategy.auth_type() | ||
return self._local_strategy.auth_type() | ||
else: | ||
return "posit-oauth-integration" | ||
|
||
def __call__(self, *args, **kwargs) -> CredentialsProvider: | ||
# If the content is not running on Connect then fall back to local_strategy | ||
if _is_local(): | ||
return self.local_strategy(*args, **kwargs) | ||
return self._local_strategy(*args, **kwargs) | ||
|
||
# If the user-session-token wasn't provided and we're running on Connect then we raise an exception. | ||
# user_session_token is required to impersonate the viewer. | ||
if self.user_session_token is None: | ||
if self._user_session_token is None: | ||
raise ValueError( | ||
"The user-session-token is required for viewer authentication." | ||
) | ||
|
||
if self.client is None: | ||
self.client = Client() | ||
if self._client is None: | ||
self._client = Client() | ||
|
||
return PositCredentialsProvider( | ||
self.client.oauth, self.user_session_token | ||
self._client, self._user_session_token | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from typing import Optional | ||
|
||
from . import _is_local | ||
from ..client import Client | ||
|
||
""" | ||
NOTE: The APIs in this module are provided as a convenience and are subject to breaking changes. | ||
""" | ||
|
||
class PositAuthenticator: | ||
tdstein marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def __init__( | ||
self, | ||
local_authenticator: Optional[str] = None, | ||
client: Optional[Client] = None, | ||
user_session_token: Optional[str] = None, | ||
): | ||
self._local_authenticator = local_authenticator | ||
self._client = client | ||
self._user_session_token = user_session_token | ||
|
||
def authenticator(self) -> Optional[str]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be a |
||
if _is_local(): | ||
return self._local_authenticator | ||
return "oauth" | ||
|
||
def token(self) -> Optional[str]: | ||
if _is_local(): | ||
return None | ||
|
||
# If the user-session-token wasn't provided and we're running on Connect then we raise an exception. | ||
# user_session_token is required to impersonate the viewer. | ||
if self._user_session_token is None: | ||
raise ValueError( | ||
"The user-session-token is required for viewer authentication." | ||
) | ||
|
||
if self._client is None: | ||
self._client = Client() | ||
|
||
access_token = self._client.oauth.get_credentials( | ||
self._user_session_token | ||
)["access_token"] | ||
return access_token | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from typing import Dict | ||
from unittest.mock import patch | ||
|
||
import responses | ||
|
||
from posit.connect import Client | ||
from posit.connect.external.snowflake import PositAuthenticator | ||
|
||
|
||
def register_mocks(): | ||
responses.post( | ||
"https://connect.example/__api__/v1/oauth/integrations/credentials", | ||
match=[ | ||
responses.matchers.urlencoded_params_matcher( | ||
{ | ||
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", | ||
"subject_token_type": "urn:posit:connect:user-session-token", | ||
"subject_token": "cit", | ||
} | ||
) | ||
], | ||
json={ | ||
"access_token": "dynamic-viewer-access-token", | ||
"issued_token_type": "urn:ietf:params:oauth:token-type:access_token", | ||
"token_type": "Bearer", | ||
}, | ||
) | ||
|
||
|
||
class TestPositAuthenticator: | ||
@responses.activate | ||
@patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"}) | ||
def test_posit_authenticator(self): | ||
register_mocks() | ||
|
||
client = Client(api_key="12345", url="https://connect.example/") | ||
auth = PositAuthenticator( | ||
local_authenticator="SNOWFLAKE", | ||
user_session_token="cit", | ||
client=client, | ||
) | ||
assert auth.authenticator() == "oauth" | ||
assert auth.token() == "dynamic-viewer-access-token" | ||
|
||
def test_posit_authenticator_fallback(self): | ||
# local_authenticator is used when the content is running locally | ||
client = Client(api_key="12345", url="https://connect.example/") | ||
auth = PositAuthenticator( | ||
local_authenticator="SNOWFLAKE", | ||
user_session_token="cit", | ||
client=client, | ||
) | ||
assert auth.authenticator() == "SNOWFLAKE" | ||
assert auth.token() == None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't found a clear answer to best practices, but I've recently started reducing init.py to only include named exports. Then placing any shared module code in a file of the same name;
external/external.py
here.This also eliminates the need to mark the method as private since it won't be declared in the init.py file.