Skip to content

Commit

Permalink
feat: Add array literals
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-koch committed Sep 3, 2024
1 parent e23b666 commit 086d3e1
Show file tree
Hide file tree
Showing 19 changed files with 279 additions and 10 deletions.
11 changes: 9 additions & 2 deletions guppylang/checker/linearity_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
PlaceId,
Variable,
)
from guppylang.definition.custom import CustomFunctionDef
from guppylang.definition.value import CallableDef
from guppylang.error import GuppyError, GuppyTypeError
from guppylang.nodes import (
Expand All @@ -34,7 +35,7 @@
PlaceNode,
TensorCall,
)
from guppylang.tys.ty import FunctionType, InputFlags, StructType
from guppylang.tys.ty import FuncInput, FunctionType, InputFlags, StructType


class Scope(Locals[PlaceId, Place]):
Expand Down Expand Up @@ -184,7 +185,13 @@ def _reassign_inout_args(self, func_ty: FunctionType, args: list[ast.expr]) -> N
def visit_GlobalCall(self, node: GlobalCall) -> None:
func = self.globals[node.def_id]
assert isinstance(func, CallableDef)
func_ty = func.ty.instantiate(node.type_args)
if isinstance(func, CustomFunctionDef) and not func.has_signature:
func_ty = FunctionType(
[FuncInput(get_type(arg), InputFlags.NoFlags) for arg in node.args],
get_type(node),
)
else:
func_ty = func.ty.instantiate(node.type_args)
self._visit_call_args(func_ty, node.args)
self._reassign_inout_args(func_ty, node.args)

Expand Down
13 changes: 11 additions & 2 deletions guppylang/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from guppylang.cfg.builder import tmp_vars
from guppylang.checker.core import Variable
from guppylang.compiler.core import CompilerBase, DFContainer
from guppylang.definition.custom import CustomFunctionDef
from guppylang.definition.value import CompiledCallableDef, CompiledValueDef
from guppylang.error import GuppyError, InternalGuppyError
from guppylang.nodes import (
Expand All @@ -43,6 +44,7 @@
from guppylang.tys.subst import Inst
from guppylang.tys.ty import (
BoundTypeVar,
FuncInput,
FunctionType,
InputFlags,
NoneType,
Expand Down Expand Up @@ -319,8 +321,15 @@ def visit_GlobalCall(self, node: GlobalCall) -> Wire:
rets = func.compile_call(
args, list(node.type_args), self.dfg, self.globals, node
)
self._update_inout_ports(node.args, iter(rets.inout_returns), func.ty)
return self._pack_returns(rets.regular_returns, func.ty.output)
if isinstance(func, CustomFunctionDef) and not func.has_signature:
func_ty = FunctionType(
[FuncInput(get_type(arg), InputFlags.NoFlags) for arg in node.args],
get_type(node),
)
else:
func_ty = func.ty.instantiate(node.type_args)
self._update_inout_ports(node.args, iter(rets.inout_returns), func_ty)
return self._pack_returns(rets.regular_returns, func_ty.output)

def visit_Call(self, node: ast.Call) -> Wire:
raise InternalGuppyError("Node should have been removed during type checking.")
Expand Down
21 changes: 16 additions & 5 deletions guppylang/definition/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from hugr import tys as ht
from hugr.dfg import _DfBase

from guppylang.ast_util import AstNode, with_loc, with_type
from guppylang.ast_util import AstNode, get_type, with_loc, with_type
from guppylang.checker.core import Context, Globals
from guppylang.checker.expr_checker import check_call, synthesize_call
from guppylang.checker.func_checker import check_signature
Expand All @@ -17,7 +17,7 @@
from guppylang.error import GuppyError, InternalGuppyError
from guppylang.nodes import GlobalCall
from guppylang.tys.subst import Inst, Subst
from guppylang.tys.ty import FunctionType, NoneType, Type
from guppylang.tys.ty import FuncInput, FunctionType, InputFlags, NoneType, Type


@dataclass(frozen=True)
Expand Down Expand Up @@ -61,7 +61,8 @@ def parse(self, globals: "Globals") -> "CustomFunctionDef":
code. The only information we need to access is that it's a function type and
that there are no unsolved existential vars.
"""
ty = self._get_signature(globals) or FunctionType([], NoneType())
sig = self._get_signature(globals)
ty = sig or FunctionType([], NoneType())
return CustomFunctionDef(
self.id,
self.name,
Expand All @@ -70,6 +71,7 @@ def parse(self, globals: "Globals") -> "CustomFunctionDef":
self.call_checker,
self.call_compiler,
self.higher_order_value,
sig is not None,
)

def compile_call(
Expand Down Expand Up @@ -131,17 +133,20 @@ class CustomFunctionDef(CompiledCallableDef):
id: The unique definition identifier.
name: The name of the definition.
defined_at: The AST node where the definition was defined.
ty: The type of the function.
ty: The type of the function. This may be a dummy value if `has_signature` is
false.
call_checker: The custom call checker.
call_compiler: The custom call compiler.
higher_order_value: Whether the function may be used as a higher-order value.
has_signature: Whether the function has a declared signature.
"""

defined_at: AstNode
ty: FunctionType
call_checker: "CustomCallChecker"
call_compiler: "CustomInoutCallCompiler"
higher_order_value: bool
has_signature: bool

description: str = field(default="function", init=False)

Expand Down Expand Up @@ -222,7 +227,13 @@ def compile_call(
node: AstNode,
) -> CallReturnWires:
"""Compiles a call to the function."""
concrete_ty = self.ty.instantiate(type_args)
if self.has_signature:
concrete_ty = self.ty.instantiate(type_args)
else:
concrete_ty = FunctionType(
[FuncInput(get_type(arg), InputFlags.NoFlags) for arg in node.args],
get_type(node),
)
hugr_ty = concrete_ty.to_hugr()

self.call_compiler._setup(type_args, dfg, globals, node, hugr_ty)
Expand Down
1 change: 1 addition & 0 deletions guppylang/definition/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def compile(self, args: list[Wire]) -> list[Wire]:
call_checker=DefaultCallChecker(),
call_compiler=ConstructorCompiler(),
higher_order_value=True,
has_signature=True,
)
return [constructor_def]

Expand Down
55 changes: 54 additions & 1 deletion guppylang/prelude/_internal/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from guppylang.ast_util import AstNode, with_loc
from guppylang.checker.core import Context
from guppylang.checker.expr_checker import (
ExprChecker,
ExprSynthesizer,
check_call,
check_num_args,
Expand All @@ -18,7 +19,13 @@
from guppylang.error import GuppyError, GuppyTypeError, InternalGuppyError
from guppylang.nodes import GlobalCall, ResultExpr
from guppylang.tys.arg import ConstArg, TypeArg
from guppylang.tys.builtin import bool_type, int_type, is_array_type, is_bool_type
from guppylang.tys.builtin import (
array_type,
bool_type,
int_type,
is_array_type,
is_bool_type,
)
from guppylang.tys.const import Const, ConstValue
from guppylang.tys.subst import Inst, Subst
from guppylang.tys.ty import (
Expand Down Expand Up @@ -180,6 +187,52 @@ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
return self._get_const_len(inst), subst


class NewArrayChecker(CustomCallChecker):
"""Function call checker for the `array.__new__` function."""

def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
if len(args) == 0:
raise GuppyTypeError(
"Cannot infer the array element type. Consider adding a type "
"annotation.",
self.node,
)
[fst, *rest] = args
fst, ty = ExprSynthesizer(self.ctx).synthesize(fst)
checker = ExprChecker(self.ctx)
for i in range(len(rest)):
rest[i], subst = checker.check(rest[i], ty)
assert len(subst) == 0, "Array element type is closed"
result_ty = array_type(ty, len(args))
call = GlobalCall(
def_id=self.func.id, args=[fst, *rest], type_args=result_ty.args
)
return with_loc(self.node, call), result_ty

def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
if not is_array_type(ty):
raise GuppyTypeError(
f"Expected expression of type `{ty}`, got `array`", self.node
)
match ty.args:
case [TypeArg(ty=elem_ty), ConstArg(ConstValue(value=int(length)))]:
subst = {}
checker = ExprChecker(self.ctx)
for i in range(len(args)):
args[i], s = checker.check(args[i], elem_ty.substitute(subst))
subst |= s
if len(args) != length:
raise GuppyTypeError(
f"Expected expression of type `{ty}`, got "
f"`array[{elem_ty}, {len(args)}]`",
self.node,
)
call = GlobalCall(def_id=self.func.id, args=args, type_args=ty.args)
return with_loc(self.node, call), subst
case type_args:
raise InternalGuppyError(f"Invalid array type args: {type_args}")


class ResultChecker(CustomCallChecker):
"""Call checker for the `result` function."""

Expand Down
28 changes: 28 additions & 0 deletions guppylang/prelude/_internal/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
from guppylang.definition.custom import (
CustomCallCompiler,
)
from guppylang.error import InternalGuppyError
from guppylang.tys.arg import ConstArg, TypeArg
from guppylang.tys.builtin import array_type
from guppylang.tys.const import ConstValue
from guppylang.tys.ty import NumericType

# Note: Hugr's INT_T is 64bits, but guppy defaults to 32bits
Expand Down Expand Up @@ -187,6 +191,30 @@ def compile(self, args: list[Wire]) -> list[Wire]:
return list(self.builder.add(ops.MakeTuple()(div, mod)))


class NewArrayCompiler(CustomCallCompiler):
"""Compiler for the `array.__new__` function."""

def compile(self, args: list[Wire]) -> list[Wire]:
match self.type_args:
case [
TypeArg(ty=elem_ty) as ty_arg,
ConstArg(ConstValue(value=int(length))) as len_arg,
]:
sig = ht.FunctionType(
[elem_ty.to_hugr()] * len(args),
[array_type(elem_ty, length).to_hugr()],
)
op = ops.Custom(
extension="prelude",
signature=sig,
name="new_array",
args=[len_arg.to_hugr(), ty_arg.to_hugr()],
)
return [self.builder.add_op(op, *args)]
case type_args:
raise InternalGuppyError(f"Invalid array type args: {type_args}")


class MeasureCompiler(CustomCallCompiler):
"""Compiler for the `measure` function."""

Expand Down
10 changes: 10 additions & 0 deletions guppylang/prelude/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
CoercingChecker,
DunderChecker,
FailingChecker,
NewArrayChecker,
ResultChecker,
ReversingChecker,
UnsupportedChecker,
Expand All @@ -25,6 +26,7 @@
FloatModCompiler,
IntTruedivCompiler,
NatTruedivCompiler,
NewArrayCompiler,
)
from guppylang.prelude._internal.util import (
custom_op,
Expand Down Expand Up @@ -82,6 +84,9 @@ class nat:
class array(Generic[_T, _n]):
"""Class to import in order to use arrays."""

def __init__(self, *args: _T):
pass


@guppy.extend_type(builtins, bool_type_def)
class Bool:
Expand Down Expand Up @@ -660,6 +665,11 @@ def __getitem__(self: array[T, n], idx: int) -> T: ...
@guppy.custom(builtins, checker=ArrayLenChecker())
def __len__(self: array[T, n]) -> int: ...

@guppy.custom(
builtins, NewArrayCompiler(), NewArrayChecker(), higher_order_value=False
)
def __new__(): ...


# TODO: This is a temporary hack until we have implemented the proper results mechanism.
@guppy.custom(builtins, checker=ResultChecker(), higher_order_value=False)
Expand Down
8 changes: 8 additions & 0 deletions guppylang/tys/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from guppylang.definition.ty import OpaqueTypeDef, TypeDef
from guppylang.error import GuppyError, InternalGuppyError
from guppylang.tys.arg import Argument, ConstArg, TypeArg
from guppylang.tys.const import ConstValue
from guppylang.tys.param import ConstParam, TypeParam
from guppylang.tys.ty import (
FunctionType,
Expand Down Expand Up @@ -205,6 +206,13 @@ def linst_type(element_ty: Type) -> OpaqueType:
return OpaqueType([TypeArg(element_ty)], linst_type_def)


def array_type(element_ty: Type, length: int) -> OpaqueType:
nat_type = NumericType(NumericType.Kind.Nat)
return OpaqueType(
[TypeArg(element_ty), ConstArg(ConstValue(nat_type, length))], array_type_def
)


def is_bool_type(ty: Type) -> bool:
return isinstance(ty, OpaqueType) and ty.defn == bool_type_def

Expand Down
7 changes: 7 additions & 0 deletions tests/error/array_errors/new_array_cannot_infer.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:13

11: @guppy(module)
12: def main() -> None:
13: xs = array()
^^^^^^^
GuppyTypeError: Cannot infer the array element type. Consider adding a type annotation.
16 changes: 16 additions & 0 deletions tests/error/array_errors/new_array_cannot_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import guppylang.prelude.quantum as quantum
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.builtins import array


module = GuppyModule("test")
module.load(quantum)


@guppy(module)
def main() -> None:
xs = array()


module.compile()
7 changes: 7 additions & 0 deletions tests/error/array_errors/new_array_check_fail.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:13

11: @guppy(module)
12: def main() -> int:
13: return array(1)
^^^^^^^^
GuppyTypeError: Expected expression of type `int`, got `array`
16 changes: 16 additions & 0 deletions tests/error/array_errors/new_array_check_fail.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import guppylang.prelude.quantum as quantum
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.builtins import array


module = GuppyModule("test")
module.load(quantum)


@guppy(module)
def main() -> int:
return array(1)


module.compile()
7 changes: 7 additions & 0 deletions tests/error/array_errors/new_array_elem_mismatch1.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:13

11: @guppy(module)
12: def main() -> array[int, 1]:
13: array(False)
^^^^^^^^^^^^
GuppyError: Expected return statement
Loading

0 comments on commit 086d3e1

Please sign in to comment.