From 3069fbd9b2a0dd2b8ddade8e5c27b10ffa6e290f Mon Sep 17 00:00:00 2001 From: "Eric O. Korman" Date: Fri, 5 Jul 2024 19:38:47 -0500 Subject: [PATCH] add vector validation --- affine/collection.py | 29 +++++++++++++---------------- tests/test_collection.py | 16 ++++++++++++++++ 2 files changed, 29 insertions(+), 16 deletions(-) create mode 100644 tests/test_collection.py diff --git a/affine/collection.py b/affine/collection.py index bbfb758..cf3e4c5 100644 --- a/affine/collection.py +++ b/affine/collection.py @@ -1,11 +1,11 @@ from dataclasses import dataclass, fields -from typing import Any, Generic, Literal, TypeVar, get_args +from typing import Any, Generic, Literal, TypeVar, get_args, get_origin import numpy as np N = TypeVar("N", bound=int) -Operation = Literal["eq", "lte", "gte"] +Operation = Literal["eq", "lte", "gte", "topk"] class Vector(Generic[N]): @@ -14,6 +14,9 @@ def __init__(self, array: np.ndarray | list): array = np.array(array) self.array = array + def __len__(self) -> int: + return len(self.array) + @dataclass class TopK: @@ -58,9 +61,14 @@ def __new__(cls, name, bases, dct): class Collection(metaclass=MetaCollection): - def validate_arrays(cls, values): - # check that any vec types have the specified length - return values + def __post_init__(self): + for field in fields(self): + if get_origin(field.type) == Vector: + n = field.type.__args__[0] + if len(getattr(self, field.name)) != n: + raise ValueError( + f"Expected vector of length {n}, got {len(getattr(self, field.name))}" + ) @classmethod def get_filter_from_kwarg(cls, k: str, v: Any) -> Filter: @@ -80,14 +88,3 @@ def get_filter_from_kwarg(cls, k: str, v: Any) -> Filter: def objects(cls, **kwargs) -> FilterSet: filters = [cls.get_filter_from_kwarg(k, v) for k, v in kwargs.items()] return FilterSet(filters=filters, collection=cls.__name__) - - -# Example -# class Person(Collection): -# age: int -# face_embedding: np.ndarray - - -# Person.query(age__gte=18, face_embedding=TopK(vector=np.array([1, 2, 3]), k=3)) - -# use a global connection (like mongoengine and others?) diff --git a/tests/test_collection.py b/tests/test_collection.py new file mode 100644 index 0000000..2b8f24f --- /dev/null +++ b/tests/test_collection.py @@ -0,0 +1,16 @@ +import pytest +from affine.collection import Collection, Vector + + +def test_vector_validation(): + class C(Collection): + x: Vector[3] + + with pytest.raises(ValueError) as exc_info: + C(x=[1, 2]) + assert "Expected vector of length 3, got 2" in str(exc_info.value) + + try: + C(x=[1, 2, 3]) + except Exception as e: + pytest.fail(f"Unexpected exception: {e}")