diff --git a/guppylang/checker/core.py b/guppylang/checker/core.py index 4f732b13..ce2d6818 100644 --- a/guppylang/checker/core.py +++ b/guppylang/checker/core.py @@ -29,6 +29,7 @@ list_type_def, nat_type_def, none_type_def, + sized_iter_type_def, tuple_type_def, ) from guppylang.tys.ty import ( @@ -222,6 +223,7 @@ def default() -> "Globals": float_type_def, list_type_def, array_type_def, + sized_iter_type_def, ] defs = {defn.id: defn for defn in builtins} names = {defn.name: defn.id for defn in builtins} diff --git a/guppylang/definition/custom.py b/guppylang/definition/custom.py index a3eeef5a..bc670759 100644 --- a/guppylang/definition/custom.py +++ b/guppylang/definition/custom.py @@ -261,12 +261,16 @@ def _setup(self, ctx: Context, node: AstNode, func: CustomFunctionDef) -> None: self.node = node self.func = func - @abstractmethod def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]: """Checks the return value against a given type. Returns a (possibly) transformed and annotated AST node for the call. """ + from guppylang.checker.expr_checker import check_type_against + + expr, res_ty = self.synthesize(args) + subst, _ = check_type_against(res_ty, ty, self.node) + return expr, subst @abstractmethod def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: diff --git a/guppylang/prelude/_internal/checker.py b/guppylang/prelude/_internal/checker.py index e6d30d82..7a2eafe2 100644 --- a/guppylang/prelude/_internal/checker.py +++ b/guppylang/prelude/_internal/checker.py @@ -1,6 +1,7 @@ import ast +from typing import cast -from guppylang.ast_util import AstNode, with_loc +from guppylang.ast_util import AstNode, with_loc, with_type from guppylang.checker.core import Context from guppylang.checker.expr_checker import ( ExprChecker, @@ -15,6 +16,7 @@ CustomFunctionDef, DefaultCallChecker, ) +from guppylang.definition.struct import CheckedStructDef, RawStructDef from guppylang.definition.value import CallableDef from guppylang.error import GuppyError, GuppyTypeError, InternalGuppyError from guppylang.nodes import GlobalCall, ResultExpr @@ -25,6 +27,7 @@ int_type, is_array_type, is_bool_type, + sized_iter_type, ) from guppylang.tys.const import Const, ConstValue from guppylang.tys.subst import Inst, Subst @@ -32,6 +35,7 @@ FunctionType, NoneType, NumericType, + StructType, Type, unify, ) @@ -279,3 +283,41 @@ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]: @staticmethod def _is_numeric_or_bool_type(ty: Type) -> bool: return isinstance(ty, NumericType) or is_bool_type(ty) + + +class RangeChecker(CustomCallChecker): + """Call checker for the `range` function.""" + + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: + check_num_args(1, len(args), self.node) + [stop] = args + stop, _ = ExprChecker(self.ctx).check(stop, int_type(), "argument") + range_iter, range_ty = self.make_range(stop) + if isinstance(stop, ast.Constant): + return to_sized_iter(range_iter, range_ty, stop.value, self.ctx) + return range_iter, range_ty + + def range_ty(self) -> StructType: + from guppylang.prelude.builtins import Range + + def_id = cast(RawStructDef, Range).id + range_type_def = self.ctx.globals.defs[def_id] + assert isinstance(range_type_def, CheckedStructDef) + return StructType([], range_type_def) + + def make_range(self, stop: ast.expr) -> tuple[ast.expr, Type]: + make_range = self.ctx.globals.get_instance_func(self.range_ty(), "__new__") + assert make_range is not None + start = with_type(int_type(), with_loc(self.node, ast.Constant(value=0))) + return make_range.synthesize_call([start, stop], self.node, self.ctx) + + +def to_sized_iter( + iterator: ast.expr, range_ty: Type, size: int, ctx: Context +) -> tuple[ast.expr, Type]: + """Adds a static size annotation to an iterator.""" + sized_iter_ty = sized_iter_type(range_ty, size) + make_sized_iter = ctx.globals.get_instance_func(sized_iter_ty, "__new__") + assert make_sized_iter is not None + sized_iter, _ = make_sized_iter.check_call([iterator], sized_iter_ty, iterator, ctx) + return sized_iter, sized_iter_ty diff --git a/guppylang/prelude/builtins.py b/guppylang/prelude/builtins.py index baa079df..57130818 100644 --- a/guppylang/prelude/builtins.py +++ b/guppylang/prelude/builtins.py @@ -14,6 +14,7 @@ CoercingChecker, DunderChecker, NewArrayChecker, + RangeChecker, ResultChecker, ReversingChecker, UnsupportedChecker, @@ -53,6 +54,7 @@ int_type_def, list_type_def, nat_type_def, + sized_iter_type_def, ) guppy.init_module(import_builtins=False) @@ -559,6 +561,32 @@ def __len__(self: array[L, n]) -> int: ... def __new__(): ... +@guppy.extend_type(sized_iter_type_def) +class SizedIter: + """A wrapper around an iterator type `T` promising that the iterator will yield + exactly `n` values. + + Annotating an iterator with an incorrect size is undefined behaviour. + """ + + def __class_getitem__(cls, item: Any) -> type: + # Dummy implementation to allow subscripting of the `SizedIter` type in + # positions that are evaluated by the Python interpreter + return cls + + @guppy.custom(NoopCompiler()) + def __new__(iterator: L @ owned) -> "SizedIter[L, n]": # type: ignore[type-arg] + """Casts an iterator into a `SizedIter`.""" + + @guppy.custom(NoopCompiler()) + def unwrap_iter(self: "SizedIter[L, n]" @ owned) -> L: + """Extracts the actual iterator.""" + + @guppy.custom(NoopCompiler()) + def __iter__(self: "SizedIter[L, n]" @ owned) -> L: + """Extracts the actual iterator.""" + + # TODO: This is a temporary hack until we have implemented the proper results mechanism. @guppy.custom(checker=ResultChecker(), higher_order_value=False) def result(tag, value): ... @@ -769,43 +797,33 @@ def property(x): ... @guppy.struct class Range: - stop: int - - @guppy - def __iter__(self: "Range") -> "RangeIter": - return RangeIter(0, self.stop) # type: ignore[call-arg] - - -@guppy.struct -class RangeIter: next: int stop: int @guppy - def __iter__(self: "RangeIter") -> "RangeIter": + def __iter__(self: "Range") -> "Range": return self @guppy - def __hasnext__(self: "RangeIter") -> tuple[bool, "RangeIter"]: + def __hasnext__(self: "Range") -> tuple[bool, "Range"]: return (self.next < self.stop, self) @guppy - def __next__(self: "RangeIter") -> tuple[int, "RangeIter"]: + def __next__(self: "Range") -> tuple[int, "Range"]: # Fine not to check bounds while we can only be called from inside a `for` loop. # if self.start >= self.stop: # raise StopIteration - return (self.next, RangeIter(self.next + 1, self.stop)) # type: ignore[call-arg] + return (self.next, Range(self.next + 1, self.stop)) # type: ignore[call-arg] @guppy - def __end__(self: "RangeIter") -> None: + def __end__(self: "Range") -> None: pass -@guppy +@guppy.custom(checker=RangeChecker(), higher_order_value=False) def range(stop: int) -> Range: """Limited version of python range(). Only a single argument (stop/limit) is supported.""" - return Range(stop) # type: ignore[call-arg] @guppy.custom(checker=UnsupportedChecker(), higher_order_value=False) diff --git a/guppylang/tys/builtin.py b/guppylang/tys/builtin.py index 1634ec14..f201ca61 100644 --- a/guppylang/tys/builtin.py +++ b/guppylang/tys/builtin.py @@ -140,6 +140,13 @@ def _array_to_hugr(args: Sequence[Argument]) -> ht.Type: return array.instantiate([len_arg.to_hugr(), ht.TypeTypeArg(elem_ty)]) +def _sized_iter_to_hugr(args: Sequence[Argument]) -> ht.Type: + [ty_arg, len_arg] = args + assert isinstance(ty_arg, TypeArg) + assert isinstance(len_arg, ConstArg) + return ty_arg.ty.to_hugr() + + callable_type_def = CallableTypeDef(DefId.fresh(), None) tuple_type_def = _TupleTypeDef(DefId.fresh(), None) none_type_def = _NoneTypeDef(DefId.fresh(), None) @@ -179,6 +186,17 @@ def _array_to_hugr(args: Sequence[Argument]) -> ht.Type: always_linear=False, to_hugr=_array_to_hugr, ) +sized_iter_type_def = OpaqueTypeDef( + id=DefId.fresh(), + name="SizedIter", + defined_at=None, + params=[ + TypeParam(0, "T", can_be_linear=True), + ConstParam(1, "n", NumericType(NumericType.Kind.Nat)), + ], + always_linear=False, + to_hugr=_sized_iter_to_hugr, +) def bool_type() -> OpaqueType: @@ -200,6 +218,13 @@ def array_type(element_ty: Type, length: int) -> OpaqueType: ) +def sized_iter_type(iter_type: Type, size: int) -> OpaqueType: + nat_type = NumericType(NumericType.Kind.Nat) + return OpaqueType( + [TypeArg(iter_type), ConstArg(ConstValue(nat_type, size))], sized_iter_type_def + ) + + def is_bool_type(ty: Type) -> bool: return isinstance(ty, OpaqueType) and ty.defn == bool_type_def @@ -212,9 +237,23 @@ def is_array_type(ty: Type) -> TypeGuard[OpaqueType]: return isinstance(ty, OpaqueType) and ty.defn == array_type_def +def is_sized_iter_type(ty: Type) -> TypeGuard[OpaqueType]: + return isinstance(ty, OpaqueType) and ty.defn == sized_iter_type_def + + def get_element_type(ty: Type) -> Type: assert isinstance(ty, OpaqueType) assert ty.defn == list_type_def (arg,) = ty.args assert isinstance(arg, TypeArg) return arg.ty + + +def get_iter_size(ty: Type) -> int: + assert isinstance(ty, OpaqueType) + assert ty.defn == sized_iter_type_def + match ty.args: + case [_, ConstArg(ConstValue(value=int(size)))]: + return size + case _: + raise InternalGuppyError("Unexpected type args") diff --git a/tests/integration/test_range.py b/tests/integration/test_range.py index 32d33a61..da9104c4 100644 --- a/tests/integration/test_range.py +++ b/tests/integration/test_range.py @@ -1,5 +1,5 @@ from guppylang.decorator import guppy -from guppylang.prelude.builtins import nat, range +from guppylang.prelude.builtins import nat, range, SizedIter, Range from guppylang.module import GuppyModule from tests.util import compile_guppy @@ -20,7 +20,27 @@ def negative() -> int: total += 100 + x return total + @guppy(module) + def non_static() -> int: + total = 0 + n = 4 + for x in range(n + 1): + total += x + 100 # Make the initial 0 obvious + return total + compiled = module.compile() validate(compiled) run_int_fn(compiled, expected=510) run_int_fn(compiled, expected=0, fn_name="negative") + run_int_fn(compiled, expected=510, fn_name="non_static") + + +def test_static_size(validate): + module = GuppyModule("test") + + @guppy(module) + def negative() -> SizedIter[Range, 10]: + return range(10) + + validate(module.compile()) +