Skip to content

Commit

Permalink
refactor: Passing config dict to OAuthSettings constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
frascuchon committed Nov 14, 2024
1 parent 7b919d2 commit 4e91069
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,16 @@ class OAuth2Settings:

ALLOWED_WORKSPACES_KEY = "allowed_workspaces"
PROVIDERS_KEY = "providers"
EXTRA_BACKENDS_KEY = "extra_backends"

def __init__(
self,
allow_http_redirect: bool = False,
providers: List[OAuth2ClientProvider] = None,
allowed_workspaces: List[AllowedWorkspace] = None,
**kwargs, # Ignore any other key
**settings,
):
self.allow_http_redirect = allow_http_redirect
self.allowed_workspaces = allowed_workspaces or []
self._providers = providers or []
self.allowed_workspaces = self._build_workspaces(settings) or []
self._providers = self._build_providers(settings) or []

if self.allow_http_redirect:
# See https://stackoverflow.com/questions/27785375/testing-flask-oauthlib-locally-without-https
Expand All @@ -67,16 +66,7 @@ def from_yaml(cls, yaml_file: str) -> "OAuth2Settings":
"""Creates an instance of OAuth2Settings from a YAML file."""

with open(yaml_file) as f:
return cls.from_dict(yaml.safe_load(f))

@classmethod
def from_dict(cls, settings: dict) -> "OAuth2Settings":
"""Creates an instance of OAuth2Settings from a dictionary."""

settings[cls.PROVIDERS_KEY] = cls._build_providers(settings)
settings[cls.ALLOWED_WORKSPACES_KEY] = cls._build_workspaces(settings)

return cls(**settings)
return cls(**yaml.safe_load(f))

@classmethod
def _build_workspaces(cls, settings: dict) -> List[AllowedWorkspace]:
Expand Down
18 changes: 8 additions & 10 deletions argilla-server/tests/unit/api/handlers/v1/test_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,14 @@

@pytest.fixture
def default_oauth_settings() -> OAuth2Settings:
return OAuth2Settings.from_dict(
{
"providers": [
{
"name": "huggingface",
"client_id": "client_id",
"client_secret": "client_secret",
}
],
}
return OAuth2Settings(
providers=[
{
"name": "huggingface",
"client_id": "client_id",
"client_secret": "client_secret",
}
]
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,18 @@
class TestOAuth2Settings:
def test_configure_unsupported_provider(self):
with pytest.raises(NotFoundError):
OAuth2Settings.from_dict({"providers": [{"name": "unsupported"}]})
OAuth2Settings(providers=[{"name": "unsupported"}])

def test_configure_github_provider(self):
settings = OAuth2Settings.from_dict(
{
"providers": [
{
"name": "github",
"client_id": "github_client_id",
"client_secret": "github_client_secret",
"scope": "user:email",
}
]
}
settings = OAuth2Settings(
providers=[
{
"name": "github",
"client_id": "github_client_id",
"client_secret": "github_client_secret",
"scope": "user:email",
}
]
)
github_provider = settings.providers["github"]

Expand All @@ -44,17 +42,15 @@ def test_configure_github_provider(self):
assert github_provider.scope == ["user:email"]

def test_configure_huggingface_provider(self):
settings = OAuth2Settings.from_dict(
{
"providers": [
{
"name": "huggingface",
"client_id": "huggingface_client_id",
"client_secret": "huggingface_client_secret",
"scope": "openid profile email",
}
]
}
settings = OAuth2Settings(
providers=[
{
"name": "huggingface",
"client_id": "huggingface_client_id",
"client_secret": "huggingface_client_secret",
"scope": "openid profile email",
}
]
)
huggingface_provider = settings.providers["huggingface"]

Expand Down
8 changes: 2 additions & 6 deletions argilla-server/tests/unit/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,7 @@ def test_server_timing_header(self):
async def test_create_allowed_workspaces(self, db: AsyncSession):
with mock.patch(
"argilla_server.security.settings.Settings.oauth",
new_callable=lambda: OAuth2Settings.from_dict(
{
"allowed_workspaces": [{"name": "ws1"}, {"name": "ws2"}],
}
),
new_callable=lambda: OAuth2Settings(allowed_workspaces=[{"name": "ws1"}, {"name": "ws2"}]),
):
await _create_oauth_allowed_workspaces(db)

Expand All @@ -102,7 +98,7 @@ async def test_create_workspaces_with_existing_workspaces(self, db: AsyncSession

with mock.patch(
"argilla_server.security.settings.Settings.oauth",
new_callable=lambda: OAuth2Settings(allowed_workspaces=[AllowedWorkspace(name=ws.name)]),
new_callable=lambda: OAuth2Settings(allowed_workspaces=[{"name": ws.name}]),
):
await _create_oauth_allowed_workspaces(db)

Expand Down

0 comments on commit 4e91069

Please sign in to comment.