diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index 8dcbdea3..3d23424a 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -108,11 +108,15 @@ from guppylang.tys.builtin import ( array_type, bool_type, + float_type, get_element_type, + int_type, + is_array_type, is_bool_type, is_list_type, is_sized_iter_type, list_type, + nat_type, ) from guppylang.tys.param import ConstParam, TypeParam from guppylang.tys.subst import Inst, Subst @@ -232,6 +236,14 @@ def _synthesize( """Invokes the type synthesiser""" return ExprSynthesizer(self.ctx).synthesize(node, allow_free_vars) + def visit_Constant(self, node: ast.Constant, ty: Type) -> tuple[ast.expr, Subst]: + act = python_value_to_guppy_type(node.value, node, self.ctx.globals, ty) + if act is None: + raise GuppyError(IllegalConstant(node, type(node.value))) + node, subst, inst = check_type_against(act, ty, node, self.ctx, self._kind) + assert inst == [], "Const values are not generic" + return node, subst + def visit_Tuple(self, node: ast.Tuple, ty: Type) -> tuple[ast.expr, Subst]: if not isinstance(ty, TupleType) or len(ty.element_types) != len(node.elts): return self._fail(ty, node) @@ -318,7 +330,9 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: def visit_PyExpr(self, node: PyExpr, ty: Type) -> tuple[ast.expr, Subst]: python_val = eval_py_expr(node, self.ctx) - if act := python_value_to_guppy_type(python_val, node.value, self.ctx.globals): + if act := python_value_to_guppy_type( + python_val, node.value, self.ctx.globals, ty + ): subst = unify(ty, act, {}) if subst is None: self._fail(ty, act, node) @@ -1137,25 +1151,43 @@ def eval_py_expr(node: PyExpr, ctx: Context) -> Any: return python_val -def python_value_to_guppy_type(v: Any, node: ast.expr, globals: Globals) -> Type | None: +def python_value_to_guppy_type( + v: Any, node: ast.expr, globals: Globals, type_hint: Type | None = None +) -> Type | None: """Turns a primitive Python value into a Guppy type. + Accepts an optional `type_hint` for the expected expression type that is used to + infer a more precise type (e.g. distinguishing between `int` and `nat`). Note that + invalid hints are ignored, i.e. no user error are emitted. + Returns `None` if the Python value cannot be represented in Guppy. """ match v: case bool(): return bool_type() + # Only resolve `int` to `nat` if the user specifically asked for it + case int(n) if type_hint == nat_type() and n >= 0: + return nat_type() + # Otherwise, default to `int` for consistency with Python case int(): - return cast(TypeDef, globals["int"]).check_instantiate([], globals) + return int_type() case float(): - return cast(TypeDef, globals["float"]).check_instantiate([], globals) + return float_type() case tuple(elts): - tys = [python_value_to_guppy_type(elt, node, globals) for elt in elts] + hints = ( + type_hint.element_types + if isinstance(type_hint, TupleType) + else len(elts) * [None] + ) + tys = [ + python_value_to_guppy_type(elt, node, globals, hint) + for elt, hint in zip(elts, hints, strict=False) + ] if any(ty is None for ty in tys): return None return TupleType(cast(list[Type], tys)) case list(): - return _python_list_to_guppy_type(v, node, globals) + return _python_list_to_guppy_type(v, node, globals, type_hint) case _: # Pytket conversion is an experimental feature # if pytket and tket2 are installed @@ -1186,7 +1218,7 @@ def python_value_to_guppy_type(v: Any, node: ast.expr, globals: Globals) -> Type def _python_list_to_guppy_type( - vs: list[Any], node: ast.expr, globals: Globals + vs: list[Any], node: ast.expr, globals: Globals, type_hint: Type | None ) -> OpaqueType | None: """Turns a Python list into a Guppy type. @@ -1198,11 +1230,14 @@ def _python_list_to_guppy_type( # All the list elements must have a unifiable types v, *rest = vs - el_ty = python_value_to_guppy_type(v, node, globals) + elt_hint = ( + get_element_type(type_hint) if type_hint and is_array_type(type_hint) else None + ) + el_ty = python_value_to_guppy_type(v, node, globals, elt_hint) if el_ty is None: return None for v in rest: - ty = python_value_to_guppy_type(v, node, globals) + ty = python_value_to_guppy_type(v, node, globals, elt_hint) if ty is None: return None if (subst := unify(ty, el_ty, {})) is None: diff --git a/guppylang/std/builtins.py b/guppylang/std/builtins.py index a4027cd3..d9843bc2 100644 --- a/guppylang/std/builtins.py +++ b/guppylang/std/builtins.py @@ -106,7 +106,9 @@ def __int__(self: bool) -> int: @guppy @no_type_check def __nat__(self: bool) -> nat: - # TODO: Literals should check against `nat` + # TODO: Type information doesn't flow through the `if` expression, so we + # have to insert the `nat` coercions by hand. + # See https://github.com/CQCL/guppylang/issues/707 return nat(1) if self else nat(0) @guppy.custom(checker=DunderChecker("__bool__"), higher_order_value=False) diff --git a/guppylang/tys/builtin.py b/guppylang/tys/builtin.py index 89ffc5c8..86630c28 100644 --- a/guppylang/tys/builtin.py +++ b/guppylang/tys/builtin.py @@ -211,6 +211,10 @@ def int_type() -> NumericType: return NumericType(NumericType.Kind.Int) +def float_type() -> NumericType: + return NumericType(NumericType.Kind.Float) + + def list_type(element_ty: Type) -> OpaqueType: return OpaqueType([TypeArg(element_ty)], list_type_def) diff --git a/tests/error/py_errors/negative_nat.err b/tests/error/py_errors/negative_nat.err new file mode 100644 index 00000000..3cfd76c6 --- /dev/null +++ b/tests/error/py_errors/negative_nat.err @@ -0,0 +1,8 @@ +Error: Type mismatch (at $FILE:7:11) + | +5 | @compile_guppy +6 | def foo() -> nat: +7 | return py(-1) + | ^^^^^^ Expected expression of type `nat`, got `int` + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/py_errors/negative_nat.py b/tests/error/py_errors/negative_nat.py new file mode 100644 index 00000000..21170653 --- /dev/null +++ b/tests/error/py_errors/negative_nat.py @@ -0,0 +1,7 @@ +from guppylang.std.builtins import nat +from tests.util import compile_guppy + + +@compile_guppy +def foo() -> nat: + return py(-1) diff --git a/tests/integration/test_arithmetic.py b/tests/integration/test_arithmetic.py index 33d03914..d8bbc348 100644 --- a/tests/integration/test_arithmetic.py +++ b/tests/integration/test_arithmetic.py @@ -42,6 +42,14 @@ def const() -> float: validate(const) +def test_nat_literal(validate): + @compile_guppy + def const() -> nat: + return 42 + + validate(const) + + def test_aug_assign(validate, run_int_fn): module = GuppyModule("test_aug_assign") diff --git a/tests/integration/test_py.py b/tests/integration/test_py.py index f62d50a5..ac9dce41 100644 --- a/tests/integration/test_py.py +++ b/tests/integration/test_py.py @@ -6,7 +6,7 @@ from guppylang.decorator import guppy from guppylang.module import GuppyModule -from guppylang.std.builtins import py, array +from guppylang.std.builtins import py, array, nat from guppylang.std import quantum from guppylang.std.quantum import qubit from tests.util import compile_guppy @@ -107,6 +107,16 @@ def foo() -> None: validate(foo) +def test_nats_from_ints(validate): + @compile_guppy + def foo() -> None: + x: nat = py(1) + y: tuple[nat, nat] = py(2, 3) + z: array[nat, 3] = py([4, 5, 6]) + + validate(foo) + + @pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed") @pytest.mark.skip("Fails because of extensions in types #343") def test_pytket_single_qubit(validate): diff --git a/tests/integration/test_range.py b/tests/integration/test_range.py index 092db623..07cfff48 100644 --- a/tests/integration/test_range.py +++ b/tests/integration/test_range.py @@ -1,5 +1,5 @@ from guppylang.decorator import guppy -from guppylang.std.builtins import nat, range, SizedIter, Range +from guppylang.std.builtins import nat, range, SizedIter, Range, py from guppylang.module import GuppyModule from tests.util import compile_guppy @@ -45,6 +45,17 @@ def negative() -> SizedIter[Range, 10]: validate(module.compile()) +def test_py_size(validate): + module = GuppyModule("test") + n = 10 + + @guppy(module) + def negative() -> SizedIter[Range, 10]: + return range(py(n)) + + validate(module.compile()) + + def test_static_generic_size(validate): module = GuppyModule("test") n = guppy.nat_var("n", module=module)