Skip to content

Commit

Permalink
Add signature for attr.evolve (#14526)
Browse files Browse the repository at this point in the history
Validate `attr.evolve` calls to specify correct arguments and types.

The implementation makes it so that at every point where `attr.evolve`
is called, the signature is modified to expect the attrs class'
initializer's arguments (but so that they're all kw-only and optional).

Notes:
- Added `class dict: pass` to some fixtures files since our attrs type
stubs now have **kwargs and that triggers a `builtin.dict` lookup in
dozens of attrs tests.
- Looking up the type of the 1st argument with
`ctx.api.expr_checker.accept(inst_arg)` which is a hack since it's not
part of the plugin API. This is a compromise for due to #10216.

Fixes #14525.
  • Loading branch information
ikonst authored Mar 6, 2023
1 parent 2ab1d82 commit bbc9cce
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 1 deletion.
66 changes: 65 additions & 1 deletion mypy/plugins/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from typing_extensions import Final, Literal

import mypy.plugin # To avoid circular imports.
from mypy.checker import TypeChecker
from mypy.errorcodes import LITERAL_REQ
from mypy.exprtotype import TypeTranslationError, expr_to_unanalyzed_type
from mypy.messages import format_type_bare
from mypy.nodes import (
ARG_NAMED,
ARG_NAMED_OPT,
Expand Down Expand Up @@ -77,6 +79,7 @@
SELF_TVAR_NAME: Final = "_AT"
MAGIC_ATTR_NAME: Final = "__attrs_attrs__"
MAGIC_ATTR_CLS_NAME_TEMPLATE: Final = "__{}_AttrsAttributes__" # The tuple subclass pattern.
ATTRS_INIT_NAME: Final = "__attrs_init__"


class Converter:
Expand Down Expand Up @@ -330,7 +333,7 @@ def attr_class_maker_callback(

adder = MethodAdder(ctx)
# If __init__ is not being generated, attrs still generates it as __attrs_init__ instead.
_add_init(ctx, attributes, adder, "__init__" if init else "__attrs_init__")
_add_init(ctx, attributes, adder, "__init__" if init else ATTRS_INIT_NAME)
if order:
_add_order(ctx, adder)
if frozen:
Expand Down Expand Up @@ -888,3 +891,64 @@ def add_method(
"""
self_type = self_type if self_type is not None else self.self_type
add_method(self.ctx, method_name, args, ret_type, self_type, tvd)


def _get_attrs_init_type(typ: Type) -> CallableType | None:
"""
If `typ` refers to an attrs class, gets the type of its initializer method.
"""
typ = get_proper_type(typ)
if not isinstance(typ, Instance):
return None
magic_attr = typ.type.get(MAGIC_ATTR_NAME)
if magic_attr is None or not magic_attr.plugin_generated:
return None
init_method = typ.type.get_method("__init__") or typ.type.get_method(ATTRS_INIT_NAME)
if not isinstance(init_method, FuncDef) or not isinstance(init_method.type, CallableType):
return None
return init_method.type


def evolve_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> CallableType:
"""
Generates a signature for the 'attr.evolve' function that's specific to the call site
and dependent on the type of the first argument.
"""
if len(ctx.args) != 2:
# Ideally the name and context should be callee's, but we don't have it in FunctionSigContext.
ctx.api.fail(f'"{ctx.default_signature.name}" has unexpected type annotation', ctx.context)
return ctx.default_signature

if len(ctx.args[0]) != 1:
return ctx.default_signature # leave it to the type checker to complain

inst_arg = ctx.args[0][0]

# <hack>
assert isinstance(ctx.api, TypeChecker)
inst_type = ctx.api.expr_checker.accept(inst_arg)
# </hack>

inst_type = get_proper_type(inst_type)
if isinstance(inst_type, AnyType):
return ctx.default_signature
inst_type_str = format_type_bare(inst_type)

attrs_init_type = _get_attrs_init_type(inst_type)
if not attrs_init_type:
ctx.api.fail(
f'Argument 1 to "evolve" has incompatible type "{inst_type_str}"; expected an attrs class',
ctx.context,
)
return ctx.default_signature

# AttrClass.__init__ has the following signature (or similar, if having kw-only & defaults):
# def __init__(self, attr1: Type1, attr2: Type2) -> None:
# We want to generate a signature for evolve that looks like this:
# def evolve(inst: AttrClass, *, attr1: Type1 = ..., attr2: Type2 = ...) -> AttrClass:
return attrs_init_type.copy_modified(
arg_names=["inst"] + attrs_init_type.arg_names[1:],
arg_kinds=[ARG_POS] + [ARG_NAMED_OPT for _ in attrs_init_type.arg_kinds[1:]],
ret_type=inst_type,
name=f"{ctx.default_signature.name} of {inst_type_str}",
)
10 changes: 10 additions & 0 deletions mypy/plugins/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
AttributeContext,
ClassDefContext,
FunctionContext,
FunctionSigContext,
MethodContext,
MethodSigContext,
Plugin,
Expand Down Expand Up @@ -46,6 +47,15 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type]
return singledispatch.create_singledispatch_function_callback
return None

def get_function_signature_hook(
self, fullname: str
) -> Callable[[FunctionSigContext], FunctionLike] | None:
from mypy.plugins import attrs

if fullname in ("attr.evolve", "attrs.evolve", "attr.assoc", "attrs.assoc"):
return attrs.evolve_function_sig_callback
return None

def get_method_signature_hook(
self, fullname: str
) -> Callable[[MethodSigContext], FunctionLike] | None:
Expand Down
78 changes: 78 additions & 0 deletions test-data/unit/check-attr.test
Original file line number Diff line number Diff line change
Expand Up @@ -1867,3 +1867,81 @@ D(1, "").a = 2 # E: Cannot assign to final attribute "a"
D(1, "").b = "2" # E: Cannot assign to final attribute "b"

[builtins fixtures/property.pyi]

[case testEvolve]
import attr

class Base:
pass

class Derived(Base):
pass

class Other:
pass

@attr.s(auto_attribs=True)
class C:
name: str
b: Base

c = C(name='foo', b=Derived())
c = attr.evolve(c)
c = attr.evolve(c, name='foo')
c = attr.evolve(c, 'foo') # E: Too many positional arguments for "evolve" of "C"
c = attr.evolve(c, b=Derived())
c = attr.evolve(c, b=Base())
c = attr.evolve(c, b=Other()) # E: Argument "b" to "evolve" of "C" has incompatible type "Other"; expected "Base"
c = attr.evolve(c, name=42) # E: Argument "name" to "evolve" of "C" has incompatible type "int"; expected "str"
c = attr.evolve(c, foobar=42) # E: Unexpected keyword argument "foobar" for "evolve" of "C"

# test passing instance as 'inst' kw
c = attr.evolve(inst=c, name='foo')
c = attr.evolve(not_inst=c, name='foo') # E: Missing positional argument "inst" in call to "evolve"

# test determining type of first argument's expression from something that's not NameExpr
def f() -> C:
return c

c = attr.evolve(f(), name='foo')

[builtins fixtures/attr.pyi]

[case testEvolveFromNonAttrs]
import attr

attr.evolve(42, name='foo') # E: Argument 1 to "evolve" has incompatible type "int"; expected an attrs class
attr.evolve(None, name='foo') # E: Argument 1 to "evolve" has incompatible type "None"; expected an attrs class
[case testEvolveFromAny]
from typing import Any
import attr

any: Any = 42
ret = attr.evolve(any, name='foo')
reveal_type(ret) # N: Revealed type is "Any"

[typing fixtures/typing-medium.pyi]

[case testEvolveVariants]
from typing import Any
import attr
import attrs


@attr.s(auto_attribs=True)
class C:
name: str

c = C(name='foo')

c = attr.assoc(c, name='test')
c = attr.assoc(c, name=42) # E: Argument "name" to "assoc" of "C" has incompatible type "int"; expected "str"

c = attrs.evolve(c, name='test')
c = attrs.evolve(c, name=42) # E: Argument "name" to "evolve" of "C" has incompatible type "int"; expected "str"

c = attrs.assoc(c, name='test')
c = attrs.assoc(c, name=42) # E: Argument "name" to "assoc" of "C" has incompatible type "int"; expected "str"

[builtins fixtures/attr.pyi]
[typing fixtures/typing-medium.pyi]
3 changes: 3 additions & 0 deletions test-data/unit/lib-stub/attr/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,6 @@ def field(
order: Optional[bool] = ...,
on_setattr: Optional[object] = ...,
) -> Any: ...

def evolve(inst: _T, **changes: Any) -> _T: ...
def assoc(inst: _T, **changes: Any) -> _T: ...
3 changes: 3 additions & 0 deletions test-data/unit/lib-stub/attrs/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,6 @@ def field(
order: Optional[bool] = ...,
on_setattr: Optional[object] = ...,
) -> Any: ...

def evolve(inst: _T, **changes: Any) -> _T: ...
def assoc(inst: _T, **changes: Any) -> _T: ...

0 comments on commit bbc9cce

Please sign in to comment.