Skip to content

Commit

Permalink
[mypyc] Support __pow__, __rpow__, and __ipow__ dunders
Browse files Browse the repository at this point in the history
Unlike every other slot, power slots are ternary. Lots of special casing
had to be done in generate_bin_op_wrapper() to support the third slot
argument. Add in the fact it's allowed and common to only define
__(r|i)pow__ to take two arguments and you get a mess of a patch...
  • Loading branch information
ichard26 committed Feb 4, 2023
1 parent dc75c19 commit c408b06
Show file tree
Hide file tree
Showing 8 changed files with 291 additions and 25 deletions.
6 changes: 6 additions & 0 deletions mypyc/codegen/emitclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
generate_dunder_wrapper,
generate_get_wrapper,
generate_hash_wrapper,
generate_ipow_wrapper,
generate_len_wrapper,
generate_richcompare_wrapper,
generate_set_del_item_wrapper,
Expand Down Expand Up @@ -109,6 +110,11 @@ def wrapper_slot(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
"__ior__": ("nb_inplace_or", generate_dunder_wrapper),
"__ixor__": ("nb_inplace_xor", generate_dunder_wrapper),
"__imatmul__": ("nb_inplace_matrix_multiply", generate_dunder_wrapper),
# Ternary operations. (yes, really)
# These are special cased in generate_bin_op_wrapper().
"__pow__": ("nb_power", generate_bin_op_wrapper),
"__rpow__": ("nb_power", generate_bin_op_wrapper),
"__ipow__": ("nb_inplace_power", generate_ipow_wrapper),
}

AS_ASYNC_SLOT_DEFS: SlotTable = {
Expand Down
105 changes: 88 additions & 17 deletions mypyc/codegen/emitwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,32 @@ def generate_dunder_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
return gen.wrapper_name()


def generate_ipow_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
"""Generate a wrapper for native __ipow__.
Since __ipow__ fills a ternary slot, but almost no one defines __ipow__ to take three
arguments, the wrapper needs to tweaked to force it to accept three arguments.
"""
gen = WrapperGenerator(cl, emitter)
gen.set_target(fn)
assert len(fn.args) in (2, 3), "__ipow__ should only take 2 or 3 arguments"
gen.arg_names = ["self", "exp", "mod"]
gen.emit_header()
gen.emit_arg_processing()
handle_third_pow_argument(
fn,
emitter,
gen,
if_unsupported=[
'PyErr_SetString(PyExc_TypeError, "__ipow__ takes 2 positional arguments but 3 were given");',
"return NULL;",
],
)
gen.emit_call()
gen.finish()
return gen.wrapper_name()


def generate_bin_op_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
"""Generates a wrapper for a native binary dunder method.
Expand All @@ -311,13 +337,16 @@ def generate_bin_op_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
"""
gen = WrapperGenerator(cl, emitter)
gen.set_target(fn)
gen.arg_names = ["left", "right"]
if fn.name in ("__pow__", "__rpow__"):
gen.arg_names = ["left", "right", "mod"]
else:
gen.arg_names = ["left", "right"]
wrapper_name = gen.wrapper_name()

gen.emit_header()
if fn.name not in reverse_op_methods and fn.name in reverse_op_method_names:
# There's only a reverse operator method.
generate_bin_op_reverse_only_wrapper(emitter, gen)
generate_bin_op_reverse_only_wrapper(fn, emitter, gen)
else:
rmethod = reverse_op_methods[fn.name]
fn_rev = cl.get_method(rmethod)
Expand All @@ -334,6 +363,7 @@ def generate_bin_op_forward_only_wrapper(
fn: FuncIR, emitter: Emitter, gen: WrapperGenerator
) -> None:
gen.emit_arg_processing(error=GotoHandler("typefail"), raise_exception=False)
handle_third_pow_argument(fn, emitter, gen, if_unsupported=["goto typefail;"])
gen.emit_call(not_implemented_handler="goto typefail;")
gen.emit_error_handling()
emitter.emit_label("typefail")
Expand All @@ -352,19 +382,16 @@ def generate_bin_op_forward_only_wrapper(
# if not isinstance(other, int):
# return NotImplemented
# ...
rmethod = reverse_op_methods[fn.name]
emitter.emit_line(f"_Py_IDENTIFIER({rmethod});")
emitter.emit_line(
'return CPy_CallReverseOpMethod(obj_left, obj_right, "{}", &PyId_{});'.format(
op_methods_to_symbols[fn.name], rmethod
)
)
generate_bin_op_reverse_dunder_call(fn, emitter, reverse_op_methods[fn.name])
gen.finish()


def generate_bin_op_reverse_only_wrapper(emitter: Emitter, gen: WrapperGenerator) -> None:
def generate_bin_op_reverse_only_wrapper(
fn: FuncIR, emitter: Emitter, gen: WrapperGenerator
) -> None:
gen.arg_names = ["right", "left"]
gen.emit_arg_processing(error=GotoHandler("typefail"), raise_exception=False)
handle_third_pow_argument(fn, emitter, gen, if_unsupported=["goto typefail;"])
gen.emit_call()
gen.emit_error_handling()
emitter.emit_label("typefail")
Expand All @@ -390,7 +417,14 @@ def generate_bin_op_both_wrappers(
)
)
gen.emit_arg_processing(error=GotoHandler("typefail"), raise_exception=False)
gen.emit_call(not_implemented_handler="goto typefail;")
handle_third_pow_argument(fn, emitter, gen, if_unsupported=["goto typefail2;"])
# Ternary __rpow__ calls aren't a thing so immediately bail
# if ternary __pow__ returns NotImplemented.
if fn.name == "__pow__" and len(fn.args) == 3:
fwd_not_implemented_handler = "goto typefail2;"
else:
fwd_not_implemented_handler = "goto typefail;"
gen.emit_call(not_implemented_handler=fwd_not_implemented_handler)
gen.emit_error_handling()
emitter.emit_line("}")
emitter.emit_label("typefail")
Expand All @@ -402,22 +436,59 @@ def generate_bin_op_both_wrappers(
gen.set_target(fn_rev)
gen.arg_names = ["right", "left"]
gen.emit_arg_processing(error=GotoHandler("typefail2"), raise_exception=False)
handle_third_pow_argument(fn_rev, emitter, gen, if_unsupported=["goto typefail2;"])
gen.emit_call()
gen.emit_error_handling()
emitter.emit_line("} else {")
emitter.emit_line(f"_Py_IDENTIFIER({fn_rev.name});")
emitter.emit_line(
'return CPy_CallReverseOpMethod(obj_left, obj_right, "{}", &PyId_{});'.format(
op_methods_to_symbols[fn.name], fn_rev.name
)
)
generate_bin_op_reverse_dunder_call(fn, emitter, fn_rev.name)
emitter.emit_line("}")
emitter.emit_label("typefail2")
emitter.emit_line("Py_INCREF(Py_NotImplemented);")
emitter.emit_line("return Py_NotImplemented;")
gen.finish()


def generate_bin_op_reverse_dunder_call(fn: FuncIR, emitter: Emitter, rmethod: str) -> None:
if fn.name in ("__pow__", "__rpow__"):
# Ternary pow() will never call the reverse dunder.
emitter.emit_line("if (obj_mod == Py_None) {")
emitter.emit_line(f"_Py_IDENTIFIER({rmethod});")
emitter.emit_line(
'return CPy_CallReverseOpMethod(obj_left, obj_right, "{}", &PyId_{});'.format(
op_methods_to_symbols[fn.name], rmethod
)
)
if fn.name in ("__pow__", "__rpow__"):
emitter.emit_line("} else {")
emitter.emit_line("Py_INCREF(Py_NotImplemented);")
emitter.emit_line("return Py_NotImplemented;")
emitter.emit_line("}")


def handle_third_pow_argument(
fn: FuncIR, emitter: Emitter, gen: WrapperGenerator, *, if_unsupported: list[str]
) -> None:
if fn.name not in ("__pow__", "__rpow__", "__ipow__"):
return

if (fn.name in ("__pow__", "__ipow__") and len(fn.args) == 2) or fn.name == "__rpow__":
# If the power dunder only supports two arguments and the third
# argument (AKA mod) is set to a non-default value, simply bail.
#
# Importantly, this prevents any ternary __rpow__ calls from
# happening (as per the language specification).
emitter.emit_line("if (obj_mod != Py_None) {")
for line in if_unsupported:
emitter.emit_line(line)
emitter.emit_line("}")
# The slot wrapper will receive three arguments, but the call only
# supports two so make sure that the third argument isn't passed
# along. This is needed as two-argument __(i)pow__ is allowed and
# rather common.
if len(gen.arg_names) == 3:
gen.arg_names.pop()


RICHCOMPARE_OPS = {
"__lt__": "Py_LT",
"__gt__": "Py_GT",
Expand Down
1 change: 1 addition & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ CPyTagged CPyObject_Hash(PyObject *o);
PyObject *CPyObject_GetAttr3(PyObject *v, PyObject *name, PyObject *defl);
PyObject *CPyIter_Next(PyObject *iter);
PyObject *CPyNumber_Power(PyObject *base, PyObject *index);
PyObject *CPyObject_InPlacePower(PyObject *base, PyObject *index);
PyObject *CPyObject_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end);


Expand Down
5 changes: 5 additions & 0 deletions mypyc/lib-rt/generic_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ PyObject *CPyNumber_Power(PyObject *base, PyObject *index)
return PyNumber_Power(base, index, Py_None);
}

PyObject *CPyNumber_InPlacePower(PyObject *base, PyObject *index)
{
return PyNumber_InPlacePower(base, index, Py_None);
}

PyObject *CPyObject_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end) {
PyObject *start_obj = CPyTagged_AsObject(start);
PyObject *end_obj = CPyTagged_AsObject(end);
Expand Down
27 changes: 19 additions & 8 deletions mypyc/primitives/generic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,25 @@
priority=0,
)

binary_op(
name="**",
arg_types=[object_rprimitive, object_rprimitive],
return_type=object_rprimitive,
error_kind=ERR_MAGIC,
c_function_name="CPyNumber_Power",
priority=0,
)
for op, c_function in (("**", "CPyNumber_Power"), ("**=", "CPyNumber_InPlacePower")):
binary_op(
name=op,
arg_types=[object_rprimitive, object_rprimitive],
return_type=object_rprimitive,
error_kind=ERR_MAGIC,
c_function_name=c_function,
priority=0,
)

for arg_count, c_function in ((2, "CPyNumber_Power"), (3, "PyNumber_Power")):
function_op(
name="builtins.pow",
arg_types=[object_rprimitive] * arg_count,
return_type=object_rprimitive,
error_kind=ERR_MAGIC,
c_function_name=c_function,
priority=0,
)

binary_op(
name="in",
Expand Down
22 changes: 22 additions & 0 deletions mypyc/test-data/fixtures/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,21 @@ def __divmod__(self, other: T_contra) -> T_co: ...
class __SupportsRDivMod(Protocol[T_contra, T_co]):
def __rdivmod__(self, other: T_contra) -> T_co: ...

_M = TypeVar("_M", contravariant=True)

class __SupportsPow2(Protocol[T_contra, T_co]):
def __pow__(self, other: T_contra) -> T_co: ...

class __SupportsPow3NoneOnly(Protocol[T_contra, T_co]):
def __pow__(self, other: T_contra, modulo: None = ...) -> T_co: ...

class __SupportsPow3(Protocol[T_contra, _M, T_co]):
def __pow__(self, other: T_contra, modulo: _M) -> T_co: ...

__SupportsSomeKindOfPow = Union[
__SupportsPow2[Any, Any], __SupportsPow3NoneOnly[Any, Any] | __SupportsPow3[Any, Any, Any]
]

class object:
def __init__(self) -> None: pass
def __eq__(self, x: object) -> bool: pass
Expand Down Expand Up @@ -99,6 +114,7 @@ def __add__(self, n: float) -> float: pass
def __sub__(self, n: float) -> float: pass
def __mul__(self, n: float) -> float: pass
def __truediv__(self, n: float) -> float: pass
def __pow__(self, n: float) -> float: pass
def __neg__(self) -> float: pass
def __pos__(self) -> float: pass
def __abs__(self) -> float: pass
Expand Down Expand Up @@ -318,6 +334,12 @@ def abs(x: __SupportsAbs[T]) -> T: ...
def divmod(x: __SupportsDivMod[T_contra, T_co], y: T_contra) -> T_co: ...
@overload
def divmod(x: T_contra, y: __SupportsRDivMod[T_contra, T_co]) -> T_co: ...
@overload
def pow(base: __SupportsPow2[T_contra, T_co], exp: T_contra, mod: None = None) -> T_co: ...
@overload
def pow(base: __SupportsPow3NoneOnly[T_contra, T_co], exp: T_contra, mod: None = None) -> T_co: ...
@overload
def pow(base: __SupportsPow3[T_contra, _M, T_co], exp: T_contra, mod: _M) -> T_co: ...
def exit() -> None: ...
def min(x: T, y: T) -> T: ...
def max(x: T, y: T) -> T: ...
Expand Down
25 changes: 25 additions & 0 deletions mypyc/test-data/irbuild-any.test
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ L0:
[case testFunctionBasedOps]
def f() -> None:
a = divmod(5, 2)
def f2() -> int:
return pow(2, 5)
def f3() -> float:
return pow(2, 5, 3)
[out]
def f():
r0, r1, r2 :: object
Expand All @@ -212,4 +216,25 @@ L0:
r3 = unbox(tuple[float, float], r2)
a = r3
return 1
def f2():
r0, r1, r2 :: object
r3 :: int
L0:
r0 = object 2
r1 = object 5
r2 = CPyNumber_Power(r0, r1)
r3 = unbox(int, r2)
return r3
def f3():
r0, r1, r2, r3 :: object
r4 :: int
r5 :: object
L0:
r0 = object 2
r1 = object 5
r2 = object 3
r3 = PyNumber_Power(r0, r1, r2)
r4 = unbox(int, r3)
r5 = box(int, r4)
return r5

Loading

0 comments on commit c408b06

Please sign in to comment.