From 561c12277f1771bdaef52d4a1ef66ea9c6721326 Mon Sep 17 00:00:00 2001 From: Zeke Zumbro <33703690+jezumbro@users.noreply.github.com> Date: Tue, 12 Sep 2023 22:17:30 -0500 Subject: [PATCH] chore(deps)!: Pydantic V2 (#12) BREAKING CHANGE: Moving to Pydantic V2 Co-authored-by: Zeke Zumbro Co-authored-by: Jeferson Daniel --- .gitignore | 1 + pydantic_mongo/abstract_repository.py | 17 ++++++++------ pydantic_mongo/fields.py | 29 +++++++++++++++++------ pydantic_mongo/pagination.py | 9 ++------ requirements_test.txt | 4 ++-- setup.py | 3 ++- test/test_enhance_meta.py | 5 ++-- test/test_fields.py | 19 +++++++-------- test/test_pagination.py | 33 +++++++++++++++------------ test/test_repository.py | 10 ++++---- 10 files changed, 72 insertions(+), 58 deletions(-) diff --git a/.gitignore b/.gitignore index fdf5aab..a4f9d18 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ build/ pydantic_mongo.egg-info/ var/ .pytest_cache/ +.idea/ \ No newline at end of file diff --git a/pydantic_mongo/abstract_repository.py b/pydantic_mongo/abstract_repository.py index 415ce71..b607a12 100644 --- a/pydantic_mongo/abstract_repository.py +++ b/pydantic_mongo/abstract_repository.py @@ -54,20 +54,23 @@ def get_collection(self) -> Collection: def __validate(self): if not issubclass(self.__document_class, BaseModel): raise Exception("Document class should inherit BaseModel") - if "id" not in self.__document_class.__fields__: + if "id" not in self.__document_class.model_fields: raise Exception("Document class should have id field") if not self.__collection_name: raise Exception("Meta should contain collection name") - def to_document(self, model: T) -> dict: + @staticmethod + def to_document(model: T) -> dict: """ Convert model to document + :param model: + :return: dict """ - result = model.dict() - result.pop("id") + data = model.model_dump() + data.pop("id") if model.id: - result["_id"] = model.id - return result + data["_id"] = model.id + return data def __map_id(self, data: dict) -> dict: query = data.copy() @@ -92,7 +95,7 @@ def to_model_custom(self, output_type: Type[OutputT], data: dict) -> OutputT: data_copy = data.copy() if "_id" in data_copy: data_copy["id"] = data_copy.pop("_id") - return output_type.parse_obj(data_copy) + return output_type.model_validate(data_copy) def to_model(self, data: dict) -> T: """ diff --git a/pydantic_mongo/fields.py b/pydantic_mongo/fields.py index 64c4b28..34484be 100644 --- a/pydantic_mongo/fields.py +++ b/pydantic_mongo/fields.py @@ -1,10 +1,29 @@ +from typing import Any + from bson import ObjectId +from pydantic_core import core_schema -class ObjectIdField: +class ObjectIdField(str): @classmethod - def __get_validators__(cls): - yield cls.validate + def __get_pydantic_core_schema__( + cls, _source_type: Any, _handler: Any + ) -> core_schema.CoreSchema: + object_id_schema = core_schema.chain_schema( + [ + core_schema.str_schema(), + core_schema.no_info_plain_validator_function(cls.validate), + ] + ) + return core_schema.json_or_python_schema( + json_schema=object_id_schema, + python_schema=core_schema.union_schema( + [core_schema.is_instance_schema(ObjectId), object_id_schema] + ), + serialization=core_schema.plain_serializer_function_ser_schema( + lambda x: str(x) + ), + ) @classmethod def validate(cls, value): @@ -12,7 +31,3 @@ def validate(cls, value): raise ValueError("Invalid id") return ObjectId(value) - - @classmethod - def __modify_schema__(cls, field_schema): - field_schema.update(type="string") diff --git a/pydantic_mongo/pagination.py b/pydantic_mongo/pagination.py index 2b7de75..5d66752 100644 --- a/pydantic_mongo/pagination.py +++ b/pydantic_mongo/pagination.py @@ -3,22 +3,17 @@ from typing import Any, Generic, List, TypeVar import bson -from bson import ObjectId from pydantic import BaseModel -from pydantic.generics import GenericModel from .errors import PaginationError DataT = TypeVar("DataT") -class Edge(GenericModel, Generic[DataT]): +class Edge(BaseModel, Generic[DataT]): node: DataT cursor: str - class Config: - json_encoders = {ObjectId: str} - def encode_pagination_cursor(data: List) -> str: byte_data = bson.BSON.encode({"v": data}) @@ -37,7 +32,7 @@ def decode_pagination_cursor(data: str) -> List: def get_pagination_cursor_payload(model: BaseModel, keys: List[str]) -> List[Any]: - model_dict = model.dict() + model_dict = model.model_dump() model_dict["_id"] = model_dict["id"] return [__evaluate_dot_notation(model_dict, key) for key in keys] diff --git a/requirements_test.txt b/requirements_test.txt index efc6b56..4fd6847 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -6,9 +6,9 @@ pytest==7.2.1 pytest-cov==4.0.0 pytest-mock==3.10.0 mongomock==4.1.2 -pydantic==1.8.2 +pydantic>=2.0.2 pymongo==4.3.3 mypy==0.991 mypy-extensions==0.4.3 black==23.3.0 -isort==5.12.0 \ No newline at end of file +isort==5.11.5 \ No newline at end of file diff --git a/setup.py b/setup.py index fea222d..b5d0953 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,5 @@ import os + from setuptools import setup long_description = open("README.rst", "r").read() @@ -15,7 +16,7 @@ version=version, packages=["pydantic_mongo"], setup_requires=["wheel"], - install_requires=["pymongo>=4.3,<5.0", "pydantic>=1.6.2,<2.0.0"], + install_requires=["pymongo>=4.3,<5.0", "pydantic>=2.0.2,<3.0.0"], entry_points={ "console_scripts": ["pydantic_mongo = pydantic_mongo.__main__:__main__"], }, diff --git a/test/test_enhance_meta.py b/test/test_enhance_meta.py index 7bb7472..a993cfd 100644 --- a/test/test_enhance_meta.py +++ b/test/test_enhance_meta.py @@ -1,10 +1,11 @@ import pytest -from pydantic import BaseModel +from pydantic import BaseModel, Field + from pydantic_mongo import AbstractRepository, ObjectIdField class HamModel(BaseModel): - id: ObjectIdField = None + id: ObjectIdField = Field(default=None) name: str diff --git a/test/test_fields.py b/test/test_fields.py index f3a9b91..6a7043e 100644 --- a/test/test_fields.py +++ b/test/test_fields.py @@ -1,33 +1,30 @@ import pytest from bson import ObjectId -from pydantic import BaseModel -from pydantic.error_wrappers import ValidationError +from pydantic import BaseModel, ValidationError + from pydantic_mongo import ObjectIdField class User(BaseModel): id: ObjectIdField = None - class Config: - json_encoders = {ObjectId: str} - class TestFields: def test_object_id_validation(self): with pytest.raises(ValidationError): - User.parse_obj({"id": "lala"}) - User.parse_obj({"id": "611827f2878b88b49ebb69fc"}) + User.model_validate({"id": "lala"}) + User.model_validate({"id": "611827f2878b88b49ebb69fc"}) def test_object_id_serialize(self): lala = User(id=ObjectId("611827f2878b88b49ebb69fc")) - json_result = lala.json() - assert '{"id": "611827f2878b88b49ebb69fc"}' == json_result + json_result = lala.model_dump_json() + assert '{"id":"611827f2878b88b49ebb69fc"}' == json_result def test_modify_schema(self): user = User(id=ObjectId("611827f2878b88b49ebb69fc")) - schema = user.schema() + schema = user.model_json_schema() assert { "title": "User", "type": "object", - "properties": {"id": {"title": "Id", "type": "string"}}, + "properties": {"id": {"default": None, "title": "Id", "type": "string"}}, } == schema diff --git a/test/test_pagination.py b/test/test_pagination.py index 5a0a486..e539f04 100644 --- a/test/test_pagination.py +++ b/test/test_pagination.py @@ -1,8 +1,14 @@ import datetime +from typing import List + from bson import ObjectId -from pydantic_mongo.pagination import encode_pagination_cursor, decode_pagination_cursor, get_pagination_cursor_payload from pydantic import BaseModel -from typing import List + +from pydantic_mongo.pagination import ( + decode_pagination_cursor, + encode_pagination_cursor, + get_pagination_cursor_payload, +) class Foo(BaseModel): @@ -11,8 +17,8 @@ class Foo(BaseModel): class Bar(BaseModel): - apple = 'x' - banana = 'y' + apple: str = "x" + banana: str = "y" class Spam(BaseModel): @@ -20,26 +26,23 @@ class Spam(BaseModel): foo: Foo bars: List[Bar] - class Config: - json_encoders = {ObjectId: str} - class TestPagination: def test_get_pagination_cursor_payload(self): - spam = Spam(id='lala', foo=Foo(count=1, size=1.0), bars=[Bar()]) + spam = Spam(id="lala", foo=Foo(count=1, size=1.0), bars=[Bar()]) - values = get_pagination_cursor_payload(spam, ['_id', 'id']) - assert values[0] == 'lala' - assert values[1] == 'lala' + values = get_pagination_cursor_payload(spam, ["_id", "id"]) + assert values[0] == "lala" + assert values[1] == "lala" - values = get_pagination_cursor_payload(spam, ['foo.count']) + values = get_pagination_cursor_payload(spam, ["foo.count"]) assert values[0] == 1 - values = get_pagination_cursor_payload(spam, ['bars.0.apple']) - assert values[0] == 'x' + values = get_pagination_cursor_payload(spam, ["bars.0.apple"]) + assert values[0] == "x" def test_cursor_encoding(self): - old_value = [ObjectId('611b158adec89d18984b7d90'), 'a', 1] + old_value = [ObjectId("611b158adec89d18984b7d90"), "a", 1] cursor = encode_pagination_cursor(old_value) new_value = decode_pagination_cursor(cursor) assert old_value == new_value diff --git a/test/test_repository.py b/test/test_repository.py index fda5292..e0de312 100644 --- a/test/test_repository.py +++ b/test/test_repository.py @@ -3,7 +3,8 @@ import mongomock import pytest from bson import ObjectId -from pydantic import BaseModel +from pydantic import BaseModel, Field + from pydantic_mongo import AbstractRepository, ObjectIdField from pydantic_mongo.errors import PaginationError @@ -14,8 +15,8 @@ class Foo(BaseModel): class Bar(BaseModel): - apple = "x" - banana = "y" + apple: str = Field(default="x") + banana: str = Field(default="y") class Spam(BaseModel): @@ -23,9 +24,6 @@ class Spam(BaseModel): foo: Foo bars: List[Bar] - class Config: - json_encoders = {ObjectId: str} - class SpamRepository(AbstractRepository[Spam]): class Meta: