Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into ts/load-pytket
Browse files Browse the repository at this point in the history
  • Loading branch information
tatiana-s committed Dec 13, 2024
2 parents 21aff7b + 3ad49ff commit d000c00
Show file tree
Hide file tree
Showing 13 changed files with 199 additions and 33 deletions.
58 changes: 47 additions & 11 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,17 @@
from guppylang.span import Span, to_span
from guppylang.tys.arg import TypeArg
from guppylang.tys.builtin import (
array_type,
bool_type,
float_type,
get_element_type,
int_type,
is_array_type,
is_bool_type,
is_list_type,
is_sized_iter_type,
list_type,
nat_type,
)
from guppylang.tys.param import ConstParam, TypeParam
from guppylang.tys.subst import Inst, Subst
Expand Down Expand Up @@ -231,6 +236,14 @@ def _synthesize(
"""Invokes the type synthesiser"""
return ExprSynthesizer(self.ctx).synthesize(node, allow_free_vars)

def visit_Constant(self, node: ast.Constant, ty: Type) -> tuple[ast.expr, Subst]:
act = python_value_to_guppy_type(node.value, node, self.ctx.globals, ty)
if act is None:
raise GuppyError(IllegalConstant(node, type(node.value)))
node, subst, inst = check_type_against(act, ty, node, self.ctx, self._kind)
assert inst == [], "Const values are not generic"
return node, subst

def visit_Tuple(self, node: ast.Tuple, ty: Type) -> tuple[ast.expr, Subst]:
if not isinstance(ty, TupleType) or len(ty.element_types) != len(node.elts):
return self._fail(ty, node)
Expand Down Expand Up @@ -317,7 +330,9 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]:

def visit_PyExpr(self, node: PyExpr, ty: Type) -> tuple[ast.expr, Subst]:
python_val = eval_py_expr(node, self.ctx)
if act := python_value_to_guppy_type(python_val, node.value, self.ctx.globals):
if act := python_value_to_guppy_type(
python_val, node.value, self.ctx.globals, ty
):
subst = unify(ty, act, {})
if subst is None:
self._fail(ty, act, node)
Expand Down Expand Up @@ -1136,25 +1151,43 @@ def eval_py_expr(node: PyExpr, ctx: Context) -> Any:
return python_val


def python_value_to_guppy_type(v: Any, node: ast.expr, globals: Globals) -> Type | None:
def python_value_to_guppy_type(
v: Any, node: ast.expr, globals: Globals, type_hint: Type | None = None
) -> Type | None:
"""Turns a primitive Python value into a Guppy type.
Accepts an optional `type_hint` for the expected expression type that is used to
infer a more precise type (e.g. distinguishing between `int` and `nat`). Note that
invalid hints are ignored, i.e. no user error are emitted.
Returns `None` if the Python value cannot be represented in Guppy.
"""
match v:
case bool():
return bool_type()
# Only resolve `int` to `nat` if the user specifically asked for it
case int(n) if type_hint == nat_type() and n >= 0:
return nat_type()
# Otherwise, default to `int` for consistency with Python
case int():
return cast(TypeDef, globals["int"]).check_instantiate([], globals)
return int_type()
case float():
return cast(TypeDef, globals["float"]).check_instantiate([], globals)
return float_type()
case tuple(elts):
tys = [python_value_to_guppy_type(elt, node, globals) for elt in elts]
hints = (
type_hint.element_types
if isinstance(type_hint, TupleType)
else len(elts) * [None]
)
tys = [
python_value_to_guppy_type(elt, node, globals, hint)
for elt, hint in zip(elts, hints, strict=False)
]
if any(ty is None for ty in tys):
return None
return TupleType(cast(list[Type], tys))
case list():
return _python_list_to_guppy_type(v, node, globals)
return _python_list_to_guppy_type(v, node, globals, type_hint)
case _:
# Pytket conversion is an experimental feature
# if pytket and tket2 are installed
Expand Down Expand Up @@ -1185,26 +1218,29 @@ def python_value_to_guppy_type(v: Any, node: ast.expr, globals: Globals) -> Type


def _python_list_to_guppy_type(
vs: list[Any], node: ast.expr, globals: Globals
vs: list[Any], node: ast.expr, globals: Globals, type_hint: Type | None
) -> OpaqueType | None:
"""Turns a Python list into a Guppy type.
Returns `None` if the list contains different types or types that are not
representable in Guppy.
"""
if len(vs) == 0:
return list_type(ExistentialTypeVar.fresh("T", False))
return array_type(ExistentialTypeVar.fresh("T", False), 0)

# All the list elements must have a unifiable types
v, *rest = vs
el_ty = python_value_to_guppy_type(v, node, globals)
elt_hint = (
get_element_type(type_hint) if type_hint and is_array_type(type_hint) else None
)
el_ty = python_value_to_guppy_type(v, node, globals, elt_hint)
if el_ty is None:
return None
for v in rest:
ty = python_value_to_guppy_type(v, node, globals)
ty = python_value_to_guppy_type(v, node, globals, elt_hint)
if ty is None:
return None
if (subst := unify(ty, el_ty, {})) is None:
raise GuppyError(PyExprIncoherentListError(node))
el_ty = el_ty.substitute(subst)
return list_type(el_ty)
return array_type(el_ty, len(vs))
15 changes: 11 additions & 4 deletions guppylang/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from hugr import val as hv
from hugr.build.cond_loop import Conditional
from hugr.build.dfg import DP, DfBase
from hugr.std.collections import ListVal
from typing_extensions import assert_never

from guppylang.ast_util import AstNode, AstVisitor, get_type
Expand Down Expand Up @@ -55,8 +54,8 @@
from guppylang.tys.builtin import (
get_element_type,
int_type,
is_array_type,
is_bool_type,
is_list_type,
)
from guppylang.tys.const import BoundConstVar, ConstValue, ExistentialConstVar
from guppylang.tys.subst import Inst
Expand Down Expand Up @@ -601,10 +600,18 @@ def python_value_to_hugr(v: Any, exp_ty: Type) -> hv.Value | None:
if doesnt_contain_none(vs):
return hv.Tuple(*vs)
case list(elts):
assert is_list_type(exp_ty)
assert is_array_type(exp_ty)
vs = [python_value_to_hugr(elt, get_element_type(exp_ty)) for elt in elts]
if doesnt_contain_none(vs):
return ListVal(vs, get_element_type(exp_ty).to_hugr())
# TODO: Use proper array value: https://github.com/CQCL/hugr/issues/1497
return hv.Extension(
name="ArrayValue",
typ=exp_ty.to_hugr(),
# The value list must be serialized at this point, otherwise the
# `Extension` value would not be serializable.
val=[v._to_serial_root() for v in vs],
extensions=["unsupported"],
)
case _:
# TODO replace with hugr protocol handling: https://github.com/CQCL/guppylang/issues/563
# Pytket conversion is an experimental feature
Expand Down
22 changes: 17 additions & 5 deletions guppylang/std/_internal/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from guppylang.nodes import (
DesugaredArrayComp,
DesugaredGeneratorExpr,
GenericParamValue,
GlobalCall,
MakeIter,
ResultExpr,
Expand Down Expand Up @@ -360,12 +361,23 @@ class RangeChecker(CustomCallChecker):
def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
check_num_args(1, len(args), self.node)
[stop] = args
stop, _ = ExprChecker(self.ctx).check(stop, int_type(), "argument")
range_iter, range_ty = self.make_range(stop)
if isinstance(stop, ast.Constant):
return to_sized_iter(range_iter, range_ty, stop.value, self.ctx)
stop_checked, _ = ExprChecker(self.ctx).check(stop, int_type(), "argument")
range_iter, range_ty = self.make_range(stop_checked)
# Check if `stop` is a statically known value. Note that we need to do this on
# the original `stop` instead of `stop_checked` to avoid any previously inserted
# `int` coercions.
if (static_stop := self.check_static(stop)) is not None:
return to_sized_iter(range_iter, range_ty, static_stop, self.ctx)
return range_iter, range_ty

def check_static(self, stop: ast.expr) -> "int | Const | None":
stop, _ = ExprSynthesizer(self.ctx).synthesize(stop, allow_free_vars=True)
if isinstance(stop, ast.Constant) and isinstance(stop.value, int):
return stop.value
if isinstance(stop, GenericParamValue) and stop.param.ty == nat_type():
return stop.param.to_bound().const
return None

def range_ty(self) -> StructType:
from guppylang.std.builtins import Range

Expand All @@ -382,7 +394,7 @@ def make_range(self, stop: ast.expr) -> tuple[ast.expr, Type]:


def to_sized_iter(
iterator: ast.expr, range_ty: Type, size: int, ctx: Context
iterator: ast.expr, range_ty: Type, size: "int | Const", ctx: Context
) -> tuple[ast.expr, Type]:
"""Adds a static size annotation to an iterator."""
sized_iter_ty = sized_iter_type(range_ty, size)
Expand Down
4 changes: 3 additions & 1 deletion guppylang/std/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ def __int__(self: bool) -> int:
@guppy
@no_type_check
def __nat__(self: bool) -> nat:
# TODO: Literals should check against `nat`
# TODO: Type information doesn't flow through the `if` expression, so we
# have to insert the `nat` coercions by hand.
# See https://github.com/CQCL/guppylang/issues/707
return nat(1) if self else nat(0)

@guppy.custom(checker=DunderChecker("__bool__"), higher_order_value=False)
Expand Down
20 changes: 19 additions & 1 deletion guppylang/std/quantum.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)
from guppylang.std._internal.util import quantum_op
from guppylang.std.angles import angle
from guppylang.std.builtins import owned
from guppylang.std.builtins import array, owned
from guppylang.std.option import Option


Expand Down Expand Up @@ -144,3 +144,21 @@ def measure(q: qubit @ owned) -> bool: ...
@guppy.hugr_op(quantum_op("Reset"))
@no_type_check
def reset(q: qubit) -> None: ...


N = guppy.nat_var("N")


@guppy
@no_type_check
def measure_array(qubits: array[qubit, N] @ owned) -> array[bool, N]:
"""Measure an array of qubits, returning an array of bools."""
return array(measure(q) for q in qubits)


@guppy
@no_type_check
def discard_array(qubits: array[qubit, N] @ owned) -> None:
"""Discard an array of qubits."""
for q in qubits:
discard(q)
4 changes: 4 additions & 0 deletions guppylang/tys/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@ def int_type() -> NumericType:
return NumericType(NumericType.Kind.Int)


def float_type() -> NumericType:
return NumericType(NumericType.Kind.Float)


def list_type(element_ty: Type) -> OpaqueType:
return OpaqueType([TypeArg(element_ty)], list_type_def)

Expand Down
2 changes: 1 addition & 1 deletion tests/error/py_errors/list_empty.err
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ Error: Cannot infer type (at $FILE:6:9)
5 | def foo() -> None:
6 | xs = py([])
| ^^^^^^ Cannot infer type variables in expression of type
| `list[?T]`
| `array[?T, 0]`

Guppy compilation failed due to 1 previous error
8 changes: 8 additions & 0 deletions tests/error/py_errors/negative_nat.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Error: Type mismatch (at $FILE:7:11)
|
5 | @compile_guppy
6 | def foo() -> nat:
7 | return py(-1)
| ^^^^^^ Expected expression of type `nat`, got `int`

Guppy compilation failed due to 1 previous error
7 changes: 7 additions & 0 deletions tests/error/py_errors/negative_nat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from guppylang.std.builtins import nat
from tests.util import compile_guppy


@compile_guppy
def foo() -> nat:
return py(-1)
8 changes: 8 additions & 0 deletions tests/integration/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ def const() -> float:
validate(const)


def test_nat_literal(validate):
@compile_guppy
def const() -> nat:
return 42

validate(const)


def test_aug_assign(validate, run_int_fn):
module = GuppyModule("test_aug_assign")

Expand Down
20 changes: 15 additions & 5 deletions tests/integration/test_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.std.builtins import py, array
from guppylang.std.builtins import py, array, nat
from guppylang.std import quantum
from guppylang.std.quantum import qubit
from tests.util import compile_guppy
Expand Down Expand Up @@ -76,7 +76,7 @@ def foo() -> int:

def test_list_basic(validate):
@compile_guppy
def foo() -> list[int]:
def foo() -> array[int, 3]:
xs = py([1, 2, 3])
return xs

Expand All @@ -85,7 +85,7 @@ def foo() -> list[int]:

def test_list_empty(validate):
@compile_guppy
def foo() -> list[int]:
def foo() -> array[int, 0]:
return py([])

validate(foo)
Expand All @@ -94,15 +94,25 @@ def foo() -> list[int]:
def test_list_empty_nested(validate):
@compile_guppy
def foo() -> None:
xs: list[tuple[int, list[bool]]] = py([(42, [])])
xs: array[tuple[int, array[bool, 0]], 1] = py([(42, [])])

validate(foo)


def test_list_empty_multiple(validate):
@compile_guppy
def foo() -> None:
xs: tuple[list[int], list[bool]] = py([], [])
xs: tuple[array[int, 0], array[bool, 0]] = py([], [])

validate(foo)


def test_nats_from_ints(validate):
@compile_guppy
def foo() -> None:
x: nat = py(1)
y: tuple[nat, nat] = py(2, 3)
z: array[nat, 3] = py([4, 5, 6])

validate(foo)

Expand Down
Loading

0 comments on commit d000c00

Please sign in to comment.