From 876e171e8c94a254761886591ab05c8f069396b8 Mon Sep 17 00:00:00 2001 From: George Date: Wed, 31 Jan 2024 16:38:07 +0100 Subject: [PATCH] fix: fix implicit ids in upload collection with paralell > 1 (#460) * fix: fix implicit ids in upload collection with paralell > 1 * fix: fix type hints * fix: remove redundant code, simplify type hints * fix: remove redundant import * fix: fix batching * fix: replace generator with list comprehension * fix: sorry, I was wrong * tests: update tests * fix: extend upload records and upload points tests --- qdrant_client/uploader/grpc_uploader.py | 5 ++ qdrant_client/uploader/rest_uploader.py | 5 ++ qdrant_client/uploader/uploader.py | 15 ++-- tests/test_qdrant_client.py | 95 +++++++++++++++++++++++-- 4 files changed, 105 insertions(+), 15 deletions(-) diff --git a/qdrant_client/uploader/grpc_uploader.py b/qdrant_client/uploader/grpc_uploader.py index 549f0fc4..a419bf2a 100644 --- a/qdrant_client/uploader/grpc_uploader.py +++ b/qdrant_client/uploader/grpc_uploader.py @@ -1,5 +1,7 @@ import logging +from itertools import count from typing import Any, Generator, Iterable, Optional, Tuple, Union +from uuid import uuid4 from qdrant_client import grpc as grpc from qdrant_client.connection import get_channel @@ -19,6 +21,9 @@ def upload_batch_grpc( ) -> bool: ids_batch, vectors_batch, payload_batch = batch + ids_batch = (PointId(uuid=str(uuid4())) for _ in count()) if ids_batch is None else ids_batch + payload_batch = (None for _ in count()) if payload_batch is None else payload_batch + points = [ PointStruct( id=RestToGrpc.convert_extended_point_id(idx) if not isinstance(idx, PointId) else idx, diff --git a/qdrant_client/uploader/rest_uploader.py b/qdrant_client/uploader/rest_uploader.py index b67f19e9..341499e8 100644 --- a/qdrant_client/uploader/rest_uploader.py +++ b/qdrant_client/uploader/rest_uploader.py @@ -1,5 +1,7 @@ import logging +from itertools import count from typing import Any, Generator, Iterable, Optional, Tuple, Union +from uuid import uuid4 import numpy as np @@ -18,6 +20,9 @@ def upload_batch( ) -> bool: ids_batch, vectors_batch, payload_batch = batch + ids_batch = (str(uuid4()) for _ in count()) if ids_batch is None else ids_batch + payload_batch = (None for _ in count()) if payload_batch is None else payload_batch + points = [ PointStruct( id=idx, diff --git a/qdrant_client/uploader/uploader.py b/qdrant_client/uploader/uploader.py index c3ff0c01..3f7a7528 100644 --- a/qdrant_client/uploader/uploader.py +++ b/qdrant_client/uploader/uploader.py @@ -1,7 +1,6 @@ from abc import ABC from itertools import count, islice from typing import Any, Dict, Generator, Iterable, List, Optional, Union -from uuid import uuid4 import numpy as np @@ -24,11 +23,6 @@ def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable: yield b -def uuid_generator() -> Generator[str, None, None]: - while True: - yield str(uuid4()) - - class BaseUploader(Worker, ABC): @classmethod def iterate_records_batches( @@ -58,13 +52,12 @@ def iterate_batches( batch_size: int, ) -> Iterable: if ids is None: - ids = uuid_generator() + ids_batches: Iterable = (None for _ in count()) + else: + ids_batches = iter_batch(ids, batch_size) - ids_batches = iter_batch(ids, batch_size) if payload is None: - payload_batches: Union[Generator, Iterable] = ( - (None for _ in range(batch_size)) for _ in count() - ) + payload_batches: Iterable = (None for _ in count()) else: payload_batches = iter_batch(payload, batch_size) diff --git a/tests/test_qdrant_client.py b/tests/test_qdrant_client.py index 044912e1..8dda9d76 100644 --- a/tests/test_qdrant_client.py +++ b/tests/test_qdrant_client.py @@ -168,7 +168,8 @@ def test_client_init(): @pytest.mark.parametrize("prefer_grpc", [False, True]) -def test_records_upload(prefer_grpc): +@pytest.mark.parametrize("parallel", [1, 2]) +def test_records_upload(prefer_grpc, parallel): import warnings warnings.simplefilter("ignore", category=DeprecationWarning) @@ -186,7 +187,7 @@ def test_records_upload(prefer_grpc): timeout=TIMEOUT, ) - client.upload_records(collection_name=COLLECTION_NAME, records=records, parallel=2) + client.upload_records(collection_name=COLLECTION_NAME, records=records, parallel=parallel) # By default, Qdrant indexes data updates asynchronously, so client don't need to wait before sending next batch # Let's give it a second to actually add all points to a collection. @@ -212,9 +213,26 @@ def test_records_upload(prefer_grpc): assert result_count.count < 900 assert result_count.count > 100 + records = (Record(id=idx, vector=np.random.rand(DIM).tolist()) for idx in range(NUM_VECTORS)) + + client.recreate_collection( + collection_name=COLLECTION_NAME, + vectors_config=VectorParams(size=DIM, distance=Distance.DOT), + timeout=TIMEOUT, + ) + + client.upload_records( + collection_name=COLLECTION_NAME, records=records, parallel=parallel, wait=True + ) + + collection_info = client.get_collection(collection_name=COLLECTION_NAME) + + assert collection_info.points_count == NUM_VECTORS + @pytest.mark.parametrize("prefer_grpc", [False, True]) -def test_point_upload(prefer_grpc): +@pytest.mark.parametrize("parallel", [1, 2]) +def test_point_upload(prefer_grpc, parallel): points = ( PointStruct( id=idx, vector=np.random.rand(DIM).tolist(), payload=one_random_payload_please(idx) @@ -230,7 +248,7 @@ def test_point_upload(prefer_grpc): timeout=TIMEOUT, ) - client.upload_points(collection_name=COLLECTION_NAME, points=points, parallel=2) + client.upload_points(collection_name=COLLECTION_NAME, points=points, parallel=parallel) # By default, Qdrant indexes data updates asynchronously, so client don't need to wait before sending next batch # Let's give it a second to actually add all points to a collection. @@ -256,6 +274,75 @@ def test_point_upload(prefer_grpc): assert result_count.count < 900 assert result_count.count > 100 + client.recreate_collection( + collection_name=COLLECTION_NAME, + vectors_config=VectorParams(size=DIM, distance=Distance.DOT), + timeout=TIMEOUT, + ) + + points = ( + PointStruct(id=idx, vector=np.random.rand(DIM).tolist()) for idx in range(NUM_VECTORS) + ) + + client.upload_points( + collection_name=COLLECTION_NAME, points=points, parallel=parallel, wait=True + ) + + collection_info = client.get_collection(collection_name=COLLECTION_NAME) + + assert collection_info.points_count == NUM_VECTORS + + +@pytest.mark.parametrize("prefer_grpc", [False, True]) +@pytest.mark.parametrize("parallel", [1, 2]) +def test_upload_collection(prefer_grpc, parallel): + size = 3 + batch_size = 2 + client = QdrantClient(prefer_grpc=prefer_grpc, timeout=TIMEOUT) + + client.recreate_collection( + collection_name=COLLECTION_NAME, + vectors_config=VectorParams(size=size, distance=Distance.DOT), + timeout=TIMEOUT, + ) + vectors = [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0], + [10.0, 11.0, 12.0], + [13.0, 14.0, 15.0], + ] + payload = [{"a": 2}, {"b": 3}, {"c": 4}, {"d": 5}, {"e": 6}] + ids = [1, 2, 3, 4, 5] + + client.upload_collection( + collection_name=COLLECTION_NAME, + vectors=vectors, + parallel=parallel, + wait=True, + batch_size=batch_size, + ) + + assert client.get_collection(collection_name=COLLECTION_NAME).points_count == 5 + + client.recreate_collection( + collection_name=COLLECTION_NAME, + vectors_config=VectorParams(size=size, distance=Distance.DOT), + timeout=TIMEOUT, + ) + + client.upload_collection( + collection_name=COLLECTION_NAME, + vectors=vectors, + payload=payload, + ids=ids, + parallel=parallel, + wait=True, + batch_size=batch_size, + ) + + assert client.get_collection(collection_name=COLLECTION_NAME).points_count == 5 + @pytest.mark.parametrize("prefer_grpc", [False, True]) def test_multiple_vectors(prefer_grpc):