diff --git a/guppylang/compiler/core.py b/guppylang/compiler/core.py index d5f96e57..219479bb 100644 --- a/guppylang/compiler/core.py +++ b/guppylang/compiler/core.py @@ -3,17 +3,55 @@ from typing import cast from hugr import Wire, ops -from hugr.build.dfg import DP, DfBase +from hugr.build.dfg import DP, DefinitionBuilder, DfBase from guppylang.checker.core import FieldAccess, Place, PlaceId, Variable -from guppylang.definition.common import CompiledDef, DefId +from guppylang.definition.common import CheckedDef, CompilableDef, CompiledDef, DefId from guppylang.error import InternalGuppyError from guppylang.tys.ty import StructType -CompiledGlobals = dict[DefId, CompiledDef] CompiledLocals = dict[PlaceId, Wire] +class CompiledGlobals: + """Compilation context containing all available definitions. + + This context drives the Hugr lowering by keeping track of which definitions are + used when lowering the definitions provided via the `required` keyword. The + `worklist` field contains all definitions whose contents still need to lowered. + """ + + module: DefinitionBuilder[ops.Module] + checked: dict[DefId, CheckedDef] + compiled: dict[DefId, CompiledDef] + worklist: set[DefId] + + def __init__( + self, + checked: dict[DefId, CheckedDef], + required: set[DefId], + module: DefinitionBuilder[ops.Module], + ) -> None: + self.module = module + self.checked = checked + self.worklist = required + self.compiled = {} + for def_id in required: + self.compiled[def_id] = self._compile(checked[def_id]) + + def __getitem__(self, def_id: DefId) -> CompiledDef: + if def_id not in self.compiled: + defn = self.checked[def_id] + self.compiled[def_id] = self._compile(defn) + self.worklist.add(def_id) + return self.compiled[def_id] + + def _compile(self, defn: CheckedDef) -> CompiledDef: + if isinstance(defn, CompilableDef): + return defn.compile_outer(self.module) + return defn + + @dataclass class DFContainer: """A dataflow graph under construction. diff --git a/guppylang/compiler/func_compiler.py b/guppylang/compiler/func_compiler.py index faa349e0..1e6671fa 100644 --- a/guppylang/compiler/func_compiler.py +++ b/guppylang/compiler/func_compiler.py @@ -62,26 +62,25 @@ def compile_local_func_def( call_args.append(partial) func.cfg.input_tys.append(func.ty) + + # Compile the CFG + cfg = compile_cfg(func.cfg, func_builder, call_args, globals) + func_builder.set_outputs(*cfg) else: # Otherwise, we treat the function like a normal global variable from guppylang.definition.function import CompiledFunctionDef - globals = globals | { - func.def_id: CompiledFunctionDef( - func.def_id, - func.name, - func, - func.ty, - {}, - None, - func.cfg, - func_builder, - ) - } - - # Compile the CFG - cfg = compile_cfg(func.cfg, func_builder, call_args, globals) - func_builder.set_outputs(*cfg) + globals.compiled[func.def_id] = CompiledFunctionDef( + func.def_id, + func.name, + func, + func.ty, + {}, + None, + func.cfg, + func_builder, + ) + globals.worklist.add(func.def_id) # Finally, load the function into the local data-flow graph loaded = dfg.builder.load_function(func_builder, closure_ty) diff --git a/guppylang/definition/common.py b/guppylang/definition/common.py index 14c11880..73e8278b 100644 --- a/guppylang/definition/common.py +++ b/guppylang/definition/common.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, ClassVar, TypeAlias from hugr.build.dfg import DefinitionBuilder, OpVar +from hugr.ext import Package if TYPE_CHECKING: from guppylang.checker.core import Globals @@ -72,6 +73,10 @@ def description(self) -> str: a function, but got {description of this definition} instead". """ + def compile(self) -> Package: + assert self.id.module is not None + return self.id.module.compile() + class ParsableDef(Definition): """Abstract base class for raw definitions that still require parsing. diff --git a/guppylang/module.py b/guppylang/module.py index beb5318e..1c8bab67 100644 --- a/guppylang/module.py +++ b/guppylang/module.py @@ -11,6 +11,7 @@ import guppylang.compiler.hugr_extension from guppylang.checker.core import Globals, PyScope +from guppylang.compiler.core import CompiledGlobals from guppylang.definition.common import ( CheckableDef, CheckedDef, @@ -325,19 +326,19 @@ def compile(self) -> Package: return self._compiled_hugr self.check() + checked_defs = self._imported_checked_defs | self._checked_defs # Prepare Hugr for this module graph = Module() graph.metadata["name"] = self.name - # Compile definitions to Hugr - compiled_defs = self._compile_defs(self._imported_checked_defs, graph) - compiled_defs |= self._compile_defs(self._checked_defs, graph) - - # Finally, compile the definition contents to Hugr. For example, this compiles - # the bodies of functions. - for defn in compiled_defs.values(): - defn.compile_inner(compiled_defs) + # Lower definitions to Hugr + required = set(self._checked_defs.keys()) + ctx = CompiledGlobals(checked_defs, required, graph) + while ctx.worklist: + next_id = ctx.worklist.pop() + next_def = ctx[next_id] + next_def.compile_inner(ctx) hugr = graph.hugr diff --git a/tests/integration/test_array.py b/tests/integration/test_array.py index 685c7db1..38ca00d3 100644 --- a/tests/integration/test_array.py +++ b/tests/integration/test_array.py @@ -1,4 +1,3 @@ -import pytest from hugr import ops from hugr.std.int import IntVal @@ -22,10 +21,7 @@ def main(xs: array[float, 42]) -> int: validate(package) hg = package.modules[0] - vals = [data.op for node, data in hg.nodes() if isinstance(data.op, ops.Const)] - if len(vals) > 1: - pytest.xfail(reason="hugr-includes-whole-stdlib") - [val] = vals + [val] = [data.op for node, data in hg.nodes() if isinstance(data.op, ops.Const)] assert isinstance(val, ops.Const) assert isinstance(val.val, IntVal) assert val.val.v == 42 diff --git a/tests/integration/test_basic.py b/tests/integration/test_basic.py index 9fe3b11d..6a7108d0 100644 --- a/tests/integration/test_basic.py +++ b/tests/integration/test_basic.py @@ -1,4 +1,3 @@ -import pytest from hugr import ops from guppylang.decorator import guppy @@ -69,14 +68,11 @@ def test_func_def_name(): def func_name() -> None: return - defs = [ + [def_op] = [ data.op for n, data in func_name.modules[0].nodes() if isinstance(data.op, ops.FuncDefn) ] - if len(defs) > 1: - pytest.xfail(reason="hugr-includes-whole-stdlib") - [def_op] = defs assert isinstance(def_op, ops.FuncDefn) assert def_op.f_name == "func_name" diff --git a/tests/integration/test_extern.py b/tests/integration/test_extern.py index 83379546..e8e56e30 100644 --- a/tests/integration/test_extern.py +++ b/tests/integration/test_extern.py @@ -18,10 +18,7 @@ def main() -> float: validate(package) hg = package.modules[0] - consts = [data.op for n, data in hg.nodes() if isinstance(data.op, ops.Const)] - if len(consts) > 1: - pytest.xfail(reason="hugr-includes-whole-stdlib") - [c] = consts + [c] = [data.op for n, data in hg.nodes() if isinstance(data.op, ops.Const)] assert isinstance(c.val, val.Extension) assert c.val.val["symbol"] == "ext" @@ -39,10 +36,7 @@ def main() -> int: validate(package) hg = package.modules[0] - consts = [data.op for n, data in hg.nodes() if isinstance(data.op, ops.Const)] - if len(consts) > 1: - pytest.xfail(reason="hugr-includes-whole-stdlib") - [c] = consts + [c] = [data.op for n, data in hg.nodes() if isinstance(data.op, ops.Const)] assert isinstance(c.val, val.Extension) assert c.val.val["symbol"] == "foo"