From fb75375e40921d95e9c93c9e0d3579bfcea3e193 Mon Sep 17 00:00:00 2001 From: George Date: Tue, 30 Jan 2024 22:58:18 +0100 Subject: [PATCH] fix: convert some types to python jsonable types (#462) * fix: convert some types to python jsonable types * fix: fix nan validation * fix: update overwrite payload and set payload * tests: update not jsonable payload tests --- qdrant_client/local/local_collection.py | 45 +++++++----- tests/congruence_tests/test_payload.py | 91 +++++++++++++++++++++++++ 2 files changed, 120 insertions(+), 16 deletions(-) diff --git a/qdrant_client/local/local_collection.py b/qdrant_client/local/local_collection.py index f6e561d5..05070b42 100644 --- a/qdrant_client/local/local_collection.py +++ b/qdrant_client/local/local_collection.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, get_args import numpy as np +from pydantic.version import VERSION as PYDANTIC_VERSION from qdrant_client import grpc as grpc from qdrant_client._pydantic_compat import construct @@ -37,6 +38,23 @@ DEFAULT_VECTOR_NAME = "" EPSILON = 1.1920929e-7 # https://doc.rust-lang.org/std/f32/constant.EPSILON.html # https://github.com/qdrant/qdrant/blob/7164ac4a5987d28f1c93f5712aef8e09e7d93555/lib/segment/src/spaces/simple_avx.rs#L99C10-L99C10 +PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") + +if PYDANTIC_V2: + from pydantic_core import to_jsonable_python as _to_jsonable_python +else: + from pydantic.json import ENCODERS_BY_TYPE + + def _to_jsonable_python(x: Any) -> Any: + return ENCODERS_BY_TYPE[type(x)](x) + + +def to_jsonable_python(x: Any) -> Any: + try: + json.dumps(x, allow_nan=False) + return x + except Exception: + return json.loads(json.dumps(x, allow_nan=False, default=_to_jsonable_python)) class LocalCollection: @@ -101,7 +119,7 @@ def load_vectors(self) -> None: self.ids_inv.append(point.id) # payload tracker - self.payload.append(point.payload or {}) + self.payload.append(to_jsonable_python(point.payload) or {}) # persisted named vectors loaded_vector = point.vector @@ -957,11 +975,7 @@ def count(self, count_filter: Optional[types.Filter] = None) -> models.CountResu def _update_point(self, point: models.PointStruct) -> None: idx = self.ids[point.id] - self.payload[idx] = ( - point.payload - if point.payload is not None and json.dumps(point.payload, allow_nan=False) - else {} - ) + self.payload[idx] = to_jsonable_python(point.payload) if point.payload is not None else {} if isinstance(point.vector, list): vectors = {DEFAULT_VECTOR_NAME: point.vector} @@ -996,11 +1010,8 @@ def _add_point(self, point: models.PointStruct) -> None: idx = len(self.ids) self.ids[point.id] = idx self.ids_inv.append(point.id) - self.payload.append( - point.payload - if point.payload is not None and json.dumps(point.payload, allow_nan=False) - else {} - ) + + self.payload.append(to_jsonable_python(point.payload) if point.payload is not None else {}) assert len(self.payload) == len(self.ids_inv), "Payload and ids_inv must be the same size" self.deleted = np.append(self.deleted, 0) @@ -1226,10 +1237,12 @@ def set_payload( ids = self._selector_to_ids(selector) for point_id in ids: idx = self.ids[point_id] - self.payload[idx] = { - **(self.payload[idx] or {}), - **payload, - } + self.payload[idx] = to_jsonable_python( + { + **(self.payload[idx] or {}), + **payload, + } + ) self._persist_by_id(point_id) def overwrite_payload( @@ -1245,7 +1258,7 @@ def overwrite_payload( ids = self._selector_to_ids(selector) for point_id in ids: idx = self.ids[point_id] - self.payload[idx] = payload or {} + self.payload[idx] = to_jsonable_python(payload) or {} self._persist_by_id(point_id) def delete_payload( diff --git a/tests/congruence_tests/test_payload.py b/tests/congruence_tests/test_payload.py index cb8fa609..475e4535 100644 --- a/tests/congruence_tests/test_payload.py +++ b/tests/congruence_tests/test_payload.py @@ -4,6 +4,8 @@ COLLECTION_NAME, compare_collections, generate_fixtures, + init_local, + init_remote, ) NUM_VECTORS = 100 @@ -140,3 +142,92 @@ def test_update_payload(local_client: QdrantClient, remote_client: QdrantClient) # endregion compare_collections(local_client, remote_client, NUM_VECTORS) # sanity check + + +def test_not_jsonable_payload(): + import datetime + import random + import uuid + + local_client = init_local() + remote_client = init_remote() + + vector_size = 2 + vectors_config = models.VectorParams(size=vector_size, distance=models.Distance.COSINE) + local_client.recreate_collection( + collection_name=COLLECTION_NAME, + vectors_config=vectors_config, + ) + remote_client.recreate_collection( + collection_name=COLLECTION_NAME, + vectors_config=vectors_config, + ) + + # subset of types from pydantic.json.ENCODERS_BY_TYPE (pydantic v1) + + payloads = [ + {"bytes": b"123"}, + {"date": datetime.date(2021, 1, 1)}, + {"datetime": datetime.datetime(2021, 1, 1, 1, 1, 1)}, + {"time": datetime.time(1, 1, 1)}, + {"timedelta": datetime.timedelta(seconds=1)}, + {"decimal": 1.0}, + {"frozenset": frozenset([1, 2])}, + {"set": {1, 2}}, + {"uuid": uuid.uuid4()}, + ] + + points = [ + models.PointStruct(id=i, vector=[random.random(), random.random()], payload=payload) + for i, payload in enumerate(payloads) + ] + + for point in points: # for better debugging + local_client.upsert(COLLECTION_NAME, [point]) + remote_client.upsert(COLLECTION_NAME, [point]) + + compare_collections(local_client, remote_client, len(points)) + + local_client.recreate_collection( + collection_name=COLLECTION_NAME, + vectors_config=vectors_config, + ) + remote_client.recreate_collection( + collection_name=COLLECTION_NAME, + vectors_config=vectors_config, + ) + + point_ids = [] + for point in points: + point.payload = None + point_ids.append(point.id) + local_client.upsert(COLLECTION_NAME, [point]) + remote_client.upsert(COLLECTION_NAME, [point]) + + for point_id, payload in zip(point_ids, payloads): + local_client.set_payload( + COLLECTION_NAME, + payload, + models.Filter(must=[models.HasIdCondition(has_id=[point_id])]), + ) + remote_client.set_payload( + COLLECTION_NAME, + payload, + models.Filter(must=[models.HasIdCondition(has_id=[point_id])]), + ) + + compare_collections(local_client, remote_client, len(points)) + + for point_id, payload in zip(point_ids[::-1], payloads): + local_client.overwrite_payload( + COLLECTION_NAME, + payload, + models.Filter(must=[models.HasIdCondition(has_id=[point_id])]), + ) + remote_client.overwrite_payload( + COLLECTION_NAME, + payload, + models.Filter(must=[models.HasIdCondition(has_id=[point_id])]), + ) + + compare_collections(local_client, remote_client, len(points))