Skip to content

Commit

Permalink
chore(deps)!: Pydantic V2 (#12)
Browse files Browse the repository at this point in the history
BREAKING CHANGE: Moving to Pydantic V2

Co-authored-by: Zeke Zumbro <[email protected]>
Co-authored-by: Jeferson Daniel <[email protected]>
  • Loading branch information
3 people authored Sep 13, 2023
1 parent cda9b63 commit 561c122
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 58 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ build/
pydantic_mongo.egg-info/
var/
.pytest_cache/
.idea/
17 changes: 10 additions & 7 deletions pydantic_mongo/abstract_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,23 @@ def get_collection(self) -> Collection:
def __validate(self):
if not issubclass(self.__document_class, BaseModel):
raise Exception("Document class should inherit BaseModel")
if "id" not in self.__document_class.__fields__:
if "id" not in self.__document_class.model_fields:
raise Exception("Document class should have id field")
if not self.__collection_name:
raise Exception("Meta should contain collection name")

def to_document(self, model: T) -> dict:
@staticmethod
def to_document(model: T) -> dict:
"""
Convert model to document
:param model:
:return: dict
"""
result = model.dict()
result.pop("id")
data = model.model_dump()
data.pop("id")
if model.id:
result["_id"] = model.id
return result
data["_id"] = model.id
return data

def __map_id(self, data: dict) -> dict:
query = data.copy()
Expand All @@ -92,7 +95,7 @@ def to_model_custom(self, output_type: Type[OutputT], data: dict) -> OutputT:
data_copy = data.copy()
if "_id" in data_copy:
data_copy["id"] = data_copy.pop("_id")
return output_type.parse_obj(data_copy)
return output_type.model_validate(data_copy)

def to_model(self, data: dict) -> T:
"""
Expand Down
29 changes: 22 additions & 7 deletions pydantic_mongo/fields.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,33 @@
from typing import Any

from bson import ObjectId
from pydantic_core import core_schema


class ObjectIdField:
class ObjectIdField(str):
@classmethod
def __get_validators__(cls):
yield cls.validate
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: Any
) -> core_schema.CoreSchema:
object_id_schema = core_schema.chain_schema(
[
core_schema.str_schema(),
core_schema.no_info_plain_validator_function(cls.validate),
]
)
return core_schema.json_or_python_schema(
json_schema=object_id_schema,
python_schema=core_schema.union_schema(
[core_schema.is_instance_schema(ObjectId), object_id_schema]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda x: str(x)
),
)

@classmethod
def validate(cls, value):
if not ObjectId.is_valid(value):
raise ValueError("Invalid id")

return ObjectId(value)

@classmethod
def __modify_schema__(cls, field_schema):
field_schema.update(type="string")
9 changes: 2 additions & 7 deletions pydantic_mongo/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,17 @@
from typing import Any, Generic, List, TypeVar

import bson
from bson import ObjectId
from pydantic import BaseModel
from pydantic.generics import GenericModel

from .errors import PaginationError

DataT = TypeVar("DataT")


class Edge(GenericModel, Generic[DataT]):
class Edge(BaseModel, Generic[DataT]):
node: DataT
cursor: str

class Config:
json_encoders = {ObjectId: str}


def encode_pagination_cursor(data: List) -> str:
byte_data = bson.BSON.encode({"v": data})
Expand All @@ -37,7 +32,7 @@ def decode_pagination_cursor(data: str) -> List:


def get_pagination_cursor_payload(model: BaseModel, keys: List[str]) -> List[Any]:
model_dict = model.dict()
model_dict = model.model_dump()
model_dict["_id"] = model_dict["id"]

return [__evaluate_dot_notation(model_dict, key) for key in keys]
Expand Down
4 changes: 2 additions & 2 deletions requirements_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ pytest==7.2.1
pytest-cov==4.0.0
pytest-mock==3.10.0
mongomock==4.1.2
pydantic==1.8.2
pydantic>=2.0.2
pymongo==4.3.3
mypy==0.991
mypy-extensions==0.4.3
black==23.3.0
isort==5.12.0
isort==5.11.5
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os

from setuptools import setup

long_description = open("README.rst", "r").read()
Expand All @@ -15,7 +16,7 @@
version=version,
packages=["pydantic_mongo"],
setup_requires=["wheel"],
install_requires=["pymongo>=4.3,<5.0", "pydantic>=1.6.2,<2.0.0"],
install_requires=["pymongo>=4.3,<5.0", "pydantic>=2.0.2,<3.0.0"],
entry_points={
"console_scripts": ["pydantic_mongo = pydantic_mongo.__main__:__main__"],
},
Expand Down
5 changes: 3 additions & 2 deletions test/test_enhance_meta.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import pytest
from pydantic import BaseModel
from pydantic import BaseModel, Field

from pydantic_mongo import AbstractRepository, ObjectIdField


class HamModel(BaseModel):
id: ObjectIdField = None
id: ObjectIdField = Field(default=None)
name: str


Expand Down
19 changes: 8 additions & 11 deletions test/test_fields.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,30 @@
import pytest
from bson import ObjectId
from pydantic import BaseModel
from pydantic.error_wrappers import ValidationError
from pydantic import BaseModel, ValidationError

from pydantic_mongo import ObjectIdField


class User(BaseModel):
id: ObjectIdField = None

class Config:
json_encoders = {ObjectId: str}


class TestFields:
def test_object_id_validation(self):
with pytest.raises(ValidationError):
User.parse_obj({"id": "lala"})
User.parse_obj({"id": "611827f2878b88b49ebb69fc"})
User.model_validate({"id": "lala"})
User.model_validate({"id": "611827f2878b88b49ebb69fc"})

def test_object_id_serialize(self):
lala = User(id=ObjectId("611827f2878b88b49ebb69fc"))
json_result = lala.json()
assert '{"id": "611827f2878b88b49ebb69fc"}' == json_result
json_result = lala.model_dump_json()
assert '{"id":"611827f2878b88b49ebb69fc"}' == json_result

def test_modify_schema(self):
user = User(id=ObjectId("611827f2878b88b49ebb69fc"))
schema = user.schema()
schema = user.model_json_schema()
assert {
"title": "User",
"type": "object",
"properties": {"id": {"title": "Id", "type": "string"}},
"properties": {"id": {"default": None, "title": "Id", "type": "string"}},
} == schema
33 changes: 18 additions & 15 deletions test/test_pagination.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import datetime
from typing import List

from bson import ObjectId
from pydantic_mongo.pagination import encode_pagination_cursor, decode_pagination_cursor, get_pagination_cursor_payload
from pydantic import BaseModel
from typing import List

from pydantic_mongo.pagination import (
decode_pagination_cursor,
encode_pagination_cursor,
get_pagination_cursor_payload,
)


class Foo(BaseModel):
Expand All @@ -11,35 +17,32 @@ class Foo(BaseModel):


class Bar(BaseModel):
apple = 'x'
banana = 'y'
apple: str = "x"
banana: str = "y"


class Spam(BaseModel):
id: str = None
foo: Foo
bars: List[Bar]

class Config:
json_encoders = {ObjectId: str}


class TestPagination:
def test_get_pagination_cursor_payload(self):
spam = Spam(id='lala', foo=Foo(count=1, size=1.0), bars=[Bar()])
spam = Spam(id="lala", foo=Foo(count=1, size=1.0), bars=[Bar()])

values = get_pagination_cursor_payload(spam, ['_id', 'id'])
assert values[0] == 'lala'
assert values[1] == 'lala'
values = get_pagination_cursor_payload(spam, ["_id", "id"])
assert values[0] == "lala"
assert values[1] == "lala"

values = get_pagination_cursor_payload(spam, ['foo.count'])
values = get_pagination_cursor_payload(spam, ["foo.count"])
assert values[0] == 1

values = get_pagination_cursor_payload(spam, ['bars.0.apple'])
assert values[0] == 'x'
values = get_pagination_cursor_payload(spam, ["bars.0.apple"])
assert values[0] == "x"

def test_cursor_encoding(self):
old_value = [ObjectId('611b158adec89d18984b7d90'), 'a', 1]
old_value = [ObjectId("611b158adec89d18984b7d90"), "a", 1]
cursor = encode_pagination_cursor(old_value)
new_value = decode_pagination_cursor(cursor)
assert old_value == new_value
10 changes: 4 additions & 6 deletions test/test_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import mongomock
import pytest
from bson import ObjectId
from pydantic import BaseModel
from pydantic import BaseModel, Field

from pydantic_mongo import AbstractRepository, ObjectIdField
from pydantic_mongo.errors import PaginationError

Expand All @@ -14,18 +15,15 @@ class Foo(BaseModel):


class Bar(BaseModel):
apple = "x"
banana = "y"
apple: str = Field(default="x")
banana: str = Field(default="y")


class Spam(BaseModel):
id: ObjectIdField = None
foo: Foo
bars: List[Bar]

class Config:
json_encoders = {ObjectId: str}


class SpamRepository(AbstractRepository[Spam]):
class Meta:
Expand Down

0 comments on commit 561c122

Please sign in to comment.