Skip to content

Commit

Permalink
feat: Only lower definitions to Hugr if they are used (#496)
Browse files Browse the repository at this point in the history
Closes  #434 and closes #470.

---------

Co-authored-by: Alan Lawrence <[email protected]>
Co-authored-by: Douglas Wilson <[email protected]>
  • Loading branch information
3 people authored Sep 17, 2024
1 parent c867f48 commit cc2c8a4
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 59 deletions.
41 changes: 38 additions & 3 deletions guppylang/compiler/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,52 @@
from typing import cast

from hugr import Wire, ops
from hugr.build.dfg import DP, DfBase
from hugr.build.dfg import DP, DefinitionBuilder, DfBase

from guppylang.checker.core import FieldAccess, Place, PlaceId, Variable
from guppylang.definition.common import CompiledDef, DefId
from guppylang.definition.common import CheckedDef, CompilableDef, CompiledDef, DefId
from guppylang.error import InternalGuppyError
from guppylang.tys.ty import StructType

CompiledGlobals = dict[DefId, CompiledDef]
CompiledLocals = dict[PlaceId, Wire]


class CompiledGlobals:
"""Compilation context containing all available definitions.
Maintains a `worklist` of definitions which have been used by other compiled code
(i.e. `compile_outer` has been called) but have not yet been compiled/lowered
themselves (i.e. `compile_inner` has not yet been called).
"""

module: DefinitionBuilder[ops.Module]
checked: dict[DefId, CheckedDef]
compiled: dict[DefId, CompiledDef]
worklist: set[DefId]

def __init__(
self,
checked: dict[DefId, CheckedDef],
module: DefinitionBuilder[ops.Module],
) -> None:
self.module = module
self.checked = checked
self.worklist = set()
self.compiled = {}

def __getitem__(self, def_id: DefId) -> CompiledDef:
if def_id not in self.compiled:
defn = self.checked[def_id]
self.compiled[def_id] = self._compile(defn)
self.worklist.add(def_id)
return self.compiled[def_id]

def _compile(self, defn: CheckedDef) -> CompiledDef:
if isinstance(defn, CompilableDef):
return defn.compile_outer(self.module)
return defn


@dataclass
class DFContainer:
"""A dataflow graph under construction.
Expand Down
31 changes: 15 additions & 16 deletions guppylang/compiler/func_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,26 +62,25 @@ def compile_local_func_def(

call_args.append(partial)
func.cfg.input_tys.append(func.ty)

# Compile the CFG
cfg = compile_cfg(func.cfg, func_builder, call_args, globals)
func_builder.set_outputs(*cfg)
else:
# Otherwise, we treat the function like a normal global variable
from guppylang.definition.function import CompiledFunctionDef

globals = globals | {
func.def_id: CompiledFunctionDef(
func.def_id,
func.name,
func,
func.ty,
{},
None,
func.cfg,
func_builder,
)
}

# Compile the CFG
cfg = compile_cfg(func.cfg, func_builder, call_args, globals)
func_builder.set_outputs(*cfg)
globals.compiled[func.def_id] = CompiledFunctionDef(
func.def_id,
func.name,
func,
func.ty,
{},
None,
func.cfg,
func_builder,
)
globals.worklist.add(func.def_id) # will compile the CFG later

# Finally, load the function into the local data-flow graph
loaded = dfg.builder.load_function(func_builder, closure_ty)
Expand Down
32 changes: 10 additions & 22 deletions guppylang/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@

import guppylang.compiler.hugr_extension
from guppylang.checker.core import Globals, PyScope
from guppylang.compiler.core import CompiledGlobals
from guppylang.definition.common import (
CheckableDef,
CheckedDef,
CompilableDef,
CompiledDef,
DefId,
Definition,
ParsableDef,
Expand Down Expand Up @@ -267,18 +266,6 @@ def _check_defs(
for def_id, defn in parsed.items()
}

@staticmethod
def _compile_defs(
checked_defs: Mapping[DefId, CheckedDef], hugr_module: Module
) -> dict[DefId, CompiledDef]:
"""Helper method to compile checked definitions to Hugr."""
return {
def_id: defn.compile_outer(hugr_module)
if isinstance(defn, CompilableDef)
else defn
for def_id, defn in checked_defs.items()
}

def check(self) -> None:
"""Type-checks the module."""
if self.checked:
Expand Down Expand Up @@ -329,19 +316,20 @@ def compile(self) -> Package:
return self._compiled_hugr

self.check()
checked_defs = self._imported_checked_defs | self._checked_defs

# Prepare Hugr for this module
graph = Module()
graph.metadata["name"] = self.name

# Compile definitions to Hugr
compiled_defs = self._compile_defs(self._imported_checked_defs, graph)
compiled_defs |= self._compile_defs(self._checked_defs, graph)

# Finally, compile the definition contents to Hugr. For example, this compiles
# the bodies of functions.
for defn in compiled_defs.values():
defn.compile_inner(compiled_defs)
# Lower definitions to Hugr
required = set(self._checked_defs.keys())
ctx = CompiledGlobals(checked_defs, graph)
_request_compilation = [ctx[def_id] for def_id in required]
while ctx.worklist:
next_id = ctx.worklist.pop()
next_def = ctx[next_id]
next_def.compile_inner(ctx)

hugr = graph.hugr

Expand Down
6 changes: 1 addition & 5 deletions tests/integration/test_array.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
from hugr import ops
from hugr.std.int import IntVal

Expand All @@ -22,10 +21,7 @@ def main(xs: array[float, 42]) -> int:
validate(package)

hg = package.modules[0]
vals = [data.op for node, data in hg.nodes() if isinstance(data.op, ops.Const)]
if len(vals) > 1:
pytest.xfail(reason="hugr-includes-whole-stdlib")
[val] = vals
[val] = [data.op for node, data in hg.nodes() if isinstance(data.op, ops.Const)]
assert isinstance(val, ops.Const)
assert isinstance(val.val, IntVal)
assert val.val.v == 42
Expand Down
6 changes: 1 addition & 5 deletions tests/integration/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
from hugr import ops

from guppylang.decorator import guppy
Expand Down Expand Up @@ -69,14 +68,11 @@ def test_func_def_name():
def func_name() -> None:
return

defs = [
[def_op] = [
data.op
for n, data in func_name.modules[0].nodes()
if isinstance(data.op, ops.FuncDefn)
]
if len(defs) > 1:
pytest.xfail(reason="hugr-includes-whole-stdlib")
[def_op] = defs
assert isinstance(def_op, ops.FuncDefn)
assert def_op.f_name == "func_name"

Expand Down
10 changes: 2 additions & 8 deletions tests/integration/test_extern.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@ def main() -> float:
validate(package)

hg = package.modules[0]
consts = [data.op for n, data in hg.nodes() if isinstance(data.op, ops.Const)]
if len(consts) > 1:
pytest.xfail(reason="hugr-includes-whole-stdlib")
[c] = consts
[c] = [data.op for n, data in hg.nodes() if isinstance(data.op, ops.Const)]
assert isinstance(c.val, val.Extension)
assert c.val.val["symbol"] == "ext"

Expand All @@ -39,10 +36,7 @@ def main() -> int:
validate(package)

hg = package.modules[0]
consts = [data.op for n, data in hg.nodes() if isinstance(data.op, ops.Const)]
if len(consts) > 1:
pytest.xfail(reason="hugr-includes-whole-stdlib")
[c] = consts
[c] = [data.op for n, data in hg.nodes() if isinstance(data.op, ops.Const)]
assert isinstance(c.val, val.Extension)
assert c.val.val["symbol"] == "foo"

Expand Down

0 comments on commit cc2c8a4

Please sign in to comment.