Skip to content

Commit

Permalink
feat: Type check array subscripts (#420)
Browse files Browse the repository at this point in the history
Closes #416. See #415 for context.

* Adds a new `SubscriptAccess` place that will be used to track array
subscripts during linearity checking
* This place is emitted when checking a subscript AST node
* Ensures that subscripts in inout positions also implement a
`__setitem__` method
  • Loading branch information
mark-koch authored Aug 30, 2024
1 parent 59e82c8 commit 61997cb
Show file tree
Hide file tree
Showing 15 changed files with 220 additions and 36 deletions.
43 changes: 41 additions & 2 deletions guppylang/checker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@
#:
#: All places are equipped with a unique id, a type and an optional definition AST
#: location. During linearity checking, they are tracked separately.
Place: TypeAlias = "Variable | FieldAccess"
Place: TypeAlias = "Variable | FieldAccess | SubscriptAccess"

#: Unique identifier for a `Place`.
PlaceId: TypeAlias = "Variable.Id | FieldAccess.Id"
PlaceId: TypeAlias = "Variable.Id | FieldAccess.Id | SubscriptAccess.Id"


@dataclass(frozen=True)
Expand Down Expand Up @@ -154,6 +154,45 @@ def replace_defined_at(self, node: AstNode | None) -> "FieldAccess":
return replace(self, exact_defined_at=node)


@dataclass(frozen=True)
class SubscriptAccess:
"""A place identifying a subscript `place[item]` access."""

parent: Place
item: Variable
ty: Type
item_expr: ast.expr
getitem_call: ast.expr
#: Only populated if this place occurs in an inout position
setitem_call: ast.expr | None = None

@dataclass(frozen=True)
class Id:
"""Identifier for subscript places."""

parent: PlaceId
item: Variable.Id

@cached_property
def id(self) -> "SubscriptAccess.Id":
"""The unique `PlaceId` identifier for this place."""
return SubscriptAccess.Id(self.parent.id, self.item.id)

@cached_property
def defined_at(self) -> AstNode | None:
"""Optional location where this place was last assigned to."""
return self.parent.defined_at

@property
def describe(self) -> str:
"""A human-readable description of this place for error messages."""
return f"Subscript `{self}`"

def __str__(self) -> str:
"""String representation of this place."""
return f"{self.parent}[...]"


PyScope = dict[str, Any]


Expand Down
82 changes: 74 additions & 8 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import sys
import traceback
from contextlib import suppress
from dataclasses import replace
from typing import Any, NoReturn, cast

from guppylang.ast_util import (
Expand All @@ -35,12 +36,15 @@
with_loc,
with_type,
)
from guppylang.cfg.builder import tmp_vars
from guppylang.checker.core import (
Context,
DummyEvalDict,
FieldAccess,
Globals,
Locals,
Place,
SubscriptAccess,
Variable,
)
from guppylang.definition.ty import TypeDef
Expand All @@ -56,13 +60,15 @@
DesugaredListComp,
FieldAccessAndDrop,
GlobalName,
InoutReturnSentinel,
IterEnd,
IterHasNext,
IterNext,
LocalCall,
MakeIter,
PlaceNode,
PyExpr,
SubscriptAccessAndDrop,
TensorCall,
TypeApply,
)
Expand Down Expand Up @@ -447,7 +453,7 @@ def _synthesize_binary(
node,
)

def _synthesize_instance_func(
def synthesize_instance_func(
self,
node: ast.expr,
args: list[ast.expr],
Expand Down Expand Up @@ -495,16 +501,37 @@ def visit_Compare(self, node: ast.Compare) -> tuple[ast.expr, Type]:

def visit_Subscript(self, node: ast.Subscript) -> tuple[ast.expr, Type]:
node.value, ty = self.synthesize(node.value)
item_expr, item_ty = self.synthesize(node.slice)
# Give the item a unique name so we can refer to it later in case we also want
# to compile a call to `__setitem__`
item = Variable(next(tmp_vars), item_ty, item_expr)
item_node = with_type(item_ty, with_loc(item_expr, PlaceNode(place=item)))
# Check a call to the `__getitem__` instance function
exp_sig = FunctionType(
[
FuncInput(ty, InputFlags.NoFlags),
FuncInput(ty, InputFlags.Inout),
FuncInput(ExistentialTypeVar.fresh("Key", False), InputFlags.NoFlags),
],
ExistentialTypeVar.fresh("Val", False),
)
return self._synthesize_instance_func(
node.value, [node.slice], "__getitem__", "not subscriptable", exp_sig
getitem_expr, result_ty = self.synthesize_instance_func(
node.value, [item_node], "__getitem__", "not subscriptable", exp_sig
)
# Subscripting a place is itself a place
expr: ast.expr
if isinstance(node.value, PlaceNode):
place = SubscriptAccess(
node.value.place, item, result_ty, item_expr, getitem_expr
)
expr = PlaceNode(place=place)
else:
# If the subscript is not on a place, then there is no way to address the
# other indices after this one has been projected out (e.g. `f()[0]` makes
# you loose access to all elements besides 0).
expr = SubscriptAccessAndDrop(
item=item, item_expr=item_expr, getitem_expr=getitem_expr
)
return with_loc(node, expr), result_ty

def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]:
if len(node.keywords) > 0:
Expand Down Expand Up @@ -550,7 +577,7 @@ def visit_MakeIter(self, node: MakeIter) -> tuple[ast.expr, Type]:
exp_sig = FunctionType(
[FuncInput(ty, InputFlags.NoFlags)], ExistentialTypeVar.fresh("Iter", False)
)
expr, ty = self._synthesize_instance_func(
expr, ty = self.synthesize_instance_func(
node.value, [], "__iter__", "not iterable", exp_sig
)

Expand All @@ -574,7 +601,7 @@ def visit_IterHasNext(self, node: IterHasNext) -> tuple[ast.expr, Type]:
exp_sig = FunctionType(
[FuncInput(ty, InputFlags.NoFlags)], TupleType([bool_type(), ty])
)
return self._synthesize_instance_func(
return self.synthesize_instance_func(
node.value, [], "__hasnext__", "not an iterator", exp_sig, True
)

Expand All @@ -584,14 +611,14 @@ def visit_IterNext(self, node: IterNext) -> tuple[ast.expr, Type]:
[FuncInput(ty, InputFlags.NoFlags)],
TupleType([ExistentialTypeVar.fresh("T", False), ty]),
)
return self._synthesize_instance_func(
return self.synthesize_instance_func(
node.value, [], "__next__", "not an iterator", exp_sig, True
)

def visit_IterEnd(self, node: IterEnd) -> tuple[ast.expr, Type]:
node.value, ty = self.synthesize(node.value)
exp_sig = FunctionType([FuncInput(ty, InputFlags.NoFlags)], NoneType())
return self._synthesize_instance_func(
return self.synthesize_instance_func(
node.value, [], "__end__", "not an iterator", exp_sig, True
)

Expand Down Expand Up @@ -714,6 +741,8 @@ def type_check_args(
new_args: list[ast.expr] = []
for inp, func_inp in zip(inputs, func_ty.inputs, strict=True):
a, s = ExprChecker(ctx).check(inp, func_inp.ty.substitute(subst), "argument")
if InputFlags.Inout in func_inp.flags and isinstance(a, PlaceNode):
a.place = check_inout_arg_place(a.place, ctx, a)
new_args.append(a)
subst |= s

Expand All @@ -734,6 +763,43 @@ def type_check_args(
return new_args, subst


def check_inout_arg_place(place: Place, ctx: Context, node: PlaceNode) -> Place:
"""Performs additional checks for place arguments in @inout position.
In particular, we need to check that places involving `place[item]` subscripts
implement the corresponding `__setitem__` method.
"""
match place:
case Variable():
return place
case FieldAccess(parent=parent):
return replace(place, parent=check_inout_arg_place(parent, ctx, node))
case SubscriptAccess(parent=parent, item=item, ty=ty):
# Check a call to the `__setitem__` instance function
exp_sig = FunctionType(
[
FuncInput(parent.ty, InputFlags.Inout),
FuncInput(item.ty, InputFlags.NoFlags),
FuncInput(ty, InputFlags.NoFlags),
],
NoneType(),
)
setitem_args = [
with_type(parent.ty, with_loc(node, PlaceNode(parent))),
with_type(item.ty, with_loc(node, PlaceNode(item))),
with_type(ty, with_loc(node, InoutReturnSentinel(var=place))),
]
setitem_call, _ = ExprSynthesizer(ctx).synthesize_instance_func(
setitem_args[0],
setitem_args[1:],
"__setitem__",
"not allowed in a subscripted `@inout` position",
exp_sig,
True,
)
return replace(place, setitem_call=setitem_call)


def synthesize_call(
func_ty: FunctionType, args: list[ast.expr], node: AstNode, ctx: Context
) -> tuple[list[ast.expr], Type, Inst]:
Expand Down
2 changes: 2 additions & 0 deletions guppylang/checker/linearity_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Locals,
Place,
PlaceId,
SubscriptAccess,
Variable,
)
from guppylang.definition.value import CallableDef
Expand Down Expand Up @@ -169,6 +170,7 @@ def _reassign_inout_args(self, func_ty: FunctionType, args: list[ast.expr]) -> N
match arg:
case PlaceNode(place=place):
for leaf in leaf_places(place):
assert not isinstance(leaf, SubscriptAccess)
leaf = leaf.replace_defined_at(arg)
self.scope.assign(leaf)
case arg if inp.ty.linear:
Expand Down
16 changes: 15 additions & 1 deletion guppylang/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,20 @@ class FieldAccessAndDrop(ast.expr):
)


class SubscriptAccessAndDrop(ast.expr):
"""A subscript element access on an object, dropping all the remaining items."""

item: "Variable"
item_expr: ast.expr
getitem_expr: ast.expr

_fields = (
"item",
"item_expr",
"getitem_expr",
)


class MakeIter(ast.expr):
"""Creates an iterator using the `__iter__` magic method.
Expand Down Expand Up @@ -205,7 +219,7 @@ class InoutReturnSentinel(ast.expr):
"""An invisible expression corresponding to an implicit use of @inout vars whenever
a function returns."""

var: "Variable | str"
var: "Place | str"

_fields = ("var",)

Expand Down
10 changes: 9 additions & 1 deletion guppylang/prelude/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,15 @@ class Array:
"ArrayGet", args=[int_arg(), type_arg()], variable_remap={0: 1, 1: 0}
),
)
def __getitem__(self: array[T, n], idx: int) -> T: ...
def __getitem__(self: array[L, n] @ inout, idx: int) -> L: ...

@guppy.hugr_op(
builtins,
custom_op(
"ArraySet", args=[int_arg(), type_arg()], variable_remap={0: 1, 1: 0}
),
)
def __setitem__(self: array[L, n] @ inout, idx: int, value: L) -> None: ...

@guppy.custom(builtins, checker=ArrayLenChecker())
def __len__(self: array[T, n]) -> int: ...
Expand Down
7 changes: 0 additions & 7 deletions tests/error/array_errors/linear_index.err

This file was deleted.

17 changes: 0 additions & 17 deletions tests/error/array_errors/linear_index.py

This file was deleted.

7 changes: 7 additions & 0 deletions tests/error/inout_errors/subscript_not_setable.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:24

22: @guppy(module)
23: def test(c: MyImmutableContainer) -> MyImmutableContainer:
24: foo(c[0])
^^^^
GuppyTypeError: Expression of type `MyImmutableContainer` is not allowed in a subscripted `@inout` position since it does not implement the `__setitem__` method
28 changes: 28 additions & 0 deletions tests/error/inout_errors/subscript_not_setable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.builtins import inout
from guppylang.prelude.quantum import qubit, quantum

module = GuppyModule("test")
module.load(quantum)


@guppy.declare(module)
def foo(q: qubit @inout) -> None: ...


@guppy.struct(module)
class MyImmutableContainer:
q: qubit

@guppy.declare(module)
def __getitem__(self: "MyImmutableContainer" @inout, idx: int) -> qubit: ...


@guppy(module)
def test(c: MyImmutableContainer) -> MyImmutableContainer:
foo(c[0])
return c


module.compile()
7 changes: 7 additions & 0 deletions tests/error/type_errors/not_subscriptable.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:9

7: @guppy(module)
8: def foo(x: int) -> None:
9: x[0]
^
GuppyTypeError: Expression of type `int` is not subscriptable
12 changes: 12 additions & 0 deletions tests/error/type_errors/not_subscriptable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule

module = GuppyModule("test")


@guppy(module)
def foo(x: int) -> None:
x[0]


module.compile()
7 changes: 7 additions & 0 deletions tests/error/type_errors/subscript_bad_item.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:10

8: @guppy(module)
9: def foo(xs: array[int, 42]) -> int:
10: return xs[1.0]
^^^
GuppyTypeError: Expected argument of type `int`, got `float`
Loading

0 comments on commit 61997cb

Please sign in to comment.