Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Implement snowflake auth helpers #268

Merged
merged 4 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 0 additions & 47 deletions examples/0001-overview.qmd

This file was deleted.

11 changes: 11 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
### Posit SDK Examples

For more in-depth SDK examples, covering a variety of use cases, check out the
[Posit Connect Cookbook](https://docs.posit.co/connect/cookbook/getting-started/).

> [!NOTE]
> The databricks and snowflake examples will be removed from this repo is a future SDK release.
> Please see the updated examples in the
> [Impersonating the Content Viewer](https://docs.posit.co/connect/cookbook/content/impersonating-the-content-viewer/)
> section of the Connect Cookbook.

Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from databricks import sql
from databricks.sdk.core import ApiClient, Config, databricks_cli
from databricks.sdk.service.iam import CurrentUserAPI

from posit.connect.external.databricks import PositCredentialsStrategy

DATABRICKS_HOST = os.getenv("DATABRICKS_HOST")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from databricks.sdk.core import Config, databricks_cli
from fastapi import FastAPI, Header
from fastapi.responses import JSONResponse

from posit.connect.external.databricks import PositCredentialsStrategy

DATABRICKS_HOST = os.getenv("DATABRICKS_HOST")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from databricks import sql
from databricks.sdk.core import Config, databricks_cli
from flask import Flask, request

from posit.connect.external.databricks import PositCredentialsStrategy

DATABRICKS_HOST = os.getenv("DATABRICKS_HOST")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from databricks import sql
from databricks.sdk.core import ApiClient, Config, databricks_cli
from databricks.sdk.service.iam import CurrentUserAPI
from shiny import App, Inputs, Outputs, Session, render, ui

from posit.connect.external.databricks import PositCredentialsStrategy
from shiny import App, Inputs, Outputs, Session, render, ui

DATABRICKS_HOST = os.getenv("DATABRICKS_HOST")
DATABRICKS_HOST_URL = f"https://{DATABRICKS_HOST}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from databricks import sql
from databricks.sdk.core import ApiClient, Config, databricks_cli
from databricks.sdk.service.iam import CurrentUserAPI

from posit.connect.external.databricks import PositCredentialsStrategy

DATABRICKS_HOST = os.getenv("DATABRICKS_HOST")
Expand Down
41 changes: 41 additions & 0 deletions examples/connect/snowflake/streamlit/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Streamlit Example

## Start the app locally

```bash
SNOWFLAKE_ACCOUNT = "<snowflake-account-identifier>"
SNOWFLAKE_WAREHOUSE = "<snowflake-warehouse-name>"

# USER is only required when running the example locally with external browser auth
SNOWFLAKE_USER="<snowflake-username>" streamlit run app.py
```

## Deploy to Posit Connect

Validate that `rsconnect-python` is installed:

```bash
rsconnect version
```

Or install it as documented in the [installation](https://docs.posit.co/rsconnect-python/#installation) section of the documentation.

To publish, make sure `CONNECT_SERVER`, `CONNECT_API_KEY`, `SNOWFLAKE_ACCOUNT`, `SNOWFLAKE_WAREHOUSE` have valid values. Then, on a terminal session, enter the following command:

```bash
rsconnect deploy streamlit . \
--server "${CONNECT_SERVER}" \
--api-key "${CONNECT_API_KEY}" \
--environment SNOWFLAKE_ACCOUNT \
--environment SNOWFLAKE_WAREHOUSE
```

Note that the Snowflake environment variables do not need to be resolved by the shell, so they do not include the `$` prefix.

The Snowflake environment variables only need to be set once, unless a change needs to be made. If the values have not changed, you don’t need to provide them again when you publish updates to the document.

```bash
rsconnect deploy streamlit . \
--server "${CONNECT_SERVER}" \
--api-key "${CONNECT_API_KEY}"
```
42 changes: 42 additions & 0 deletions examples/connect/snowflake/streamlit/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-
# mypy: ignore-errors
import os

import pandas as pd
import snowflake.connector
import streamlit as st
from posit.connect.external.snowflake import PositAuthenticator

ACCOUNT = os.getenv("SNOWFLAKE_ACCOUNT")
WAREHOUSE = os.getenv("SNOWFLAKE_WAREHOUSE")

# USER is only required when running the example locally with external browser auth
USER = os.getenv("SNOWFLAKE_USER")

# https://docs.snowflake.com/en/user-guide/sample-data-using
DATABASE = os.getenv("SNOWFLAKE_DATABASE", "snowflake_sample_data")
SCHEMA = os.getenv("SNOWFLAKE_SCHEMA", "tpch_sf1")
TABLE = os.getenv("SNOWFLAKE_TABLE", "lineitem")

session_token = st.context.headers.get("Posit-Connect-User-Session-Token")
auth = PositAuthenticator(
local_authenticator="EXTERNALBROWSER", user_session_token=session_token
)

con = snowflake.connector.connect(
user=USER,
account=ACCOUNT,
warehouse=WAREHOUSE,
database=DATABASE,
schema=SCHEMA,
authenticator=auth.authenticator,
token=auth.token,
)

snowflake_user = con.cursor().execute("SELECT CURRENT_USER()").fetchone()
st.write(f"Hello, {snowflake_user[0]}!")

with st.spinner("Loading data from Snowflake..."):
df = pd.read_sql_query(f"SELECT * FROM {TABLE} LIMIT 10", con)

st.dataframe(df)
3 changes: 3 additions & 0 deletions examples/connect/snowflake/streamlit/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
snowflake-connector-python==3.12.1
streamlit==1.37.0
posit-sdk>=0.4.1
1 change: 0 additions & 1 deletion integration/tests/posit/connect/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from packaging import version

from posit import connect

client = connect.Client()
Expand Down
1 change: 0 additions & 1 deletion integration/tests/posit/connect/test_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import pytest
from packaging import version

from posit import connect

from . import CONNECT_VERSION
Expand Down
1 change: 1 addition & 0 deletions src/posit/connect/external/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# NOTE: The APIs in this module are provided as a convenience and are subject to breaking changes.
49 changes: 18 additions & 31 deletions src/posit/connect/external/databricks.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import abc
import os
from typing import Callable, Dict, Optional

from ..client import Client
from ..oauth import OAuthIntegration
from .external import is_local

"""
NOTE: These APIs are provided as a convenience and are subject to breaking changes:
Expand All @@ -30,24 +29,14 @@ def __call__(self, *args, **kwargs) -> CredentialsProvider:
raise NotImplementedError


def _is_local() -> bool:
"""Returns true if called from a piece of content running on a Connect server.

The connect server will always set the environment variable `RSTUDIO_PRODUCT=CONNECT`.
We can use this environment variable to determine if the content is running locally
or on a Connect server.
"""
return not os.getenv("RSTUDIO_PRODUCT") == "CONNECT"


class PositCredentialsProvider:
def __init__(self, posit_oauth: OAuthIntegration, user_session_token: str):
self.posit_oauth = posit_oauth
self.user_session_token = user_session_token
def __init__(self, client: Client, user_session_token: str):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

breaking: PositCredentialsProvider now accepts a Client instead of an OAuthIntegration resource. This fits better with the API changes that @zackverham is about to make.

self._client = client
self._user_session_token = user_session_token
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

breaking: These fields are now _internal

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! If you can capture the breaking changes in the PR description, it will be easy to pull into the release notes.


def __call__(self) -> Dict[str, str]:
access_token = self.posit_oauth.get_credentials(
self.user_session_token
access_token = self._client.oauth.get_credentials(
self._user_session_token
)["access_token"]
return {"Authorization": f"Bearer {access_token}"}

Expand All @@ -56,12 +45,12 @@ class PositCredentialsStrategy(CredentialsStrategy):
def __init__(
self,
local_strategy: CredentialsStrategy,
user_session_token: Optional[str] = None,
client: Optional[Client] = None,
user_session_token: Optional[str] = None,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

potentially breaking? These are named arguments but I did change the order. Callers who are using args instead of kwargs may break

):
self.user_session_token = user_session_token
self.local_strategy = local_strategy
self.client = client
self._local_strategy = local_strategy
self._client = client
self._user_session_token = user_session_token
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

breaking: These fields are now _internal


def sql_credentials_provider(self, *args, **kwargs):
"""The sql connector attempts to call the credentials provider w/o any args.
Expand Down Expand Up @@ -89,26 +78,24 @@ def auth_type(self) -> str:
NOTE: The databricks-sql client does not use auth_type to set the user-agent.
https://github.com/databricks/databricks-sql-python/blob/v3.3.0/src/databricks/sql/client.py#L214-L219
"""
if _is_local():
return self.local_strategy.auth_type()
if is_local():
return self._local_strategy.auth_type()
else:
return "posit-oauth-integration"

def __call__(self, *args, **kwargs) -> CredentialsProvider:
# If the content is not running on Connect then fall back to local_strategy
if _is_local():
return self.local_strategy(*args, **kwargs)
if is_local():
return self._local_strategy(*args, **kwargs)

# If the user-session-token wasn't provided and we're running on Connect then we raise an exception.
# user_session_token is required to impersonate the viewer.
if self.user_session_token is None:
if self._user_session_token is None:
raise ValueError(
"The user-session-token is required for viewer authentication."
)

if self.client is None:
self.client = Client()
if self._client is None:
self._client = Client()

return PositCredentialsProvider(
self.client.oauth, self.user_session_token
)
return PositCredentialsProvider(self._client, self._user_session_token)
11 changes: 11 additions & 0 deletions src/posit/connect/external/external.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import os


def is_local() -> bool:
"""Returns true if called from a piece of content running on a Connect server.

The connect server will always set the environment variable `RSTUDIO_PRODUCT=CONNECT`.
We can use this environment variable to determine if the content is running locally
or on a Connect server.
"""
return not os.getenv("RSTUDIO_PRODUCT") == "CONNECT"
46 changes: 46 additions & 0 deletions src/posit/connect/external/snowflake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Optional

from ..client import Client
from .external import is_local

"""
NOTE: The APIs in this module are provided as a convenience and are subject to breaking changes.
"""


class PositAuthenticator:
tdstein marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
local_authenticator: Optional[str] = None,
client: Optional[Client] = None,
user_session_token: Optional[str] = None,
):
self._local_authenticator = local_authenticator
self._client = client
self._user_session_token = user_session_token

@property
def authenticator(self) -> Optional[str]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a @property?

if is_local():
return self._local_authenticator
return "oauth"

@property
def token(self) -> Optional[str]:
if is_local():
return None

# If the user-session-token wasn't provided and we're running on Connect then we raise an exception.
# user_session_token is required to impersonate the viewer.
if self._user_session_token is None:
raise ValueError(
"The user-session-token is required for viewer authentication."
)

if self._client is None:
self._client = Client()

access_token = self._client.oauth.get_credentials(
self._user_session_token
)["access_token"]
return access_token
5 changes: 1 addition & 4 deletions tests/posit/connect/external/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from unittest.mock import patch

import responses

from posit.connect import Client
from posit.connect.external.databricks import (
CredentialsProvider,
Expand Down Expand Up @@ -48,9 +47,7 @@ def test_posit_credentials_provider(self):
register_mocks()

client = Client(api_key="12345", url="https://connect.example/")
cp = PositCredentialsProvider(
posit_oauth=client.oauth, user_session_token="cit"
)
cp = PositCredentialsProvider(client=client, user_session_token="cit")
assert cp() == {"Authorization": f"Bearer dynamic-viewer-access-token"}

@responses.activate
Expand Down
Loading
Loading