diff --git a/guppylang/prelude/_internal/compiler/arithmetic.py b/guppylang/prelude/_internal/compiler/arithmetic.py index d8af8ad9..bea97af6 100644 --- a/guppylang/prelude/_internal/compiler/arithmetic.py +++ b/guppylang/prelude/_internal/compiler/arithmetic.py @@ -2,7 +2,7 @@ from collections.abc import Sequence -import hugr +import hugr.std.int from hugr import Wire, ops from hugr import tys as ht from hugr.std.float import FLOAT_T @@ -45,45 +45,53 @@ def ine(width: int) -> ops.ExtOp: return _instantiate_int_op("ine", width, [int_t(width), int_t(width)], [ht.Bool]) +def iwiden_u(from_width: int, to_width: int) -> ops.ExtOp: + """Returns an unsigned `std.arithmetic.int.widen_u` operation.""" + return _instantiate_int_op( + "iwiden_u", [from_width, to_width], [int_t(from_width)], [int_t(to_width)] + ) + + +def iwiden_s(from_width: int, to_width: int) -> ops.ExtOp: + """Returns a signed `std.arithmetic.int.widen_s` operation.""" + return _instantiate_int_op( + "iwiden_s", [from_width, to_width], [int_t(from_width)], [int_t(to_width)] + ) + + # ------------------------------------------------------ # --------- std.arithmetic.conversions ops ------------- # ------------------------------------------------------ +def _instantiate_convert_op( + name: str, + inp: list[ht.Type], + out: list[ht.Type], + args: list[ht.TypeArg] | None = None, +) -> ops.ExtOp: + op_def = hugr.std.int.CONVERSIONS_EXTENSION.get_op(name) + return ops.ExtOp(op_def, ht.FunctionType(inp, out), args or []) + + def convert_ifromusize() -> ops.ExtOp: """Returns a `std.arithmetic.conversions.ifromusize` operation.""" - op_def = hugr.std.int.CONVERSIONS_EXTENSION.get_op("ifromusize") - return ops.ExtOp( - op_def, - ht.FunctionType([ht.USize()], [INT_T]), - ) + return _instantiate_convert_op("ifromusize", [ht.USize()], [INT_T]) def convert_itousize() -> ops.ExtOp: """Returns a `std.arithmetic.conversions.itousize` operation.""" - op_def = hugr.std.int.CONVERSIONS_EXTENSION.get_op("itousize") - return ops.ExtOp( - op_def, - ht.FunctionType([INT_T], [ht.USize()]), - ) + return _instantiate_convert_op("itousize", [INT_T], [ht.USize()]) def convert_ifrombool() -> ops.ExtOp: """Returns a `std.arithmetic.conversions.ifrombool` operation.""" - op_def = hugr.std.int.CONVERSIONS_EXTENSION.get_op("ifrombool") - return ops.ExtOp( - op_def, - ht.FunctionType([ht.Bool], [int_t(1)]), - ) + return _instantiate_convert_op("ifrombool", [ht.Bool], [int_t(0)]) def convert_itobool() -> ops.ExtOp: """Returns a `std.arithmetic.conversions.itobool` operation.""" - op_def = hugr.std.int.CONVERSIONS_EXTENSION.get_op("itobool") - return ops.ExtOp( - op_def, - ht.FunctionType([int_t(1)], [ht.Bool]), - ) + return _instantiate_convert_op("itobool", [int_t(0)], [ht.Bool]) # ------------------------------------------------------ @@ -264,3 +272,35 @@ def compile(self, args: list[Wire]) -> list[Wire]: ht.FunctionType([FLOAT_T] * len(args), [FLOAT_T]), ) return list(self.builder.add(ops.MakeTuple()(div, mod))) + + +class IToBoolCompiler(CustomCallCompiler): + """Compiler for the `Int` and `Nat` `.__bool__` methods. + + Note that the native `std.arithmetic.conversions.itobool` hugr op + only supports 1 bit integers as input. + """ + + def compile(self, args: list[Wire]) -> list[Wire]: + # Emit a comparison against zero + [num] = args + zero = self.builder.load(hugr.std.int.IntVal(0, width=6)) + out = self.builder.add_op(ine(NumericType.INT_WIDTH), num, zero) + return [out] + + +class IFromBoolCompiler(CustomCallCompiler): + """Compiler for the `Bool` `.__int__` and `.__nat__` methods. + + Note that the native `std.arithmetic.conversions.ifrombool` hugr op + only produces 1 bit integers as output, so we have to widen the result. + """ + + def compile(self, args: list[Wire]) -> list[Wire]: + # Emit an `ifrombool` followed by a widening cast + # We use `widen_u` independently of the target type, since we want the bit `1` + # to be expanded to `0x00000001` even for `nat` types + [boolean] = args + bit = self.builder.add_op(convert_ifrombool(), boolean) + num = self.builder.add_op(iwiden_u(0, NumericType.INT_WIDTH), bit) + return [num] diff --git a/guppylang/prelude/builtins.py b/guppylang/prelude/builtins.py index d98482aa..752381e8 100644 --- a/guppylang/prelude/builtins.py +++ b/guppylang/prelude/builtins.py @@ -26,7 +26,9 @@ FloatDivmodCompiler, FloatFloordivCompiler, FloatModCompiler, + IFromBoolCompiler, IntTruedivCompiler, + IToBoolCompiler, NatTruedivCompiler, ) from guppylang.prelude._internal.compiler.array import ( @@ -101,10 +103,10 @@ def __bool__(self: bool) -> bool: ... @guppy.hugr_op(builtins, logic_op("Eq")) def __eq__(self: bool, other: bool) -> bool: ... - @guppy.hugr_op(builtins, unsupported_op("ifrombool")) # TODO: Widen to INT_WIDTH + @guppy.custom(builtins, IFromBoolCompiler()) def __int__(self: bool) -> int: ... - @guppy.hugr_op(builtins, unsupported_op("ifrombool")) # TODO: Widen to INT_WIDTH + @guppy.custom(builtins, IFromBoolCompiler()) def __nat__(self: bool) -> nat: ... @guppy.custom(builtins, checker=DunderChecker("__bool__"), higher_order_value=False) @@ -128,9 +130,7 @@ def __add__(self: nat, other: nat) -> nat: ... @guppy.hugr_op(builtins, int_op("iand")) def __and__(self: nat, other: nat) -> nat: ... - @guppy.hugr_op( - builtins, unsupported_op("itobool") - ) # TODO: itobool only supports single bit ints + @guppy.custom(builtins, IToBoolCompiler()) def __bool__(self: nat) -> bool: ... @guppy.custom(builtins, NoopCompiler()) @@ -273,9 +273,7 @@ def __add__(self: int, other: int) -> int: ... @guppy.hugr_op(builtins, int_op("iand")) def __and__(self: int, other: int) -> int: ... - @guppy.hugr_op( - builtins, unsupported_op("itobool") - ) # TODO: itobool only supports single bit ints + @guppy.custom(builtins, IToBoolCompiler()) def __bool__(self: int) -> bool: ... @guppy.custom(builtins, NoopCompiler())