Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix implicit ids in upload collection with paralell > 1 #460

Merged
merged 9 commits into from
Jan 31, 2024
Merged
5 changes: 5 additions & 0 deletions qdrant_client/uploader/grpc_uploader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
from itertools import count
from typing import Any, Generator, Iterable, Optional, Tuple, Union
from uuid import uuid4

from qdrant_client import grpc as grpc
from qdrant_client.connection import get_channel
Expand All @@ -19,6 +21,9 @@ def upload_batch_grpc(
) -> bool:
ids_batch, vectors_batch, payload_batch = batch

ids_batch = (PointId(uuid=str(uuid4())) for _ in count()) if ids_batch is None else ids_batch
payload_batch = (None for _ in count()) if payload_batch is None else payload_batch

points = [
PointStruct(
id=RestToGrpc.convert_extended_point_id(idx) if not isinstance(idx, PointId) else idx,
Expand Down
5 changes: 5 additions & 0 deletions qdrant_client/uploader/rest_uploader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
from itertools import count
from typing import Any, Generator, Iterable, Optional, Tuple, Union
from uuid import uuid4

import numpy as np

Expand All @@ -18,6 +20,9 @@ def upload_batch(
) -> bool:
ids_batch, vectors_batch, payload_batch = batch

ids_batch = (str(uuid4()) for _ in count()) if ids_batch is None else ids_batch
payload_batch = (None for _ in count()) if payload_batch is None else payload_batch

points = [
PointStruct(
id=idx,
Expand Down
15 changes: 4 additions & 11 deletions qdrant_client/uploader/uploader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from abc import ABC
from itertools import count, islice
from typing import Any, Dict, Generator, Iterable, List, Optional, Union
from uuid import uuid4

import numpy as np

Expand All @@ -24,11 +23,6 @@ def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable:
yield b


def uuid_generator() -> Generator[str, None, None]:
while True:
yield str(uuid4())


class BaseUploader(Worker, ABC):
@classmethod
def iterate_records_batches(
Expand Down Expand Up @@ -58,13 +52,12 @@ def iterate_batches(
batch_size: int,
) -> Iterable:
if ids is None:
ids = uuid_generator()
ids_batches: Iterable = (None for _ in count())
else:
ids_batches = iter_batch(ids, batch_size)

ids_batches = iter_batch(ids, batch_size)
if payload is None:
payload_batches: Union[Generator, Iterable] = (
(None for _ in range(batch_size)) for _ in count()
)
payload_batches: Iterable = (None for _ in count())
else:
payload_batches = iter_batch(payload, batch_size)

Expand Down
95 changes: 91 additions & 4 deletions tests/test_qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ def test_client_init():


@pytest.mark.parametrize("prefer_grpc", [False, True])
def test_records_upload(prefer_grpc):
@pytest.mark.parametrize("parallel", [1, 2])
def test_records_upload(prefer_grpc, parallel):
import warnings

warnings.simplefilter("ignore", category=DeprecationWarning)
Expand All @@ -186,7 +187,7 @@ def test_records_upload(prefer_grpc):
timeout=TIMEOUT,
)

client.upload_records(collection_name=COLLECTION_NAME, records=records, parallel=2)
client.upload_records(collection_name=COLLECTION_NAME, records=records, parallel=parallel)

# By default, Qdrant indexes data updates asynchronously, so client don't need to wait before sending next batch
# Let's give it a second to actually add all points to a collection.
Expand All @@ -212,9 +213,26 @@ def test_records_upload(prefer_grpc):
assert result_count.count < 900
assert result_count.count > 100

records = (Record(id=idx, vector=np.random.rand(DIM).tolist()) for idx in range(NUM_VECTORS))

client.recreate_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(size=DIM, distance=Distance.DOT),
timeout=TIMEOUT,
)

client.upload_records(
collection_name=COLLECTION_NAME, records=records, parallel=parallel, wait=True
)

collection_info = client.get_collection(collection_name=COLLECTION_NAME)

assert collection_info.points_count == NUM_VECTORS


@pytest.mark.parametrize("prefer_grpc", [False, True])
def test_point_upload(prefer_grpc):
@pytest.mark.parametrize("parallel", [1, 2])
def test_point_upload(prefer_grpc, parallel):
points = (
PointStruct(
id=idx, vector=np.random.rand(DIM).tolist(), payload=one_random_payload_please(idx)
Expand All @@ -230,7 +248,7 @@ def test_point_upload(prefer_grpc):
timeout=TIMEOUT,
)

client.upload_points(collection_name=COLLECTION_NAME, points=points, parallel=2)
client.upload_points(collection_name=COLLECTION_NAME, points=points, parallel=parallel)

# By default, Qdrant indexes data updates asynchronously, so client don't need to wait before sending next batch
# Let's give it a second to actually add all points to a collection.
Expand All @@ -256,6 +274,75 @@ def test_point_upload(prefer_grpc):
assert result_count.count < 900
assert result_count.count > 100

client.recreate_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(size=DIM, distance=Distance.DOT),
timeout=TIMEOUT,
)

points = (
PointStruct(id=idx, vector=np.random.rand(DIM).tolist()) for idx in range(NUM_VECTORS)
)

client.upload_points(
collection_name=COLLECTION_NAME, points=points, parallel=parallel, wait=True
)

collection_info = client.get_collection(collection_name=COLLECTION_NAME)

assert collection_info.points_count == NUM_VECTORS


@pytest.mark.parametrize("prefer_grpc", [False, True])
@pytest.mark.parametrize("parallel", [1, 2])
def test_upload_collection(prefer_grpc, parallel):
size = 3
batch_size = 2
client = QdrantClient(prefer_grpc=prefer_grpc, timeout=TIMEOUT)

client.recreate_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(size=size, distance=Distance.DOT),
timeout=TIMEOUT,
)
vectors = [
[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0],
[10.0, 11.0, 12.0],
[13.0, 14.0, 15.0],
]
Comment on lines +308 to +314
Copy link
Contributor

@coszio coszio Jan 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outside the scope of this PR, but right now the behavior is to stop at the shortest iterator of any of ids, vectors, or ids. Is it possible to emit a warning when this happens? E.g, for when it stopped with any of those un-exhausted

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we can consider it as a separate issue

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've just realised that having iterators of different length is a valid scenario, e.g. it is valid when ids iterator is infinite.

We can only check the number of ids/payloads/vectors right before making a request, however this check won't help when the smallest iterator is divisible by batch_size

payload = [{"a": 2}, {"b": 3}, {"c": 4}, {"d": 5}, {"e": 6}]
ids = [1, 2, 3, 4, 5]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can make another test for locking in the behavior of auto-generating ids when ids = None

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is actually already there, the first call to upload_collection does not provide ids and payload, only vectors

The second call provides all of them - vectors, ids and payload. I put the data into one place because if we change vectors, then ids and payload should also be changed


client.upload_collection(
collection_name=COLLECTION_NAME,
vectors=vectors,
parallel=parallel,
wait=True,
batch_size=batch_size,
)

assert client.get_collection(collection_name=COLLECTION_NAME).points_count == 5

client.recreate_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(size=size, distance=Distance.DOT),
timeout=TIMEOUT,
)

client.upload_collection(
collection_name=COLLECTION_NAME,
vectors=vectors,
payload=payload,
ids=ids,
parallel=parallel,
wait=True,
batch_size=batch_size,
)

assert client.get_collection(collection_name=COLLECTION_NAME).points_count == 5


@pytest.mark.parametrize("prefer_grpc", [False, True])
def test_multiple_vectors(prefer_grpc):
Expand Down
Loading