Skip to content

Commit

Permalink
fix: Ensure unused classical registers are not omitted (#238)
Browse files Browse the repository at this point in the history
* test: add test for issue #237

* fix: Ensure unused classical registers are not omitted

Also refactor get_decls to use the info from pytket Circuit

* chore: update typos
  • Loading branch information
qartik authored Oct 8, 2024
1 parent a4fdc54 commit 13246af
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 46 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ repos:
- id: debug-statements

- repo: https://github.com/crate-ci/typos
rev: v1.25.0
rev: v1.26.0
hooks:
- id: typos

Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Fixed

* Ensure unused classical registers are not omitted

## [0.8.1] - 2024-09-11

### Fixed
Expand Down
4 changes: 2 additions & 2 deletions pytket/phir/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ def pytket_to_phir(circuit: "Circuit", qtm_machine: QtmMachine | None = None) ->
placed = place_and_route(shards, machine)
# safety check: never run with parallelization on a 1 qubit circuit
if machine and len(circuit.qubits) > 1:
phir_json = genphir_parallel(placed, machine)
phir_json = genphir_parallel(placed, circuit, machine)
else:
phir_json = genphir(placed, machine_ops=bool(machine))
phir_json = genphir(placed, circuit, machine_ops=bool(machine))
if logger.getEffectiveLevel() <= logging.INFO:
print("PHIR JSON:")
print(PHIRModel.model_validate_json(phir_json))
Expand Down
47 changes: 20 additions & 27 deletions pytket/phir/phirgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@
RegWiseOp,
)
from pytket.unit_id import Bit as tkBit
from pytket.unit_id import BitRegister
from pytket.unit_id import BitRegister, QubitRegister

if TYPE_CHECKING:
from collections.abc import Sequence

from pytket.unit_id import Qubit, UnitID
from pytket.circuit import Circuit
from pytket.unit_id import UnitID

from .sharding.shard import Cost, Ordering, ShardLayer

Expand Down Expand Up @@ -546,60 +547,52 @@ def make_comment_text(cmd: tk.Command, op: tk.Op) -> str:
return comment


def get_decls(qbits: set["Qubit"], cbits: set[tkBit]) -> list[dict[str, str | int]]:
"""Format the qvar and cvar define PHIR elements."""
qvar_dim: dict[str, int] = {}
for qbit in qbits:
qvar_dim.setdefault(qbit.reg_name, 0)
qvar_dim[qbit.reg_name] += 1

cvar_dim: dict[str, int] = {}
for cbit in cbits:
cvar_dim.setdefault(cbit.reg_name, 0)
cvar_dim[cbit.reg_name] += 1

def get_decls(
qregs: list[QubitRegister], cregs: list[BitRegister]
) -> list[dict[str, str | int]]:
"""Get PHIR declarations for qubits and classical variables."""
decls: list[dict[str, str | int]] = [
{
"data": "qvar_define",
"data_type": "qubits",
"variable": qvar,
"size": dim,
"variable": qreg.name,
"size": qreg.size,
}
for qvar, dim in qvar_dim.items()
for qreg in qregs
]

decls += [
{
"data": "cvar_define",
"data_type": f"i{WORDSIZE}",
"variable": cvar,
"size": dim,
"variable": creg.name,
"size": creg.size,
}
for cvar, dim in cvar_dim.items()
if cvar != "_w"
for creg in cregs
if creg.name != "_w"
]

return decls


def genphir(
inp: list[tuple["Ordering", "ShardLayer", "Cost"]], *, machine_ops: bool = True
inp: list[tuple["Ordering", "ShardLayer", "Cost"]],
circuit: "Circuit",
*,
machine_ops: bool = True,
) -> str:
"""Convert a list of shards to the equivalent PHIR.
Args:
inp: list of shards
circuit: corresponding tket Circuit
machine_ops: whether to include machine ops
"""
phir = PHIR_HEADER
ops: list[JsonDict] = []

qbits = set()
cbits = set()
for _orders, shard_layer, layer_cost in inp:
for shard in shard_layer:
qbits |= shard.qubits_used
cbits |= shard.bits_read | shard.bits_written
for sub_commands in shard.sub_commands.values():
for sc in sub_commands:
append_cmd(sc, ops)
Expand All @@ -612,7 +605,7 @@ def genphir(
},
)

decls = get_decls(qbits, cbits)
decls = get_decls(circuit.q_registers, circuit.c_registers)

phir["ops"] = decls + ops
PHIRModel.model_validate(phir)
Expand Down
12 changes: 6 additions & 6 deletions pytket/phir/phirgen_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .phirgen import PHIR_HEADER, append_cmd, arg_to_bit, get_decls, tket_gate_to_phir

if TYPE_CHECKING:
from pytket.circuit import Circuit
from pytket.unit_id import UnitID

from .machine import Machine
Expand Down Expand Up @@ -310,12 +311,15 @@ def adjust_phir_transport_time(ops: list["JsonDict"], machine: "Machine") -> Non


def genphir_parallel(
inp: list[tuple["Ordering", "ShardLayer", "Cost"]], machine: "Machine"
inp: list[tuple["Ordering", "ShardLayer", "Cost"]],
circuit: "Circuit",
machine: "Machine",
) -> str:
"""Convert a list of shards to the equivalent PHIR with parallel gating.
Args:
inp: list of shards
circuit: corresponding tket Circuit
machine: a QTM machine on which to simulate the circuit
"""
max_parallel_tq_gates = len(machine.tq_options) // 2
Expand All @@ -325,8 +329,6 @@ def genphir_parallel(
phir["metadata"]["strict_parallelism"] = True
ops: list[JsonDict] = []

qbits = set()
cbits = set()
for _orders, shard_layer, layer_cost in inp:
# within each shard layer, create groups of parallelizable shards
# squash all the sub-commands into the first shard in the group
Expand All @@ -335,8 +337,6 @@ def genphir_parallel(
)
for group in shard_groups.values():
for shard in group:
qbits |= shard.qubits_used
cbits |= shard.bits_read | shard.bits_written
if shard.sub_commands.values():
# sub-commands are always sq gates
subcmd_groups = process_sub_commands(
Expand All @@ -353,7 +353,7 @@ def genphir_parallel(
)
adjust_phir_transport_time(ops, machine)

decls = get_decls(qbits, cbits)
decls = get_decls(circuit.q_registers, circuit.c_registers)

phir["ops"] = decls + ops
PHIRModel.model_validate(phir)
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@
cost_1 = output[1][2]
assert cost_1 == 0.0

phir_json = genphir(output)
phir_json = genphir(output, circuit)

print(PHIRModel.model_validate_json(phir_json)) # type: ignore[misc]
28 changes: 21 additions & 7 deletions tests/test_phirgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_simple_cond_classical() -> None:

def test_pytket_classical_only() -> None:
"""From https://github.com/CQCL/pytket-phir/issues/61 ."""
c = Circuit(1)
c = Circuit()
a = c.add_c_register("a", 2)
b = c.add_c_register("b", 3)

Expand Down Expand Up @@ -106,7 +106,7 @@ def test_pytket_classical_only() -> None:

def test_classicalexpbox() -> None:
"""From https://github.com/CQCL/pytket-phir/issues/86 ."""
circ = Circuit(1)
circ = Circuit()
a = circ.add_c_register("a", 2)
b = circ.add_c_register("b", 2)
c = circ.add_c_register("c", 3)
Expand All @@ -122,7 +122,7 @@ def test_classicalexpbox() -> None:

def test_nested_arith() -> None:
"""From https://github.com/CQCL/pytket-phir/issues/87 ."""
circ = Circuit(1)
circ = Circuit()
a = circ.add_c_register("a", 2)
b = circ.add_c_register("b", 2)
c = circ.add_c_register("c", 3)
Expand All @@ -138,7 +138,7 @@ def test_nested_arith() -> None:

def test_arith_with_int() -> None:
"""From https://github.com/CQCL/pytket-phir/issues/88 ."""
circ = Circuit(1)
circ = Circuit()
a = circ.add_c_register("a", 2)
circ.add_classicalexpbox_register(a << 1, a.to_list())

Expand All @@ -152,7 +152,7 @@ def test_arith_with_int() -> None:

def test_bitwise_ops() -> None:
"""From https://github.com/CQCL/pytket-phir/issues/91 ."""
circ = Circuit(1)
circ = Circuit()
a = circ.add_c_register("a", 2)
b = circ.add_c_register("b", 2)
c = circ.add_c_register("c", 1)
Expand Down Expand Up @@ -187,7 +187,7 @@ def test_conditional_barrier() -> None:

def test_nested_bitwise_op() -> None:
"""From https://github.com/CQCL/pytket-phir/issues/133 ."""
circ = Circuit(4)
circ = Circuit()
a = circ.add_c_register("a", 4)
b = circ.add_c_register("b", 1)
circ.add_classicalexpbox_bit(a[0] ^ a[1] ^ a[2] ^ a[3], [b[0]])
Expand Down Expand Up @@ -338,7 +338,7 @@ def test_explicit_classical_ops() -> None:
def test_multi_bit_ops() -> None:
"""Test classical ops added to the circuit via tket multi-bit ops."""
# Test from https://github.com/CQCL/tket/blob/a2f6fab8a57da8787dfae94764b7c3a8e5779024/pytket/tests/classical_test.py#L107-L112
c = Circuit(0, 4)
c = Circuit()
c0 = c.add_c_register("c0", 3)
c1 = c.add_c_register("c1", 4)
c2 = c.add_c_register("c2", 5)
Expand Down Expand Up @@ -486,3 +486,17 @@ def test_condition_multiple_bits() -> None:
},
],
}


def test_unused_classical_registers() -> None:
"""From https://github.com/CQCL/pytket-phir/issues/237 ."""
circ = Circuit()
_ = circ.add_c_register("a", 1)
phir = json.loads(pytket_to_phir(circ))

assert phir["ops"][0] == {
"data": "cvar_define",
"data_type": "i64",
"size": 1,
"variable": "a",
}
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def get_phir_json(qasmfile: QasmFile, *, rebase: bool) -> "JsonDict":
assert machine
shards = Sharder(circuit).shard()
placed = place_and_route(shards, machine)
return json.loads(genphir_parallel(placed, machine)) # type: ignore[misc, no-any-return]
return json.loads(genphir_parallel(placed, circuit, machine)) # type: ignore[misc, no-any-return]


def get_wat_as_wasm_bytes(wat_file: WatFile) -> bytes:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_wasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_pytket_with_wasm() -> None:

w = WasmFileHandler(wasm_file.name)

c = Circuit(6, 6)
c = Circuit()
c0 = c.add_c_register("c0", 3)
c1 = c.add_c_register("c1", 4)
c2 = c.add_c_register("c2", 5)
Expand Down

0 comments on commit 13246af

Please sign in to comment.