Skip to content

Commit

Permalink
add vector validation
Browse files Browse the repository at this point in the history
  • Loading branch information
ekorman committed Jul 6, 2024
1 parent 5007cc2 commit 3069fbd
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 16 deletions.
29 changes: 13 additions & 16 deletions affine/collection.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from dataclasses import dataclass, fields
from typing import Any, Generic, Literal, TypeVar, get_args
from typing import Any, Generic, Literal, TypeVar, get_args, get_origin

import numpy as np

N = TypeVar("N", bound=int)

Operation = Literal["eq", "lte", "gte"]
Operation = Literal["eq", "lte", "gte", "topk"]


class Vector(Generic[N]):
Expand All @@ -14,6 +14,9 @@ def __init__(self, array: np.ndarray | list):
array = np.array(array)
self.array = array

def __len__(self) -> int:
return len(self.array)


@dataclass
class TopK:
Expand Down Expand Up @@ -58,9 +61,14 @@ def __new__(cls, name, bases, dct):


class Collection(metaclass=MetaCollection):
def validate_arrays(cls, values):
# check that any vec types have the specified length
return values
def __post_init__(self):
for field in fields(self):
if get_origin(field.type) == Vector:
n = field.type.__args__[0]
if len(getattr(self, field.name)) != n:
raise ValueError(
f"Expected vector of length {n}, got {len(getattr(self, field.name))}"
)

@classmethod
def get_filter_from_kwarg(cls, k: str, v: Any) -> Filter:
Expand All @@ -80,14 +88,3 @@ def get_filter_from_kwarg(cls, k: str, v: Any) -> Filter:
def objects(cls, **kwargs) -> FilterSet:
filters = [cls.get_filter_from_kwarg(k, v) for k, v in kwargs.items()]
return FilterSet(filters=filters, collection=cls.__name__)


# Example
# class Person(Collection):
# age: int
# face_embedding: np.ndarray


# Person.query(age__gte=18, face_embedding=TopK(vector=np.array([1, 2, 3]), k=3))

# use a global connection (like mongoengine and others?)
16 changes: 16 additions & 0 deletions tests/test_collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest
from affine.collection import Collection, Vector


def test_vector_validation():
class C(Collection):
x: Vector[3]

with pytest.raises(ValueError) as exc_info:
C(x=[1, 2])
assert "Expected vector of length 3, got 2" in str(exc_info.value)

try:
C(x=[1, 2, 3])
except Exception as e:
pytest.fail(f"Unexpected exception: {e}")

0 comments on commit 3069fbd

Please sign in to comment.