Skip to content

Commit

Permalink
[ENH] CIP-2: Auth Providers Proposal (#986)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - New functionality
	 - Auth Provide Client and Server Side Abstractions
	 - Basic Auth Provider

## Test plan
Unit tests for authorized endpoints

## Documentation Changes
Docs should change to describe how to use auth providers on the client
and server. CIP added in `docs/`
  • Loading branch information
tazarov authored Aug 23, 2023
1 parent 8a7f0ba commit 48700dd
Show file tree
Hide file tree
Showing 24 changed files with 1,458 additions and 43 deletions.
47 changes: 37 additions & 10 deletions chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
import json
from typing import Optional, cast
from typing import Sequence
from uuid import UUID

import requests
from overrides import override

import chromadb.errors as errors
import chromadb.utils.embedding_functions as ef
from chromadb.api import API
from chromadb.config import Settings, System
from chromadb.api.models.Collection import Collection
from chromadb.api.types import (
Documents,
Embeddings,
Expand All @@ -14,15 +23,13 @@
QueryResult,
CollectionMetadata,
)
import chromadb.utils.embedding_functions as ef
import requests
import json
from typing import Sequence
from chromadb.api.models.Collection import Collection
import chromadb.errors as errors
from uuid import UUID
from chromadb.auth import (
ClientAuthProvider,
)
from chromadb.auth.providers import RequestsClientAuthProtocolAdapter
from chromadb.auth.registry import resolve_provider
from chromadb.config import Settings, System
from chromadb.telemetry import Telemetry
from overrides import override


class FastAPI(API):
Expand All @@ -47,7 +54,27 @@ def __init__(self, system: System):
)

self._header = system.settings.chroma_server_headers
self._session = requests.Session()
if (
system.settings.chroma_client_auth_provider
and system.settings.chroma_client_auth_protocol_adapter
):
self._auth_provider = self.require(
resolve_provider(
system.settings.chroma_client_auth_provider, ClientAuthProvider
)
)
self._adapter = cast(
RequestsClientAuthProtocolAdapter,
system.require(
resolve_provider(
system.settings.chroma_client_auth_protocol_adapter,
RequestsClientAuthProtocolAdapter,
)
),
)
self._session = self._adapter.session
else:
self._session = requests.Session()
if self._header is not None:
self._session.headers.update(self._header)

Expand Down
207 changes: 207 additions & 0 deletions chromadb/auth/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
"""
Contains only Auth abstractions, no implementations.
"""
import base64
import logging
from abc import ABC, abstractmethod
from enum import Enum
from typing import (
Optional,
Dict,
TypeVar,
Tuple,
Generic,
)

from overrides import EnforceOverrides, override
from pydantic import SecretStr

from chromadb.config import (
Component,
System,
)
from chromadb.errors import ChromaError

logger = logging.getLogger(__name__)

T = TypeVar("T")
S = TypeVar("S")


class AuthInfoType(Enum):
COOKIE = "cookie"
HEADER = "header"
URL = "url"
METADATA = "metadata" # gRPC


class ClientAuthResponse(EnforceOverrides, ABC):
@abstractmethod
def get_auth_info_type(self) -> AuthInfoType:
...

@abstractmethod
def get_auth_info(self) -> Tuple[str, SecretStr]:
...


class ClientAuthProvider(Component):
def __init__(self, system: System) -> None:
super().__init__(system)

@abstractmethod
def authenticate(self) -> ClientAuthResponse:
pass


class ClientAuthConfigurationProvider(Component):
def __init__(self, system: System) -> None:
super().__init__(system)

@abstractmethod
def get_configuration(self) -> Optional[T]:
pass


class ClientAuthCredentialsProvider(Component, Generic[T]):
def __init__(self, system: System) -> None:
super().__init__(system)

@abstractmethod
def get_credentials(self) -> T:
pass


class ClientAuthProtocolAdapter(Component, Generic[T]):
def __init__(self, system: System) -> None:
super().__init__(system)

@abstractmethod
def inject_credentials(self, injection_context: T) -> None:
pass


# SERVER-SIDE Abstractions


class ServerAuthenticationRequest(EnforceOverrides, ABC, Generic[T]):
@abstractmethod
def get_auth_info(
self, auth_info_type: AuthInfoType, auth_info_id: Optional[str] = None
) -> T:
"""
This method should return the necessary auth info based on the type of authentication (e.g. header, cookie, url)
and a given id for the respective auth type (e.g. name of the header, cookie, url param).
:param auth_info_type: The type of auth info to return
:param auth_info_id: The id of the auth info to return
:return: The auth info which can be specific to the implementation
"""
pass


class ServerAuthenticationResponse(EnforceOverrides, ABC):
def success(self) -> bool:
raise NotImplementedError()


class ServerAuthProvider(Component):
def __init__(self, system: System) -> None:
super().__init__(system)

@abstractmethod
def authenticate(self, request: ServerAuthenticationRequest[T]) -> bool:
pass


class ChromaAuthMiddleware(Component):
def __init__(self, system: System) -> None:
super().__init__(system)

@abstractmethod
def authenticate(
self, request: ServerAuthenticationRequest[T]
) -> Optional[ServerAuthenticationResponse]:
...

@abstractmethod
def ignore_operation(self, verb: str, path: str) -> bool:
...

@abstractmethod
def instrument_server(self, app: T) -> None:
...


class ServerAuthConfigurationProvider(Component):
def __init__(self, system: System) -> None:
super().__init__(system)

@abstractmethod
def get_configuration(self) -> Optional[T]:
pass


class AuthenticationError(ChromaError):
@override
def code(self) -> int:
return 401

@classmethod
@override
def name(cls) -> str:
return "AuthenticationError"


class AbstractCredentials(EnforceOverrides, ABC, Generic[T]):
"""
The class is used by Auth Providers to encapsulate credentials received from the server
and pass them to a ServerAuthCredentialsProvider.
"""

@abstractmethod
def get_credentials(self) -> Dict[str, T]:
"""
Returns the data encapsulated by the credentials object.
"""
pass


class SecretStrAbstractCredentials(AbstractCredentials[SecretStr]):
@abstractmethod
@override
def get_credentials(self) -> Dict[str, SecretStr]:
"""
Returns the data encapsulated by the credentials object.
"""
pass


class BasicAuthCredentials(SecretStrAbstractCredentials):
def __init__(self, username: SecretStr, password: SecretStr) -> None:
self.username = username
self.password = password

@override
def get_credentials(self) -> Dict[str, SecretStr]:
return {"username": self.username, "password": self.password}

@staticmethod
def from_header(header: str) -> "BasicAuthCredentials":
"""
Parses a basic auth header and returns a BasicAuthCredentials object.
"""
header = header.replace("Basic ", "")
header = header.strip()
base64_decoded = base64.b64decode(header).decode("utf-8")
username, password = base64_decoded.split(":")
return BasicAuthCredentials(SecretStr(username), SecretStr(password))


class ServerAuthCredentialsProvider(Component):
def __init__(self, system: System) -> None:
super().__init__(system)

@abstractmethod
def validate_credentials(self, credentials: AbstractCredentials[T]) -> bool:
pass
96 changes: 96 additions & 0 deletions chromadb/auth/basic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import base64
import logging
from typing import Tuple, Any, cast

from overrides import override
from pydantic import SecretStr

from chromadb.auth import (
ServerAuthProvider,
ClientAuthProvider,
ServerAuthenticationRequest,
ServerAuthCredentialsProvider,
AuthInfoType,
BasicAuthCredentials,
ClientAuthCredentialsProvider,
ClientAuthResponse,
)
from chromadb.auth.registry import register_provider, resolve_provider
from chromadb.config import System
from chromadb.utils import get_class

logger = logging.getLogger(__name__)

__all__ = ["BasicAuthServerProvider", "BasicAuthClientProvider"]


class BasicAuthClientAuthResponse(ClientAuthResponse):
def __init__(self, credentials: SecretStr) -> None:
self._credentials = credentials

@override
def get_auth_info_type(self) -> AuthInfoType:
return AuthInfoType.HEADER

@override
def get_auth_info(self) -> Tuple[str, SecretStr]:
return "Authorization", SecretStr(
f"Basic {self._credentials.get_secret_value()}"
)


@register_provider("basic")
class BasicAuthClientProvider(ClientAuthProvider):
_credentials_provider: ClientAuthCredentialsProvider[Any]

def __init__(self, system: System) -> None:
super().__init__(system)
self._settings = system.settings
system.settings.require("chroma_client_auth_credentials_provider")
self._credentials_provider = system.require(
get_class(
str(system.settings.chroma_client_auth_credentials_provider),
ClientAuthCredentialsProvider,
)
)

@override
def authenticate(self) -> ClientAuthResponse:
_creds = self._credentials_provider.get_credentials()
return BasicAuthClientAuthResponse(
SecretStr(
base64.b64encode(f"{_creds.get_secret_value()}".encode("utf-8")).decode(
"utf-8"
)
)
)


@register_provider("basic")
class BasicAuthServerProvider(ServerAuthProvider):
_credentials_provider: ServerAuthCredentialsProvider

def __init__(self, system: System) -> None:
super().__init__(system)
self._settings = system.settings
system.settings.require("chroma_server_auth_credentials_provider")
self._credentials_provider = cast(
ServerAuthCredentialsProvider,
system.require(
resolve_provider(
str(system.settings.chroma_server_auth_credentials_provider),
ServerAuthCredentialsProvider,
)
),
)

@override
def authenticate(self, request: ServerAuthenticationRequest[Any]) -> bool:
try:
_auth_header = request.get_auth_info(AuthInfoType.HEADER, "Authorization")
return self._credentials_provider.validate_credentials(
BasicAuthCredentials.from_header(_auth_header)
)
except Exception as e:
logger.error(f"BasicAuthServerProvider.authenticate failed: {repr(e)}")
return False
Loading

0 comments on commit 48700dd

Please sign in to comment.