Skip to content

Commit

Permalink
[BUG] Make sure Client parameters are strings (chroma-core#1577)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- Stringify all paremeters to `Client`s which are meant to be strings.
At present some parameters -- `port` in particular -- can be reasonably
passed as integers which causes weird and unexpected behavior.
	 - Fixes chroma-core#1573

## Test plan
*How are these changes tested?*

- [ ] Tests pass locally with `pytest` for python, `yarn test` for js

## Documentation Changes
*Are all docstrings for user-facing APIs updated if required? Do we need
to make documentation changes in the [docs
repository](https://github.com/chroma-core/docs)?*
  • Loading branch information
beggers authored Feb 20, 2024
1 parent f96be93 commit 8a0f67e
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 18 deletions.
32 changes: 29 additions & 3 deletions chromadb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ def EphemeralClient(
settings = Settings()
settings.is_persistent = False

# Make sure paramaters are the correct types -- users can pass anything.
tenant = str(tenant)
database = str(database)

return ClientCreator(settings=settings, tenant=tenant, database=database)


Expand All @@ -135,12 +139,16 @@ def PersistentClient(
settings.persist_directory = path
settings.is_persistent = True

# Make sure paramaters are the correct types -- users can pass anything.
tenant = str(tenant)
database = str(database)

return ClientCreator(tenant=tenant, database=database, settings=settings)


def HttpClient(
host: str = "localhost",
port: str = "8000",
port: int = 8000,
ssl: bool = False,
headers: Optional[Dict[str, str]] = None,
settings: Optional[Settings] = None,
Expand All @@ -165,6 +173,13 @@ def HttpClient(
if settings is None:
settings = Settings()

# Make sure paramaters are the correct types -- users can pass anything.
host = str(host)
port = int(port)
ssl = bool(ssl)
tenant = str(tenant)
database = str(database)

settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
if settings.chroma_server_host and settings.chroma_server_host != host:
raise ValueError(
Expand All @@ -189,7 +204,7 @@ def CloudClient(
settings: Optional[Settings] = None,
*, # Following arguments are keyword-only, intended for testing only.
cloud_host: str = "api.trychroma.com",
cloud_port: str = "8000",
cloud_port: int = 8000,
enable_ssl: bool = True,
) -> ClientAPI:
"""
Expand Down Expand Up @@ -217,6 +232,14 @@ def CloudClient(
if settings is None:
settings = Settings()

# Make sure paramaters are the correct types -- users can pass anything.
tenant = str(tenant)
database = str(database)
api_key = str(api_key)
cloud_host = str(cloud_host)
cloud_port = int(cloud_port)
enable_ssl = bool(enable_ssl)

settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
settings.chroma_server_host = cloud_host
settings.chroma_server_http_port = cloud_port
Expand All @@ -242,9 +265,12 @@ def Client(
tenant: The tenant to use for this client. Defaults to the default tenant.
database: The database to use for this client. Defaults to the default database.
"""

# Make sure paramaters are the correct types -- users can pass anything.
tenant = str(tenant)
database = str(database)

return ClientCreator(tenant=tenant, database=database, settings=settings)


Expand Down
2 changes: 1 addition & 1 deletion chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init__(self, system: System):

self._api_url = FastAPI.resolve_url(
chroma_server_host=str(system.settings.chroma_server_host),
chroma_server_http_port=int(str(system.settings.chroma_server_http_port)),
chroma_server_http_port=system.settings.chroma_server_http_port,
chroma_server_ssl_enabled=system.settings.chroma_server_ssl_enabled,
default_api_path=system.settings.chroma_server_api_default_path,
)
Expand Down
8 changes: 4 additions & 4 deletions chromadb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,12 @@ class Settings(BaseSettings): # type: ignore

chroma_server_host: Optional[str] = None
chroma_server_headers: Optional[Dict[str, str]] = None
chroma_server_http_port: Optional[str] = None
chroma_server_http_port: Optional[int] = None
chroma_server_ssl_enabled: Optional[bool] = False
# the below config value is only applicable to Chroma HTTP clients
chroma_server_ssl_verify: Optional[Union[bool, str]] = None
chroma_server_api_default_path: Optional[str] = "/api/v1"
chroma_server_grpc_port: Optional[str] = None
chroma_server_grpc_port: Optional[int] = None
# eg ["http://localhost:3000"]
chroma_server_cors_allow_origins: List[str] = []

Expand All @@ -141,8 +141,8 @@ def empty_str_to_none(cls, v: str) -> Optional[str]:
chroma_server_nofile: Optional[int] = None

pulsar_broker_url: Optional[str] = None
pulsar_admin_port: Optional[str] = "8080"
pulsar_broker_port: Optional[str] = "6650"
pulsar_admin_port: Optional[int] = 8080
pulsar_broker_port: Optional[int] = 6650

chroma_server_auth_provider: Optional[str] = None

Expand Down
6 changes: 3 additions & 3 deletions chromadb/test/client/test_cloud_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def mock_cloud_server(valid_token: str) -> Generator[System, None, None]:
settings = Settings(
chroma_api_impl="chromadb.api.fastapi.FastAPI",
chroma_server_host=TEST_CLOUD_HOST,
chroma_server_http_port=str(port),
chroma_server_http_port=port,
chroma_client_auth_provider="chromadb.auth.token.TokenAuthClientProvider",
chroma_client_auth_credentials=valid_token,
chroma_client_auth_token_transport_header=TOKEN_TRANSPORT_HEADER,
Expand All @@ -82,7 +82,7 @@ def test_valid_key(mock_cloud_server: System, valid_token: str) -> None:
database=DEFAULT_DATABASE,
api_key=valid_token,
cloud_host=TEST_CLOUD_HOST,
cloud_port=mock_cloud_server.settings.chroma_server_http_port, # type: ignore
cloud_port=mock_cloud_server.settings.chroma_server_http_port or 8000,
enable_ssl=False,
)

Expand All @@ -98,7 +98,7 @@ def test_invalid_key(mock_cloud_server: System, valid_token: str) -> None:
database=DEFAULT_DATABASE,
api_key=invalid_token,
cloud_host=TEST_CLOUD_HOST,
cloud_port=mock_cloud_server.settings.chroma_server_http_port, # type: ignore
cloud_port=mock_cloud_server.settings.chroma_server_http_port or 8000,
enable_ssl=False,
)
client.heartbeat()
4 changes: 2 additions & 2 deletions chromadb/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def _fastapi_fixture(
settings = Settings(
chroma_api_impl="chromadb.api.fastapi.FastAPI",
chroma_server_host="localhost",
chroma_server_http_port=str(port),
chroma_server_http_port=port,
allow_reset=True,
chroma_client_auth_provider=chroma_client_auth_provider,
chroma_client_auth_credentials=chroma_client_auth_credentials,
Expand Down Expand Up @@ -286,7 +286,7 @@ def fastapi_ssl() -> Generator[System, None, None]:
def basic_http_client() -> Generator[System, None, None]:
settings = Settings(
chroma_api_impl="chromadb.api.fastapi.FastAPI",
chroma_server_http_port="8000",
chroma_server_http_port=8000,
allow_reset=True,
)
system = System(settings)
Expand Down
6 changes: 3 additions & 3 deletions chromadb/test/test_chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_fastapi(self, mock: Mock) -> None:
chroma_api_impl="chromadb.api.fastapi.FastAPI",
persist_directory="./foo",
chroma_server_host="foo",
chroma_server_http_port="80",
chroma_server_http_port=80,
)
)
assert mock.called
Expand All @@ -78,7 +78,7 @@ def test_settings_pass_to_fastapi(self, mock: Mock) -> None:
settings = chromadb.config.Settings(
chroma_api_impl="chromadb.api.fastapi.FastAPI",
chroma_server_host="foo",
chroma_server_http_port="80",
chroma_server_http_port=80,
chroma_server_headers={"foo": "bar"},
)
client = chromadb.Client(settings)
Expand Down Expand Up @@ -106,7 +106,7 @@ def test_legacy_values() -> None:
chroma_api_impl="chromadb.api.local.LocalAPI",
persist_directory="./foo",
chroma_server_host="foo",
chroma_server_http_port="80",
chroma_server_http_port=80,
)
)
client.clear_system_cache()
4 changes: 2 additions & 2 deletions chromadb/test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def test_http_client_with_inconsistent_host_settings() -> None:
def test_http_client_with_inconsistent_port_settings() -> None:
try:
chromadb.HttpClient(
port="8002",
port=8002,
settings=Settings(
chroma_server_http_port="8001",
chroma_server_http_port=8001,
),
)
except ValueError as e:
Expand Down

0 comments on commit 8a0f67e

Please sign in to comment.