Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Ensure unused classical registers are not omitted #238

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ repos:
- id: debug-statements

- repo: https://github.com/crate-ci/typos
rev: v1.25.0
rev: v1.26.0
hooks:
- id: typos

Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Fixed

* Ensure unused classical registers are not omitted

## [0.8.1] - 2024-09-11

### Fixed
Expand Down
4 changes: 2 additions & 2 deletions pytket/phir/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ def pytket_to_phir(circuit: "Circuit", qtm_machine: QtmMachine | None = None) ->
placed = place_and_route(shards, machine)
# safety check: never run with parallelization on a 1 qubit circuit
if machine and len(circuit.qubits) > 1:
phir_json = genphir_parallel(placed, machine)
phir_json = genphir_parallel(placed, circuit, machine)
else:
phir_json = genphir(placed, machine_ops=bool(machine))
phir_json = genphir(placed, circuit, machine_ops=bool(machine))
if logger.getEffectiveLevel() <= logging.INFO:
print("PHIR JSON:")
print(PHIRModel.model_validate_json(phir_json))
Expand Down
47 changes: 20 additions & 27 deletions pytket/phir/phirgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@
RegWiseOp,
)
from pytket.unit_id import Bit as tkBit
from pytket.unit_id import BitRegister
from pytket.unit_id import BitRegister, QubitRegister

if TYPE_CHECKING:
from collections.abc import Sequence

from pytket.unit_id import Qubit, UnitID
from pytket.circuit import Circuit
from pytket.unit_id import UnitID

from .sharding.shard import Cost, Ordering, ShardLayer

Expand Down Expand Up @@ -546,60 +547,52 @@ def make_comment_text(cmd: tk.Command, op: tk.Op) -> str:
return comment


def get_decls(qbits: set["Qubit"], cbits: set[tkBit]) -> list[dict[str, str | int]]:
"""Format the qvar and cvar define PHIR elements."""
qvar_dim: dict[str, int] = {}
for qbit in qbits:
qvar_dim.setdefault(qbit.reg_name, 0)
qvar_dim[qbit.reg_name] += 1

cvar_dim: dict[str, int] = {}
for cbit in cbits:
cvar_dim.setdefault(cbit.reg_name, 0)
cvar_dim[cbit.reg_name] += 1

def get_decls(
qregs: list[QubitRegister], cregs: list[BitRegister]
) -> list[dict[str, str | int]]:
"""Get PHIR declarations for qubits and classical variables."""
decls: list[dict[str, str | int]] = [
{
"data": "qvar_define",
"data_type": "qubits",
"variable": qvar,
"size": dim,
"variable": qreg.name,
"size": qreg.size,
}
for qvar, dim in qvar_dim.items()
for qreg in qregs
]

decls += [
{
"data": "cvar_define",
"data_type": f"i{WORDSIZE}",
"variable": cvar,
"size": dim,
"variable": creg.name,
"size": creg.size,
}
for cvar, dim in cvar_dim.items()
if cvar != "_w"
for creg in cregs
if creg.name != "_w"
]

return decls


def genphir(
inp: list[tuple["Ordering", "ShardLayer", "Cost"]], *, machine_ops: bool = True
inp: list[tuple["Ordering", "ShardLayer", "Cost"]],
circuit: "Circuit",
*,
machine_ops: bool = True,
) -> str:
"""Convert a list of shards to the equivalent PHIR.

Args:
inp: list of shards
circuit: corresponding tket Circuit
machine_ops: whether to include machine ops
"""
phir = PHIR_HEADER
ops: list[JsonDict] = []

qbits = set()
cbits = set()
for _orders, shard_layer, layer_cost in inp:
for shard in shard_layer:
qbits |= shard.qubits_used
cbits |= shard.bits_read | shard.bits_written
for sub_commands in shard.sub_commands.values():
for sc in sub_commands:
append_cmd(sc, ops)
Expand All @@ -612,7 +605,7 @@ def genphir(
},
)

decls = get_decls(qbits, cbits)
decls = get_decls(circuit.q_registers, circuit.c_registers)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(See other comment.)


phir["ops"] = decls + ops
PHIRModel.model_validate(phir)
Expand Down
12 changes: 6 additions & 6 deletions pytket/phir/phirgen_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .phirgen import PHIR_HEADER, append_cmd, arg_to_bit, get_decls, tket_gate_to_phir

if TYPE_CHECKING:
from pytket.circuit import Circuit
from pytket.unit_id import UnitID

from .machine import Machine
Expand Down Expand Up @@ -310,12 +311,15 @@ def adjust_phir_transport_time(ops: list["JsonDict"], machine: "Machine") -> Non


def genphir_parallel(
inp: list[tuple["Ordering", "ShardLayer", "Cost"]], machine: "Machine"
inp: list[tuple["Ordering", "ShardLayer", "Cost"]],
circuit: "Circuit",
machine: "Machine",
) -> str:
"""Convert a list of shards to the equivalent PHIR with parallel gating.

Args:
inp: list of shards
circuit: corresponding tket Circuit
machine: a QTM machine on which to simulate the circuit
"""
max_parallel_tq_gates = len(machine.tq_options) // 2
Expand All @@ -325,8 +329,6 @@ def genphir_parallel(
phir["metadata"]["strict_parallelism"] = True
ops: list[JsonDict] = []

qbits = set()
cbits = set()
for _orders, shard_layer, layer_cost in inp:
# within each shard layer, create groups of parallelizable shards
# squash all the sub-commands into the first shard in the group
Expand All @@ -335,8 +337,6 @@ def genphir_parallel(
)
for group in shard_groups.values():
for shard in group:
qbits |= shard.qubits_used
cbits |= shard.bits_read | shard.bits_written
if shard.sub_commands.values():
# sub-commands are always sq gates
subcmd_groups = process_sub_commands(
Expand All @@ -353,7 +353,7 @@ def genphir_parallel(
)
adjust_phir_transport_time(ops, machine)

decls = get_decls(qbits, cbits)
decls = get_decls(circuit.q_registers, circuit.c_registers)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A possible "gotcha" here is that if a pytket circuit contains bits that do not form part of a full register -- for example, if it contains a bit c[1] but not c[0] -- then these bits are ignored by circuit.c_registers (similarly for qubits). I am not sure how phirgen handles this case currently (maybe an error, maybe add missing bits?) -- but if such a circuit were passed to this function it would presumably generate invalid PHIR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @cqc-alec. My guess without working with an example is that as long as pytket commands applying operations are referring to the right bits, we will generate correct PHIR commands for them. The only chance of error here, from what I can gather from your comment, is that the declaration command may mismatch the dimension of a register.

Would you be able to provide an example to help make this change more robust?

Copy link
Collaborator

@cqc-alec cqc-alec Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, how about this:

from pytket.circuit import Circuit, Qubit, Bit
from pytket.phir.api import pytket_to_phir

circ = Circuit()
circ.add_qubit(Qubit(1))
circ.add_bit(Bit(1))
circ.H(1)
circ.Measure(1, 1)
phir = pytket_to_phir(circ)
print(phir)

Currently this produces the following:

{
  "format": "PHIR/JSON",
  "version": "0.1.0",
  "metadata": { "source": "pytket-phir v0.8.1" },
  "ops": [
    {
      "data": "qvar_define",
      "data_type": "qubits",
      "variable": "q",
      "size": 1
    },
    { "data": "cvar_define", "data_type": "i64", "variable": "c", "size": 1 },
    { "//": "H q[1];" },
    { "qop": "H", "angles": null, "args": [["q", 1]] },
    { "//": "Measure q[1] --> c[1];" },
    { "qop": "Measure", "returns": [["c", 1]], "args": [["q", 1]] }
  ]
}

This looks incorrect, since it seems to be indexing q[1] in a size-1 register. I am not sure what pecos will do with this, probably error?

With the change in this PR, I suspect it will be differently wrong, since the qubit and bit registers would not be declared at all?

I think it would be OK to reject circuits with incomplete registers.

Copy link
Member Author

@qartik qartik Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

circ.add_qubit(Qubit(1))
circ.add_bit(Bit(1))

One would think that these operations should not be allowed by pytket as they seem inherently unsafe, and no backend may be able to simulate/run such a circuit.

What's the rationale behind supporting them in pytket?

In general, I agree we should reject this circuit in either of the cases -- with or without this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fundamental entities in pytket are not registers but Qubits and Bits -- and each of these is identified by the combination of a string and an index. There's nothing inherently wrong with this but it is confusing because most other languages and representations have the register as their fundamental entity, and in pytket this is a secondary concept.

As for the original rationale -- I'm afraid I don't know!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I will update the PR to reject such programs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, equally happy for that to be a separate PR.

Copy link
Member Author

@qartik qartik Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me make another PR then. #242


phir["ops"] = decls + ops
PHIRModel.model_validate(phir)
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@
cost_1 = output[1][2]
assert cost_1 == 0.0

phir_json = genphir(output)
phir_json = genphir(output, circuit)

print(PHIRModel.model_validate_json(phir_json)) # type: ignore[misc]
28 changes: 21 additions & 7 deletions tests/test_phirgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_simple_cond_classical() -> None:

def test_pytket_classical_only() -> None:
"""From https://github.com/CQCL/pytket-phir/issues/61 ."""
c = Circuit(1)
c = Circuit()
a = c.add_c_register("a", 2)
b = c.add_c_register("b", 3)

Expand Down Expand Up @@ -106,7 +106,7 @@ def test_pytket_classical_only() -> None:

def test_classicalexpbox() -> None:
"""From https://github.com/CQCL/pytket-phir/issues/86 ."""
circ = Circuit(1)
circ = Circuit()
a = circ.add_c_register("a", 2)
b = circ.add_c_register("b", 2)
c = circ.add_c_register("c", 3)
Expand All @@ -122,7 +122,7 @@ def test_classicalexpbox() -> None:

def test_nested_arith() -> None:
"""From https://github.com/CQCL/pytket-phir/issues/87 ."""
circ = Circuit(1)
circ = Circuit()
a = circ.add_c_register("a", 2)
b = circ.add_c_register("b", 2)
c = circ.add_c_register("c", 3)
Expand All @@ -138,7 +138,7 @@ def test_nested_arith() -> None:

def test_arith_with_int() -> None:
"""From https://github.com/CQCL/pytket-phir/issues/88 ."""
circ = Circuit(1)
circ = Circuit()
a = circ.add_c_register("a", 2)
circ.add_classicalexpbox_register(a << 1, a.to_list())

Expand All @@ -152,7 +152,7 @@ def test_arith_with_int() -> None:

def test_bitwise_ops() -> None:
"""From https://github.com/CQCL/pytket-phir/issues/91 ."""
circ = Circuit(1)
circ = Circuit()
a = circ.add_c_register("a", 2)
b = circ.add_c_register("b", 2)
c = circ.add_c_register("c", 1)
Expand Down Expand Up @@ -187,7 +187,7 @@ def test_conditional_barrier() -> None:

def test_nested_bitwise_op() -> None:
"""From https://github.com/CQCL/pytket-phir/issues/133 ."""
circ = Circuit(4)
circ = Circuit()
a = circ.add_c_register("a", 4)
b = circ.add_c_register("b", 1)
circ.add_classicalexpbox_bit(a[0] ^ a[1] ^ a[2] ^ a[3], [b[0]])
Expand Down Expand Up @@ -338,7 +338,7 @@ def test_explicit_classical_ops() -> None:
def test_multi_bit_ops() -> None:
"""Test classical ops added to the circuit via tket multi-bit ops."""
# Test from https://github.com/CQCL/tket/blob/a2f6fab8a57da8787dfae94764b7c3a8e5779024/pytket/tests/classical_test.py#L107-L112
c = Circuit(0, 4)
c = Circuit()
c0 = c.add_c_register("c0", 3)
c1 = c.add_c_register("c1", 4)
c2 = c.add_c_register("c2", 5)
Expand Down Expand Up @@ -486,3 +486,17 @@ def test_condition_multiple_bits() -> None:
},
],
}


def test_unused_classical_registers() -> None:
"""From https://github.com/CQCL/pytket-phir/issues/237 ."""
circ = Circuit()
_ = circ.add_c_register("a", 1)
phir = json.loads(pytket_to_phir(circ))

assert phir["ops"][0] == {
"data": "cvar_define",
"data_type": "i64",
"size": 1,
"variable": "a",
}
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def get_phir_json(qasmfile: QasmFile, *, rebase: bool) -> "JsonDict":
assert machine
shards = Sharder(circuit).shard()
placed = place_and_route(shards, machine)
return json.loads(genphir_parallel(placed, machine)) # type: ignore[misc, no-any-return]
return json.loads(genphir_parallel(placed, circuit, machine)) # type: ignore[misc, no-any-return]


def get_wat_as_wasm_bytes(wat_file: WatFile) -> bytes:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_wasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_pytket_with_wasm() -> None:

w = WasmFileHandler(wasm_file.name)

c = Circuit(6, 6)
c = Circuit()
c0 = c.add_c_register("c0", 3)
c1 = c.add_c_register("c1", 4)
c2 = c.add_c_register("c2", 5)
Expand Down