From 1e3fd2e8c0ef384d8994dbb664d2e4ad25ef3275 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 4 Nov 2024 13:08:59 +0000 Subject: [PATCH] feat: Type check array comprehensions --- guppylang/checker/expr_checker.py | 32 +++-- guppylang/prelude/_internal/checker.py | 133 ++++++++++++++---- guppylang/prelude/builtins.py | 2 +- guppylang/tys/builtin.py | 4 +- .../comprehension_wrong_length.err | 7 + .../comprehension_wrong_length.py | 16 +++ .../array_errors/guarded_comprehension.err | 7 + .../array_errors/guarded_comprehension.py | 16 +++ .../array_errors/nested_comprehension.err | 7 + .../array_errors/nested_comprehension.py | 16 +++ .../array_errors/non_static_comprehension.err | 7 + .../array_errors/non_static_comprehension.py | 16 +++ 12 files changed, 221 insertions(+), 42 deletions(-) create mode 100644 tests/error/array_errors/comprehension_wrong_length.err create mode 100644 tests/error/array_errors/comprehension_wrong_length.py create mode 100644 tests/error/array_errors/guarded_comprehension.err create mode 100644 tests/error/array_errors/guarded_comprehension.py create mode 100644 tests/error/array_errors/nested_comprehension.err create mode 100644 tests/error/array_errors/nested_comprehension.py create mode 100644 tests/error/array_errors/non_static_comprehension.err create mode 100644 tests/error/array_errors/non_static_comprehension.py diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index c21ffede..b0a11e5b 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -60,6 +60,7 @@ from guppylang.experimental import check_function_tensors_enabled, check_lists_enabled from guppylang.nodes import ( DesugaredGenerator, + DesugaredGeneratorExpr, DesugaredListComp, FieldAccessAndDrop, GlobalName, @@ -228,7 +229,9 @@ def visit_DesugaredListComp( ) -> tuple[ast.expr, Subst]: if not is_list_type(ty): return self._fail(ty, node) - node, elt_ty = synthesize_comprehension(node, node.generators, self.ctx) + node.generators, node.elt, elt_ty = synthesize_comprehension( + node, node.generators, node.elt, self.ctx + ) subst = unify(get_element_type(ty), elt_ty, {}) if subst is None: actual = list_type(elt_ty) @@ -446,10 +449,23 @@ def visit_List(self, node: ast.List) -> tuple[ast.expr, Type]: return node, list_type(el_ty) def visit_DesugaredListComp(self, node: DesugaredListComp) -> tuple[ast.expr, Type]: - node, elt_ty = synthesize_comprehension(node, node.generators, self.ctx) + node.generators, node.elt, elt_ty = synthesize_comprehension( + node, node.generators, node.elt, self.ctx + ) result_ty = list_type(elt_ty) return node, result_ty + def visit_DesugaredGeneratorExpr( + self, node: DesugaredGeneratorExpr + ) -> tuple[ast.expr, Type]: + # This is a generator in an arbitrary expression position. We don't support + # generators as first-class value yet, so we always error out here. Special + # cases where generator are allowed need to explicitly check for them (e.g. see + # the handling of array comprehensions in the compiler for the `array` function) + raise GuppyError( + "Generator expressions are not supported in this positions", node + ) + def visit_UnaryOp(self, node: ast.UnaryOp) -> tuple[ast.expr, Type]: # We need to synthesise the argument type, so we can look up dunder methods node.operand, op_ty = self.synthesize(node.operand) @@ -1017,15 +1033,15 @@ def to_bool(node: ast.expr, node_ty: Type, ctx: Context) -> tuple[ast.expr, Type def synthesize_comprehension( - node: DesugaredListComp, gens: list[DesugaredGenerator], ctx: Context -) -> tuple[DesugaredListComp, Type]: + node: AstNode, gens: list[DesugaredGenerator], elt: ast.expr, ctx: Context +) -> tuple[list[DesugaredGenerator], ast.expr, Type]: """Helper function to synthesise the element type of a list comprehension.""" from guppylang.checker.stmt_checker import StmtChecker # If there are no more generators left, we can check the list element if not gens: - node.elt, elt_ty = ExprSynthesizer(ctx).synthesize(node.elt) - return node, elt_ty + elt, elt_ty = ExprSynthesizer(ctx).synthesize(elt) + return gens, elt, elt_ty # Check the iterator in the outer context gen, *gens = gens @@ -1049,12 +1065,12 @@ def synthesize_comprehension( gen.ifs[i], _ = to_bool(gen.ifs[i], if_ty, inner_ctx) # Check remaining generators - node, elt_ty = synthesize_comprehension(node, gens, inner_ctx) + gens, elt, elt_ty = synthesize_comprehension(node, gens, elt, inner_ctx) # The iter finalizer is again checked in the outer context gen.iterend, iterend_ty = ExprSynthesizer(ctx).synthesize(gen.iterend) gen.iterend = with_type(iterend_ty, gen.iterend) - return node, elt_ty + return [gen, *gens], elt, elt_ty def eval_py_expr(node: PyExpr, ctx: Context) -> Any: diff --git a/guppylang/prelude/_internal/checker.py b/guppylang/prelude/_internal/checker.py index 7a2eafe2..3f871591 100644 --- a/guppylang/prelude/_internal/checker.py +++ b/guppylang/prelude/_internal/checker.py @@ -1,6 +1,8 @@ import ast from typing import cast +from typing_extensions import assert_never + from guppylang.ast_util import AstNode, with_loc, with_type from guppylang.checker.core import Context from guppylang.checker.expr_checker import ( @@ -10,6 +12,7 @@ check_num_args, check_type_against, synthesize_call, + synthesize_comprehension, ) from guppylang.definition.custom import ( CustomCallChecker, @@ -19,14 +22,22 @@ from guppylang.definition.struct import CheckedStructDef, RawStructDef from guppylang.definition.value import CallableDef from guppylang.error import GuppyError, GuppyTypeError, InternalGuppyError -from guppylang.nodes import GlobalCall, ResultExpr +from guppylang.nodes import ( + DesugaredArrayComp, + DesugaredGeneratorExpr, + GlobalCall, + MakeIter, + ResultExpr, +) from guppylang.tys.arg import ConstArg, TypeArg from guppylang.tys.builtin import ( array_type, bool_type, + get_iter_size, int_type, is_array_type, is_bool_type, + is_sized_iter_type, sized_iter_type, ) from guppylang.tys.const import Const, ConstValue @@ -195,47 +206,107 @@ class NewArrayChecker(CustomCallChecker): """Function call checker for the `array.__new__` function.""" def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: - if len(args) == 0: - raise GuppyTypeError( - "Cannot infer the array element type. Consider adding a type " - "annotation.", - self.node, - ) - [fst, *rest] = args - fst, ty = ExprSynthesizer(self.ctx).synthesize(fst) - checker = ExprChecker(self.ctx) - for i in range(len(rest)): - rest[i], subst = checker.check(rest[i], ty) - assert len(subst) == 0, "Array element type is closed" - result_ty = array_type(ty, len(args)) - call = GlobalCall( - def_id=self.func.id, args=[fst, *rest], type_args=result_ty.args - ) - return with_loc(self.node, call), result_ty + match args: + case []: + raise GuppyTypeError( + "Cannot infer the array element type. Consider adding a type " + "annotation.", + self.node, + ) + # Either an array comprehension + case [DesugaredGeneratorExpr() as compr]: + return self.synthesize_array_comprehension(compr) + # Or a list of array elements + case [fst, *rest]: + fst, ty = ExprSynthesizer(self.ctx).synthesize(fst) + checker = ExprChecker(self.ctx) + for i in range(len(rest)): + rest[i], subst = checker.check(rest[i], ty) + assert len(subst) == 0, "Array element type is closed" + result_ty = array_type(ty, len(args)) + call = GlobalCall( + def_id=self.func.id, args=[fst, *rest], type_args=result_ty.args + ) + return with_loc(self.node, call), result_ty + case args: + return assert_never(args) # type: ignore[arg-type] def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]: if not is_array_type(ty): raise GuppyTypeError( f"Expected expression of type `{ty}`, got `array`", self.node ) + subst: Subst = {} match ty.args: case [TypeArg(ty=elem_ty), ConstArg(ConstValue(value=int(length)))]: - subst: Subst = {} - checker = ExprChecker(self.ctx) - for i in range(len(args)): - args[i], s = checker.check(args[i], elem_ty.substitute(subst)) - subst |= s - if len(args) != length: - raise GuppyTypeError( - f"Expected expression of type `{ty}`, got " - f"`array[{elem_ty}, {len(args)}]`", - self.node, - ) - call = GlobalCall(def_id=self.func.id, args=args, type_args=ty.args) - return with_loc(self.node, call), subst + match args: + # Either an array comprehension + case [DesugaredGeneratorExpr() as compr]: + # TODO: We could use the type information to infer some stuff + # in the comprehension + arr_compr, res_ty = self.synthesize_array_comprehension(compr) + subst, _ = check_type_against(res_ty, ty, self.node) + return arr_compr, subst + # Or a list of array elements + case args: + checker = ExprChecker(self.ctx) + for i in range(len(args)): + args[i], s = checker.check( + args[i], elem_ty.substitute(subst) + ) + subst |= s + if len(args) != length: + raise GuppyTypeError( + f"Expected expression of type `{ty}`, got " + f"`array[{elem_ty}, {len(args)}]`", + self.node, + ) + call = GlobalCall( + def_id=self.func.id, args=args, type_args=ty.args + ) + return with_loc(self.node, call), subst case type_args: raise InternalGuppyError(f"Invalid array type args: {type_args}") + def synthesize_array_comprehension( + self, compr: DesugaredGeneratorExpr + ) -> tuple[DesugaredArrayComp, Type]: + # Array comprehensions require a static size. To keep things simple, we'll only + # allow a single generator for now, so we don't have to reason about products + # of iterator sizes. + if len(compr.generators) > 1: + # Individual generator objects unfortunately don't have a span in Python's + # AST, so we have to use the whole expression span + raise GuppyError("Nested array comprehensions are not supported", compr) + [gen] = compr.generators + # Similarly, dynamic if guards are not allowed + if gen.ifs: + raise GuppyError( + "If guards are not allowed in array comprehensions", gen.ifs[0] + ) + # Extract the iterator size + match gen.iter_assign: + case ast.Assign(value=MakeIter() as make_iter): + make_iter.value, iter_ty = ExprSynthesizer(self.ctx).synthesize( + make_iter.value + ) + # The iterator must have a static size hint + if not is_sized_iter_type(iter_ty): + raise GuppyError( + f"Iterator of type `{iter_ty}` doesn't have a static size " + "require for array comprehensions", + make_iter, + ) + size = get_iter_size(iter_ty) + case _: + raise InternalGuppyError("Invalid iterator assign statement") + # Finally, type check the comprehension + [gen], elt, elt_ty = synthesize_comprehension(compr, [gen], compr.elt, self.ctx) + array_compr = DesugaredArrayComp( + elt=elt, generator=gen, length=size, elt_ty=elt_ty + ) + return with_loc(compr, array_compr), array_type(elt_ty, size) + #: Maximum length of a tag in the `result` function. TAG_MAX_LEN = 200 diff --git a/guppylang/prelude/builtins.py b/guppylang/prelude/builtins.py index 486dda11..73acc883 100644 --- a/guppylang/prelude/builtins.py +++ b/guppylang/prelude/builtins.py @@ -92,7 +92,7 @@ class nat: class array(Generic[_T, _n]): """Class to import in order to use arrays.""" - def __init__(self, *args: _T): + def __init__(self, *args: Any): pass diff --git a/guppylang/tys/builtin.py b/guppylang/tys/builtin.py index f201ca61..70d7299d 100644 --- a/guppylang/tys/builtin.py +++ b/guppylang/tys/builtin.py @@ -243,8 +243,8 @@ def is_sized_iter_type(ty: Type) -> TypeGuard[OpaqueType]: def get_element_type(ty: Type) -> Type: assert isinstance(ty, OpaqueType) - assert ty.defn == list_type_def - (arg,) = ty.args + assert ty.defn == list_type_def or ty.defn == array_type_def + (arg, *_) = ty.args assert isinstance(arg, TypeArg) return arg.ty diff --git a/tests/error/array_errors/comprehension_wrong_length.err b/tests/error/array_errors/comprehension_wrong_length.err new file mode 100644 index 00000000..ab679359 --- /dev/null +++ b/tests/error/array_errors/comprehension_wrong_length.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: @guppy(module) +12: def main() -> array[int, 42]: +13: return array(i for i in range(10)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^ +GuppyTypeError: Expected expression of type `array[int, 42]`, got `array[int, 10]` diff --git a/tests/error/array_errors/comprehension_wrong_length.py b/tests/error/array_errors/comprehension_wrong_length.py new file mode 100644 index 00000000..fae641c7 --- /dev/null +++ b/tests/error/array_errors/comprehension_wrong_length.py @@ -0,0 +1,16 @@ +import guppylang.prelude.quantum as quantum +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import array + + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def main() -> array[int, 42]: + return array(i for i in range(10)) + + +module.compile() diff --git a/tests/error/array_errors/guarded_comprehension.err b/tests/error/array_errors/guarded_comprehension.err new file mode 100644 index 00000000..b387d0e7 --- /dev/null +++ b/tests/error/array_errors/guarded_comprehension.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: @guppy(module) +12: def main() -> None: +13: array(i for i in range(100) if i % 2 == 0) + ^^^^^^^^^^ +GuppyError: If guards are not allowed in array comprehensions diff --git a/tests/error/array_errors/guarded_comprehension.py b/tests/error/array_errors/guarded_comprehension.py new file mode 100644 index 00000000..49bf9769 --- /dev/null +++ b/tests/error/array_errors/guarded_comprehension.py @@ -0,0 +1,16 @@ +import guppylang.prelude.quantum as quantum +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import array + + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def main() -> None: + array(i for i in range(100) if i % 2 == 0) + + +module.compile() diff --git a/tests/error/array_errors/nested_comprehension.err b/tests/error/array_errors/nested_comprehension.err new file mode 100644 index 00000000..b316cb83 --- /dev/null +++ b/tests/error/array_errors/nested_comprehension.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: @guppy(module) +12: def main() -> array[int, 50]: +13: return array(0 for _ in range(10) for _ in range(5)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +GuppyError: Nested array comprehensions are not supported diff --git a/tests/error/array_errors/nested_comprehension.py b/tests/error/array_errors/nested_comprehension.py new file mode 100644 index 00000000..f107b22c --- /dev/null +++ b/tests/error/array_errors/nested_comprehension.py @@ -0,0 +1,16 @@ +import guppylang.prelude.quantum as quantum +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import array + + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def main() -> array[int, 50]: + return array(0 for _ in range(10) for _ in range(5)) + + +module.compile() diff --git a/tests/error/array_errors/non_static_comprehension.err b/tests/error/array_errors/non_static_comprehension.err new file mode 100644 index 00000000..7834391a --- /dev/null +++ b/tests/error/array_errors/non_static_comprehension.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: @guppy(module) +12: def main(n: int) -> None: +13: array(i for i in range(n)) + ^^^^^^^^ +GuppyError: Iterator of type `Range` doesn't have a static size require for array comprehensions diff --git a/tests/error/array_errors/non_static_comprehension.py b/tests/error/array_errors/non_static_comprehension.py new file mode 100644 index 00000000..ad53c7c9 --- /dev/null +++ b/tests/error/array_errors/non_static_comprehension.py @@ -0,0 +1,16 @@ +import guppylang.prelude.quantum as quantum +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import array + + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def main(n: int) -> None: + array(i for i in range(n)) + + +module.compile()