Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend DecomposeClassicalExp to handle ClExprOp #1678

Merged
merged 7 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pytket/binders/passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ static PassPtr gen_default_aas_routing_pass(

const PassPtr &DecomposeClassicalExp() {
// a special box decomposer for Circuits containing
// ClassicalExpBox<py::object>
// ClassicalExpBox<py::object> and ClExprOp
static const PassPtr pp([]() {
Transform t = Transform([](Circuit &circ) {
py::module decomposer =
Expand Down Expand Up @@ -483,8 +483,8 @@ PYBIND11_MODULE(passes, m) {
py::arg("excluded_opgroups") = std::unordered_set<std::string>());
m.def(
"DecomposeClassicalExp", &DecomposeClassicalExp,
"Replaces each :py:class:`ClassicalExpBox` by a sequence of "
"classical gates.");
"Replaces each :py:class:`ClassicalExpBox` and `ClExprOp` by a sequence "
"of classical gates.");
m.def(
"DecomposeMultiQubitsCX", &DecomposeMultiQubitsCX,
"Converts all multi-qubit gates into CX and single-qubit gates.");
Expand Down
2 changes: 2 additions & 0 deletions pytket/docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ Features:
and `flatten_registers`
* Implement `dagger()` and `transpose()` for `CustomGate`.
* Use `ClExprOp` by default when converting from QASM.
* Extend `DecomposeClassicalExp` to handle `ClExprOp` as well as
`ClassicalExpBox`.

Deprecations:

Expand Down
2 changes: 1 addition & 1 deletion pytket/pytket/_tket/passes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def DecomposeBoxes(excluded_types: set[pytket._tket.circuit.OpType] = set(), exc
"""
def DecomposeClassicalExp() -> BasePass:
"""
Replaces each :py:class:`ClassicalExpBox` by a sequence of classical gates.
Replaces each :py:class:`ClassicalExpBox` and `ClExprOp` by a sequence of classical gates.
"""
def DecomposeMultiQubitsCX() -> BasePass:
"""
Expand Down
168 changes: 164 additions & 4 deletions pytket/pytket/circuit/decompose_classical.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,28 @@
import copy
from collections.abc import Callable
from heapq import heappop, heappush
from typing import Generic, TypeVar

from pytket._tket.circuit import Circuit, ClassicalExpBox, Conditional, OpType
from typing import Any, Generic, TypeVar

from pytket._tket.circuit import (
Circuit,
ClassicalExpBox,
ClBitVar,
ClExpr,
ClExprOp,
ClOp,
ClRegVar,
Conditional,
OpType,
WiredClExpr,
)
from pytket._tket.unit_id import (
_TEMP_BIT_NAME,
_TEMP_BIT_REG_BASE,
_TEMP_REG_SIZE,
Bit,
BitRegister,
)
from pytket.circuit.clexpr import check_register_alignments, has_reg_output
from pytket.circuit.logic_exp import (
BitLogicExp,
BitWiseOp,
Expand Down Expand Up @@ -242,8 +254,131 @@ def recursive_walk(
return recursive_walk


class ClExprDecomposer:
def __init__(
self,
circ: Circuit,
bit_posn: dict[int, int],
reg_posn: dict[int, list[int]],
args: list[Bit],
bit_heap: BitHeap,
reg_heap: RegHeap,
kwargs: dict[str, Any],
):
self.circ: Circuit = circ
self.bit_posn: dict[int, int] = bit_posn
self.reg_posn: dict[int, list[int]] = reg_posn
self.args: list[Bit] = args
self.bit_heap: BitHeap = bit_heap
self.reg_heap: RegHeap = reg_heap
self.kwargs: dict[str, Any] = kwargs
# Construct maps from int (i.e. ClBitVar) to Bit, and from int (i.e. ClRegVar)
# to BitRegister:
self.bit_vars = {i: args[p] for i, p in bit_posn.items()}
self.reg_vars = {
i: BitRegister(args[p[0]].reg_name, len(p)) for i, p in reg_posn.items()
}

def add_var(self, var: Variable) -> None:
"""Add a Bit or BitRegister to the circuit if not already present."""
if isinstance(var, Bit):
self.circ.add_bit(var, reject_dups=False)
else:
assert isinstance(var, BitRegister)
for bit in var.to_list():
self.circ.add_bit(bit, reject_dups=False)

def set_bits(self, var: Variable, val: int) -> None:
"""Set the value of a Bit or BitRegister."""
assert val >= 0
if isinstance(var, Bit):
assert val >> 1 == 0
self.circ.add_c_setbits([bool(val)], [var], **self.kwargs)
else:
assert isinstance(var, BitRegister)
assert val >> var.size == 0
self.circ.add_c_setreg(val, var, **self.kwargs)

def decompose_expr(self, expr: ClExpr, out_var: Variable | None) -> Variable:
"""Add the decomposed expression to the circuit and return the Bit or
BitRegister that contains the result.

:param expr: the expression to decompose
:param out_var: where to put the output (if None, create a new scratch location)
"""
op: ClOp = expr.op
heap: VarHeap = self.reg_heap if has_reg_output(op) else self.bit_heap

# Eliminate (recursively) subsidiary expressions from the arguments, and convert
# all terms to Bit or BitRegister:
terms: list[Variable] = []
for arg in expr.args:
if isinstance(arg, int):
# Assign to a fresh variable
fresh_var = heap.fresh_var()
self.add_var(fresh_var)
self.set_bits(fresh_var, arg)
terms.append(fresh_var)
elif isinstance(arg, ClBitVar):
terms.append(self.bit_vars[arg.index])
elif isinstance(arg, ClRegVar):
terms.append(self.reg_vars[arg.index])
else:
assert isinstance(arg, ClExpr)
terms.append(self.decompose_expr(arg, None))

# Enable reuse of temporary terms:
for term in terms:
if heap.is_heap_var(term):
heap.push(term)

if out_var is None:
out_var = heap.fresh_var()
self.add_var(out_var)
match op:
case ClOp.BitAnd:
self.circ.add_c_and(*terms, out_var, **self.kwargs) # type: ignore
case ClOp.BitNot:
self.circ.add_c_not(*terms, out_var, **self.kwargs) # type: ignore
case ClOp.BitOne:
assert isinstance(out_var, Bit)
self.circ.add_c_setbits([True], [out_var], **self.kwargs)
case ClOp.BitOr:
self.circ.add_c_or(*terms, out_var, **self.kwargs) # type: ignore
case ClOp.BitXor:
self.circ.add_c_xor(*terms, out_var, **self.kwargs) # type: ignore
case ClOp.BitZero:
assert isinstance(out_var, Bit)
self.circ.add_c_setbits([False], [out_var], **self.kwargs)
case ClOp.RegAnd:
self.circ.add_c_and_to_registers(*terms, out_var, **self.kwargs) # type: ignore
case ClOp.RegNot:
self.circ.add_c_not_to_registers(*terms, out_var, **self.kwargs) # type: ignore
case ClOp.RegOne:
assert isinstance(out_var, BitRegister)
self.circ.add_c_setbits(
[True] * out_var.size, out_var.to_list(), **self.kwargs
)
case ClOp.RegOr:
self.circ.add_c_or_to_registers(*terms, out_var, **self.kwargs) # type: ignore
case ClOp.RegXor:
self.circ.add_c_xor_to_registers(*terms, out_var, **self.kwargs) # type: ignore
case ClOp.RegZero:
assert isinstance(out_var, BitRegister)
self.circ.add_c_setbits(
[False] * out_var.size, out_var.to_list(), **self.kwargs
)
case _:
raise DecomposeClassicalError(
f"{op} cannot be decomposed to TKET primitives."
)
return out_var


def _decompose_expressions(circ: Circuit) -> tuple[Circuit, bool]:
"""Rewrite a circuit command-wise, decomposing ClassicalExpBox."""
"""Rewrite a circuit command-wise, decomposing ClassicalExpBox and ClExprOp."""
if not check_register_alignments(circ):
raise DecomposeClassicalError("Circuit contains non-register-aligned ClExprOp.")
bit_heap = BitHeap()
reg_heap = RegHeap()
# add already used heap variables to heaps
Expand Down Expand Up @@ -343,6 +478,31 @@ def _decompose_expressions(circ: Circuit) -> tuple[Circuit, bool]:
replace_targets[out_reg] = comp_reg
modified = True
continue

elif optype == OpType.ClExpr:
assert isinstance(op, ClExprOp)
wexpr: WiredClExpr = op.expr
expr: ClExpr = wexpr.expr
bit_posn = wexpr.bit_posn
reg_posn = wexpr.reg_posn
output_posn = wexpr.output_posn
assert len(output_posn) > 0
output0 = args[output_posn[0]]
assert isinstance(output0, Bit)
out_var: Variable = (
BitRegister(output0.reg_name, len(output_posn))
if has_reg_output(expr.op)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it not possible to have an incomplete register here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(for example when using temp bits?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check at line 380 would catch this case and throw an exception. We do not allow this pass to be run on circuits with non-register-aligned ClExprOp (even though they are valid pytket circuits). I think removing this restriction would be possible, but would make decomposition more complicated, so prefer to keep it unless there is a real need.

Temporary (scratch) bits would not normally form part of a ClExprOp. They are introduced when decomposing it into elementary classical operations.

else output0
)
decomposer = ClExprDecomposer(
newcirc, bit_posn, reg_posn, args, bit_heap, reg_heap, kwargs # type: ignore
)
comp_var = decomposer.decompose_expr(expr, out_var)
if comp_var != out_var:
replace_targets[out_var] = comp_var
modified = True
continue

if optype == OpType.Barrier:
# add_gate doesn't work for metaops
newcirc.add_barrier(args)
Expand Down
14 changes: 10 additions & 4 deletions pytket/tests/qasm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
reg_lt,
reg_neq,
)
from pytket.circuit.decompose_classical import DecomposeClassicalError
from pytket.circuit.logic_exp import BitWiseOp, create_bit_logic_exp
from pytket.passes import DecomposeBoxes, DecomposeClassicalExp
from pytket.qasm import (
Expand Down Expand Up @@ -464,14 +465,18 @@ def test_extended_qasm() -> None:

assert circuit_to_qasm_str(c2, "hqslib1")

assert not DecomposeClassicalExp().apply(c)
with pytest.raises(DecomposeClassicalError):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This reverts to the test as it was before #1628 .

DecomposeClassicalExp().apply(c)


def test_decomposable_extended() -> None:
@pytest.mark.parametrize("use_clexpr", [True, False])
def test_decomposable_extended(use_clexpr: bool) -> None:
fname = str(curr_file_path / "qasm_test_files/test18.qasm")
out_fname = str(curr_file_path / "qasm_test_files/test18_output.qasm")

c = circuit_from_qasm_wasm(fname, "testfile.wasm", maxwidth=64, use_clexpr=True)
c = circuit_from_qasm_wasm(
fname, "testfile.wasm", maxwidth=64, use_clexpr=use_clexpr
)
DecomposeClassicalExp().apply(c)

out_qasm = circuit_to_qasm_str(c, "hqslib1", maxwidth=64)
Expand Down Expand Up @@ -1233,7 +1238,8 @@ def test_multibitop() -> None:
test_hqs_conditional_params()
test_barrier()
test_barrier_2()
test_decomposable_extended()
test_decomposable_extended(True)
test_decomposable_extended(False)
test_alternate_encoding()
test_header_stops_gate_definition()
test_tk2_definition()
Expand Down
16 changes: 11 additions & 5 deletions pytket/tests/qasm_test_files/test18_output.qasm
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this reverts to the file as it was before #1628 .

Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,22 @@ creg a[2];
creg b[3];
creg c[4];
creg d[1];
creg tk_SCRATCH_BIT[7];
creg tk_SCRATCH_BITREG_0[64];
c = 2;
tk_SCRATCH_BITREG_0[0] = b[0] & a[0];
tk_SCRATCH_BITREG_0[1] = b[1] & a[1];
c[0] = a[0];
c[1] = a[1];
if(b!=2) c[1] = ((b[1] & a[1]) | a[0]);
c = ((b & a) | d);
d[0] = (a[0] ^ 1);
a = CCE(a, b);
if(b!=2) tk_SCRATCH_BIT[6] = b[1] & a[1];
c[0] = tk_SCRATCH_BITREG_0[0] | d[0];
if(b!=2) c[1] = tk_SCRATCH_BIT[6] | a[0];
tk_SCRATCH_BIT[6] = 1;
d[0] = a[0] ^ tk_SCRATCH_BIT[6];
if(c>=2) h q[0];
CCE(c);
a = CCE(a, b);
if(c<=2) h q[0];
CCE(c);
if(c<=1) h q[0];
if(c>=3) h q[0];
if(c!=2) h q[0];
Expand Down
Loading