Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding WASM support #77

Merged
merged 31 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
4c4418d
WASM support
nealerickson-qtm Jan 2, 2024
4fe3a7c
Merge branch 'main' into issue-50-wasm-support
qartik Jan 3, 2024
fa7a818
feedback
nealerickson-qtm Jan 3, 2024
251d32f
Merge branch 'main' into issue-50-wasm-support
qartik Jan 3, 2024
e3af6c7
Update pytket/phir/phirgen.py
neal-erickson Jan 4, 2024
12647cf
cleanup
nealerickson-qtm Jan 4, 2024
e7ca093
Merge branch 'main' into issue-50-wasm-support
qartik Jan 8, 2024
f5009f7
Merge branch 'main' into issue-50-wasm-support
qartik Jan 10, 2024
2874c39
Merge branch 'main' into issue-50-wasm-support
qartik Jan 13, 2024
3b3a01d
Merge branch 'main' into issue-50-wasm-support
qartik Jan 15, 2024
10b482c
removing direct wasm usage, adding wasmtime, improving testing
nealerickson-qtm Jan 17, 2024
c834934
Merge branch 'main' into issue-50-wasm-support
qartik Jan 17, 2024
fd76eae
cleanup
nealerickson-qtm Jan 17, 2024
8d84539
removing large file exemption
nealerickson-qtm Jan 17, 2024
6906eeb
Merge branch 'wasm-support' into issue-50-wasm-support
nealerickson-qtm Jan 17, 2024
1d2f0ff
style: ignore the whole file vs per line for misc
qartik Jan 18, 2024
749ad20
build(mypy): include wasmtime in pre-commit mypy
qartik Jan 18, 2024
9265dee
style: remove unneeded import
qartik Jan 18, 2024
69d9d09
style(mypy): remove unneeded cast
qartik Jan 18, 2024
47c4699
feedback
nealerickson-qtm Jan 18, 2024
1734e43
making the temp files work on windows
nealerickson-qtm Jan 18, 2024
a0347ff
Merge branch 'main' into issue-50-wasm-support
nealerickson-qtm Jan 21, 2024
73f80ef
fixing test
nealerickson-qtm Jan 22, 2024
b3f71f9
Merge branch 'main' into issue-50-wasm-support
qartik Jan 22, 2024
f2e72af
Fixing windows on test
nealerickson-qtm Jan 22, 2024
f3e8b62
style: move misc ignore to whole file
qartik Jan 22, 2024
e42c997
test(wasm): avoid magic hash for testing WASM module uid
qartik Jan 22, 2024
1f401de
docs: update README, add dev-all target for phirc CLI deps
qartik Jan 22, 2024
74227d3
fix(cli): use -v for verbose and simplify
qartik Jan 23, 2024
6849cbf
Using PECOS wasm correctly
nealerickson-qtm Jan 23, 2024
f8dd59a
style: avoid unneeded import
qartik Jan 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,5 @@ repos:
pytket-quantinuum,
pytket,
types-setuptools,
wasmtime,
]
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ options:
-h, --help show this help message and exit
-m {H1-1,H1-2}, --machine {H1-1,H1-2}
machine name, H1-1 by default
-o, --tket-opt-level select TKET optimization level (0 to 2, default: 0)
-w, --wasm-file path to an optional WASM file to include
-v, --version show program's version number and exit

```

## Development
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies = [
"phir>=0.2.1",
"pytket-quantinuum>=0.25.0",
"pytket>=1.21.0",
"wasmtime>=15.0.0",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -55,8 +56,9 @@ pythonpath = [
]
log_cli = true
log_cli_level = "INFO"
log_level = "DEBUG"
filterwarnings = ["ignore:::lark.s*"]
log_format = "%(asctime)s.%(msecs)03d %(levelname)s %(message)s"
log_format = "%(asctime)s.%(msecs)03d %(levelname)s %(name)s:%(lineno)s %(message)s"
log_date_format = "%Y-%m-%d %H:%M:%S"

[tool.setuptools_scm]
Expand Down
25 changes: 23 additions & 2 deletions pytket/phir/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
##############################################################################

import logging
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING

from rich import print

from phir.model import PHIRModel
from pytket.qasm.qasm import circuit_from_qasm_str
from pytket.qasm.qasm import circuit_from_qasm_str, circuit_from_qasm_wasm

from .phirgen import genphir
from .phirgen_parallel import genphir_parallel
Expand Down Expand Up @@ -82,6 +84,7 @@ def qasm_to_phir(
qasm: str,
qtm_machine: QtmMachine | None = None,
tket_optimization_level: int = DEFAULT_TKET_OPT_LEVEL,
wasm_bytes: bytes | None = None,
) -> str:
"""Converts a QASM circuit string into its PHIR representation.

Expand All @@ -91,6 +94,24 @@ def qasm_to_phir(
:param circuit: Circuit object to be converted
:param qtm_machine: (Optional) Quantinuum machine architecture to rebase against
:param tket_optimization_level: (Default=0) TKET circuit optimization level
:param wasm_bytes (Optional) WASM as bytes to include as part of circuit
"""
circuit = circuit_from_qasm_str(qasm)
circuit: Circuit
if wasm_bytes:
qciaran marked this conversation as resolved.
Show resolved Hide resolved
try:
qasm_file = NamedTemporaryFile(suffix=".qasm", delete=False)
wasm_file = NamedTemporaryFile(suffix=".wasm", delete=False)
qasm_file.write(qasm.encode())
qasm_file.flush()
qasm_file.close()
wasm_file.write(wasm_bytes)
wasm_file.flush()
wasm_file.close()

circuit = circuit_from_qasm_wasm(qasm_file.name, wasm_file.name)
finally:
Path.unlink(Path(qasm_file.name)) # type: ignore[misc]
Path.unlink(Path(wasm_file.name)) # type: ignore[misc]
else:
circuit = circuit_from_qasm_str(qasm)
return pytket_to_phir(circuit, qtm_machine, tket_optimization_level)
38 changes: 24 additions & 14 deletions pytket/phir/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,20 @@
##############################################################################

# mypy: disable-error-code="misc"
neal-erickson marked this conversation as resolved.
Show resolved Hide resolved
# ruff: noqa: T201

from argparse import ArgumentParser
from argparse import ArgumentParser, BooleanOptionalAction
from importlib.metadata import version

from pecos.engines.hybrid_engine import HybridEngine # type:ignore [import-not-found]

from phir.model import PHIRModel
from pytket.qasm.qasm import (
circuit_from_qasm,
circuit_from_qasm_str,
circuit_to_qasm_str,
circuit_from_qasm_wasm,
)

from .api import pytket_to_phir
from .qtm_machine import QtmMachine
from .rebasing.rebaser import rebase_to_qtm_machine


def main() -> None:
Expand All @@ -34,6 +32,12 @@ def main() -> None:
parser.add_argument(
"qasm_files", nargs="+", default=None, help="One or more QASM files to emulate"
)
parser.add_argument(
"-w",
"--wasm-file",
default=None,
help="Optional WASM file for use by the QASM programs",
)
parser.add_argument(
"-m",
"--machine",
Expand All @@ -48,6 +52,7 @@ def main() -> None:
default="0",
help="TKET optimization level, 0 by default",
)
parser.add_argument("--verbose", action=BooleanOptionalAction)
parser.add_argument(
"-v",
"--version",
Expand All @@ -57,19 +62,24 @@ def main() -> None:
args = parser.parse_args()

for file in args.qasm_files:
print(f"Processing {file}") # noqa: T201
c = circuit_from_qasm(file)
tket_opt_level = int(args.tk)
rc = rebase_to_qtm_machine(c, args.machine, tket_opt_level)
qartik marked this conversation as resolved.
Show resolved Hide resolved
qasm = circuit_to_qasm_str(rc, header="hqslib1")
circ = circuit_from_qasm_str(qasm)
print(f"Processing {file}")
circuit = None
if args.wasm_file:
print(f"Including WASM from file {args.wasm_file}")
circuit = circuit_from_qasm_wasm(file, args.wasm_file)
else:
circuit = circuit_from_qasm(file)

match args.machine:
case "H1-1":
machine = QtmMachine.H1_1
case "H1-2":
machine = QtmMachine.H1_2
phir = pytket_to_phir(circ, machine)
PHIRModel.model_validate_json(phir)

HybridEngine(qsim="state-vector").run(program=phir, shots=10)
phir = pytket_to_phir(circuit, machine, int(args.tket_opt_level))
if args.verbose:
print("\nPHIR to be simulated:")
print(phir)

print("\nPECOS results:")
print(HybridEngine(qsim="state-vector").run(program=phir, shots=10))
57 changes: 56 additions & 1 deletion pytket/phir/phirgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ def convert_subcmd(op: tk.Op, cmd: tk.Command) -> JsonDict | None:
[arg_to_bit(cmd.args[i]) for i in range(len(cmd.args) // 2)],
)

case tk.WASMOp():
return create_wasm_op(cmd, op)

case _:
# TODO(kartik): NYI
# https://github.com/CQCL/pytket-phir/issues/25
Expand All @@ -267,12 +270,63 @@ def append_cmd(cmd: tk.Command, ops: list[JsonDict]) -> None:
cmd: pytket command obtained from pytket-phir
ops: the list of ops to append to
"""
ops.append({"//": str(cmd)})
ops.append({"//": make_comment_text(cmd, cmd.op)})
op: JsonDict | None = convert_subcmd(cmd.op, cmd)
if op:
ops.append(op)


def create_wasm_op(cmd: tk.Command, wasm_op: tk.WASMOp) -> JsonDict:
"""Creates a PHIR operation for a WASM command."""
args, returns = extract_wasm_args_and_returns(cmd, wasm_op)
op = {
"cop": "ffcall",
"function": wasm_op.func_name,
"args": args,
"metadata": {
"ff_object": f"WASM module uid: {wasm_op.wasm_uid}",
},
}
if cmd.bits:
op["returns"] = returns

return op


def extract_wasm_args_and_returns(
command: tk.Command, op: tk.WASMOp
) -> tuple[list[str], list[str]]:
"""Extract the wasm args and return values as whole register names."""
# This slice removes the extra `_w` cregs (wires) that are not part of the
# circuit, and the output args which are appended after the input args
slice_index = op.num_w + sum(op.output_widths)
only_args = command.args[:-slice_index]
return (
dedupe_bits_to_registers(only_args),
dedupe_bits_to_registers(command.bits),
)


def dedupe_bits_to_registers(bits: "Sequence[UnitID]") -> list[str]:
"""Dedupes a list of bits to their registers, keeping order intact."""
return list(dict.fromkeys([bit.reg_name for bit in bits]))


def make_comment_text(command: tk.Command, op: tk.Op) -> str:
"""Converts a command + op to the PHIR comment spec."""
match op:
case tk.Conditional():
conditional_text = str(command)
cleaned = conditional_text[: conditional_text.find("THEN") + 4]
return f"{cleaned} {make_comment_text(command, op.op)}"

case tk.WASMOp():
args, returns = extract_wasm_args_and_returns(command, op)
return f"WASM function={op.func_name} args={args} returns={returns}"
case _:
return str(command)


def get_decls(qbits: set["Qubit"], cbits: set[tkBit]) -> list[dict[str, str | int]]:
"""Format the qvar and cvar define PHIR elements."""
# TODO(kartik): this may not always be accurate
Expand Down Expand Up @@ -305,6 +359,7 @@ def get_decls(qbits: set["Qubit"], cbits: set[tkBit]) -> list[dict[str, str | in
"size": dim,
}
for cvar, dim in cvar_dim.items()
if cvar != "_w"
]

return decls
Expand Down
17 changes: 13 additions & 4 deletions pytket/phir/sharding/sharder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from .shard import Shard

NOT_IMPLEMENTED_OP_TYPES = [OpType.CircBox, OpType.WASM]
NOT_IMPLEMENTED_OP_TYPES = [OpType.CircBox]

SHARD_TRIGGER_OP_TYPES = [
OpType.Measure,
Expand All @@ -25,6 +25,7 @@
OpType.RangePredicate,
OpType.ExplicitPredicate,
OpType.CopyBits,
OpType.WASM,
]

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -98,7 +99,6 @@ def _process_command(self, command: Command) -> None:
return

if self.should_op_create_shard(command.op):
logger.debug("Building shard for command: %s", command)
self._build_shard(command)
else:
self._add_pending_sub_command(command)
Expand All @@ -112,6 +112,7 @@ def _build_shard(self, command: Command) -> None:
Args:
command: tket command (operation, bits, etc)
"""
logger.debug("Building shard for command: %s", command)
# Rollup any sub commands (SQ gates) that interact with the same qubits
sub_commands: dict[UnitID, list[Command]] = {}
for key in (
Expand All @@ -123,6 +124,7 @@ def _build_shard(self, command: Command) -> None:
for sub_command_list in sub_commands.values():
all_commands.extend(sub_command_list)

logger.debug("All shard commands: %s", all_commands)
qubits_used = set(command.qubits)
bits_written = set(command.bits)
bits_read: set[Bit] = set()
Expand Down Expand Up @@ -185,7 +187,9 @@ def _resolve_shard_dependencies(

for bit_read in bits_read:
if bit_read in self._bit_written_by:
logger.debug("...adding shard dep %s -> RAW")
logger.debug(
"...adding shard dep %s -> RAW", self._bit_written_by[bit_read]
)
depends_upon.add(self._bit_written_by[bit_read])

for bit_written in bits_written:
Expand Down Expand Up @@ -220,6 +224,7 @@ def _mark_dependencies(
self._bit_written_by[bit] = shard.ID
for bit in shard.bits_read:
self._bit_read_by[bit] = shard.ID
logger.debug("... dependencies marked")

def _cleanup_remaining_commands(self) -> None:
"""Cleans up any remaining subcommands.
Expand All @@ -228,7 +233,11 @@ def _cleanup_remaining_commands(self) -> None:
to roll up lingering subcommands.
"""
remaining_qubits = [k for k, v in self._pending_commands.items() if v]
logger.debug(
"Cleaning up remaining subcommands for qubits %s", remaining_qubits
)
for qubit in remaining_qubits:
logger.debug("Adding barrier for subcommands for qubit %s", qubit)
self._circuit.add_barrier([qubit])
# Easiest way to get to a command, since there's no constructor. Could
# create an entire orphan circuit with the matching qubits and the barrier
Expand All @@ -249,7 +258,7 @@ def _add_pending_sub_command(self, command: Command) -> None:
if qubit_key not in self._pending_commands:
self._pending_commands[qubit_key] = []
self._pending_commands[qubit_key].append(command)
logger.debug("Adding pending command %s", command)
logger.debug("Added pending sub-command %s", command)

@staticmethod
def should_op_create_shard(op: Op) -> bool:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ pytket==1.24.0
ruff==0.1.13
setuptools_scm==8.0.4
sphinx==7.2.6
wasmtime==15.0.0
wheel==0.42.0
17 changes: 17 additions & 0 deletions tests/data/wasm/add.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
(module
(type (;0;) (func))
(type (;1;) (func (param i32 i32) (result i32)))
(func $init (type 0))
(func $add (type 1) (param i32 i32) (result i32)
local.get 1
local.get 0
i32.add)
(memory (;0;) 16)
(global $__stack_pointer (mut i32) (i32.const 1048576))
(global (;1;) i32 (i32.const 1048576))
(global (;2;) i32 (i32.const 1048576))
(export "memory" (memory 0))
(export "init" (func $init))
(export "add" (func $add))
(export "__data_end" (global 1))
(export "__heap_base" (global 2)))
38 changes: 38 additions & 0 deletions tests/data/wasm/testfile.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
(module
(type $t0 (func))
(type $t1 (func (param i32) (result i32)))
(type $t2 (func (param i32 i32) (result i32)))
(type $t3 (func (param i64) (result i64)))
(type $t4 (func (param i32)))
(type $t5 (func (result i32)))
(func $init (export "init") (type $t0))
(func $add_one (export "add_one") (type $t1) (param $p0 i32) (result i32)
(i32.add
(local.get $p0)
(i32.const 1)))
(func $multi (export "multi") (type $t2) (param $p0 i32) (param $p1 i32) (result i32)
(i32.mul
(local.get $p1)
(local.get $p0)))
(func $add_two (export "add_two") (type $t1) (param $p0 i32) (result i32)
(i32.add
(local.get $p0)
(i32.const 2)))
(func $add_something (export "add_something") (type $t3) (param $p0 i64) (result i64)
(i64.add
(local.get $p0)
(i64.const 11)))
(func $add_eleven (export "add_eleven") (type $t1) (param $p0 i32) (result i32)
(i32.add
(local.get $p0)
(i32.const 11)))
(func $no_return (export "no_return") (type $t4) (param $p0 i32))
(func $no_parameters (export "no_parameters") (type $t5) (result i32)
(i32.const 11))
(func $new_function (export "new_function") (type $t5) (result i32)
(i32.const 13))
(table $T0 1 1 funcref)
(memory $memory (export "memory") 16)
(global $__stack_pointer (mut i32) (i32.const 1048576))
(global $__data_end (export "__data_end") i32 (i32.const 1048576))
(global $__heap_base (export "__heap_base") i32 (i32.const 1048576)))
Loading