diff --git a/mypy_django_plugin/transformers/meta.py b/mypy_django_plugin/transformers/meta.py index 007d8e48d..05f56500d 100644 --- a/mypy_django_plugin/transformers/meta.py +++ b/mypy_django_plugin/transformers/meta.py @@ -5,7 +5,7 @@ from mypy.types import TypeOfAny from mypy_django_plugin.django.context import DjangoContext -from mypy_django_plugin.lib import fullnames, helpers +from mypy_django_plugin.lib import helpers def _get_field_instance(ctx: MethodContext, field_fullname: str) -> MypyType: @@ -20,21 +20,25 @@ def return_proper_field_type_from_get_field(ctx: MethodContext, django_context: # Options instance assert isinstance(ctx.type, Instance) + # bail if list of generic params is empty + if len(ctx.type.args) == 0: + return ctx.default_return_type + model_type = ctx.type.args[0] if not isinstance(model_type, Instance): - return _get_field_instance(ctx, fullnames.FIELD_FULLNAME) + return ctx.default_return_type model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname()) if model_cls is None: - return _get_field_instance(ctx, fullnames.FIELD_FULLNAME) + return ctx.default_return_type field_name_expr = helpers.get_call_argument_by_name(ctx, 'field_name') if field_name_expr is None: - return _get_field_instance(ctx, fullnames.FIELD_FULLNAME) + return ctx.default_return_type field_name = helpers.resolve_string_attribute_value(field_name_expr, ctx, django_context) if field_name is None: - return _get_field_instance(ctx, fullnames.FIELD_FULLNAME) + return ctx.default_return_type try: field = model_cls._meta.get_field(field_name) diff --git a/mypy_django_plugin/transformers/querysets.py b/mypy_django_plugin/transformers/querysets.py index 3ba73b4ac..a524aaba2 100644 --- a/mypy_django_plugin/transformers/querysets.py +++ b/mypy_django_plugin/transformers/querysets.py @@ -3,6 +3,7 @@ from django.core.exceptions import FieldError from django.db.models.base import Model +from django.db.models.fields.related import RelatedField from mypy.newsemanal.typeanal import TypeAnalyser from mypy.nodes import Expression, NameExpr, TypeInfo from mypy.plugin import AnalyzeTypeContext, FunctionContext, MethodContext @@ -10,11 +11,19 @@ from mypy.types import Type as MypyType from mypy.types import TypeOfAny -from django.db.models.fields.related import RelatedField from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.lib import fullnames, helpers +def _extract_model_type_from_queryset(queryset_type: Instance) -> Optional[Instance]: + for base_type in [queryset_type, *queryset_type.type.bases]: + if (len(base_type.args) + and isinstance(base_type.args[0], Instance) + and base_type.args[0].type.has_base(fullnames.MODEL_CLASS_FULLNAME)): + return base_type.args[0] + return None + + def determine_proper_manager_type(ctx: FunctionContext) -> MypyType: default_return_type = ctx.default_return_type assert isinstance(default_return_type, Instance) @@ -98,11 +107,10 @@ def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context: assert isinstance(ctx.type, Instance) assert isinstance(ctx.default_return_type, Instance) - # bail if queryset of Any or other non-instances - if not isinstance(ctx.type.args[0], Instance): + model_type = _extract_model_type_from_queryset(ctx.type) + if model_type is None: return AnyType(TypeOfAny.from_omitted_generics) - model_type = ctx.type.args[0] model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname()) if model_cls is None: return ctx.default_return_type @@ -148,11 +156,10 @@ def extract_proper_type_queryset_values(ctx: MethodContext, django_context: Djan assert isinstance(ctx.type, Instance) assert isinstance(ctx.default_return_type, Instance) - # if queryset of non-instance type - if not isinstance(ctx.type.args[0], Instance): + model_type = _extract_model_type_from_queryset(ctx.type) + if model_type is None: return AnyType(TypeOfAny.from_omitted_generics) - model_type = ctx.type.args[0] model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname()) if model_cls is None: return ctx.default_return_type diff --git a/setup.py b/setup.py index 118380131..220e0150d 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ def find_stub_files(name: str) -> List[str]: setup( name="django-stubs", - version="1.0.1", + version="1.0.2", description='Mypy stubs for Django', long_description=readme, long_description_content_type='text/markdown', diff --git a/test-data/typecheck/managers/querysets/test_values_list.yml b/test-data/typecheck/managers/querysets/test_values_list.yml index 82018202a..94a124308 100644 --- a/test-data/typecheck/managers/querysets/test_values_list.yml +++ b/test-data/typecheck/managers/querysets/test_values_list.yml @@ -204,4 +204,22 @@ class Publisher(models.Model): pass class Blog(models.Model): - publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE) \ No newline at end of file + publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE) + +- case: subclass_of_queryset_has_proper_typings_on_methods + main: | + from myapp.models import TransactionQuerySet + reveal_type(TransactionQuerySet()) # N: Revealed type is 'myapp.models.TransactionQuerySet' + reveal_type(TransactionQuerySet().values()) # N: Revealed type is 'django.db.models.query.QuerySet[myapp.models.Transaction, TypedDict({'id': builtins.int, 'total': builtins.int})]' + reveal_type(TransactionQuerySet().values_list()) # N: Revealed type is 'django.db.models.query.QuerySet[myapp.models.Transaction, Tuple[builtins.int, builtins.int]]' + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class TransactionQuerySet(models.QuerySet['Transaction']): + pass + class Transaction(models.Model): + total = models.IntegerField()