diff --git a/tests/test_qdrant_client.py b/tests/test_qdrant_client.py index 850e391b..8dda9d76 100644 --- a/tests/test_qdrant_client.py +++ b/tests/test_qdrant_client.py @@ -213,6 +213,22 @@ def test_records_upload(prefer_grpc, parallel): assert result_count.count < 900 assert result_count.count > 100 + records = (Record(id=idx, vector=np.random.rand(DIM).tolist()) for idx in range(NUM_VECTORS)) + + client.recreate_collection( + collection_name=COLLECTION_NAME, + vectors_config=VectorParams(size=DIM, distance=Distance.DOT), + timeout=TIMEOUT, + ) + + client.upload_records( + collection_name=COLLECTION_NAME, records=records, parallel=parallel, wait=True + ) + + collection_info = client.get_collection(collection_name=COLLECTION_NAME) + + assert collection_info.points_count == NUM_VECTORS + @pytest.mark.parametrize("prefer_grpc", [False, True]) @pytest.mark.parametrize("parallel", [1, 2]) @@ -258,6 +274,24 @@ def test_point_upload(prefer_grpc, parallel): assert result_count.count < 900 assert result_count.count > 100 + client.recreate_collection( + collection_name=COLLECTION_NAME, + vectors_config=VectorParams(size=DIM, distance=Distance.DOT), + timeout=TIMEOUT, + ) + + points = ( + PointStruct(id=idx, vector=np.random.rand(DIM).tolist()) for idx in range(NUM_VECTORS) + ) + + client.upload_points( + collection_name=COLLECTION_NAME, points=points, parallel=parallel, wait=True + ) + + collection_info = client.get_collection(collection_name=COLLECTION_NAME) + + assert collection_info.points_count == NUM_VECTORS + @pytest.mark.parametrize("prefer_grpc", [False, True]) @pytest.mark.parametrize("parallel", [1, 2])