diff --git a/pytket/phir/phirgen.py b/pytket/phir/phirgen.py index 510645e..19d5739 100644 --- a/pytket/phir/phirgen.py +++ b/pytket/phir/phirgen.py @@ -76,12 +76,14 @@ def arg_to_bit(arg: "UnitID") -> Bit: return [arg.reg_name, arg.index[0]] -def assign_cop(into: list[Var] | list[Bit], what: "Sequence[int]") -> dict[str, Any]: +def assign_cop( + lhs: list[Var] | list[Bit], rhs: "Sequence[Var | int]" +) -> dict[str, Any]: """PHIR for classical assign operation.""" return { "cop": "=", - "returns": into, - "args": what, + "returns": lhs, + "args": rhs, } @@ -136,8 +138,14 @@ def convert_subcmd(op: tk.Op, cmd: tk.Command) -> dict[str, Any]: if len(cmd.bits) != len(op.values): logger.error("LHS and RHS lengths mismatch for classical assignment") raise ValueError + return assign_cop([arg_to_bit(bit) for bit in cmd.bits], op.values) + + case tk.CopyBitsOp(): + if len(cmd.bits) != len(cmd.args) // 2: + logger.warning("LHS and RHS lengths mismatch for CopyBits") return assign_cop( - [arg_to_bit(cmd.bits[i]) for i in range(len(cmd.bits))], op.values + [arg_to_bit(bit) for bit in cmd.bits], + [arg_to_bit(cmd.args[i]) for i in range(len(cmd.args) // 2)], ) case _: @@ -159,9 +167,6 @@ def append_cmd(cmd: tk.Command, ops: list[dict[str, Any]]) -> None: else: op: dict[str, Any] | None = None match cmd.op: - case tk.SetBitsOp(): - op = convert_subcmd(cmd.op, cmd) - case tk.BarrierOp(): # TODO(kartik): confirm with Ciaran/spec # https://github.com/CQCL/phir/blob/main/spec.md @@ -205,6 +210,7 @@ def append_cmd(cmd: tk.Command, ops: list[dict[str, Any]]) -> None: "condition": cond, "true_branch": [assign_cop([arg_to_bit(cmd.bits[0])], [1])], } + case tk.ClassicalExpBox(): exp = cmd.op.get_exp() match exp.op: @@ -243,6 +249,10 @@ def append_cmd(cmd: tk.Command, ops: list[dict[str, Any]]) -> None: "cop": cop, "args": [arg["name"] for arg in exp.to_dict()["args"]], } + + case tk.ClassicalEvalOp(): + op = convert_subcmd(cmd.op, cmd) + case m: raise NotImplementedError(m) if op: diff --git a/pytket/phir/sharding/sharder.py b/pytket/phir/sharding/sharder.py index f51c96d..5271542 100644 --- a/pytket/phir/sharding/sharder.py +++ b/pytket/phir/sharding/sharder.py @@ -24,6 +24,7 @@ OpType.ClassicalExpBox, # some classical operations are rolled up into a box OpType.RangePredicate, OpType.ExplicitPredicate, + OpType.CopyBits, ] logger = logging.getLogger(__name__) diff --git a/tests/test_api.py b/tests/test_api.py index c46a4ba..6a01b88 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -6,10 +6,12 @@ # ############################################################################## +import json import logging import pytest +from pytket.circuit import Bit, Circuit from pytket.phir.api import pytket_to_phir from pytket.phir.qtm_machine import QtmMachine @@ -46,3 +48,24 @@ def test_pytket_to_phir_h1_1_all(self, test_file: QasmFile) -> None: circuit = get_qasm_as_circuit(test_file) assert pytket_to_phir(circuit, QtmMachine.H1_1) + + def test_pytket_classical_only(self) -> None: + c = Circuit(1) + a = c.add_c_register("a", 2) + b = c.add_c_register("b", 3) + + c.add_c_copyreg(a, b) + c.add_c_copybits([Bit("b", 2), Bit("a", 1)], [Bit("a", 0), Bit("b", 0)]) + + phir = json.loads(pytket_to_phir(c)) # type: ignore[misc] + + assert phir["ops"][3] == { # type: ignore[misc] + "cop": "=", + "returns": [["b", 0], ["b", 1]], + "args": [["a", 0], ["a", 1]], + } + assert phir["ops"][5] == { # type: ignore[misc] + "cop": "=", + "returns": [["a", 0], ["b", 0]], + "args": [["b", 2], ["a", 1]], + } diff --git a/tests/test_parallel_tk2.py b/tests/test_parallel_tk2.py index 3244ada..1df44e7 100644 --- a/tests/test_parallel_tk2.py +++ b/tests/test_parallel_tk2.py @@ -46,7 +46,7 @@ def test_pll_tk2() -> None: # it is the correct output for the tk2.qasm file # if you change the tk2.qasm file, you just re-generate the correct # phir json and replace the expected or the test will fail - expected = { + expected: dict[str, Any] = { "ops": [ {"data": "qvar_define", "data_type": "qubits", "variable": "q", "size": 4}, {"data": "cvar_define", "data_type": "u32", "variable": "c", "size": 4}, @@ -93,10 +93,18 @@ def test_pll_tk2() -> None: {"mop": "Transport", "duration": [0.0, "ms"]}, { "qop": "Measure", - "args": [["q", 3], ["q", 0], ["q", 1], ["q", 2]], - "returns": [["c", 3], ["c", 0], ["c", 1], ["c", 2]], + "args": [["q", 0], ["q", 1], ["q", 2], ["q", 3]], + "returns": [["c", 0], ["c", 1], ["c", 2], ["c", 3]], }, {"mop": "Transport", "duration": [0.0, "ms"]}, ], } - assert actual["ops"] == expected["ops"] + + assert actual["ops"][6]["block"] == "qparallel" + for op in expected["ops"][6]["ops"]: + assert op in actual["ops"][6]["ops"] + + act_meas_op = actual["ops"][8] + assert act_meas_op["qop"] == "Measure" + assert sorted(act_meas_op["args"]) == expected["ops"][8]["args"] + assert sorted(act_meas_op["returns"]) == expected["ops"][8]["returns"] diff --git a/tests/test_parallelization.py b/tests/test_parallelization.py index 16a0f7a..11b081c 100644 --- a/tests/test_parallelization.py +++ b/tests/test_parallelization.py @@ -39,7 +39,7 @@ def get_phir_json(qasmfile: QasmFile) -> dict[str, Any]: def test_bv_n10() -> None: """Make sure the parallelization is happening properly for the test circuit.""" actual = get_phir_json(QasmFile.parallelization_test) - expected = { + expected: dict[str, Any] = { "ops": [ {"data": "qvar_define", "data_type": "qubits", "variable": "q", "size": 4}, {"data": "cvar_define", "data_type": "u32", "variable": "c", "size": 4}, @@ -80,4 +80,12 @@ def test_bv_n10() -> None: {"mop": "Transport", "duration": [0.0, "ms"]}, ], } - assert actual["ops"] == expected["ops"] + + assert actual["ops"][7]["block"] == "qparallel" + for op in expected["ops"][7]["ops"]: + assert op in actual["ops"][7]["ops"] + + act_meas_op = actual["ops"][9] + assert act_meas_op["qop"] == "Measure" + assert sorted(act_meas_op["args"]) == expected["ops"][9]["args"] + assert sorted(act_meas_op["returns"]) == expected["ops"][9]["returns"]