Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(deps)!: Pydantic V2 #12

Merged
merged 7 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ build/
pydantic_mongo.egg-info/
var/
.pytest_cache/
.idea/
17 changes: 10 additions & 7 deletions pydantic_mongo/abstract_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,23 @@ def get_collection(self) -> Collection:
def __validate(self):
if not issubclass(self.__document_class, BaseModel):
raise Exception("Document class should inherit BaseModel")
if "id" not in self.__document_class.__fields__:
if "id" not in self.__document_class.model_fields:
raise Exception("Document class should have id field")
if not self.__collection_name:
raise Exception("Meta should contain collection name")

def to_document(self, model: T) -> dict:
@staticmethod
def to_document(model: T) -> dict:
"""
Convert model to document
:param model:
:return: dict
"""
result = model.dict()
result.pop("id")
data = model.model_dump()
data.pop("id")
if model.id:
result["_id"] = model.id
return result
data["_id"] = model.id
return data

def __map_id(self, data: dict) -> dict:
query = data.copy()
Expand All @@ -92,7 +95,7 @@ def to_model_custom(self, output_type: Type[OutputT], data: dict) -> OutputT:
data_copy = data.copy()
if "_id" in data_copy:
data_copy["id"] = data_copy.pop("_id")
return output_type.parse_obj(data_copy)
return output_type.model_validate(data_copy)

def to_model(self, data: dict) -> T:
"""
Expand Down
29 changes: 22 additions & 7 deletions pydantic_mongo/fields.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,33 @@
from typing import Any

from bson import ObjectId
from pydantic_core import core_schema


class ObjectIdField:
class ObjectIdField(str):
@classmethod
def __get_validators__(cls):
yield cls.validate
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: Any
) -> core_schema.CoreSchema:
object_id_schema = core_schema.chain_schema(
[
core_schema.str_schema(),
core_schema.no_info_plain_validator_function(cls.validate),
]
)
return core_schema.json_or_python_schema(
json_schema=object_id_schema,
python_schema=core_schema.union_schema(
[core_schema.is_instance_schema(ObjectId), object_id_schema]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda x: str(x)
),
)

@classmethod
def validate(cls, value):
if not ObjectId.is_valid(value):
raise ValueError("Invalid id")

return ObjectId(value)

@classmethod
def __modify_schema__(cls, field_schema):
field_schema.update(type="string")
9 changes: 2 additions & 7 deletions pydantic_mongo/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,17 @@
from typing import Any, Generic, List, TypeVar

import bson
from bson import ObjectId
from pydantic import BaseModel
from pydantic.generics import GenericModel

from .errors import PaginationError

DataT = TypeVar("DataT")


class Edge(GenericModel, Generic[DataT]):
class Edge(BaseModel, Generic[DataT]):
node: DataT
cursor: str

class Config:
json_encoders = {ObjectId: str}


def encode_pagination_cursor(data: List) -> str:
byte_data = bson.BSON.encode({"v": data})
Expand All @@ -37,7 +32,7 @@ def decode_pagination_cursor(data: str) -> List:


def get_pagination_cursor_payload(model: BaseModel, keys: List[str]) -> List[Any]:
model_dict = model.dict()
model_dict = model.model_dump()
model_dict["_id"] = model_dict["id"]

return [__evaluate_dot_notation(model_dict, key) for key in keys]
Expand Down
4 changes: 2 additions & 2 deletions requirements_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ pytest==7.2.1
pytest-cov==4.0.0
pytest-mock==3.10.0
mongomock==4.1.2
pydantic==1.8.2
pydantic>=2.0.2
pymongo==4.3.3
mypy==0.991
mypy-extensions==0.4.3
black==23.3.0
isort==5.12.0
isort
12 changes: 4 additions & 8 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os

from setuptools import setup

long_description = open("README.rst", "r").read()
Expand All @@ -15,14 +16,9 @@
version=version,
packages=["pydantic_mongo"],
setup_requires=["wheel"],
install_requires=[
"pymongo>=4.3,<5.0",
"pydantic>=1.6.2,<2.0.0"
],
install_requires=["pymongo>=4.3,<5.0", "pydantic>=2.0.2"],
entry_points={
"console_scripts": [
"pydantic_mongo = pydantic_mongo.__main__:__main__"
],
"console_scripts": ["pydantic_mongo = pydantic_mongo.__main__:__main__"],
},
description="Document object mapper for pydantic and pymongo",
long_description=long_description,
Expand All @@ -37,5 +33,5 @@
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
],
python_requires=">=3.7"
python_requires=">=3.7",
)
5 changes: 3 additions & 2 deletions test/test_enhance_meta.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import pytest
from pydantic import BaseModel
from pydantic import BaseModel, Field

from pydantic_mongo import AbstractRepository, ObjectIdField


class HamModel(BaseModel):
id: ObjectIdField = None
id: ObjectIdField = Field(default=None)
name: str


Expand Down
19 changes: 8 additions & 11 deletions test/test_fields.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,30 @@
import pytest
from bson import ObjectId
from pydantic import BaseModel
from pydantic.error_wrappers import ValidationError
from pydantic import BaseModel, ValidationError

from pydantic_mongo import ObjectIdField


class User(BaseModel):
id: ObjectIdField = None

class Config:
json_encoders = {ObjectId: str}


class TestFields:
def test_object_id_validation(self):
with pytest.raises(ValidationError):
User.parse_obj({"id": "lala"})
User.parse_obj({"id": "611827f2878b88b49ebb69fc"})
User.model_validate({"id": "lala"})
User.model_validate({"id": "611827f2878b88b49ebb69fc"})

def test_object_id_serialize(self):
lala = User(id=ObjectId("611827f2878b88b49ebb69fc"))
json_result = lala.json()
assert '{"id": "611827f2878b88b49ebb69fc"}' == json_result
json_result = lala.model_dump_json()
assert '{"id":"611827f2878b88b49ebb69fc"}' == json_result

def test_modify_schema(self):
user = User(id=ObjectId("611827f2878b88b49ebb69fc"))
schema = user.schema()
schema = user.model_json_schema()
assert {
"title": "User",
"type": "object",
"properties": {"id": {"title": "Id", "type": "string"}},
"properties": {"id": {"default": None, "title": "Id", "type": "string"}},
} == schema
33 changes: 18 additions & 15 deletions test/test_pagination.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import datetime
from typing import List

from bson import ObjectId
from pydantic_mongo.pagination import encode_pagination_cursor, decode_pagination_cursor, get_pagination_cursor_payload
from pydantic import BaseModel
from typing import List

from pydantic_mongo.pagination import (
decode_pagination_cursor,
encode_pagination_cursor,
get_pagination_cursor_payload,
)


class Foo(BaseModel):
Expand All @@ -11,35 +17,32 @@ class Foo(BaseModel):


class Bar(BaseModel):
apple = 'x'
banana = 'y'
apple: str = "x"
banana: str = "y"


class Spam(BaseModel):
id: str = None
foo: Foo
bars: List[Bar]

class Config:
json_encoders = {ObjectId: str}


class TestPagination:
def test_get_pagination_cursor_payload(self):
spam = Spam(id='lala', foo=Foo(count=1, size=1.0), bars=[Bar()])
spam = Spam(id="lala", foo=Foo(count=1, size=1.0), bars=[Bar()])

values = get_pagination_cursor_payload(spam, ['_id', 'id'])
assert values[0] == 'lala'
assert values[1] == 'lala'
values = get_pagination_cursor_payload(spam, ["_id", "id"])
assert values[0] == "lala"
assert values[1] == "lala"

values = get_pagination_cursor_payload(spam, ['foo.count'])
values = get_pagination_cursor_payload(spam, ["foo.count"])
assert values[0] == 1

values = get_pagination_cursor_payload(spam, ['bars.0.apple'])
assert values[0] == 'x'
values = get_pagination_cursor_payload(spam, ["bars.0.apple"])
assert values[0] == "x"

def test_cursor_encoding(self):
old_value = [ObjectId('611b158adec89d18984b7d90'), 'a', 1]
old_value = [ObjectId("611b158adec89d18984b7d90"), "a", 1]
cursor = encode_pagination_cursor(old_value)
new_value = decode_pagination_cursor(cursor)
assert old_value == new_value
10 changes: 4 additions & 6 deletions test/test_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import mongomock
import pytest
from bson import ObjectId
from pydantic import BaseModel
from pydantic import BaseModel, Field

from pydantic_mongo import AbstractRepository, ObjectIdField
from pydantic_mongo.errors import PaginationError

Expand All @@ -14,18 +15,15 @@ class Foo(BaseModel):


class Bar(BaseModel):
apple = "x"
banana = "y"
apple: str = Field(default="x")
banana: str = Field(default="y")


class Spam(BaseModel):
id: ObjectIdField = None
foo: Foo
bars: List[Bar]

class Config:
json_encoders = {ObjectId: str}


class SpamRepository(AbstractRepository[Spam]):
class Meta:
Expand Down