Skip to content

Commit

Permalink
Transform user_passes_test signature
Browse files Browse the repository at this point in the history
Add a transformer that updates the signature of user_passes_test to
reflect settings.AUTH_USER_MODEL.

Fixes typeddjango#1058
  • Loading branch information
ljodal committed Nov 3, 2022
1 parent e88f942 commit 6c372f1
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 1 deletion.
5 changes: 4 additions & 1 deletion mypy_django_plugin/lib/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
ClassDefContext,
DynamicClassDefContext,
FunctionContext,
FunctionSigContext,
MethodContext,
SemanticAnalyzerPluginInterface,
)
Expand Down Expand Up @@ -328,7 +329,9 @@ def get_semanal_api(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> Sema
return ctx.api


def get_typechecker_api(ctx: Union[AttributeContext, MethodContext, FunctionContext]) -> TypeChecker:
def get_typechecker_api(
ctx: Union[AttributeContext, MethodContext, FunctionContext, FunctionSigContext]
) -> TypeChecker:
if not isinstance(ctx.api, TypeChecker):
raise ValueError("Not a TypeChecker")
return ctx.api
Expand Down
9 changes: 9 additions & 0 deletions mypy_django_plugin/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@
ClassDefContext,
DynamicClassDefContext,
FunctionContext,
FunctionSigContext,
MethodContext,
Plugin,
)
from mypy.types import FunctionLike
from mypy.types import Type as MypyType

import mypy_django_plugin.transformers.orm_lookups
from mypy_django_plugin.config import DjangoPluginConfig
from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.transformers import fields, forms, init_create, meta, querysets, request, settings
from mypy_django_plugin.transformers.auth import transform_user_passes_test
from mypy_django_plugin.transformers.functional import resolve_str_promise_attribute
from mypy_django_plugin.transformers.managers import (
create_new_manager_class_from_as_manager_method,
Expand Down Expand Up @@ -323,6 +326,12 @@ def get_dynamic_class_hook(self, fullname: str) -> Optional[Callable[[DynamicCla
return create_new_manager_class_from_as_manager_method
return None

def get_function_signature_hook(self, fullname: str) -> Optional[Callable[[FunctionSigContext], FunctionLike]]:
if fullname == "django.contrib.auth.decorators.user_passes_test":
return partial(transform_user_passes_test, django_context=self.django_context)

return None


def plugin(version: str) -> Type[NewSemanalDjangoPlugin]:
return NewSemanalDjangoPlugin
40 changes: 40 additions & 0 deletions mypy_django_plugin/transformers/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from mypy.plugin import FunctionSigContext
from mypy.types import CallableType, FunctionLike, Instance, UnionType

from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import helpers


def transform_user_passes_test(ctx: FunctionSigContext, django_context: DjangoContext) -> FunctionLike:
"""
Update the signature of user_passes_test to reflect settings.AUTH_USER_MODEL
"""

auth_user_model = django_context.settings.AUTH_USER_MODEL
try:
model_cls = django_context.apps_registry.get_model(auth_user_model)
except LookupError:
return ctx.default_signature
model_cls_fullname = helpers.get_class_fullname(model_cls)
user_model_info = helpers.lookup_fully_qualified_typeinfo(helpers.get_typechecker_api(ctx), model_cls_fullname)
if user_model_info is None:
return ctx.default_signature

if not ctx.default_signature.arg_types or not isinstance(ctx.default_signature.arg_types[0], CallableType):
return ctx.default_signature

test_func_type = ctx.default_signature.arg_types[0]

if not test_func_type.arg_types or not isinstance(test_func_type.arg_types[0], UnionType):
return ctx.default_signature
union = test_func_type.arg_types[0]

new_union = UnionType([Instance(user_model_info, [])] + union.items[1:])

new_test_func_type = test_func_type.copy_modified(
arg_types=[new_union] + test_func_type.arg_types[1:],
)

return ctx.default_signature.copy_modified(
arg_types=[new_test_func_type] + ctx.default_signature.arg_types[1:],
)
22 changes: 22 additions & 0 deletions tests/typecheck/contrib/auth/test_decorators.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,28 @@
@user_passes_test(lambda u: u.get_username().startswith('super'))
def view_func(request: HttpRequest) -> HttpResponse: ...
reveal_type(view_func) # N: Revealed type is "def (request: django.http.request.HttpRequest) -> django.http.response.HttpResponse"
- case: user_passes_test_custom_user_model
main: |
from typing import Union
from django.contrib.auth.decorators import user_passes_test
from django.contrib.auth.models import AnonymousUser
from django.http import HttpRequest, HttpResponse
from users.models import User
def check_user(user: Union[User, AnonymousUser]) -> bool: ...
@user_passes_test(check_user)
def view_func(request: HttpRequest) -> HttpResponse: ...
custom_settings: |
AUTH_USER_MODEL = "users.User"
INSTALLED_APPS = ("django.contrib.auth", "django.contrib.contenttypes", "users")
files:
- path: users/__init__.py
- path: users/models.py
content: |
from django.contrib.auth.models import AbstractBaseUser
from django.db import models
class User(AbstractBaseUser):
email = models.EmailField(unique=True)
USERNAME_FIELD = "email"
- case: user_passes_test_bare_is_error
main: |
from django.http import HttpRequest, HttpResponse
Expand Down

0 comments on commit 6c372f1

Please sign in to comment.