Skip to content

Commit

Permalink
Merge branch 'main' into feature/orjson-parser-async-endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
tazarov authored Feb 14, 2024
2 parents 98fdfc1 + da68516 commit a1ddd69
Show file tree
Hide file tree
Showing 30 changed files with 759 additions and 29 deletions.
9 changes: 9 additions & 0 deletions DEVELOP.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ api = chromadb.HttpClient(host="localhost", port="8000")

print(api.heartbeat())
```
## Local dev setup for distributed chroma
We use tilt for providing local dev setup. Tilt is an open source project
##### Requirement
- Docker
- Local Kubernetes cluster (Recommended: [OrbStack](https://orbstack.dev/) for mac, [Kind](https://kind.sigs.k8s.io/) for linux)
- [Tilt](https://docs.tilt.dev/)

For starting the distributed Chroma in the workspace, use `tilt up`. It will create all the required resources and build the necessary Docker image in the current kubectl context.
Once done, it will expose Chroma on port 8000. You can also visit the Tilt dashboard UI at http://localhost:10350/. To clean and remove all the resources created by Tilt, use `tilt down`.

## Testing

Expand Down
11 changes: 10 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ COPY --from=builder /install /usr/local
COPY ./bin/docker_entrypoint.sh /docker_entrypoint.sh
COPY ./ /chroma

RUN chmod +x /docker_entrypoint.sh

ENV CHROMA_HOST_ADDR "0.0.0.0"
ENV CHROMA_HOST_PORT 8000
ENV CHROMA_WORKERS 1
ENV CHROMA_LOG_CONFIG "chromadb/log_config.yml"
ENV CHROMA_TIMEOUT_KEEP_ALIVE 30

EXPOSE 8000

CMD ["/docker_entrypoint.sh"]
ENTRYPOINT ["/docker_entrypoint.sh"]
CMD [ "--workers ${CHROMA_WORKERS} --host ${CHROMA_HOST_ADDR} --port ${CHROMA_HOST_PORT} --proxy-headers --log-config ${CHROMA_LOG_CONFIG} --timeout-keep-alive ${CHROMA_TIMEOUT_KEEP_ALIVE}"]
30 changes: 30 additions & 0 deletions Tiltfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
docker_build('coordinator',
context='.',
dockerfile='./go/coordinator/Dockerfile'
)

docker_build('server',
context='.',
dockerfile='./Dockerfile',
)

docker_build('worker',
context='.',
dockerfile='./rust/worker/Dockerfile'
)


k8s_yaml(['k8s/dev/setup.yaml'])
k8s_resource(
objects=['chroma:Namespace', 'memberlist-reader:ClusterRole', 'memberlist-reader:ClusterRoleBinding', 'pod-list-role:Role', 'pod-list-role-binding:RoleBinding', 'memberlists.chroma.cluster:CustomResourceDefinition','worker-memberlist:MemberList'],
new_name='k8s_setup',
labels=["infrastructure"]
)
k8s_yaml(['k8s/dev/pulsar.yaml'])
k8s_resource('pulsar', resource_deps=['k8s_setup'], labels=["infrastructure"])
k8s_yaml(['k8s/dev/server.yaml'])
k8s_resource('server', resource_deps=['k8s_setup'],labels=["chroma"], port_forwards=8000 )
k8s_yaml(['k8s/dev/coordinator.yaml'])
k8s_resource('coordinator', resource_deps=['pulsar', 'server'], labels=["chroma"])
k8s_yaml(['k8s/dev/worker.yaml'])
k8s_resource('worker', resource_deps=['coordinator'],labels=["chroma"])
12 changes: 11 additions & 1 deletion bin/docker_entrypoint.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
#!/bin/bash
set -e

export IS_PERSISTENT=1
export CHROMA_SERVER_NOFILE=65535
exec uvicorn chromadb.app:app --workers 1 --host 0.0.0.0 --port 8000 --proxy-headers --log-config chromadb/log_config.yml --timeout-keep-alive 30
args="$@"

if [[ $args =~ ^uvicorn.* ]]; then
echo "Starting server with args: $(eval echo "$args")"
echo -e "\033[31mWARNING: Please remove 'uvicorn chromadb.app:app' from your command line arguments. This is now handled by the entrypoint script."
exec $(eval echo "$args")
else
echo "Starting 'uvicorn chromadb.app:app' with args: $(eval echo "$args")"
exec uvicorn chromadb.app:app $(eval echo "$args")
fi
34 changes: 18 additions & 16 deletions chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import json
import orjson as json
import logging
from typing import Optional, cast, Tuple
from typing import Sequence
Expand Down Expand Up @@ -138,14 +138,16 @@ def __init__(self, system: System):
self._session = requests.Session()
if self._header is not None:
self._session.headers.update(self._header)
if self._settings.chroma_server_ssl_verify is not None:
self._session.verify = self._settings.chroma_server_ssl_verify

@trace_method("FastAPI.heartbeat", OpenTelemetryGranularity.OPERATION)
@override
def heartbeat(self) -> int:
"""Returns the current server time in nanoseconds to check if the server is alive"""
resp = self._session.get(self._api_url)
raise_chroma_error(resp)
return int(resp.json()["nanosecond heartbeat"])
return int(json.loads(resp.text)["nanosecond heartbeat"])

@trace_method("FastAPI.create_database", OpenTelemetryGranularity.OPERATION)
@override
Expand Down Expand Up @@ -175,7 +177,7 @@ def get_database(
params={"tenant": tenant},
)
raise_chroma_error(resp)
resp_json = resp.json()
resp_json = json.loads(resp.text)
return Database(
id=resp_json["id"], name=resp_json["name"], tenant=resp_json["tenant"]
)
Expand All @@ -196,7 +198,7 @@ def get_tenant(self, name: str) -> Tenant:
self._api_url + "/tenants/" + name,
)
raise_chroma_error(resp)
resp_json = resp.json()
resp_json = json.loads(resp.text)
return Tenant(name=resp_json["name"])

@trace_method("FastAPI.list_collections", OpenTelemetryGranularity.OPERATION)
Expand All @@ -219,7 +221,7 @@ def list_collections(
},
)
raise_chroma_error(resp)
json_collections = resp.json()
json_collections = json.loads(resp.text)
collections = []
for json_collection in json_collections:
collections.append(Collection(self, **json_collection))
Expand All @@ -237,7 +239,7 @@ def count_collections(
params={"tenant": tenant, "database": database},
)
raise_chroma_error(resp)
return cast(int, resp.json())
return cast(int, json.loads(resp.text))

@trace_method("FastAPI.create_collection", OpenTelemetryGranularity.OPERATION)
@override
Expand Down Expand Up @@ -266,7 +268,7 @@ def create_collection(
params={"tenant": tenant, "database": database},
)
raise_chroma_error(resp)
resp_json = resp.json()
resp_json = json.loads(resp.text)
return Collection(
client=self,
id=resp_json["id"],
Expand Down Expand Up @@ -300,7 +302,7 @@ def get_collection(
self._api_url + "/collections/" + name if name else str(id), params=_params
)
raise_chroma_error(resp)
resp_json = resp.json()
resp_json = json.loads(resp.text)
return Collection(
client=self,
name=resp_json["name"],
Expand Down Expand Up @@ -379,7 +381,7 @@ def _count(
self._api_url + "/collections/" + str(collection_id) + "/count"
)
raise_chroma_error(resp)
return cast(int, resp.json())
return cast(int, json.loads(resp.text))

@trace_method("FastAPI._peek", OpenTelemetryGranularity.OPERATION)
@override
Expand Down Expand Up @@ -432,7 +434,7 @@ def _get(
)

raise_chroma_error(resp)
body = resp.json()
body = json.loads(resp.text)
return GetResult(
ids=body["ids"],
embeddings=body.get("embeddings", None),
Expand Down Expand Up @@ -460,7 +462,7 @@ def _delete(
)

raise_chroma_error(resp)
return cast(IDs, resp.json())
return cast(IDs, json.loads(resp.text))

@trace_method("FastAPI._submit_batch", OpenTelemetryGranularity.ALL)
def _submit_batch(
Expand Down Expand Up @@ -584,7 +586,7 @@ def _query(
)

raise_chroma_error(resp)
body = resp.json()
body = json.loads(resp.text)

return QueryResult(
ids=body["ids"],
Expand All @@ -602,15 +604,15 @@ def reset(self) -> bool:
"""Resets the database"""
resp = self._session.post(self._api_url + "/reset")
raise_chroma_error(resp)
return cast(bool, resp.json())
return cast(bool, json.loads(resp.text))

@trace_method("FastAPI.get_version", OpenTelemetryGranularity.OPERATION)
@override
def get_version(self) -> str:
"""Returns the version of the server"""
resp = self._session.get(self._api_url + "/version")
raise_chroma_error(resp)
return cast(str, resp.json())
return cast(str, json.loads(resp.text))

@override
def get_settings(self) -> Settings:
Expand All @@ -624,7 +626,7 @@ def max_batch_size(self) -> int:
if self._max_batch_size == -1:
resp = self._session.get(self._api_url + "/pre-flight-checks")
raise_chroma_error(resp)
self._max_batch_size = cast(int, resp.json()["max_batch_size"])
self._max_batch_size = cast(int, json.loads(resp.text)["max_batch_size"])
return self._max_batch_size


Expand All @@ -635,7 +637,7 @@ def raise_chroma_error(resp: requests.Response) -> None:

chroma_error = None
try:
body = resp.json()
body = json.loads(resp.text)
if "error" in body:
if body["error"] in errors.error_types:
chroma_error = errors.error_types[body["error"]](body["message"])
Expand Down
12 changes: 10 additions & 2 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

# Re-export types from chromadb.types
__all__ = ["Metadata", "Where", "WhereDocument", "UpdateCollectionMetadata"]

META_KEY_CHROMA_DOCUMENT = "chroma:document"
T = TypeVar("T")
OneOrMany = Union[T, List[T]]

Expand Down Expand Up @@ -265,6 +265,10 @@ def validate_metadata(metadata: Metadata) -> Metadata:
if len(metadata) == 0:
raise ValueError(f"Expected metadata to be a non-empty dict, got {metadata}")
for key, value in metadata.items():
if key == META_KEY_CHROMA_DOCUMENT:
raise ValueError(
f"Expected metadata to not contain the reserved key {META_KEY_CHROMA_DOCUMENT}"
)
if not isinstance(key, str):
raise TypeError(
f"Expected metadata key to be a str, got {key} which is a {type(key)}"
Expand Down Expand Up @@ -476,7 +480,11 @@ def validate_embeddings(embeddings: Embeddings) -> Embeddings:
raise ValueError(
f"Expected each embedding in the embeddings to be a list, got {embeddings}"
)
for embedding in embeddings:
for i, embedding in enumerate(embeddings):
if len(embedding) == 0:
raise ValueError(
f"Expected each embedding in the embeddings to be a non-empty list, got empty embedding at pos {i}"
)
if not all(
[
isinstance(value, (int, float)) and not isinstance(value, bool)
Expand Down
4 changes: 3 additions & 1 deletion chromadb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
from abc import ABC
from graphlib import TopologicalSorter
from typing import Optional, List, Any, Dict, Set, Iterable
from typing import Optional, List, Any, Dict, Set, Iterable, Union
from typing import Type, TypeVar, cast

from overrides import EnforceOverrides
Expand Down Expand Up @@ -122,6 +122,8 @@ class Settings(BaseSettings): # type: ignore
chroma_server_headers: Optional[Dict[str, str]] = None
chroma_server_http_port: Optional[str] = None
chroma_server_ssl_enabled: Optional[bool] = False
# the below config value is only applicable to Chroma HTTP clients
chroma_server_ssl_verify: Optional[Union[bool, str]] = None
chroma_server_api_default_path: Optional[str] = "/api/v1"
chroma_server_grpc_port: Optional[str] = None
# eg ["http://localhost:3000"]
Expand Down
3 changes: 2 additions & 1 deletion chromadb/db/migrations.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from typing import Sequence
from typing_extensions import TypedDict, NotRequired
from importlib_resources.abc import Traversable
Expand Down Expand Up @@ -253,7 +254,7 @@ def _read_migration_file(file: MigrationFile, hash_alg: str) -> Migration:
sql = file["path"].read_text()

if hash_alg == "md5":
hash = hashlib.md5(sql.encode("utf-8")).hexdigest()
hash = hashlib.md5(sql.encode("utf-8"), usedforsecurity=False).hexdigest() if sys.version_info >= (3, 9) else hashlib.md5(sql.encode("utf-8")).hexdigest()
elif hash_alg == "sha256":
hash = hashlib.sha256(sql.encode("utf-8")).hexdigest()
else:
Expand Down
17 changes: 17 additions & 0 deletions chromadb/segment/impl/metadata/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,7 @@ def _where_doc_criterion(
def delete(self) -> None:
t = Table("embeddings")
t1 = Table("embedding_metadata")
t2 = Table("embedding_fulltext_search")
q0 = (
self._db.querybuilder()
.from_(t1)
Expand Down Expand Up @@ -603,7 +604,23 @@ def delete(self) -> None:
)
)
)
q_fts = (
self._db.querybuilder()
.from_(t2)
.delete()
.where(
t2.rowid.isin(
self._db.querybuilder()
.from_(t)
.select(t.id)
.where(
t.segment_id == ParameterValue(self._db.uuid_to_db(self._id))
)
)
)
)
with self._db.tx() as cur:
cur.execute(*get_sql(q_fts))
cur.execute(*get_sql(q0))
cur.execute(*get_sql(q))

Expand Down
Loading

0 comments on commit a1ddd69

Please sign in to comment.