Skip to content

Commit

Permalink
Merge branch 'main' into issue-14-nje-phase-gates
Browse files Browse the repository at this point in the history
  • Loading branch information
nealerickson-qtm committed Oct 30, 2023
2 parents 8630c94 + 97c539b commit 5ed97ad
Show file tree
Hide file tree
Showing 10 changed files with 73 additions and 51 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ repos:
- id: black

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.0
rev: v0.1.1
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ requires-python = ">=3.10"
license = {file = "LICENSE"}
authors = [{name = "Quantinuum"}]

dependencies = ["pytket"]
dependencies = ["phir>=0.1.5", "pytket"]

[project.optional-dependencies]
tests = ["pytest"]
Expand Down
15 changes: 7 additions & 8 deletions pytket/phir/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,21 @@ def pytket_to_phir(
circuit = rebase_to_qtm_machine(circuit, qtm_machine.value)
machine = QTM_MACHINES_MAP.get(qtm_machine)
else:
msg = "Machine parameter is currently required"
raise NotImplementedError(msg)
machine = None

logger.debug("Sharding input circuit...")
sharder = Sharder(circuit)
shards = sharder.shard()

logger.debug("Performing placement and routing...")
if machine:
placed = place_and_route(machine, shards)
else:
msg = "no machine found"
raise ValueError(msg)
# Only print message if a machine object is passed
# Otherwise, placment and routing are functionally skipped
# The function is called, but the output is just filled with 0s
logger.debug("Performing placement and routing...")
placed = place_and_route(shards, machine)

phir_json = genphir(placed)

if logger.getEffectiveLevel() <= logging.INFO:
print(PHIRModel.model_validate_json(phir_json, strict=True)) # type: ignore[misc]
print(PHIRModel.model_validate_json(phir_json)) # type: ignore[misc]
return phir_json
20 changes: 8 additions & 12 deletions pytket/phir/phirgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from phir.model import PHIRModel
from pytket.circuit import Command
from pytket.phir.sharding.shard import Shard
from pytket.phir.sharding.shard import Cost, Layer, Ordering


def write_cmd(cmd: Command, ops: list[dict[str, Any]]) -> None:
Expand All @@ -14,13 +14,9 @@ def write_cmd(cmd: Command, ops: list[dict[str, Any]]) -> None:
ops: the list of ops to append to
"""
gate = cmd.op.get_name().split("(", 1)[0]
metadata, angles = (
({"angle_multiplier": "π"}, cmd.op.params)
if gate != "Measure" and cmd.op.params
else (None, None)
)
angles = (cmd.op.params, "pi") if cmd.op.is_gate() and cmd.op.params else None

qop: dict[str, Any] = {
"metadata": metadata,
"angles": angles,
"qop": gate,
"args": [],
Expand All @@ -36,7 +32,7 @@ def write_cmd(cmd: Command, ops: list[dict[str, Any]]) -> None:
ops.extend(({"//": str(cmd)}, qop))


def genphir(inp: list[tuple[list[int], list[Shard], float]]) -> str:
def genphir(inp: list[tuple[Ordering, Layer, Cost]]) -> str:
"""Convert a list of shards to the equivalent PHIR.
Args:
Expand All @@ -51,8 +47,8 @@ def genphir(inp: list[tuple[list[int], list[Shard], float]]) -> str:

qbits = set()
cbits = set()
for _orders, shard_layers, layer_costs in inp:
for shard in shard_layers:
for _orders, shard_layer, layer_cost in inp:
for shard in shard_layer:
qbits |= shard.qubits_used
cbits |= shard.bits_read | shard.bits_written
for sub_commands in shard.sub_commands.values():
Expand All @@ -62,7 +58,7 @@ def genphir(inp: list[tuple[list[int], list[Shard], float]]) -> str:
ops.append(
{
"mop": "Transport",
"metadata": {"duration": layer_costs / 1000000}, # microseconds to secs
"duration": (layer_cost, "ms"),
},
)

Expand Down Expand Up @@ -97,5 +93,5 @@ def genphir(inp: list[tuple[list[int], list[Shard], float]]) -> str:
]

phir["ops"] = decls + ops
PHIRModel.model_validate(phir, strict=True)
PHIRModel.model_validate(phir)
return json.dumps(phir)
52 changes: 32 additions & 20 deletions pytket/phir/place_and_route.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,48 @@
from pytket.phir.machine import Machine
from pytket.phir.placement import optimized_place
from pytket.phir.routing import transport_cost
from pytket.phir.sharding.shard import Shard
from pytket.phir.sharding.shard import Cost, Layer, Ordering, Shard
from pytket.phir.sharding.shards2ops import parse_shards_naive


def place_and_route(
machine: Machine,
shards: list[Shard],
) -> list[tuple[list[int], list[Shard], float]]:
machine: Machine | None = None,
) -> list[tuple[Ordering, Layer, Cost]]:
"""Get all the routing info needed for PHIR generation."""
shard_set = set(shards)
circuit_rep, shard_layers = parse_shards_naive(shard_set)
initial_order = list(range(machine.size))
if machine:
initial_order = list(range(machine.size))
layer_num = 0
orders: list[list[int]] = []
layer_costs: list[float] = []
orders: list[Ordering] = []
layer_costs: list[Cost] = []
net_cost: float = 0.0
for layer in circuit_rep:
order = optimized_place(
layer,
machine.tq_options,
machine.sq_options,
machine.size,
initial_order,
)
orders.append(order)
cost = transport_cost(initial_order, order, machine.qb_swap_time)
layer_num += 1
initial_order = order
layer_costs.append(cost)
net_cost += cost
if machine:
for layer in circuit_rep:
order = optimized_place(
layer,
machine.tq_options,
machine.sq_options,
machine.size,
initial_order,
)
orders.append(order)
cost = transport_cost(initial_order, order, machine.qb_swap_time)
layer_num += 1
initial_order = order
layer_costs.append(cost)
net_cost += cost
else:
# If no machine object specified,
# generic lists of qubits with no placement and no routing costs,
# only the shards

# If needed later, write a helper to find the number
# of qubits needed in the circuit
n = len(circuit_rep)
orders = [[]] * n
layer_costs = [0] * n

# don't need a custom error for this, "strict" parameter will throw error if needed
return list(zip(orders, shard_layers, layer_costs, strict=True))
5 changes: 5 additions & 0 deletions pytket/phir/placement.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ def placement_check(
placement_valid = False
inv = inverse(state)

# If there are no operations to place, it does not matter where the
# qubits are and any placement is valid
if len(ops) == 0:
return True

# assume ops look like this [[1,2],[3],[4],[5,6],[7],[8],[9,10]]
for op in ops:
if len(op) == 2: # tq operation # noqa: PLR2004
Expand Down
6 changes: 6 additions & 0 deletions pytket/phir/sharding/shard.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import io
from dataclasses import dataclass, field
from itertools import count
from typing import TypeAlias

from pytket.circuit import Command
from pytket.unit_id import Bit, Qubit, UnitID
Expand Down Expand Up @@ -57,3 +58,8 @@ def pretty_print(self) -> str:
content = output.getvalue()
output.close()
return content


Cost: TypeAlias = float
Layer: TypeAlias = list[Shard]
Ordering: TypeAlias = list[int]
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
black==23.10.0
build==1.0.3
mypy==1.6.1
phir==0.1.3
phir==0.1.5
pre-commit==3.5.0
pytest==7.4.2
pytket-quantinuum==0.25.0
pytket==1.21.0
ruff==0.1.0
ruff==0.1.1
wheel==0.41.2
11 changes: 8 additions & 3 deletions tests/e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pytket.phir.phirgen import genphir
from pytket.phir.place_and_route import place_and_route
from pytket.phir.placement import placement_check
from pytket.phir.qtm_machine import QTM_MACHINES_MAP, QtmMachine
from pytket.phir.sharding.sharder import Sharder
from tests.sample_data import QasmFile, get_qasm_as_circuit

Expand All @@ -19,10 +20,14 @@
# force machine options for this test
# machines normally don't like odd numbers of qubits
machine.sq_options = {0, 1, 2}
circuit = get_qasm_as_circuit(QasmFile.eztest)

h11 = QTM_MACHINES_MAP[QtmMachine.H1_1]

circuit = get_qasm_as_circuit(QasmFile.classical_hazards)
sharder = Sharder(circuit)
shards = sharder.shard()
output = place_and_route(machine, shards)

output = place_and_route(shards, h11)
ez_ops_0 = [[0, 2], [1]]
ez_ops_1 = [[0], [2]]
state_0 = output[0][0]
Expand All @@ -42,4 +47,4 @@

phir_json = genphir(output)

print(PHIRModel.model_validate_json(phir_json, strict=True)) # type: ignore[misc]
print(PHIRModel.model_validate_json(phir_json)) # type: ignore[misc]
7 changes: 3 additions & 4 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import pytest

from pytket.phir.api import pytket_to_phir
from pytket.phir.qtm_machine import QtmMachine

Expand All @@ -8,12 +6,13 @@

class TestApi:
def test_pytket_to_phir_no_machine(self) -> None:
"""Test case when no machine is present."""
circuit = get_qasm_as_circuit(QasmFile.baby)

with pytest.raises(NotImplementedError):
pytket_to_phir(circuit)
assert pytket_to_phir(circuit)

def test_pytket_to_phir_h1_1(self) -> None:
"""Standard case."""
circuit = get_qasm_as_circuit(QasmFile.baby)

# TODO(neal): Make this test more valuable once PHIR is actually returned
Expand Down

0 comments on commit 5ed97ad

Please sign in to comment.