From 8e3b2b25dbbd43ee50df59bf3b4cbf34590dc6f6 Mon Sep 17 00:00:00 2001 From: anthony sottile Date: Mon, 17 Jun 2024 14:24:26 -0400 Subject: [PATCH] ref: add mypy plugin to fix referencing .objects through a TypeVar --- .../hybridcloud/models/webhookpayload.py | 4 +- tests/tools/mypy_helpers/test_plugin.py | 75 +++++++++++++++++++ tools/mypy_helpers/plugin.py | 25 +++++++ 3 files changed, 102 insertions(+), 2 deletions(-) diff --git a/src/sentry/hybridcloud/models/webhookpayload.py b/src/sentry/hybridcloud/models/webhookpayload.py index 284b4659f2dab6..76abaf2a63aee2 100644 --- a/src/sentry/hybridcloud/models/webhookpayload.py +++ b/src/sentry/hybridcloud/models/webhookpayload.py @@ -1,7 +1,7 @@ from __future__ import annotations import datetime -from typing import Any +from typing import Any, Self from django.db import models from django.http import HttpRequest @@ -79,7 +79,7 @@ def create_from_request( identifier: int | str, request: HttpRequest, integration_id: int | None = None, - ) -> WebhookPayload: + ) -> Self: metrics.incr("hybridcloud.deliver_webhooks.saved") return cls.objects.create( mailbox_name=f"{provider}:{identifier}", diff --git a/tests/tools/mypy_helpers/test_plugin.py b/tests/tools/mypy_helpers/test_plugin.py index f7eaa93e045d93..42af5e0fb740a2 100644 --- a/tests/tools/mypy_helpers/test_plugin.py +++ b/tests/tools/mypy_helpers/test_plugin.py @@ -234,3 +234,78 @@ def _mypy() -> tuple[int, str]: cfg.write_text('[tool.mypy]\nplugins = ["tools.mypy_helpers.plugin"]\n') ret, out = _mypy() assert ret == 0 + + +def test_resolution_of_objects_across_typevar(tmp_path: pathlib.Path) -> None: + src = """\ +from typing import assert_type, TypeVar + +from sentry.db.models.base import Model + +M = TypeVar("M", bound=Model, covariant=True) + +def f(m: type[M]) -> M: + return m.objects.get() + +class C(Model): pass + +assert_type(f(C), C) +""" + expected = """\ +:8: error: Incompatible return value type (got "Model", expected "M") [return-value] +Found 1 error in 1 file (checked 1 source file) +""" + + # tools tests aren't allowed to import from `sentry` so we fixture + # the particular source file we are testing + models_dir = tmp_path.joinpath("sentry/db/models") + models_dir.mkdir(parents=True) + + models_base_src = """\ +from typing import ClassVar, Self + +from .manager.base import BaseManager + +class Model: + objects: ClassVar[BaseManager[Self]] +""" + models_dir.joinpath("base.pyi").write_text(models_base_src) + + manager_dir = models_dir.joinpath("manager") + manager_dir.mkdir(parents=True) + + manager_base_src = """\ +from typing import Generic, TypeVar + +M = TypeVar("M") + +class BaseManager(Generic[M]): + def get(self) -> M: ... + """ + manager_dir.joinpath("base.pyi").write_text(manager_base_src) + + cfg = tmp_path.joinpath("mypy.toml") + cfg.write_text("[tool.mypy]\nplugins = []\n") + + # can't use our helper above because we're fixturing sentry src, so mimic it here + def _mypy() -> tuple[int, str]: + ret = subprocess.run( + ( + *(sys.executable, "-m", "mypy"), + *("--config", cfg), + *("-c", src), + ), + env={**os.environ, "MYPYPATH": str(tmp_path)}, + capture_output=True, + encoding="UTF-8", + ) + assert not ret.stderr + return ret.returncode, ret.stdout + + ret, out = _mypy() + assert ret + assert out == expected + + cfg.write_text('[tool.mypy]\nplugins = ["tools.mypy_helpers.plugin"]\n') + ret, out = _mypy() + assert ret == 0 diff --git a/tools/mypy_helpers/plugin.py b/tools/mypy_helpers/plugin.py index 9ee797a4541770..8bd4214e3ce82e 100644 --- a/tools/mypy_helpers/plugin.py +++ b/tools/mypy_helpers/plugin.py @@ -17,6 +17,8 @@ NoneType, Type, TypeOfAny, + TypeType, + TypeVarType, UnionType, ) @@ -114,6 +116,23 @@ def _lazy_service_wrapper_attribute(ctx: AttributeContext, *, attr: str) -> Type return member +def _resolve_objects_for_typevars(ctx: AttributeContext) -> Type: + # XXX: hack around python/mypy#17395 + + # self: type[] + # default_attr_type: BaseManager[ConcreteTypeVarBound] + if ( + isinstance(ctx.type, TypeType) + and isinstance(ctx.type.item, TypeVarType) + and isinstance(ctx.default_attr_type, Instance) + and ctx.default_attr_type.type.fullname == "sentry.db.models.manager.base.BaseManager" + ): + tvar = ctx.type.item + return ctx.default_attr_type.copy_modified(args=(tvar,)) + else: + return ctx.default_attr_type + + class SentryMypyPlugin(Plugin): def get_function_signature_hook( self, fullname: str @@ -127,6 +146,12 @@ def get_base_class_hook(self, fullname: str) -> Callable[[ClassDefContext], None else: return None + def get_class_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None: + if fullname.startswith("sentry.") and fullname.endswith(".objects"): + return _resolve_objects_for_typevars + else: + return None + def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None: if fullname.startswith("sentry.utils.lazy_service_wrapper.LazyServiceWrapper."): _, attr = fullname.rsplit(".", 1)