Skip to content

Commit

Permalink
fix: Accept non-negative int literals and py expressions as nats (#708)
Browse files Browse the repository at this point in the history
Fixes #704
  • Loading branch information
mark-koch authored Dec 12, 2024
1 parent d52a00a commit a93d4fe
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 12 deletions.
53 changes: 44 additions & 9 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,15 @@
from guppylang.tys.builtin import (
array_type,
bool_type,
float_type,
get_element_type,
int_type,
is_array_type,
is_bool_type,
is_list_type,
is_sized_iter_type,
list_type,
nat_type,
)
from guppylang.tys.param import ConstParam, TypeParam
from guppylang.tys.subst import Inst, Subst
Expand Down Expand Up @@ -232,6 +236,14 @@ def _synthesize(
"""Invokes the type synthesiser"""
return ExprSynthesizer(self.ctx).synthesize(node, allow_free_vars)

def visit_Constant(self, node: ast.Constant, ty: Type) -> tuple[ast.expr, Subst]:
act = python_value_to_guppy_type(node.value, node, self.ctx.globals, ty)
if act is None:
raise GuppyError(IllegalConstant(node, type(node.value)))
node, subst, inst = check_type_against(act, ty, node, self.ctx, self._kind)
assert inst == [], "Const values are not generic"
return node, subst

def visit_Tuple(self, node: ast.Tuple, ty: Type) -> tuple[ast.expr, Subst]:
if not isinstance(ty, TupleType) or len(ty.element_types) != len(node.elts):
return self._fail(ty, node)
Expand Down Expand Up @@ -318,7 +330,9 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]:

def visit_PyExpr(self, node: PyExpr, ty: Type) -> tuple[ast.expr, Subst]:
python_val = eval_py_expr(node, self.ctx)
if act := python_value_to_guppy_type(python_val, node.value, self.ctx.globals):
if act := python_value_to_guppy_type(
python_val, node.value, self.ctx.globals, ty
):
subst = unify(ty, act, {})
if subst is None:
self._fail(ty, act, node)
Expand Down Expand Up @@ -1137,25 +1151,43 @@ def eval_py_expr(node: PyExpr, ctx: Context) -> Any:
return python_val


def python_value_to_guppy_type(v: Any, node: ast.expr, globals: Globals) -> Type | None:
def python_value_to_guppy_type(
v: Any, node: ast.expr, globals: Globals, type_hint: Type | None = None
) -> Type | None:
"""Turns a primitive Python value into a Guppy type.
Accepts an optional `type_hint` for the expected expression type that is used to
infer a more precise type (e.g. distinguishing between `int` and `nat`). Note that
invalid hints are ignored, i.e. no user error are emitted.
Returns `None` if the Python value cannot be represented in Guppy.
"""
match v:
case bool():
return bool_type()
# Only resolve `int` to `nat` if the user specifically asked for it
case int(n) if type_hint == nat_type() and n >= 0:
return nat_type()
# Otherwise, default to `int` for consistency with Python
case int():
return cast(TypeDef, globals["int"]).check_instantiate([], globals)
return int_type()
case float():
return cast(TypeDef, globals["float"]).check_instantiate([], globals)
return float_type()
case tuple(elts):
tys = [python_value_to_guppy_type(elt, node, globals) for elt in elts]
hints = (
type_hint.element_types
if isinstance(type_hint, TupleType)
else len(elts) * [None]
)
tys = [
python_value_to_guppy_type(elt, node, globals, hint)
for elt, hint in zip(elts, hints, strict=False)
]
if any(ty is None for ty in tys):
return None
return TupleType(cast(list[Type], tys))
case list():
return _python_list_to_guppy_type(v, node, globals)
return _python_list_to_guppy_type(v, node, globals, type_hint)
case _:
# Pytket conversion is an experimental feature
# if pytket and tket2 are installed
Expand Down Expand Up @@ -1186,7 +1218,7 @@ def python_value_to_guppy_type(v: Any, node: ast.expr, globals: Globals) -> Type


def _python_list_to_guppy_type(
vs: list[Any], node: ast.expr, globals: Globals
vs: list[Any], node: ast.expr, globals: Globals, type_hint: Type | None
) -> OpaqueType | None:
"""Turns a Python list into a Guppy type.
Expand All @@ -1198,11 +1230,14 @@ def _python_list_to_guppy_type(

# All the list elements must have a unifiable types
v, *rest = vs
el_ty = python_value_to_guppy_type(v, node, globals)
elt_hint = (
get_element_type(type_hint) if type_hint and is_array_type(type_hint) else None
)
el_ty = python_value_to_guppy_type(v, node, globals, elt_hint)
if el_ty is None:
return None
for v in rest:
ty = python_value_to_guppy_type(v, node, globals)
ty = python_value_to_guppy_type(v, node, globals, elt_hint)
if ty is None:
return None
if (subst := unify(ty, el_ty, {})) is None:
Expand Down
4 changes: 3 additions & 1 deletion guppylang/std/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ def __int__(self: bool) -> int:
@guppy
@no_type_check
def __nat__(self: bool) -> nat:
# TODO: Literals should check against `nat`
# TODO: Type information doesn't flow through the `if` expression, so we
# have to insert the `nat` coercions by hand.
# See https://github.com/CQCL/guppylang/issues/707
return nat(1) if self else nat(0)

@guppy.custom(checker=DunderChecker("__bool__"), higher_order_value=False)
Expand Down
4 changes: 4 additions & 0 deletions guppylang/tys/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@ def int_type() -> NumericType:
return NumericType(NumericType.Kind.Int)


def float_type() -> NumericType:
return NumericType(NumericType.Kind.Float)


def list_type(element_ty: Type) -> OpaqueType:
return OpaqueType([TypeArg(element_ty)], list_type_def)

Expand Down
8 changes: 8 additions & 0 deletions tests/error/py_errors/negative_nat.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Error: Type mismatch (at $FILE:7:11)
|
5 | @compile_guppy
6 | def foo() -> nat:
7 | return py(-1)
| ^^^^^^ Expected expression of type `nat`, got `int`

Guppy compilation failed due to 1 previous error
7 changes: 7 additions & 0 deletions tests/error/py_errors/negative_nat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from guppylang.std.builtins import nat
from tests.util import compile_guppy


@compile_guppy
def foo() -> nat:
return py(-1)
8 changes: 8 additions & 0 deletions tests/integration/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ def const() -> float:
validate(const)


def test_nat_literal(validate):
@compile_guppy
def const() -> nat:
return 42

validate(const)


def test_aug_assign(validate, run_int_fn):
module = GuppyModule("test_aug_assign")

Expand Down
12 changes: 11 additions & 1 deletion tests/integration/test_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.std.builtins import py, array
from guppylang.std.builtins import py, array, nat
from guppylang.std import quantum
from guppylang.std.quantum import qubit
from tests.util import compile_guppy
Expand Down Expand Up @@ -107,6 +107,16 @@ def foo() -> None:
validate(foo)


def test_nats_from_ints(validate):
@compile_guppy
def foo() -> None:
x: nat = py(1)
y: tuple[nat, nat] = py(2, 3)
z: array[nat, 3] = py([4, 5, 6])

validate(foo)


@pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed")
@pytest.mark.skip("Fails because of extensions in types #343")
def test_pytket_single_qubit(validate):
Expand Down
13 changes: 12 additions & 1 deletion tests/integration/test_range.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from guppylang.decorator import guppy
from guppylang.std.builtins import nat, range, SizedIter, Range
from guppylang.std.builtins import nat, range, SizedIter, Range, py
from guppylang.module import GuppyModule
from tests.util import compile_guppy

Expand Down Expand Up @@ -45,6 +45,17 @@ def negative() -> SizedIter[Range, 10]:
validate(module.compile())


def test_py_size(validate):
module = GuppyModule("test")
n = 10

@guppy(module)
def negative() -> SizedIter[Range, 10]:
return range(py(n))

validate(module.compile())


def test_static_generic_size(validate):
module = GuppyModule("test")
n = guppy.nat_var("n", module=module)
Expand Down

0 comments on commit a93d4fe

Please sign in to comment.