Skip to content

Commit

Permalink
Add test for dict init
Browse files Browse the repository at this point in the history
  • Loading branch information
piercefreeman committed Apr 25, 2023
1 parent fe168d1 commit 8b7a942
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
6 changes: 3 additions & 3 deletions vectordb_orm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ def from_dict(cls, data: dict):
:returns: A MilvusBase object.
:raises ValueError: If an unexpected attribute name is encountered in the dictionary.
"""
obj = cls()
init_payload = {}
allowed_keys = list(cls.__annotations__.keys())
for attribute_name, value in data.items():
if attribute_name not in allowed_keys:
raise ValueError(f"Key `{attribute_name}` is not allowed on {cls.__name__}")
setattr(obj, attribute_name, value)
return obj
init_payload[attribute_name] = value
return cls(**init_payload)
13 changes: 13 additions & 0 deletions vectordb_orm/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,19 @@ def test_create_object(session: str, model: Type[VectorSchemaBase]):
assert my_object.id is None


@pytest.mark.parametrize("session,model", SESSION_MODEL_PAIRS)
def test_from_dict(session: str, model: Type[VectorSchemaBase]):
my_object = model.from_dict(
{
"text": 'example',
"embedding": np.array([1.0] * 128)
}
)
assert my_object.text == 'example'
assert np.array_equal(my_object.embedding, np.array([1.0] * 128))
assert my_object.id is None


@pytest.mark.parametrize("session,model", SESSION_MODEL_PAIRS)
def test_insert_object(session: str, model: Type[VectorSchemaBase], request):
session : VectorSession = request.getfixturevalue(session)
Expand Down

0 comments on commit 8b7a942

Please sign in to comment.