Skip to content

Commit

Permalink
feat: Generic function definitions (#618)
Browse files Browse the repository at this point in the history
Allow function definitions that are generic over parameters of kind
`type` or `nat`.

* Store currently available type parameters in the type checking context
* Refactor type parsing logic to explicitly opt in to add free variables
to the parameter mapping
* Add a new `GenericParamValue` AST node to represent usages of generic
nat params `n` inside function bodies. We need a way to encode this in
Hugr. For now, we just emit a dummy node. See
CQCL/hugr#1629

Note: Nested generic functions are not supported yet, so we don't have
to worry about scoping of type params.

Closes #522
  • Loading branch information
mark-koch authored Nov 14, 2024
1 parent f527703 commit 7519b90
Show file tree
Hide file tree
Showing 19 changed files with 293 additions and 61 deletions.
17 changes: 13 additions & 4 deletions guppylang/checker/cfg_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from guppylang.checker.stmt_checker import StmtChecker
from guppylang.diagnostic import Error, Note
from guppylang.error import GuppyError
from guppylang.tys.param import Parameter
from guppylang.tys.ty import InputFlags, Type

Row = Sequence[V]
Expand Down Expand Up @@ -60,7 +61,12 @@ def __init__(self, input_tys: list[Type], output_ty: Type) -> None:


def check_cfg(
cfg: CFG, inputs: Row[Variable], return_ty: Type, func_name: str, globals: Globals
cfg: CFG,
inputs: Row[Variable],
return_ty: Type,
generic_params: dict[str, Parameter],
func_name: str,
globals: Globals,
) -> CheckedCFG[Place]:
"""Type checks a control-flow graph.
Expand All @@ -76,7 +82,7 @@ def check_cfg(
# We start by compiling the entry BB
checked_cfg: CheckedCFG[Variable] = CheckedCFG([v.ty for v in inputs], return_ty)
checked_cfg.entry_bb = check_bb(
cfg.entry_bb, checked_cfg, inputs, return_ty, globals
cfg.entry_bb, checked_cfg, inputs, return_ty, generic_params, globals
)
compiled = {cfg.entry_bb: checked_cfg.entry_bb}

Expand All @@ -102,7 +108,9 @@ def check_cfg(
check_rows_match(input_row, compiled[bb].sig.input_row, bb)
else:
# Otherwise, check the BB and enqueue its successors
checked_bb = check_bb(bb, checked_cfg, input_row, return_ty, globals)
checked_bb = check_bb(
bb, checked_cfg, input_row, return_ty, generic_params, globals
)
queue += [
# We enumerate the successor starting from the back, so we start with
# the `True` branch. This way, we find errors in a more natural order
Expand Down Expand Up @@ -174,6 +182,7 @@ def check_bb(
checked_cfg: CheckedCFG[Variable],
inputs: Row[Variable],
return_ty: Type,
generic_params: dict[str, Parameter],
globals: Globals,
) -> CheckedBB[Variable]:
cfg = bb.containing_cfg
Expand All @@ -187,7 +196,7 @@ def check_bb(
raise GuppyError(VarNotDefinedError(use, x))

# Check the basic block
ctx = Context(globals, Locals({v.name: v for v in inputs}))
ctx = Context(globals, Locals({v.name: v for v in inputs}), generic_params)
checked_stmts = StmtChecker(ctx, bb, return_ty).check_stmts(bb.statements)

# If we branch, we also have to check the branch predicate
Expand Down
2 changes: 2 additions & 0 deletions guppylang/checker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
sized_iter_type_def,
tuple_type_def,
)
from guppylang.tys.param import Parameter
from guppylang.tys.ty import (
BoundTypeVar,
ExistentialTypeVar,
Expand Down Expand Up @@ -381,6 +382,7 @@ class Context(NamedTuple):

globals: Globals
locals: Locals[str, Variable]
generic_params: dict[str, Parameter]


class DummyEvalDict(PyScope):
Expand Down
19 changes: 17 additions & 2 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from dataclasses import replace
from typing import Any, NoReturn, cast

from typing_extensions import assert_never

from guppylang.ast_util import (
AstNode,
AstVisitor,
Expand Down Expand Up @@ -85,6 +87,7 @@
DesugaredGenerator,
DesugaredListComp,
FieldAccessAndDrop,
GenericParamValue,
GlobalName,
InoutReturnSentinel,
IterEnd,
Expand All @@ -108,7 +111,7 @@
is_list_type,
list_type,
)
from guppylang.tys.param import TypeParam
from guppylang.tys.param import ConstParam, TypeParam
from guppylang.tys.subst import Inst, Subst
from guppylang.tys.ty import (
ExistentialTypeVar,
Expand Down Expand Up @@ -368,6 +371,18 @@ def visit_Name(self, node: ast.Name) -> tuple[ast.expr, Type]:
if x in self.ctx.locals:
var = self.ctx.locals[x]
return with_loc(node, PlaceNode(place=var)), var.ty
elif x in self.ctx.generic_params:
param = self.ctx.generic_params[x]
match param:
case ConstParam() as param:
ast_node = with_loc(node, GenericParamValue(id=x, param=param))
return ast_node, param.ty
case TypeParam() as param:
raise GuppyError(
ExpectedError(node, "a value", got=f"type `{param.name}`")
)
case _:
return assert_never(param)
elif x in self.ctx.globals:
defn = self.ctx.globals[x]
return self._check_global(defn, x, node)
Expand Down Expand Up @@ -1031,7 +1046,7 @@ def synthesize_comprehension(
# The rest is checked in a new nested context to ensure that variables don't escape
# their scope
inner_locals: Locals[str, Variable] = Locals({}, parent_scope=ctx.locals)
inner_ctx = Context(ctx.globals, inner_locals)
inner_ctx = Context(ctx.globals, inner_locals, ctx.generic_params)
expr_sth, stmt_chk = ExprSynthesizer(inner_ctx), StmtChecker(inner_ctx)
gen.hasnext_assign = stmt_chk.visit_Assign(gen.hasnext_assign)
gen.next_assign = stmt_chk.visit_Assign(gen.next_assign)
Expand Down
14 changes: 11 additions & 3 deletions guppylang/checker/func_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ def check_global_func_def(
Variable(x, inp.ty, loc, inp.flags)
for x, inp, loc in zip(ty.input_names, ty.inputs, args, strict=True)
]
return check_cfg(cfg, inputs, ty.output, func_def.name, globals)
generic_params = {
param.name: param.with_idx(i) for i, param in enumerate(ty.params)
}
return check_cfg(cfg, inputs, ty.output, generic_params, func_def.name, globals)


def check_nested_func_def(
Expand All @@ -84,6 +87,11 @@ def check_nested_func_def(
func_ty = check_signature(func_def, ctx.globals)
assert func_ty.input_names is not None

if func_ty.parametrized:
raise GuppyError(
UnsupportedError(func_def, "Nested generic function definitions")
)

# We've already built the CFG for this function while building the CFG of the
# enclosing function
cfg = func_def.cfg
Expand Down Expand Up @@ -137,7 +145,7 @@ def check_nested_func_def(
# Otherwise, we treat it like a local name
inputs.append(Variable(func_def.name, func_def.ty, func_def))

checked_cfg = check_cfg(cfg, inputs, func_ty.output, func_def.name, globals)
checked_cfg = check_cfg(cfg, inputs, func_ty.output, {}, func_def.name, globals)
checked_def = CheckedNestedFunctionDef(
def_id,
checked_cfg,
Expand Down Expand Up @@ -188,7 +196,7 @@ def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType
input_nodes.append(ty_ast)
input_names.append(inp.arg)
inputs, output = parse_function_io_types(
input_nodes, func_def.returns, func_def, globals, param_var_mapping
input_nodes, func_def.returns, func_def, globals, param_var_mapping, True
)
return FunctionType(
inputs,
Expand Down
2 changes: 1 addition & 1 deletion guppylang/checker/stmt_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def visit_Assign(self, node: ast.Assign) -> ast.Assign:
def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt:
if node.value is None:
raise GuppyError(UnsupportedError(node, "Variable declarations"))
ty = type_from_ast(node.annotation, self.ctx.globals)
ty = type_from_ast(node.annotation, self.ctx.globals, self.ctx.generic_params)
node.value, subst = self._check_expr(node.value, ty)
assert not ty.unsolved_vars # `ty` must be closed!
assert len(subst) == 0
Expand Down
10 changes: 9 additions & 1 deletion guppylang/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@
from guppylang.checker.errors.generic import UnsupportedError
from guppylang.checker.linearity_checker import contains_subscript
from guppylang.compiler.core import CompilerBase, DFContainer
from guppylang.compiler.hugr_extension import PartialOp
from guppylang.compiler.hugr_extension import PartialOp, UnsupportedOp
from guppylang.definition.custom import CustomFunctionDef
from guppylang.definition.value import CompiledCallableDef, CompiledValueDef
from guppylang.error import GuppyError, InternalGuppyError
from guppylang.nodes import (
DesugaredGenerator,
DesugaredListComp,
FieldAccessAndDrop,
GenericParamValue,
GlobalCall,
GlobalName,
InoutReturnSentinel,
Expand Down Expand Up @@ -196,6 +197,13 @@ def visit_GlobalName(self, node: GlobalName) -> Wire:
raise GuppyError(err)
return defn.load(self.dfg, self.globals, node)

def visit_GenericParamValue(self, node: GenericParamValue) -> Wire:
# TODO: We need a way to look up the concrete value of a generic type arg in
# Hugr. For example, a new op that captures the value during monomorphisation
return self.builder.add_op(
UnsupportedOp("load_type_param", [], [node.param.ty.to_hugr()]).ext_op
)

def visit_Name(self, node: ast.Name) -> Wire:
raise InternalGuppyError("Node should have been removed during type checking.")

Expand Down
2 changes: 1 addition & 1 deletion guppylang/definition/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def parse(self, globals: Globals, sources: SourceMap) -> "ConstDef":
self.id,
self.name,
self.defined_at,
type_from_ast(self.type_ast, globals, None),
type_from_ast(self.type_ast, globals, {}),
self.type_ast,
self.value,
)
Expand Down
2 changes: 1 addition & 1 deletion guppylang/definition/extern.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def parse(self, globals: Globals, sources: SourceMap) -> "ExternDef":
self.id,
self.name,
self.defined_at,
type_from_ast(self.type_ast, globals, None),
type_from_ast(self.type_ast, globals, {}),
self.symbol,
self.constant,
self.type_ast,
Expand Down
11 changes: 5 additions & 6 deletions guppylang/definition/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from guppylang.ast_util import AstNode, annotate_location, with_loc
from guppylang.checker.cfg_checker import CheckedCFG
from guppylang.checker.core import Context, Globals, Place, PyScope
from guppylang.checker.errors.generic import ExpectedError, UnsupportedError
from guppylang.checker.errors.generic import ExpectedError
from guppylang.checker.expr_checker import check_call, synthesize_call
from guppylang.checker.func_checker import (
check_global_func_def,
Expand Down Expand Up @@ -65,8 +65,6 @@ def parse(self, globals: Globals, sources: SourceMap) -> "ParsedFunctionDef":
"""Parses and checks the user-provided signature of the function."""
func_ast, docstring = parse_py_func(self.python_func, sources)
ty = check_signature(func_ast, globals.with_python_scope(self.python_scope))
if ty.parametrized:
raise GuppyError(UnsupportedError(func_ast, "Generic function definitions"))
return ParsedFunctionDef(
self.id, self.name, func_ast, ty, self.python_scope, docstring
)
Expand Down Expand Up @@ -160,9 +158,10 @@ def compile_outer(self, module: DefinitionBuilder[OpVar]) -> "CompiledFunctionDe
access to the other compiled functions yet. The body is compiled later in
`CompiledFunctionDef.compile_inner()`.
"""
func_type = self.ty.to_hugr()
func_def = module.define_function(self.name, func_type.input)
func_def.declare_outputs(func_type.output)
func_type = self.ty.to_hugr_poly()
func_def = module.define_function(
self.name, func_type.body.input, func_type.body.output, func_type.params
)
return CompiledFunctionDef(
self.id,
self.name,
Expand Down
10 changes: 6 additions & 4 deletions guppylang/definition/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,9 @@ def check(self, globals: Globals) -> "CheckedStructDef":
# otherwise the code below would not terminate.
# TODO: This is not ideal (see todo in `check_instantiate`)
globals = globals.with_python_scope(self.python_scope)
check_not_recursive(self, globals)

param_var_mapping = {p.name: p for p in self.params}
check_not_recursive(self, globals, param_var_mapping)

fields = [
StructField(f.name, type_from_ast(f.type_ast, globals, param_var_mapping))
for f in self.fields
Expand Down Expand Up @@ -330,7 +330,9 @@ def params_from_ast(nodes: Sequence[ast.expr], globals: Globals) -> list[Paramet
return params


def check_not_recursive(defn: ParsedStructDef, globals: Globals) -> None:
def check_not_recursive(
defn: ParsedStructDef, globals: Globals, param_var_mapping: dict[str, Parameter]
) -> None:
"""Throws a user error if the given struct definition is recursive."""

# TODO: The implementation below hijacks the type parsing logic to detect recursive
Expand Down Expand Up @@ -359,4 +361,4 @@ def check_instantiate(
}
dummy_globals = replace(globals, defs=globals.defs | dummy_defs)
for field in defn.fields:
type_from_ast(field.type_ast, dummy_globals, {})
type_from_ast(field.type_ast, dummy_globals, param_var_mapping)
11 changes: 11 additions & 0 deletions guppylang/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from guppylang.checker.core import Place, Variable
from guppylang.definition.common import DefId
from guppylang.definition.struct import StructField
from guppylang.tys.param import ConstParam


class PlaceNode(ast.expr):
Expand All @@ -33,6 +34,16 @@ class GlobalName(ast.Name):
)


class GenericParamValue(ast.Name):
id: str
param: "ConstParam"

_fields = (
"id",
"param",
)


class LocalCall(ast.expr):
func: ast.expr
args: list[ast.expr]
Expand Down
Loading

0 comments on commit 7519b90

Please sign in to comment.