Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add pi constant #451

Merged
merged 6 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 54 additions & 23 deletions guppylang/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
from collections.abc import Callable
from dataclasses import dataclass, field
from pathlib import Path
from types import ModuleType
from types import FrameType, ModuleType
from typing import Any, TypeVar

from hugr import Hugr, ops
from hugr import tys as ht
from hugr import val as hv

from guppylang.ast_util import annotate_location, has_empty_body
from guppylang.definition.common import DefId
from guppylang.definition.const import RawConstDef
from guppylang.definition.custom import (
CustomCallChecker,
CustomCallCompiler,
Expand Down Expand Up @@ -250,6 +252,15 @@ def dec(f: Callable[..., Any]) -> RawFunctionDecl:

return dec

def constant(
self, module: GuppyModule, name: str, ty: str, value: hv.Value
) -> RawConstDef:
"""Adds a constant to a module, backed by a `hugr.val.Value`."""
type_ast = _parse_expr_string(ty, f"Not a valid Guppy type: `{ty}`")
defn = RawConstDef(DefId.fresh(module), name, None, type_ast, value)
module.register_def(defn)
return defn

def extern(
self,
module: GuppyModule,
Expand All @@ -259,28 +270,7 @@ def extern(
constant: bool = True,
) -> RawExternDef:
"""Adds an extern symbol to a module."""
try:
type_ast = ast.parse(ty, mode="eval").body
except SyntaxError:
err = f"Not a valid Guppy type: `{ty}`"
raise GuppyError(err) from None

# Try to annotate the type AST with source information. This requires us to
# inspect the stack frame of the caller
if frame := inspect.currentframe(): # noqa: SIM102
if caller_frame := frame.f_back: # noqa: SIM102
if caller_module := inspect.getmodule(caller_frame):
info = inspect.getframeinfo(caller_frame)
source_lines, _ = inspect.getsourcelines(caller_module)
source = "".join(source_lines)
annotate_location(type_ast, source, info.filename, 0)
# Modify the AST so that all sub-nodes span the entire line. We
# can't give a better location since we don't know the column
# offset of the `ty` argument
for node in [type_ast, *ast.walk(type_ast)]:
node.lineno, node.col_offset = info.lineno, 0
node.end_col_offset = len(source_lines[info.lineno - 1])

type_ast = _parse_expr_string(ty, f"Not a valid Guppy type: `{ty}`")
defn = RawExternDef(
DefId.fresh(module), name, None, symbol or name, constant, type_ast
)
Expand Down Expand Up @@ -326,3 +316,44 @@ def registered_modules(self) -> list[ModuleIdentifier]:


guppy = _Guppy()


def _parse_expr_string(ty_str: str, parse_err: str) -> ast.expr:
"""Helper function to parse expressions that are provided as strings.

Tries to infer the source location were the given string was defined by inspecting
the call stack.
"""
try:
expr_ast = ast.parse(ty_str, mode="eval").body
except SyntaxError:
raise GuppyError(parse_err) from None

# Try to annotate the type AST with source information. This requires us to
# inspect the stack frame of the caller
if caller_frame := _get_calling_frame():
info = inspect.getframeinfo(caller_frame)
if caller_module := inspect.getmodule(caller_frame):
source_lines, _ = inspect.getsourcelines(caller_module)
source = "".join(source_lines)
annotate_location(expr_ast, source, info.filename, 0)
# Modify the AST so that all sub-nodes span the entire line. We
# can't give a better location since we don't know the column
# offset of the `ty` argument
for node in [expr_ast, *ast.walk(expr_ast)]:
node.lineno, node.col_offset = info.lineno, 0
node.end_col_offset = len(source_lines[info.lineno - 1])
return expr_ast


def _get_calling_frame() -> FrameType | None:
"""Finds the first frame that called this function outside the current module."""
frame = inspect.currentframe()
while frame:
module = inspect.getmodule(frame)
if module is None:
break
if module.__file__ != __file__:
return frame
frame = frame.f_back
return None
62 changes: 62 additions & 0 deletions guppylang/definition/const.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import ast
from dataclasses import dataclass, field

from hugr import Node, Wire
from hugr import val as hv
from hugr.dfg import OpVar, _DefinitionBuilder

from guppylang.ast_util import AstNode
from guppylang.checker.core import Globals
from guppylang.compiler.core import CompiledGlobals, DFContainer
from guppylang.definition.common import CompilableDef, ParsableDef
from guppylang.definition.value import CompiledValueDef, ValueDef
from guppylang.tys.parsing import type_from_ast


@dataclass(frozen=True)
class RawConstDef(ParsableDef):
"""A raw constant definition as provided by the user."""

type_ast: ast.expr
value: hv.Value

description: str = field(default="constant", init=False)

def parse(self, globals: Globals) -> "ConstDef":
"""Parses and checks the user-provided signature of the function."""
return ConstDef(
self.id,
self.name,
self.defined_at,
type_from_ast(self.type_ast, globals, None),
self.type_ast,
self.value,
)


@dataclass(frozen=True)
class ConstDef(RawConstDef, ValueDef, CompilableDef):
"""A constant with a checked type."""

def compile_outer(self, graph: _DefinitionBuilder[OpVar]) -> "CompiledConstDef":
const_node = graph.add_const(self.value)
return CompiledConstDef(
self.id,
self.name,
self.defined_at,
self.ty,
self.type_ast,
self.value,
const_node,
)


@dataclass(frozen=True)
class CompiledConstDef(ConstDef, CompiledValueDef):
"""A constant that has been compiled to a Hugr node."""

const_node: Node

def load(self, dfg: DFContainer, globals: CompiledGlobals, node: AstNode) -> Wire:
"""Loads the extern value into a local Hugr dataflow graph."""
return dfg.builder.load(self.const_node)
47 changes: 47 additions & 0 deletions guppylang/prelude/_internal/compiler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import ClassVar

import hugr
from hugr import Wire, ops
from hugr import tys as ht
Expand Down Expand Up @@ -249,6 +251,51 @@ def compile(self, args: list[Wire]) -> list[Wire]:
return [q]


class AngleOpCompiler(CustomCallCompiler):
"""Compiler for tket2 angle ops.

Automatically translated between the Hugr usize type used in the angle extension
and Guppy's `nat` type.
"""

NAT_TYPE: ClassVar[ht.Type] = NumericType(NumericType.Kind.Nat).to_hugr()

def __init__(self, op_name: str) -> None:
self.op_name = op_name

def nat_to_usize(self, value: Wire) -> Wire:
op = ops.Custom(
name="itousize",
signature=ht.FunctionType([self.NAT_TYPE], [ht.USize()]),
extension="arithmetic.conversions",
args=[ht.BoundedNatArg(NumericType.INT_WIDTH)],
)
return self.builder.add_op(op, value)

def usize_to_nat(self, value: Wire) -> Wire:
op = ops.Custom(
name="ifromusize",
signature=ht.FunctionType([self.NAT_TYPE], [ht.USize()]),
extension="arithmetic.conversions",
args=[ht.BoundedNatArg(NumericType.INT_WIDTH)],
)
return self.builder.add_op(op, value)

def compile(self, args: list[Wire]) -> list[Wire]:
sig = ht.FunctionType(self.ty.input.copy(), self.ty.output.copy())
for i, ty in enumerate(sig.input):
if ty == self.NAT_TYPE:
args[i] = self.nat_to_usize(args[i])
sig.input[i] = ht.USize()
op = ops.Custom(self.op_name, sig, extension="guppy.unsupported", args=[])
outs: list[Wire] = [*self.builder.add_op(op, *args)]
for i, ty in enumerate(sig.input):
if ty == self.NAT_TYPE:
outs[i] = self.usize_to_nat(outs[i])
sig.output[i] = ht.USize()
return outs


class ArrayGetitemCompiler(CustomCallCompiler):
"""Compiler for the `array.__getitem__` function."""

Expand Down
104 changes: 104 additions & 0 deletions guppylang/prelude/angles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""Guppy standard module for dyadic rational angles."""

# mypy: disable-error-code="empty-body, misc, override"

from typing import no_type_check

from hugr import tys as ht
from hugr import val as hv

from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude._internal.checker import CoercingChecker
from guppylang.prelude._internal.compiler import (
AngleOpCompiler,
)
from guppylang.prelude.builtins import nat

angles = GuppyModule("angles")


_hugr_angle_type = ht.Opaque(
"angle", ht.TypeBound.Copyable, [ht.BoundedNatArg(1)], "quantum.tket2"
)


def _hugr_angle_value(numerator: int, log_denominator: int) -> hv.Value:
custom_const = {
"log_denominator": log_denominator,
"value": numerator,
}
return hv.Extension(
name="ConstAngle",
typ=_hugr_angle_type,
val=custom_const,
extensions=["quantum.tket2"],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: this will be "tket2.quantum" in next tket2 release

)


pi = guppy.constant(angles, "pi", ty="angle", value=_hugr_angle_value(1, 1))


@guppy.type(angles, _hugr_angle_type)
class angle:
"""The type of angles represented as dyadic rational multiples of 2π."""

@guppy.custom(angles, AngleOpCompiler("afromrad"), CoercingChecker())
def __new__(radians: float) -> "angle": ...

@guppy.custom(angles, AngleOpCompiler("aadd"))
def __add__(self: "angle", other: "angle") -> "angle": ...

@guppy.custom(angles, AngleOpCompiler("asub"))
def __sub__(self: "angle", other: "angle") -> "angle": ...

@guppy.custom(angles, AngleOpCompiler("aneg"))
def __neg__(self: "angle") -> "angle": ...

@guppy.custom(angles, AngleOpCompiler("atorad"))
def __float__(self: "angle") -> float: ...

@guppy.custom(angles, AngleOpCompiler("aeq"))
def __eq__(self: "angle", other: "angle") -> bool: ...

@guppy(angles)
@no_type_check
def __mul__(self: "angle", other: int) -> "angle":
if other < 0:
return self._nat_mul(nat(other))
else:
return -self._nat_mul(nat(other))

@guppy(angles)
@no_type_check
def __rmul__(self: "angle", other: int) -> "angle":
return self * other

@guppy(angles)
@no_type_check
def __truediv__(self: "angle", other: int) -> "angle":
if other < 0:
return self._nat_div(nat(other))
else:
return -self._nat_div(nat(other))

@guppy.custom(angles, AngleOpCompiler("amul"))
def _nat_mul(self: "angle", other: nat) -> "angle": ...

@guppy.custom(angles, AngleOpCompiler("aneg"))
def _nat_div(self: "angle", other: nat) -> "angle": ...

@guppy.custom(angles, AngleOpCompiler("aparts"))
def _parts(self: "angle") -> tuple[nat, nat]: ...

@guppy(angles)
@no_type_check
def numerator(self: "angle") -> nat:
numerator, _ = self._parts()
return numerator

@guppy(angles)
@no_type_check
def log_denominator(self: "angle") -> nat:
_, log_denominator = self._parts()
return log_denominator
41 changes: 41 additions & 0 deletions tests/integration/test_arithmetic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from guppylang.decorator import guppy
from guppylang.prelude.angles import angle, pi
from guppylang.prelude.builtins import nat
from guppylang.module import GuppyModule
from tests.util import compile_guppy
Expand Down Expand Up @@ -101,6 +102,46 @@ def arith(x: int, y: float, z: int) -> bool:
validate(arith)


def test_angle_arith(validate):
module = GuppyModule("test")
module.load(angle)

@guppy(module)
def main(a1: angle, a2: angle) -> bool:
a3 = -a1 + a2 * -3
a3 -= a1
a3 += 2 * a1
return a3 / 3 == -a2

validate(module.compile())


def test_angle_float_coercion(validate):
module = GuppyModule("test")
module.load(angle)

@guppy(module)
def main(f: float) -> tuple[angle, float]:
a = angle(f)
return a, float(a)

validate(module.compile())


def test_angle_pi(validate):
module = GuppyModule("test")
module.load(angle, pi)

@guppy(module)
def main() -> angle:
a = 2 * pi
a += -pi / 3
a += 3 * pi / 2
return a

validate(module.compile())


def test_shortcircuit_assign1(validate):
@compile_guppy
def foo(x: bool, y: int) -> bool:
Expand Down
Loading