Skip to content

Commit

Permalink
feat: add version check support (#299)
Browse files Browse the repository at this point in the history
Adds a new decorator, `@context.requires`, which asserts version
compatibility when the server version is known. The check is skipped if
the server version is unknown (e.g., the Connect configuration disables
version information).

Also marks the OAuth API with a '2024.08.0' requirement.

Closes #272
  • Loading branch information
tdstein authored Oct 7, 2024
1 parent 0614ae4 commit 0e9cf02
Show file tree
Hide file tree
Showing 12 changed files with 185 additions and 15 deletions.
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ classifiers = [
"Typing :: Typed",
]
dynamic = ["version"]
dependencies = ["requests>=2.31.0,<3"]
dependencies = [
"requests>=2.31.0,<3",
"packaging"
]

[project.urls]
Source = "https://github.com/posit-dev/posit-sdk-py"
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
requests==2.32.2
packaging==24.1
11 changes: 7 additions & 4 deletions src/posit/connect/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

from __future__ import annotations

from typing import overload
from typing import Optional, overload

from requests import Response, Session

from . import hooks, me
from .auth import Auth
from .config import Config
from .content import Content
from .context import Context, ContextManager, requires
from .groups import Groups
from .metrics import Metrics
from .oauth import OAuth
Expand All @@ -18,7 +19,7 @@
from .users import User, Users


class Client:
class Client(ContextManager):
"""
Client connection for Posit Connect.
Expand Down Expand Up @@ -156,9 +157,10 @@ def __init__(self, *args, **kwargs) -> None:
session.hooks["response"].append(hooks.handle_errors)
self.session = session
self.resource_params = ResourceParameters(session, self.cfg.url)
self.ctx = Context(self.session, self.cfg.url)

@property
def version(self) -> str:
def version(self) -> Optional[str]:
"""
The server version.
Expand All @@ -167,7 +169,7 @@ def version(self) -> str:
str
The version of the Posit Connect server.
"""
return self.get("server_settings").json()["version"]
return self.ctx.version

@property
def me(self) -> User:
Expand Down Expand Up @@ -257,6 +259,7 @@ def metrics(self) -> Metrics:
return Metrics(self.resource_params)

@property
@requires(version="2024.08.0")
def oauth(self) -> OAuth:
"""
The OAuth API interface.
Expand Down
45 changes: 45 additions & 0 deletions src/posit/connect/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import functools
from typing import Optional, Protocol

from packaging.version import Version


def requires(version: str):
def decorator(func):
@functools.wraps(func)
def wrapper(instance: ContextManager, *args, **kwargs):
ctx = instance.ctx
if ctx.version and Version(ctx.version) < Version(version):
raise RuntimeError(
f"This API is not available in Connect version {ctx.version}. Please upgrade to version {version} or later.",
)
return func(instance, *args, **kwargs)

return wrapper

return decorator


class Context(dict):
def __init__(self, session, url):
self.session = session
self.url = url

@property
def version(self) -> Optional[str]:
try:
value = self["version"]
except KeyError:
endpoint = self.url + "server_settings"
response = self.session.get(endpoint)
result = response.json()
value = self["version"] = result.get("version")
return value

@version.setter
def version(self, value: str):
self["version"] = value


class ContextManager(Protocol):
ctx: Context
2 changes: 2 additions & 0 deletions tests/posit/connect/external/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def test_posit_credentials_provider(self):
register_mocks()

client = Client(api_key="12345", url="https://connect.example/")
client.ctx.version = None
cp = PositCredentialsProvider(client=client, user_session_token="cit")
assert cp() == {"Authorization": f"Bearer dynamic-viewer-access-token"}

Expand All @@ -57,6 +58,7 @@ def test_posit_credentials_strategy(self):
register_mocks()

client = Client(api_key="12345", url="https://connect.example/")
client.ctx.version = None
cs = PositCredentialsStrategy(
local_strategy=mock_strategy(),
user_session_token="cit",
Expand Down
2 changes: 2 additions & 0 deletions tests/posit/connect/external/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def test_posit_authenticator(self):
register_mocks()

client = Client(api_key="12345", url="https://connect.example/")
client.ctx.version = None
auth = PositAuthenticator(
local_authenticator="SNOWFLAKE",
user_session_token="cit",
Expand All @@ -44,6 +45,7 @@ def test_posit_authenticator(self):
def test_posit_authenticator_fallback(self):
# local_authenticator is used when the content is running locally
client = Client(api_key="12345", url="https://connect.example/")
client.ctx.version = None
auth = PositAuthenticator(
local_authenticator="SNOWFLAKE",
user_session_token="cit",
Expand Down
4 changes: 4 additions & 0 deletions tests/posit/connect/oauth/test_associations.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def test(self):

# setup
c = Client("https://connect.example", "12345")
c.ctx.version = None
# invoke
associations = c.oauth.integrations.get(guid).associations.find()

Expand Down Expand Up @@ -83,6 +84,7 @@ def test(self):

# setup
c = Client("https://connect.example", "12345")
c.ctx.version = None
# invoke
associations = c.content.get(guid).oauth.associations.find()

Expand Down Expand Up @@ -115,6 +117,7 @@ def test(self):

# setup
c = Client("https://connect.example", "12345")
c.ctx.version = None

# invoke
c.content.get(guid).oauth.associations.update(new_integration_guid)
Expand Down Expand Up @@ -142,6 +145,7 @@ def test(self):

# setup
c = Client("https://connect.example", "12345")
c.ctx.version = None

# invoke
c.content.get(guid).oauth.associations.delete()
Expand Down
9 changes: 7 additions & 2 deletions tests/posit/connect/oauth/test_integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def test(self):

# setup
c = Client("https://connect.example", "12345")
c.ctx.version = None
integration = c.oauth.integrations.get(guid)

# invoke
Expand All @@ -93,6 +94,7 @@ def test(self):
)

c = Client("https://connect.example", "12345")
c.ctx.version = None
integration = c.oauth.integrations.get(guid)
assert integration.guid == guid

Expand Down Expand Up @@ -137,6 +139,7 @@ def test(self):

# setup
c = Client("https://connect.example", "12345")
c.ctx.version = None

# invoke
integration = c.oauth.integrations.create(
Expand Down Expand Up @@ -164,10 +167,11 @@ def test(self):
)

# setup
client = Client("https://connect.example", "12345")
c = Client("https://connect.example", "12345")
c.ctx.version = None

# invoke
integrations = client.oauth.integrations.find()
integrations = c.oauth.integrations.find()

# assert
assert mock_get.call_count == 1
Expand All @@ -189,6 +193,7 @@ def test(self):

# setup
c = Client("https://connect.example", "12345")
c.ctx.version = None
integration = c.oauth.integrations.get(guid)

assert mock_get.call_count == 1
Expand Down
5 changes: 3 additions & 2 deletions tests/posit/connect/oauth/test_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,6 @@ def test_get_credentials(self):
"token_type": "Bearer",
},
)
con = Client(api_key="12345", url="https://connect.example/")
assert con.oauth.get_credentials("cit")["access_token"] == "viewer-token"
c = Client(api_key="12345", url="https://connect.example/")
c.ctx.version = None
assert c.oauth.get_credentials("cit")["access_token"] == "viewer-token"
16 changes: 10 additions & 6 deletions tests/posit/connect/oauth/test_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def test(self):

# setup
c = Client("https://connect.example", "12345")
c.ctx.version = None
session = c.oauth.sessions.get(guid)

# invoke
Expand All @@ -72,10 +73,11 @@ def test(self):
)

# setup
client = Client("https://connect.example", "12345")
c = Client("https://connect.example", "12345")
c.ctx.version = None

# invoke
sessions = client.oauth.sessions.find()
sessions = c.oauth.sessions.find()

# assert
assert mock_get.call_count == 1
Expand All @@ -94,10 +96,11 @@ def test_params_all(self):
)

# setup
client = Client("https://connect.example", "12345")
c = Client("https://connect.example", "12345")
c.ctx.version = None

# invoke
client.oauth.sessions.find(all=True)
c.oauth.sessions.find(all=True)

# assert
assert mock_get.call_count == 1
Expand All @@ -115,10 +118,11 @@ def test(self):
)

# setup
client = Client("https://connect.example", "12345")
c = Client("https://connect.example", "12345")
c.ctx.version = None

# invoke
session = client.oauth.sessions.get(guid=guid)
session = c.oauth.sessions.get(guid=guid)

# assert
assert mock_get.call_count == 1
Expand Down
10 changes: 10 additions & 0 deletions tests/posit/connect/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,13 @@ def test_delete(self, MockSession):
client = Client(api_key=api_key, url=url)
client.delete("/foo")
client.session.delete.assert_called_once_with("https://connect.example.com/__api__/foo")


class TestClientOAuth:
def test_required_version(self):
api_key = "12345"
url = "https://connect.example.com"
client = Client(api_key=api_key, url=url)
client.ctx.version = "2024.07.0"
with pytest.raises(RuntimeError):
client.oauth
90 changes: 90 additions & 0 deletions tests/posit/connect/test_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from email.contentmanager import ContentManager
from unittest.mock import MagicMock, Mock

import pytest
import requests
import responses

from posit.connect.context import Context, requires
from posit.connect.urls import Url


class TestRequires:
def test_version_unsupported(self):
class Stub(ContentManager):
def __init__(self, ctx):
self.ctx = ctx

@requires("1.0.0")
def fail(self):
pass

ctx = MagicMock()
ctx.version = "0.0.0"
instance = Stub(ctx)

with pytest.raises(RuntimeError):
instance.fail()

def test_version_supported(self):
class Stub(ContentManager):
def __init__(self, ctx):
self.ctx = ctx

@requires("1.0.0")
def success(self):
pass

ctx = MagicMock()
ctx.version = "1.0.0"
instance = Stub(ctx)

instance.success()

def test_version_missing(self):
class Stub(ContentManager):
def __init__(self, ctx):
self.ctx = ctx

@requires("1.0.0")
def success(self):
pass

ctx = MagicMock()
ctx.version = None
instance = Stub(ctx)

instance.success()


class TestContextVersion:
@responses.activate
def test_unknown(self):
responses.get(
f"http://connect.example/__api__/server_settings",
json={},
)

session = requests.Session()
url = Url("http://connect.example")
ctx = Context(session, url)

assert ctx.version is None

@responses.activate
def test_known(self):
responses.get(
f"http://connect.example/__api__/server_settings",
json={"version": "2024.09.24"},
)

session = requests.Session()
url = Url("http://connect.example")
ctx = Context(session, url)

assert ctx.version == "2024.09.24"

def test_setter(self):
ctx = Context(Mock(), Mock())
ctx.version = "2024.09.24"
assert ctx.version == "2024.09.24"

0 comments on commit 0e9cf02

Please sign in to comment.