Skip to content

Commit

Permalink
fix: Fix array execution bugs (#731)
Browse files Browse the repository at this point in the history
Bumps to Hugr on main and fixes some bugs to make arrays executable.
  • Loading branch information
mark-koch authored Dec 18, 2024
1 parent d0c2da4 commit 0f6ceaa
Show file tree
Hide file tree
Showing 13 changed files with 175 additions and 90 deletions.
18 changes: 5 additions & 13 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ inkwell = "0.5.0"
[patch.crates-io]

# Uncomment these to test the latest dependency version during development
# hugr = { git = "https://github.com/CQCL/hugr", rev = "861183e" }
# hugr-cli = { git = "https://github.com/CQCL/hugr", rev = "861183e" }
# hugr-llvm = { git = "https://github.com/CQCL/hugr", rev = "1091755" }
hugr = { git = "https://github.com/CQCL/hugr", rev = "ab94518" }
hugr-cli = { git = "https://github.com/CQCL/hugr", rev = "ab94518" }
hugr-llvm = { git = "https://github.com/CQCL/hugr", rev = "ab94518" }
12 changes: 10 additions & 2 deletions execute_llvm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ fn find_funcdef_node(hugr: impl HugrView, fn_name: &str) -> PyResult<hugr::Node>
}
}

fn guppy_pass(hugr: Hugr) -> Hugr {
let hugr = hugr::algorithms::monomorphize(hugr);
hugr::algorithms::remove_polyfuncs(hugr)
}

fn compile_module<'a>(
hugr: &'a hugr::Hugr,
ctx: &'a Context,
Expand All @@ -47,6 +52,7 @@ fn compile_module<'a>(
// TODO: Handle tket2 codegen extension
let extensions = hugr::llvm::custom::CodegenExtsBuilder::default()
.add_int_extensions()
.add_logic_extensions()
.add_default_prelude_extensions()
.add_default_array_extensions()
.add_float_extensions()
Expand All @@ -64,9 +70,10 @@ fn compile_module<'a>(

#[pyfunction]
fn compile_module_to_string(hugr_json: &str) -> PyResult<String> {
let hugr = parse_hugr(hugr_json)?;
let mut hugr = parse_hugr(hugr_json)?;
let ctx = Context::create();

hugr = guppy_pass(hugr);
let module = compile_module(&hugr, &ctx, Default::default())?;

Ok(module.print_to_string().to_str().unwrap().to_string())
Expand All @@ -77,7 +84,8 @@ fn run_function<T>(
fn_name: &str,
parse_result: impl FnOnce(&Context, GenericValue) -> PyResult<T>,
) -> PyResult<T> {
let hugr = parse_hugr(hugr_json)?;
let mut hugr = parse_hugr(hugr_json)?;
hugr = guppy_pass(hugr);
let ctx = Context::create();

let namer = hugr::llvm::emit::Namer::default();
Expand Down
57 changes: 29 additions & 28 deletions guppylang/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, TypeGuard, TypeVar

import hugr
import hugr.std.collections.array
import hugr.std.float
import hugr.std.int
import hugr.std.logic
Expand All @@ -21,7 +22,7 @@
from guppylang.checker.errors.generic import UnsupportedError
from guppylang.checker.linearity_checker import contains_subscript
from guppylang.compiler.core import CompilerBase, DFContainer
from guppylang.compiler.hugr_extension import PartialOp, UnsupportedOp
from guppylang.compiler.hugr_extension import PartialOp
from guppylang.definition.custom import CustomFunctionDef
from guppylang.definition.value import (
CallReturnWires,
Expand All @@ -46,6 +47,7 @@
TensorCall,
TypeApply,
)
from guppylang.std._internal.compiler.arithmetic import convert_ifromusize
from guppylang.std._internal.compiler.array import array_repeat
from guppylang.std._internal.compiler.list import (
list_new,
Expand Down Expand Up @@ -123,7 +125,7 @@ def _new_dfcontainer(
def _new_loop(
self,
loop_vars: list[PlaceNode],
branch: PlaceNode,
continue_predicate: PlaceNode,
) -> Iterator[None]:
"""Context manager to build a graph inside a new `TailLoop` node.
Expand All @@ -134,13 +136,12 @@ def _new_loop(
loop = self.builder.add_tail_loop([], loop_inputs)
with self._new_dfcontainer(loop_vars, loop):
yield
# Output the branch predicate and the inputs for the next iteration
loop.set_loop_outputs(
# Note that we have to do fresh calls to `self.visit` here since we're
# in a new context
self.visit(branch),
*(self.visit(name) for name in loop_vars),
)
# Output the branch predicate and the inputs for the next iteration. Note
# that we have to do fresh calls to `self.visit` here since we're in a new
# context
do_continue = self.visit(continue_predicate)
do_break = loop.add_op(hugr.std.logic.Not, do_continue)
loop.set_loop_outputs(do_break, *(self.visit(name) for name in loop_vars))
# Update the DFG with the outputs from the loop
for node, wire in zip(loop_vars, loop, strict=True):
self.dfg[node.place] = wire
Expand Down Expand Up @@ -172,12 +173,12 @@ def _if_true(self, cond: ast.expr, inputs: list[PlaceNode]) -> Iterator[None]:
conditional = self.builder.add_conditional(
self.visit(cond), *(self.visit(inp) for inp in inputs)
)
# If the condition is true, we enter the `with` block
with self._new_case(inputs, inputs, conditional, 0):
yield
# If the condition is false, output the inputs as is
with self._new_case(inputs, inputs, conditional, 1):
with self._new_case(inputs, inputs, conditional, 0):
pass
# If the condition is true, we enter the `with` block
with self._new_case(inputs, inputs, conditional, 1):
yield
# Update the DFG with the outputs from the Conditional node
for node, wire in zip(inputs, conditional, strict=True):
self.dfg[node.place] = wire
Expand Down Expand Up @@ -206,11 +207,16 @@ def visit_GlobalName(self, node: GlobalName) -> Wire:
return defn.load(self.dfg, self.globals, node)

def visit_GenericParamValue(self, node: GenericParamValue) -> Wire:
# TODO: We need a way to look up the concrete value of a generic type arg in
# Hugr. For example, a new op that captures the value during monomorphisation
return self.builder.add_op(
UnsupportedOp("load_type_param", [], [node.param.ty.to_hugr()]).ext_op
)
match node.param.ty:
case NumericType(NumericType.Kind.Nat):
arg = node.param.to_bound().to_hugr()
load_nat = hugr.std.PRELUDE.get_op("load_nat").instantiate(
[arg], ht.FunctionType([], [ht.USize()])
)
usize = self.builder.add_op(load_nat)
return self.builder.add_op(convert_ifromusize(), usize)
case _:
raise NotImplementedError

def visit_Name(self, node: ast.Name) -> Wire:
raise InternalGuppyError("Node should have been removed during type checking.")
Expand Down Expand Up @@ -604,17 +610,12 @@ def python_value_to_hugr(v: Any, exp_ty: Type) -> hv.Value | None:
return hv.Tuple(*vs)
case list(elts):
assert is_array_type(exp_ty)
vs = [python_value_to_hugr(elt, get_element_type(exp_ty)) for elt in elts]
elem_ty = get_element_type(exp_ty)
vs = [python_value_to_hugr(elt, elem_ty) for elt in elts]
if doesnt_contain_none(vs):
# TODO: Use proper array value: https://github.com/CQCL/hugr/issues/1497
return hv.Extension(
name="ArrayValue",
typ=exp_ty.to_hugr(),
# The value list must be serialized at this point, otherwise the
# `Extension` value would not be serializable.
val=[v._to_serial_root() for v in vs],
extensions=["unsupported"],
)
opt_ty = ht.Option(elem_ty.to_hugr())
opt_vs: list[hv.Value] = [hv.Some(v) for v in vs]
return hugr.std.collections.array.ArrayVal(opt_vs, opt_ty)
case _:
# TODO replace with hugr protocol handling: https://github.com/CQCL/guppylang/issues/563
# Pytket conversion is an experimental feature
Expand Down
19 changes: 15 additions & 4 deletions guppylang/compiler/stmt_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,26 @@ def pop(
array: Wire, length: int, pats: list[ast.expr], from_left: bool
) -> tuple[Wire, int]:
err = "Internal error: unpacking of iterable failed"
for pat in pats:
num_pats = len(pats)
# Pop the number of requested elements from the array
elts = []
for i in range(num_pats):
res = self.builder.add_op(
array_pop(opt_elt_ty, length, from_left), array
array_pop(opt_elt_ty, length - i, from_left), array
)
[elt_opt, array] = build_unwrap(self.builder, res, err)
[elt] = build_unwrap(self.builder, elt_opt, err)
elts.append(elt)
# Assign elements to the given patterns
for pat, elt in zip(
pats,
# Assignments are evaluated from left to right, so we need to assign in
# reverse order if we popped from the right
elts if from_left else reversed(elts),
strict=True,
):
self._assign(pat, elt)
length -= 1
return array, length
return array, length - num_pats

self.dfg[lhs.rhs_var.place] = port
array = self.expr_compiler.visit_DesugaredArrayComp(lhs.compr)
Expand Down
43 changes: 30 additions & 13 deletions guppylang/std/_internal/compiler/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from hugr import tys as ht
from hugr.std.collections.array import EXTENSION

from guppylang.compiler.hugr_extension import UnsupportedOp
from guppylang.definition.custom import CustomCallCompiler
from guppylang.definition.value import CallReturnWires
from guppylang.error import InternalGuppyError
Expand Down Expand Up @@ -92,24 +91,42 @@ def array_discard_empty(elem_ty: ht.Type) -> ops.ExtOp:
)


def array_scan(
elem_ty: ht.Type,
length: ht.TypeArg,
new_elem_ty: ht.Type,
accumulators: list[ht.Type],
) -> ops.ExtOp:
"""Returns an operation that maps and folds a function across an array."""
ty_args = [
length,
ht.TypeTypeArg(elem_ty),
ht.TypeTypeArg(new_elem_ty),
ht.SequenceArg([ht.TypeTypeArg(acc) for acc in accumulators]),
ht.ExtensionsArg([]),
]
ins = [
array_type(elem_ty, length),
ht.FunctionType([elem_ty, *accumulators], [new_elem_ty, *accumulators]),
*accumulators,
]
outs = [array_type(new_elem_ty, length), *accumulators]
return EXTENSION.get_op("scan").instantiate(ty_args, ht.FunctionType(ins, outs))


def array_map(elem_ty: ht.Type, length: ht.TypeArg, new_elem_ty: ht.Type) -> ops.ExtOp:
"""Returns an operation that maps a function across an array."""
# TODO
return UnsupportedOp(
op_name="array_map",
inputs=[array_type(elem_ty, length), ht.FunctionType([elem_ty], [new_elem_ty])],
outputs=[array_type(new_elem_ty, length)],
).ext_op
return array_scan(elem_ty, length, new_elem_ty, accumulators=[])


def array_repeat(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp:
"""Returns an array `repeat` operation."""
# TODO
return UnsupportedOp(
op_name="array.repeat",
inputs=[ht.FunctionType([], [elem_ty])],
outputs=[array_type(elem_ty, length)],
).ext_op
return EXTENSION.get_op("repeat").instantiate(
[length, ht.TypeTypeArg(elem_ty), ht.ExtensionsArg([])],
ht.FunctionType(
[ht.FunctionType([], [elem_ty])], [array_type(elem_ty, length)]
),
)


# ------------------------------------------------------
Expand Down
4 changes: 3 additions & 1 deletion guppylang/std/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ def __ge__(self: nat, other: nat) -> bool: ...
@guppy.hugr_op(int_op("igt_u"))
def __gt__(self: nat, other: nat) -> bool: ...

@guppy.hugr_op(int_op("iu_to_s"))
# TODO: Use "iu_to_s" once we have lowering:
# https://github.com/CQCL/hugr/issues/1806
@guppy.custom(NoopCompiler())
def __int__(self: nat) -> int: ...

@guppy.hugr_op(int_op("inot"))
Expand Down
3 changes: 1 addition & 2 deletions guppylang/tys/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ def _array_to_hugr(args: Sequence[Argument]) -> ht.Type:
elem_ty = ht.Option(ty_arg.ty.to_hugr())
hugr_arg = len_arg.to_hugr()

# TODO remove type ignore after Array type annotation fixed to include VariableArg
return hugr.std.collections.array.Array(elem_ty, hugr_arg) # type:ignore[arg-type]
return hugr.std.collections.array.Array(elem_ty, hugr_arg)


def _sized_iter_to_hugr(args: Sequence[Argument]) -> ht.Type:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ members = ["execute_llvm"]
execute-llvm = { workspace = true }

# Uncomment these to test the latest dependency version during development
# hugr = { git = "https://github.com/CQCL/hugr", subdirectory = "hugr-py", rev = "861183e" }
hugr = { git = "https://github.com/CQCL/hugr", subdirectory = "hugr-py", rev = "e40b6c7" }
# tket2-exts = { git = "https://github.com/CQCL/tket2", subdirectory = "tket2-exts", rev = "eb7cc63"}
# tket2 = { git = "https://github.com/CQCL/tket2", subdirectory = "tket2-py", rev = "eb7cc63"}

Expand Down
3 changes: 1 addition & 2 deletions tests/integration/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,7 @@ def main() -> int:
package = module.compile()
validate(package)

# TODO: Enable execution once lowering for missing ops is implemented
# run_int_fn(package, expected=9)
run_int_fn(package, expected=9)


def test_mem_swap(validate):
Expand Down
Loading

0 comments on commit 0f6ceaa

Please sign in to comment.