Skip to content

Commit

Permalink
[mypyc] Optimize classmethod calls via cls (#14789)
Browse files Browse the repository at this point in the history
If the class has no subclasses, we can statically bind the call target:
```
class C:
    @classmethod
    def f(cls) -> int:
        return cls.g()  # This can be statically bound, same as C.g()

    @classmethod
    def g(cls) -> int:
        return 1
```
For this to be safe, also reject assignments to the "cls" argument in
classmethods in compiled code.

This makes the deltablue benchmark about 11% faster.
  • Loading branch information
JukkaL authored Mar 2, 2023
1 parent 9393c22 commit 43883fa
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 50 deletions.
4 changes: 4 additions & 0 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,7 @@ def deserialize(cls, data: JsonDict) -> Decorator:

VAR_FLAGS: Final = [
"is_self",
"is_cls",
"is_initialized_in_class",
"is_staticmethod",
"is_classmethod",
Expand Down Expand Up @@ -935,6 +936,7 @@ class Var(SymbolNode):
"type",
"final_value",
"is_self",
"is_cls",
"is_ready",
"is_inferred",
"is_initialized_in_class",
Expand Down Expand Up @@ -967,6 +969,8 @@ def __init__(self, name: str, type: mypy.types.Type | None = None) -> None:
self.type: mypy.types.Type | None = type # Declared or inferred type, or None
# Is this the first argument to an ordinary method (usually "self")?
self.is_self = False
# Is this the first argument to a classmethod (typically "cls")?
self.is_cls = False
self.is_ready = True # If inferred, is the inferred type available?
self.is_inferred = self.type is None
# Is this initialized explicitly to a non-None value in class body?
Expand Down
7 changes: 5 additions & 2 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,8 +1369,11 @@ def analyze_function_body(self, defn: FuncItem) -> None:
# The first argument of a non-static, non-class method is like 'self'
# (though the name could be different), having the enclosing class's
# instance type.
if is_method and not defn.is_static and not defn.is_class and defn.arguments:
defn.arguments[0].variable.is_self = True
if is_method and not defn.is_static and defn.arguments:
if not defn.is_class:
defn.arguments[0].variable.is_self = True
else:
defn.arguments[0].variable.is_cls = True

defn.body.accept(self)
self.function_stack.pop()
Expand Down
7 changes: 6 additions & 1 deletion mypyc/ir/class_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ def __init__(
self.base_mro: list[ClassIR] = [self]

# Direct subclasses of this class (use subclasses() to also include non-direct ones)
# None if separate compilation prevents this from working
# None if separate compilation prevents this from working.
#
# Often it's better to use has_no_subclasses() or subclasses() instead.
self.children: list[ClassIR] | None = []

# Instance attributes that are initialized in the class body.
Expand Down Expand Up @@ -301,6 +303,9 @@ def get_method(self, name: str, *, prefer_method: bool = False) -> FuncIR | None
def has_method_decl(self, name: str) -> bool:
return any(name in ir.method_decls for ir in self.mro)

def has_no_subclasses(self) -> bool:
return self.children == [] and not self.allow_interpreted_subclasses

def subclasses(self) -> set[ClassIR] | None:
"""Return all subclasses of this class, both direct and indirect.
Expand Down
8 changes: 7 additions & 1 deletion mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,11 @@ def load_final_literal_value(self, val: int | str | bytes | float | bool, line:
else:
assert False, "Unsupported final literal value"

def get_assignment_target(self, lvalue: Lvalue, line: int = -1) -> AssignmentTarget:
def get_assignment_target(
self, lvalue: Lvalue, line: int = -1, *, for_read: bool = False
) -> AssignmentTarget:
if line == -1:
line = lvalue.line
if isinstance(lvalue, NameExpr):
# If we are visiting a decorator, then the SymbolNode we really want to be looking at
# is the function that is decorated, not the entire Decorator node itself.
Expand All @@ -578,6 +582,8 @@ def get_assignment_target(self, lvalue: Lvalue, line: int = -1) -> AssignmentTar
# New semantic analyzer doesn't create ad-hoc Vars for special forms.
assert lvalue.is_special_form
symbol = Var(lvalue.name)
if not for_read and isinstance(symbol, Var) and symbol.is_cls:
self.error("Cannot assign to the first argument of classmethod", line)
if lvalue.kind == LDEF:
if symbol not in self.symtables[-1]:
# If the function is a generator function, then first define a new variable
Expand Down
65 changes: 40 additions & 25 deletions mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
)
from mypy.types import Instance, ProperType, TupleType, TypeType, get_proper_type
from mypyc.common import MAX_SHORT_INT
from mypyc.ir.class_ir import ClassIR
from mypyc.ir.func_ir import FUNC_CLASSMETHOD, FUNC_STATICMETHOD
from mypyc.ir.ops import (
Assign,
Expand Down Expand Up @@ -174,7 +175,7 @@ def transform_name_expr(builder: IRBuilder, expr: NameExpr) -> Value:
)
return obj
else:
return builder.read(builder.get_assignment_target(expr), expr.line)
return builder.read(builder.get_assignment_target(expr, for_read=True), expr.line)

return builder.load_global(expr)

Expand Down Expand Up @@ -336,30 +337,7 @@ def translate_method_call(builder: IRBuilder, expr: CallExpr, callee: MemberExpr
# Call a method via the *class*
assert isinstance(callee.expr.node, TypeInfo)
ir = builder.mapper.type_to_ir[callee.expr.node]
decl = ir.method_decl(callee.name)
args = []
arg_kinds, arg_names = expr.arg_kinds[:], expr.arg_names[:]
# Add the class argument for class methods in extension classes
if decl.kind == FUNC_CLASSMETHOD and ir.is_ext_class:
args.append(builder.load_native_type_object(callee.expr.node.fullname))
arg_kinds.insert(0, ARG_POS)
arg_names.insert(0, None)
args += [builder.accept(arg) for arg in expr.args]

if ir.is_ext_class:
return builder.builder.call(decl, args, arg_kinds, arg_names, expr.line)
else:
obj = builder.accept(callee.expr)
return builder.gen_method_call(
obj,
callee.name,
args,
builder.node_type(expr),
expr.line,
expr.arg_kinds,
expr.arg_names,
)

return call_classmethod(builder, ir, expr, callee)
elif builder.is_module_member_expr(callee):
# Fall back to a PyCall for non-native module calls
function = builder.accept(callee)
Expand All @@ -368,6 +346,17 @@ def translate_method_call(builder: IRBuilder, expr: CallExpr, callee: MemberExpr
function, args, expr.line, arg_kinds=expr.arg_kinds, arg_names=expr.arg_names
)
else:
if isinstance(callee.expr, RefExpr):
node = callee.expr.node
if isinstance(node, Var) and node.is_cls:
typ = get_proper_type(node.type)
if isinstance(typ, TypeType) and isinstance(typ.item, Instance):
class_ir = builder.mapper.type_to_ir.get(typ.item.type)
if class_ir and class_ir.is_ext_class and class_ir.has_no_subclasses():
# Call a native classmethod via cls that can be statically bound,
# since the class has no subclasses.
return call_classmethod(builder, class_ir, expr, callee)

receiver_typ = builder.node_type(callee.expr)

# If there is a specializer for this method name/type, try calling it.
Expand All @@ -389,6 +378,32 @@ def translate_method_call(builder: IRBuilder, expr: CallExpr, callee: MemberExpr
)


def call_classmethod(builder: IRBuilder, ir: ClassIR, expr: CallExpr, callee: MemberExpr) -> Value:
decl = ir.method_decl(callee.name)
args = []
arg_kinds, arg_names = expr.arg_kinds[:], expr.arg_names[:]
# Add the class argument for class methods in extension classes
if decl.kind == FUNC_CLASSMETHOD and ir.is_ext_class:
args.append(builder.load_native_type_object(ir.fullname))
arg_kinds.insert(0, ARG_POS)
arg_names.insert(0, None)
args += [builder.accept(arg) for arg in expr.args]

if ir.is_ext_class:
return builder.builder.call(decl, args, arg_kinds, arg_names, expr.line)
else:
obj = builder.accept(callee.expr)
return builder.gen_method_call(
obj,
callee.name,
args,
builder.node_type(expr),
expr.line,
expr.arg_kinds,
expr.arg_names,
)


def translate_super_method_call(builder: IRBuilder, expr: CallExpr, callee: SuperExpr) -> Value:
if callee.info is None or (len(callee.call.args) != 0 and len(callee.call.args) != 2):
return translate_call(builder, expr, callee)
Expand Down
69 changes: 69 additions & 0 deletions mypyc/test-data/irbuild-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,75 @@ L0:
r3 = CPyTagged_Add(r0, r2)
return r3

[case testCallClassMethodViaCls]
class C:
@classmethod
def f(cls, x: int) -> int:
return cls.g(x)

@classmethod
def g(cls, x: int) -> int:
return x

class D:
@classmethod
def f(cls, x: int) -> int:
# TODO: This could aso be optimized, since g is not ever overridden
return cls.g(x)

@classmethod
def g(cls, x: int) -> int:
return x

class DD(D):
pass
[out]
def C.f(cls, x):
cls :: object
x :: int
r0 :: object
r1 :: int
L0:
r0 = __main__.C :: type
r1 = C.g(r0, x)
return r1
def C.g(cls, x):
cls :: object
x :: int
L0:
return x
def D.f(cls, x):
cls :: object
x :: int
r0 :: str
r1, r2 :: object
r3 :: int
L0:
r0 = 'g'
r1 = box(int, x)
r2 = CPyObject_CallMethodObjArgs(cls, r0, r1, 0)
r3 = unbox(int, r2)
return r3
def D.g(cls, x):
cls :: object
x :: int
L0:
return x

[case testCannotAssignToClsArgument]
from typing import Any, cast

class C:
@classmethod
def m(cls) -> None:
cls = cast(Any, D) # E: Cannot assign to the first argument of classmethod
cls, x = cast(Any, D), 1 # E: Cannot assign to the first argument of classmethod
cls, x = cast(Any, [1, 2]) # E: Cannot assign to the first argument of classmethod
cls.m()

class D:
pass

[case testSuper1]
class A:
def __init__(self, x: int) -> None:
Expand Down
107 changes: 86 additions & 21 deletions mypyc/test-data/run-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -662,42 +662,107 @@ Traceback (most recent call last):
AttributeError: attribute 'x' of 'X' undefined

[case testClassMethods]
MYPY = False
if MYPY:
from typing import ClassVar
from typing import ClassVar, Any
from typing_extensions import final
from mypy_extensions import mypyc_attr

from interp import make_interpreted_subclass

class C:
lurr: 'ClassVar[int]' = 9
lurr: ClassVar[int] = 9
@staticmethod
def foo(x: int) -> int: return 10 + x
def foo(x: int) -> int:
return 10 + x
@classmethod
def bar(cls, x: int) -> int: return cls.lurr + x
def bar(cls, x: int) -> int:
return cls.lurr + x
@staticmethod
def baz(x: int, y: int = 10) -> int: return y - x
def baz(x: int, y: int = 10) -> int:
return y - x
@classmethod
def quux(cls, x: int, y: int = 10) -> int: return y - x
def quux(cls, x: int, y: int = 10) -> int:
return y - x
@classmethod
def call_other(cls, x: int) -> int:
return cls.quux(x, 3)

class D(C):
def f(self) -> int:
return super().foo(1) + super().bar(2) + super().baz(10) + super().quux(10)

def test1() -> int:
def ctest1() -> int:
return C.foo(1) + C.bar(2) + C.baz(10) + C.quux(10) + C.quux(y=10, x=9)
def test2() -> int:

def ctest2() -> int:
c = C()
return c.foo(1) + c.bar(2) + c.baz(10)
[file driver.py]
from native import *
assert C.foo(10) == 20
assert C.bar(10) == 19
c = C()
assert c.foo(10) == 20
assert c.bar(10) == 19

assert test1() == 23
assert test2() == 22
CAny: Any = C

def test_classmethod_using_any() -> None:
assert CAny.foo(10) == 20
assert CAny.bar(10) == 19

def test_classmethod_on_instance() -> None:
c = C()
assert c.foo(10) == 20
assert c.bar(10) == 19
assert c.call_other(1) == 2

def test_classmethod_misc() -> None:
assert ctest1() == 23
assert ctest2() == 22
assert C.call_other(2) == 1

def test_classmethod_using_super() -> None:
d = D()
assert d.f() == 22

d = D()
assert d.f() == 22
@final
class F1:
@classmethod
def f(cls, x: int) -> int:
return cls.g(x)

@classmethod
def g(cls, x: int) -> int:
return x + 1

class F2: # Implicitly final (no subclasses)
@classmethod
def f(cls, x: int) -> int:
return cls.g(x)

@classmethod
def g(cls, x: int) -> int:
return x + 1

def test_classmethod_of_final_class() -> None:
assert F1.f(5) == 6
assert F2.f(7) == 8

@mypyc_attr(allow_interpreted_subclasses=True)
class CI:
@classmethod
def f(cls, x: int) -> int:
return cls.g(x)

@classmethod
def g(cls, x: int) -> int:
return x + 1

def test_classmethod_with_allow_interpreted() -> None:
assert CI.f(4) == 5
sub = make_interpreted_subclass(CI)
assert sub.f(4) == 7

[file interp.py]
def make_interpreted_subclass(base):
class Sub(base):
@classmethod
def g(cls, x: int) -> int:
return x + 3
return Sub

[case testSuper]
from mypy_extensions import trait
Expand Down

0 comments on commit 43883fa

Please sign in to comment.