Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support nat/intbool cast operations #459

Merged
merged 1 commit into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 61 additions & 21 deletions guppylang/prelude/_internal/compiler/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from collections.abc import Sequence

import hugr
import hugr.std.int
from hugr import Wire, ops
from hugr import tys as ht
from hugr.std.float import FLOAT_T
Expand Down Expand Up @@ -45,45 +45,53 @@ def ine(width: int) -> ops.ExtOp:
return _instantiate_int_op("ine", width, [int_t(width), int_t(width)], [ht.Bool])


def iwiden_u(from_width: int, to_width: int) -> ops.ExtOp:
"""Returns an unsigned `std.arithmetic.int.widen_u` operation."""
return _instantiate_int_op(
"iwiden_u", [from_width, to_width], [int_t(from_width)], [int_t(to_width)]
)


def iwiden_s(from_width: int, to_width: int) -> ops.ExtOp:
"""Returns a signed `std.arithmetic.int.widen_s` operation."""
return _instantiate_int_op(
"iwiden_s", [from_width, to_width], [int_t(from_width)], [int_t(to_width)]
)


# ------------------------------------------------------
# --------- std.arithmetic.conversions ops -------------
# ------------------------------------------------------


def _instantiate_convert_op(
name: str,
inp: list[ht.Type],
out: list[ht.Type],
args: list[ht.TypeArg] | None = None,
) -> ops.ExtOp:
op_def = hugr.std.int.CONVERSIONS_EXTENSION.get_op(name)
return ops.ExtOp(op_def, ht.FunctionType(inp, out), args or [])


def convert_ifromusize() -> ops.ExtOp:
"""Returns a `std.arithmetic.conversions.ifromusize` operation."""
op_def = hugr.std.int.CONVERSIONS_EXTENSION.get_op("ifromusize")
return ops.ExtOp(
op_def,
ht.FunctionType([ht.USize()], [INT_T]),
)
return _instantiate_convert_op("ifromusize", [ht.USize()], [INT_T])


def convert_itousize() -> ops.ExtOp:
"""Returns a `std.arithmetic.conversions.itousize` operation."""
op_def = hugr.std.int.CONVERSIONS_EXTENSION.get_op("itousize")
return ops.ExtOp(
op_def,
ht.FunctionType([INT_T], [ht.USize()]),
)
return _instantiate_convert_op("itousize", [INT_T], [ht.USize()])


def convert_ifrombool() -> ops.ExtOp:
"""Returns a `std.arithmetic.conversions.ifrombool` operation."""
op_def = hugr.std.int.CONVERSIONS_EXTENSION.get_op("ifrombool")
return ops.ExtOp(
op_def,
ht.FunctionType([ht.Bool], [int_t(1)]),
)
return _instantiate_convert_op("ifrombool", [ht.Bool], [int_t(0)])


def convert_itobool() -> ops.ExtOp:
"""Returns a `std.arithmetic.conversions.itobool` operation."""
op_def = hugr.std.int.CONVERSIONS_EXTENSION.get_op("itobool")
return ops.ExtOp(
op_def,
ht.FunctionType([int_t(1)], [ht.Bool]),
)
return _instantiate_convert_op("itobool", [int_t(0)], [ht.Bool])


# ------------------------------------------------------
Expand Down Expand Up @@ -264,3 +272,35 @@ def compile(self, args: list[Wire]) -> list[Wire]:
ht.FunctionType([FLOAT_T] * len(args), [FLOAT_T]),
)
return list(self.builder.add(ops.MakeTuple()(div, mod)))


class IToBoolCompiler(CustomCallCompiler):
"""Compiler for the `Int` and `Nat` `.__bool__` methods.

Note that the native `std.arithmetic.conversions.itobool` hugr op
only supports 1 bit integers as input.
"""

def compile(self, args: list[Wire]) -> list[Wire]:
# Emit a comparison against zero
[num] = args
zero = self.builder.load(hugr.std.int.IntVal(0, width=6))
out = self.builder.add_op(ine(NumericType.INT_WIDTH), num, zero)
return [out]


class IFromBoolCompiler(CustomCallCompiler):
"""Compiler for the `Bool` `.__int__` and `.__nat__` methods.

Note that the native `std.arithmetic.conversions.ifrombool` hugr op
only produces 1 bit integers as output, so we have to widen the result.
"""

def compile(self, args: list[Wire]) -> list[Wire]:
# Emit an `ifrombool` followed by a widening cast
# We use `widen_u` independently of the target type, since we want the bit `1`
# to be expanded to `0x00000001` even for `nat` types
[boolean] = args
bit = self.builder.add_op(convert_ifrombool(), boolean)
num = self.builder.add_op(iwiden_u(0, NumericType.INT_WIDTH), bit)
return [num]
14 changes: 6 additions & 8 deletions guppylang/prelude/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
FloatDivmodCompiler,
FloatFloordivCompiler,
FloatModCompiler,
IFromBoolCompiler,
IntTruedivCompiler,
IToBoolCompiler,
NatTruedivCompiler,
)
from guppylang.prelude._internal.compiler.array import (
Expand Down Expand Up @@ -101,10 +103,10 @@ def __bool__(self: bool) -> bool: ...
@guppy.hugr_op(builtins, logic_op("Eq"))
def __eq__(self: bool, other: bool) -> bool: ...

@guppy.hugr_op(builtins, unsupported_op("ifrombool")) # TODO: Widen to INT_WIDTH
@guppy.custom(builtins, IFromBoolCompiler())
def __int__(self: bool) -> int: ...

@guppy.hugr_op(builtins, unsupported_op("ifrombool")) # TODO: Widen to INT_WIDTH
@guppy.custom(builtins, IFromBoolCompiler())
def __nat__(self: bool) -> nat: ...

@guppy.custom(builtins, checker=DunderChecker("__bool__"), higher_order_value=False)
Expand All @@ -128,9 +130,7 @@ def __add__(self: nat, other: nat) -> nat: ...
@guppy.hugr_op(builtins, int_op("iand"))
def __and__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(
builtins, unsupported_op("itobool")
) # TODO: itobool only supports single bit ints
@guppy.custom(builtins, IToBoolCompiler())
def __bool__(self: nat) -> bool: ...

@guppy.custom(builtins, NoopCompiler())
Expand Down Expand Up @@ -273,9 +273,7 @@ def __add__(self: int, other: int) -> int: ...
@guppy.hugr_op(builtins, int_op("iand"))
def __and__(self: int, other: int) -> int: ...

@guppy.hugr_op(
builtins, unsupported_op("itobool")
) # TODO: itobool only supports single bit ints
@guppy.custom(builtins, IToBoolCompiler())
def __bool__(self: int) -> bool: ...

@guppy.custom(builtins, NoopCompiler())
Expand Down
Loading