-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Closes #448. Once the Guppy type checker has a proper framework for type coercions, we should turn angle into a `NumericType`, but this isn't needed for now
- Loading branch information
Showing
4 changed files
with
167 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
"""Compilers for angle operations from the tket2 extension.""" | ||
|
||
from typing import ClassVar | ||
|
||
from hugr import Wire, ops | ||
from hugr import tys as ht | ||
|
||
from guppylang.compiler.hugr_extension import UnsupportedOp | ||
from guppylang.definition.custom import ( | ||
CustomCallCompiler, | ||
) | ||
from guppylang.tys.ty import NumericType | ||
|
||
|
||
class AngleOpCompiler(CustomCallCompiler): | ||
"""Compiler for tket2 angle ops. | ||
Automatically translated between the Hugr usize type used in the angle extension | ||
and Guppy's `nat` type. | ||
""" | ||
|
||
NAT_TYPE: ClassVar[ht.Type] = NumericType(NumericType.Kind.Nat).to_hugr() | ||
|
||
def __init__(self, op_name: str) -> None: | ||
self.op_name = op_name | ||
|
||
def nat_to_usize(self, value: Wire) -> Wire: | ||
op = ops.Custom( | ||
op_name="itousize", | ||
signature=ht.FunctionType([self.NAT_TYPE], [ht.USize()]), | ||
extension="arithmetic.conversions", | ||
args=[], | ||
) | ||
return self.builder.add_op(op, value) | ||
|
||
def usize_to_nat(self, value: Wire) -> Wire: | ||
op = ops.Custom( | ||
op_name="ifromusize", | ||
signature=ht.FunctionType([self.NAT_TYPE], [ht.USize()]), | ||
extension="arithmetic.conversions", | ||
args=[ht.BoundedNatArg(NumericType.INT_WIDTH)], | ||
) | ||
return self.builder.add_op(op, value) | ||
|
||
def compile(self, args: list[Wire]) -> list[Wire]: | ||
sig = ht.FunctionType(self.ty.input.copy(), self.ty.output.copy()) | ||
for i, ty in enumerate(sig.input): | ||
if ty == self.NAT_TYPE: | ||
args[i] = self.nat_to_usize(args[i]) | ||
sig.input[i] = ht.USize() | ||
op = UnsupportedOp(self.op_name, sig.input, sig.output) | ||
outs: list[Wire] = [*self.builder.add_op(op, *args)] | ||
for i, ty in enumerate(sig.input): | ||
if ty == self.NAT_TYPE: | ||
outs[i] = self.usize_to_nat(outs[i]) | ||
sig.output[i] = ht.USize() | ||
return outs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
"""Guppy standard module for dyadic rational angles.""" | ||
|
||
# mypy: disable-error-code="empty-body, misc, override" | ||
|
||
from typing import no_type_check | ||
|
||
from hugr import tys as ht | ||
|
||
from guppylang.decorator import guppy | ||
from guppylang.module import GuppyModule | ||
from guppylang.prelude._internal.checker import CoercingChecker | ||
from guppylang.prelude._internal.compiler.angle import AngleOpCompiler | ||
from guppylang.prelude.builtins import nat | ||
|
||
angles = GuppyModule("angles") | ||
|
||
|
||
_hugr_angle_type = ht.Opaque("angle", ht.TypeBound.Copyable, [], "tket2.quantum") | ||
|
||
|
||
@guppy.type(angles, _hugr_angle_type) | ||
class angle: | ||
"""The type of angles represented as dyadic rational multiples of 2π.""" | ||
|
||
@guppy.custom(angles, AngleOpCompiler("afromrad"), CoercingChecker()) | ||
def __new__(radians: float) -> "angle": ... | ||
|
||
@guppy.custom(angles, AngleOpCompiler("aadd")) | ||
def __add__(self: "angle", other: "angle") -> "angle": ... | ||
|
||
@guppy.custom(angles, AngleOpCompiler("asub")) | ||
def __sub__(self: "angle", other: "angle") -> "angle": ... | ||
|
||
@guppy.custom(angles, AngleOpCompiler("aneg")) | ||
def __neg__(self: "angle") -> "angle": ... | ||
|
||
@guppy.custom(angles, AngleOpCompiler("atorad")) | ||
def __float__(self: "angle") -> float: ... | ||
|
||
@guppy.custom(angles, AngleOpCompiler("aeq")) | ||
def __eq__(self: "angle", other: "angle") -> bool: ... | ||
|
||
@guppy(angles) | ||
@no_type_check | ||
def __mul__(self: "angle", other: int) -> "angle": | ||
if other < 0: | ||
return self._nat_mul(nat(other)) | ||
else: | ||
return -self._nat_mul(nat(other)) | ||
|
||
@guppy(angles) | ||
@no_type_check | ||
def __rmul__(self: "angle", other: int) -> "angle": | ||
return self * other | ||
|
||
@guppy(angles) | ||
@no_type_check | ||
def __truediv__(self: "angle", other: int) -> "angle": | ||
if other < 0: | ||
return self._nat_div(nat(other)) | ||
else: | ||
return -self._nat_div(nat(other)) | ||
|
||
@guppy.custom(angles, AngleOpCompiler("amul")) | ||
def _nat_mul(self: "angle", other: nat) -> "angle": ... | ||
|
||
@guppy.custom(angles, AngleOpCompiler("aneg")) | ||
def _nat_div(self: "angle", other: nat) -> "angle": ... | ||
|
||
@guppy.custom(angles, AngleOpCompiler("aparts")) | ||
def _parts(self: "angle") -> tuple[nat, nat]: ... | ||
|
||
@guppy(angles) | ||
@no_type_check | ||
def numerator(self: "angle") -> nat: | ||
numerator, _ = self._parts() | ||
return numerator | ||
|
||
@guppy(angles) | ||
@no_type_check | ||
def log_denominator(self: "angle") -> nat: | ||
_, log_denominator = self._parts() | ||
return log_denominator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters