Skip to content

Commit

Permalink
Merge branch 'ordering' of https://github.com/CQCL/pytket-phir into o…
Browse files Browse the repository at this point in the history
…rdering
  • Loading branch information
Asa-Kosto-QTM committed Jan 26, 2024
2 parents 365338e + a3b9f58 commit c6bda61
Show file tree
Hide file tree
Showing 12 changed files with 65 additions and 81 deletions.
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Install additional dependencies needed for the CLI using `pip install pytket-phi

```sh
❯ phirc -h
usage: phirc [-h] [-w WASM_FILE] [-m {H1-1,H1-2}] [-o {0,1,2}] [-v] [--version] qasm_files [qasm_files ...]
usage: phirc [-h] [-w WASM_FILE] [-m {H1-1,H1-2}] [-v] [--version] qasm_files [qasm_files ...]

Emulates QASM program execution via PECOS

Expand All @@ -39,8 +39,6 @@ options:
Optional WASM file for use by the QASM programs
-m {H1-1,H1-2}, --machine {H1-1,H1-2}
Machine name, H1-1 by default
-o {0,1,2}, --tket-opt-level {0,1,2}
TKET optimization level, 0 by default
-v, --verbose
--version show program's version number and exit
```
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ classifiers = [
dynamic = ["version"]
dependencies = [
"phir>=0.2.1",
"pytket-quantinuum>=0.25.0",
"pytket>=1.21.0",
"wasmtime>=15.0.0",
]
Expand Down
18 changes: 4 additions & 14 deletions pytket/phir/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,15 @@

logger = logging.getLogger(__name__)

DEFAULT_TKET_OPT_LEVEL = 0


def pytket_to_phir(
circuit: "Circuit",
qtm_machine: QtmMachine | None = None,
tket_optimization_level: int = DEFAULT_TKET_OPT_LEVEL,
) -> str:
def pytket_to_phir(circuit: "Circuit", qtm_machine: QtmMachine | None = None) -> str:
"""Converts a pytket circuit into its PHIR representation.
This can optionally include rebasing against a Quantinuum machine architecture,
and control of the TKET optimization level.
:param circuit: Circuit object to be converted
:param qtm_machine: (Optional) Quantinuum machine architecture to rebase against
:param tket_optimization_level: (Default=0) TKET circuit optimization level
Returns:
PHIR JSON as a str
Expand All @@ -56,9 +49,7 @@ def pytket_to_phir(
machine: Machine | None = None
if qtm_machine:
logger.info("Rebasing to machine %s", qtm_machine)
circuit = rebase_to_qtm_machine(
circuit, qtm_machine.value, tket_optimization_level
)
circuit = rebase_to_qtm_machine(circuit, qtm_machine)
machine = QTM_MACHINES_MAP.get(qtm_machine)
else:
machine = None
Expand All @@ -78,14 +69,14 @@ def pytket_to_phir(
else:
phir_json = genphir(placed, machine_ops=bool(machine))
if logger.getEffectiveLevel() <= logging.INFO:
print("PHIR JSON:")
print(PHIRModel.model_validate_json(phir_json))
return phir_json


def qasm_to_phir(
qasm: str,
qtm_machine: QtmMachine | None = None,
tket_optimization_level: int = DEFAULT_TKET_OPT_LEVEL,
wasm_bytes: bytes | None = None,
) -> str:
"""Converts a QASM circuit string into its PHIR representation.
Expand All @@ -95,7 +86,6 @@ def qasm_to_phir(
:param qasm: QASM input to be converted
:param qtm_machine: (Optional) Quantinuum machine architecture to rebase against
:param tket_optimization_level: (Default=0) TKET circuit optimization level
:param wasm_bytes: (Optional) WASM as bytes to include as part of circuit
"""
circuit: Circuit
Expand All @@ -116,4 +106,4 @@ def qasm_to_phir(
Path.unlink(Path(wasm_file.name))
else:
circuit = circuit_from_qasm_str(qasm)
return pytket_to_phir(circuit, qtm_machine, tket_optimization_level)
return pytket_to_phir(circuit, qtm_machine)
14 changes: 3 additions & 11 deletions pytket/phir/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
# mypy: disable-error-code="misc"
# ruff: noqa: T201

import logging
from argparse import ArgumentParser
from importlib.metadata import version

from pecos.engines.hybrid_engine import HybridEngine # type:ignore [import-not-found]
from pecos.foreign_objects.wasmtime import WasmtimeObj # type:ignore [import-not-found]
from rich import print

from pytket.qasm.qasm import (
circuit_from_qasm,
Expand Down Expand Up @@ -47,13 +47,6 @@ def main() -> None:
default="H1-1",
help="Machine name, H1-1 by default",
)
parser.add_argument(
"-o",
"--tket-opt-level",
choices=["0", "1", "2"],
default="0",
help="TKET optimization level, 0 by default",
)
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument(
"--version",
Expand All @@ -78,10 +71,9 @@ def main() -> None:
case "H1-2":
machine = QtmMachine.H1_2

phir = pytket_to_phir(circuit, machine, int(args.tket_opt_level))
if args.verbose:
print("\nPHIR to be simulated:")
print(phir)
logging.basicConfig(level=logging.INFO)
phir = pytket_to_phir(circuit, machine)

print("\nPECOS results:")
print(
Expand Down
37 changes: 28 additions & 9 deletions pytket/phir/machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,52 @@
#
##############################################################################

from dataclasses import dataclass
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from pytket.circuit import OpType


@dataclass
class MachineTimings:
"""Gate times for a machine.
tq_time: time for a two qubit gate
sq_time: time for a single qubit gate
qb_swap_time: time it takes to swap to qubits
"""

tq_time: float
sq_time: float
qb_swap_time: float


class Machine:
"""A machine info class for testing."""

def __init__(
self,
size: int,
gateset: "set[OpType]",
tq_options: set[int],
tq_time: float,
sq_time: float,
qb_swap_time: float,
timings: MachineTimings,
):
"""Create Machine object.
Args:
size: number of qubits/slots
gateset: set of supported gates
tq_options: options for where to perform tq gates
tq_time: time for a two qubit gate
sq_time: time for a single qubit gate
qb_swap_time: time it takes to swap to qubits
timings: gate times
"""
self.size = size
self.gateset = gateset
self.tq_options = tq_options
self.sq_options: set[int] = set()
self.tq_time = tq_time
self.sq_time = sq_time
self.qb_swap_time = qb_swap_time
self.tq_time = timings.tq_time
self.sq_time = timings.sq_time
self.qb_swap_time = timings.qb_swap_time

for i in self.tq_options:
self.sq_options.add(i)
Expand Down
16 changes: 9 additions & 7 deletions pytket/phir/qtm_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

from enum import Enum

from .machine import Machine
from pytket.circuit import OpType

from .machine import Machine, MachineTimings


class QtmMachine(Enum):
Expand All @@ -18,23 +20,23 @@ class QtmMachine(Enum):
H1_2 = "H1-2"


QTM_DEFAULT_GATESET = {OpType.Rz, OpType.PhasedX, OpType.ZZPhase}

QTM_MACHINES_MAP = {
QtmMachine.H1_1: Machine(
size=20,
gateset=QTM_DEFAULT_GATESET,
tq_options={0, 2, 4, 6, 8, 10, 12, 14, 16, 18},
# need to get better timing values for below
# but will have to look them up in hqcompiler
tq_time=3.0,
sq_time=1.0,
qb_swap_time=2.0,
timings=MachineTimings(tq_time=3.0, sq_time=1.0, qb_swap_time=2.0),
),
QtmMachine.H1_2: Machine(
size=12,
gateset=QTM_DEFAULT_GATESET,
tq_options={0, 2, 4, 6, 8, 10},
# need to get better timing values for below
# but will have to look them up in hqcompiler
tq_time=3.0,
sq_time=1.0,
qb_swap_time=2.0,
timings=MachineTimings(tq_time=3.0, sq_time=1.0, qb_swap_time=2.0),
),
}
29 changes: 9 additions & 20 deletions pytket/phir/rebasing/rebaser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,19 @@

from typing import TYPE_CHECKING

from pytket.extensions.quantinuum.backends.api_wrappers import QuantinuumAPIOffline
from pytket.extensions.quantinuum.backends.quantinuum import (
QuantinuumBackend,
)
from pytket.passes import DecomposeBoxes
from pytket.passes.auto_rebase import auto_rebase_pass
from pytket.phir.qtm_machine import QTM_DEFAULT_GATESET, QTM_MACHINES_MAP, QtmMachine

if TYPE_CHECKING:
from pytket.circuit import Circuit


def rebase_to_qtm_machine(
circuit: "Circuit", qtm_machine: str, tket_optimization_level: int
) -> "Circuit":
def rebase_to_qtm_machine(circuit: "Circuit", qtm_machine: QtmMachine) -> "Circuit":
"""Rebases a circuit's gate to the gate set appropriate for the given machine."""
qapi_offline = QuantinuumAPIOffline()
backend = QuantinuumBackend(
device_name=qtm_machine,
machine_debug=False,
api_handler=qapi_offline, # type: ignore [arg-type]
)

# Decompose boxes to ensure no problematic phase gates
DecomposeBoxes().apply(circuit)

# Optimization level 0 includes rebasing and little else
# see: https://cqcl.github.io/pytket-quantinuum/api/#default-compilation
return backend.get_compiled_circuit(circuit, tket_optimization_level)
machine = QTM_MACHINES_MAP.get(qtm_machine)
gateset = QTM_DEFAULT_GATESET if machine is None else machine.gateset
c = circuit.copy()
DecomposeBoxes().apply(c)
auto_rebase_pass(gateset, allow_swaps=True).apply(c)
return c
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ phir==0.2.1
pre-commit==3.6.0
pydata_sphinx_theme==0.15.2
pytest==7.4.4
pytket-quantinuum==0.27.0
pytket==1.24.0
ruff==0.1.14
setuptools_scm==8.0.4
Expand Down
12 changes: 3 additions & 9 deletions tests/e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,16 @@
from phir.model import PHIRModel
from rich import print

from pytket.phir.machine import Machine
from pytket.phir.machine import Machine, MachineTimings
from pytket.phir.phirgen import genphir
from pytket.phir.place_and_route import place_and_route
from pytket.phir.placement import placement_check
from pytket.phir.qtm_machine import QTM_MACHINES_MAP, QtmMachine
from pytket.phir.qtm_machine import QTM_DEFAULT_GATESET, QTM_MACHINES_MAP, QtmMachine
from pytket.phir.sharding.sharder import Sharder
from tests.test_utils import QasmFile, get_qasm_as_circuit

if __name__ == "__main__":
machine = Machine(
3,
{1},
3.0,
1.0,
2.0,
)
machine = Machine(3, QTM_DEFAULT_GATESET, {1}, MachineTimings(3.0, 1.0, 2.0))
# force machine options for this test
# machines normally don't like odd numbers of qubits
machine.sq_options = {0, 1, 2}
Expand Down
9 changes: 5 additions & 4 deletions tests/test_placement.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,18 @@

import pytest

from pytket.phir.machine import Machine
from pytket.phir.machine import Machine, MachineTimings
from pytket.phir.placement import (
GateOpportunitiesError,
InvalidParallelOpsError,
place,
placement_check,
)
from pytket.phir.qtm_machine import QTM_DEFAULT_GATESET

m = Machine(4, {1}, 10, 2, 2)
m2 = Machine(6, {1, 3}, 10, 2, 2)
m3 = Machine(8, {0, 6}, 10, 2, 2)
m = Machine(4, QTM_DEFAULT_GATESET, {1}, MachineTimings(10, 2, 2))
m2 = Machine(6, QTM_DEFAULT_GATESET, {1, 3}, MachineTimings(10, 2, 2))
m3 = Machine(8, QTM_DEFAULT_GATESET, {0, 6}, MachineTimings(10, 2, 2))


def test_placement_check() -> None:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_rebaser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging

from pytket.circuit import Circuit, OpType
from pytket.phir.qtm_machine import QtmMachine
from pytket.phir.rebasing.rebaser import rebase_to_qtm_machine

from .test_utils import QasmFile, get_qasm_as_circuit
Expand All @@ -27,7 +28,7 @@
class TestRebaser:
def test_rebaser_happy_path_arc1a(self) -> None:
circ = get_qasm_as_circuit(QasmFile.baby)
rebased: Circuit = rebase_to_qtm_machine(circ, "H1-1", 0)
rebased: Circuit = rebase_to_qtm_machine(circ, QtmMachine.H1_1)

logger.info(rebased)
for command in rebased.get_commands():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def get_phir_json(qasmfile: QasmFile, *, rebase: bool) -> "JsonDict":
qtm_machine = QtmMachine.H1_1
circuit = get_qasm_as_circuit(qasmfile)
if rebase:
circuit = rebase_to_qtm_machine(circuit, qtm_machine.value, 0)
circuit = rebase_to_qtm_machine(circuit, qtm_machine)
machine = QTM_MACHINES_MAP.get(qtm_machine)
assert machine
shards = Sharder(circuit).shard()
Expand Down

0 comments on commit c6bda61

Please sign in to comment.