-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ENH] CIP-2: Auth Providers Proposal (#986)
## Description of changes *Summarize the changes made by this PR.* - New functionality - Auth Provide Client and Server Side Abstractions - Basic Auth Provider ## Test plan Unit tests for authorized endpoints ## Documentation Changes Docs should change to describe how to use auth providers on the client and server. CIP added in `docs/`
- Loading branch information
Showing
24 changed files
with
1,458 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
""" | ||
Contains only Auth abstractions, no implementations. | ||
""" | ||
import base64 | ||
import logging | ||
from abc import ABC, abstractmethod | ||
from enum import Enum | ||
from typing import ( | ||
Optional, | ||
Dict, | ||
TypeVar, | ||
Tuple, | ||
Generic, | ||
) | ||
|
||
from overrides import EnforceOverrides, override | ||
from pydantic import SecretStr | ||
|
||
from chromadb.config import ( | ||
Component, | ||
System, | ||
) | ||
from chromadb.errors import ChromaError | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
T = TypeVar("T") | ||
S = TypeVar("S") | ||
|
||
|
||
class AuthInfoType(Enum): | ||
COOKIE = "cookie" | ||
HEADER = "header" | ||
URL = "url" | ||
METADATA = "metadata" # gRPC | ||
|
||
|
||
class ClientAuthResponse(EnforceOverrides, ABC): | ||
@abstractmethod | ||
def get_auth_info_type(self) -> AuthInfoType: | ||
... | ||
|
||
@abstractmethod | ||
def get_auth_info(self) -> Tuple[str, SecretStr]: | ||
... | ||
|
||
|
||
class ClientAuthProvider(Component): | ||
def __init__(self, system: System) -> None: | ||
super().__init__(system) | ||
|
||
@abstractmethod | ||
def authenticate(self) -> ClientAuthResponse: | ||
pass | ||
|
||
|
||
class ClientAuthConfigurationProvider(Component): | ||
def __init__(self, system: System) -> None: | ||
super().__init__(system) | ||
|
||
@abstractmethod | ||
def get_configuration(self) -> Optional[T]: | ||
pass | ||
|
||
|
||
class ClientAuthCredentialsProvider(Component, Generic[T]): | ||
def __init__(self, system: System) -> None: | ||
super().__init__(system) | ||
|
||
@abstractmethod | ||
def get_credentials(self) -> T: | ||
pass | ||
|
||
|
||
class ClientAuthProtocolAdapter(Component, Generic[T]): | ||
def __init__(self, system: System) -> None: | ||
super().__init__(system) | ||
|
||
@abstractmethod | ||
def inject_credentials(self, injection_context: T) -> None: | ||
pass | ||
|
||
|
||
# SERVER-SIDE Abstractions | ||
|
||
|
||
class ServerAuthenticationRequest(EnforceOverrides, ABC, Generic[T]): | ||
@abstractmethod | ||
def get_auth_info( | ||
self, auth_info_type: AuthInfoType, auth_info_id: Optional[str] = None | ||
) -> T: | ||
""" | ||
This method should return the necessary auth info based on the type of authentication (e.g. header, cookie, url) | ||
and a given id for the respective auth type (e.g. name of the header, cookie, url param). | ||
:param auth_info_type: The type of auth info to return | ||
:param auth_info_id: The id of the auth info to return | ||
:return: The auth info which can be specific to the implementation | ||
""" | ||
pass | ||
|
||
|
||
class ServerAuthenticationResponse(EnforceOverrides, ABC): | ||
def success(self) -> bool: | ||
raise NotImplementedError() | ||
|
||
|
||
class ServerAuthProvider(Component): | ||
def __init__(self, system: System) -> None: | ||
super().__init__(system) | ||
|
||
@abstractmethod | ||
def authenticate(self, request: ServerAuthenticationRequest[T]) -> bool: | ||
pass | ||
|
||
|
||
class ChromaAuthMiddleware(Component): | ||
def __init__(self, system: System) -> None: | ||
super().__init__(system) | ||
|
||
@abstractmethod | ||
def authenticate( | ||
self, request: ServerAuthenticationRequest[T] | ||
) -> Optional[ServerAuthenticationResponse]: | ||
... | ||
|
||
@abstractmethod | ||
def ignore_operation(self, verb: str, path: str) -> bool: | ||
... | ||
|
||
@abstractmethod | ||
def instrument_server(self, app: T) -> None: | ||
... | ||
|
||
|
||
class ServerAuthConfigurationProvider(Component): | ||
def __init__(self, system: System) -> None: | ||
super().__init__(system) | ||
|
||
@abstractmethod | ||
def get_configuration(self) -> Optional[T]: | ||
pass | ||
|
||
|
||
class AuthenticationError(ChromaError): | ||
@override | ||
def code(self) -> int: | ||
return 401 | ||
|
||
@classmethod | ||
@override | ||
def name(cls) -> str: | ||
return "AuthenticationError" | ||
|
||
|
||
class AbstractCredentials(EnforceOverrides, ABC, Generic[T]): | ||
""" | ||
The class is used by Auth Providers to encapsulate credentials received from the server | ||
and pass them to a ServerAuthCredentialsProvider. | ||
""" | ||
|
||
@abstractmethod | ||
def get_credentials(self) -> Dict[str, T]: | ||
""" | ||
Returns the data encapsulated by the credentials object. | ||
""" | ||
pass | ||
|
||
|
||
class SecretStrAbstractCredentials(AbstractCredentials[SecretStr]): | ||
@abstractmethod | ||
@override | ||
def get_credentials(self) -> Dict[str, SecretStr]: | ||
""" | ||
Returns the data encapsulated by the credentials object. | ||
""" | ||
pass | ||
|
||
|
||
class BasicAuthCredentials(SecretStrAbstractCredentials): | ||
def __init__(self, username: SecretStr, password: SecretStr) -> None: | ||
self.username = username | ||
self.password = password | ||
|
||
@override | ||
def get_credentials(self) -> Dict[str, SecretStr]: | ||
return {"username": self.username, "password": self.password} | ||
|
||
@staticmethod | ||
def from_header(header: str) -> "BasicAuthCredentials": | ||
""" | ||
Parses a basic auth header and returns a BasicAuthCredentials object. | ||
""" | ||
header = header.replace("Basic ", "") | ||
header = header.strip() | ||
base64_decoded = base64.b64decode(header).decode("utf-8") | ||
username, password = base64_decoded.split(":") | ||
return BasicAuthCredentials(SecretStr(username), SecretStr(password)) | ||
|
||
|
||
class ServerAuthCredentialsProvider(Component): | ||
def __init__(self, system: System) -> None: | ||
super().__init__(system) | ||
|
||
@abstractmethod | ||
def validate_credentials(self, credentials: AbstractCredentials[T]) -> bool: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import base64 | ||
import logging | ||
from typing import Tuple, Any, cast | ||
|
||
from overrides import override | ||
from pydantic import SecretStr | ||
|
||
from chromadb.auth import ( | ||
ServerAuthProvider, | ||
ClientAuthProvider, | ||
ServerAuthenticationRequest, | ||
ServerAuthCredentialsProvider, | ||
AuthInfoType, | ||
BasicAuthCredentials, | ||
ClientAuthCredentialsProvider, | ||
ClientAuthResponse, | ||
) | ||
from chromadb.auth.registry import register_provider, resolve_provider | ||
from chromadb.config import System | ||
from chromadb.utils import get_class | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
__all__ = ["BasicAuthServerProvider", "BasicAuthClientProvider"] | ||
|
||
|
||
class BasicAuthClientAuthResponse(ClientAuthResponse): | ||
def __init__(self, credentials: SecretStr) -> None: | ||
self._credentials = credentials | ||
|
||
@override | ||
def get_auth_info_type(self) -> AuthInfoType: | ||
return AuthInfoType.HEADER | ||
|
||
@override | ||
def get_auth_info(self) -> Tuple[str, SecretStr]: | ||
return "Authorization", SecretStr( | ||
f"Basic {self._credentials.get_secret_value()}" | ||
) | ||
|
||
|
||
@register_provider("basic") | ||
class BasicAuthClientProvider(ClientAuthProvider): | ||
_credentials_provider: ClientAuthCredentialsProvider[Any] | ||
|
||
def __init__(self, system: System) -> None: | ||
super().__init__(system) | ||
self._settings = system.settings | ||
system.settings.require("chroma_client_auth_credentials_provider") | ||
self._credentials_provider = system.require( | ||
get_class( | ||
str(system.settings.chroma_client_auth_credentials_provider), | ||
ClientAuthCredentialsProvider, | ||
) | ||
) | ||
|
||
@override | ||
def authenticate(self) -> ClientAuthResponse: | ||
_creds = self._credentials_provider.get_credentials() | ||
return BasicAuthClientAuthResponse( | ||
SecretStr( | ||
base64.b64encode(f"{_creds.get_secret_value()}".encode("utf-8")).decode( | ||
"utf-8" | ||
) | ||
) | ||
) | ||
|
||
|
||
@register_provider("basic") | ||
class BasicAuthServerProvider(ServerAuthProvider): | ||
_credentials_provider: ServerAuthCredentialsProvider | ||
|
||
def __init__(self, system: System) -> None: | ||
super().__init__(system) | ||
self._settings = system.settings | ||
system.settings.require("chroma_server_auth_credentials_provider") | ||
self._credentials_provider = cast( | ||
ServerAuthCredentialsProvider, | ||
system.require( | ||
resolve_provider( | ||
str(system.settings.chroma_server_auth_credentials_provider), | ||
ServerAuthCredentialsProvider, | ||
) | ||
), | ||
) | ||
|
||
@override | ||
def authenticate(self, request: ServerAuthenticationRequest[Any]) -> bool: | ||
try: | ||
_auth_header = request.get_auth_info(AuthInfoType.HEADER, "Authorization") | ||
return self._credentials_provider.validate_credentials( | ||
BasicAuthCredentials.from_header(_auth_header) | ||
) | ||
except Exception as e: | ||
logger.error(f"BasicAuthServerProvider.authenticate failed: {repr(e)}") | ||
return False |
Oops, something went wrong.