diff --git a/qdrant_client/local/local_collection.py b/qdrant_client/local/local_collection.py index 442ae82a..05070b42 100644 --- a/qdrant_client/local/local_collection.py +++ b/qdrant_client/local/local_collection.py @@ -41,18 +41,20 @@ PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") if PYDANTIC_V2: - from pydantic_core import to_jsonable_python + from pydantic_core import to_jsonable_python as _to_jsonable_python else: from pydantic.json import ENCODERS_BY_TYPE - def to_jsonable_python(x: Any) -> Any: - try: - json.dumps(x, allow_nan=False) - return x - except Exception: - return json.loads( - json.dumps(x, allow_nan=False, default=lambda y: ENCODERS_BY_TYPE[type(y)](y)) - ) + def _to_jsonable_python(x: Any) -> Any: + return ENCODERS_BY_TYPE[type(x)](x) + + +def to_jsonable_python(x: Any) -> Any: + try: + json.dumps(x, allow_nan=False) + return x + except Exception: + return json.loads(json.dumps(x, allow_nan=False, default=_to_jsonable_python)) class LocalCollection: @@ -117,7 +119,7 @@ def load_vectors(self) -> None: self.ids_inv.append(point.id) # payload tracker - self.payload.append(point.payload or {}) + self.payload.append(to_jsonable_python(point.payload) or {}) # persisted named vectors loaded_vector = point.vector @@ -1235,10 +1237,12 @@ def set_payload( ids = self._selector_to_ids(selector) for point_id in ids: idx = self.ids[point_id] - self.payload[idx] = { - **(self.payload[idx] or {}), - **payload, - } + self.payload[idx] = to_jsonable_python( + { + **(self.payload[idx] or {}), + **payload, + } + ) self._persist_by_id(point_id) def overwrite_payload( @@ -1254,7 +1258,7 @@ def overwrite_payload( ids = self._selector_to_ids(selector) for point_id in ids: idx = self.ids[point_id] - self.payload[idx] = payload or {} + self.payload[idx] = to_jsonable_python(payload) or {} self._persist_by_id(point_id) def delete_payload( diff --git a/tests/congruence_tests/test_payload.py b/tests/congruence_tests/test_payload.py index cb8fa609..2d35680d 100644 --- a/tests/congruence_tests/test_payload.py +++ b/tests/congruence_tests/test_payload.py @@ -4,6 +4,8 @@ COLLECTION_NAME, compare_collections, generate_fixtures, + init_local, + init_remote, ) NUM_VECTORS = 100 @@ -140,3 +142,48 @@ def test_update_payload(local_client: QdrantClient, remote_client: QdrantClient) # endregion compare_collections(local_client, remote_client, NUM_VECTORS) # sanity check + + +def test_upsert_payload(): + import datetime + import random + import uuid + + local_client = init_local() + remote_client = init_remote() + + vector_size = 2 + vectors_config = models.VectorParams(size=vector_size, distance=models.Distance.COSINE) + local_client.recreate_collection( + collection_name=COLLECTION_NAME, + vectors_config=vectors_config, + ) + remote_client.recreate_collection( + collection_name=COLLECTION_NAME, + vectors_config=vectors_config, + ) + + # subset of types from pydantic.json.ENCODERS_BY_TYPE (pydantic v1) + + payloads = [ + {"bytes": b"123"}, + {"date": datetime.date(2021, 1, 1)}, + {"datetime": datetime.datetime(2021, 1, 1, 1, 1, 1)}, + {"time": datetime.time(1, 1, 1)}, + {"timedelta": datetime.timedelta(seconds=1)}, + {"decimal": 1.0}, + {"frozenset": frozenset([1, 2])}, + {"set": {1, 2}}, + {"uuid": uuid.uuid4()}, + ] + + points = [ + models.PointStruct(id=i, vector=[random.random(), random.random()], payload=payload) + for i, payload in enumerate(payloads) + ] + + for point in points: # for better debugging + local_client.upsert(COLLECTION_NAME, [point]) + remote_client.upsert(COLLECTION_NAME, [point]) + + compare_collections(local_client, remote_client, len(points)) diff --git a/tests/congruence_tests/test_updates.py b/tests/congruence_tests/test_updates.py index 64c918ca..b16ce64c 100644 --- a/tests/congruence_tests/test_updates.py +++ b/tests/congruence_tests/test_updates.py @@ -385,51 +385,3 @@ def test_upload_wrong_vectors(): wrong_vectors_collection, points=[models.PointStruct(id=1, vector=unnamed_vector)], ) - - -def test_upsert_payload(): - import datetime - import ipaddress - import random - import re - - from pydantic import SecretBytes, SecretStr - from pydantic.networks import NameEmail - - local_client = init_local() - remote_client = init_remote() - - vector_size = 2 - vectors_config = models.VectorParams(size=vector_size, distance=models.Distance.COSINE) - local_client.recreate_collection( - collection_name=COLLECTION_NAME, - vectors_config=vectors_config, - ) - remote_client.recreate_collection( - collection_name=COLLECTION_NAME, - vectors_config=vectors_config, - ) - - # subset of types from pydantic.json.ENCODERS_BY_TYPE (pydantic v1) - - payloads = [ - {"bytes": b"123"}, - {"date": datetime.date(2021, 1, 1)}, - {"datetime": datetime.datetime(2021, 1, 1, 1, 1, 1)}, - {"time": datetime.time(1, 1, 1)}, - {"timedelta": datetime.timedelta(seconds=1)}, - {"decimal": 1.0}, - {"frozenset": frozenset([1, 2])}, - {"set": {1, 2}}, - {"uuid": uuid.uuid4()}, - ] - - points = [ - models.PointStruct(id=i, vector=[random.random(), random.random()], payload=payload) - for i, payload in enumerate(payloads) - ] - - for point in points: # for better debugging - local_client.upsert(COLLECTION_NAME, [point]) - remote_client.upsert(COLLECTION_NAME, [point]) - compare_collections(local_client, remote_client, len(points))