Skip to content

Commit

Permalink
ref: add mypy plugin to fix referencing .objects through a TypeVar
Browse files Browse the repository at this point in the history
  • Loading branch information
asottile-sentry committed Jun 17, 2024
1 parent 070a2b3 commit 8e3b2b2
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/sentry/hybridcloud/models/webhookpayload.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import datetime
from typing import Any
from typing import Any, Self

from django.db import models
from django.http import HttpRequest
Expand Down Expand Up @@ -79,7 +79,7 @@ def create_from_request(
identifier: int | str,
request: HttpRequest,
integration_id: int | None = None,
) -> WebhookPayload:
) -> Self:
metrics.incr("hybridcloud.deliver_webhooks.saved")
return cls.objects.create(
mailbox_name=f"{provider}:{identifier}",
Expand Down
75 changes: 75 additions & 0 deletions tests/tools/mypy_helpers/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,78 @@ def _mypy() -> tuple[int, str]:
cfg.write_text('[tool.mypy]\nplugins = ["tools.mypy_helpers.plugin"]\n')
ret, out = _mypy()
assert ret == 0


def test_resolution_of_objects_across_typevar(tmp_path: pathlib.Path) -> None:
src = """\
from typing import assert_type, TypeVar
from sentry.db.models.base import Model
M = TypeVar("M", bound=Model, covariant=True)
def f(m: type[M]) -> M:
return m.objects.get()
class C(Model): pass
assert_type(f(C), C)
"""
expected = """\
<string>:8: error: Incompatible return value type (got "Model", expected "M") [return-value]
Found 1 error in 1 file (checked 1 source file)
"""

# tools tests aren't allowed to import from `sentry` so we fixture
# the particular source file we are testing
models_dir = tmp_path.joinpath("sentry/db/models")
models_dir.mkdir(parents=True)

models_base_src = """\
from typing import ClassVar, Self
from .manager.base import BaseManager
class Model:
objects: ClassVar[BaseManager[Self]]
"""
models_dir.joinpath("base.pyi").write_text(models_base_src)

manager_dir = models_dir.joinpath("manager")
manager_dir.mkdir(parents=True)

manager_base_src = """\
from typing import Generic, TypeVar
M = TypeVar("M")
class BaseManager(Generic[M]):
def get(self) -> M: ...
"""
manager_dir.joinpath("base.pyi").write_text(manager_base_src)

cfg = tmp_path.joinpath("mypy.toml")
cfg.write_text("[tool.mypy]\nplugins = []\n")

# can't use our helper above because we're fixturing sentry src, so mimic it here
def _mypy() -> tuple[int, str]:
ret = subprocess.run(
(
*(sys.executable, "-m", "mypy"),
*("--config", cfg),
*("-c", src),
),
env={**os.environ, "MYPYPATH": str(tmp_path)},
capture_output=True,
encoding="UTF-8",
)
assert not ret.stderr
return ret.returncode, ret.stdout

ret, out = _mypy()
assert ret
assert out == expected

cfg.write_text('[tool.mypy]\nplugins = ["tools.mypy_helpers.plugin"]\n')
ret, out = _mypy()
assert ret == 0
25 changes: 25 additions & 0 deletions tools/mypy_helpers/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
NoneType,
Type,
TypeOfAny,
TypeType,
TypeVarType,
UnionType,
)

Expand Down Expand Up @@ -114,6 +116,23 @@ def _lazy_service_wrapper_attribute(ctx: AttributeContext, *, attr: str) -> Type
return member


def _resolve_objects_for_typevars(ctx: AttributeContext) -> Type:
# XXX: hack around python/mypy#17395

# self: type[<TypeVar>]
# default_attr_type: BaseManager[ConcreteTypeVarBound]
if (
isinstance(ctx.type, TypeType)
and isinstance(ctx.type.item, TypeVarType)
and isinstance(ctx.default_attr_type, Instance)
and ctx.default_attr_type.type.fullname == "sentry.db.models.manager.base.BaseManager"
):
tvar = ctx.type.item
return ctx.default_attr_type.copy_modified(args=(tvar,))
else:
return ctx.default_attr_type


class SentryMypyPlugin(Plugin):
def get_function_signature_hook(
self, fullname: str
Expand All @@ -127,6 +146,12 @@ def get_base_class_hook(self, fullname: str) -> Callable[[ClassDefContext], None
else:
return None

def get_class_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None:
if fullname.startswith("sentry.") and fullname.endswith(".objects"):
return _resolve_objects_for_typevars
else:
return None

def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None:
if fullname.startswith("sentry.utils.lazy_service_wrapper.LazyServiceWrapper."):
_, attr = fullname.rsplit(".", 1)
Expand Down

0 comments on commit 8e3b2b2

Please sign in to comment.