Skip to content

Commit

Permalink
fix: More work on typing support; add black and ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
last-partizan authored Nov 3, 2023
1 parent bc57439 commit b9a0e0e
Show file tree
Hide file tree
Showing 12 changed files with 433 additions and 60 deletions.
11 changes: 10 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,24 @@ jobs:
- 27017:27017
steps:
- uses: actions/checkout@v3
- name: Install poetry
run: pipx install poetry
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python }}
cache: 'poetry'
- name: Cache virtualenv
uses: actions/cache@v3
with:
key: venv-${{ runner.os }}-${{ steps.setup_python.outputs.python-version}}-${{ hashFiles('poetry.lock') }}
path: .venv
- name: Set up env
run: |
python -m pip install -U -q poetry pip
poetry install -q
poetry run pip install -q "${{ matrix.django }}"
- name: Run tests
run: |
poetry run ruff .
poetry run black --check .
poetry run python -m pytest
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,8 @@ publish:

test:
poetry run python -m pytest

codegen:
python codegen.py
black django_mongoengine/fields/__init__.py
ruff django_mongoengine/ --fix # It doesn't work with filename.
31 changes: 31 additions & 0 deletions codegen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
def generate_fields():
"""
Typing support cannot handle monkey-patching at runtime, so we need to generate fields explicitly.
"""
from mongoengine import fields
from django_mongoengine.fields import djangoflavor as mixins

fields_code = str(_fields)
for fname in fields.__all__:
mixin_name = fname if hasattr(mixins, fname) else "DjangoField"
fields_code += f"class {fname}(_mixins.{mixin_name}, _fields.{fname}):\n pass\n"

return fields_code


_fields = """
from mongoengine import fields as _fields
from . import djangoflavor as _mixins
from django_mongoengine.utils.monkey import patch_mongoengine_field
for f in ["StringField", "ObjectIdField"]:
patch_mongoengine_field(f)
"""

if __name__ == "__main__":
fname = "django_mongoengine/fields/__init__.py"
# This content required, because otherwise mixins import does not work.
open(fname, "w").write("from mongoengine.fields import *")
content = generate_fields()
open(fname, "w").write(content)
16 changes: 10 additions & 6 deletions django_mongoengine/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

from bson.objectid import ObjectId
from django.db.models import Model
Expand All @@ -11,11 +11,15 @@
from mongoengine import document as me
from mongoengine.base import metaclasses as mtc
from mongoengine.errors import FieldDoesNotExist
from typing_extensions import Self

from .fields import ObjectIdField
from .forms.document_options import DocumentMetaWrapper
from .queryset import QuerySetManager

if TYPE_CHECKING:
from mongoengine.fields import StringField

# TopLevelDocumentMetaclass is using ObjectIdField to create default pk field,
# if one's not set explicitly.
# We need to know it's not editable and auto_created.
Expand Down Expand Up @@ -43,11 +47,11 @@ def __new__(cls, name, bases, attrs):


class DjangoFlavor:
id: Any
objects: Any = QuerySetManager()
_meta: DocumentMetaWrapper
_default_manager: Any = QuerySetManager()
id: StringField
objects = QuerySetManager[Self]()
_default_manager = QuerySetManager[Self]()
_get_pk_val = Model.__dict__["_get_pk_val"]
_meta: DocumentMetaWrapper
DoesNotExist: type[DoesNotExist]

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -115,7 +119,7 @@ class DynamicDocument(DjangoFlavor, me.DynamicDocument):
...

class EmbeddedDocument(DjangoFlavor, me.EmbeddedDocument):
...
_instance: Document

class DynamicEmbeddedDocument(DjangoFlavor, me.DynamicEmbeddedDocument):
...
203 changes: 171 additions & 32 deletions django_mongoengine/fields/__init__.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,182 @@
from . import djangoflavor
from mongoengine import fields as _fields
from . import djangoflavor as _mixins
from django_mongoengine.utils.monkey import patch_mongoengine_field

for f in ["StringField", "ObjectIdField"]:
patch_mongoengine_field(f)

def init_module():
"""
Create classes with Django-flavor mixins,
use DjangoField mixin as default
"""
import sys

from mongoengine import fields
class StringField(_mixins.StringField, _fields.StringField):
pass

current_module = sys.modules[__name__]
current_module.__all__ = fields.__all__

for name in fields.__all__:
fieldcls = getattr(fields, name)
mixin = getattr(djangoflavor, name, djangoflavor.DjangoField)
setattr(
current_module,
name,
type(name, (mixin, fieldcls), {}),
)
class URLField(_mixins.URLField, _fields.URLField):
pass


def patch_mongoengine_field(field_name):
"""
patch mongoengine.[field_name] for comparison support
becouse it's required in django.forms.models.fields_for_model
importing using mongoengine internal import cache
"""
from mongoengine import common
class EmailField(_mixins.EmailField, _fields.EmailField):
pass

field = common._import_class(field_name)
for k in ["__eq__", "__lt__", "__hash__", "attname", "get_internal_type"]:
if k not in field.__dict__:
setattr(field, k, djangoflavor.DjangoField.__dict__[k])

class IntField(_mixins.IntField, _fields.IntField):
pass

init_module()

for f in ["StringField", "ObjectIdField"]:
patch_mongoengine_field(f)
class LongField(_mixins.DjangoField, _fields.LongField):
pass


class FloatField(_mixins.FloatField, _fields.FloatField):
pass


class DecimalField(_mixins.DecimalField, _fields.DecimalField):
pass


class BooleanField(_mixins.BooleanField, _fields.BooleanField):
pass


class DateTimeField(_mixins.DateTimeField, _fields.DateTimeField):
pass


class DateField(_mixins.DjangoField, _fields.DateField):
pass


class ComplexDateTimeField(_mixins.DjangoField, _fields.ComplexDateTimeField):
pass


class EmbeddedDocumentField(_mixins.EmbeddedDocumentField, _fields.EmbeddedDocumentField):
pass


class ObjectIdField(_mixins.DjangoField, _fields.ObjectIdField):
pass


class GenericEmbeddedDocumentField(_mixins.DjangoField, _fields.GenericEmbeddedDocumentField):
pass


class DynamicField(_mixins.DjangoField, _fields.DynamicField):
pass


class ListField(_mixins.ListField, _fields.ListField):
pass


class SortedListField(_mixins.DjangoField, _fields.SortedListField):
pass


class EmbeddedDocumentListField(_mixins.DjangoField, _fields.EmbeddedDocumentListField):
pass


class DictField(_mixins.DictField, _fields.DictField):
pass


class MapField(_mixins.DjangoField, _fields.MapField):
pass


class ReferenceField(_mixins.ReferenceField, _fields.ReferenceField):
pass


class CachedReferenceField(_mixins.DjangoField, _fields.CachedReferenceField):
pass


class LazyReferenceField(_mixins.DjangoField, _fields.LazyReferenceField):
pass


class GenericLazyReferenceField(_mixins.DjangoField, _fields.GenericLazyReferenceField):
pass


class GenericReferenceField(_mixins.DjangoField, _fields.GenericReferenceField):
pass


class BinaryField(_mixins.DjangoField, _fields.BinaryField):
pass


class GridFSError(_mixins.DjangoField, _fields.GridFSError):
pass


class GridFSProxy(_mixins.DjangoField, _fields.GridFSProxy):
pass


class FileField(_mixins.FileField, _fields.FileField):
pass


class ImageGridFsProxy(_mixins.DjangoField, _fields.ImageGridFsProxy):
pass


class ImproperlyConfigured(_mixins.ImproperlyConfigured, _fields.ImproperlyConfigured):
pass


class ImageField(_mixins.ImageField, _fields.ImageField):
pass


class GeoPointField(_mixins.DjangoField, _fields.GeoPointField):
pass


class PointField(_mixins.DjangoField, _fields.PointField):
pass


class LineStringField(_mixins.DjangoField, _fields.LineStringField):
pass


class PolygonField(_mixins.DjangoField, _fields.PolygonField):
pass


class SequenceField(_mixins.DjangoField, _fields.SequenceField):
pass


class UUIDField(_mixins.DjangoField, _fields.UUIDField):
pass


class EnumField(_mixins.DjangoField, _fields.EnumField):
pass


class MultiPointField(_mixins.DjangoField, _fields.MultiPointField):
pass


class MultiLineStringField(_mixins.DjangoField, _fields.MultiLineStringField):
pass


class MultiPolygonField(_mixins.DjangoField, _fields.MultiPolygonField):
pass


class GeoJsonBaseField(_mixins.DjangoField, _fields.GeoJsonBaseField):
pass


class Decimal128Field(_mixins.DjangoField, _fields.Decimal128Field):
pass
2 changes: 1 addition & 1 deletion django_mongoengine/forms/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(self, form, *args, **kwargs):
kwargs['widget'] = EmbeddedFieldWidget(self.form.fields)
kwargs['initial'] = [f.initial for f in self.form.fields.values()]
kwargs['require_all_fields'] = False
super().__init__(fields=tuple([f for f in self.form.fields.values()]), *args, **kwargs)
super().__init__(fields=tuple(self.form.fields.values()), *args, **kwargs)

def bound_data(self, data, initial):
return data
Expand Down
Loading

0 comments on commit b9a0e0e

Please sign in to comment.