Skip to content

Commit

Permalink
fix: convert some types to python jsonable types (#462)
Browse files Browse the repository at this point in the history
* fix: convert some types to python jsonable types

* fix: fix nan validation

* fix: update overwrite payload and set payload

* tests: update not jsonable payload tests
  • Loading branch information
joein authored Jan 30, 2024
1 parent aff0432 commit fb75375
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 16 deletions.
45 changes: 29 additions & 16 deletions qdrant_client/local/local_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, get_args

import numpy as np
from pydantic.version import VERSION as PYDANTIC_VERSION

from qdrant_client import grpc as grpc
from qdrant_client._pydantic_compat import construct
Expand Down Expand Up @@ -37,6 +38,23 @@
DEFAULT_VECTOR_NAME = ""
EPSILON = 1.1920929e-7 # https://doc.rust-lang.org/std/f32/constant.EPSILON.html
# https://github.com/qdrant/qdrant/blob/7164ac4a5987d28f1c93f5712aef8e09e7d93555/lib/segment/src/spaces/simple_avx.rs#L99C10-L99C10
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")

if PYDANTIC_V2:
from pydantic_core import to_jsonable_python as _to_jsonable_python
else:
from pydantic.json import ENCODERS_BY_TYPE

def _to_jsonable_python(x: Any) -> Any:
return ENCODERS_BY_TYPE[type(x)](x)


def to_jsonable_python(x: Any) -> Any:
try:
json.dumps(x, allow_nan=False)
return x
except Exception:
return json.loads(json.dumps(x, allow_nan=False, default=_to_jsonable_python))


class LocalCollection:
Expand Down Expand Up @@ -101,7 +119,7 @@ def load_vectors(self) -> None:
self.ids_inv.append(point.id)

# payload tracker
self.payload.append(point.payload or {})
self.payload.append(to_jsonable_python(point.payload) or {})

# persisted named vectors
loaded_vector = point.vector
Expand Down Expand Up @@ -957,11 +975,7 @@ def count(self, count_filter: Optional[types.Filter] = None) -> models.CountResu

def _update_point(self, point: models.PointStruct) -> None:
idx = self.ids[point.id]
self.payload[idx] = (
point.payload
if point.payload is not None and json.dumps(point.payload, allow_nan=False)
else {}
)
self.payload[idx] = to_jsonable_python(point.payload) if point.payload is not None else {}

if isinstance(point.vector, list):
vectors = {DEFAULT_VECTOR_NAME: point.vector}
Expand Down Expand Up @@ -996,11 +1010,8 @@ def _add_point(self, point: models.PointStruct) -> None:
idx = len(self.ids)
self.ids[point.id] = idx
self.ids_inv.append(point.id)
self.payload.append(
point.payload
if point.payload is not None and json.dumps(point.payload, allow_nan=False)
else {}
)

self.payload.append(to_jsonable_python(point.payload) if point.payload is not None else {})
assert len(self.payload) == len(self.ids_inv), "Payload and ids_inv must be the same size"
self.deleted = np.append(self.deleted, 0)

Expand Down Expand Up @@ -1226,10 +1237,12 @@ def set_payload(
ids = self._selector_to_ids(selector)
for point_id in ids:
idx = self.ids[point_id]
self.payload[idx] = {
**(self.payload[idx] or {}),
**payload,
}
self.payload[idx] = to_jsonable_python(
{
**(self.payload[idx] or {}),
**payload,
}
)
self._persist_by_id(point_id)

def overwrite_payload(
Expand All @@ -1245,7 +1258,7 @@ def overwrite_payload(
ids = self._selector_to_ids(selector)
for point_id in ids:
idx = self.ids[point_id]
self.payload[idx] = payload or {}
self.payload[idx] = to_jsonable_python(payload) or {}
self._persist_by_id(point_id)

def delete_payload(
Expand Down
91 changes: 91 additions & 0 deletions tests/congruence_tests/test_payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
COLLECTION_NAME,
compare_collections,
generate_fixtures,
init_local,
init_remote,
)

NUM_VECTORS = 100
Expand Down Expand Up @@ -140,3 +142,92 @@ def test_update_payload(local_client: QdrantClient, remote_client: QdrantClient)
# endregion

compare_collections(local_client, remote_client, NUM_VECTORS) # sanity check


def test_not_jsonable_payload():
import datetime
import random
import uuid

local_client = init_local()
remote_client = init_remote()

vector_size = 2
vectors_config = models.VectorParams(size=vector_size, distance=models.Distance.COSINE)
local_client.recreate_collection(
collection_name=COLLECTION_NAME,
vectors_config=vectors_config,
)
remote_client.recreate_collection(
collection_name=COLLECTION_NAME,
vectors_config=vectors_config,
)

# subset of types from pydantic.json.ENCODERS_BY_TYPE (pydantic v1)

payloads = [
{"bytes": b"123"},
{"date": datetime.date(2021, 1, 1)},
{"datetime": datetime.datetime(2021, 1, 1, 1, 1, 1)},
{"time": datetime.time(1, 1, 1)},
{"timedelta": datetime.timedelta(seconds=1)},
{"decimal": 1.0},
{"frozenset": frozenset([1, 2])},
{"set": {1, 2}},
{"uuid": uuid.uuid4()},
]

points = [
models.PointStruct(id=i, vector=[random.random(), random.random()], payload=payload)
for i, payload in enumerate(payloads)
]

for point in points: # for better debugging
local_client.upsert(COLLECTION_NAME, [point])
remote_client.upsert(COLLECTION_NAME, [point])

compare_collections(local_client, remote_client, len(points))

local_client.recreate_collection(
collection_name=COLLECTION_NAME,
vectors_config=vectors_config,
)
remote_client.recreate_collection(
collection_name=COLLECTION_NAME,
vectors_config=vectors_config,
)

point_ids = []
for point in points:
point.payload = None
point_ids.append(point.id)
local_client.upsert(COLLECTION_NAME, [point])
remote_client.upsert(COLLECTION_NAME, [point])

for point_id, payload in zip(point_ids, payloads):
local_client.set_payload(
COLLECTION_NAME,
payload,
models.Filter(must=[models.HasIdCondition(has_id=[point_id])]),
)
remote_client.set_payload(
COLLECTION_NAME,
payload,
models.Filter(must=[models.HasIdCondition(has_id=[point_id])]),
)

compare_collections(local_client, remote_client, len(points))

for point_id, payload in zip(point_ids[::-1], payloads):
local_client.overwrite_payload(
COLLECTION_NAME,
payload,
models.Filter(must=[models.HasIdCondition(has_id=[point_id])]),
)
remote_client.overwrite_payload(
COLLECTION_NAME,
payload,
models.Filter(must=[models.HasIdCondition(has_id=[point_id])]),
)

compare_collections(local_client, remote_client, len(points))

0 comments on commit fb75375

Please sign in to comment.