Skip to content

Commit

Permalink
fix: fix implicit ids in upload collection with paralell > 1 (#460)
Browse files Browse the repository at this point in the history
* fix: fix implicit ids in upload collection with paralell > 1

* fix: fix type hints

* fix: remove redundant code, simplify type hints

* fix: remove redundant import

* fix: fix batching

* fix: replace generator with list comprehension

* fix: sorry, I was wrong

* tests: update tests

* fix: extend upload records and upload points tests
  • Loading branch information
joein committed Jan 31, 2024
1 parent 640fe5a commit 876e171
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 15 deletions.
5 changes: 5 additions & 0 deletions qdrant_client/uploader/grpc_uploader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
from itertools import count
from typing import Any, Generator, Iterable, Optional, Tuple, Union
from uuid import uuid4

from qdrant_client import grpc as grpc
from qdrant_client.connection import get_channel
Expand All @@ -19,6 +21,9 @@ def upload_batch_grpc(
) -> bool:
ids_batch, vectors_batch, payload_batch = batch

ids_batch = (PointId(uuid=str(uuid4())) for _ in count()) if ids_batch is None else ids_batch
payload_batch = (None for _ in count()) if payload_batch is None else payload_batch

points = [
PointStruct(
id=RestToGrpc.convert_extended_point_id(idx) if not isinstance(idx, PointId) else idx,
Expand Down
5 changes: 5 additions & 0 deletions qdrant_client/uploader/rest_uploader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
from itertools import count
from typing import Any, Generator, Iterable, Optional, Tuple, Union
from uuid import uuid4

import numpy as np

Expand All @@ -18,6 +20,9 @@ def upload_batch(
) -> bool:
ids_batch, vectors_batch, payload_batch = batch

ids_batch = (str(uuid4()) for _ in count()) if ids_batch is None else ids_batch
payload_batch = (None for _ in count()) if payload_batch is None else payload_batch

points = [
PointStruct(
id=idx,
Expand Down
15 changes: 4 additions & 11 deletions qdrant_client/uploader/uploader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from abc import ABC
from itertools import count, islice
from typing import Any, Dict, Generator, Iterable, List, Optional, Union
from uuid import uuid4

import numpy as np

Expand All @@ -24,11 +23,6 @@ def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable:
yield b


def uuid_generator() -> Generator[str, None, None]:
while True:
yield str(uuid4())


class BaseUploader(Worker, ABC):
@classmethod
def iterate_records_batches(
Expand Down Expand Up @@ -58,13 +52,12 @@ def iterate_batches(
batch_size: int,
) -> Iterable:
if ids is None:
ids = uuid_generator()
ids_batches: Iterable = (None for _ in count())
else:
ids_batches = iter_batch(ids, batch_size)

ids_batches = iter_batch(ids, batch_size)
if payload is None:
payload_batches: Union[Generator, Iterable] = (
(None for _ in range(batch_size)) for _ in count()
)
payload_batches: Iterable = (None for _ in count())
else:
payload_batches = iter_batch(payload, batch_size)

Expand Down
95 changes: 91 additions & 4 deletions tests/test_qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ def test_client_init():


@pytest.mark.parametrize("prefer_grpc", [False, True])
def test_records_upload(prefer_grpc):
@pytest.mark.parametrize("parallel", [1, 2])
def test_records_upload(prefer_grpc, parallel):
import warnings

warnings.simplefilter("ignore", category=DeprecationWarning)
Expand All @@ -186,7 +187,7 @@ def test_records_upload(prefer_grpc):
timeout=TIMEOUT,
)

client.upload_records(collection_name=COLLECTION_NAME, records=records, parallel=2)
client.upload_records(collection_name=COLLECTION_NAME, records=records, parallel=parallel)

# By default, Qdrant indexes data updates asynchronously, so client don't need to wait before sending next batch
# Let's give it a second to actually add all points to a collection.
Expand All @@ -212,9 +213,26 @@ def test_records_upload(prefer_grpc):
assert result_count.count < 900
assert result_count.count > 100

records = (Record(id=idx, vector=np.random.rand(DIM).tolist()) for idx in range(NUM_VECTORS))

client.recreate_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(size=DIM, distance=Distance.DOT),
timeout=TIMEOUT,
)

client.upload_records(
collection_name=COLLECTION_NAME, records=records, parallel=parallel, wait=True
)

collection_info = client.get_collection(collection_name=COLLECTION_NAME)

assert collection_info.points_count == NUM_VECTORS


@pytest.mark.parametrize("prefer_grpc", [False, True])
def test_point_upload(prefer_grpc):
@pytest.mark.parametrize("parallel", [1, 2])
def test_point_upload(prefer_grpc, parallel):
points = (
PointStruct(
id=idx, vector=np.random.rand(DIM).tolist(), payload=one_random_payload_please(idx)
Expand All @@ -230,7 +248,7 @@ def test_point_upload(prefer_grpc):
timeout=TIMEOUT,
)

client.upload_points(collection_name=COLLECTION_NAME, points=points, parallel=2)
client.upload_points(collection_name=COLLECTION_NAME, points=points, parallel=parallel)

# By default, Qdrant indexes data updates asynchronously, so client don't need to wait before sending next batch
# Let's give it a second to actually add all points to a collection.
Expand All @@ -256,6 +274,75 @@ def test_point_upload(prefer_grpc):
assert result_count.count < 900
assert result_count.count > 100

client.recreate_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(size=DIM, distance=Distance.DOT),
timeout=TIMEOUT,
)

points = (
PointStruct(id=idx, vector=np.random.rand(DIM).tolist()) for idx in range(NUM_VECTORS)
)

client.upload_points(
collection_name=COLLECTION_NAME, points=points, parallel=parallel, wait=True
)

collection_info = client.get_collection(collection_name=COLLECTION_NAME)

assert collection_info.points_count == NUM_VECTORS


@pytest.mark.parametrize("prefer_grpc", [False, True])
@pytest.mark.parametrize("parallel", [1, 2])
def test_upload_collection(prefer_grpc, parallel):
size = 3
batch_size = 2
client = QdrantClient(prefer_grpc=prefer_grpc, timeout=TIMEOUT)

client.recreate_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(size=size, distance=Distance.DOT),
timeout=TIMEOUT,
)
vectors = [
[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0],
[10.0, 11.0, 12.0],
[13.0, 14.0, 15.0],
]
payload = [{"a": 2}, {"b": 3}, {"c": 4}, {"d": 5}, {"e": 6}]
ids = [1, 2, 3, 4, 5]

client.upload_collection(
collection_name=COLLECTION_NAME,
vectors=vectors,
parallel=parallel,
wait=True,
batch_size=batch_size,
)

assert client.get_collection(collection_name=COLLECTION_NAME).points_count == 5

client.recreate_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(size=size, distance=Distance.DOT),
timeout=TIMEOUT,
)

client.upload_collection(
collection_name=COLLECTION_NAME,
vectors=vectors,
payload=payload,
ids=ids,
parallel=parallel,
wait=True,
batch_size=batch_size,
)

assert client.get_collection(collection_name=COLLECTION_NAME).points_count == 5


@pytest.mark.parametrize("prefer_grpc", [False, True])
def test_multiple_vectors(prefer_grpc):
Expand Down

0 comments on commit 876e171

Please sign in to comment.