From 0f6ceaa58648f105e601298081170eaec5a0e026 Mon Sep 17 00:00:00 2001 From: Mark Koch <48097969+mark-koch@users.noreply.github.com> Date: Wed, 18 Dec 2024 16:11:57 +0000 Subject: [PATCH] fix: Fix array execution bugs (#731) Bumps to Hugr on main and fixes some bugs to make arrays executable. --- Cargo.lock | 18 ++--- Cargo.toml | 6 +- execute_llvm/src/lib.rs | 12 +++- guppylang/compiler/expr_compiler.py | 57 ++++++++-------- guppylang/compiler/stmt_compiler.py | 19 ++++-- guppylang/std/_internal/compiler/array.py | 43 ++++++++---- guppylang/std/builtins.py | 4 +- guppylang/tys/builtin.py | 3 +- pyproject.toml | 2 +- tests/integration/test_array.py | 3 +- tests/integration/test_array_comprehension.py | 68 +++++++++++++++---- tests/integration/test_unpack.py | 22 +++++- uv.lock | 8 +-- 13 files changed, 175 insertions(+), 90 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ddb2de16..0cad9946 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -525,8 +525,7 @@ dependencies = [ [[package]] name = "hugr" version = "0.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f209c7cd671de29be8bdf0725e09b2e9d386387f439b13975e158f095e5a0fe" +source = "git+https://github.com/CQCL/hugr?rev=ab94518#ab94518ed2812abca615bfbfb5a822f67c115be8" dependencies = [ "hugr-core", "hugr-llvm", @@ -536,24 +535,20 @@ dependencies = [ [[package]] name = "hugr-cli" version = "0.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab6a94a980d47788908d7f93165846164f8b623b7f382cd3813bd0c0d1188e65" +source = "git+https://github.com/CQCL/hugr?rev=ab94518#ab94518ed2812abca615bfbfb5a822f67c115be8" dependencies = [ "clap", "clap-verbosity-flag", "clio", "derive_more", "hugr", - "serde", "serde_json", - "thiserror 2.0.7", ] [[package]] name = "hugr-core" version = "0.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60c3d5422f76dbec1d6948e68544b134562ec9ec087e8e6a599555b716f555dc" +source = "git+https://github.com/CQCL/hugr?rev=ab94518#ab94518ed2812abca615bfbfb5a822f67c115be8" dependencies = [ "bitvec", "bumpalo", @@ -585,12 +580,10 @@ dependencies = [ [[package]] name = "hugr-llvm" version = "0.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce4117f4f934b1033b82d8cb672b3c33c3a7f8f541c50f7cc7ff53cebb5816d1" +source = "git+https://github.com/CQCL/hugr?rev=ab94518#ab94518ed2812abca615bfbfb5a822f67c115be8" dependencies = [ "anyhow", "delegate", - "downcast-rs", "hugr-core", "inkwell", "itertools 0.13.0", @@ -602,8 +595,7 @@ dependencies = [ [[package]] name = "hugr-passes" version = "0.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec2591767b6fe03074d38de7c4e61d52b37cb2e73b7340bf4ff957ad4554022a" +source = "git+https://github.com/CQCL/hugr?rev=ab94518#ab94518ed2812abca615bfbfb5a822f67c115be8" dependencies = [ "ascent", "hugr-core", diff --git a/Cargo.toml b/Cargo.toml index 21c65419..6d600d42 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,6 @@ inkwell = "0.5.0" [patch.crates-io] # Uncomment these to test the latest dependency version during development -# hugr = { git = "https://github.com/CQCL/hugr", rev = "861183e" } -# hugr-cli = { git = "https://github.com/CQCL/hugr", rev = "861183e" } -# hugr-llvm = { git = "https://github.com/CQCL/hugr", rev = "1091755" } + hugr = { git = "https://github.com/CQCL/hugr", rev = "ab94518" } + hugr-cli = { git = "https://github.com/CQCL/hugr", rev = "ab94518" } + hugr-llvm = { git = "https://github.com/CQCL/hugr", rev = "ab94518" } diff --git a/execute_llvm/src/lib.rs b/execute_llvm/src/lib.rs index 5eb06261..3ff6d2d7 100644 --- a/execute_llvm/src/lib.rs +++ b/execute_llvm/src/lib.rs @@ -38,6 +38,11 @@ fn find_funcdef_node(hugr: impl HugrView, fn_name: &str) -> PyResult } } +fn guppy_pass(hugr: Hugr) -> Hugr { + let hugr = hugr::algorithms::monomorphize(hugr); + hugr::algorithms::remove_polyfuncs(hugr) +} + fn compile_module<'a>( hugr: &'a hugr::Hugr, ctx: &'a Context, @@ -47,6 +52,7 @@ fn compile_module<'a>( // TODO: Handle tket2 codegen extension let extensions = hugr::llvm::custom::CodegenExtsBuilder::default() .add_int_extensions() + .add_logic_extensions() .add_default_prelude_extensions() .add_default_array_extensions() .add_float_extensions() @@ -64,9 +70,10 @@ fn compile_module<'a>( #[pyfunction] fn compile_module_to_string(hugr_json: &str) -> PyResult { - let hugr = parse_hugr(hugr_json)?; + let mut hugr = parse_hugr(hugr_json)?; let ctx = Context::create(); + hugr = guppy_pass(hugr); let module = compile_module(&hugr, &ctx, Default::default())?; Ok(module.print_to_string().to_str().unwrap().to_string()) @@ -77,7 +84,8 @@ fn run_function( fn_name: &str, parse_result: impl FnOnce(&Context, GenericValue) -> PyResult, ) -> PyResult { - let hugr = parse_hugr(hugr_json)?; + let mut hugr = parse_hugr(hugr_json)?; + hugr = guppy_pass(hugr); let ctx = Context::create(); let namer = hugr::llvm::emit::Namer::default(); diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index b3d8184f..e724cdec 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -5,6 +5,7 @@ from typing import Any, TypeGuard, TypeVar import hugr +import hugr.std.collections.array import hugr.std.float import hugr.std.int import hugr.std.logic @@ -21,7 +22,7 @@ from guppylang.checker.errors.generic import UnsupportedError from guppylang.checker.linearity_checker import contains_subscript from guppylang.compiler.core import CompilerBase, DFContainer -from guppylang.compiler.hugr_extension import PartialOp, UnsupportedOp +from guppylang.compiler.hugr_extension import PartialOp from guppylang.definition.custom import CustomFunctionDef from guppylang.definition.value import ( CallReturnWires, @@ -46,6 +47,7 @@ TensorCall, TypeApply, ) +from guppylang.std._internal.compiler.arithmetic import convert_ifromusize from guppylang.std._internal.compiler.array import array_repeat from guppylang.std._internal.compiler.list import ( list_new, @@ -123,7 +125,7 @@ def _new_dfcontainer( def _new_loop( self, loop_vars: list[PlaceNode], - branch: PlaceNode, + continue_predicate: PlaceNode, ) -> Iterator[None]: """Context manager to build a graph inside a new `TailLoop` node. @@ -134,13 +136,12 @@ def _new_loop( loop = self.builder.add_tail_loop([], loop_inputs) with self._new_dfcontainer(loop_vars, loop): yield - # Output the branch predicate and the inputs for the next iteration - loop.set_loop_outputs( - # Note that we have to do fresh calls to `self.visit` here since we're - # in a new context - self.visit(branch), - *(self.visit(name) for name in loop_vars), - ) + # Output the branch predicate and the inputs for the next iteration. Note + # that we have to do fresh calls to `self.visit` here since we're in a new + # context + do_continue = self.visit(continue_predicate) + do_break = loop.add_op(hugr.std.logic.Not, do_continue) + loop.set_loop_outputs(do_break, *(self.visit(name) for name in loop_vars)) # Update the DFG with the outputs from the loop for node, wire in zip(loop_vars, loop, strict=True): self.dfg[node.place] = wire @@ -172,12 +173,12 @@ def _if_true(self, cond: ast.expr, inputs: list[PlaceNode]) -> Iterator[None]: conditional = self.builder.add_conditional( self.visit(cond), *(self.visit(inp) for inp in inputs) ) - # If the condition is true, we enter the `with` block - with self._new_case(inputs, inputs, conditional, 0): - yield # If the condition is false, output the inputs as is - with self._new_case(inputs, inputs, conditional, 1): + with self._new_case(inputs, inputs, conditional, 0): pass + # If the condition is true, we enter the `with` block + with self._new_case(inputs, inputs, conditional, 1): + yield # Update the DFG with the outputs from the Conditional node for node, wire in zip(inputs, conditional, strict=True): self.dfg[node.place] = wire @@ -206,11 +207,16 @@ def visit_GlobalName(self, node: GlobalName) -> Wire: return defn.load(self.dfg, self.globals, node) def visit_GenericParamValue(self, node: GenericParamValue) -> Wire: - # TODO: We need a way to look up the concrete value of a generic type arg in - # Hugr. For example, a new op that captures the value during monomorphisation - return self.builder.add_op( - UnsupportedOp("load_type_param", [], [node.param.ty.to_hugr()]).ext_op - ) + match node.param.ty: + case NumericType(NumericType.Kind.Nat): + arg = node.param.to_bound().to_hugr() + load_nat = hugr.std.PRELUDE.get_op("load_nat").instantiate( + [arg], ht.FunctionType([], [ht.USize()]) + ) + usize = self.builder.add_op(load_nat) + return self.builder.add_op(convert_ifromusize(), usize) + case _: + raise NotImplementedError def visit_Name(self, node: ast.Name) -> Wire: raise InternalGuppyError("Node should have been removed during type checking.") @@ -604,17 +610,12 @@ def python_value_to_hugr(v: Any, exp_ty: Type) -> hv.Value | None: return hv.Tuple(*vs) case list(elts): assert is_array_type(exp_ty) - vs = [python_value_to_hugr(elt, get_element_type(exp_ty)) for elt in elts] + elem_ty = get_element_type(exp_ty) + vs = [python_value_to_hugr(elt, elem_ty) for elt in elts] if doesnt_contain_none(vs): - # TODO: Use proper array value: https://github.com/CQCL/hugr/issues/1497 - return hv.Extension( - name="ArrayValue", - typ=exp_ty.to_hugr(), - # The value list must be serialized at this point, otherwise the - # `Extension` value would not be serializable. - val=[v._to_serial_root() for v in vs], - extensions=["unsupported"], - ) + opt_ty = ht.Option(elem_ty.to_hugr()) + opt_vs: list[hv.Value] = [hv.Some(v) for v in vs] + return hugr.std.collections.array.ArrayVal(opt_vs, opt_ty) case _: # TODO replace with hugr protocol handling: https://github.com/CQCL/guppylang/issues/563 # Pytket conversion is an experimental feature diff --git a/guppylang/compiler/stmt_compiler.py b/guppylang/compiler/stmt_compiler.py index 6f92854e..9132268c 100644 --- a/guppylang/compiler/stmt_compiler.py +++ b/guppylang/compiler/stmt_compiler.py @@ -115,15 +115,26 @@ def pop( array: Wire, length: int, pats: list[ast.expr], from_left: bool ) -> tuple[Wire, int]: err = "Internal error: unpacking of iterable failed" - for pat in pats: + num_pats = len(pats) + # Pop the number of requested elements from the array + elts = [] + for i in range(num_pats): res = self.builder.add_op( - array_pop(opt_elt_ty, length, from_left), array + array_pop(opt_elt_ty, length - i, from_left), array ) [elt_opt, array] = build_unwrap(self.builder, res, err) [elt] = build_unwrap(self.builder, elt_opt, err) + elts.append(elt) + # Assign elements to the given patterns + for pat, elt in zip( + pats, + # Assignments are evaluated from left to right, so we need to assign in + # reverse order if we popped from the right + elts if from_left else reversed(elts), + strict=True, + ): self._assign(pat, elt) - length -= 1 - return array, length + return array, length - num_pats self.dfg[lhs.rhs_var.place] = port array = self.expr_compiler.visit_DesugaredArrayComp(lhs.compr) diff --git a/guppylang/std/_internal/compiler/array.py b/guppylang/std/_internal/compiler/array.py index e4432360..b3195874 100644 --- a/guppylang/std/_internal/compiler/array.py +++ b/guppylang/std/_internal/compiler/array.py @@ -6,7 +6,6 @@ from hugr import tys as ht from hugr.std.collections.array import EXTENSION -from guppylang.compiler.hugr_extension import UnsupportedOp from guppylang.definition.custom import CustomCallCompiler from guppylang.definition.value import CallReturnWires from guppylang.error import InternalGuppyError @@ -92,24 +91,42 @@ def array_discard_empty(elem_ty: ht.Type) -> ops.ExtOp: ) +def array_scan( + elem_ty: ht.Type, + length: ht.TypeArg, + new_elem_ty: ht.Type, + accumulators: list[ht.Type], +) -> ops.ExtOp: + """Returns an operation that maps and folds a function across an array.""" + ty_args = [ + length, + ht.TypeTypeArg(elem_ty), + ht.TypeTypeArg(new_elem_ty), + ht.SequenceArg([ht.TypeTypeArg(acc) for acc in accumulators]), + ht.ExtensionsArg([]), + ] + ins = [ + array_type(elem_ty, length), + ht.FunctionType([elem_ty, *accumulators], [new_elem_ty, *accumulators]), + *accumulators, + ] + outs = [array_type(new_elem_ty, length), *accumulators] + return EXTENSION.get_op("scan").instantiate(ty_args, ht.FunctionType(ins, outs)) + + def array_map(elem_ty: ht.Type, length: ht.TypeArg, new_elem_ty: ht.Type) -> ops.ExtOp: """Returns an operation that maps a function across an array.""" - # TODO - return UnsupportedOp( - op_name="array_map", - inputs=[array_type(elem_ty, length), ht.FunctionType([elem_ty], [new_elem_ty])], - outputs=[array_type(new_elem_ty, length)], - ).ext_op + return array_scan(elem_ty, length, new_elem_ty, accumulators=[]) def array_repeat(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp: """Returns an array `repeat` operation.""" - # TODO - return UnsupportedOp( - op_name="array.repeat", - inputs=[ht.FunctionType([], [elem_ty])], - outputs=[array_type(elem_ty, length)], - ).ext_op + return EXTENSION.get_op("repeat").instantiate( + [length, ht.TypeTypeArg(elem_ty), ht.ExtensionsArg([])], + ht.FunctionType( + [ht.FunctionType([], [elem_ty])], [array_type(elem_ty, length)] + ), + ) # ------------------------------------------------------ diff --git a/guppylang/std/builtins.py b/guppylang/std/builtins.py index d9843bc2..85a9be7d 100644 --- a/guppylang/std/builtins.py +++ b/guppylang/std/builtins.py @@ -161,7 +161,9 @@ def __ge__(self: nat, other: nat) -> bool: ... @guppy.hugr_op(int_op("igt_u")) def __gt__(self: nat, other: nat) -> bool: ... - @guppy.hugr_op(int_op("iu_to_s")) + # TODO: Use "iu_to_s" once we have lowering: + # https://github.com/CQCL/hugr/issues/1806 + @guppy.custom(NoopCompiler()) def __int__(self: nat) -> int: ... @guppy.hugr_op(int_op("inot")) diff --git a/guppylang/tys/builtin.py b/guppylang/tys/builtin.py index f9785b21..740869dc 100644 --- a/guppylang/tys/builtin.py +++ b/guppylang/tys/builtin.py @@ -138,8 +138,7 @@ def _array_to_hugr(args: Sequence[Argument]) -> ht.Type: elem_ty = ht.Option(ty_arg.ty.to_hugr()) hugr_arg = len_arg.to_hugr() - # TODO remove type ignore after Array type annotation fixed to include VariableArg - return hugr.std.collections.array.Array(elem_ty, hugr_arg) # type:ignore[arg-type] + return hugr.std.collections.array.Array(elem_ty, hugr_arg) def _sized_iter_to_hugr(args: Sequence[Argument]) -> ht.Type: diff --git a/pyproject.toml b/pyproject.toml index 473993fb..29ceb926 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,7 +83,7 @@ members = ["execute_llvm"] execute-llvm = { workspace = true } # Uncomment these to test the latest dependency version during development -# hugr = { git = "https://github.com/CQCL/hugr", subdirectory = "hugr-py", rev = "861183e" } + hugr = { git = "https://github.com/CQCL/hugr", subdirectory = "hugr-py", rev = "e40b6c7" } # tket2-exts = { git = "https://github.com/CQCL/tket2", subdirectory = "tket2-exts", rev = "eb7cc63"} # tket2 = { git = "https://github.com/CQCL/tket2", subdirectory = "tket2-py", rev = "eb7cc63"} diff --git a/tests/integration/test_array.py b/tests/integration/test_array.py index ab462cff..9973539a 100644 --- a/tests/integration/test_array.py +++ b/tests/integration/test_array.py @@ -291,8 +291,7 @@ def main() -> int: package = module.compile() validate(package) - # TODO: Enable execution once lowering for missing ops is implemented - # run_int_fn(package, expected=9) + run_int_fn(package, expected=9) def test_mem_swap(validate): diff --git a/tests/integration/test_array_comprehension.py b/tests/integration/test_array_comprehension.py index e012b3a5..d967c705 100644 --- a/tests/integration/test_array_comprehension.py +++ b/tests/integration/test_array_comprehension.py @@ -9,12 +9,23 @@ from tests.util import compile_guppy -def test_basic(validate): - @compile_guppy +def test_basic_exec(validate, run_int_fn): + module = GuppyModule("test") + + @guppy(module) def test() -> array[int, 10]: return array(i + 1 for i in range(10)) - validate(test) + @guppy(module) + def main() -> int: + s = 0 + for x in test(): + s += x + return s + + package = module.compile() + validate(package) + run_int_fn(package, expected=sum(i + 1 for i in range(10))) def test_basic_linear(validate): @@ -29,23 +40,42 @@ def test() -> array[qubit, 42]: validate(module.compile()) -def test_zero_length(validate): - @compile_guppy +def test_zero_length(validate, run_int_fn): + module = GuppyModule("test") + + @guppy(module) def test() -> array[float, 0]: return array(i / 0 for i in range(0)) - validate(test) + @guppy(module) + def main() -> int: + test() + return 0 + package = module.compile() + validate(package) + run_int_fn(package, expected=0) -def test_capture(validate): - @compile_guppy + +def test_capture(validate, run_int_fn): + module = GuppyModule("test") + + @guppy(module) def test(x: int) -> array[int, 42]: return array(i + x for i in range(42)) - validate(test) + @guppy(module) + def main() -> int: + s = 0 + for x in test(3): + s += x + return s + + package = module.compile() + validate(package) + run_int_fn(package, expected=sum(i + 3 for i in range(42))) -@pytest.mark.skip("See https://github.com/CQCL/hugr/issues/1625") def test_capture_struct(validate): module = GuppyModule("test") @@ -71,12 +101,24 @@ def test() -> float: validate(test) -def test_nested_left(validate): - @compile_guppy +def test_nested_left(validate, run_int_fn): + module = GuppyModule("test") + + @guppy(module) def test() -> array[array[int, 10], 20]: return array(array(x + y for y in range(10)) for x in range(20)) - validate(test) + @guppy(module) + def main() -> int: + s = 0 + for xs in test(): + for x in xs: + s += x + return s + + package = module.compile() + validate(package) + run_int_fn(package, expected=sum(x + y for y in range(10) for x in range(20))) def test_generic(validate): diff --git a/tests/integration/test_unpack.py b/tests/integration/test_unpack.py index f399f22d..69f019f7 100644 --- a/tests/integration/test_unpack.py +++ b/tests/integration/test_unpack.py @@ -69,8 +69,7 @@ def main() -> int: compiled = module.compile() validate(compiled) - # TODO: Enable execution test once array lowering is fully supported - # run_int_fn(compiled, expected=9) + run_int_fn(compiled, expected=9) def test_unpack_tuple_starred(validate, run_int_fn): @@ -102,3 +101,22 @@ def main( return x, y, z, a, b, c validate(module.compile()) + + +def test_left_to_right(validate, run_int_fn): + module = GuppyModule("test") + + @guppy(module) + def left() -> int: + [x, x, *_] = range(10) + return x + + @guppy(module) + def right() -> int: + [*_, x, x] = range(10) + return x + + compiled = module.compile() + validate(compiled) + run_int_fn(compiled, fn_name="left", expected=1) + run_int_fn(compiled, fn_name="right", expected=9) diff --git a/uv.lock b/uv.lock index 8b5ef284..c61fbdd1 100644 --- a/uv.lock +++ b/uv.lock @@ -614,7 +614,7 @@ test = [ [package.metadata] requires-dist = [ { name = "graphviz", specifier = ">=0.20.1,<0.21" }, - { name = "hugr", specifier = ">=0.10.0,<0.11" }, + { name = "hugr", git = "https://github.com/CQCL/hugr?subdirectory=hugr-py&rev=e40b6c7" }, { name = "networkx", specifier = ">=3.2.1,<4" }, { name = "pydantic", specifier = ">=2.7.0b1,<3" }, { name = "pytket", marker = "extra == 'pytket'", specifier = ">=1.34" }, @@ -679,17 +679,13 @@ test = [ [[package]] name = "hugr" version = "0.10.0" -source = { registry = "https://pypi.org/simple" } +source = { git = "https://github.com/CQCL/hugr?subdirectory=hugr-py&rev=e40b6c7#e40b6c7057a15ead78bb18aa837e5b84e12a3722" } dependencies = [ { name = "graphviz" }, { name = "pydantic" }, { name = "pydantic-extra-types" }, { name = "semver" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a6/ef/9cd410ff0e3a92c5e88da2ef3c0e051dd971f4f6c5577873c7901ed31dd5/hugr-0.10.0.tar.gz", hash = "sha256:11e5a80ebd4e31cad0cb04d408cdd93a094e6fb817dd81481eedac5a58f86ff7", size = 129441 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f2/71/83556457cfe27f4a1613cd49041cfe4c6e9e087a53b5beec48a8d709c36d/hugr-0.10.0-py3-none-any.whl", hash = "sha256:591e252ef3e4182fd0de99274ebb4999ddd9572a0ec823519de154e4bd9f14ec", size = 83000 }, -] [[package]] name = "identify"