Skip to content

Commit

Permalink
fix: update overwrite payload and set payload
Browse files Browse the repository at this point in the history
  • Loading branch information
joein committed Jan 26, 2024
1 parent 9d11506 commit a7900b2
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 63 deletions.
34 changes: 19 additions & 15 deletions qdrant_client/local/local_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,20 @@
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")

if PYDANTIC_V2:
from pydantic_core import to_jsonable_python
from pydantic_core import to_jsonable_python as _to_jsonable_python
else:
from pydantic.json import ENCODERS_BY_TYPE

def to_jsonable_python(x: Any) -> Any:
try:
json.dumps(x, allow_nan=False)
return x
except Exception:
return json.loads(
json.dumps(x, allow_nan=False, default=lambda y: ENCODERS_BY_TYPE[type(y)](y))
)
def _to_jsonable_python(x: Any) -> Any:
return ENCODERS_BY_TYPE[type(x)](x)


def to_jsonable_python(x: Any) -> Any:
try:
json.dumps(x, allow_nan=False)
return x
except Exception:
return json.loads(json.dumps(x, allow_nan=False, default=_to_jsonable_python))


class LocalCollection:
Expand Down Expand Up @@ -117,7 +119,7 @@ def load_vectors(self) -> None:
self.ids_inv.append(point.id)

# payload tracker
self.payload.append(point.payload or {})
self.payload.append(to_jsonable_python(point.payload) or {})

# persisted named vectors
loaded_vector = point.vector
Expand Down Expand Up @@ -1235,10 +1237,12 @@ def set_payload(
ids = self._selector_to_ids(selector)
for point_id in ids:
idx = self.ids[point_id]
self.payload[idx] = {
**(self.payload[idx] or {}),
**payload,
}
self.payload[idx] = to_jsonable_python(
{
**(self.payload[idx] or {}),
**payload,
}
)
self._persist_by_id(point_id)

def overwrite_payload(
Expand All @@ -1254,7 +1258,7 @@ def overwrite_payload(
ids = self._selector_to_ids(selector)
for point_id in ids:
idx = self.ids[point_id]
self.payload[idx] = payload or {}
self.payload[idx] = to_jsonable_python(payload) or {}
self._persist_by_id(point_id)

def delete_payload(
Expand Down
47 changes: 47 additions & 0 deletions tests/congruence_tests/test_payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
COLLECTION_NAME,
compare_collections,
generate_fixtures,
init_local,
init_remote,
)

NUM_VECTORS = 100
Expand Down Expand Up @@ -140,3 +142,48 @@ def test_update_payload(local_client: QdrantClient, remote_client: QdrantClient)
# endregion

compare_collections(local_client, remote_client, NUM_VECTORS) # sanity check


def test_upsert_payload():
import datetime
import random
import uuid

local_client = init_local()
remote_client = init_remote()

vector_size = 2
vectors_config = models.VectorParams(size=vector_size, distance=models.Distance.COSINE)
local_client.recreate_collection(
collection_name=COLLECTION_NAME,
vectors_config=vectors_config,
)
remote_client.recreate_collection(
collection_name=COLLECTION_NAME,
vectors_config=vectors_config,
)

# subset of types from pydantic.json.ENCODERS_BY_TYPE (pydantic v1)

payloads = [
{"bytes": b"123"},
{"date": datetime.date(2021, 1, 1)},
{"datetime": datetime.datetime(2021, 1, 1, 1, 1, 1)},
{"time": datetime.time(1, 1, 1)},
{"timedelta": datetime.timedelta(seconds=1)},
{"decimal": 1.0},
{"frozenset": frozenset([1, 2])},
{"set": {1, 2}},
{"uuid": uuid.uuid4()},
]

points = [
models.PointStruct(id=i, vector=[random.random(), random.random()], payload=payload)
for i, payload in enumerate(payloads)
]

for point in points: # for better debugging
local_client.upsert(COLLECTION_NAME, [point])
remote_client.upsert(COLLECTION_NAME, [point])

compare_collections(local_client, remote_client, len(points))
48 changes: 0 additions & 48 deletions tests/congruence_tests/test_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,51 +385,3 @@ def test_upload_wrong_vectors():
wrong_vectors_collection,
points=[models.PointStruct(id=1, vector=unnamed_vector)],
)


def test_upsert_payload():
import datetime
import ipaddress
import random
import re

from pydantic import SecretBytes, SecretStr
from pydantic.networks import NameEmail

local_client = init_local()
remote_client = init_remote()

vector_size = 2
vectors_config = models.VectorParams(size=vector_size, distance=models.Distance.COSINE)
local_client.recreate_collection(
collection_name=COLLECTION_NAME,
vectors_config=vectors_config,
)
remote_client.recreate_collection(
collection_name=COLLECTION_NAME,
vectors_config=vectors_config,
)

# subset of types from pydantic.json.ENCODERS_BY_TYPE (pydantic v1)

payloads = [
{"bytes": b"123"},
{"date": datetime.date(2021, 1, 1)},
{"datetime": datetime.datetime(2021, 1, 1, 1, 1, 1)},
{"time": datetime.time(1, 1, 1)},
{"timedelta": datetime.timedelta(seconds=1)},
{"decimal": 1.0},
{"frozenset": frozenset([1, 2])},
{"set": {1, 2}},
{"uuid": uuid.uuid4()},
]

points = [
models.PointStruct(id=i, vector=[random.random(), random.random()], payload=payload)
for i, payload in enumerate(payloads)
]

for point in points: # for better debugging
local_client.upsert(COLLECTION_NAME, [point])
remote_client.upsert(COLLECTION_NAME, [point])
compare_collections(local_client, remote_client, len(points))

0 comments on commit a7900b2

Please sign in to comment.