Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Accept non-negative int literals and py expressions as nats #708

Merged
merged 4 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 44 additions & 9 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,15 @@
from guppylang.tys.builtin import (
array_type,
bool_type,
float_type,
get_element_type,
int_type,
is_array_type,
is_bool_type,
is_list_type,
is_sized_iter_type,
list_type,
nat_type,
)
from guppylang.tys.param import ConstParam, TypeParam
from guppylang.tys.subst import Inst, Subst
Expand Down Expand Up @@ -232,6 +236,14 @@ def _synthesize(
"""Invokes the type synthesiser"""
return ExprSynthesizer(self.ctx).synthesize(node, allow_free_vars)

def visit_Constant(self, node: ast.Constant, ty: Type) -> tuple[ast.expr, Subst]:
act = python_value_to_guppy_type(node.value, node, self.ctx.globals, ty)
if act is None:
raise GuppyError(IllegalConstant(node, type(node.value)))
node, subst, inst = check_type_against(act, ty, node, self.ctx, self._kind)
assert inst == [], "Const values are not generic"
return node, subst

def visit_Tuple(self, node: ast.Tuple, ty: Type) -> tuple[ast.expr, Subst]:
if not isinstance(ty, TupleType) or len(ty.element_types) != len(node.elts):
return self._fail(ty, node)
Expand Down Expand Up @@ -318,7 +330,9 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]:

def visit_PyExpr(self, node: PyExpr, ty: Type) -> tuple[ast.expr, Subst]:
python_val = eval_py_expr(node, self.ctx)
if act := python_value_to_guppy_type(python_val, node.value, self.ctx.globals):
if act := python_value_to_guppy_type(
python_val, node.value, self.ctx.globals, ty
):
subst = unify(ty, act, {})
if subst is None:
self._fail(ty, act, node)
Expand Down Expand Up @@ -1137,25 +1151,43 @@ def eval_py_expr(node: PyExpr, ctx: Context) -> Any:
return python_val


def python_value_to_guppy_type(v: Any, node: ast.expr, globals: Globals) -> Type | None:
def python_value_to_guppy_type(
v: Any, node: ast.expr, globals: Globals, type_hint: Type | None = None
) -> Type | None:
"""Turns a primitive Python value into a Guppy type.

Accepts an optional `type_hint` for the expected expression type that is used to
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could add a comment explaining that this type hint is only used internally for coercing things to nat and it's an internal error to give a type hint to something that wouldnt otherwise synthesize int type

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently it's not an internal error, the hint is just ignored if it's invalid. The actual checking that the synthesised type matches what's expected happens in the ExprChecker.

A wider refactor of this would be good, but imo too much for this PR.

infer a more precise type (e.g. distinguishing between `int` and `nat`). Note that
invalid hints are ignored, i.e. no user error are emitted.

Returns `None` if the Python value cannot be represented in Guppy.
"""
match v:
case bool():
return bool_type()
# Only resolve `int` to `nat` if the user specifically asked for it
case int(n) if type_hint == nat_type() and n >= 0:
return nat_type()
# Otherwise, default to `int` for consistency with Python
case int():
return cast(TypeDef, globals["int"]).check_instantiate([], globals)
return int_type()
case float():
return cast(TypeDef, globals["float"]).check_instantiate([], globals)
return float_type()
Comment on lines +1173 to +1175
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drive-by: No need to look these up in globals

case tuple(elts):
tys = [python_value_to_guppy_type(elt, node, globals) for elt in elts]
hints = (
type_hint.element_types
if isinstance(type_hint, TupleType)
else len(elts) * [None]
)
tys = [
python_value_to_guppy_type(elt, node, globals, hint)
for elt, hint in zip(elts, hints, strict=False)
]
if any(ty is None for ty in tys):
return None
return TupleType(cast(list[Type], tys))
case list():
return _python_list_to_guppy_type(v, node, globals)
return _python_list_to_guppy_type(v, node, globals, type_hint)
case _:
# Pytket conversion is an experimental feature
# if pytket and tket2 are installed
Expand Down Expand Up @@ -1186,7 +1218,7 @@ def python_value_to_guppy_type(v: Any, node: ast.expr, globals: Globals) -> Type


def _python_list_to_guppy_type(
vs: list[Any], node: ast.expr, globals: Globals
vs: list[Any], node: ast.expr, globals: Globals, type_hint: Type | None
) -> OpaqueType | None:
"""Turns a Python list into a Guppy type.

Expand All @@ -1198,11 +1230,14 @@ def _python_list_to_guppy_type(

# All the list elements must have a unifiable types
v, *rest = vs
el_ty = python_value_to_guppy_type(v, node, globals)
elt_hint = (
get_element_type(type_hint) if type_hint and is_array_type(type_hint) else None
)
el_ty = python_value_to_guppy_type(v, node, globals, elt_hint)
if el_ty is None:
return None
for v in rest:
ty = python_value_to_guppy_type(v, node, globals)
ty = python_value_to_guppy_type(v, node, globals, elt_hint)
if ty is None:
return None
if (subst := unify(ty, el_ty, {})) is None:
Expand Down
4 changes: 3 additions & 1 deletion guppylang/std/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ def __int__(self: bool) -> int:
@guppy
@no_type_check
def __nat__(self: bool) -> nat:
# TODO: Literals should check against `nat`
# TODO: Type information doesn't flow through the `if` expression, so we
# have to insert the `nat` coercions by hand.
# See https://github.com/CQCL/guppylang/issues/707
return nat(1) if self else nat(0)
Comment on lines +109 to 112
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sadly, this still fails due to #707


@guppy.custom(checker=DunderChecker("__bool__"), higher_order_value=False)
Expand Down
4 changes: 4 additions & 0 deletions guppylang/tys/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@ def int_type() -> NumericType:
return NumericType(NumericType.Kind.Int)


def float_type() -> NumericType:
return NumericType(NumericType.Kind.Float)


def list_type(element_ty: Type) -> OpaqueType:
return OpaqueType([TypeArg(element_ty)], list_type_def)

Expand Down
8 changes: 8 additions & 0 deletions tests/error/py_errors/negative_nat.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Error: Type mismatch (at $FILE:7:11)
|
5 | @compile_guppy
6 | def foo() -> nat:
7 | return py(-1)
| ^^^^^^ Expected expression of type `nat`, got `int`

Guppy compilation failed due to 1 previous error
7 changes: 7 additions & 0 deletions tests/error/py_errors/negative_nat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from guppylang.std.builtins import nat
from tests.util import compile_guppy


@compile_guppy
def foo() -> nat:
return py(-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a test like range(py(4)) just so we can use it for regression checking

8 changes: 8 additions & 0 deletions tests/integration/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ def const() -> float:
validate(const)


def test_nat_literal(validate):
@compile_guppy
def const() -> nat:
return 42

validate(const)


def test_aug_assign(validate, run_int_fn):
module = GuppyModule("test_aug_assign")

Expand Down
12 changes: 11 additions & 1 deletion tests/integration/test_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.std.builtins import py, array
from guppylang.std.builtins import py, array, nat
from guppylang.std import quantum
from guppylang.std.quantum import qubit
from tests.util import compile_guppy
Expand Down Expand Up @@ -107,6 +107,16 @@ def foo() -> None:
validate(foo)


def test_nats_from_ints(validate):
@compile_guppy
def foo() -> None:
x: nat = py(1)
y: tuple[nat, nat] = py(2, 3)
z: array[nat, 3] = py([4, 5, 6])

validate(foo)


@pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed")
@pytest.mark.skip("Fails because of extensions in types #343")
def test_pytket_single_qubit(validate):
Expand Down
13 changes: 12 additions & 1 deletion tests/integration/test_range.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from guppylang.decorator import guppy
from guppylang.std.builtins import nat, range, SizedIter, Range
from guppylang.std.builtins import nat, range, SizedIter, Range, py
from guppylang.module import GuppyModule
from tests.util import compile_guppy

Expand Down Expand Up @@ -45,6 +45,17 @@ def negative() -> SizedIter[Range, 10]:
validate(module.compile())


def test_py_size(validate):
module = GuppyModule("test")
n = 10

@guppy(module)
def negative() -> SizedIter[Range, 10]:
return range(py(n))

validate(module.compile())


def test_static_generic_size(validate):
module = GuppyModule("test")
n = guppy.nat_var("n", module=module)
Expand Down
Loading