From 5614ffa0ef87e604e07e3f83079d0499c7d22886 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Sun, 5 Feb 2023 11:57:24 +0000 Subject: [PATCH] [mypyc] Generate faster code for bool comparisons and arithmetic (#14489) Generate specialized, efficient IR for various operations on bools. These are covered: * Bool comparisons * Mixed bool/integer comparisons * Bool arithmetic (binary and unary) * Mixed bool/integer arithmetic and bitwise ops Mixed operations where the left operand is a `bool` and the right operand is a native int still have some unnecessary conversions between native int and `int`. This would be a bit trickier to fix and is seems rare, so it doesn't seem urgent to fix this. Fixes mypyc/mypyc#968. --- mypyc/analysis/ircheck.py | 6 +- mypyc/irbuild/ll_builder.py | 82 +++++--- mypyc/test-data/irbuild-bool.test | 319 ++++++++++++++++++++++++++++++ mypyc/test-data/irbuild-i64.test | 39 ++++ mypyc/test-data/irbuild-int.test | 11 ++ mypyc/test-data/run-bools.test | 102 ++++++++++ 6 files changed, 533 insertions(+), 26 deletions(-) diff --git a/mypyc/analysis/ircheck.py b/mypyc/analysis/ircheck.py index e96c640fa8a1..719faebfcee8 100644 --- a/mypyc/analysis/ircheck.py +++ b/mypyc/analysis/ircheck.py @@ -217,6 +217,10 @@ def check_type_coercion(self, op: Op, src: RType, dest: RType) -> None: source=op, desc=f"Cannot coerce source type {src.name} to dest type {dest.name}" ) + def check_compatibility(self, op: Op, t: RType, s: RType) -> None: + if not can_coerce_to(t, s) or not can_coerce_to(s, t): + self.fail(source=op, desc=f"{t.name} and {s.name} are not compatible") + def visit_goto(self, op: Goto) -> None: self.check_control_op_targets(op) @@ -375,7 +379,7 @@ def visit_int_op(self, op: IntOp) -> None: pass def visit_comparison_op(self, op: ComparisonOp) -> None: - pass + self.check_compatibility(op, op.lhs.type, op.rhs.type) def visit_load_mem(self, op: LoadMem) -> None: pass diff --git a/mypyc/irbuild/ll_builder.py b/mypyc/irbuild/ll_builder.py index 019f709f0acc..691f4729e4a4 100644 --- a/mypyc/irbuild/ll_builder.py +++ b/mypyc/irbuild/ll_builder.py @@ -199,6 +199,9 @@ ">>=", } +# Binary operations on bools that are specialized and don't just promote operands to int +BOOL_BINARY_OPS: Final = {"&", "&=", "|", "|=", "^", "^=", "==", "!=", "<", "<=", ">", ">="} + class LowLevelIRBuilder: def __init__(self, current_module: str, mapper: Mapper, options: CompilerOptions) -> None: @@ -326,13 +329,13 @@ def coerce( ): # Equivalent types return src - elif ( - is_bool_rprimitive(src_type) or is_bit_rprimitive(src_type) - ) and is_int_rprimitive(target_type): + elif (is_bool_rprimitive(src_type) or is_bit_rprimitive(src_type)) and is_tagged( + target_type + ): shifted = self.int_op( bool_rprimitive, src, Integer(1, bool_rprimitive), IntOp.LEFT_SHIFT ) - return self.add(Extend(shifted, int_rprimitive, signed=False)) + return self.add(Extend(shifted, target_type, signed=False)) elif ( is_bool_rprimitive(src_type) or is_bit_rprimitive(src_type) ) and is_fixed_width_rtype(target_type): @@ -1245,48 +1248,45 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value: return self.compare_bytes(lreg, rreg, op, line) if is_tagged(ltype) and is_tagged(rtype) and op in int_comparison_op_mapping: return self.compare_tagged(lreg, rreg, op, line) - if ( - is_bool_rprimitive(ltype) - and is_bool_rprimitive(rtype) - and op in ("&", "&=", "|", "|=", "^", "^=") - ): - return self.bool_bitwise_op(lreg, rreg, op[0], line) + if is_bool_rprimitive(ltype) and is_bool_rprimitive(rtype) and op in BOOL_BINARY_OPS: + if op in ComparisonOp.signed_ops: + return self.bool_comparison_op(lreg, rreg, op, line) + else: + return self.bool_bitwise_op(lreg, rreg, op[0], line) if isinstance(rtype, RInstance) and op in ("in", "not in"): return self.translate_instance_contains(rreg, lreg, op, line) if is_fixed_width_rtype(ltype): if op in FIXED_WIDTH_INT_BINARY_OPS: if op.endswith("="): op = op[:-1] + if op != "//": + op_id = int_op_to_id[op] + else: + op_id = IntOp.DIV + if is_bool_rprimitive(rtype) or is_bit_rprimitive(rtype): + rreg = self.coerce(rreg, ltype, line) + rtype = ltype if is_fixed_width_rtype(rtype) or is_tagged(rtype): - if op != "//": - op_id = int_op_to_id[op] - else: - op_id = IntOp.DIV return self.fixed_width_int_op(ltype, lreg, rreg, op_id, line) if isinstance(rreg, Integer): # TODO: Check what kind of Integer - if op != "//": - op_id = int_op_to_id[op] - else: - op_id = IntOp.DIV return self.fixed_width_int_op( ltype, lreg, Integer(rreg.value >> 1, ltype), op_id, line ) elif op in ComparisonOp.signed_ops: if is_int_rprimitive(rtype): rreg = self.coerce_int_to_fixed_width(rreg, ltype, line) + elif is_bool_rprimitive(rtype) or is_bit_rprimitive(rtype): + rreg = self.coerce(rreg, ltype, line) op_id = ComparisonOp.signed_ops[op] if is_fixed_width_rtype(rreg.type): return self.comparison_op(lreg, rreg, op_id, line) if isinstance(rreg, Integer): return self.comparison_op(lreg, Integer(rreg.value >> 1, ltype), op_id, line) elif is_fixed_width_rtype(rtype): - if ( - isinstance(lreg, Integer) or is_tagged(ltype) - ) and op in FIXED_WIDTH_INT_BINARY_OPS: + if op in FIXED_WIDTH_INT_BINARY_OPS: if op.endswith("="): op = op[:-1] - # TODO: Support comparison ops (similar to above) if op != "//": op_id = int_op_to_id[op] else: @@ -1296,15 +1296,38 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value: return self.fixed_width_int_op( rtype, Integer(lreg.value >> 1, rtype), rreg, op_id, line ) - else: + if is_tagged(ltype): + return self.fixed_width_int_op(rtype, lreg, rreg, op_id, line) + if is_bool_rprimitive(ltype) or is_bit_rprimitive(ltype): + lreg = self.coerce(lreg, rtype, line) return self.fixed_width_int_op(rtype, lreg, rreg, op_id, line) elif op in ComparisonOp.signed_ops: if is_int_rprimitive(ltype): lreg = self.coerce_int_to_fixed_width(lreg, rtype, line) + elif is_bool_rprimitive(ltype) or is_bit_rprimitive(ltype): + lreg = self.coerce(lreg, rtype, line) op_id = ComparisonOp.signed_ops[op] if isinstance(lreg, Integer): return self.comparison_op(Integer(lreg.value >> 1, rtype), rreg, op_id, line) + if is_fixed_width_rtype(lreg.type): + return self.comparison_op(lreg, rreg, op_id, line) + + # Mixed int comparisons + if op in ("==", "!="): + op_id = ComparisonOp.signed_ops[op] + if is_tagged(ltype) and is_subtype(rtype, ltype): + rreg = self.coerce(rreg, int_rprimitive, line) + return self.comparison_op(lreg, rreg, op_id, line) + if is_tagged(rtype) and is_subtype(ltype, rtype): + lreg = self.coerce(lreg, int_rprimitive, line) return self.comparison_op(lreg, rreg, op_id, line) + elif op in op in int_comparison_op_mapping: + if is_tagged(ltype) and is_subtype(rtype, ltype): + rreg = self.coerce(rreg, short_int_rprimitive, line) + return self.compare_tagged(lreg, rreg, op, line) + if is_tagged(rtype) and is_subtype(ltype, rtype): + lreg = self.coerce(lreg, short_int_rprimitive, line) + return self.compare_tagged(lreg, rreg, op, line) call_c_ops_candidates = binary_ops.get(op, []) target = self.matching_call_c(call_c_ops_candidates, [lreg, rreg], line) @@ -1509,14 +1532,21 @@ def bool_bitwise_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value assert False, op return self.add(IntOp(bool_rprimitive, lreg, rreg, code, line)) + def bool_comparison_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value: + op_id = ComparisonOp.signed_ops[op] + return self.comparison_op(lreg, rreg, op_id, line) + def unary_not(self, value: Value, line: int) -> Value: mask = Integer(1, value.type, line) return self.int_op(value.type, value, mask, IntOp.XOR, line) def unary_op(self, value: Value, expr_op: str, line: int) -> Value: typ = value.type - if (is_bool_rprimitive(typ) or is_bit_rprimitive(typ)) and expr_op == "not": - return self.unary_not(value, line) + if is_bool_rprimitive(typ) or is_bit_rprimitive(typ): + if expr_op == "not": + return self.unary_not(value, line) + if expr_op == "+": + return value if is_fixed_width_rtype(typ): if expr_op == "-": # Translate to '0 - x' @@ -1532,6 +1562,8 @@ def unary_op(self, value: Value, expr_op: str, line: int) -> Value: if is_short_int_rprimitive(typ): num >>= 1 return Integer(-num, typ, value.line) + if is_tagged(typ) and expr_op == "+": + return value if isinstance(typ, RInstance): if expr_op == "-": method = "__neg__" diff --git a/mypyc/test-data/irbuild-bool.test b/mypyc/test-data/irbuild-bool.test index 407ab8bcda93..9257d8d63f7e 100644 --- a/mypyc/test-data/irbuild-bool.test +++ b/mypyc/test-data/irbuild-bool.test @@ -142,3 +142,322 @@ L2: r4 = 0 L3: return r4 + +[case testBoolComparisons] +def eq(x: bool, y: bool) -> bool: + return x == y + +def neq(x: bool, y: bool) -> bool: + return x != y + +def lt(x: bool, y: bool) -> bool: + return x < y + +def le(x: bool, y: bool) -> bool: + return x <= y + +def gt(x: bool, y: bool) -> bool: + return x > y + +def ge(x: bool, y: bool) -> bool: + return x >= y +[out] +def eq(x, y): + x, y :: bool + r0 :: bit +L0: + r0 = x == y + return r0 +def neq(x, y): + x, y :: bool + r0 :: bit +L0: + r0 = x != y + return r0 +def lt(x, y): + x, y :: bool + r0 :: bit +L0: + r0 = x < y :: signed + return r0 +def le(x, y): + x, y :: bool + r0 :: bit +L0: + r0 = x <= y :: signed + return r0 +def gt(x, y): + x, y :: bool + r0 :: bit +L0: + r0 = x > y :: signed + return r0 +def ge(x, y): + x, y :: bool + r0 :: bit +L0: + r0 = x >= y :: signed + return r0 + +[case testBoolMixedComparisons1] +from mypy_extensions import i64 + +def eq1(x: int, y: bool) -> bool: + return x == y + +def eq2(x: bool, y: int) -> bool: + return x == y + +def neq1(x: i64, y: bool) -> bool: + return x != y + +def neq2(x: bool, y: i64) -> bool: + return x != y +[out] +def eq1(x, y): + x :: int + y, r0 :: bool + r1 :: int + r2 :: bit +L0: + r0 = y << 1 + r1 = extend r0: builtins.bool to builtins.int + r2 = x == r1 + return r2 +def eq2(x, y): + x :: bool + y :: int + r0 :: bool + r1 :: int + r2 :: bit +L0: + r0 = x << 1 + r1 = extend r0: builtins.bool to builtins.int + r2 = r1 == y + return r2 +def neq1(x, y): + x :: int64 + y :: bool + r0 :: int64 + r1 :: bit +L0: + r0 = extend y: builtins.bool to int64 + r1 = x != r0 + return r1 +def neq2(x, y): + x :: bool + y, r0 :: int64 + r1 :: bit +L0: + r0 = extend x: builtins.bool to int64 + r1 = r0 != y + return r1 + +[case testBoolMixedComparisons2] +from mypy_extensions import i64 + +def lt1(x: bool, y: int) -> bool: + return x < y + +def lt2(x: int, y: bool) -> bool: + return x < y + +def gt1(x: bool, y: i64) -> bool: + return x < y + +def gt2(x: i64, y: bool) -> bool: + return x < y +[out] +def lt1(x, y): + x :: bool + y :: int + r0 :: bool + r1 :: short_int + r2 :: native_int + r3 :: bit + r4 :: native_int + r5, r6, r7 :: bit + r8 :: bool + r9 :: bit +L0: + r0 = x << 1 + r1 = extend r0: builtins.bool to short_int + r2 = r1 & 1 + r3 = r2 == 0 + r4 = y & 1 + r5 = r4 == 0 + r6 = r3 & r5 + if r6 goto L1 else goto L2 :: bool +L1: + r7 = r1 < y :: signed + r8 = r7 + goto L3 +L2: + r9 = CPyTagged_IsLt_(r1, y) + r8 = r9 +L3: + return r8 +def lt2(x, y): + x :: int + y, r0 :: bool + r1 :: short_int + r2 :: native_int + r3 :: bit + r4 :: native_int + r5, r6, r7 :: bit + r8 :: bool + r9 :: bit +L0: + r0 = y << 1 + r1 = extend r0: builtins.bool to short_int + r2 = x & 1 + r3 = r2 == 0 + r4 = r1 & 1 + r5 = r4 == 0 + r6 = r3 & r5 + if r6 goto L1 else goto L2 :: bool +L1: + r7 = x < r1 :: signed + r8 = r7 + goto L3 +L2: + r9 = CPyTagged_IsLt_(x, r1) + r8 = r9 +L3: + return r8 +def gt1(x, y): + x :: bool + y, r0 :: int64 + r1 :: bit +L0: + r0 = extend x: builtins.bool to int64 + r1 = r0 < y :: signed + return r1 +def gt2(x, y): + x :: int64 + y :: bool + r0 :: int64 + r1 :: bit +L0: + r0 = extend y: builtins.bool to int64 + r1 = x < r0 :: signed + return r1 + +[case testBoolBitwise] +from mypy_extensions import i64 +def bitand(x: bool, y: bool) -> bool: + b = x & y + return b +def bitor(x: bool, y: bool) -> bool: + b = x | y + return b +def bitxor(x: bool, y: bool) -> bool: + b = x ^ y + return b +def invert(x: bool) -> int: + return ~x +def mixed_bitand(x: i64, y: bool) -> i64: + return x & y +[out] +def bitand(x, y): + x, y, r0, b :: bool +L0: + r0 = x & y + b = r0 + return b +def bitor(x, y): + x, y, r0, b :: bool +L0: + r0 = x | y + b = r0 + return b +def bitxor(x, y): + x, y, r0, b :: bool +L0: + r0 = x ^ y + b = r0 + return b +def invert(x): + x, r0 :: bool + r1, r2 :: int +L0: + r0 = x << 1 + r1 = extend r0: builtins.bool to builtins.int + r2 = CPyTagged_Invert(r1) + return r2 +def mixed_bitand(x, y): + x :: int64 + y :: bool + r0, r1 :: int64 +L0: + r0 = extend y: builtins.bool to int64 + r1 = x & r0 + return r1 + +[case testBoolArithmetic] +def add(x: bool, y: bool) -> int: + z = x + y + return z +def mixed(b: bool, n: int) -> int: + z = b + n + z -= b + z = z * b + return z +def negate(b: bool) -> int: + return -b +def unary_plus(b: bool) -> int: + x = +b + return x +[out] +def add(x, y): + x, y, r0 :: bool + r1 :: int + r2 :: bool + r3, r4, z :: int +L0: + r0 = x << 1 + r1 = extend r0: builtins.bool to builtins.int + r2 = y << 1 + r3 = extend r2: builtins.bool to builtins.int + r4 = CPyTagged_Add(r1, r3) + z = r4 + return z +def mixed(b, n): + b :: bool + n :: int + r0 :: bool + r1, r2, z :: int + r3 :: bool + r4, r5 :: int + r6 :: bool + r7, r8 :: int +L0: + r0 = b << 1 + r1 = extend r0: builtins.bool to builtins.int + r2 = CPyTagged_Add(r1, n) + z = r2 + r3 = b << 1 + r4 = extend r3: builtins.bool to builtins.int + r5 = CPyTagged_Subtract(z, r4) + z = r5 + r6 = b << 1 + r7 = extend r6: builtins.bool to builtins.int + r8 = CPyTagged_Multiply(z, r7) + z = r8 + return z +def negate(b): + b, r0 :: bool + r1, r2 :: int +L0: + r0 = b << 1 + r1 = extend r0: builtins.bool to builtins.int + r2 = CPyTagged_Negate(r1) + return r2 +def unary_plus(b): + b, r0 :: bool + r1, x :: int +L0: + r0 = b << 1 + r1 = extend r0: builtins.bool to builtins.int + x = r1 + return x diff --git a/mypyc/test-data/irbuild-i64.test b/mypyc/test-data/irbuild-i64.test index 6b8dd357421f..253d1a837c7b 100644 --- a/mypyc/test-data/irbuild-i64.test +++ b/mypyc/test-data/irbuild-i64.test @@ -1731,6 +1731,45 @@ def f5(): L0: return 4 +[case testI64OperationsWithBools] +from mypy_extensions import i64 + +# TODO: Other mixed operations + +def add_bool_to_int(n: i64, b: bool) -> i64: + return n + b + +def compare_bool_to_i64(n: i64, b: bool) -> bool: + if n == b: + return b != n + return True +[out] +def add_bool_to_int(n, b): + n :: int64 + b :: bool + r0, r1 :: int64 +L0: + r0 = extend b: builtins.bool to int64 + r1 = n + r0 + return r1 +def compare_bool_to_i64(n, b): + n :: int64 + b :: bool + r0 :: int64 + r1 :: bit + r2 :: int64 + r3 :: bit +L0: + r0 = extend b: builtins.bool to int64 + r1 = n == r0 + if r1 goto L1 else goto L2 :: bool +L1: + r2 = extend b: builtins.bool to int64 + r3 = r2 != n + return r3 +L2: + return 1 + [case testI64Cast] from typing import cast from mypy_extensions import i64 diff --git a/mypyc/test-data/irbuild-int.test b/mypyc/test-data/irbuild-int.test index aebadce5650e..fbe00aff4040 100644 --- a/mypyc/test-data/irbuild-int.test +++ b/mypyc/test-data/irbuild-int.test @@ -222,3 +222,14 @@ def int_to_int(n): n :: int L0: return n + +[case testIntUnaryPlus] +def unary_plus(n: int) -> int: + x = +n + return x +[out] +def unary_plus(n): + n, x :: int +L0: + x = n + return x diff --git a/mypyc/test-data/run-bools.test b/mypyc/test-data/run-bools.test index e23b35d82fc5..522296592c54 100644 --- a/mypyc/test-data/run-bools.test +++ b/mypyc/test-data/run-bools.test @@ -16,6 +16,9 @@ False [case testBoolOps] from typing import Optional, Any +MYPY = False +if MYPY: + from mypy_extensions import i64 def f(x: bool) -> bool: if x: @@ -119,3 +122,102 @@ def test_any_to_bool() -> None: b: Any = a + 1 assert not bool(a) assert bool(b) + +def eq(x: bool, y: bool) -> bool: + return x == y + +def ne(x: bool, y: bool) -> bool: + return x != y + +def lt(x: bool, y: bool) -> bool: + return x < y + +def le(x: bool, y: bool) -> bool: + return x <= y + +def gt(x: bool, y: bool) -> bool: + return x > y + +def ge(x: bool, y: bool) -> bool: + return x >= y + +def test_comparisons() -> None: + for x in True, False: + for y in True, False: + x2: Any = x + y2: Any = y + assert eq(x, y) == (x2 == y2) + assert ne(x, y) == (x2 != y2) + assert lt(x, y) == (x2 < y2) + assert le(x, y) == (x2 <= y2) + assert gt(x, y) == (x2 > y2) + assert ge(x, y) == (x2 >= y2) + +def eq_mixed(x: bool, y: int) -> bool: + return x == y + +def neq_mixed(x: int, y: bool) -> bool: + return x != y + +def lt_mixed(x: bool, y: int) -> bool: + return x < y + +def gt_mixed(x: int, y: bool) -> bool: + return x > y + +def test_mixed_comparisons() -> None: + for x in True, False: + for n in -(1 << 70), -123, 0, 1, 1753, 1 << 70: + assert eq_mixed(x, n) == (int(x) == n) + assert neq_mixed(n, x) == (n != int(x)) + assert lt_mixed(x, n) == (int(x) < n) + assert gt_mixed(n, x) == (n > int(x)) + +def add(x: bool, y: bool) -> int: + return x + y + +def add_mixed(b: bool, n: int) -> int: + return b + n + +def sub_mixed(n: int, b: bool) -> int: + return n - b + +def test_arithmetic() -> None: + for x in True, False: + for y in True, False: + assert add(x, y) == int(x) + int(y) + for n in -(1 << 70), -123, 0, 1, 1753, 1 << 70: + assert add_mixed(x, n) == int(x) + n + assert sub_mixed(n, x) == n - int(x) + +def add_mixed_i64(b: bool, n: i64) -> i64: + return b + n + +def sub_mixed_i64(n: i64, b: bool) -> i64: + return n - b + +def test_arithmetic_i64() -> None: + for x in True, False: + for n in -(1 << 62), -123, 0, 1, 1753, 1 << 62: + assert add_mixed_i64(x, n) == int(x) + n + assert sub_mixed_i64(n, x) == n - int(x) + +def eq_mixed_i64(x: bool, y: i64) -> bool: + return x == y + +def neq_mixed_i64(x: i64, y: bool) -> bool: + return x != y + +def lt_mixed_i64(x: bool, y: i64) -> bool: + return x < y + +def gt_mixed_i64(x: i64, y: bool) -> bool: + return x > y + +def test_mixed_comparisons_i64() -> None: + for x in True, False: + for n in -(1 << 62), -123, 0, 1, 1753, 1 << 62: + assert eq_mixed_i64(x, n) == (int(x) == n) + assert neq_mixed_i64(n, x) == (n != int(x)) + assert lt_mixed_i64(x, n) == (int(x) < n) + assert gt_mixed_i64(n, x) == (n > int(x))