Skip to content

Commit

Permalink
feat: Add pi constant (#451)
Browse files Browse the repository at this point in the history
Closes #450. Depends on #449.

* Adds a `guppy.constant` function to declare constants in modules
* Adds a `pi` constant to the angle module for convenient writing of
angles
  • Loading branch information
mark-koch authored Sep 9, 2024
1 parent 12e41e0 commit 9d35a78
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 24 deletions.
77 changes: 54 additions & 23 deletions guppylang/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@
from collections.abc import Callable
from dataclasses import dataclass, field
from pathlib import Path
from types import ModuleType
from types import FrameType, ModuleType
from typing import Any, TypeVar

import hugr.ext
from hugr import ops
from hugr import tys as ht
from hugr import val as hv

from guppylang.ast_util import annotate_location, has_empty_body
from guppylang.definition.common import DefId
from guppylang.definition.const import RawConstDef
from guppylang.definition.custom import (
CustomCallChecker,
CustomCallCompiler,
Expand Down Expand Up @@ -251,6 +253,15 @@ def dec(f: Callable[..., Any]) -> RawFunctionDecl:

return dec

def constant(
self, module: GuppyModule, name: str, ty: str, value: hv.Value
) -> RawConstDef:
"""Adds a constant to a module, backed by a `hugr.val.Value`."""
type_ast = _parse_expr_string(ty, f"Not a valid Guppy type: `{ty}`")
defn = RawConstDef(DefId.fresh(module), name, None, type_ast, value)
module.register_def(defn)
return defn

def extern(
self,
module: GuppyModule,
Expand All @@ -260,28 +271,7 @@ def extern(
constant: bool = True,
) -> RawExternDef:
"""Adds an extern symbol to a module."""
try:
type_ast = ast.parse(ty, mode="eval").body
except SyntaxError:
err = f"Not a valid Guppy type: `{ty}`"
raise GuppyError(err) from None

# Try to annotate the type AST with source information. This requires us to
# inspect the stack frame of the caller
if frame := inspect.currentframe(): # noqa: SIM102
if caller_frame := frame.f_back: # noqa: SIM102
if caller_module := inspect.getmodule(caller_frame):
info = inspect.getframeinfo(caller_frame)
source_lines, _ = inspect.getsourcelines(caller_module)
source = "".join(source_lines)
annotate_location(type_ast, source, info.filename, 0)
# Modify the AST so that all sub-nodes span the entire line. We
# can't give a better location since we don't know the column
# offset of the `ty` argument
for node in [type_ast, *ast.walk(type_ast)]:
node.lineno, node.col_offset = info.lineno, 0
node.end_col_offset = len(source_lines[info.lineno - 1])

type_ast = _parse_expr_string(ty, f"Not a valid Guppy type: `{ty}`")
defn = RawExternDef(
DefId.fresh(module), name, None, symbol or name, constant, type_ast
)
Expand Down Expand Up @@ -327,3 +317,44 @@ def registered_modules(self) -> list[ModuleIdentifier]:


guppy = _Guppy()


def _parse_expr_string(ty_str: str, parse_err: str) -> ast.expr:
"""Helper function to parse expressions that are provided as strings.
Tries to infer the source location were the given string was defined by inspecting
the call stack.
"""
try:
expr_ast = ast.parse(ty_str, mode="eval").body
except SyntaxError:
raise GuppyError(parse_err) from None

# Try to annotate the type AST with source information. This requires us to
# inspect the stack frame of the caller
if caller_frame := _get_calling_frame():
info = inspect.getframeinfo(caller_frame)
if caller_module := inspect.getmodule(caller_frame):
source_lines, _ = inspect.getsourcelines(caller_module)
source = "".join(source_lines)
annotate_location(expr_ast, source, info.filename, 0)
# Modify the AST so that all sub-nodes span the entire line. We
# can't give a better location since we don't know the column
# offset of the `ty` argument
for node in [expr_ast, *ast.walk(expr_ast)]:
node.lineno, node.col_offset = info.lineno, 0
node.end_col_offset = len(source_lines[info.lineno - 1])
return expr_ast


def _get_calling_frame() -> FrameType | None:
"""Finds the first frame that called this function outside the current module."""
frame = inspect.currentframe()
while frame:
module = inspect.getmodule(frame)
if module is None:
break
if module.__file__ != __file__:
return frame
frame = frame.f_back
return None
62 changes: 62 additions & 0 deletions guppylang/definition/const.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import ast
from dataclasses import dataclass, field

from hugr import Node, Wire
from hugr import val as hv
from hugr.build.dfg import DefinitionBuilder, OpVar

from guppylang.ast_util import AstNode
from guppylang.checker.core import Globals
from guppylang.compiler.core import CompiledGlobals, DFContainer
from guppylang.definition.common import CompilableDef, ParsableDef
from guppylang.definition.value import CompiledValueDef, ValueDef
from guppylang.tys.parsing import type_from_ast


@dataclass(frozen=True)
class RawConstDef(ParsableDef):
"""A raw constant definition as provided by the user."""

type_ast: ast.expr
value: hv.Value

description: str = field(default="constant", init=False)

def parse(self, globals: Globals) -> "ConstDef":
"""Parses and checks the user-provided signature of the function."""
return ConstDef(
self.id,
self.name,
self.defined_at,
type_from_ast(self.type_ast, globals, None),
self.type_ast,
self.value,
)


@dataclass(frozen=True)
class ConstDef(RawConstDef, ValueDef, CompilableDef):
"""A constant with a checked type."""

def compile_outer(self, graph: DefinitionBuilder[OpVar]) -> "CompiledConstDef":
const_node = graph.add_const(self.value)
return CompiledConstDef(
self.id,
self.name,
self.defined_at,
self.ty,
self.type_ast,
self.value,
const_node,
)


@dataclass(frozen=True)
class CompiledConstDef(ConstDef, CompiledValueDef):
"""A constant that has been compiled to a Hugr node."""

const_node: Node

def load(self, dfg: DFContainer, globals: CompiledGlobals, node: AstNode) -> Wire:
"""Loads the extern value into a local Hugr dataflow graph."""
return dfg.builder.load(self.const_node)
17 changes: 17 additions & 0 deletions guppylang/prelude/angles.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import no_type_check

from hugr import tys as ht
from hugr import val as hv

from guppylang.decorator import guppy
from guppylang.module import GuppyModule
Expand All @@ -18,6 +19,22 @@
_hugr_angle_type = ht.Opaque("angle", ht.TypeBound.Copyable, [], "tket2.quantum")


def _hugr_angle_value(numerator: int, log_denominator: int) -> hv.Value:
custom_const = {
"log_denominator": log_denominator,
"value": numerator,
}
return hv.Extension(
name="ConstAngle",
typ=_hugr_angle_type,
val=custom_const,
extensions=["quantum.tket2"],
)


pi = guppy.constant(angles, "pi", ty="angle", value=_hugr_angle_value(1, 1))


@guppy.type(angles, _hugr_angle_type)
class angle:
"""The type of angles represented as dyadic rational multiples of 2π."""
Expand Down
16 changes: 15 additions & 1 deletion tests/integration/test_arithmetic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from guppylang.decorator import guppy
from guppylang.prelude.angles import angle
from guppylang.prelude.angles import angle, pi
from guppylang.prelude.builtins import nat
from guppylang.module import GuppyModule
from tests.util import compile_guppy
Expand Down Expand Up @@ -129,6 +129,20 @@ def main(f: float) -> tuple[angle, float]:
validate(module.compile())


def test_angle_pi(validate):
module = GuppyModule("test")
module.load(angle, pi)

@guppy(module)
def main() -> angle:
a = 2 * pi
a += -pi / 3
a += 3 * pi / 2
return a

validate(module.compile())


def test_shortcircuit_assign1(validate):
@compile_guppy
def foo(x: bool, y: int) -> bool:
Expand Down

0 comments on commit 9d35a78

Please sign in to comment.