From 086d3e16f2866f5d8640c219cee589de8a91bdc8 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 3 Sep 2024 16:31:12 +0200 Subject: [PATCH] feat: Add array literals --- guppylang/checker/linearity_checker.py | 11 +++- guppylang/compiler/expr_compiler.py | 13 ++++- guppylang/definition/custom.py | 21 +++++-- guppylang/definition/struct.py | 1 + guppylang/prelude/_internal/checker.py | 55 ++++++++++++++++++- guppylang/prelude/_internal/compiler.py | 28 ++++++++++ guppylang/prelude/builtins.py | 10 ++++ guppylang/tys/builtin.py | 8 +++ .../array_errors/new_array_cannot_infer.err | 7 +++ .../array_errors/new_array_cannot_infer.py | 16 ++++++ .../array_errors/new_array_check_fail.err | 7 +++ .../array_errors/new_array_check_fail.py | 16 ++++++ .../array_errors/new_array_elem_mismatch1.err | 7 +++ .../array_errors/new_array_elem_mismatch1.py | 16 ++++++ .../array_errors/new_array_elem_mismatch2.err | 7 +++ .../array_errors/new_array_elem_mismatch2.py | 16 ++++++ .../array_errors/new_array_wrong_length.err | 7 +++ .../array_errors/new_array_wrong_length.py | 16 ++++++ tests/integration/test_array.py | 27 +++++++++ 19 files changed, 279 insertions(+), 10 deletions(-) create mode 100644 tests/error/array_errors/new_array_cannot_infer.err create mode 100644 tests/error/array_errors/new_array_cannot_infer.py create mode 100644 tests/error/array_errors/new_array_check_fail.err create mode 100644 tests/error/array_errors/new_array_check_fail.py create mode 100644 tests/error/array_errors/new_array_elem_mismatch1.err create mode 100644 tests/error/array_errors/new_array_elem_mismatch1.py create mode 100644 tests/error/array_errors/new_array_elem_mismatch2.err create mode 100644 tests/error/array_errors/new_array_elem_mismatch2.py create mode 100644 tests/error/array_errors/new_array_wrong_length.err create mode 100644 tests/error/array_errors/new_array_wrong_length.py diff --git a/guppylang/checker/linearity_checker.py b/guppylang/checker/linearity_checker.py index c95b601a..b467103a 100644 --- a/guppylang/checker/linearity_checker.py +++ b/guppylang/checker/linearity_checker.py @@ -20,6 +20,7 @@ PlaceId, Variable, ) +from guppylang.definition.custom import CustomFunctionDef from guppylang.definition.value import CallableDef from guppylang.error import GuppyError, GuppyTypeError from guppylang.nodes import ( @@ -34,7 +35,7 @@ PlaceNode, TensorCall, ) -from guppylang.tys.ty import FunctionType, InputFlags, StructType +from guppylang.tys.ty import FuncInput, FunctionType, InputFlags, StructType class Scope(Locals[PlaceId, Place]): @@ -184,7 +185,13 @@ def _reassign_inout_args(self, func_ty: FunctionType, args: list[ast.expr]) -> N def visit_GlobalCall(self, node: GlobalCall) -> None: func = self.globals[node.def_id] assert isinstance(func, CallableDef) - func_ty = func.ty.instantiate(node.type_args) + if isinstance(func, CustomFunctionDef) and not func.has_signature: + func_ty = FunctionType( + [FuncInput(get_type(arg), InputFlags.NoFlags) for arg in node.args], + get_type(node), + ) + else: + func_ty = func.ty.instantiate(node.type_args) self._visit_call_args(func_ty, node.args) self._reassign_inout_args(func_ty, node.args) diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index 58d701d7..dbc35dce 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -19,6 +19,7 @@ from guppylang.cfg.builder import tmp_vars from guppylang.checker.core import Variable from guppylang.compiler.core import CompilerBase, DFContainer +from guppylang.definition.custom import CustomFunctionDef from guppylang.definition.value import CompiledCallableDef, CompiledValueDef from guppylang.error import GuppyError, InternalGuppyError from guppylang.nodes import ( @@ -43,6 +44,7 @@ from guppylang.tys.subst import Inst from guppylang.tys.ty import ( BoundTypeVar, + FuncInput, FunctionType, InputFlags, NoneType, @@ -319,8 +321,15 @@ def visit_GlobalCall(self, node: GlobalCall) -> Wire: rets = func.compile_call( args, list(node.type_args), self.dfg, self.globals, node ) - self._update_inout_ports(node.args, iter(rets.inout_returns), func.ty) - return self._pack_returns(rets.regular_returns, func.ty.output) + if isinstance(func, CustomFunctionDef) and not func.has_signature: + func_ty = FunctionType( + [FuncInput(get_type(arg), InputFlags.NoFlags) for arg in node.args], + get_type(node), + ) + else: + func_ty = func.ty.instantiate(node.type_args) + self._update_inout_ports(node.args, iter(rets.inout_returns), func_ty) + return self._pack_returns(rets.regular_returns, func_ty.output) def visit_Call(self, node: ast.Call) -> Wire: raise InternalGuppyError("Node should have been removed during type checking.") diff --git a/guppylang/definition/custom.py b/guppylang/definition/custom.py index 913677aa..e19085cc 100644 --- a/guppylang/definition/custom.py +++ b/guppylang/definition/custom.py @@ -7,7 +7,7 @@ from hugr import tys as ht from hugr.dfg import _DfBase -from guppylang.ast_util import AstNode, with_loc, with_type +from guppylang.ast_util import AstNode, get_type, with_loc, with_type from guppylang.checker.core import Context, Globals from guppylang.checker.expr_checker import check_call, synthesize_call from guppylang.checker.func_checker import check_signature @@ -17,7 +17,7 @@ from guppylang.error import GuppyError, InternalGuppyError from guppylang.nodes import GlobalCall from guppylang.tys.subst import Inst, Subst -from guppylang.tys.ty import FunctionType, NoneType, Type +from guppylang.tys.ty import FuncInput, FunctionType, InputFlags, NoneType, Type @dataclass(frozen=True) @@ -61,7 +61,8 @@ def parse(self, globals: "Globals") -> "CustomFunctionDef": code. The only information we need to access is that it's a function type and that there are no unsolved existential vars. """ - ty = self._get_signature(globals) or FunctionType([], NoneType()) + sig = self._get_signature(globals) + ty = sig or FunctionType([], NoneType()) return CustomFunctionDef( self.id, self.name, @@ -70,6 +71,7 @@ def parse(self, globals: "Globals") -> "CustomFunctionDef": self.call_checker, self.call_compiler, self.higher_order_value, + sig is not None, ) def compile_call( @@ -131,10 +133,12 @@ class CustomFunctionDef(CompiledCallableDef): id: The unique definition identifier. name: The name of the definition. defined_at: The AST node where the definition was defined. - ty: The type of the function. + ty: The type of the function. This may be a dummy value if `has_signature` is + false. call_checker: The custom call checker. call_compiler: The custom call compiler. higher_order_value: Whether the function may be used as a higher-order value. + has_signature: Whether the function has a declared signature. """ defined_at: AstNode @@ -142,6 +146,7 @@ class CustomFunctionDef(CompiledCallableDef): call_checker: "CustomCallChecker" call_compiler: "CustomInoutCallCompiler" higher_order_value: bool + has_signature: bool description: str = field(default="function", init=False) @@ -222,7 +227,13 @@ def compile_call( node: AstNode, ) -> CallReturnWires: """Compiles a call to the function.""" - concrete_ty = self.ty.instantiate(type_args) + if self.has_signature: + concrete_ty = self.ty.instantiate(type_args) + else: + concrete_ty = FunctionType( + [FuncInput(get_type(arg), InputFlags.NoFlags) for arg in node.args], + get_type(node), + ) hugr_ty = concrete_ty.to_hugr() self.call_compiler._setup(type_args, dfg, globals, node, hugr_ty) diff --git a/guppylang/definition/struct.py b/guppylang/definition/struct.py index 2d20c37a..a1fb1a97 100644 --- a/guppylang/definition/struct.py +++ b/guppylang/definition/struct.py @@ -219,6 +219,7 @@ def compile(self, args: list[Wire]) -> list[Wire]: call_checker=DefaultCallChecker(), call_compiler=ConstructorCompiler(), higher_order_value=True, + has_signature=True, ) return [constructor_def] diff --git a/guppylang/prelude/_internal/checker.py b/guppylang/prelude/_internal/checker.py index f6c59995..dd5735bf 100644 --- a/guppylang/prelude/_internal/checker.py +++ b/guppylang/prelude/_internal/checker.py @@ -3,6 +3,7 @@ from guppylang.ast_util import AstNode, with_loc from guppylang.checker.core import Context from guppylang.checker.expr_checker import ( + ExprChecker, ExprSynthesizer, check_call, check_num_args, @@ -18,7 +19,13 @@ from guppylang.error import GuppyError, GuppyTypeError, InternalGuppyError from guppylang.nodes import GlobalCall, ResultExpr from guppylang.tys.arg import ConstArg, TypeArg -from guppylang.tys.builtin import bool_type, int_type, is_array_type, is_bool_type +from guppylang.tys.builtin import ( + array_type, + bool_type, + int_type, + is_array_type, + is_bool_type, +) from guppylang.tys.const import Const, ConstValue from guppylang.tys.subst import Inst, Subst from guppylang.tys.ty import ( @@ -180,6 +187,52 @@ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]: return self._get_const_len(inst), subst +class NewArrayChecker(CustomCallChecker): + """Function call checker for the `array.__new__` function.""" + + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: + if len(args) == 0: + raise GuppyTypeError( + "Cannot infer the array element type. Consider adding a type " + "annotation.", + self.node, + ) + [fst, *rest] = args + fst, ty = ExprSynthesizer(self.ctx).synthesize(fst) + checker = ExprChecker(self.ctx) + for i in range(len(rest)): + rest[i], subst = checker.check(rest[i], ty) + assert len(subst) == 0, "Array element type is closed" + result_ty = array_type(ty, len(args)) + call = GlobalCall( + def_id=self.func.id, args=[fst, *rest], type_args=result_ty.args + ) + return with_loc(self.node, call), result_ty + + def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]: + if not is_array_type(ty): + raise GuppyTypeError( + f"Expected expression of type `{ty}`, got `array`", self.node + ) + match ty.args: + case [TypeArg(ty=elem_ty), ConstArg(ConstValue(value=int(length)))]: + subst = {} + checker = ExprChecker(self.ctx) + for i in range(len(args)): + args[i], s = checker.check(args[i], elem_ty.substitute(subst)) + subst |= s + if len(args) != length: + raise GuppyTypeError( + f"Expected expression of type `{ty}`, got " + f"`array[{elem_ty}, {len(args)}]`", + self.node, + ) + call = GlobalCall(def_id=self.func.id, args=args, type_args=ty.args) + return with_loc(self.node, call), subst + case type_args: + raise InternalGuppyError(f"Invalid array type args: {type_args}") + + class ResultChecker(CustomCallChecker): """Call checker for the `result` function.""" diff --git a/guppylang/prelude/_internal/compiler.py b/guppylang/prelude/_internal/compiler.py index 798031e0..d9bde883 100644 --- a/guppylang/prelude/_internal/compiler.py +++ b/guppylang/prelude/_internal/compiler.py @@ -6,6 +6,10 @@ from guppylang.definition.custom import ( CustomCallCompiler, ) +from guppylang.error import InternalGuppyError +from guppylang.tys.arg import ConstArg, TypeArg +from guppylang.tys.builtin import array_type +from guppylang.tys.const import ConstValue from guppylang.tys.ty import NumericType # Note: Hugr's INT_T is 64bits, but guppy defaults to 32bits @@ -187,6 +191,30 @@ def compile(self, args: list[Wire]) -> list[Wire]: return list(self.builder.add(ops.MakeTuple()(div, mod))) +class NewArrayCompiler(CustomCallCompiler): + """Compiler for the `array.__new__` function.""" + + def compile(self, args: list[Wire]) -> list[Wire]: + match self.type_args: + case [ + TypeArg(ty=elem_ty) as ty_arg, + ConstArg(ConstValue(value=int(length))) as len_arg, + ]: + sig = ht.FunctionType( + [elem_ty.to_hugr()] * len(args), + [array_type(elem_ty, length).to_hugr()], + ) + op = ops.Custom( + extension="prelude", + signature=sig, + name="new_array", + args=[len_arg.to_hugr(), ty_arg.to_hugr()], + ) + return [self.builder.add_op(op, *args)] + case type_args: + raise InternalGuppyError(f"Invalid array type args: {type_args}") + + class MeasureCompiler(CustomCallCompiler): """Compiler for the `measure` function.""" diff --git a/guppylang/prelude/builtins.py b/guppylang/prelude/builtins.py index d3303323..111a0626 100644 --- a/guppylang/prelude/builtins.py +++ b/guppylang/prelude/builtins.py @@ -14,6 +14,7 @@ CoercingChecker, DunderChecker, FailingChecker, + NewArrayChecker, ResultChecker, ReversingChecker, UnsupportedChecker, @@ -25,6 +26,7 @@ FloatModCompiler, IntTruedivCompiler, NatTruedivCompiler, + NewArrayCompiler, ) from guppylang.prelude._internal.util import ( custom_op, @@ -82,6 +84,9 @@ class nat: class array(Generic[_T, _n]): """Class to import in order to use arrays.""" + def __init__(self, *args: _T): + pass + @guppy.extend_type(builtins, bool_type_def) class Bool: @@ -660,6 +665,11 @@ def __getitem__(self: array[T, n], idx: int) -> T: ... @guppy.custom(builtins, checker=ArrayLenChecker()) def __len__(self: array[T, n]) -> int: ... + @guppy.custom( + builtins, NewArrayCompiler(), NewArrayChecker(), higher_order_value=False + ) + def __new__(): ... + # TODO: This is a temporary hack until we have implemented the proper results mechanism. @guppy.custom(builtins, checker=ResultChecker(), higher_order_value=False) diff --git a/guppylang/tys/builtin.py b/guppylang/tys/builtin.py index 4fefa11c..427ea1e3 100644 --- a/guppylang/tys/builtin.py +++ b/guppylang/tys/builtin.py @@ -9,6 +9,7 @@ from guppylang.definition.ty import OpaqueTypeDef, TypeDef from guppylang.error import GuppyError, InternalGuppyError from guppylang.tys.arg import Argument, ConstArg, TypeArg +from guppylang.tys.const import ConstValue from guppylang.tys.param import ConstParam, TypeParam from guppylang.tys.ty import ( FunctionType, @@ -205,6 +206,13 @@ def linst_type(element_ty: Type) -> OpaqueType: return OpaqueType([TypeArg(element_ty)], linst_type_def) +def array_type(element_ty: Type, length: int) -> OpaqueType: + nat_type = NumericType(NumericType.Kind.Nat) + return OpaqueType( + [TypeArg(element_ty), ConstArg(ConstValue(nat_type, length))], array_type_def + ) + + def is_bool_type(ty: Type) -> bool: return isinstance(ty, OpaqueType) and ty.defn == bool_type_def diff --git a/tests/error/array_errors/new_array_cannot_infer.err b/tests/error/array_errors/new_array_cannot_infer.err new file mode 100644 index 00000000..f5f9411c --- /dev/null +++ b/tests/error/array_errors/new_array_cannot_infer.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: @guppy(module) +12: def main() -> None: +13: xs = array() + ^^^^^^^ +GuppyTypeError: Cannot infer the array element type. Consider adding a type annotation. diff --git a/tests/error/array_errors/new_array_cannot_infer.py b/tests/error/array_errors/new_array_cannot_infer.py new file mode 100644 index 00000000..e6e92f15 --- /dev/null +++ b/tests/error/array_errors/new_array_cannot_infer.py @@ -0,0 +1,16 @@ +import guppylang.prelude.quantum as quantum +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import array + + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def main() -> None: + xs = array() + + +module.compile() \ No newline at end of file diff --git a/tests/error/array_errors/new_array_check_fail.err b/tests/error/array_errors/new_array_check_fail.err new file mode 100644 index 00000000..63978c5d --- /dev/null +++ b/tests/error/array_errors/new_array_check_fail.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: @guppy(module) +12: def main() -> int: +13: return array(1) + ^^^^^^^^ +GuppyTypeError: Expected expression of type `int`, got `array` diff --git a/tests/error/array_errors/new_array_check_fail.py b/tests/error/array_errors/new_array_check_fail.py new file mode 100644 index 00000000..cd179a4d --- /dev/null +++ b/tests/error/array_errors/new_array_check_fail.py @@ -0,0 +1,16 @@ +import guppylang.prelude.quantum as quantum +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import array + + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def main() -> int: + return array(1) + + +module.compile() \ No newline at end of file diff --git a/tests/error/array_errors/new_array_elem_mismatch1.err b/tests/error/array_errors/new_array_elem_mismatch1.err new file mode 100644 index 00000000..8d587c7e --- /dev/null +++ b/tests/error/array_errors/new_array_elem_mismatch1.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: @guppy(module) +12: def main() -> array[int, 1]: +13: array(False) + ^^^^^^^^^^^^ +GuppyError: Expected return statement diff --git a/tests/error/array_errors/new_array_elem_mismatch1.py b/tests/error/array_errors/new_array_elem_mismatch1.py new file mode 100644 index 00000000..89c0653f --- /dev/null +++ b/tests/error/array_errors/new_array_elem_mismatch1.py @@ -0,0 +1,16 @@ +import guppylang.prelude.quantum as quantum +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import array + + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def main() -> array[int, 1]: + array(False) + + +module.compile() \ No newline at end of file diff --git a/tests/error/array_errors/new_array_elem_mismatch2.err b/tests/error/array_errors/new_array_elem_mismatch2.err new file mode 100644 index 00000000..6cdee565 --- /dev/null +++ b/tests/error/array_errors/new_array_elem_mismatch2.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: @guppy(module) +12: def main() -> None: +13: array(1, False) + ^^^^^ +GuppyTypeError: Expected expression of type `int`, got `bool` diff --git a/tests/error/array_errors/new_array_elem_mismatch2.py b/tests/error/array_errors/new_array_elem_mismatch2.py new file mode 100644 index 00000000..3dba7541 --- /dev/null +++ b/tests/error/array_errors/new_array_elem_mismatch2.py @@ -0,0 +1,16 @@ +import guppylang.prelude.quantum as quantum +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import array + + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def main() -> None: + array(1, False) + + +module.compile() \ No newline at end of file diff --git a/tests/error/array_errors/new_array_wrong_length.err b/tests/error/array_errors/new_array_wrong_length.err new file mode 100644 index 00000000..e0a577b5 --- /dev/null +++ b/tests/error/array_errors/new_array_wrong_length.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: @guppy(module) +12: def main() -> array[int, 2]: +13: return array(1, 2, 3) + ^^^^^^^^^^^^^^ +GuppyTypeError: Expected expression of type `array[int, 2]`, got `array[int, 3]` diff --git a/tests/error/array_errors/new_array_wrong_length.py b/tests/error/array_errors/new_array_wrong_length.py new file mode 100644 index 00000000..e3dc51db --- /dev/null +++ b/tests/error/array_errors/new_array_wrong_length.py @@ -0,0 +1,16 @@ +import guppylang.prelude.quantum as quantum +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import array + + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def main() -> array[int, 2]: + return array(1, 2, 3) + + +module.compile() \ No newline at end of file diff --git a/tests/integration/test_array.py b/tests/integration/test_array.py index 49ac3fc7..12481ac6 100644 --- a/tests/integration/test_array.py +++ b/tests/integration/test_array.py @@ -29,3 +29,30 @@ def main(xs: array[int, 5], i: int) -> int: return xs[0] + xs[i] + xs[xs[2 * i]] validate(main) + + +def test_new_array(validate): + @compile_guppy + def main(x: int, y: int) -> array[int, 3]: + xs = array(x, y, x) + return xs + + validate(main) + + +def test_new_array_infer_empty(validate): + @compile_guppy + def main() -> array[float, 0]: + return array() + + validate(main) + + +def test_new_array_infer_nested(validate): + @compile_guppy + def main(ys: array[int, 0]) -> array[array[int, 0], 2]: + xs = array(ys, array()) + return xs + + validate(main) +