Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix!: Use latest results extension spec #370

Merged
merged 3 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/random_walk_qpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def main() -> int:
mu,
sigma,
)
result(0, eigenvalue) # Expected outcome is 0.5
result("eigenvalue", eigenvalue) # Expected outcome is 0.5
return 0


Expand Down
61 changes: 51 additions & 10 deletions guppylang/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from contextlib import contextmanager
from typing import Any, TypeGuard, TypeVar

from hugr.serialization import ops
from hugr.serialization import ops, tys
from typing_extensions import assert_never

from guppylang.ast_util import AstVisitor, get_type, with_loc, with_type
from guppylang.cfg.builder import tmp_vars
Expand All @@ -31,9 +32,13 @@
TensorCall,
TypeApply,
)
from guppylang.tys.arg import ConstArg, TypeArg
from guppylang.tys.builtin import bool_type, get_element_type, is_list_type
from guppylang.tys.const import ConstValue
from guppylang.tys.builtin import (
bool_type,
get_element_type,
is_bool_type,
is_list_type,
)
from guppylang.tys.const import BoundConstVar, ConstValue, ExistentialConstVar
from guppylang.tys.subst import Inst
from guppylang.tys.ty import (
BoundTypeVar,
Expand Down Expand Up @@ -297,14 +302,50 @@ def visit_FieldAccessAndDrop(self, node: FieldAccessAndDrop) -> OutPortV:
return unpack.out_port(node.struct_ty.fields.index(node.field))

def visit_ResultExpr(self, node: ResultExpr) -> OutPortV:
type_args = [
TypeArg(node.ty),
ConstArg(ConstValue(value=node.tag, ty=NumericType(NumericType.Kind.Nat))),
]
extra_args = []
if isinstance(node.base_ty, NumericType):
match node.base_ty.kind:
case NumericType.Kind.Nat:
base_name = "uint"
extra_args = [
tys.TypeArg(tys.BoundedNatArg(n=NumericType.INT_WIDTH))
]
case NumericType.Kind.Int:
base_name = "int"
extra_args = [
tys.TypeArg(tys.BoundedNatArg(n=NumericType.INT_WIDTH))
]
case NumericType.Kind.Float:
base_name = "f64"
case kind:
assert_never(kind)
else:
# The only other valid base type is bool
assert is_bool_type(node.base_ty)
base_name = "bool"
if node.array_len is not None:
op_name = f"result_array_{base_name}"
match node.array_len:
case ConstValue(value=value):
assert isinstance(value, int)
extra_args = [tys.TypeArg(tys.BoundedNatArg(n=value)), *extra_args]
case BoundConstVar():
# TODO: We need to handle this once we allow function definitions
# that are generic over array lengths
raise NotImplementedError
case ExistentialConstVar() as var:
raise InternalGuppyError(
f"Unsolved existential variable during Hugr lowering: {var}"
)
case c:
assert_never(c)
else:
op_name = f"result_{base_name}"
args = [tys.TypeArg(tys.StringArg(arg=node.tag)), *extra_args]
op = ops.CustomOp(
extension="tket2.result",
name="result_uint",
args=[arg.to_hugr() for arg in type_args],
name=op_name,
args=args,
parent=UNDEFINED,
)
self.graph.add_node(ops.OpType(op), inputs=[self.visit(node.value)])
Expand Down
9 changes: 6 additions & 3 deletions guppylang/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TYPE_CHECKING, Any

from guppylang.ast_util import AstNode
from guppylang.tys.const import Const
from guppylang.tys.subst import Inst
from guppylang.tys.ty import FunctionType, StructType, Type

Expand Down Expand Up @@ -192,10 +193,12 @@ class ResultExpr(ast.expr):
"""A `result(tag, value)` expression."""

value: ast.expr
ty: Type
tag: int
base_ty: Type
#: Array length in case this is an array result, otherwise `None`
array_len: Const | None
tag: str

_fields = ("value", "ty", "tag")
_fields = ("value", "base_ty", "array_len", "tag")


class NestedFunctionDef(ast.FunctionDef):
Expand Down
50 changes: 36 additions & 14 deletions guppylang/prelude/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,15 @@
from guppylang.error import GuppyError, GuppyTypeError, InternalGuppyError
from guppylang.hugr_builder.hugr import UNDEFINED, OutPortV
from guppylang.nodes import GlobalCall, ResultExpr
from guppylang.tys.arg import ConstArg
from guppylang.tys.builtin import bool_type, int_type, list_type
from guppylang.tys.const import ConstValue
from guppylang.tys.arg import ConstArg, TypeArg
from guppylang.tys.builtin import (
bool_type,
int_type,
is_array_type,
is_bool_type,
list_type,
)
from guppylang.tys.const import Const, ConstValue
from guppylang.tys.subst import Inst, Subst
from guppylang.tys.ty import (
FunctionType,
Expand Down Expand Up @@ -277,28 +283,44 @@ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:


class ResultChecker(CustomCallChecker):
"""Call checker for the `result` function.

This is a temporary hack until we have implemented the proper results mechanism.
"""
"""Call checker for the `result` function."""

def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
check_num_args(2, len(args), self.node)
[tag, value] = args
if not isinstance(tag, ast.Constant) or not isinstance(tag.value, int):
raise GuppyTypeError("Expected an int literal", tag)
if not isinstance(tag, ast.Constant) or not isinstance(tag.value, str):
raise GuppyTypeError("Expected a string literal", tag)
value, ty = ExprSynthesizer(self.ctx).synthesize(value)
if ty.linear:
raise GuppyTypeError(
f"Cannot use value with linear type `{ty}` as a result", value
)
return with_loc(self.node, ResultExpr(value, ty, tag.value)), NoneType()
# We only allow numeric values or vectors of numeric values
err = (
f"Expression of type `{ty}` is not a valid result. Only numeric values or "
"arrays thereof are allowed."
)
if self._is_numeric_or_bool_type(ty):
base_ty = ty
array_len: Const | None = None
elif is_array_type(ty):
[ty_arg, len_arg] = ty.args
assert isinstance(ty_arg, TypeArg)
assert isinstance(len_arg, ConstArg)
Comment on lines +304 to +305
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can these asserts result in unhelpful errors?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, failing asserts would be a compiler bug since you shouldn't be able to construct an array type with invalid args

if not self._is_numeric_or_bool_type(ty_arg.ty):
raise GuppyError(err, value)
base_ty = ty_arg.ty
array_len = len_arg.const
else:
raise GuppyError(err, value)
node = ResultExpr(value, base_ty, array_len, tag.value)
return with_loc(self.node, node), NoneType()

def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
expr, res_ty = self.synthesize(args)
subst, _ = check_type_against(res_ty, ty, self.node)
return expr, subst

@staticmethod
def _is_numeric_or_bool_type(ty: Type) -> bool:
return isinstance(ty, NumericType) or is_bool_type(ty)


class NatTruedivCompiler(CustomCallCompiler):
"""Compiler for the `nat.__truediv__` method."""
Expand Down
6 changes: 5 additions & 1 deletion guppylang/tys/builtin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Literal, TypeGuard

from hugr.serialization import tys

Expand Down Expand Up @@ -227,6 +227,10 @@ def is_linst_type(ty: Type) -> bool:
return isinstance(ty, OpaqueType) and ty.defn == linst_type_def


def is_array_type(ty: Type) -> TypeGuard[OpaqueType]:
return isinstance(ty, OpaqueType) and ty.defn == array_type_def


def get_element_type(ty: Type) -> Type:
assert isinstance(ty, OpaqueType)
assert ty.defn in (list_type_def, linst_type_def)
Expand Down
7 changes: 7 additions & 0 deletions tests/error/misc_errors/result_array_not_numeric.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:7

5: @compile_guppy
6: def foo(x: array[tuple[int, bool], 42]) -> None:
7: result("foo", x)
^
GuppyError: Expression of type `array[(int, bool), 42]` is not a valid result. Only numeric values or arrays thereof are allowed.
7 changes: 7 additions & 0 deletions tests/error/misc_errors/result_array_not_numeric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from guppylang.prelude.builtins import result, array
from tests.util import compile_guppy


@compile_guppy
def foo(x: array[tuple[int, bool], 42]) -> None:
result("foo", x)
8 changes: 4 additions & 4 deletions tests/error/misc_errors/result_tag_not_static.err
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Guppy compilation failed. Error in file $FILE:7

5: @compile_guppy
6: def foo(x: int, y: bool) -> None:
7: result(x, y)
^
GuppyTypeError: Expected an int literal
6: def foo(y: bool) -> None:
7: result("foo" + "bar", y)
^^^^^^^^^^^^^
GuppyTypeError: Expected a string literal
4 changes: 2 additions & 2 deletions tests/error/misc_errors/result_tag_not_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@


@compile_guppy
def foo(x: int, y: bool) -> None:
result(x, y)
def foo(y: bool) -> None:
result("foo" + "bar", y)
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ Guppy compilation failed. Error in file $FILE:7
6: def foo(x: int) -> None:
7: result((), x)
^^
GuppyTypeError: Expected an int literal
GuppyTypeError: Expected a string literal
6 changes: 3 additions & 3 deletions tests/error/misc_errors/result_value_linear.err
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ Guppy compilation failed. Error in file $FILE:14

12: @guppy(module)
13: def foo(q: qubit) -> None:
14: result(0, q)
^
GuppyTypeError: Cannot use value with linear type `qubit` as a result
14: result("foo", q)
^
GuppyError: Expression of type `qubit` is not a valid result. Only numeric values or arrays thereof are allowed.
2 changes: 1 addition & 1 deletion tests/error/misc_errors/result_value_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

@guppy(module)
def foo(q: qubit) -> None:
result(0, q)
result("foo", q)


module.compile()
40 changes: 18 additions & 22 deletions tests/integration/test_result.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,42 @@
from guppylang.prelude.builtins import result
from guppylang.prelude.builtins import result, nat, array
from tests.util import compile_guppy


def test_single(validate):
def test_basic(validate):
@compile_guppy
def main(x: int) -> None:
result(0, x)
result("foo", x)

validate(main)


def test_value(validate):
@compile_guppy
def main(x: int) -> None:
return result(0, x)

validate(main)


def test_nested(validate):
def test_multi(validate):
@compile_guppy
def main(x: int, y: float, z: bool) -> None:
result(42, (x, (y, z)))
def main(w: nat, x: int, y: float, z: bool) -> None:
result("a", w)
result("b", x)
result("c", y)
result("d", z)

validate(main)


def test_multi(validate):
def test_array(validate):
@compile_guppy
def main(x: int, y: float, z: bool) -> None:
result(0, x)
result(1, y)
result(2, z)
def main(w: array[nat, 42], x: array[int, 5], y: array[float, 1], z: array[bool, 0]) -> None:
result("a", w)
result("b", x)
result("c", y)
result("d", z)

validate(main)


def test_same_tag(validate):
@compile_guppy
def main(x: int, y: float, z: bool) -> None:
result(0, x)
result(0, y)
result(0, z)
result("foo", x)
result("foo", y)
result("foo", z)

validate(main)
Loading
Loading