diff --git a/vectordb_orm/base.py b/vectordb_orm/base.py index 8f8b74b..f239901 100644 --- a/vectordb_orm/base.py +++ b/vectordb_orm/base.py @@ -81,10 +81,10 @@ def from_dict(cls, data: dict): :returns: A MilvusBase object. :raises ValueError: If an unexpected attribute name is encountered in the dictionary. """ - obj = cls() + init_payload = {} allowed_keys = list(cls.__annotations__.keys()) for attribute_name, value in data.items(): if attribute_name not in allowed_keys: raise ValueError(f"Key `{attribute_name}` is not allowed on {cls.__name__}") - setattr(obj, attribute_name, value) - return obj + init_payload[attribute_name] = value + return cls(**init_payload) diff --git a/vectordb_orm/tests/test_base.py b/vectordb_orm/tests/test_base.py index 25cbe47..d0bfe1a 100644 --- a/vectordb_orm/tests/test_base.py +++ b/vectordb_orm/tests/test_base.py @@ -17,6 +17,19 @@ def test_create_object(session: str, model: Type[VectorSchemaBase]): assert my_object.id is None +@pytest.mark.parametrize("session,model", SESSION_MODEL_PAIRS) +def test_from_dict(session: str, model: Type[VectorSchemaBase]): + my_object = model.from_dict( + { + "text": 'example', + "embedding": np.array([1.0] * 128) + } + ) + assert my_object.text == 'example' + assert np.array_equal(my_object.embedding, np.array([1.0] * 128)) + assert my_object.id is None + + @pytest.mark.parametrize("session,model", SESSION_MODEL_PAIRS) def test_insert_object(session: str, model: Type[VectorSchemaBase], request): session : VectorSession = request.getfixturevalue(session)