Skip to content

Commit

Permalink
chore: apply formatting (#254)
Browse files Browse the repository at this point in the history
  • Loading branch information
tdstein authored Aug 8, 2024
1 parent 37bc1dc commit cf4ca56
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 29 deletions.
11 changes: 7 additions & 4 deletions examples/connect/dash/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@ def update_page(_):
"Posit-Connect-User-Session-Token"
)
posit_strategy = PositCredentialsStrategy(
local_strategy=databricks_cli,
user_session_token=session_token)
local_strategy=databricks_cli, user_session_token=session_token
)
cfg = Config(
host=DATABRICKS_HOST_URL,
# uses Posit's custom credential_strategy if running on Connect,
# otherwise falls back to the strategy defined by local_strategy
credentials_strategy=posit_strategy)
credentials_strategy=posit_strategy,
)

def get_greeting():
databricks_user_info = CurrentUserAPI(ApiClient(cfg)).me()
Expand All @@ -61,7 +62,9 @@ def get_table():
server_hostname=DATABRICKS_HOST,
http_path=SQL_HTTP_PATH,
# https://github.com/databricks/databricks-sql-python/issues/148#issuecomment-2271561365
credentials_provider=posit_strategy.sql_credentials_provider(cfg)
credentials_provider=posit_strategy.sql_credentials_provider(
cfg
),
) as connection:
with connection.cursor() as cursor:
cursor.execute(query)
Expand Down
8 changes: 5 additions & 3 deletions examples/connect/fastapi/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@ async def get_fares(

posit_strategy = PositCredentialsStrategy(
local_strategy=databricks_cli,
user_session_token=posit_connect_user_session_token)
user_session_token=posit_connect_user_session_token,
)
cfg = Config(
host=DATABRICKS_HOST_URL,
# uses Posit's custom credential_strategy if running on Connect,
# otherwise falls back to the strategy defined by local_strategy
credentials_strategy=posit_strategy)
credentials_strategy=posit_strategy,
)

if rows is None:
query = "SELECT * FROM samples.nyctaxi.trips LIMIT 10;"
Expand All @@ -43,7 +45,7 @@ async def get_fares(
server_hostname=DATABRICKS_HOST,
http_path=SQL_HTTP_PATH,
# https://github.com/databricks/databricks-sql-python/issues/148#issuecomment-2271561365
credentials_provider=posit_strategy.sql_credentials_provider(cfg)
credentials_provider=posit_strategy.sql_credentials_provider(cfg),
) as connection:
with connection.cursor() as cursor:
cursor.execute(query)
Expand Down
9 changes: 5 additions & 4 deletions examples/connect/flask/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,14 @@ def get_fares():

session_token = request.headers.get("Posit-Connect-User-Session-Token")
posit_strategy = PositCredentialsStrategy(
local_strategy=databricks_cli,
user_session_token=session_token)
local_strategy=databricks_cli, user_session_token=session_token
)
cfg = Config(
host=DATABRICKS_HOST_URL,
# uses Posit's custom credential_strategy if running on Connect,
# otherwise falls back to the strategy defined by local_strategy
credentials_strategy=posit_strategy)
credentials_strategy=posit_strategy,
)

if rows is None:
query = "SELECT * FROM samples.nyctaxi.trips LIMIT 10;"
Expand All @@ -45,7 +46,7 @@ def get_fares():
server_hostname=DATABRICKS_HOST,
http_path=SQL_HTTP_PATH,
# https://github.com/databricks/databricks-sql-python/issues/148#issuecomment-2271561365
credentials_provider=posit_strategy.sql_credentials_provider(cfg)
credentials_provider=posit_strategy.sql_credentials_provider(cfg),
) as connection:
with connection.cursor() as cursor:
cursor.execute(query)
Expand Down
7 changes: 4 additions & 3 deletions examples/connect/shiny-python/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@ def server(i: Inputs, o: Outputs, session: Session):
"Posit-Connect-User-Session-Token"
)
posit_strategy = PositCredentialsStrategy(
local_strategy=databricks_cli,
user_session_token=session_token)
local_strategy=databricks_cli, user_session_token=session_token
)
cfg = Config(
host=DATABRICKS_HOST_URL,
# uses Posit's custom credential_strategy if running on Connect,
# otherwise falls back to the strategy defined by local_strategy
credentials_strategy=posit_strategy)
credentials_strategy=posit_strategy,
)

@render.data_frame
def result():
Expand Down
9 changes: 5 additions & 4 deletions examples/connect/streamlit/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@

session_token = st.context.headers.get("Posit-Connect-User-Session-Token")
posit_strategy = PositCredentialsStrategy(
local_strategy=databricks_cli,
user_session_token=session_token)
local_strategy=databricks_cli, user_session_token=session_token
)
cfg = Config(
host=DATABRICKS_HOST_URL,
# uses Posit's custom credential_strategy if running on Connect,
# otherwise falls back to the strategy defined by local_strategy
credentials_strategy=posit_strategy)
credentials_strategy=posit_strategy,
)

databricks_user = CurrentUserAPI(ApiClient(cfg)).me()
st.write(f"Hello, {databricks_user.display_name}!")
Expand All @@ -30,7 +31,7 @@
server_hostname=DATABRICKS_HOST,
http_path=SQL_HTTP_PATH,
# https://github.com/databricks/databricks-sql-python/issues/148#issuecomment-2271561365
credentials_provider=posit_strategy.sql_credentials_provider(cfg)
credentials_provider=posit_strategy.sql_credentials_provider(cfg),
) as connection:
with connection.cursor() as cursor:
cursor.execute("SELECT * FROM samples.nyctaxi.trips LIMIT 10;")
Expand Down
7 changes: 4 additions & 3 deletions src/posit/connect/external/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# The Databricks SDK CredentialsProvider == Databricks SQL HeaderFactory
CredentialsProvider = Callable[[], Dict[str, str]]


class CredentialsStrategy(abc.ABC):
"""Maintain compatibility with the Databricks SQL/SDK client libraries.
Expand Down Expand Up @@ -52,11 +53,11 @@ def __call__(self) -> Dict[str, str]:


class PositCredentialsStrategy(CredentialsStrategy):

def __init__(self,
def __init__(
self,
local_strategy: CredentialsStrategy,
user_session_token: Optional[str] = None,
client: Optional[Client] = None
client: Optional[Client] = None,
):
self.user_session_token = user_session_token
self.local_strategy = local_strategy
Expand Down
24 changes: 16 additions & 8 deletions tests/posit/connect/external/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
class mock_strategy:
def auth_type(self) -> str:
return "local"

def __call__(self) -> CredentialsProvider:
def inner() -> Dict[str,str]:
def inner() -> Dict[str, str]:
return {"Authorization": "Bearer static-pat-token"}

return inner


Expand Down Expand Up @@ -45,7 +47,9 @@ def test_posit_credentials_provider(self):
register_mocks()

client = Client(api_key="12345", url="https://connect.example/")
cp = PositCredentialsProvider(posit_oauth=client.oauth, user_session_token="cit")
cp = PositCredentialsProvider(
posit_oauth=client.oauth, user_session_token="cit"
)
assert cp() == {"Authorization": f"Bearer dynamic-viewer-access-token"}

@responses.activate
Expand All @@ -54,19 +58,23 @@ def test_posit_credentials_strategy(self):
register_mocks()

client = Client(api_key="12345", url="https://connect.example/")
cs = PositCredentialsStrategy(local_strategy=mock_strategy(),
user_session_token="cit",
client=client)
cs = PositCredentialsStrategy(
local_strategy=mock_strategy(),
user_session_token="cit",
client=client,
)
cp = cs()
assert cs.auth_type() == "posit-oauth-integration"
assert cp() == {"Authorization": "Bearer dynamic-viewer-access-token"}

def test_posit_credentials_strategy_fallback(self):
# local_strategy is used when the content is running locally
client = Client(api_key="12345", url="https://connect.example/")
cs = PositCredentialsStrategy(local_strategy=mock_strategy(),
user_session_token="cit",
client=client)
cs = PositCredentialsStrategy(
local_strategy=mock_strategy(),
user_session_token="cit",
client=client,
)
cp = cs()
assert cs.auth_type() == "local"
assert cp() == {"Authorization": "Bearer static-pat-token"}

0 comments on commit cf4ca56

Please sign in to comment.