Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix generic array functions #630

Merged
merged 8 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions guppylang/checker/cfg_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from guppylang.checker.stmt_checker import StmtChecker
from guppylang.diagnostic import Error, Note
from guppylang.error import GuppyError
from guppylang.tys.param import Parameter
from guppylang.tys.ty import InputFlags, Type

Row = Sequence[V]
Expand Down Expand Up @@ -60,7 +61,12 @@ def __init__(self, input_tys: list[Type], output_ty: Type) -> None:


def check_cfg(
cfg: CFG, inputs: Row[Variable], return_ty: Type, func_name: str, globals: Globals
cfg: CFG,
inputs: Row[Variable],
return_ty: Type,
generic_params: dict[str, Parameter],
func_name: str,
globals: Globals,
) -> CheckedCFG[Place]:
"""Type checks a control-flow graph.

Expand All @@ -76,7 +82,7 @@ def check_cfg(
# We start by compiling the entry BB
checked_cfg: CheckedCFG[Variable] = CheckedCFG([v.ty for v in inputs], return_ty)
checked_cfg.entry_bb = check_bb(
cfg.entry_bb, checked_cfg, inputs, return_ty, globals
cfg.entry_bb, checked_cfg, inputs, return_ty, generic_params, globals
)
compiled = {cfg.entry_bb: checked_cfg.entry_bb}

Expand All @@ -102,7 +108,9 @@ def check_cfg(
check_rows_match(input_row, compiled[bb].sig.input_row, bb)
else:
# Otherwise, check the BB and enqueue its successors
checked_bb = check_bb(bb, checked_cfg, input_row, return_ty, globals)
checked_bb = check_bb(
bb, checked_cfg, input_row, return_ty, generic_params, globals
)
queue += [
# We enumerate the successor starting from the back, so we start with
# the `True` branch. This way, we find errors in a more natural order
Expand Down Expand Up @@ -174,6 +182,7 @@ def check_bb(
checked_cfg: CheckedCFG[Variable],
inputs: Row[Variable],
return_ty: Type,
generic_params: dict[str, Parameter],
globals: Globals,
) -> CheckedBB[Variable]:
cfg = bb.containing_cfg
Expand All @@ -187,7 +196,7 @@ def check_bb(
raise GuppyError(VarNotDefinedError(use, x))

# Check the basic block
ctx = Context(globals, Locals({v.name: v for v in inputs}))
ctx = Context(globals, Locals({v.name: v for v in inputs}), generic_params)
checked_stmts = StmtChecker(ctx, bb, return_ty).check_stmts(bb.statements)

# If we branch, we also have to check the branch predicate
Expand Down
2 changes: 2 additions & 0 deletions guppylang/checker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
sized_iter_type_def,
tuple_type_def,
)
from guppylang.tys.param import Parameter
from guppylang.tys.ty import (
BoundTypeVar,
ExistentialTypeVar,
Expand Down Expand Up @@ -381,6 +382,7 @@ class Context(NamedTuple):

globals: Globals
locals: Locals[str, Variable]
generic_params: dict[str, Parameter]


class DummyEvalDict(PyScope):
Expand Down
19 changes: 17 additions & 2 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from dataclasses import replace
from typing import Any, NoReturn, cast

from typing_extensions import assert_never

from guppylang.ast_util import (
AstNode,
AstVisitor,
Expand Down Expand Up @@ -85,6 +87,7 @@
DesugaredGenerator,
DesugaredListComp,
FieldAccessAndDrop,
GenericParamValue,
GlobalName,
InoutReturnSentinel,
IterEnd,
Expand All @@ -108,7 +111,7 @@
is_list_type,
list_type,
)
from guppylang.tys.param import TypeParam
from guppylang.tys.param import ConstParam, TypeParam
from guppylang.tys.subst import Inst, Subst
from guppylang.tys.ty import (
ExistentialTypeVar,
Expand Down Expand Up @@ -368,6 +371,18 @@ def visit_Name(self, node: ast.Name) -> tuple[ast.expr, Type]:
if x in self.ctx.locals:
var = self.ctx.locals[x]
return with_loc(node, PlaceNode(place=var)), var.ty
elif x in self.ctx.generic_params:
param = self.ctx.generic_params[x]
match param:
case ConstParam() as param:
ast_node = with_loc(node, GenericParamValue(id=x, param=param))
return ast_node, param.ty
case TypeParam() as param:
raise GuppyError(
ExpectedError(node, "a value", got=f"type `{param.name}`")
)
case _:
return assert_never(param)
elif x in self.ctx.globals:
defn = self.ctx.globals[x]
return self._check_global(defn, x, node)
Expand Down Expand Up @@ -1031,7 +1046,7 @@ def synthesize_comprehension(
# The rest is checked in a new nested context to ensure that variables don't escape
# their scope
inner_locals: Locals[str, Variable] = Locals({}, parent_scope=ctx.locals)
inner_ctx = Context(ctx.globals, inner_locals)
inner_ctx = Context(ctx.globals, inner_locals, ctx.generic_params)
expr_sth, stmt_chk = ExprSynthesizer(inner_ctx), StmtChecker(inner_ctx)
gen.hasnext_assign = stmt_chk.visit_Assign(gen.hasnext_assign)
gen.next_assign = stmt_chk.visit_Assign(gen.next_assign)
Expand Down
14 changes: 11 additions & 3 deletions guppylang/checker/func_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ def check_global_func_def(
Variable(x, inp.ty, loc, inp.flags)
for x, inp, loc in zip(ty.input_names, ty.inputs, args, strict=True)
]
return check_cfg(cfg, inputs, ty.output, func_def.name, globals)
generic_params = {
param.name: param.with_idx(i) for i, param in enumerate(ty.params)
}
return check_cfg(cfg, inputs, ty.output, generic_params, func_def.name, globals)


def check_nested_func_def(
Expand All @@ -84,6 +87,11 @@ def check_nested_func_def(
func_ty = check_signature(func_def, ctx.globals)
assert func_ty.input_names is not None

if func_ty.parametrized:
raise GuppyError(
UnsupportedError(func_def, "Nested generic function definitions")
)

# We've already built the CFG for this function while building the CFG of the
# enclosing function
cfg = func_def.cfg
Expand Down Expand Up @@ -137,7 +145,7 @@ def check_nested_func_def(
# Otherwise, we treat it like a local name
inputs.append(Variable(func_def.name, func_def.ty, func_def))

checked_cfg = check_cfg(cfg, inputs, func_ty.output, func_def.name, globals)
checked_cfg = check_cfg(cfg, inputs, func_ty.output, {}, func_def.name, globals)
checked_def = CheckedNestedFunctionDef(
def_id,
checked_cfg,
Expand Down Expand Up @@ -188,7 +196,7 @@ def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType
input_nodes.append(ty_ast)
input_names.append(inp.arg)
inputs, output = parse_function_io_types(
input_nodes, func_def.returns, func_def, globals, param_var_mapping
input_nodes, func_def.returns, func_def, globals, param_var_mapping, True
)
return FunctionType(
inputs,
Expand Down
2 changes: 1 addition & 1 deletion guppylang/checker/stmt_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def visit_Assign(self, node: ast.Assign) -> ast.Assign:
def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt:
if node.value is None:
raise GuppyError(UnsupportedError(node, "Variable declarations"))
ty = type_from_ast(node.annotation, self.ctx.globals)
ty = type_from_ast(node.annotation, self.ctx.globals, self.ctx.generic_params)
node.value, subst = self._check_expr(node.value, ty)
assert not ty.unsolved_vars # `ty` must be closed!
assert len(subst) == 0
Expand Down
10 changes: 9 additions & 1 deletion guppylang/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@
from guppylang.checker.errors.generic import UnsupportedError
from guppylang.checker.linearity_checker import contains_subscript
from guppylang.compiler.core import CompilerBase, DFContainer
from guppylang.compiler.hugr_extension import PartialOp
from guppylang.compiler.hugr_extension import PartialOp, UnsupportedOp
from guppylang.definition.custom import CustomFunctionDef
from guppylang.definition.value import CompiledCallableDef, CompiledValueDef
from guppylang.error import GuppyError, InternalGuppyError
from guppylang.nodes import (
DesugaredGenerator,
DesugaredListComp,
FieldAccessAndDrop,
GenericParamValue,
GlobalCall,
GlobalName,
InoutReturnSentinel,
Expand Down Expand Up @@ -196,6 +197,13 @@ def visit_GlobalName(self, node: GlobalName) -> Wire:
raise GuppyError(err)
return defn.load(self.dfg, self.globals, node)

def visit_GenericParamValue(self, node: GenericParamValue) -> Wire:
# TODO: We need a way to look up the concrete value of a generic type arg in
# Hugr. For example, a new op that captures the value during monomorphisation
return self.builder.add_op(
UnsupportedOp("load_type_param", [], [node.param.ty.to_hugr()]).ext_op
)

def visit_Name(self, node: ast.Name) -> Wire:
raise InternalGuppyError("Node should have been removed during type checking.")

Expand Down
2 changes: 1 addition & 1 deletion guppylang/definition/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def parse(self, globals: Globals, sources: SourceMap) -> "ConstDef":
self.id,
self.name,
self.defined_at,
type_from_ast(self.type_ast, globals, None),
type_from_ast(self.type_ast, globals, {}),
self.type_ast,
self.value,
)
Expand Down
2 changes: 1 addition & 1 deletion guppylang/definition/extern.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def parse(self, globals: Globals, sources: SourceMap) -> "ExternDef":
self.id,
self.name,
self.defined_at,
type_from_ast(self.type_ast, globals, None),
type_from_ast(self.type_ast, globals, {}),
self.symbol,
self.constant,
self.type_ast,
Expand Down
11 changes: 5 additions & 6 deletions guppylang/definition/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from guppylang.ast_util import AstNode, annotate_location, with_loc
from guppylang.checker.cfg_checker import CheckedCFG
from guppylang.checker.core import Context, Globals, Place, PyScope
from guppylang.checker.errors.generic import ExpectedError, UnsupportedError
from guppylang.checker.errors.generic import ExpectedError
from guppylang.checker.expr_checker import check_call, synthesize_call
from guppylang.checker.func_checker import (
check_global_func_def,
Expand Down Expand Up @@ -65,8 +65,6 @@ def parse(self, globals: Globals, sources: SourceMap) -> "ParsedFunctionDef":
"""Parses and checks the user-provided signature of the function."""
func_ast, docstring = parse_py_func(self.python_func, sources)
ty = check_signature(func_ast, globals.with_python_scope(self.python_scope))
if ty.parametrized:
raise GuppyError(UnsupportedError(func_ast, "Generic function definitions"))
return ParsedFunctionDef(
self.id, self.name, func_ast, ty, self.python_scope, docstring
)
Expand Down Expand Up @@ -160,9 +158,10 @@ def compile_outer(self, module: DefinitionBuilder[OpVar]) -> "CompiledFunctionDe
access to the other compiled functions yet. The body is compiled later in
`CompiledFunctionDef.compile_inner()`.
"""
func_type = self.ty.to_hugr()
func_def = module.define_function(self.name, func_type.input)
func_def.declare_outputs(func_type.output)
func_type = self.ty.to_hugr_poly()
func_def = module.define_function(
self.name, func_type.body.input, func_type.body.output, func_type.params
)
return CompiledFunctionDef(
self.id,
self.name,
Expand Down
10 changes: 6 additions & 4 deletions guppylang/definition/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,9 @@ def check(self, globals: Globals) -> "CheckedStructDef":
# otherwise the code below would not terminate.
# TODO: This is not ideal (see todo in `check_instantiate`)
globals = globals.with_python_scope(self.python_scope)
check_not_recursive(self, globals)

param_var_mapping = {p.name: p for p in self.params}
check_not_recursive(self, globals, param_var_mapping)

fields = [
StructField(f.name, type_from_ast(f.type_ast, globals, param_var_mapping))
for f in self.fields
Expand Down Expand Up @@ -330,7 +330,9 @@ def params_from_ast(nodes: Sequence[ast.expr], globals: Globals) -> list[Paramet
return params


def check_not_recursive(defn: ParsedStructDef, globals: Globals) -> None:
def check_not_recursive(
defn: ParsedStructDef, globals: Globals, param_var_mapping: dict[str, Parameter]
) -> None:
"""Throws a user error if the given struct definition is recursive."""

# TODO: The implementation below hijacks the type parsing logic to detect recursive
Expand Down Expand Up @@ -359,4 +361,4 @@ def check_instantiate(
}
dummy_globals = replace(globals, defs=globals.defs | dummy_defs)
for field in defn.fields:
type_from_ast(field.type_ast, dummy_globals, {})
type_from_ast(field.type_ast, dummy_globals, param_var_mapping)
11 changes: 11 additions & 0 deletions guppylang/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from guppylang.checker.core import Place, Variable
from guppylang.definition.common import DefId
from guppylang.definition.struct import StructField
from guppylang.tys.param import ConstParam


class PlaceNode(ast.expr):
Expand All @@ -33,6 +34,16 @@ class GlobalName(ast.Name):
)


class GenericParamValue(ast.Name):
id: str
param: "ConstParam"

_fields = (
"id",
"param",
)


class LocalCall(ast.expr):
func: ast.expr
args: list[ast.expr]
Expand Down
6 changes: 4 additions & 2 deletions guppylang/std/_internal/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,10 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
assert isinstance(len_arg, ConstArg)
if not self._is_numeric_or_bool_type(ty_arg.ty):
raise GuppyError(err)
base_ty = ty_arg.ty
array_len = len_arg.const
_base_ty = ty_arg.ty
_array_len = len_arg.const
# See https://github.com/CQCL/guppylang/issues/631
raise GuppyError(UnsupportedError(value, "Array results"))
else:
raise GuppyError(err)
node = ResultExpr(value, base_ty, array_len, tag.value)
Expand Down
15 changes: 11 additions & 4 deletions guppylang/std/_internal/compiler/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ class NewArrayCompiler(ArrayCompiler):

def build_classical_array(self, elems: list[Wire]) -> Wire:
"""Lowers a call to `array.__new__` for classical arrays."""
return self.builder.add_op(array_new(self.elem_ty, len(elems)), *elems)
# See https://github.com/CQCL/guppylang/issues/629
return self.build_linear_array(elems)

def build_linear_array(self, elems: list[Wire]) -> Wire:
"""Lowers a call to `array.__new__` for linear arrays."""
Expand All @@ -121,9 +122,12 @@ class ArrayGetitemCompiler(ArrayCompiler):

def build_classical_getitem(self, array: Wire, idx: Wire) -> CallReturnWires:
"""Lowers a call to `array.__getitem__` for classical arrays."""
# See https://github.com/CQCL/guppylang/issues/629
elem_opt_ty = ht.Option(self.elem_ty)
idx = self.builder.add_op(convert_itousize(), idx)
result = self.builder.add_op(array_get(self.elem_ty, self.length), array, idx)
elem = build_unwrap(self.builder, result, "Array index out of bounds")
result = self.builder.add_op(array_get(elem_opt_ty, self.length), array, idx)
elem_opt = build_unwrap(self.builder, result, "Array index out of bounds")
elem = build_unwrap(self.builder, elem_opt, "array.__getitem__: Internal error")
return CallReturnWires(regular_returns=[elem], inout_returns=[array])

def build_linear_getitem(self, array: Wire, idx: Wire) -> CallReturnWires:
Expand Down Expand Up @@ -163,9 +167,12 @@ def build_classical_setitem(
self, array: Wire, idx: Wire, elem: Wire
) -> CallReturnWires:
"""Lowers a call to `array.__setitem__` for classical arrays."""
# See https://github.com/CQCL/guppylang/issues/629
elem_opt_ty = ht.Option(self.elem_ty)
idx = self.builder.add_op(convert_itousize(), idx)
elem_opt = self.builder.add_op(ops.Tag(1, elem_opt_ty), elem)
result = self.builder.add_op(
array_set(self.elem_ty, self.length), array, idx, elem
array_set(elem_opt_ty, self.length), array, idx, elem_opt
)
# Unwrap the result, but we don't have to hold onto the returned old value
_, array = build_unwrap_right(self.builder, result, "Array index out of bounds")
Expand Down
5 changes: 2 additions & 3 deletions guppylang/tys/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,8 @@ def _array_to_hugr(args: Sequence[Argument]) -> ht.Type:

# Linear elements are turned into an optional to enable unsafe indexing.
# See `ArrayGetitemCompiler` for details.
elem_ty = (
ht.Option(ty_arg.ty.to_hugr()) if ty_arg.ty.linear else ty_arg.ty.to_hugr()
)
# Same also for classical arrays, see https://github.com/CQCL/guppylang/issues/629
elem_ty = ht.Option(ty_arg.ty.to_hugr())

array = hugr.std.PRELUDE.get_type("array")
return array.instantiate([len_arg.to_hugr(), ht.TypeTypeArg(elem_ty)])
Expand Down
Loading
Loading