Skip to content

Commit

Permalink
feat: Type check array comprehensions (#614)
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-koch authored Nov 11, 2024
1 parent 3716b7e commit e80c1a0
Show file tree
Hide file tree
Showing 12 changed files with 221 additions and 42 deletions.
32 changes: 24 additions & 8 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from guppylang.experimental import check_function_tensors_enabled, check_lists_enabled
from guppylang.nodes import (
DesugaredGenerator,
DesugaredGeneratorExpr,
DesugaredListComp,
FieldAccessAndDrop,
GlobalName,
Expand Down Expand Up @@ -228,7 +229,9 @@ def visit_DesugaredListComp(
) -> tuple[ast.expr, Subst]:
if not is_list_type(ty):
return self._fail(ty, node)
node, elt_ty = synthesize_comprehension(node, node.generators, self.ctx)
node.generators, node.elt, elt_ty = synthesize_comprehension(
node, node.generators, node.elt, self.ctx
)
subst = unify(get_element_type(ty), elt_ty, {})
if subst is None:
actual = list_type(elt_ty)
Expand Down Expand Up @@ -446,10 +449,23 @@ def visit_List(self, node: ast.List) -> tuple[ast.expr, Type]:
return node, list_type(el_ty)

def visit_DesugaredListComp(self, node: DesugaredListComp) -> tuple[ast.expr, Type]:
node, elt_ty = synthesize_comprehension(node, node.generators, self.ctx)
node.generators, node.elt, elt_ty = synthesize_comprehension(
node, node.generators, node.elt, self.ctx
)
result_ty = list_type(elt_ty)
return node, result_ty

def visit_DesugaredGeneratorExpr(
self, node: DesugaredGeneratorExpr
) -> tuple[ast.expr, Type]:
# This is a generator in an arbitrary expression position. We don't support
# generators as first-class value yet, so we always error out here. Special
# cases where generator are allowed need to explicitly check for them (e.g. see
# the handling of array comprehensions in the compiler for the `array` function)
raise GuppyError(
"Generator expressions are not supported in this positions", node
)

def visit_UnaryOp(self, node: ast.UnaryOp) -> tuple[ast.expr, Type]:
# We need to synthesise the argument type, so we can look up dunder methods
node.operand, op_ty = self.synthesize(node.operand)
Expand Down Expand Up @@ -1017,15 +1033,15 @@ def to_bool(node: ast.expr, node_ty: Type, ctx: Context) -> tuple[ast.expr, Type


def synthesize_comprehension(
node: DesugaredListComp, gens: list[DesugaredGenerator], ctx: Context
) -> tuple[DesugaredListComp, Type]:
node: AstNode, gens: list[DesugaredGenerator], elt: ast.expr, ctx: Context
) -> tuple[list[DesugaredGenerator], ast.expr, Type]:
"""Helper function to synthesise the element type of a list comprehension."""
from guppylang.checker.stmt_checker import StmtChecker

# If there are no more generators left, we can check the list element
if not gens:
node.elt, elt_ty = ExprSynthesizer(ctx).synthesize(node.elt)
return node, elt_ty
elt, elt_ty = ExprSynthesizer(ctx).synthesize(elt)
return gens, elt, elt_ty

# Check the iterator in the outer context
gen, *gens = gens
Expand All @@ -1049,12 +1065,12 @@ def synthesize_comprehension(
gen.ifs[i], _ = to_bool(gen.ifs[i], if_ty, inner_ctx)

# Check remaining generators
node, elt_ty = synthesize_comprehension(node, gens, inner_ctx)
gens, elt, elt_ty = synthesize_comprehension(node, gens, elt, inner_ctx)

# The iter finalizer is again checked in the outer context
gen.iterend, iterend_ty = ExprSynthesizer(ctx).synthesize(gen.iterend)
gen.iterend = with_type(iterend_ty, gen.iterend)
return node, elt_ty
return [gen, *gens], elt, elt_ty


def eval_py_expr(node: PyExpr, ctx: Context) -> Any:
Expand Down
133 changes: 102 additions & 31 deletions guppylang/prelude/_internal/checker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import ast
from typing import cast

from typing_extensions import assert_never

from guppylang.ast_util import AstNode, with_loc, with_type
from guppylang.checker.core import Context
from guppylang.checker.expr_checker import (
Expand All @@ -10,6 +12,7 @@
check_num_args,
check_type_against,
synthesize_call,
synthesize_comprehension,
)
from guppylang.definition.custom import (
CustomCallChecker,
Expand All @@ -19,14 +22,22 @@
from guppylang.definition.struct import CheckedStructDef, RawStructDef
from guppylang.definition.value import CallableDef
from guppylang.error import GuppyError, GuppyTypeError, InternalGuppyError
from guppylang.nodes import GlobalCall, ResultExpr
from guppylang.nodes import (
DesugaredArrayComp,
DesugaredGeneratorExpr,
GlobalCall,
MakeIter,
ResultExpr,
)
from guppylang.tys.arg import ConstArg, TypeArg
from guppylang.tys.builtin import (
array_type,
bool_type,
get_iter_size,
int_type,
is_array_type,
is_bool_type,
is_sized_iter_type,
sized_iter_type,
)
from guppylang.tys.const import Const, ConstValue
Expand Down Expand Up @@ -195,47 +206,107 @@ class NewArrayChecker(CustomCallChecker):
"""Function call checker for the `array.__new__` function."""

def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
if len(args) == 0:
raise GuppyTypeError(
"Cannot infer the array element type. Consider adding a type "
"annotation.",
self.node,
)
[fst, *rest] = args
fst, ty = ExprSynthesizer(self.ctx).synthesize(fst)
checker = ExprChecker(self.ctx)
for i in range(len(rest)):
rest[i], subst = checker.check(rest[i], ty)
assert len(subst) == 0, "Array element type is closed"
result_ty = array_type(ty, len(args))
call = GlobalCall(
def_id=self.func.id, args=[fst, *rest], type_args=result_ty.args
)
return with_loc(self.node, call), result_ty
match args:
case []:
raise GuppyTypeError(
"Cannot infer the array element type. Consider adding a type "
"annotation.",
self.node,
)
# Either an array comprehension
case [DesugaredGeneratorExpr() as compr]:
return self.synthesize_array_comprehension(compr)
# Or a list of array elements
case [fst, *rest]:
fst, ty = ExprSynthesizer(self.ctx).synthesize(fst)
checker = ExprChecker(self.ctx)
for i in range(len(rest)):
rest[i], subst = checker.check(rest[i], ty)
assert len(subst) == 0, "Array element type is closed"
result_ty = array_type(ty, len(args))
call = GlobalCall(
def_id=self.func.id, args=[fst, *rest], type_args=result_ty.args
)
return with_loc(self.node, call), result_ty
case args:
return assert_never(args) # type: ignore[arg-type]

def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
if not is_array_type(ty):
raise GuppyTypeError(
f"Expected expression of type `{ty}`, got `array`", self.node
)
subst: Subst = {}
match ty.args:
case [TypeArg(ty=elem_ty), ConstArg(ConstValue(value=int(length)))]:
subst: Subst = {}
checker = ExprChecker(self.ctx)
for i in range(len(args)):
args[i], s = checker.check(args[i], elem_ty.substitute(subst))
subst |= s
if len(args) != length:
raise GuppyTypeError(
f"Expected expression of type `{ty}`, got "
f"`array[{elem_ty}, {len(args)}]`",
self.node,
)
call = GlobalCall(def_id=self.func.id, args=args, type_args=ty.args)
return with_loc(self.node, call), subst
match args:
# Either an array comprehension
case [DesugaredGeneratorExpr() as compr]:
# TODO: We could use the type information to infer some stuff
# in the comprehension
arr_compr, res_ty = self.synthesize_array_comprehension(compr)
subst, _ = check_type_against(res_ty, ty, self.node)
return arr_compr, subst
# Or a list of array elements
case args:
checker = ExprChecker(self.ctx)
for i in range(len(args)):
args[i], s = checker.check(
args[i], elem_ty.substitute(subst)
)
subst |= s
if len(args) != length:
raise GuppyTypeError(
f"Expected expression of type `{ty}`, got "
f"`array[{elem_ty}, {len(args)}]`",
self.node,
)
call = GlobalCall(
def_id=self.func.id, args=args, type_args=ty.args
)
return with_loc(self.node, call), subst
case type_args:
raise InternalGuppyError(f"Invalid array type args: {type_args}")

def synthesize_array_comprehension(
self, compr: DesugaredGeneratorExpr
) -> tuple[DesugaredArrayComp, Type]:
# Array comprehensions require a static size. To keep things simple, we'll only
# allow a single generator for now, so we don't have to reason about products
# of iterator sizes.
if len(compr.generators) > 1:
# Individual generator objects unfortunately don't have a span in Python's
# AST, so we have to use the whole expression span
raise GuppyError("Nested array comprehensions are not supported", compr)
[gen] = compr.generators
# Similarly, dynamic if guards are not allowed
if gen.ifs:
raise GuppyError(
"If guards are not allowed in array comprehensions", gen.ifs[0]
)
# Extract the iterator size
match gen.iter_assign:
case ast.Assign(value=MakeIter() as make_iter):
make_iter.value, iter_ty = ExprSynthesizer(self.ctx).synthesize(
make_iter.value
)
# The iterator must have a static size hint
if not is_sized_iter_type(iter_ty):
raise GuppyError(
f"Iterator of type `{iter_ty}` doesn't have a static size "
"require for array comprehensions",
make_iter,
)
size = get_iter_size(iter_ty)
case _:
raise InternalGuppyError("Invalid iterator assign statement")
# Finally, type check the comprehension
[gen], elt, elt_ty = synthesize_comprehension(compr, [gen], compr.elt, self.ctx)
array_compr = DesugaredArrayComp(
elt=elt, generator=gen, length=size, elt_ty=elt_ty
)
return with_loc(compr, array_compr), array_type(elt_ty, size)


#: Maximum length of a tag in the `result` function.
TAG_MAX_LEN = 200
Expand Down
2 changes: 1 addition & 1 deletion guppylang/prelude/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class nat:
class array(Generic[_T, _n]):
"""Class to import in order to use arrays."""

def __init__(self, *args: _T):
def __init__(self, *args: Any):
pass


Expand Down
4 changes: 2 additions & 2 deletions guppylang/tys/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ def is_sized_iter_type(ty: Type) -> TypeGuard[OpaqueType]:

def get_element_type(ty: Type) -> Type:
assert isinstance(ty, OpaqueType)
assert ty.defn == list_type_def
(arg,) = ty.args
assert ty.defn == list_type_def or ty.defn == array_type_def
(arg, *_) = ty.args
assert isinstance(arg, TypeArg)
return arg.ty

Expand Down
7 changes: 7 additions & 0 deletions tests/error/array_errors/comprehension_wrong_length.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:13

11: @guppy(module)
12: def main() -> array[int, 42]:
13: return array(i for i in range(10))
^^^^^^^^^^^^^^^^^^^^^^^^^^^
GuppyTypeError: Expected expression of type `array[int, 42]`, got `array[int, 10]`
16 changes: 16 additions & 0 deletions tests/error/array_errors/comprehension_wrong_length.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import guppylang.prelude.quantum as quantum
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.builtins import array


module = GuppyModule("test")
module.load(quantum)


@guppy(module)
def main() -> array[int, 42]:
return array(i for i in range(10))


module.compile()
7 changes: 7 additions & 0 deletions tests/error/array_errors/guarded_comprehension.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:13

11: @guppy(module)
12: def main() -> None:
13: array(i for i in range(100) if i % 2 == 0)
^^^^^^^^^^
GuppyError: If guards are not allowed in array comprehensions
16 changes: 16 additions & 0 deletions tests/error/array_errors/guarded_comprehension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import guppylang.prelude.quantum as quantum
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.builtins import array


module = GuppyModule("test")
module.load(quantum)


@guppy(module)
def main() -> None:
array(i for i in range(100) if i % 2 == 0)


module.compile()
7 changes: 7 additions & 0 deletions tests/error/array_errors/nested_comprehension.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:13

11: @guppy(module)
12: def main() -> array[int, 50]:
13: return array(0 for _ in range(10) for _ in range(5))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
GuppyError: Nested array comprehensions are not supported
16 changes: 16 additions & 0 deletions tests/error/array_errors/nested_comprehension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import guppylang.prelude.quantum as quantum
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.builtins import array


module = GuppyModule("test")
module.load(quantum)


@guppy(module)
def main() -> array[int, 50]:
return array(0 for _ in range(10) for _ in range(5))


module.compile()
7 changes: 7 additions & 0 deletions tests/error/array_errors/non_static_comprehension.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:13

11: @guppy(module)
12: def main(n: int) -> None:
13: array(i for i in range(n))
^^^^^^^^
GuppyError: Iterator of type `Range` doesn't have a static size require for array comprehensions
16 changes: 16 additions & 0 deletions tests/error/array_errors/non_static_comprehension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import guppylang.prelude.quantum as quantum
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.builtins import array


module = GuppyModule("test")
module.load(quantum)


@guppy(module)
def main(n: int) -> None:
array(i for i in range(n))


module.compile()

0 comments on commit e80c1a0

Please sign in to comment.