Skip to content

Commit

Permalink
feat: Add angle type (#449)
Browse files Browse the repository at this point in the history
Closes #448.

Once the Guppy type checker has a proper framework for type coercions,
we should turn angle into a `NumericType`, but this isn't needed for now
  • Loading branch information
mark-koch authored Sep 9, 2024
1 parent 3b778c3 commit 12e41e0
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 0 deletions.
Empty file.
57 changes: 57 additions & 0 deletions guppylang/prelude/_internal/compiler/angle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""Compilers for angle operations from the tket2 extension."""

from typing import ClassVar

from hugr import Wire, ops
from hugr import tys as ht

from guppylang.compiler.hugr_extension import UnsupportedOp
from guppylang.definition.custom import (
CustomCallCompiler,
)
from guppylang.tys.ty import NumericType


class AngleOpCompiler(CustomCallCompiler):
"""Compiler for tket2 angle ops.
Automatically translated between the Hugr usize type used in the angle extension
and Guppy's `nat` type.
"""

NAT_TYPE: ClassVar[ht.Type] = NumericType(NumericType.Kind.Nat).to_hugr()

def __init__(self, op_name: str) -> None:
self.op_name = op_name

def nat_to_usize(self, value: Wire) -> Wire:
op = ops.Custom(
op_name="itousize",
signature=ht.FunctionType([self.NAT_TYPE], [ht.USize()]),
extension="arithmetic.conversions",
args=[],
)
return self.builder.add_op(op, value)

def usize_to_nat(self, value: Wire) -> Wire:
op = ops.Custom(
op_name="ifromusize",
signature=ht.FunctionType([self.NAT_TYPE], [ht.USize()]),
extension="arithmetic.conversions",
args=[ht.BoundedNatArg(NumericType.INT_WIDTH)],
)
return self.builder.add_op(op, value)

def compile(self, args: list[Wire]) -> list[Wire]:
sig = ht.FunctionType(self.ty.input.copy(), self.ty.output.copy())
for i, ty in enumerate(sig.input):
if ty == self.NAT_TYPE:
args[i] = self.nat_to_usize(args[i])
sig.input[i] = ht.USize()
op = UnsupportedOp(self.op_name, sig.input, sig.output)
outs: list[Wire] = [*self.builder.add_op(op, *args)]
for i, ty in enumerate(sig.input):
if ty == self.NAT_TYPE:
outs[i] = self.usize_to_nat(outs[i])
sig.output[i] = ht.USize()
return outs
83 changes: 83 additions & 0 deletions guppylang/prelude/angles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""Guppy standard module for dyadic rational angles."""

# mypy: disable-error-code="empty-body, misc, override"

from typing import no_type_check

from hugr import tys as ht

from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude._internal.checker import CoercingChecker
from guppylang.prelude._internal.compiler.angle import AngleOpCompiler
from guppylang.prelude.builtins import nat

angles = GuppyModule("angles")


_hugr_angle_type = ht.Opaque("angle", ht.TypeBound.Copyable, [], "tket2.quantum")


@guppy.type(angles, _hugr_angle_type)
class angle:
"""The type of angles represented as dyadic rational multiples of 2π."""

@guppy.custom(angles, AngleOpCompiler("afromrad"), CoercingChecker())
def __new__(radians: float) -> "angle": ...

@guppy.custom(angles, AngleOpCompiler("aadd"))
def __add__(self: "angle", other: "angle") -> "angle": ...

@guppy.custom(angles, AngleOpCompiler("asub"))
def __sub__(self: "angle", other: "angle") -> "angle": ...

@guppy.custom(angles, AngleOpCompiler("aneg"))
def __neg__(self: "angle") -> "angle": ...

@guppy.custom(angles, AngleOpCompiler("atorad"))
def __float__(self: "angle") -> float: ...

@guppy.custom(angles, AngleOpCompiler("aeq"))
def __eq__(self: "angle", other: "angle") -> bool: ...

@guppy(angles)
@no_type_check
def __mul__(self: "angle", other: int) -> "angle":
if other < 0:
return self._nat_mul(nat(other))
else:
return -self._nat_mul(nat(other))

@guppy(angles)
@no_type_check
def __rmul__(self: "angle", other: int) -> "angle":
return self * other

@guppy(angles)
@no_type_check
def __truediv__(self: "angle", other: int) -> "angle":
if other < 0:
return self._nat_div(nat(other))
else:
return -self._nat_div(nat(other))

@guppy.custom(angles, AngleOpCompiler("amul"))
def _nat_mul(self: "angle", other: nat) -> "angle": ...

@guppy.custom(angles, AngleOpCompiler("aneg"))
def _nat_div(self: "angle", other: nat) -> "angle": ...

@guppy.custom(angles, AngleOpCompiler("aparts"))
def _parts(self: "angle") -> tuple[nat, nat]: ...

@guppy(angles)
@no_type_check
def numerator(self: "angle") -> nat:
numerator, _ = self._parts()
return numerator

@guppy(angles)
@no_type_check
def log_denominator(self: "angle") -> nat:
_, log_denominator = self._parts()
return log_denominator
27 changes: 27 additions & 0 deletions tests/integration/test_arithmetic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
from guppylang.decorator import guppy
from guppylang.prelude.angles import angle
from guppylang.prelude.builtins import nat
from guppylang.module import GuppyModule
from tests.util import compile_guppy
Expand Down Expand Up @@ -102,6 +103,32 @@ def arith(x: int, y: float, z: int) -> bool:
validate(arith)


def test_angle_arith(validate):
module = GuppyModule("test")
module.load(angle)

@guppy(module)
def main(a1: angle, a2: angle) -> bool:
a3 = -a1 + a2 * -3
a3 -= a1
a3 += 2 * a1
return a3 / 3 == -a2

validate(module.compile())


def test_angle_float_coercion(validate):
module = GuppyModule("test")
module.load(angle)

@guppy(module)
def main(f: float) -> tuple[angle, float]:
a = angle(f)
return a, float(a)

validate(module.compile())


def test_shortcircuit_assign1(validate):
@compile_guppy
def foo(x: bool, y: int) -> bool:
Expand Down

0 comments on commit 12e41e0

Please sign in to comment.