Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mypyc] Generate faster code for bool comparisons and arithmetic #14489

Merged
merged 13 commits into from
Feb 5, 2023
6 changes: 5 additions & 1 deletion mypyc/analysis/ircheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,10 @@ def check_type_coercion(self, op: Op, src: RType, dest: RType) -> None:
source=op, desc=f"Cannot coerce source type {src.name} to dest type {dest.name}"
)

def check_compatibility(self, op: Op, t: RType, s: RType) -> None:
if not can_coerce_to(t, s) or not can_coerce_to(s, t):
self.fail(source=op, desc=f"{t.name} and {s.name} are not compatible")

def visit_goto(self, op: Goto) -> None:
self.check_control_op_targets(op)

Expand Down Expand Up @@ -375,7 +379,7 @@ def visit_int_op(self, op: IntOp) -> None:
pass

def visit_comparison_op(self, op: ComparisonOp) -> None:
pass
self.check_compatibility(op, op.lhs.type, op.rhs.type)

def visit_load_mem(self, op: LoadMem) -> None:
pass
Expand Down
82 changes: 57 additions & 25 deletions mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@
">>=",
}

# Binary operations on bools that are specialized and don't just promote operands to int
BOOL_BINARY_OPS: Final = {"&", "&=", "|", "|=", "^", "^=", "==", "!=", "<", "<=", ">", ">="}


class LowLevelIRBuilder:
def __init__(self, current_module: str, mapper: Mapper, options: CompilerOptions) -> None:
Expand Down Expand Up @@ -326,13 +329,13 @@ def coerce(
):
# Equivalent types
return src
elif (
is_bool_rprimitive(src_type) or is_bit_rprimitive(src_type)
) and is_int_rprimitive(target_type):
elif (is_bool_rprimitive(src_type) or is_bit_rprimitive(src_type)) and is_tagged(
target_type
):
shifted = self.int_op(
bool_rprimitive, src, Integer(1, bool_rprimitive), IntOp.LEFT_SHIFT
)
return self.add(Extend(shifted, int_rprimitive, signed=False))
return self.add(Extend(shifted, target_type, signed=False))
elif (
is_bool_rprimitive(src_type) or is_bit_rprimitive(src_type)
) and is_fixed_width_rtype(target_type):
Expand Down Expand Up @@ -1245,48 +1248,45 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
return self.compare_bytes(lreg, rreg, op, line)
if is_tagged(ltype) and is_tagged(rtype) and op in int_comparison_op_mapping:
return self.compare_tagged(lreg, rreg, op, line)
if (
is_bool_rprimitive(ltype)
and is_bool_rprimitive(rtype)
and op in ("&", "&=", "|", "|=", "^", "^=")
):
return self.bool_bitwise_op(lreg, rreg, op[0], line)
if is_bool_rprimitive(ltype) and is_bool_rprimitive(rtype) and op in BOOL_BINARY_OPS:
if op in ComparisonOp.signed_ops:
return self.bool_comparison_op(lreg, rreg, op, line)
else:
return self.bool_bitwise_op(lreg, rreg, op[0], line)
if isinstance(rtype, RInstance) and op in ("in", "not in"):
return self.translate_instance_contains(rreg, lreg, op, line)
if is_fixed_width_rtype(ltype):
if op in FIXED_WIDTH_INT_BINARY_OPS:
if op.endswith("="):
op = op[:-1]
if op != "//":
op_id = int_op_to_id[op]
else:
op_id = IntOp.DIV
if is_bool_rprimitive(rtype) or is_bit_rprimitive(rtype):
rreg = self.coerce(rreg, ltype, line)
rtype = ltype
if is_fixed_width_rtype(rtype) or is_tagged(rtype):
if op != "//":
op_id = int_op_to_id[op]
else:
op_id = IntOp.DIV
return self.fixed_width_int_op(ltype, lreg, rreg, op_id, line)
if isinstance(rreg, Integer):
# TODO: Check what kind of Integer
if op != "//":
op_id = int_op_to_id[op]
else:
op_id = IntOp.DIV
return self.fixed_width_int_op(
ltype, lreg, Integer(rreg.value >> 1, ltype), op_id, line
)
elif op in ComparisonOp.signed_ops:
if is_int_rprimitive(rtype):
rreg = self.coerce_int_to_fixed_width(rreg, ltype, line)
elif is_bool_rprimitive(rtype) or is_bit_rprimitive(rtype):
rreg = self.coerce(rreg, ltype, line)
op_id = ComparisonOp.signed_ops[op]
if is_fixed_width_rtype(rreg.type):
return self.comparison_op(lreg, rreg, op_id, line)
if isinstance(rreg, Integer):
return self.comparison_op(lreg, Integer(rreg.value >> 1, ltype), op_id, line)
elif is_fixed_width_rtype(rtype):
if (
isinstance(lreg, Integer) or is_tagged(ltype)
) and op in FIXED_WIDTH_INT_BINARY_OPS:
if op in FIXED_WIDTH_INT_BINARY_OPS:
if op.endswith("="):
op = op[:-1]
# TODO: Support comparison ops (similar to above)
if op != "//":
op_id = int_op_to_id[op]
else:
Expand All @@ -1296,15 +1296,38 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
return self.fixed_width_int_op(
rtype, Integer(lreg.value >> 1, rtype), rreg, op_id, line
)
else:
if is_tagged(ltype):
return self.fixed_width_int_op(rtype, lreg, rreg, op_id, line)
if is_bool_rprimitive(ltype) or is_bit_rprimitive(ltype):
lreg = self.coerce(lreg, rtype, line)
return self.fixed_width_int_op(rtype, lreg, rreg, op_id, line)
elif op in ComparisonOp.signed_ops:
if is_int_rprimitive(ltype):
lreg = self.coerce_int_to_fixed_width(lreg, rtype, line)
elif is_bool_rprimitive(ltype) or is_bit_rprimitive(ltype):
lreg = self.coerce(lreg, rtype, line)
op_id = ComparisonOp.signed_ops[op]
if isinstance(lreg, Integer):
return self.comparison_op(Integer(lreg.value >> 1, rtype), rreg, op_id, line)
if is_fixed_width_rtype(lreg.type):
return self.comparison_op(lreg, rreg, op_id, line)

# Mixed int comparisons
if op in ("==", "!="):
op_id = ComparisonOp.signed_ops[op]
if is_tagged(ltype) and is_subtype(rtype, ltype):
rreg = self.coerce(rreg, int_rprimitive, line)
return self.comparison_op(lreg, rreg, op_id, line)
if is_tagged(rtype) and is_subtype(ltype, rtype):
lreg = self.coerce(lreg, int_rprimitive, line)
return self.comparison_op(lreg, rreg, op_id, line)
elif op in op in int_comparison_op_mapping:
if is_tagged(ltype) and is_subtype(rtype, ltype):
rreg = self.coerce(rreg, short_int_rprimitive, line)
return self.compare_tagged(lreg, rreg, op, line)
if is_tagged(rtype) and is_subtype(ltype, rtype):
lreg = self.coerce(lreg, short_int_rprimitive, line)
return self.compare_tagged(lreg, rreg, op, line)

call_c_ops_candidates = binary_ops.get(op, [])
target = self.matching_call_c(call_c_ops_candidates, [lreg, rreg], line)
Expand Down Expand Up @@ -1509,14 +1532,21 @@ def bool_bitwise_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value
assert False, op
return self.add(IntOp(bool_rprimitive, lreg, rreg, code, line))

def bool_comparison_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
op_id = ComparisonOp.signed_ops[op]
return self.comparison_op(lreg, rreg, op_id, line)

def unary_not(self, value: Value, line: int) -> Value:
mask = Integer(1, value.type, line)
return self.int_op(value.type, value, mask, IntOp.XOR, line)

def unary_op(self, value: Value, expr_op: str, line: int) -> Value:
typ = value.type
if (is_bool_rprimitive(typ) or is_bit_rprimitive(typ)) and expr_op == "not":
return self.unary_not(value, line)
if is_bool_rprimitive(typ) or is_bit_rprimitive(typ):
if expr_op == "not":
return self.unary_not(value, line)
if expr_op == "+":
return value
if is_fixed_width_rtype(typ):
if expr_op == "-":
# Translate to '0 - x'
Expand All @@ -1532,6 +1562,8 @@ def unary_op(self, value: Value, expr_op: str, line: int) -> Value:
if is_short_int_rprimitive(typ):
num >>= 1
return Integer(-num, typ, value.line)
if is_tagged(typ) and expr_op == "+":
return value
if isinstance(typ, RInstance):
if expr_op == "-":
method = "__neg__"
Expand Down
Loading