Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Type check array comprehensions #614

Merged
merged 1 commit into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from guppylang.experimental import check_function_tensors_enabled, check_lists_enabled
from guppylang.nodes import (
DesugaredGenerator,
DesugaredGeneratorExpr,
DesugaredListComp,
FieldAccessAndDrop,
GlobalName,
Expand Down Expand Up @@ -228,7 +229,9 @@ def visit_DesugaredListComp(
) -> tuple[ast.expr, Subst]:
if not is_list_type(ty):
return self._fail(ty, node)
node, elt_ty = synthesize_comprehension(node, node.generators, self.ctx)
node.generators, node.elt, elt_ty = synthesize_comprehension(
node, node.generators, node.elt, self.ctx
)
subst = unify(get_element_type(ty), elt_ty, {})
if subst is None:
actual = list_type(elt_ty)
Expand Down Expand Up @@ -446,10 +449,23 @@ def visit_List(self, node: ast.List) -> tuple[ast.expr, Type]:
return node, list_type(el_ty)

def visit_DesugaredListComp(self, node: DesugaredListComp) -> tuple[ast.expr, Type]:
node, elt_ty = synthesize_comprehension(node, node.generators, self.ctx)
node.generators, node.elt, elt_ty = synthesize_comprehension(
node, node.generators, node.elt, self.ctx
)
result_ty = list_type(elt_ty)
return node, result_ty

def visit_DesugaredGeneratorExpr(
self, node: DesugaredGeneratorExpr
) -> tuple[ast.expr, Type]:
# This is a generator in an arbitrary expression position. We don't support
# generators as first-class value yet, so we always error out here. Special
# cases where generator are allowed need to explicitly check for them (e.g. see
# the handling of array comprehensions in the compiler for the `array` function)
raise GuppyError(
"Generator expressions are not supported in this positions", node
)

def visit_UnaryOp(self, node: ast.UnaryOp) -> tuple[ast.expr, Type]:
# We need to synthesise the argument type, so we can look up dunder methods
node.operand, op_ty = self.synthesize(node.operand)
Expand Down Expand Up @@ -1017,15 +1033,15 @@ def to_bool(node: ast.expr, node_ty: Type, ctx: Context) -> tuple[ast.expr, Type


def synthesize_comprehension(
node: DesugaredListComp, gens: list[DesugaredGenerator], ctx: Context
) -> tuple[DesugaredListComp, Type]:
node: AstNode, gens: list[DesugaredGenerator], elt: ast.expr, ctx: Context
) -> tuple[list[DesugaredGenerator], ast.expr, Type]:
"""Helper function to synthesise the element type of a list comprehension."""
from guppylang.checker.stmt_checker import StmtChecker

# If there are no more generators left, we can check the list element
if not gens:
node.elt, elt_ty = ExprSynthesizer(ctx).synthesize(node.elt)
return node, elt_ty
elt, elt_ty = ExprSynthesizer(ctx).synthesize(elt)
return gens, elt, elt_ty

# Check the iterator in the outer context
gen, *gens = gens
Expand All @@ -1049,12 +1065,12 @@ def synthesize_comprehension(
gen.ifs[i], _ = to_bool(gen.ifs[i], if_ty, inner_ctx)

# Check remaining generators
node, elt_ty = synthesize_comprehension(node, gens, inner_ctx)
gens, elt, elt_ty = synthesize_comprehension(node, gens, elt, inner_ctx)

# The iter finalizer is again checked in the outer context
gen.iterend, iterend_ty = ExprSynthesizer(ctx).synthesize(gen.iterend)
gen.iterend = with_type(iterend_ty, gen.iterend)
return node, elt_ty
return [gen, *gens], elt, elt_ty


def eval_py_expr(node: PyExpr, ctx: Context) -> Any:
Expand Down
133 changes: 102 additions & 31 deletions guppylang/prelude/_internal/checker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import ast
from typing import cast

from typing_extensions import assert_never

from guppylang.ast_util import AstNode, with_loc, with_type
from guppylang.checker.core import Context
from guppylang.checker.expr_checker import (
Expand All @@ -10,6 +12,7 @@
check_num_args,
check_type_against,
synthesize_call,
synthesize_comprehension,
)
from guppylang.definition.custom import (
CustomCallChecker,
Expand All @@ -19,14 +22,22 @@
from guppylang.definition.struct import CheckedStructDef, RawStructDef
from guppylang.definition.value import CallableDef
from guppylang.error import GuppyError, GuppyTypeError, InternalGuppyError
from guppylang.nodes import GlobalCall, ResultExpr
from guppylang.nodes import (
DesugaredArrayComp,
DesugaredGeneratorExpr,
GlobalCall,
MakeIter,
ResultExpr,
)
from guppylang.tys.arg import ConstArg, TypeArg
from guppylang.tys.builtin import (
array_type,
bool_type,
get_iter_size,
int_type,
is_array_type,
is_bool_type,
is_sized_iter_type,
sized_iter_type,
)
from guppylang.tys.const import Const, ConstValue
Expand Down Expand Up @@ -195,47 +206,107 @@ class NewArrayChecker(CustomCallChecker):
"""Function call checker for the `array.__new__` function."""

def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
if len(args) == 0:
raise GuppyTypeError(
"Cannot infer the array element type. Consider adding a type "
"annotation.",
self.node,
)
[fst, *rest] = args
fst, ty = ExprSynthesizer(self.ctx).synthesize(fst)
checker = ExprChecker(self.ctx)
for i in range(len(rest)):
rest[i], subst = checker.check(rest[i], ty)
assert len(subst) == 0, "Array element type is closed"
result_ty = array_type(ty, len(args))
call = GlobalCall(
def_id=self.func.id, args=[fst, *rest], type_args=result_ty.args
)
return with_loc(self.node, call), result_ty
match args:
case []:
raise GuppyTypeError(
"Cannot infer the array element type. Consider adding a type "
"annotation.",
self.node,
)
# Either an array comprehension
case [DesugaredGeneratorExpr() as compr]:
return self.synthesize_array_comprehension(compr)
# Or a list of array elements
case [fst, *rest]:
fst, ty = ExprSynthesizer(self.ctx).synthesize(fst)
checker = ExprChecker(self.ctx)
for i in range(len(rest)):
rest[i], subst = checker.check(rest[i], ty)
assert len(subst) == 0, "Array element type is closed"
result_ty = array_type(ty, len(args))
call = GlobalCall(
def_id=self.func.id, args=[fst, *rest], type_args=result_ty.args
)
return with_loc(self.node, call), result_ty
case args:
return assert_never(args) # type: ignore[arg-type]
Comment on lines +231 to +232
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mypy isn't smart enough to infer that this code is unreachable, therefore type: ignore[arg-type]


def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
if not is_array_type(ty):
raise GuppyTypeError(
f"Expected expression of type `{ty}`, got `array`", self.node
)
subst: Subst = {}
match ty.args:
case [TypeArg(ty=elem_ty), ConstArg(ConstValue(value=int(length)))]:
subst: Subst = {}
checker = ExprChecker(self.ctx)
for i in range(len(args)):
args[i], s = checker.check(args[i], elem_ty.substitute(subst))
subst |= s
if len(args) != length:
raise GuppyTypeError(
f"Expected expression of type `{ty}`, got "
f"`array[{elem_ty}, {len(args)}]`",
self.node,
)
call = GlobalCall(def_id=self.func.id, args=args, type_args=ty.args)
return with_loc(self.node, call), subst
match args:
# Either an array comprehension
case [DesugaredGeneratorExpr() as compr]:
# TODO: We could use the type information to infer some stuff
# in the comprehension
arr_compr, res_ty = self.synthesize_array_comprehension(compr)
subst, _ = check_type_against(res_ty, ty, self.node)
return arr_compr, subst
# Or a list of array elements
case args:
checker = ExprChecker(self.ctx)
for i in range(len(args)):
args[i], s = checker.check(
args[i], elem_ty.substitute(subst)
)
subst |= s
if len(args) != length:
raise GuppyTypeError(
f"Expected expression of type `{ty}`, got "
f"`array[{elem_ty}, {len(args)}]`",
self.node,
)
call = GlobalCall(
def_id=self.func.id, args=args, type_args=ty.args
)
return with_loc(self.node, call), subst
case type_args:
raise InternalGuppyError(f"Invalid array type args: {type_args}")

def synthesize_array_comprehension(
self, compr: DesugaredGeneratorExpr
) -> tuple[DesugaredArrayComp, Type]:
# Array comprehensions require a static size. To keep things simple, we'll only
# allow a single generator for now, so we don't have to reason about products
# of iterator sizes.
if len(compr.generators) > 1:
# Individual generator objects unfortunately don't have a span in Python's
# AST, so we have to use the whole expression span
raise GuppyError("Nested array comprehensions are not supported", compr)
[gen] = compr.generators
# Similarly, dynamic if guards are not allowed
if gen.ifs:
raise GuppyError(
"If guards are not allowed in array comprehensions", gen.ifs[0]
)
# Extract the iterator size
match gen.iter_assign:
case ast.Assign(value=MakeIter() as make_iter):
make_iter.value, iter_ty = ExprSynthesizer(self.ctx).synthesize(
make_iter.value
)
# The iterator must have a static size hint
if not is_sized_iter_type(iter_ty):
raise GuppyError(
f"Iterator of type `{iter_ty}` doesn't have a static size "
"require for array comprehensions",
make_iter,
)
size = get_iter_size(iter_ty)
case _:
raise InternalGuppyError("Invalid iterator assign statement")
# Finally, type check the comprehension
[gen], elt, elt_ty = synthesize_comprehension(compr, [gen], compr.elt, self.ctx)
array_compr = DesugaredArrayComp(
elt=elt, generator=gen, length=size, elt_ty=elt_ty
)
return with_loc(compr, array_compr), array_type(elt_ty, size)


#: Maximum length of a tag in the `result` function.
TAG_MAX_LEN = 200
Expand Down
2 changes: 1 addition & 1 deletion guppylang/prelude/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class nat:
class array(Generic[_T, _n]):
"""Class to import in order to use arrays."""

def __init__(self, *args: _T):
def __init__(self, *args: Any):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To prevent mypy from complaining

pass


Expand Down
4 changes: 2 additions & 2 deletions guppylang/tys/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ def is_sized_iter_type(ty: Type) -> TypeGuard[OpaqueType]:

def get_element_type(ty: Type) -> Type:
assert isinstance(ty, OpaqueType)
assert ty.defn == list_type_def
(arg,) = ty.args
assert ty.defn == list_type_def or ty.defn == array_type_def
(arg, *_) = ty.args
assert isinstance(arg, TypeArg)
return arg.ty

Expand Down
7 changes: 7 additions & 0 deletions tests/error/array_errors/comprehension_wrong_length.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:13

11: @guppy(module)
12: def main() -> array[int, 42]:
13: return array(i for i in range(10))
^^^^^^^^^^^^^^^^^^^^^^^^^^^
GuppyTypeError: Expected expression of type `array[int, 42]`, got `array[int, 10]`
16 changes: 16 additions & 0 deletions tests/error/array_errors/comprehension_wrong_length.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import guppylang.prelude.quantum as quantum
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.builtins import array


module = GuppyModule("test")
module.load(quantum)


@guppy(module)
def main() -> array[int, 42]:
return array(i for i in range(10))


module.compile()
7 changes: 7 additions & 0 deletions tests/error/array_errors/guarded_comprehension.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:13

11: @guppy(module)
12: def main() -> None:
13: array(i for i in range(100) if i % 2 == 0)
^^^^^^^^^^
GuppyError: If guards are not allowed in array comprehensions
16 changes: 16 additions & 0 deletions tests/error/array_errors/guarded_comprehension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import guppylang.prelude.quantum as quantum
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.builtins import array


module = GuppyModule("test")
module.load(quantum)


@guppy(module)
def main() -> None:
array(i for i in range(100) if i % 2 == 0)


module.compile()
7 changes: 7 additions & 0 deletions tests/error/array_errors/nested_comprehension.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:13

11: @guppy(module)
12: def main() -> array[int, 50]:
13: return array(0 for _ in range(10) for _ in range(5))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
GuppyError: Nested array comprehensions are not supported
16 changes: 16 additions & 0 deletions tests/error/array_errors/nested_comprehension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import guppylang.prelude.quantum as quantum
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.builtins import array


module = GuppyModule("test")
module.load(quantum)


@guppy(module)
def main() -> array[int, 50]:
return array(0 for _ in range(10) for _ in range(5))


module.compile()
7 changes: 7 additions & 0 deletions tests/error/array_errors/non_static_comprehension.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:13

11: @guppy(module)
12: def main(n: int) -> None:
13: array(i for i in range(n))
^^^^^^^^
GuppyError: Iterator of type `Range` doesn't have a static size require for array comprehensions
16 changes: 16 additions & 0 deletions tests/error/array_errors/non_static_comprehension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import guppylang.prelude.quantum as quantum
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.builtins import array


module = GuppyModule("test")
module.load(quantum)


@guppy(module)
def main(n: int) -> None:
array(i for i in range(n))


module.compile()
Loading