Skip to content

Commit

Permalink
feat: Only lower definitions to Hugr if they are used
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-koch committed Sep 13, 2024
1 parent f9aaaa9 commit b65f837
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 45 deletions.
44 changes: 41 additions & 3 deletions guppylang/compiler/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,55 @@
from typing import cast

from hugr import Wire, ops
from hugr.build.dfg import DP, DfBase
from hugr.build.dfg import DP, DefinitionBuilder, DfBase

from guppylang.checker.core import FieldAccess, Place, PlaceId, Variable
from guppylang.definition.common import CompiledDef, DefId
from guppylang.definition.common import CheckedDef, CompilableDef, CompiledDef, DefId
from guppylang.error import InternalGuppyError
from guppylang.tys.ty import StructType

CompiledGlobals = dict[DefId, CompiledDef]
CompiledLocals = dict[PlaceId, Wire]


class CompiledGlobals:
"""Compilation context containing all available definitions.
This context drives the Hugr lowering by keeping track of which definitions are
used when lowering the definitions provided via the `required` keyword. The
`worklist` field contains all definitions whose contents still need to lowered.
"""

module: DefinitionBuilder[ops.Module]
checked: dict[DefId, CheckedDef]
compiled: dict[DefId, CompiledDef]
worklist: set[DefId]

def __init__(
self,
checked: dict[DefId, CheckedDef],
required: set[DefId],
module: DefinitionBuilder[ops.Module],
) -> None:
self.module = module
self.checked = checked
self.worklist = required
self.compiled = {}
for def_id in required:
self.compiled[def_id] = self._compile(checked[def_id])

def __getitem__(self, def_id: DefId) -> CompiledDef:
if def_id not in self.compiled:
defn = self.checked[def_id]
self.compiled[def_id] = self._compile(defn)
self.worklist.add(def_id)
return self.compiled[def_id]

def _compile(self, defn: CheckedDef) -> CompiledDef:
if isinstance(defn, CompilableDef):
return defn.compile_outer(self.module)
return defn


@dataclass
class DFContainer:
"""A dataflow graph under construction.
Expand Down
31 changes: 15 additions & 16 deletions guppylang/compiler/func_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,26 +62,25 @@ def compile_local_func_def(

call_args.append(partial)
func.cfg.input_tys.append(func.ty)

# Compile the CFG
cfg = compile_cfg(func.cfg, func_builder, call_args, globals)
func_builder.set_outputs(*cfg)
else:
# Otherwise, we treat the function like a normal global variable
from guppylang.definition.function import CompiledFunctionDef

globals = globals | {
func.def_id: CompiledFunctionDef(
func.def_id,
func.name,
func,
func.ty,
{},
None,
func.cfg,
func_builder,
)
}

# Compile the CFG
cfg = compile_cfg(func.cfg, func_builder, call_args, globals)
func_builder.set_outputs(*cfg)
globals.compiled[func.def_id] = CompiledFunctionDef(
func.def_id,
func.name,
func,
func.ty,
{},
None,
func.cfg,
func_builder,
)
globals.worklist.add(func.def_id)

# Finally, load the function into the local data-flow graph
loaded = dfg.builder.load_function(func_builder, closure_ty)
Expand Down
5 changes: 5 additions & 0 deletions guppylang/definition/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import TYPE_CHECKING, ClassVar, TypeAlias

from hugr.build.dfg import DefinitionBuilder, OpVar
from hugr.ext import Package

if TYPE_CHECKING:
from guppylang.checker.core import Globals
Expand Down Expand Up @@ -72,6 +73,10 @@ def description(self) -> str:
a function, but got {description of this definition} instead".
"""

def compile(self) -> Package:
assert self.id.module is not None
return self.id.module.compile()


class ParsableDef(Definition):
"""Abstract base class for raw definitions that still require parsing.
Expand Down
17 changes: 9 additions & 8 deletions guppylang/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import guppylang.compiler.hugr_extension
from guppylang.checker.core import Globals, PyScope
from guppylang.compiler.core import CompiledGlobals
from guppylang.definition.common import (
CheckableDef,
CheckedDef,
Expand Down Expand Up @@ -325,19 +326,19 @@ def compile(self) -> Package:
return self._compiled_hugr

self.check()
checked_defs = self._imported_checked_defs | self._checked_defs

# Prepare Hugr for this module
graph = Module()
graph.metadata["name"] = self.name

# Compile definitions to Hugr
compiled_defs = self._compile_defs(self._imported_checked_defs, graph)
compiled_defs |= self._compile_defs(self._checked_defs, graph)

# Finally, compile the definition contents to Hugr. For example, this compiles
# the bodies of functions.
for defn in compiled_defs.values():
defn.compile_inner(compiled_defs)
# Lower definitions to Hugr
required = set(self._checked_defs.keys())
ctx = CompiledGlobals(checked_defs, required, graph)
while ctx.worklist:
next_id = ctx.worklist.pop()
next_def = ctx[next_id]
next_def.compile_inner(ctx)

hugr = graph.hugr

Expand Down
6 changes: 1 addition & 5 deletions tests/integration/test_array.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
from hugr import ops
from hugr.std.int import IntVal

Expand All @@ -22,10 +21,7 @@ def main(xs: array[float, 42]) -> int:
validate(package)

hg = package.modules[0]
vals = [data.op for node, data in hg.nodes() if isinstance(data.op, ops.Const)]
if len(vals) > 1:
pytest.xfail(reason="hugr-includes-whole-stdlib")
[val] = vals
[val] = [data.op for node, data in hg.nodes() if isinstance(data.op, ops.Const)]
assert isinstance(val, ops.Const)
assert isinstance(val.val, IntVal)
assert val.val.v == 42
Expand Down
6 changes: 1 addition & 5 deletions tests/integration/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
from hugr import ops

from guppylang.decorator import guppy
Expand Down Expand Up @@ -69,14 +68,11 @@ def test_func_def_name():
def func_name() -> None:
return

defs = [
[def_op] = [
data.op
for n, data in func_name.modules[0].nodes()
if isinstance(data.op, ops.FuncDefn)
]
if len(defs) > 1:
pytest.xfail(reason="hugr-includes-whole-stdlib")
[def_op] = defs
assert isinstance(def_op, ops.FuncDefn)
assert def_op.f_name == "func_name"

Expand Down
10 changes: 2 additions & 8 deletions tests/integration/test_extern.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@ def main() -> float:
validate(package)

hg = package.modules[0]
consts = [data.op for n, data in hg.nodes() if isinstance(data.op, ops.Const)]
if len(consts) > 1:
pytest.xfail(reason="hugr-includes-whole-stdlib")
[c] = consts
[c] = [data.op for n, data in hg.nodes() if isinstance(data.op, ops.Const)]
assert isinstance(c.val, val.Extension)
assert c.val.val["symbol"] == "ext"

Expand All @@ -39,10 +36,7 @@ def main() -> int:
validate(package)

hg = package.modules[0]
consts = [data.op for n, data in hg.nodes() if isinstance(data.op, ops.Const)]
if len(consts) > 1:
pytest.xfail(reason="hugr-includes-whole-stdlib")
[c] = consts
[c] = [data.op for n, data in hg.nodes() if isinstance(data.op, ops.Const)]
assert isinstance(c.val, val.Extension)
assert c.val.val["symbol"] == "foo"

Expand Down

0 comments on commit b65f837

Please sign in to comment.