Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] fix persistent HNSW parameter migration #2511

Merged
merged 3 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion chromadb/db/mixins/sysdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
UniqueConstraintError,
)
from chromadb.db.system import SysDB
from chromadb.segment.impl.vector.hnsw_params import PersistentHnswParams
from chromadb.telemetry.opentelemetry import (
add_attributes_to_current_span,
OpenTelemetryClient,
Expand Down Expand Up @@ -776,7 +777,12 @@ def _insert_config_from_legacy_params(
collections_t = Table("collections")

# Get any existing HNSW params from the metadata
hnsw_metadata_params = HnswParams.extract(metadata or {})
metadata = metadata or {}
if metadata.get("hnsw:batch_size") or metadata.get("hnsw:sync_threshold"):
codetheweb marked this conversation as resolved.
Show resolved Hide resolved
hnsw_metadata_params = PersistentHnswParams.extract(metadata)
else:
hnsw_metadata_params = HnswParams.extract(metadata)

hnsw_configuration = HNSWConfigurationInternal.from_legacy_params(
hnsw_metadata_params # type: ignore[arg-type]
)
Expand Down
10 changes: 6 additions & 4 deletions chromadb/test/property/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def collections(
with_hnsw_params: bool = False,
has_embeddings: Optional[bool] = None,
has_documents: Optional[bool] = None,
with_persistent_hnsw_params: bool = False,
with_persistent_hnsw_params: st.SearchStrategy[bool] = st.just(False),
max_hnsw_batch_size: int = 2000,
max_hnsw_sync_threshold: int = 2000,
) -> Collection:
Expand All @@ -294,16 +294,18 @@ def collections(
dimension = draw(st.integers(min_value=2, max_value=2048))
dtype = draw(st.sampled_from(float_types))

if with_persistent_hnsw_params and not with_hnsw_params:
use_persistent_hnsw_params = draw(with_persistent_hnsw_params)

if use_persistent_hnsw_params and not with_hnsw_params:
raise ValueError(
"with_hnsw_params requires with_persistent_hnsw_params to be true"
"with_persistent_hnsw_params requires with_hnsw_params to be true"
)

if with_hnsw_params:
if metadata is None:
metadata = {}
metadata.update(test_hnsw_config)
if with_persistent_hnsw_params:
if use_persistent_hnsw_params:
metadata["hnsw:batch_size"] = draw(
st.integers(min_value=3, max_value=max_hnsw_batch_size)
)
Expand Down
7 changes: 6 additions & 1 deletion chromadb/test/property/test_cross_version_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,12 @@ def persist_generated_data_with_old_version(

# Since we can't pickle the embedding function, we always generate record sets with embeddings
collection_st: st.SearchStrategy[strategies.Collection] = st.shared(
strategies.collections(with_hnsw_params=True, has_embeddings=True), key="coll"
strategies.collections(
with_hnsw_params=True,
has_embeddings=True,
with_persistent_hnsw_params=st.booleans(),
),
key="coll",
)


Expand Down
2 changes: 1 addition & 1 deletion chromadb/test/property/test_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def settings(request: pytest.FixtureRequest) -> Generator[Settings, None, None]:
collection_st = st.shared(
strategies.collections(
with_hnsw_params=True,
with_persistent_hnsw_params=True,
with_persistent_hnsw_params=st.just(True),
# Makes it more likely to find persist-related bugs (by default these are set to 2000).
max_hnsw_batch_size=10,
max_hnsw_sync_threshold=10,
Expand Down
Loading