From c91e56205a9189594d244e294f45e49644a8fccd Mon Sep 17 00:00:00 2001 From: Alec Edgington <54802828+cqc-alec@users.noreply.github.com> Date: Fri, 15 Nov 2024 10:23:22 +0000 Subject: [PATCH] Extend `DecomposeClassicalExp` to handle `ClExprOp` (#1678) --- pytket/binders/passes.cpp | 6 +- pytket/docs/changelog.rst | 2 + pytket/pytket/_tket/passes.pyi | 2 +- pytket/pytket/circuit/decompose_classical.py | 168 +++++++++++++++++- pytket/tests/qasm_test.py | 14 +- .../tests/qasm_test_files/test18_output.qasm | 16 +- 6 files changed, 191 insertions(+), 17 deletions(-) diff --git a/pytket/binders/passes.cpp b/pytket/binders/passes.cpp index 9414ec9ae2..081937323d 100644 --- a/pytket/binders/passes.cpp +++ b/pytket/binders/passes.cpp @@ -111,7 +111,7 @@ static PassPtr gen_default_aas_routing_pass( const PassPtr &DecomposeClassicalExp() { // a special box decomposer for Circuits containing - // ClassicalExpBox + // ClassicalExpBox and ClExprOp static const PassPtr pp([]() { Transform t = Transform([](Circuit &circ) { py::module decomposer = @@ -483,8 +483,8 @@ PYBIND11_MODULE(passes, m) { py::arg("excluded_opgroups") = std::unordered_set()); m.def( "DecomposeClassicalExp", &DecomposeClassicalExp, - "Replaces each :py:class:`ClassicalExpBox` by a sequence of " - "classical gates."); + "Replaces each :py:class:`ClassicalExpBox` and `ClExprOp` by a sequence " + "of classical gates."); m.def( "DecomposeMultiQubitsCX", &DecomposeMultiQubitsCX, "Converts all multi-qubit gates into CX and single-qubit gates."); diff --git a/pytket/docs/changelog.rst b/pytket/docs/changelog.rst index fa8320bcdf..1e78fcedc6 100644 --- a/pytket/docs/changelog.rst +++ b/pytket/docs/changelog.rst @@ -18,6 +18,8 @@ Features: and `flatten_registers` * Implement `dagger()` and `transpose()` for `CustomGate`. * Use `ClExprOp` by default when converting from QASM. +* Extend `DecomposeClassicalExp` to handle `ClExprOp` as well as + `ClassicalExpBox`. Deprecations: diff --git a/pytket/pytket/_tket/passes.pyi b/pytket/pytket/_tket/passes.pyi index 600a954777..e9fd51a094 100644 --- a/pytket/pytket/_tket/passes.pyi +++ b/pytket/pytket/_tket/passes.pyi @@ -342,7 +342,7 @@ def DecomposeBoxes(excluded_types: set[pytket._tket.circuit.OpType] = set(), exc """ def DecomposeClassicalExp() -> BasePass: """ - Replaces each :py:class:`ClassicalExpBox` by a sequence of classical gates. + Replaces each :py:class:`ClassicalExpBox` and `ClExprOp` by a sequence of classical gates. """ def DecomposeMultiQubitsCX() -> BasePass: """ diff --git a/pytket/pytket/circuit/decompose_classical.py b/pytket/pytket/circuit/decompose_classical.py index f62493c1bd..11a7f1a71c 100644 --- a/pytket/pytket/circuit/decompose_classical.py +++ b/pytket/pytket/circuit/decompose_classical.py @@ -17,9 +17,20 @@ import copy from collections.abc import Callable from heapq import heappop, heappush -from typing import Generic, TypeVar - -from pytket._tket.circuit import Circuit, ClassicalExpBox, Conditional, OpType +from typing import Any, Generic, TypeVar + +from pytket._tket.circuit import ( + Circuit, + ClassicalExpBox, + ClBitVar, + ClExpr, + ClExprOp, + ClOp, + ClRegVar, + Conditional, + OpType, + WiredClExpr, +) from pytket._tket.unit_id import ( _TEMP_BIT_NAME, _TEMP_BIT_REG_BASE, @@ -27,6 +38,7 @@ Bit, BitRegister, ) +from pytket.circuit.clexpr import check_register_alignments, has_reg_output from pytket.circuit.logic_exp import ( BitLogicExp, BitWiseOp, @@ -242,8 +254,131 @@ def recursive_walk( return recursive_walk +class ClExprDecomposer: + def __init__( + self, + circ: Circuit, + bit_posn: dict[int, int], + reg_posn: dict[int, list[int]], + args: list[Bit], + bit_heap: BitHeap, + reg_heap: RegHeap, + kwargs: dict[str, Any], + ): + self.circ: Circuit = circ + self.bit_posn: dict[int, int] = bit_posn + self.reg_posn: dict[int, list[int]] = reg_posn + self.args: list[Bit] = args + self.bit_heap: BitHeap = bit_heap + self.reg_heap: RegHeap = reg_heap + self.kwargs: dict[str, Any] = kwargs + # Construct maps from int (i.e. ClBitVar) to Bit, and from int (i.e. ClRegVar) + # to BitRegister: + self.bit_vars = {i: args[p] for i, p in bit_posn.items()} + self.reg_vars = { + i: BitRegister(args[p[0]].reg_name, len(p)) for i, p in reg_posn.items() + } + + def add_var(self, var: Variable) -> None: + """Add a Bit or BitRegister to the circuit if not already present.""" + if isinstance(var, Bit): + self.circ.add_bit(var, reject_dups=False) + else: + assert isinstance(var, BitRegister) + for bit in var.to_list(): + self.circ.add_bit(bit, reject_dups=False) + + def set_bits(self, var: Variable, val: int) -> None: + """Set the value of a Bit or BitRegister.""" + assert val >= 0 + if isinstance(var, Bit): + assert val >> 1 == 0 + self.circ.add_c_setbits([bool(val)], [var], **self.kwargs) + else: + assert isinstance(var, BitRegister) + assert val >> var.size == 0 + self.circ.add_c_setreg(val, var, **self.kwargs) + + def decompose_expr(self, expr: ClExpr, out_var: Variable | None) -> Variable: + """Add the decomposed expression to the circuit and return the Bit or + BitRegister that contains the result. + + :param expr: the expression to decompose + :param out_var: where to put the output (if None, create a new scratch location) + """ + op: ClOp = expr.op + heap: VarHeap = self.reg_heap if has_reg_output(op) else self.bit_heap + + # Eliminate (recursively) subsidiary expressions from the arguments, and convert + # all terms to Bit or BitRegister: + terms: list[Variable] = [] + for arg in expr.args: + if isinstance(arg, int): + # Assign to a fresh variable + fresh_var = heap.fresh_var() + self.add_var(fresh_var) + self.set_bits(fresh_var, arg) + terms.append(fresh_var) + elif isinstance(arg, ClBitVar): + terms.append(self.bit_vars[arg.index]) + elif isinstance(arg, ClRegVar): + terms.append(self.reg_vars[arg.index]) + else: + assert isinstance(arg, ClExpr) + terms.append(self.decompose_expr(arg, None)) + + # Enable reuse of temporary terms: + for term in terms: + if heap.is_heap_var(term): + heap.push(term) + + if out_var is None: + out_var = heap.fresh_var() + self.add_var(out_var) + match op: + case ClOp.BitAnd: + self.circ.add_c_and(*terms, out_var, **self.kwargs) # type: ignore + case ClOp.BitNot: + self.circ.add_c_not(*terms, out_var, **self.kwargs) # type: ignore + case ClOp.BitOne: + assert isinstance(out_var, Bit) + self.circ.add_c_setbits([True], [out_var], **self.kwargs) + case ClOp.BitOr: + self.circ.add_c_or(*terms, out_var, **self.kwargs) # type: ignore + case ClOp.BitXor: + self.circ.add_c_xor(*terms, out_var, **self.kwargs) # type: ignore + case ClOp.BitZero: + assert isinstance(out_var, Bit) + self.circ.add_c_setbits([False], [out_var], **self.kwargs) + case ClOp.RegAnd: + self.circ.add_c_and_to_registers(*terms, out_var, **self.kwargs) # type: ignore + case ClOp.RegNot: + self.circ.add_c_not_to_registers(*terms, out_var, **self.kwargs) # type: ignore + case ClOp.RegOne: + assert isinstance(out_var, BitRegister) + self.circ.add_c_setbits( + [True] * out_var.size, out_var.to_list(), **self.kwargs + ) + case ClOp.RegOr: + self.circ.add_c_or_to_registers(*terms, out_var, **self.kwargs) # type: ignore + case ClOp.RegXor: + self.circ.add_c_xor_to_registers(*terms, out_var, **self.kwargs) # type: ignore + case ClOp.RegZero: + assert isinstance(out_var, BitRegister) + self.circ.add_c_setbits( + [False] * out_var.size, out_var.to_list(), **self.kwargs + ) + case _: + raise DecomposeClassicalError( + f"{op} cannot be decomposed to TKET primitives." + ) + return out_var + + def _decompose_expressions(circ: Circuit) -> tuple[Circuit, bool]: - """Rewrite a circuit command-wise, decomposing ClassicalExpBox.""" + """Rewrite a circuit command-wise, decomposing ClassicalExpBox and ClExprOp.""" + if not check_register_alignments(circ): + raise DecomposeClassicalError("Circuit contains non-register-aligned ClExprOp.") bit_heap = BitHeap() reg_heap = RegHeap() # add already used heap variables to heaps @@ -343,6 +478,31 @@ def _decompose_expressions(circ: Circuit) -> tuple[Circuit, bool]: replace_targets[out_reg] = comp_reg modified = True continue + + elif optype == OpType.ClExpr: + assert isinstance(op, ClExprOp) + wexpr: WiredClExpr = op.expr + expr: ClExpr = wexpr.expr + bit_posn = wexpr.bit_posn + reg_posn = wexpr.reg_posn + output_posn = wexpr.output_posn + assert len(output_posn) > 0 + output0 = args[output_posn[0]] + assert isinstance(output0, Bit) + out_var: Variable = ( + BitRegister(output0.reg_name, len(output_posn)) + if has_reg_output(expr.op) + else output0 + ) + decomposer = ClExprDecomposer( + newcirc, bit_posn, reg_posn, args, bit_heap, reg_heap, kwargs # type: ignore + ) + comp_var = decomposer.decompose_expr(expr, out_var) + if comp_var != out_var: + replace_targets[out_var] = comp_var + modified = True + continue + if optype == OpType.Barrier: # add_gate doesn't work for metaops newcirc.add_barrier(args) diff --git a/pytket/tests/qasm_test.py b/pytket/tests/qasm_test.py index 574546c552..a3824c03fe 100644 --- a/pytket/tests/qasm_test.py +++ b/pytket/tests/qasm_test.py @@ -37,6 +37,7 @@ reg_lt, reg_neq, ) +from pytket.circuit.decompose_classical import DecomposeClassicalError from pytket.circuit.logic_exp import BitWiseOp, create_bit_logic_exp from pytket.passes import DecomposeBoxes, DecomposeClassicalExp from pytket.qasm import ( @@ -464,14 +465,18 @@ def test_extended_qasm() -> None: assert circuit_to_qasm_str(c2, "hqslib1") - assert not DecomposeClassicalExp().apply(c) + with pytest.raises(DecomposeClassicalError): + DecomposeClassicalExp().apply(c) -def test_decomposable_extended() -> None: +@pytest.mark.parametrize("use_clexpr", [True, False]) +def test_decomposable_extended(use_clexpr: bool) -> None: fname = str(curr_file_path / "qasm_test_files/test18.qasm") out_fname = str(curr_file_path / "qasm_test_files/test18_output.qasm") - c = circuit_from_qasm_wasm(fname, "testfile.wasm", maxwidth=64, use_clexpr=True) + c = circuit_from_qasm_wasm( + fname, "testfile.wasm", maxwidth=64, use_clexpr=use_clexpr + ) DecomposeClassicalExp().apply(c) out_qasm = circuit_to_qasm_str(c, "hqslib1", maxwidth=64) @@ -1233,7 +1238,8 @@ def test_multibitop() -> None: test_hqs_conditional_params() test_barrier() test_barrier_2() - test_decomposable_extended() + test_decomposable_extended(True) + test_decomposable_extended(False) test_alternate_encoding() test_header_stops_gate_definition() test_tk2_definition() diff --git a/pytket/tests/qasm_test_files/test18_output.qasm b/pytket/tests/qasm_test_files/test18_output.qasm index c635f280ab..f2ab0b7d65 100644 --- a/pytket/tests/qasm_test_files/test18_output.qasm +++ b/pytket/tests/qasm_test_files/test18_output.qasm @@ -6,16 +6,22 @@ creg a[2]; creg b[3]; creg c[4]; creg d[1]; +creg tk_SCRATCH_BIT[7]; +creg tk_SCRATCH_BITREG_0[64]; c = 2; +tk_SCRATCH_BITREG_0[0] = b[0] & a[0]; +tk_SCRATCH_BITREG_0[1] = b[1] & a[1]; c[0] = a[0]; c[1] = a[1]; -if(b!=2) c[1] = ((b[1] & a[1]) | a[0]); -c = ((b & a) | d); -d[0] = (a[0] ^ 1); -a = CCE(a, b); +if(b!=2) tk_SCRATCH_BIT[6] = b[1] & a[1]; +c[0] = tk_SCRATCH_BITREG_0[0] | d[0]; +if(b!=2) c[1] = tk_SCRATCH_BIT[6] | a[0]; +tk_SCRATCH_BIT[6] = 1; +d[0] = a[0] ^ tk_SCRATCH_BIT[6]; if(c>=2) h q[0]; -CCE(c); +a = CCE(a, b); if(c<=2) h q[0]; +CCE(c); if(c<=1) h q[0]; if(c>=3) h q[0]; if(c!=2) h q[0];