From f2d2020a99a4ffd6f8f2b5dbc9da0f2376085b06 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 10 Dec 2024 18:10:35 +0000 Subject: [PATCH 1/3] fix: Accept int literals and py expressions as nats --- guppylang/checker/expr_checker.py | 52 +++++++++++++++++++++----- guppylang/std/builtins.py | 4 +- guppylang/tys/builtin.py | 4 ++ tests/error/py_errors/negative_nat.err | 8 ++++ tests/error/py_errors/negative_nat.py | 7 ++++ tests/integration/test_arithmetic.py | 8 ++++ tests/integration/test_py.py | 12 +++++- 7 files changed, 84 insertions(+), 11 deletions(-) create mode 100644 tests/error/py_errors/negative_nat.err create mode 100644 tests/error/py_errors/negative_nat.py diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index 61e24674..b1b4e8da 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -107,11 +107,14 @@ from guppylang.tys.arg import TypeArg from guppylang.tys.builtin import ( bool_type, + float_type, get_element_type, + int_type, is_bool_type, is_list_type, is_sized_iter_type, list_type, + nat_type, ) from guppylang.tys.param import ConstParam, TypeParam from guppylang.tys.subst import Inst, Subst @@ -231,6 +234,14 @@ def _synthesize( """Invokes the type synthesiser""" return ExprSynthesizer(self.ctx).synthesize(node, allow_free_vars) + def visit_Constant(self, node: ast.Constant, ty: Type) -> tuple[ast.expr, Subst]: + act = python_value_to_guppy_type(node.value, node, self.ctx.globals, ty) + if act is None: + raise GuppyError(IllegalConstant(node, type(node.value))) + node, subst, inst = check_type_against(act, ty, node, self.ctx, self._kind) + assert inst == [], "Const values are not generic" + return node, subst + def visit_Tuple(self, node: ast.Tuple, ty: Type) -> tuple[ast.expr, Subst]: if not isinstance(ty, TupleType) or len(ty.element_types) != len(node.elts): return self._fail(ty, node) @@ -317,7 +328,9 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: def visit_PyExpr(self, node: PyExpr, ty: Type) -> tuple[ast.expr, Subst]: python_val = eval_py_expr(node, self.ctx) - if act := python_value_to_guppy_type(python_val, node.value, self.ctx.globals): + if act := python_value_to_guppy_type( + python_val, node.value, self.ctx.globals, ty + ): subst = unify(ty, act, {}) if subst is None: self._fail(ty, act, node) @@ -1138,25 +1151,43 @@ def eval_py_expr(node: PyExpr, ctx: Context) -> Any: return python_val -def python_value_to_guppy_type(v: Any, node: ast.expr, globals: Globals) -> Type | None: +def python_value_to_guppy_type( + v: Any, node: ast.expr, globals: Globals, type_hint: Type | None = None +) -> Type | None: """Turns a primitive Python value into a Guppy type. + Accepts an optional `type_hint` for the expected expression type that is used to + infer a more precise type (e.g. distinguishing between `int` and `nat`). Note that + invalid hints are ignored, i.e. no user error are emitted. + Returns `None` if the Python value cannot be represented in Guppy. """ match v: case bool(): return bool_type() + # Only resolve `int` to `nat` if the user specifically asked for it + case int(n) if type_hint == nat_type() and n >= 0: + return nat_type() + # Otherwise, default to `int` for consistency with Python case int(): - return cast(TypeDef, globals["int"]).check_instantiate([], globals) + return int_type() case float(): - return cast(TypeDef, globals["float"]).check_instantiate([], globals) + return float_type() case tuple(elts): - tys = [python_value_to_guppy_type(elt, node, globals) for elt in elts] + hints = ( + type_hint.element_types + if isinstance(type_hint, TupleType) + else len(elts) * [None] + ) + tys = [ + python_value_to_guppy_type(elt, node, globals, hint) + for elt, hint in zip(elts, hints, strict=False) + ] if any(ty is None for ty in tys): return None return TupleType(cast(list[Type], tys)) case list(): - return _python_list_to_guppy_type(v, node, globals) + return _python_list_to_guppy_type(v, node, globals, type_hint) case _: # Pytket conversion is an experimental feature # if pytket and tket2 are installed @@ -1187,7 +1218,7 @@ def python_value_to_guppy_type(v: Any, node: ast.expr, globals: Globals) -> Type def _python_list_to_guppy_type( - vs: list[Any], node: ast.expr, globals: Globals + vs: list[Any], node: ast.expr, globals: Globals, type_hint: Type | None ) -> OpaqueType | None: """Turns a Python list into a Guppy type. @@ -1199,11 +1230,14 @@ def _python_list_to_guppy_type( # All the list elements must have a unifiable types v, *rest = vs - el_ty = python_value_to_guppy_type(v, node, globals) + elt_hint = ( + get_element_type(type_hint) if type_hint and is_list_type(type_hint) else None + ) + el_ty = python_value_to_guppy_type(v, node, globals, elt_hint) if el_ty is None: return None for v in rest: - ty = python_value_to_guppy_type(v, node, globals) + ty = python_value_to_guppy_type(v, node, globals, elt_hint) if ty is None: return None if (subst := unify(ty, el_ty, {})) is None: diff --git a/guppylang/std/builtins.py b/guppylang/std/builtins.py index a4027cd3..d9843bc2 100644 --- a/guppylang/std/builtins.py +++ b/guppylang/std/builtins.py @@ -106,7 +106,9 @@ def __int__(self: bool) -> int: @guppy @no_type_check def __nat__(self: bool) -> nat: - # TODO: Literals should check against `nat` + # TODO: Type information doesn't flow through the `if` expression, so we + # have to insert the `nat` coercions by hand. + # See https://github.com/CQCL/guppylang/issues/707 return nat(1) if self else nat(0) @guppy.custom(checker=DunderChecker("__bool__"), higher_order_value=False) diff --git a/guppylang/tys/builtin.py b/guppylang/tys/builtin.py index 89ffc5c8..86630c28 100644 --- a/guppylang/tys/builtin.py +++ b/guppylang/tys/builtin.py @@ -211,6 +211,10 @@ def int_type() -> NumericType: return NumericType(NumericType.Kind.Int) +def float_type() -> NumericType: + return NumericType(NumericType.Kind.Float) + + def list_type(element_ty: Type) -> OpaqueType: return OpaqueType([TypeArg(element_ty)], list_type_def) diff --git a/tests/error/py_errors/negative_nat.err b/tests/error/py_errors/negative_nat.err new file mode 100644 index 00000000..3cfd76c6 --- /dev/null +++ b/tests/error/py_errors/negative_nat.err @@ -0,0 +1,8 @@ +Error: Type mismatch (at $FILE:7:11) + | +5 | @compile_guppy +6 | def foo() -> nat: +7 | return py(-1) + | ^^^^^^ Expected expression of type `nat`, got `int` + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/py_errors/negative_nat.py b/tests/error/py_errors/negative_nat.py new file mode 100644 index 00000000..21170653 --- /dev/null +++ b/tests/error/py_errors/negative_nat.py @@ -0,0 +1,7 @@ +from guppylang.std.builtins import nat +from tests.util import compile_guppy + + +@compile_guppy +def foo() -> nat: + return py(-1) diff --git a/tests/integration/test_arithmetic.py b/tests/integration/test_arithmetic.py index 33d03914..d8bbc348 100644 --- a/tests/integration/test_arithmetic.py +++ b/tests/integration/test_arithmetic.py @@ -42,6 +42,14 @@ def const() -> float: validate(const) +def test_nat_literal(validate): + @compile_guppy + def const() -> nat: + return 42 + + validate(const) + + def test_aug_assign(validate, run_int_fn): module = GuppyModule("test_aug_assign") diff --git a/tests/integration/test_py.py b/tests/integration/test_py.py index 7b945b89..6bb72bf8 100644 --- a/tests/integration/test_py.py +++ b/tests/integration/test_py.py @@ -6,7 +6,7 @@ from guppylang.decorator import guppy from guppylang.module import GuppyModule -from guppylang.std.builtins import py, array +from guppylang.std.builtins import py, array, nat from guppylang.std import quantum from guppylang.std.quantum import qubit from tests.util import compile_guppy @@ -107,6 +107,16 @@ def foo() -> None: validate(foo) +def test_nats_from_ints(validate): + @compile_guppy + def foo() -> None: + x: nat = py(1) + y: tuple[nat, nat] = py(2, 3) + z: list[nat] = py([4, 5, 6]) + + validate(foo) + + @pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed") @pytest.mark.skip("Fails because of extensions in types #343") def test_pytket_single_qubit(validate): From 943bd71e84b1d621cfae2d0e4965c3588b03bb48 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 12 Dec 2024 11:40:58 +0000 Subject: [PATCH 2/3] Add range test --- tests/integration/test_range.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/integration/test_range.py b/tests/integration/test_range.py index 2babd982..05665efe 100644 --- a/tests/integration/test_range.py +++ b/tests/integration/test_range.py @@ -1,5 +1,5 @@ from guppylang.decorator import guppy -from guppylang.std.builtins import nat, range, SizedIter, Range +from guppylang.std.builtins import nat, range, SizedIter, Range, py from guppylang.module import GuppyModule from tests.util import compile_guppy @@ -43,3 +43,14 @@ def negative() -> SizedIter[Range, 10]: return range(10) validate(module.compile()) + + +def test_py_size(validate): + module = GuppyModule("test") + n = 10 + + @guppy(module) + def negative() -> SizedIter[Range, 10]: + return range(py(n)) + + validate(module.compile()) From 5ad8fe62f0564af9d5902ea5b1b7d62d2398531f Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 12 Dec 2024 11:49:12 +0000 Subject: [PATCH 3/3] Fix list to array --- guppylang/checker/expr_checker.py | 3 ++- tests/integration/test_py.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index f39c5c6e..3d23424a 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -111,6 +111,7 @@ float_type, get_element_type, int_type, + is_array_type, is_bool_type, is_list_type, is_sized_iter_type, @@ -1230,7 +1231,7 @@ def _python_list_to_guppy_type( # All the list elements must have a unifiable types v, *rest = vs elt_hint = ( - get_element_type(type_hint) if type_hint and is_list_type(type_hint) else None + get_element_type(type_hint) if type_hint and is_array_type(type_hint) else None ) el_ty = python_value_to_guppy_type(v, node, globals, elt_hint) if el_ty is None: diff --git a/tests/integration/test_py.py b/tests/integration/test_py.py index 43512a2f..ac9dce41 100644 --- a/tests/integration/test_py.py +++ b/tests/integration/test_py.py @@ -112,7 +112,7 @@ def test_nats_from_ints(validate): def foo() -> None: x: nat = py(1) y: tuple[nat, nat] = py(2, 3) - z: list[nat] = py([4, 5, 6]) + z: array[nat, 3] = py([4, 5, 6]) validate(foo)