Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] CIP-2: Auth Providers Proposal #986

Merged
merged 15 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 37 additions & 10 deletions chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
import json
from typing import Optional, cast
from typing import Sequence
from uuid import UUID

import requests
from overrides import override

import chromadb.errors as errors
import chromadb.utils.embedding_functions as ef
from chromadb.api import API
from chromadb.config import Settings, System
from chromadb.api.models.Collection import Collection
from chromadb.api.types import (
Documents,
Embeddings,
Expand All @@ -14,15 +23,13 @@
QueryResult,
CollectionMetadata,
)
import chromadb.utils.embedding_functions as ef
import requests
import json
from typing import Sequence
from chromadb.api.models.Collection import Collection
import chromadb.errors as errors
from uuid import UUID
from chromadb.auth import (
ClientAuthProvider,
)
from chromadb.auth.providers import RequestsClientAuthProtocolAdapter
from chromadb.auth.registry import resolve_provider
from chromadb.config import Settings, System
from chromadb.telemetry import Telemetry
from overrides import override


class FastAPI(API):
Expand All @@ -47,7 +54,27 @@ def __init__(self, system: System):
)

self._header = system.settings.chroma_server_headers
self._session = requests.Session()
if (
system.settings.chroma_client_auth_provider
and system.settings.chroma_client_auth_protocol_adapter
):
self._auth_provider = self.require(
resolve_provider(
system.settings.chroma_client_auth_provider, ClientAuthProvider
)
)
self._adapter = cast(
RequestsClientAuthProtocolAdapter,
system.require(
resolve_provider(
system.settings.chroma_client_auth_protocol_adapter,
RequestsClientAuthProtocolAdapter,
)
),
)
self._session = self._adapter.session
else:
self._session = requests.Session()
if self._header is not None:
self._session.headers.update(self._header)

Expand Down
207 changes: 207 additions & 0 deletions chromadb/auth/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
"""
Contains only Auth abstractions, no implementations.
"""
import base64
import logging
from abc import ABC, abstractmethod
from enum import Enum
from typing import (
Optional,
Dict,
TypeVar,
Tuple,
Generic,
)

from overrides import EnforceOverrides, override
from pydantic import SecretStr

from chromadb.config import (
Component,
System,
)
from chromadb.errors import ChromaError

logger = logging.getLogger(__name__)

T = TypeVar("T")
S = TypeVar("S")


class AuthInfoType(Enum):
COOKIE = "cookie"
HEADER = "header"
URL = "url"
METADATA = "metadata" # gRPC


class ClientAuthResponse(EnforceOverrides, ABC):
@abstractmethod
def get_auth_info_type(self) -> AuthInfoType:
...

@abstractmethod
def get_auth_info(self) -> Tuple[str, SecretStr]:
...


class ClientAuthProvider(Component):
def __init__(self, system: System) -> None:
super().__init__(system)

@abstractmethod
def authenticate(self) -> ClientAuthResponse:
pass


class ClientAuthConfigurationProvider(Component):
def __init__(self, system: System) -> None:
super().__init__(system)

@abstractmethod
def get_configuration(self) -> Optional[T]:
pass


class ClientAuthCredentialsProvider(Component, Generic[T]):
def __init__(self, system: System) -> None:
super().__init__(system)

@abstractmethod
def get_credentials(self) -> T:
pass


class ClientAuthProtocolAdapter(Component, Generic[T]):
def __init__(self, system: System) -> None:
super().__init__(system)

@abstractmethod
def inject_credentials(self, injection_context: T) -> None:
pass


# SERVER-SIDE Abstractions


class ServerAuthenticationRequest(EnforceOverrides, ABC, Generic[T]):
@abstractmethod
def get_auth_info(
self, auth_info_type: AuthInfoType, auth_info_id: Optional[str] = None
) -> T:
"""
This method should return the necessary auth info based on the type of authentication (e.g. header, cookie, url)
and a given id for the respective auth type (e.g. name of the header, cookie, url param).

:param auth_info_type: The type of auth info to return
:param auth_info_id: The id of the auth info to return
:return: The auth info which can be specific to the implementation
"""
pass


class ServerAuthenticationResponse(EnforceOverrides, ABC):
def success(self) -> bool:
raise NotImplementedError()


class ServerAuthProvider(Component):
def __init__(self, system: System) -> None:
super().__init__(system)

@abstractmethod
def authenticate(self, request: ServerAuthenticationRequest[T]) -> bool:
pass


class ChromaAuthMiddleware(Component):
def __init__(self, system: System) -> None:
super().__init__(system)

@abstractmethod
def authenticate(
self, request: ServerAuthenticationRequest[T]
) -> Optional[ServerAuthenticationResponse]:
...

@abstractmethod
def ignore_operation(self, verb: str, path: str) -> bool:
...

@abstractmethod
def instrument_server(self, app: T) -> None:
...


class ServerAuthConfigurationProvider(Component):
def __init__(self, system: System) -> None:
super().__init__(system)

@abstractmethod
def get_configuration(self) -> Optional[T]:
pass


class AuthenticationError(ChromaError):
@override
def code(self) -> int:
return 401

@classmethod
@override
def name(cls) -> str:
return "AuthenticationError"


class AbstractCredentials(EnforceOverrides, ABC, Generic[T]):
"""
The class is used by Auth Providers to encapsulate credentials received from the server
and pass them to a ServerAuthCredentialsProvider.
"""

@abstractmethod
def get_credentials(self) -> Dict[str, T]:
"""
Returns the data encapsulated by the credentials object.
"""
pass


class SecretStrAbstractCredentials(AbstractCredentials[SecretStr]):
@abstractmethod
@override
def get_credentials(self) -> Dict[str, SecretStr]:
"""
Returns the data encapsulated by the credentials object.
"""
pass


class BasicAuthCredentials(SecretStrAbstractCredentials):
def __init__(self, username: SecretStr, password: SecretStr) -> None:
self.username = username
self.password = password

@override
def get_credentials(self) -> Dict[str, SecretStr]:
return {"username": self.username, "password": self.password}

@staticmethod
def from_header(header: str) -> "BasicAuthCredentials":
"""
Parses a basic auth header and returns a BasicAuthCredentials object.
"""
header = header.replace("Basic ", "")
header = header.strip()
base64_decoded = base64.b64decode(header).decode("utf-8")
username, password = base64_decoded.split(":")
return BasicAuthCredentials(SecretStr(username), SecretStr(password))


class ServerAuthCredentialsProvider(Component):
def __init__(self, system: System) -> None:
super().__init__(system)

@abstractmethod
def validate_credentials(self, credentials: AbstractCredentials[T]) -> bool:
pass
106 changes: 106 additions & 0 deletions chromadb/auth/basic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import base64
import logging
from typing import Tuple, Any, cast

from overrides import override
from pydantic import SecretStr

from chromadb.auth import (
ServerAuthProvider,
ClientAuthProvider,
ServerAuthenticationRequest,
ServerAuthCredentialsProvider,
AuthInfoType,
BasicAuthCredentials,
ClientAuthCredentialsProvider,
ClientAuthProtocolAdapter,
ClientAuthResponse,
)
from chromadb.auth.registry import register_provider, resolve_provider
from chromadb.config import System
from chromadb.utils import get_class

logger = logging.getLogger(__name__)

__all__ = ["BasicAuthServerProvider", "BasicAuthClientProvider"]


def _encode_credentials(username: str, password: str) -> SecretStr:
tazarov marked this conversation as resolved.
Show resolved Hide resolved
return SecretStr(
base64.b64encode(f"{username}:{password}".encode("utf-8")).decode("utf-8")
)


class BasicAuthClientAuthResponse(ClientAuthResponse):
def __init__(self, credentials: SecretStr) -> None:
self._credentials = credentials

@override
def get_auth_info_type(self) -> AuthInfoType:
return AuthInfoType.HEADER

@override
def get_auth_info(self) -> Tuple[str, SecretStr]:
return "Authorization", SecretStr(
f"Basic {self._credentials.get_secret_value()}"
)


@register_provider("basic")
class BasicAuthClientProvider(ClientAuthProvider):
_credentials_provider: ClientAuthCredentialsProvider[Any]
_protocol_adapter: ClientAuthProtocolAdapter[Any]
tazarov marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, system: System) -> None:
super().__init__(system)
self._settings = system.settings
system.settings.require("chroma_client_auth_credentials_provider")
self._credentials_provider = system.require(
get_class(
str(system.settings.chroma_client_auth_credentials_provider),
ClientAuthCredentialsProvider,
)
)

@override
def authenticate(self) -> ClientAuthResponse:
_creds = self._credentials_provider.get_credentials()
return BasicAuthClientAuthResponse(
SecretStr(
base64.b64encode(f"{_creds.get_secret_value()}".encode("utf-8")).decode(
"utf-8"
)
)
)


@register_provider("basic")
class BasicAuthServerProvider(ServerAuthProvider):
_credentials_provider: ServerAuthCredentialsProvider

def __init__(self, system: System) -> None:
super().__init__(system)
self._settings = system.settings
system.settings.require("chroma_server_auth_credentials_provider")
self._credentials_provider = cast(
ServerAuthCredentialsProvider,
system.require(
resolve_provider(
str(system.settings.chroma_server_auth_credentials_provider),
ServerAuthCredentialsProvider,
)
),
)

@override
def authenticate(self, request: ServerAuthenticationRequest[Any]) -> bool:
try:
# print(f"BasicAuthServerProvider.authenticate: {}")
tazarov marked this conversation as resolved.
Show resolved Hide resolved
_auth_header = request.get_auth_info(AuthInfoType.HEADER, "Authorization")
return self._credentials_provider.validate_credentials(
BasicAuthCredentials.from_header(_auth_header)
)
except Exception as e:
logger.error(f"BasicAuthServerProvider.authenticate failed: {repr(e)}")
return False
# raise AuthenticationError()
tazarov marked this conversation as resolved.
Show resolved Hide resolved
Loading