Skip to content

Commit

Permalink
Merge pull request #67 from CQCL/copybits
Browse files Browse the repository at this point in the history
  • Loading branch information
qartik authored Dec 14, 2023
2 parents 0a11b11 + 7eccd8c commit 5cd46aa
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 13 deletions.
24 changes: 17 additions & 7 deletions pytket/phir/phirgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,14 @@ def arg_to_bit(arg: "UnitID") -> Bit:
return [arg.reg_name, arg.index[0]]


def assign_cop(into: list[Var] | list[Bit], what: "Sequence[int]") -> dict[str, Any]:
def assign_cop(
lhs: list[Var] | list[Bit], rhs: "Sequence[Var | int]"
) -> dict[str, Any]:
"""PHIR for classical assign operation."""
return {
"cop": "=",
"returns": into,
"args": what,
"returns": lhs,
"args": rhs,
}


Expand Down Expand Up @@ -136,8 +138,14 @@ def convert_subcmd(op: tk.Op, cmd: tk.Command) -> dict[str, Any]:
if len(cmd.bits) != len(op.values):
logger.error("LHS and RHS lengths mismatch for classical assignment")
raise ValueError
return assign_cop([arg_to_bit(bit) for bit in cmd.bits], op.values)

case tk.CopyBitsOp():
if len(cmd.bits) != len(cmd.args) // 2:
logger.warning("LHS and RHS lengths mismatch for CopyBits")
return assign_cop(
[arg_to_bit(cmd.bits[i]) for i in range(len(cmd.bits))], op.values
[arg_to_bit(bit) for bit in cmd.bits],
[arg_to_bit(cmd.args[i]) for i in range(len(cmd.args) // 2)],
)

case _:
Expand All @@ -159,9 +167,6 @@ def append_cmd(cmd: tk.Command, ops: list[dict[str, Any]]) -> None:
else:
op: dict[str, Any] | None = None
match cmd.op:
case tk.SetBitsOp():
op = convert_subcmd(cmd.op, cmd)

case tk.BarrierOp():
# TODO(kartik): confirm with Ciaran/spec
# https://github.com/CQCL/phir/blob/main/spec.md
Expand Down Expand Up @@ -205,6 +210,7 @@ def append_cmd(cmd: tk.Command, ops: list[dict[str, Any]]) -> None:
"condition": cond,
"true_branch": [assign_cop([arg_to_bit(cmd.bits[0])], [1])],
}

case tk.ClassicalExpBox():
exp = cmd.op.get_exp()
match exp.op:
Expand Down Expand Up @@ -243,6 +249,10 @@ def append_cmd(cmd: tk.Command, ops: list[dict[str, Any]]) -> None:
"cop": cop,
"args": [arg["name"] for arg in exp.to_dict()["args"]],
}

case tk.ClassicalEvalOp():
op = convert_subcmd(cmd.op, cmd)

case m:
raise NotImplementedError(m)
if op:
Expand Down
1 change: 1 addition & 0 deletions pytket/phir/sharding/sharder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
OpType.ClassicalExpBox, # some classical operations are rolled up into a box
OpType.RangePredicate,
OpType.ExplicitPredicate,
OpType.CopyBits,
]

logger = logging.getLogger(__name__)
Expand Down
23 changes: 23 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
#
##############################################################################

import json
import logging

import pytest

from pytket.circuit import Bit, Circuit
from pytket.phir.api import pytket_to_phir
from pytket.phir.qtm_machine import QtmMachine

Expand Down Expand Up @@ -46,3 +48,24 @@ def test_pytket_to_phir_h1_1_all(self, test_file: QasmFile) -> None:
circuit = get_qasm_as_circuit(test_file)

assert pytket_to_phir(circuit, QtmMachine.H1_1)

def test_pytket_classical_only(self) -> None:
c = Circuit(1)
a = c.add_c_register("a", 2)
b = c.add_c_register("b", 3)

c.add_c_copyreg(a, b)
c.add_c_copybits([Bit("b", 2), Bit("a", 1)], [Bit("a", 0), Bit("b", 0)])

phir = json.loads(pytket_to_phir(c)) # type: ignore[misc]

assert phir["ops"][3] == { # type: ignore[misc]
"cop": "=",
"returns": [["b", 0], ["b", 1]],
"args": [["a", 0], ["a", 1]],
}
assert phir["ops"][5] == { # type: ignore[misc]
"cop": "=",
"returns": [["a", 0], ["b", 0]],
"args": [["b", 2], ["a", 1]],
}
16 changes: 12 additions & 4 deletions tests/test_parallel_tk2.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_pll_tk2() -> None:
# it is the correct output for the tk2.qasm file
# if you change the tk2.qasm file, you just re-generate the correct
# phir json and replace the expected or the test will fail
expected = {
expected: dict[str, Any] = {
"ops": [
{"data": "qvar_define", "data_type": "qubits", "variable": "q", "size": 4},
{"data": "cvar_define", "data_type": "u32", "variable": "c", "size": 4},
Expand Down Expand Up @@ -93,10 +93,18 @@ def test_pll_tk2() -> None:
{"mop": "Transport", "duration": [0.0, "ms"]},
{
"qop": "Measure",
"args": [["q", 3], ["q", 0], ["q", 1], ["q", 2]],
"returns": [["c", 3], ["c", 0], ["c", 1], ["c", 2]],
"args": [["q", 0], ["q", 1], ["q", 2], ["q", 3]],
"returns": [["c", 0], ["c", 1], ["c", 2], ["c", 3]],
},
{"mop": "Transport", "duration": [0.0, "ms"]},
],
}
assert actual["ops"] == expected["ops"]

assert actual["ops"][6]["block"] == "qparallel"
for op in expected["ops"][6]["ops"]:
assert op in actual["ops"][6]["ops"]

act_meas_op = actual["ops"][8]
assert act_meas_op["qop"] == "Measure"
assert sorted(act_meas_op["args"]) == expected["ops"][8]["args"]
assert sorted(act_meas_op["returns"]) == expected["ops"][8]["returns"]
12 changes: 10 additions & 2 deletions tests/test_parallelization.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_phir_json(qasmfile: QasmFile) -> dict[str, Any]:
def test_bv_n10() -> None:
"""Make sure the parallelization is happening properly for the test circuit."""
actual = get_phir_json(QasmFile.parallelization_test)
expected = {
expected: dict[str, Any] = {
"ops": [
{"data": "qvar_define", "data_type": "qubits", "variable": "q", "size": 4},
{"data": "cvar_define", "data_type": "u32", "variable": "c", "size": 4},
Expand Down Expand Up @@ -80,4 +80,12 @@ def test_bv_n10() -> None:
{"mop": "Transport", "duration": [0.0, "ms"]},
],
}
assert actual["ops"] == expected["ops"]

assert actual["ops"][7]["block"] == "qparallel"
for op in expected["ops"][7]["ops"]:
assert op in actual["ops"][7]["ops"]

act_meas_op = actual["ops"][9]
assert act_meas_op["qop"] == "Measure"
assert sorted(act_meas_op["args"]) == expected["ops"][9]["args"]
assert sorted(act_meas_op["returns"]) == expected["ops"][9]["returns"]

0 comments on commit 5cd46aa

Please sign in to comment.