Skip to content

Commit

Permalink
feat!: Allow linear data inside lists (#531)
Browse files Browse the repository at this point in the history
Closes #530 and closes #524.

* Changes the signature of list methods to allow linear list elements
(implicitly using the inout feature).
* Removes methods that are not supported for now (see #528)
* Small change to the unification logic to be a bit laxer about
requiring matching `@owned` flags if one of the arguments is not linear
(otherwise we couldn't unify `L @owned` with any classical type). We'll
need to revisit that once we have flags that can be put on classical
types.

BREAKING CHANGE: Unsupported list methods have been removed.
  • Loading branch information
mark-koch authored Oct 3, 2024
1 parent b4fae3f commit 229be2e
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 101 deletions.
79 changes: 18 additions & 61 deletions guppylang/prelude/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
CallableChecker,
CoercingChecker,
DunderChecker,
FailingChecker,
NewArrayChecker,
ResultChecker,
ReversingChecker,
Expand Down Expand Up @@ -505,77 +504,35 @@ def __trunc__(self: float) -> float: ...

@guppy.extend_type(list_type_def)
class List:
@guppy.hugr_op(unsupported_op("Append"))
def __add__(self: list[T], other: list[T]) -> list[T]: ...

@guppy.hugr_op(unsupported_op("IsEmpty"))
def __bool__(self: list[T]) -> bool: ...

@guppy.hugr_op(unsupported_op("Contains"))
def __contains__(self: list[T], el: T) -> bool: ...

@guppy.hugr_op(unsupported_op("AssertEmpty"))
def __end__(self: list[T]) -> None: ...

@guppy.hugr_op(unsupported_op("Lookup"))
def __getitem__(self: list[T], idx: int) -> T: ...

@guppy.hugr_op(unsupported_op("IsNotEmpty"))
def __hasnext__(self: list[T]) -> tuple[bool, list[T]]: ...

@guppy.custom(NoopCompiler())
def __iter__(self: list[T]) -> list[T]: ...
@guppy.hugr_op(unsupported_op("pop")) # TODO: unwrap and swap None
def __getitem__(self: list[L], idx: int) -> L: ...

@guppy.hugr_op(unsupported_op("Length"))
def __len__(self: list[T]) -> int: ...

@guppy.hugr_op(unsupported_op("Repeat"))
def __mul__(self: list[T], other: int) -> list[T]: ...
@guppy.hugr_op(unsupported_op("set")) # TODO: check None and unwrap
def __setitem__(self: list[L], idx: int, value: L @ owned) -> None: ...

@guppy.hugr_op(unsupported_op("Pop"))
def __next__(self: list[T]) -> tuple[T, list[T]]: ...
@guppy.hugr_op(unsupported_op("length")) # TODO: inout return in wrong order
def __len__(self: list[L]) -> int: ...

@guppy.custom(checker=UnsupportedChecker(), higher_order_value=False)
def __new__(x): ...

@guppy.custom(checker=FailingChecker("Guppy lists are immutable"))
def __setitem__(self: list[T], idx: int, value: T) -> None: ...

@guppy.hugr_op(unsupported_op("Append"), ReversingChecker())
def __radd__(self: list[T], other: list[T]) -> list[T]: ...

@guppy.hugr_op(unsupported_op("Repeat"))
def __rmul__(self: list[T], other: int) -> list[T]: ...

@guppy.custom(checker=FailingChecker("Guppy lists are immutable"))
def append(self: list[T], elt: T) -> None: ...

@guppy.custom(checker=FailingChecker("Guppy lists are immutable"))
def clear(self: list[T]) -> None: ...

@guppy.custom(NoopCompiler()) # Can be noop since lists are immutable
def copy(self: list[T]) -> list[T]: ...

@guppy.hugr_op(unsupported_op("Count"))
def count(self: list[T], elt: T) -> int: ...

@guppy.custom(checker=FailingChecker("Guppy lists are immutable"))
def extend(self: list[T], seq: None) -> None: ...
@guppy.custom(NoopCompiler()) # TODO: define via Guppy source instead
def __iter__(self: list[L] @ owned) -> list[L]: ...

@guppy.hugr_op(unsupported_op("Find"))
def index(self: list[T], elt: T) -> int: ...
@guppy.hugr_op(unsupported_op("IsNotEmpty")) # TODO
def __hasnext__(self: list[L] @ owned) -> tuple[bool, list[L]]: ...

@guppy.custom(checker=FailingChecker("Guppy lists are immutable"))
def pop(self: list[T], idx: int) -> None: ...
@guppy.hugr_op(unsupported_op("AssertEmpty")) # TODO
def __end__(self: list[L] @ owned) -> None: ...

@guppy.custom(checker=FailingChecker("Guppy lists are immutable"))
def remove(self: list[T], elt: T) -> None: ...
@guppy.hugr_op(unsupported_op("pop"))
def __next__(self: list[L] @ owned) -> tuple[L, list[L]]: ...

@guppy.custom(checker=FailingChecker("Guppy lists are immutable"))
def reverse(self: list[T]) -> None: ...
@guppy.hugr_op(unsupported_op("push"))
def append(self: list[L], item: L @ owned) -> None: ...

@guppy.custom(checker=FailingChecker("Guppy lists are immutable"))
def sort(self: list[T]) -> None: ...
@guppy.hugr_op(unsupported_op("pop")) # TODO
def pop(self: list[L]) -> L: ...


linst = list
Expand Down
12 changes: 3 additions & 9 deletions guppylang/tys/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,20 +103,14 @@ def check_instantiate(
class _ListTypeDef(OpaqueTypeDef):
"""Type definition associated with the builtin `list` type.
We have a custom definition to give a nicer error message if the user tries to put
linear data into a regular list.
We have a custom definition to disable usage of lists unless experimental features
are enabled.
"""

def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
) -> OpaqueType:
check_lists_enabled(loc)
if len(args) == 1:
[arg] = args
if isinstance(arg, TypeArg) and arg.ty.linear:
raise GuppyError(
"Type `list` cannot store linear data, use `linst` instead", loc
)
return super().check_instantiate(args, globals, loc)


Expand Down Expand Up @@ -192,7 +186,7 @@ def _array_to_hugr(args: Sequence[Argument]) -> ht.Type:
id=DefId.fresh(),
name="list",
defined_at=None,
params=[TypeParam(0, "T", can_be_linear=False)],
params=[TypeParam(0, "T", can_be_linear=True)],
always_linear=False,
to_hugr=_list_to_hugr,
)
Expand Down
7 changes: 4 additions & 3 deletions guppylang/tys/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,10 +667,11 @@ def unify(s: Type | Const, t: Type | Const, subst: "Subst | None") -> "Subst | N
case NoneType(), NoneType():
return subst
case FunctionType() as s, FunctionType() as t if s.params == t.params:
if len(s.inputs) != len(t.inputs) or any(
a.flags != b.flags for a, b in zip(s.inputs, t.inputs, strict=True)
):
if len(s.inputs) != len(t.inputs):
return None
for a, b in zip(s.inputs, t.inputs, strict=True):
if a.ty.linear and b.ty.linear and a.flags != b.flags:
return None
return _unify_args(s, t, subst)
case TupleType() as s, TupleType() as t:
return _unify_args(s, t, subst)
Expand Down
6 changes: 0 additions & 6 deletions tests/error/misc_errors/list_linear.err

This file was deleted.

17 changes: 0 additions & 17 deletions tests/error/misc_errors/list_linear.py

This file was deleted.

4 changes: 2 additions & 2 deletions tests/error/poly_errors/non_linear2.err
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ Guppy compilation failed. Error in file $FILE:23
21: @guppy(module)
22: def main() -> None:
23: foo(h)
^
GuppyTypeError: Expected argument of type `?T -> ?T`, got `qubit @owned -> qubit`
^^^^^^
GuppyTypeError: Cannot instantiate non-linear type variable `T` in type `forall T. (T -> T) -> None` with linear type `qubit`
4 changes: 2 additions & 2 deletions tests/error/poly_errors/non_linear3.err
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ Guppy compilation failed. Error in file $FILE:25
23: @guppy(module)
24: def main() -> None:
25: foo(h)
^
GuppyTypeError: Expected argument of type `?T -> None`, got `qubit -> None`
^^^^^^
GuppyTypeError: Cannot instantiate non-linear type variable `T` in type `forall T. (T -> None) -> None` with linear type `qubit`
26 changes: 25 additions & 1 deletion tests/integration/test_list.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import pytest
from guppylang import qubit, guppy, GuppyModule
from guppylang.prelude.builtins import owned

from tests.util import compile_guppy

Expand Down Expand Up @@ -29,6 +31,17 @@ def test(x: float) -> list[float]:
validate(test)


def test_push_pop(validate):
@compile_guppy
def test(xs: list[int]) -> bool:
xs.append(3)
x = xs.pop()
return x == 3

validate(test)


@pytest.mark.skip("See https://github.com/CQCL/guppylang/issues/528")
def test_arith(validate):
@compile_guppy
def test(xs: list[int]) -> list[int]:
Expand All @@ -39,10 +52,21 @@ def test(xs: list[int]) -> list[int]:
validate(test)


@pytest.mark.skip("Requires updating lists to use inout")
def test_subscript(validate):
@compile_guppy
def test(xs: list[float], i: int) -> float:
return xs[2 * i]

validate(test)


def test_linear(validate):
module = GuppyModule("test")
module.load(qubit)

@guppy(module)
def test(xs: list[qubit], q: qubit @owned) -> int:
xs.append(q)
return len(xs)

validate(module.compile())

0 comments on commit 229be2e

Please sign in to comment.