From a6e9e131717be8b34c59ab6ab2044114510fd830 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Tue, 25 Jun 2024 12:49:11 +0100 Subject: [PATCH] feat: Drop linear bits, improve pytket encoding/decoding (#420) Removes the ad-hoc `LINEAR_BIT` used for decoding pytket circuits, and uses non-linear `BOOL_T`s instead. This now lets us encode guppy circuits with measurements; ```python module = GuppyModule("test") module.load(quantum) @guppy(module) def my_func(q0: qubit, q1: qubit) -> tuple[bool,]: q0 = phased_x(q0, py(math.pi / 2), py(-math.pi / 2)) q0 = rz(q0, py(math.pi)) q1 = phased_x(q1, py(math.pi / 2), py(-math.pi / 2)) q1 = rz(q1, py(math.pi)) q0, q1 = zz_max(q0, q1) _ = measure(q0) return (measure(q1),) circ = guppy_to_circuit(my_func) print(to_hugr_mermaid(circ)) tk1 = circ.to_tket1() render_circuit_jupyter(tk1) circ2 = Tk2Circuit(tk1) print(to_hugr_mermaid(circ2)) ``` Mermaid diagram (rooted on the `DataflowBlock`): ```mermaid graph LR subgraph 0 ["(0) Module"] direction LR subgraph 7 ["(7) FuncDefn"] direction LR 3["(3) Input"] 3--"0:0
qubit"-->8 3--"1:1
qubit"-->8 6["(6) Output"] subgraph 8 ["(8) CFG"] direction LR subgraph 1 ["(1) DataflowBlock"] direction LR 4["(4) Input"] 4--"0:0
qubit"-->13 4--"1:0
qubit"-->21 5["(5) Output"] 9["(9) const:custom:f64(1.5707963267948966)"] 9--"0:0
float64"-->10 10["(10) LoadConstant"] 10--"0:1
float64"-->13 11["(11) const:custom:f64(-1.5707963267948966)"] 11--"0:0
float64"-->12 12["(12) LoadConstant"] 12--"0:2
float64"-->13 13["(13) quantum.tket2.PhasedX"] 13--"0:0
qubit"-->16 14["(14) const:custom:f64(3.141592653589793)"] 14--"0:0
float64"-->15 15["(15) LoadConstant"] 15--"0:1
float64"-->16 16["(16) quantum.tket2.RzF64"] 16--"0:0
qubit"-->25 17["(17) const:custom:f64(1.5707963267948966)"] 17--"0:0
float64"-->18 18["(18) LoadConstant"] 18--"0:1
float64"-->21 19["(19) const:custom:f64(-1.5707963267948966)"] 19--"0:0
float64"-->20 20["(20) LoadConstant"] 20--"0:2
float64"-->21 21["(21) quantum.tket2.PhasedX"] 21--"0:0
qubit"-->24 22["(22) const:custom:f64(3.141592653589793)"] 22--"0:0
float64"-->23 23["(23) LoadConstant"] 23--"0:1
float64"-->24 24["(24) quantum.tket2.RzF64"] 24--"0:1
qubit"-->25 25["(25) quantum.tket2.ZZMax"] 25--"0:0
qubit"-->26 25--"1:1
qubit"-->26 26["(26) MakeTuple"] 26--"0:0
[qubit, qubit]"-->27 27["(27) UnpackTuple"] 27--"0:0
qubit"-->28 27--"1:0
qubit"-->30 28["(28) quantum.tket2.Measure"] 28--"0:0
qubit"-->29 29["(29) quantum.tket2.QFree"] 30["(30) quantum.tket2.Measure"] 30--"0:0
qubit"-->31 30--"1:0
[]+[]"-->32 31["(31) quantum.tket2.QFree"] 32["(32) MakeTuple"] 32--"0:0
[[]+[]]"-->33 33["(33) UnpackTuple"] 33--"0:1
[]+[]"-->5 34["(34) Tag"] 34--"0:0
[]"-->5 end 1-."0:0".->2 2["(2) ExitBlock"] end 8--"0:0
[]+[]"-->6 end end ``` tket1 circuit: ![circuit](https://github.com/CQCL/tket2/assets/121866228/62877218-dd24-4e5e-a8f7-ce738e05662c) Re-extracted circuit: ```mermaid graph LR subgraph 0 ["(0) FuncDefn"] direction LR 1["(1) Input"] 1--"0:0
qubit"-->7 1--"1:0
qubit"-->12 2["(2) Output"] 3["(3) const:custom:f64(1.5707963267948966)"] 3--"0:0
float64"-->4 4["(4) LoadConstant"] 4--"0:1
float64"-->7 5["(5) const:custom:f64(-1.5707963267948966)"] 5--"0:0
float64"-->6 6["(6) LoadConstant"] 6--"0:2
float64"-->7 7["(7) quantum.tket2.PhasedX"] 7--"0:0
qubit"-->15 8["(8) const:custom:f64(1.5707963267948966)"] 8--"0:0
float64"-->9 9["(9) LoadConstant"] 9--"0:1
float64"-->12 10["(10) const:custom:f64(-1.5707963267948966)"] 10--"0:0
float64"-->11 11["(11) LoadConstant"] 11--"0:2
float64"-->12 12["(12) quantum.tket2.PhasedX"] 12--"0:0
qubit"-->18 13["(13) const:custom:f64(3.141592653589793)"] 13--"0:0
float64"-->14 14["(14) LoadConstant"] 14--"0:1
float64"-->15 15["(15) quantum.tket2.RzF64"] 15--"0:0
qubit"-->19 16["(16) const:custom:f64(3.141592653589793)"] 16--"0:0
float64"-->17 17["(17) LoadConstant"] 17--"0:1
float64"-->18 18["(18) quantum.tket2.RzF64"] 18--"0:1
qubit"-->19 19["(19) quantum.tket2.ZZMax"] 19--"0:0
qubit"-->21 19--"1:0
qubit"-->20 20["(20) quantum.tket2.Measure"] 20--"0:1
qubit"-->2 20--"1:2
[]+[]"-->2 21["(21) quantum.tket2.Measure"] 21--"0:0
qubit"-->2 21--"1:3
[]+[]"-->2 end ``` This required multiple improvements to the encoder/decoder logic, including - `Tk1Op::Native` operations (backed by a `Tk2Op`) can now have different number of input/output qubit/bits. - Added support for tket2 circuits with different input and output signatures. - Added support for `QAlloc`/`QFree` operations (generated by guppy) by adding extra input/outputs to the circuit. - Added support for pytket's implicit permutations, and recalculates the value when encoding a tket2 circuit. - Preserve the `opgroup` value from decoded pytket operations. - Improved error reporting. Closes #379 --------- Co-authored-by: Seyon Sivarajah --- tket2-py/src/circuit.rs | 7 +- tket2-py/test/test_guppy.py | 7 + tket2-py/tket2/_tket2/circuit.pyi | 4 - tket2-py/tket2/circuit/build.py | 3 +- tket2/src/circuit.rs | 4 +- tket2/src/circuit/command.rs | 8 +- tket2/src/extension.rs | 16 +- tket2/src/serialize/pytket.rs | 123 +++- tket2/src/serialize/pytket/decoder.rs | 300 +++++--- tket2/src/serialize/pytket/encoder.rs | 720 +++++++++++++++----- tket2/src/serialize/pytket/op.rs | 76 ++- tket2/src/serialize/pytket/op/native.rs | 107 ++- tket2/src/serialize/pytket/op/serialised.rs | 12 +- tket2/src/serialize/pytket/tests.rs | 238 +++++-- 14 files changed, 1222 insertions(+), 403 deletions(-) diff --git a/tket2-py/src/circuit.rs b/tket2-py/src/circuit.rs index 931dbe85..49bd6905 100644 --- a/tket2-py/src/circuit.rs +++ b/tket2-py/src/circuit.rs @@ -15,7 +15,7 @@ use pyo3::prelude::*; use std::fmt; use hugr::{type_row, Hugr, HugrView, PortIndex}; -use tket2::extension::{LINEAR_BIT, REGISTRY}; +use tket2::extension::REGISTRY; use tket2::rewrite::CircuitRewrite; use tket2::serialize::TKETDecode; use tket_json_rs::circuit_json::SerialCircuit; @@ -315,11 +315,6 @@ impl PyHugrType { Self(QB_T) } - #[staticmethod] - fn linear_bit() -> Self { - Self(LINEAR_BIT.to_owned()) - } - #[staticmethod] fn bool() -> Self { Self(BOOL_T) diff --git a/tket2-py/test/test_guppy.py b/tket2-py/test/test_guppy.py index d349495c..22865982 100644 --- a/tket2-py/test/test_guppy.py +++ b/tket2-py/test/test_guppy.py @@ -65,3 +65,10 @@ def my_func( # The 7 operations in the function, plus two implicit QFree assert circ.num_operations() == 9 + + tk1 = circ.to_tket1() + assert tk1.n_gates == 7 + assert tk1.n_qubits == 2 + + gates = list(tk1) + assert gates[4].op.type == pytket.circuit.OpType.ZZMax diff --git a/tket2-py/tket2/_tket2/circuit.pyi b/tket2-py/tket2/_tket2/circuit.pyi index 6c92063e..101718ba 100644 --- a/tket2-py/tket2/_tket2/circuit.pyi +++ b/tket2-py/tket2/_tket2/circuit.pyi @@ -157,10 +157,6 @@ class HugrType: def qubit() -> HugrType: """Qubit type from HUGR prelude.""" - @staticmethod - def linear_bit() -> HugrType: - """Linear bit type from TKET1 extension.""" - @staticmethod def bool() -> HugrType: """Boolean type (HUGR 2-ary unit sum).""" diff --git a/tket2-py/tket2/circuit/build.py b/tket2-py/tket2/circuit/build.py index 83ee5082..cf742720 100644 --- a/tket2-py/tket2/circuit/build.py +++ b/tket2-py/tket2/circuit/build.py @@ -3,7 +3,6 @@ from dataclasses import dataclass QB_T = HugrType.qubit() -LB_T = HugrType.linear_bit() BOOL_T = HugrType.bool() @@ -22,7 +21,7 @@ def bits(self) -> list[int]: @classmethod def op(cls) -> CustomOp: - types = [QB_T] * cls.n_qb + [LB_T] * cls.n_lb + types = [QB_T] * cls.n_qb + [BOOL_T] * cls.n_lb return CustomOp(cls.extension_name, cls.gate_name, types, types) diff --git a/tket2/src/circuit.rs b/tket2/src/circuit.rs index c8b13b78..92b72f6f 100644 --- a/tket2/src/circuit.rs +++ b/tket2/src/circuit.rs @@ -681,8 +681,8 @@ mod tests { assert_eq!(circ.operations().count(), 3); assert_eq!(circ.units().count(), qubits + bits); - assert_eq!(circ.nonlinear_units().count(), 0); - assert_eq!(circ.linear_units().count(), qubits + bits); + assert_eq!(circ.nonlinear_units().count(), bits); + assert_eq!(circ.linear_units().count(), qubits); assert_eq!(circ.qubits().count(), qubits); } diff --git a/tket2/src/circuit/command.rs b/tket2/src/circuit/command.rs index 50aa5815..d1cd8223 100644 --- a/tket2/src/circuit/command.rs +++ b/tket2/src/circuit/command.rs @@ -7,7 +7,7 @@ use std::collections::{HashMap, HashSet}; use std::iter::FusedIterator; use hugr::hugr::views::{HierarchyView, SiblingGraph}; -use hugr::hugr::NodeType; +use hugr::hugr::{NodeMetadata, NodeType}; use hugr::ops::{OpTag, OpTrait}; use hugr::{HugrView, IncomingPort, OutgoingPort}; use itertools::Either::{self, Left, Right}; @@ -161,6 +161,12 @@ impl<'circ, T: HugrView> Command<'circ, T> { .port_kind(port) .map_or(false, |kind| kind.is_linear()) } + + /// Returns a metadata value associated with the command's node. + #[inline] + pub fn metadata(&self, key: impl AsRef) -> Option<&NodeMetadata> { + self.circ.hugr().get_metadata(self.node, key) + } } impl<'a, 'circ, T: HugrView> UnitLabeller for &'a Command<'circ, T> { diff --git a/tket2/src/extension.rs b/tket2/src/extension.rs index 12b47517..66a71e26 100644 --- a/tket2/src/extension.rs +++ b/tket2/src/extension.rs @@ -10,7 +10,7 @@ use hugr::extension::{CustomSignatureFunc, ExtensionId, ExtensionRegistry, Signa use hugr::hugr::IdentList; use hugr::std_extensions::arithmetic::float_types::{EXTENSION as FLOAT_EXTENSION, FLOAT64_TYPE}; use hugr::types::type_param::{TypeArg, TypeParam}; -use hugr::types::{CustomType, FunctionType, PolyFuncType, Type, TypeBound}; +use hugr::types::{CustomType, FunctionType, PolyFuncType, TypeBound}; use hugr::{type_row, Extension}; use lazy_static::lazy_static; use smol_str::SmolStr; @@ -21,9 +21,6 @@ pub mod angle; /// The ID of the TKET1 extension. pub const TKET1_EXTENSION_ID: ExtensionId = IdentList::new_unchecked("TKET1"); -/// The name for the linear bit custom type. -pub const LINEAR_BIT_NAME: SmolStr = SmolStr::new_inline("LBit"); - /// The name for opaque TKET1 operations. pub const TKET1_OP_NAME: SmolStr = SmolStr::new_inline("TKET1 Json Op"); @@ -39,8 +36,6 @@ pub static ref TKET1_OP_PAYLOAD : CustomType = pub static ref TKET1_EXTENSION: Extension = { let mut res = Extension::new(TKET1_EXTENSION_ID); - res.add_type(LINEAR_BIT_NAME, vec![], "A linear bit.".into(), TypeBound::Any.into()).unwrap(); - let tket1_op_payload_def = res.add_type(TKET1_PAYLOAD_NAME, vec![], "Opaque TKET1 operation metadata.".into(), TypeBound::Eq.into()).unwrap(); let tket1_op_payload = TypeParam::Opaque{ty:tket1_op_payload_def.instantiate([]).unwrap()}; res.add_op( @@ -52,15 +47,6 @@ pub static ref TKET1_EXTENSION: Extension = { res }; -/// The type for linear bits. Part of the TKET1 extension. -pub static ref LINEAR_BIT: Type = { - Type::new_extension(TKET1_EXTENSION - .get_type(&LINEAR_BIT_NAME) - .unwrap() - .instantiate([]) - .unwrap()) - }; - /// Extension registry including the prelude, TKET1 and Tk2Ops extensions. pub static ref REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ TKET1_EXTENSION.clone(), diff --git a/tket2/src/serialize/pytket.rs b/tket2/src/serialize/pytket.rs index 8d335780..81d08406 100644 --- a/tket2/src/serialize/pytket.rs +++ b/tket2/src/serialize/pytket.rs @@ -6,39 +6,47 @@ mod op; use hugr::types::Type; +use hugr::Node; +use itertools::Itertools; // Required for serialising ops in the tket1 hugr extension. pub(crate) use op::serialised::OpaqueTk1Op; #[cfg(test)] mod tests; +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; use std::path::Path; use std::{fs, io}; -use hugr::ops::{OpType, Value}; +use hugr::ops::{NamedOp, OpType, Value}; use hugr::std_extensions::arithmetic::float_types::ConstF64; use thiserror::Error; -use tket_json_rs::circuit_json::SerialCircuit; +use tket_json_rs::circuit_json::{self, SerialCircuit}; use tket_json_rs::optype::OpType as SerialOpType; use crate::circuit::Circuit; -use self::decoder::JsonDecoder; -use self::encoder::JsonEncoder; +use self::decoder::Tk1Decoder; +use self::encoder::Tk1Encoder; pub use crate::passes::pytket::lower_to_pytket; /// Prefix used for storing metadata in the hugr nodes. -pub const METADATA_PREFIX: &str = "TKET1_JSON"; +pub const METADATA_PREFIX: &str = "TKET1"; /// The global phase specified as metadata. -const METADATA_PHASE: &str = "TKET1_JSON.phase"; -/// The implicit permutation of qubits. -const METADATA_IMPLICIT_PERM: &str = "TKET1_JSON.implicit_permutation"; +const METADATA_PHASE: &str = "TKET1.phase"; /// Explicit names for the input qubit registers. -const METADATA_Q_REGISTERS: &str = "TKET1_JSON.qubit_registers"; +const METADATA_Q_REGISTERS: &str = "TKET1.qubit_registers"; +/// The reordered qubit registers in the output, if an implicit permutation was applied. +const METADATA_Q_OUTPUT_REGISTERS: &str = "TKET1.qubit_output_registers"; /// Explicit names for the input bit registers. -const METADATA_B_REGISTERS: &str = "TKET1_JSON.bit_registers"; +const METADATA_B_REGISTERS: &str = "TKET1.bit_registers"; +/// The reordered bit registers in the output, if an implicit permutation was applied. +const METADATA_B_OUTPUT_REGISTERS: &str = "TKET1.bit_output_registers"; +/// A tket1 operation "opgroup" field. +const METADATA_OPGROUP: &str = "TKET1.opgroup"; /// A serialized representation of a [`Circuit`]. /// @@ -59,7 +67,7 @@ impl TKETDecode for SerialCircuit { type EncodeError = TK1ConvertError; fn decode(self) -> Result { - let mut decoder = JsonDecoder::try_new(&self)?; + let mut decoder = Tk1Decoder::try_new(&self)?; if !self.phase.is_empty() { // TODO - add a phase gate @@ -68,18 +76,18 @@ impl TKETDecode for SerialCircuit { } for com in self.commands { - decoder.add_command(com); + decoder.add_command(com)?; } Ok(decoder.finish().into()) } fn encode(circ: &Circuit) -> Result { - let mut encoder = JsonEncoder::new(circ)?; + let mut encoder = Tk1Encoder::new(circ)?; for com in circ.commands() { let optype = com.optype(); encoder.add_command(com.clone(), optype)?; } - Ok(encoder.finish()) + Ok(encoder.finish(circ)) } } @@ -156,6 +164,47 @@ pub enum OpConvertError { /// The serialized operation is not supported. #[error("Cannot serialize tket2 operation: {0:?}")] UnsupportedOpSerialization(OpType), + /// The operation has non-serializable inputs. + #[error("Operation {} in {node} has an unsupported input of type {typ}.", optype.name())] + UnsupportedInputType { + /// The unsupported type. + typ: Type, + /// The operation name. + optype: OpType, + /// The node. + node: Node, + }, + /// The operation has non-serializable outputs. + #[error("Operation {} in {node} has an unsupported output of type {typ}.", optype.name())] + UnsupportedOutputType { + /// The unsupported type. + typ: Type, + /// The operation name. + optype: OpType, + /// The node. + node: Node, + }, + /// A parameter input could not be evaluated. + #[error("The {typ} parameter input for operation {} in {node} could not be resolved.", optype.name())] + UnresolvedParamInput { + /// The parameter type. + typ: Type, + /// The operation with the missing input param. + optype: OpType, + /// The node. + node: Node, + }, + /// The operation has output-only qubits. + /// This is not currently supported by the encoder. + #[error("Operation {} in {node} has more output qubits than inputs.", optype.name())] + TooManyOutputQubits { + /// The unsupported type. + typ: Type, + /// The operation name. + optype: OpType, + /// The node. + node: Node, + }, /// The opaque tket1 operation had an invalid type parameter. #[error("Opaque TKET1 operation had an invalid type parameter. {error}")] InvalidOpaqueTypeParam { @@ -163,6 +212,35 @@ pub enum OpConvertError { #[from] error: serde_yaml::Error, }, + /// Tried to decode a tket1 operation with not enough parameters. + #[error( + "Operation {} is missing encoded parameters. Expected at least {expected} but only \"{}\" were specified.", + optype.name(), + params.iter().join(", "), + )] + MissingSerialisedParams { + /// The operation name. + optype: OpType, + /// The expected number of parameters. + expected: usize, + /// The given of parameters. + params: Vec, + }, + /// Tried to decode a tket1 operation with not enough qubit/bit arguments. + #[error( + "Operation {} is missing encoded arguments. Expected {expected_qubits} and {expected_bits}, but only \"{args:?}\" were specified.", + optype.name(), + )] + MissingSerialisedArguments { + /// The operation name. + optype: OpType, + /// The expected number of qubits. + expected_qubits: usize, + /// The expected number of bits. + expected_bits: usize, + /// The given of parameters. + args: Vec, + }, } /// Error type for conversion between `Op` and `OpType`. @@ -234,3 +312,20 @@ fn try_constant_to_param(val: &Value) -> Option { let half_turns = radians / std::f64::consts::PI; Some(half_turns.to_string()) } + +/// A hashed register, used to identify registers in the [`Tk1Decoder::register_wire`] map, +/// avoiding string and vector clones on lookup. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +struct RegisterHash { + hash: u64, +} + +impl From<&circuit_json::Register> for RegisterHash { + fn from(reg: &circuit_json::Register) -> Self { + let mut hasher = DefaultHasher::new(); + reg.hash(&mut hasher); + Self { + hash: hasher.finish(), + } + } +} diff --git a/tket2/src/serialize/pytket/decoder.rs b/tket2/src/serialize/pytket/decoder.rs index 04fc6242..b1f2eeb5 100644 --- a/tket2/src/serialize/pytket/decoder.rs +++ b/tket2/src/serialize/pytket/decoder.rs @@ -1,56 +1,51 @@ -//! Intermediate structure for converting decoding [`SerialCircuit`]s into [`Hugr`]s. +//! Intermediate structure for decoding [`SerialCircuit`]s into [`Hugr`]s. -use std::collections::hash_map::DefaultHasher; -use std::collections::HashMap; -use std::hash::{Hash, Hasher}; -use std::mem; +use std::collections::{HashMap, HashSet}; -use hugr::builder::{CircuitBuilder, Container, Dataflow, DataflowHugr, FunctionBuilder}; -use hugr::extension::prelude::QB_T; +use hugr::builder::{Container, Dataflow, DataflowHugr, FunctionBuilder}; +use hugr::extension::prelude::{BOOL_T, QB_T}; +use hugr::ops::handle::NodeHandle; use hugr::ops::OpType; use hugr::types::FunctionType; -use hugr::CircuitUnit; use hugr::{Hugr, Wire}; -use itertools::Itertools; +use itertools::{EitherOrBoth, Itertools}; use serde_json::json; use tket_json_rs::circuit_json; use tket_json_rs::circuit_json::SerialCircuit; use super::op::Tk1Op; -use super::{try_param_to_constant, TK1ConvertError, METADATA_IMPLICIT_PERM, METADATA_PHASE}; -use super::{METADATA_B_REGISTERS, METADATA_Q_REGISTERS}; -use crate::extension::{LINEAR_BIT, REGISTRY, TKET1_EXTENSION_ID}; +use super::{ + try_param_to_constant, OpConvertError, RegisterHash, TK1ConvertError, + METADATA_B_OUTPUT_REGISTERS, METADATA_B_REGISTERS, METADATA_OPGROUP, METADATA_PHASE, + METADATA_Q_OUTPUT_REGISTERS, METADATA_Q_REGISTERS, +}; +use crate::extension::{REGISTRY, TKET1_EXTENSION_ID}; use crate::symbolic_constant_op; /// The state of an in-progress [`FunctionBuilder`] being built from a [`SerialCircuit`]. /// /// Mostly used to define helper internal methods. #[derive(Debug, PartialEq)] -pub(super) struct JsonDecoder { +pub(super) struct Tk1Decoder { /// The Hugr being built. pub hugr: FunctionBuilder, - /// The dangling wires of the builder. - /// Used to generate [`CircuitBuilder`]s. - dangling_wires: Vec, - /// A map from the json registers to the units in the circuit being built. - register_units: HashMap, - /// The number of qubits in the circuit. - num_qubits: usize, - /// The number of bits in the circuit. - num_bits: usize, + /// A map from the tracked pytket registers to the [`Wire`]s in the circuit. + register_wires: HashMap, + /// The ordered list of register to have at the output. + ordered_registers: Vec, + /// A set of registers that encode qubits. + qubit_registers: HashSet, } -impl JsonDecoder { - /// Initialize a new [`JsonDecoder`], using the metadata from a [`SerialCircuit`]. +impl Tk1Decoder { + /// Initialize a new [`Tk1Decoder`], using the metadata from a [`SerialCircuit`]. pub fn try_new(serialcirc: &SerialCircuit) -> Result { let num_qubits = serialcirc.qubits.len(); let num_bits = serialcirc.bits.len(); - let sig = FunctionType::new_endo( - [vec![QB_T; num_qubits], vec![LINEAR_BIT.clone(); num_bits]].concat(), - ) - .with_extension_delta(TKET1_EXTENSION_ID); + let sig = FunctionType::new_endo([vec![QB_T; num_qubits], vec![BOOL_T; num_bits]].concat()) + .with_extension_delta(TKET1_EXTENSION_ID); let name = serialcirc.name.clone().unwrap_or_default(); let mut dfg = FunctionBuilder::new(name, sig.into()).unwrap(); @@ -59,83 +54,202 @@ impl JsonDecoder { // Metadata. The circuit requires "name", and we store other things that // should pass through the serialization roundtrip. dfg.set_metadata(METADATA_PHASE, json!(serialcirc.phase)); - dfg.set_metadata( - METADATA_IMPLICIT_PERM, - json!(serialcirc.implicit_permutation), - ); dfg.set_metadata(METADATA_Q_REGISTERS, json!(serialcirc.qubits)); dfg.set_metadata(METADATA_B_REGISTERS, json!(serialcirc.bits)); - // Map each register element to their starting `CircuitUnit`. - let mut wire_map: HashMap = - HashMap::with_capacity(num_bits + num_qubits); - for (i, register) in serialcirc.qubits.iter().enumerate() { - check_register(register)?; - wire_map.insert(register.into(), CircuitUnit::Linear(i)); + // Compute the output register reordering, and store it in the metadata. + // + // The `implicit_permutation` field is a dictionary mapping input + // registers to output registers on the same path. + // + // Here we store an ordered list showing the order in which the input + // registers appear in the output. + // + // For a circuit with three qubit registers 0, 1, 2 and an implicit + // permutation {0 -> 1, 1 -> 2, 2 -> 0}, `output_to_input` will be + // {1 -> 0, 2 -> 1, 0 -> 2} and the output order will be [2, 0, 1]. + // That is, at position 0 of the output we'll see the register originally + // named 2, at position 1 the register originally named 0, and so on. + let mut output_qubits = Vec::with_capacity(serialcirc.qubits.len()); + let mut output_bits = Vec::with_capacity(serialcirc.bits.len()); + let output_to_input: HashMap = serialcirc + .implicit_permutation + .iter() + .map(|p| (p.1.clone(), p.0.clone())) + .collect(); + for qubit in &serialcirc.qubits { + // For each output position, find the input register that should be there. + output_qubits.push(output_to_input.get(qubit).unwrap_or(qubit).clone()); } - for (i, register) in serialcirc.bits.iter().enumerate() { - check_register(register)?; - wire_map.insert(register.into(), CircuitUnit::Linear(i + num_qubits)); + for bit in &serialcirc.bits { + // For each output position, find the input register that should be there. + output_bits.push(output_to_input.get(bit).unwrap_or(bit).clone()); } + dfg.set_metadata(METADATA_Q_OUTPUT_REGISTERS, json!(output_qubits)); + dfg.set_metadata(METADATA_B_OUTPUT_REGISTERS, json!(output_bits)); + + let qubit_registers = serialcirc.qubits.iter().map(RegisterHash::from).collect(); + + let ordered_registers = serialcirc + .qubits + .iter() + .chain(&serialcirc.bits) + .map(|reg| { + check_register(reg)?; + Ok(RegisterHash::from(reg)) + }) + .collect::, TK1ConvertError>>()?; + + // Map each register element to their starting wire. + let register_wires: HashMap = ordered_registers + .iter() + .copied() + .zip(dangling_wires) + .collect(); - Ok(JsonDecoder { + Ok(Tk1Decoder { hugr: dfg, - dangling_wires, - register_units: wire_map, - num_qubits, - num_bits, + register_wires, + ordered_registers, + qubit_registers, }) } /// Finish building the [`Hugr`]. - pub fn finish(self) -> Hugr { - // TODO: Throw validation error? + pub fn finish(mut self) -> Hugr { + // Order the final wires according to the serial circuit register order. + let mut outputs = Vec::with_capacity(self.ordered_registers.len()); + for register in self.ordered_registers { + let wire = self.register_wires.remove(®ister).unwrap(); + outputs.push(wire); + } + debug_assert!( + self.register_wires.is_empty(), + "Some output wires were not associated with a register." + ); + self.hugr - .finish_hugr_with_outputs(self.dangling_wires, ®ISTRY) + .finish_hugr_with_outputs(outputs, ®ISTRY) .unwrap() } /// Add a tket1 [`circuit_json::Command`] from the serial circuit to the /// decoder. - pub fn add_command(&mut self, command: circuit_json::Command) { - // TODO Store the command's `opgroup` in the metadata. - let circuit_json::Command { op, args, .. } = command; + pub fn add_command(&mut self, command: circuit_json::Command) -> Result<(), OpConvertError> { + let circuit_json::Command { + op, args, opgroup, .. + } = command; + let op_params = op.params.clone().unwrap_or_default(); + + // Interpret the serialised operation as a [`Tk1Op`]. let num_qubits = args .iter() - .take_while(|&arg| match self.reg_wire(arg) { - CircuitUnit::Linear(i) => i < self.num_qubits, - _ => false, - }) + .take_while(|&arg| self.is_qubit_register(arg)) .count(); let num_input_bits = args.len() - num_qubits; - let op_params = op.params.clone(); let tk1op = Tk1Op::from_serialised_op(op, num_qubits, num_input_bits); - let param_units = tk1op - .param_ports() - .enumerate() - .filter_map(|(i, _port)| op_params.as_ref()?.get(i).map(String::as_ref)) - .map(|p| CircuitUnit::Wire(self.create_param_wire(p))) - .collect_vec(); - let arg_units = args.into_iter().map(|reg| self.reg_wire(®)); - - let append_wires: Vec = arg_units.chain(param_units).collect_vec(); + let (input_wires, output_registers) = self.get_op_wires(&tk1op, &args, op_params)?; let op: OpType = (&tk1op).into(); - self.with_circ_builder(|circ| { - circ.append_and_consume(op, append_wires).unwrap(); - }); + let new_op = self.hugr.add_dataflow_op(op, input_wires).unwrap(); + let wires = new_op.outputs(); + + // Store the opgroup metadata. + if let Some(opgroup) = opgroup { + self.hugr + .set_child_metadata(new_op.node(), METADATA_OPGROUP, json!(opgroup)); + } + + // Assign the new output wires to some register, replacing the previous association. + for (register, wire) in output_registers.into_iter().zip_eq(wires) { + self.set_register_wire(register, wire); + } + + Ok(()) } - /// Apply a function to the internal hugr builder viewed as a [`CircuitBuilder`]. - fn with_circ_builder(&mut self, f: F) -> T - where - F: FnOnce(&mut CircuitBuilder>) -> T, - { - let mut circ = self.hugr.as_circuit(mem::take(&mut self.dangling_wires)); - let res = f(&mut circ); - self.dangling_wires = circ.finish(); - res + /// Returns the input wires to connect to a new operation + /// and the registers to associate with outputs. + /// + /// It may add constant nodes to the Hugr if the operation has constant parameters. + fn get_op_wires( + &mut self, + tk1op: &Tk1Op, + args: &[circuit_json::Register], + params: Vec, + ) -> Result<(Vec, Vec), OpConvertError> { + // Arguments are always ordered with qubits first, and then bits. + let mut inputs: Vec = Vec::with_capacity(args.len() + params.len()); + let mut outputs: Vec = + Vec::with_capacity(tk1op.qubit_outputs() + tk1op.bit_outputs()); + + let mut current_arg = 0; + let mut next_arg = || { + if args.len() <= current_arg { + return Err(OpConvertError::MissingSerialisedArguments { + optype: tk1op.optype(), + expected_qubits: tk1op.qubit_inputs(), + expected_bits: tk1op.bit_inputs(), + args: args.to_owned(), + }); + } + current_arg += 1; + Ok(&args[current_arg - 1]) + }; + + // Qubit wires + assert_eq!( + tk1op.qubit_inputs(), + tk1op.qubit_outputs(), + "Operations with different numbers of input and output qubits are not currently supported." + ); + for _ in 0..tk1op.qubit_inputs() { + let reg = next_arg()?; + inputs.push(self.register_wire(reg)); + outputs.push(reg.into()); + } + + // Bit wires + for zip in (0..tk1op.bit_inputs()).zip_longest(0..tk1op.bit_outputs()) { + let reg = next_arg()?; + match zip { + EitherOrBoth::Both(_inp, _out) => { + // A bit used both as input and output. + inputs.push(self.register_wire(reg)); + outputs.push(reg.into()); + } + EitherOrBoth::Left(_inp) => { + // A bit used only used as input. + inputs.push(self.register_wire(reg)); + } + EitherOrBoth::Right(_out) => { + // A new bit output. + outputs.push(reg.into()); + } + } + } + + // Check that the operation is not missing parameters. + // + // Nb: `Tk1Op::Opaque` operations may not have parameters in their hugr definition. + // In that case, we just store the parameter values in the opaque data. + if tk1op.num_params() > params.len() { + return Err(OpConvertError::MissingSerialisedParams { + optype: tk1op.optype(), + expected: tk1op.num_params(), + params, + }); + } + // Add the parameter wires to the input. + inputs.extend( + tk1op + .param_ports() + .zip(params) + .map(|(_port, param)| self.create_param_wire(¶m)), + ); + + Ok((inputs, outputs)) } /// Returns the wire carrying a parameter. @@ -155,29 +269,19 @@ impl JsonDecoder { } } - /// Return the wire unit for the `elem`th value of a given register. - /// - /// Relies on TKET1 constraint that all registers have unique names. - fn reg_wire(&self, register: &circuit_json::Register) -> CircuitUnit { - self.register_units[®ister.into()] + /// Return the [`Wire`] associated with a register. + fn register_wire(&self, register: impl Into) -> Wire { + self.register_wires[®ister.into()] } -} -/// A hashed register, used to identify registers in the [`JsonDecoder::register_wire`] map, -/// avoiding string clones on lookup. -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -struct RegisterHash { - hash: u64, -} + /// Update the tracked [`Wire`] for a register. + fn set_register_wire(&mut self, register: impl Into, unit: Wire) { + self.register_wires.insert(register.into(), unit); + } -impl From<&circuit_json::Register> for RegisterHash { - fn from(reg: &circuit_json::Register) -> Self { - let mut hasher = DefaultHasher::new(); - reg.0.hash(&mut hasher); - reg.1.hash(&mut hasher); - Self { - hash: hasher.finish(), - } + /// Returns `true` if the register is a qubit register. + fn is_qubit_register(&self, register: impl Into) -> bool { + self.qubit_registers.contains(®ister.into()) } } diff --git a/tket2/src/serialize/pytket/encoder.rs b/tket2/src/serialize/pytket/encoder.rs index e64768b8..fa26b2b7 100644 --- a/tket2/src/serialize/pytket/encoder.rs +++ b/tket2/src/serialize/pytket/encoder.rs @@ -1,122 +1,78 @@ -//! Intermediate structure for converting encoding [`Circuit`]s into [`SerialCircuit`]s. +//! Intermediate structure for encoding [`Circuit`]s into [`SerialCircuit`]s. use core::panic; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet, VecDeque}; -use hugr::extension::prelude::QB_T; -use hugr::ops::{NamedOp, OpType}; +use hugr::extension::prelude::{BOOL_T, QB_T}; +use hugr::ops::{OpTrait, OpType}; use hugr::std_extensions::arithmetic::float_types::FLOAT64_TYPE; use hugr::{HugrView, Wire}; -use itertools::{Either, Itertools}; -use tket_json_rs::circuit_json::{self, Permutation, Register, SerialCircuit}; +use itertools::Itertools; +use tket_json_rs::circuit_json::Register as RegisterUnit; +use tket_json_rs::circuit_json::{self, SerialCircuit}; use crate::circuit::command::{CircuitUnit, Command}; use crate::circuit::Circuit; -use crate::extension::LINEAR_BIT; use crate::ops::{match_symb_const_op, op_matches}; +use crate::serialize::pytket::RegisterHash; use crate::Tk2Op; use super::op::Tk1Op; use super::{ - try_constant_to_param, OpConvertError, TK1ConvertError, METADATA_B_REGISTERS, - METADATA_IMPLICIT_PERM, METADATA_PHASE, METADATA_Q_REGISTERS, + try_constant_to_param, OpConvertError, TK1ConvertError, METADATA_B_OUTPUT_REGISTERS, + METADATA_B_REGISTERS, METADATA_OPGROUP, METADATA_PHASE, METADATA_Q_OUTPUT_REGISTERS, + METADATA_Q_REGISTERS, }; /// The state of an in-progress [`SerialCircuit`] being built from a [`Circuit`]. -#[derive(Debug, PartialEq)] -pub(super) struct JsonEncoder { +#[derive(Debug, Clone)] +pub(super) struct Tk1Encoder { /// The name of the circuit being encoded. name: Option, /// Global phase value. Defaults to "0" phase: String, - /// Implicit permutation of output qubits - implicit_permutation: Vec, - /// The current commands + /// The current serialised commands commands: Vec, - /// The TKET1 qubit registers associated to each qubit unit of the circuit. - qubit_to_reg: HashMap, - /// The TKET1 bit registers associated to each linear bit unit of the circuit. - bit_to_reg: HashMap, - /// The ordered TKET1 names for the input qubit registers. - /// - /// Nb: Although `tket-json-rs` calls these "registers", they're actually - /// identifiers for single qubits in the `Register::0` register. - qubit_registers: Vec, - /// The ordered TKET1 names for the input bit registers. - /// - /// Nb: Although `tket-json-rs` calls these "registers", they're actually - /// identifiers for single bits in the `Register::0` register. - bit_registers: Vec, - /// A register of wires with constant values, used to recover TKET1 - /// parameters. - parameters: HashMap, + /// A tracker for the qubits used in the circuit. + qubits: QubitTracker, + /// A tracker for the bits used in the circuit. + bits: BitTracker, + /// A tracker for the operation parameters used in the circuit. + parameters: ParameterTracker, } -impl JsonEncoder { +impl Tk1Encoder { /// Create a new [`JsonEncoder`] from a [`Circuit`]. pub fn new(circ: &Circuit) -> Result { let name = circ.name().map(str::to_string); let hugr = circ.hugr(); - let mut qubit_registers = vec![]; - let mut bit_registers = vec![]; - let mut phase = "0".to_string(); - let mut implicit_permutation = vec![]; + // Check for unsupported input types. + for (_, _, typ) in circ.units() { + if ![FLOAT64_TYPE, QB_T, BOOL_T].contains(&typ) { + return Err(TK1ConvertError::NonSerializableInputs { typ }); + } + } // Recover other parameters stored in the metadata // TODO: Check for invalid encoded metadata - let root = circ.parent(); - if let Some(p) = hugr.get_metadata(root, METADATA_PHASE) { - phase = p.as_str().unwrap().to_string(); - } - if let Some(perm) = hugr.get_metadata(root, METADATA_IMPLICIT_PERM) { - implicit_permutation = serde_json::from_value(perm.clone()).unwrap(); - } - if let Some(q_regs) = hugr.get_metadata(root, METADATA_Q_REGISTERS) { - qubit_registers = serde_json::from_value(q_regs.clone()).unwrap(); - } - if let Some(b_regs) = hugr.get_metadata(root, METADATA_B_REGISTERS) { - bit_registers = serde_json::from_value(b_regs.clone()).unwrap(); - } - - // Map the Hugr units to tket1 register names. - // Uses the names from the metadata if available, or initializes new sequentially-numbered registers. - let mut qubit_to_reg = HashMap::new(); - let mut bit_to_reg = HashMap::new(); - let get_register = |registers: &mut Vec, name: &str, index| { - registers.get(index).cloned().unwrap_or_else(|| { - let r = Register(name.to_string(), vec![index as i64]); - registers.push(r.clone()); - r - }) + let phase = match hugr.get_metadata(circ.parent(), METADATA_PHASE) { + Some(p) => p.as_str().unwrap().to_string(), + None => "0".to_string(), }; - for (unit, _, ty) in circ.units() { - if ty == QB_T { - let index = qubit_to_reg.len(); - let reg = get_register(&mut qubit_registers, "q", index); - qubit_to_reg.insert(unit, reg); - } else if ty == *LINEAR_BIT { - let index = bit_to_reg.len(); - let reg = get_register(&mut bit_registers, "b", index); - bit_to_reg.insert(unit, reg.clone()); - } - } - let mut encoder = Self { + let qubit_tracker = QubitTracker::new(circ); + let bit_tracker = BitTracker::new(circ); + let parameter_tracker = ParameterTracker::new(circ); + + Ok(Self { name, phase, - implicit_permutation, commands: vec![], - qubit_to_reg, - bit_to_reg, - qubit_registers, - bit_registers, - parameters: HashMap::new(), - }; - - encoder.add_input_parameters(circ)?; - - Ok(encoder) + qubits: qubit_tracker, + bits: bit_tracker, + parameters: parameter_tracker, + }) } /// Add a circuit command to the serialization. @@ -126,39 +82,113 @@ impl JsonEncoder { optype: &OpType, ) -> Result<(), OpConvertError> { // Register any output of the command that can be used as a TKET1 parameter. - if self.record_parameters(&command, optype) { + if self.parameters.record_parameters(&command, optype)? { // for now all ops that record parameters should be ignored (are // just constants) return Ok(()); } - let (args, params): (Vec, Vec) = - command - .inputs() - .partition_map(|(u, _, _)| match self.unit_to_register(u) { - Some(r) => Either::Left(r), - None => match u { - CircuitUnit::Wire(w) => Either::Right(w), - CircuitUnit::Linear(_) => { - panic!("No register found for the linear input {u:?}.") - } - }, + // Special case for the QAlloc operation. + // This does not translate to a TKET1 operation, we just start tracking a new qubit register. + if optype == &Tk2Op::QAlloc.into() { + let Some((CircuitUnit::Linear(unit_id), _, _)) = command.outputs().next() else { + panic!("QAlloc should have a single qubit output.") + }; + debug_assert!(self.qubits.get(unit_id).is_none()); + self.qubits.add_qubit_register(unit_id); + return Ok(()); + } + + let Some(tk1op) = Tk1Op::try_from_optype(optype.clone())? else { + // This command should be ignored. + return Ok(()); + }; + + // Get the registers and wires associated with the operation's inputs. + let mut qubit_args = Vec::with_capacity(tk1op.qubit_inputs()); + let mut bit_args = Vec::with_capacity(tk1op.bit_inputs()); + let mut params = Vec::with_capacity(tk1op.num_params()); + for (unit, _, ty) in command.inputs() { + if ty == QB_T { + let reg = self.unit_to_register(unit).unwrap_or_else(|| { + panic!( + "No register found for qubit input {unit} in node {}.", + command.node(), + ) + }); + qubit_args.push(reg); + } else if ty == BOOL_T { + let reg = self.unit_to_register(unit).unwrap_or_else(|| { + panic!( + "No register found for bit input {unit} in node {}.", + command.node(), + ) + }); + bit_args.push(reg); + } else if ty == FLOAT64_TYPE { + let CircuitUnit::Wire(param_wire) = unit else { + unreachable!("Float types are not linear.") + }; + params.push(param_wire); + } else { + return Err(OpConvertError::UnsupportedInputType { + typ: ty.clone(), + optype: optype.clone(), + node: command.node(), }); + } + } - // TODO Restore the opgroup (once the decoding supports it) - let opgroup = None; + for (unit, _, ty) in command.outputs() { + if ty == QB_T { + // If the qubit is not already in the qubit tracker, add it as a + // new register. + let CircuitUnit::Linear(unit_id) = unit else { + panic!("Qubit types are linear.") + }; + if self.qubits.get(unit_id).is_none() { + let reg = self.qubits.add_qubit_register(unit_id); + qubit_args.push(reg.clone()); + } + } else if ty == BOOL_T { + // If the operation has any bit outputs, create a new one bit + // register. + // + // Note that we do not reassign input registers to the new + // output wires as we do not know if the bit value was modified + // by the operation, and the old value may be needed later. + // + // This may cause register duplication for opaque operations + // with input bits. + let CircuitUnit::Wire(wire) = unit else { + panic!("Bool types are not linear.") + }; + let reg = self.bits.add_bit_register(wire); + bit_args.push(reg.clone()); + } else { + return Err(OpConvertError::UnsupportedOutputType { + typ: ty.clone(), + optype: optype.clone(), + node: command.node(), + }); + } + } + + let opgroup: Option = command + .metadata(METADATA_OPGROUP) + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); // Convert the command's operator to a pytket serialized one. This will // return an error for operations that should have been caught by the // `record_parameters` branch above (in addition to other unsupported // ops). - let op: Tk1Op = Tk1Op::try_from_optype(optype.clone())?; - let mut op: circuit_json::Operation = op + let mut serial_op: circuit_json::Operation = tk1op .serialised_op() .ok_or_else(|| OpConvertError::UnsupportedOpSerialization(optype.clone()))?; if !params.is_empty() { - op.params = Some( + serial_op.params = Some( params .into_iter() .filter_map(|w| self.parameters.get(&w)) @@ -169,21 +199,369 @@ impl JsonEncoder { // TODO: ops that contain free variables. // (update decoder to ignore them too, but store them in the wrapped op) - let command = circuit_json::Command { op, args, opgroup }; + let mut args = qubit_args; + args.append(&mut bit_args); + let command = circuit_json::Command { + op: serial_op, + args, + opgroup, + }; self.commands.push(command); + Ok(()) } /// Finish building and return the final [`SerialCircuit`]. - pub fn finish(self) -> SerialCircuit { + pub fn finish(self, circ: &Circuit) -> SerialCircuit { + let (qubits, qubits_permutation) = self.qubits.finish(circ); + let (bits, mut bits_permutation) = self.bits.finish(circ); + + let mut implicit_permutation = qubits_permutation; + implicit_permutation.append(&mut bits_permutation); + SerialCircuit { name: self.name, phase: self.phase, commands: self.commands, - qubits: self.qubit_registers, - bits: self.bit_registers, - implicit_permutation: self.implicit_permutation, + qubits, + bits, + implicit_permutation, + } + } + + /// Translate a linear [`CircuitUnit`] into a [`RegisterUnit`], if possible. + fn unit_to_register(&self, unit: CircuitUnit) -> Option { + match unit { + CircuitUnit::Linear(i) => self.qubits.get(i).cloned(), + CircuitUnit::Wire(wire) => self.bits.get(&wire).cloned(), + } + } +} + +/// A structure for tracking qubits used in the circuit being encoded. +/// +/// Nb: Although `tket-json-rs` has a "Register" struct, it's actually +/// an identifier for single qubits in the `Register::0` register. +/// We rename it to `RegisterUnit` here to avoid confusion. +#[derive(Debug, Clone, Default)] +struct QubitTracker { + /// The ordered TKET1 names for the input qubit registers. + inputs: Vec, + /// The ordered TKET1 names for the output qubit registers. + outputs: Option>, + /// The TKET1 qubit registers associated to each qubit unit of the circuit. + qubit_to_reg: HashMap, + /// A generator of new registers units to use for bit wires. + unit_generator: RegisterUnitGenerator, +} + +impl QubitTracker { + /// Create a new [`QubitTracker`] from the qubit inputs of a [`Circuit`]. + /// Reads the [`METADATA_Q_REGISTERS`] metadata entry with preset pytket qubit register names. + /// + /// If the circuit contains more qubit inputs than the provided list, + /// new registers are created for the remaining qubits. + pub fn new(circ: &Circuit) -> Self { + let mut tracker = QubitTracker::default(); + + if let Some(input_regs) = circ + .hugr() + .get_metadata(circ.parent(), METADATA_Q_REGISTERS) + { + tracker.inputs = serde_json::from_value(input_regs.clone()).unwrap(); + } + let output_regs = circ + .hugr() + .get_metadata(circ.parent(), METADATA_Q_OUTPUT_REGISTERS) + .map(|regs| serde_json::from_value(regs.clone()).unwrap()); + if let Some(output_regs) = output_regs { + tracker.outputs = Some(output_regs); + } + + tracker.unit_generator = RegisterUnitGenerator::new( + "q", + tracker + .inputs + .iter() + .chain(tracker.outputs.iter().flatten()), + ); + + let qubit_count = circ.units().filter(|(_, _, ty)| ty == &QB_T).count(); + + for i in 0..qubit_count { + // Use the given input register names if available, or create new ones. + if let Some(reg) = tracker.inputs.get(i) { + tracker.qubit_to_reg.insert(i, reg.clone()); + } else { + let reg = tracker.add_qubit_register(i).clone(); + tracker.inputs.push(reg); + } + } + + tracker + } + + /// Add a new register unit for a qubit wire. + pub fn add_qubit_register(&mut self, unit_id: usize) -> &RegisterUnit { + let reg = self.unit_generator.next(); + self.qubit_to_reg.insert(unit_id, reg); + self.qubit_to_reg.get(&unit_id).unwrap() + } + + /// Returns the register unit for a qubit wire, if it exists. + pub fn get(&self, unit_id: usize) -> Option<&RegisterUnit> { + self.qubit_to_reg.get(&unit_id) + } + + /// Consumes the tracker and returns the final list of qubit registers, along + /// with the final permutation of the outputs. + pub fn finish( + mut self, + _circ: &Circuit, + ) -> (Vec, Vec) { + // Ensure the input and output lists have the same registers. + let mut outputs = self.outputs.unwrap_or_default(); + let mut input_regs: HashSet = + self.inputs.iter().map(RegisterHash::from).collect(); + let output_regs: HashSet = outputs.iter().map(RegisterHash::from).collect(); + + for inp in &self.inputs { + if !output_regs.contains(&inp.into()) { + outputs.push(inp.clone()); + } + } + for out in &outputs { + if !input_regs.contains(&out.into()) { + self.inputs.push(out.clone()); + } + } + input_regs.extend(output_regs); + + // Add registers defined mid-circuit to both ends. + for reg in self.qubit_to_reg.into_values() { + if !input_regs.contains(&(®).into()) { + self.inputs.push(reg.clone()); + outputs.push(reg); + } + } + + // TODO: Look at the circuit outputs to determine the final permutation. + // + // We don't have the `CircuitUnit::Linear` assignments for the outputs + // here, so that requires some extra piping. + let permutation = outputs + .into_iter() + .zip(&self.inputs) + .map(|(out, inp)| circuit_json::Permutation(inp.clone(), out)) + .collect_vec(); + + (self.inputs, permutation) + } +} + +/// A structure for tracking bits used in the circuit being encoded. +/// +/// Nb: Although `tket-json-rs` has a "Register" struct, it's actually +/// an identifier for single bits in the `Register::0` register. +/// We rename it to `RegisterUnit` here to avoid confusion. +#[derive(Debug, Clone, Default)] +struct BitTracker { + /// The ordered TKET1 names for the bit inputs. + inputs: Vec, + /// The expected order of TKET1 names for the bit outputs, + /// if that was stored in the metadata. + outputs: Option>, + /// Map each bit wire to a TKET1 register element. + bit_to_reg: HashMap, + /// Registers defined in the metadata, but not present in the circuit + /// inputs. + unused_registers: VecDeque, + /// A generator of new registers units to use for bit wires. + unit_generator: RegisterUnitGenerator, +} + +impl BitTracker { + /// Create a new [`BitTracker`] from the bit inputs of a [`Circuit`]. + /// Reads the [`METADATA_B_REGISTERS`] metadata entry with preset pytket bit register names. + /// + /// If the circuit contains more bit inputs than the provided list, + /// new registers are created for the remaining bits. + /// + /// TODO: Compute output bit permutations when finishing the circuit. + pub fn new(circ: &Circuit) -> Self { + let mut tracker = BitTracker::default(); + + if let Some(input_regs) = circ + .hugr() + .get_metadata(circ.parent(), METADATA_B_REGISTERS) + { + tracker.inputs = serde_json::from_value(input_regs.clone()).unwrap(); + } + let output_regs = circ + .hugr() + .get_metadata(circ.parent(), METADATA_B_OUTPUT_REGISTERS) + .map(|regs| serde_json::from_value(regs.clone()).unwrap()); + if let Some(output_regs) = output_regs { + tracker.outputs = Some(output_regs); + } + + tracker.unit_generator = RegisterUnitGenerator::new( + "c", + tracker + .inputs + .iter() + .chain(tracker.outputs.iter().flatten()), + ); + + let bit_input_wires = circ.units().filter_map(|u| match u { + (CircuitUnit::Wire(w), _, ty) if ty == BOOL_T => Some(w), + _ => None, + }); + + let mut unused_registers: HashSet = tracker.inputs.iter().cloned().collect(); + for (i, wire) in bit_input_wires.enumerate() { + // If the input is not used in the circuit, ignore it. + if circ + .hugr() + .linked_inputs(wire.node(), wire.source()) + .next() + .is_none() + { + continue; + } + + // Use the given input register names if available, or create new ones. + if let Some(reg) = tracker.inputs.get(i) { + unused_registers.remove(reg); + tracker.bit_to_reg.insert(wire, reg.clone()); + } else { + let reg = tracker.add_bit_register(wire).clone(); + tracker.inputs.push(reg); + }; } + + // If a register was defined in the metadata but not used in the circuit, + // we keep it so it can be assigned to an operation output. + tracker.unused_registers = unused_registers.into_iter().collect(); + + tracker + } + + /// Add a new register unit for a bit wire. + pub fn add_bit_register(&mut self, wire: Wire) -> &RegisterUnit { + let reg = self + .unused_registers + .pop_front() + .unwrap_or_else(|| self.unit_generator.next()); + + self.bit_to_reg.insert(wire, reg); + self.bit_to_reg.get(&wire).unwrap() + } + + /// Returns the register unit for a bit wire, if it exists. + pub fn get(&self, wire: &Wire) -> Option<&RegisterUnit> { + self.bit_to_reg.get(wire) + } + + /// Consumes the tracker and returns the final list of bit registers, along + /// with the final permutation of the outputs. + pub fn finish( + mut self, + circ: &Circuit, + ) -> (Vec, Vec) { + let mut circuit_output_order: Vec = Vec::with_capacity(self.inputs.len()); + for (node, port) in circ.hugr().all_linked_outputs(circ.output_node()) { + let wire = Wire::new(node, port); + if let Some(reg) = self.bit_to_reg.get(&wire) { + circuit_output_order.push(reg.clone()); + } + } + + // Ensure the input and output lists have the same registers. + let mut outputs = self.outputs.unwrap_or_default(); + let mut input_regs: HashSet = + self.inputs.iter().map(RegisterHash::from).collect(); + let output_regs: HashSet = outputs.iter().map(RegisterHash::from).collect(); + + for inp in &self.inputs { + if !output_regs.contains(&inp.into()) { + outputs.push(inp.clone()); + } + } + for out in &outputs { + if !input_regs.contains(&out.into()) { + self.inputs.push(out.clone()); + } + } + input_regs.extend(output_regs); + + // Add registers defined mid-circuit to both ends. + for reg in self.bit_to_reg.into_values() { + if !input_regs.contains(&(®).into()) { + self.inputs.push(reg.clone()); + outputs.push(reg); + } + } + + // And ensure `circuit_output_order` has all virtual registers added too. + let circuit_outputs: HashSet = circuit_output_order + .iter() + .map(RegisterHash::from) + .collect(); + for out in &outputs { + if !circuit_outputs.contains(&out.into()) { + circuit_output_order.push(out.clone()); + } + } + + // Compute the final permutation. This is a combination of two mappings: + // - First, the original implicit permutation for the circuit, if this was decoded from pytket. + let original_permutation: HashMap = self + .inputs + .iter() + .zip(&outputs) + .map(|(inp, out)| (inp.clone(), RegisterHash::from(out))) + .collect(); + // - Second, the actual reordering of outputs seen at the circuit's output node. + let mut circuit_permutation: HashMap = outputs + .iter() + .zip(circuit_output_order) + .map(|(out, circ_out)| (RegisterHash::from(out), circ_out)) + .collect(); + // The final permutation is the composition of these two mappings. + let permutation = original_permutation + .into_iter() + .map(|(inp, out)| { + circuit_json::Permutation(inp, circuit_permutation.remove(&out).unwrap()) + }) + .collect_vec(); + + (self.inputs, permutation) + } +} + +/// A structure for tracking the parameters of a circuit being encoded. +#[derive(Debug, Clone, Default)] +struct ParameterTracker { + /// The parameters associated with each wire. + parameters: HashMap, +} + +impl ParameterTracker { + /// Create a new [`ParameterTracker`] from the input parameters of a [`Circuit`]. + fn new(circ: &Circuit) -> Self { + let mut tracker = ParameterTracker::default(); + + let float_input_wires = circ.units().filter_map(|u| match u { + (CircuitUnit::Wire(w), _, ty) if ty == FLOAT64_TYPE => Some(w), + _ => None, + }); + + for (i, wire) in float_input_wires.enumerate() { + tracker.add_parameter(wire, format!("f{i}")); + } + + tracker } /// Record any output of the command that can be used as a TKET1 parameter. @@ -193,25 +571,40 @@ impl JsonEncoder { &mut self, command: &Command<'_, T>, optype: &OpType, - ) -> bool { - // Only consider commands where all inputs are parameters. - let inputs = command - .inputs() - .filter_map(|(unit, _, _)| match unit { - CircuitUnit::Wire(wire) => self.parameters.get(&wire), - CircuitUnit::Linear(_) => None, - }) - .collect_vec(); - if inputs.len() != command.input_count() { - debug_assert!( - !matches!(optype, OpType::Const(_) | OpType::LoadConstant(_)), - "Found a {} with {} inputs, of which {} are non-linear. In node {:?}", - optype.name(), - command.input_count(), - inputs.len(), - command.node() - ); - return false; + ) -> Result { + let input_count = if let Some(signature) = optype.dataflow_signature() { + // Only consider commands where all inputs are parameters, + // and some outputs are also parameters. + let all_inputs = signature.input().iter().all(|ty| ty == &FLOAT64_TYPE); + let some_output = signature.output().iter().any(|ty| ty == &FLOAT64_TYPE); + if !all_inputs || !some_output { + return Ok(false); + } + signature.input_count() + } else if let OpType::Const(_) = optype { + // `Const` is a special non-dataflow command we can handle. + // It has zero inputs. + 0 + } else { + // Not a parameter-generating command. + return Ok(false); + }; + + // Collect the input parameters. + let mut inputs = Vec::with_capacity(input_count); + for (unit, _, _) in command.inputs() { + let CircuitUnit::Wire(wire) = unit else { + panic!("Float types are not linear") + }; + let Some(param) = self.parameters.get(&wire) else { + let typ = FLOAT64_TYPE; + return Err(OpConvertError::UnresolvedParamInput { + typ, + optype: optype.clone(), + node: command.node(), + }); + }; + inputs.push(param); } let param = match optype { @@ -219,7 +612,7 @@ impl JsonEncoder { // New constant, register it if it can be interpreted as a parameter. match try_constant_to_param(const_op.value()) { Some(param) => param, - None => return false, + None => return Ok(false), } } OpType::LoadConstant(_op_type) => { @@ -231,30 +624,18 @@ impl JsonEncoder { } _ => { let Some(s) = match_symb_const_op(optype) else { - return false; + return Ok(false); }; s.to_string() } }; for (unit, _, _) in command.outputs() { - match unit { - CircuitUnit::Wire(wire) => self.add_parameter(wire, param.clone()), - CircuitUnit::Linear(_) => panic!( - "Found a non-wire output {unit:?} for a {} command.", - optype.name() - ), + if let CircuitUnit::Wire(wire) = unit { + self.add_parameter(wire, param.clone()) } } - true - } - - /// Translate a linear [`CircuitUnit`] into a [`Register`], if possible. - fn unit_to_register(&self, unit: CircuitUnit) -> Option { - self.qubit_to_reg - .get(&unit) - .or_else(|| self.bit_to_reg.get(&unit)) - .cloned() + Ok(true) } /// Associate a parameter expression with a wire. @@ -262,23 +643,48 @@ impl JsonEncoder { self.parameters.insert(wire, param); } - /// Adds a parameter for each floating-point input to the circuit. - fn add_input_parameters( - &mut self, - circ: &Circuit, - ) -> Result<(), TK1ConvertError> { - let mut num_f64_inputs = 0; - for (wire, _, typ) in circ.units() { - match wire { - CircuitUnit::Linear(_) => {} - CircuitUnit::Wire(wire) if typ == FLOAT64_TYPE => { - let param = format!("f{num_f64_inputs}"); - num_f64_inputs += 1; - self.add_parameter(wire, param); - } - CircuitUnit::Wire(_) => return Err(TK1ConvertError::NonSerializableInputs { typ }), + /// Returns the parameter expression for a wire, if it exists. + fn get(&self, wire: &Wire) -> Option<&String> { + self.parameters.get(wire) + } +} + +/// A utility class for finding new unused qubit/bit names. +#[derive(Debug, Clone, Default)] +struct RegisterUnitGenerator { + /// The next index to use for a new register. + next_unit: u16, + /// The register name to use. + register: String, +} + +impl RegisterUnitGenerator { + /// Create a new [`RegisterUnitGenerator`] + /// + /// Scans the set of existing registers to find the last used index, and + /// starts generating new unit names from there. + pub fn new<'a>( + register: impl ToString, + existing: impl IntoIterator, + ) -> Self { + let register = register.to_string(); + let mut last_unit: Option = None; + for reg in existing { + if reg.0 != register { + continue; } + last_unit = Some(last_unit.unwrap_or_default().max(reg.1[0] as u16)); } - Ok(()) + RegisterUnitGenerator { + register, + next_unit: last_unit.map_or(0, |i| i + 1), + } + } + + /// Returns a fresh register unit. + pub fn next(&mut self) -> RegisterUnit { + let unit = self.next_unit; + self.next_unit += 1; + RegisterUnit(self.register.clone(), vec![unit as i64]) } } diff --git a/tket2/src/serialize/pytket/op.rs b/tket2/src/serialize/pytket/op.rs index 85a04c54..3fff3da1 100644 --- a/tket2/src/serialize/pytket/op.rs +++ b/tket2/src/serialize/pytket/op.rs @@ -13,15 +13,16 @@ use hugr::ops::OpType; use hugr::IncomingPort; use tket_json_rs::circuit_json; +use crate::Tk2Op; + use self::native::NativeOp; use self::serialised::OpaqueTk1Op; use super::OpConvertError; -/// An operation originating from pytket, containing the operation type and all its attributes. -/// -/// Wrapper around [`tket_json_rs::circuit_json::Operation`] with cached number of qubits and bits. +/// An intermediary artifact when converting between TKET1 and TKET2 operations. /// -/// The `Operation` contained by this struct is guaranteed to have a signature. +/// This enum represents either operations that can be represented natively in TKET2, +/// or operations that must be serialised as opaque TKET1 operations. #[derive(Clone, Debug, PartialEq, derive_more::From)] pub enum Tk1Op { /// An operation with a native TKET2 counterpart. @@ -38,15 +39,19 @@ impl Tk1Op { /// # Errors /// /// Returns an error if the operation is not supported by the TKET1 serialization. - pub fn try_from_optype(op: OpType) -> Result { - let res = (&op).try_into(); - let tk1_op = if let Ok(tk2op) = res { - NativeOp::try_from_tk2op(tk2op).map(Tk1Op::Native) + pub fn try_from_optype(op: OpType) -> Result, OpConvertError> { + if let Ok(tk2op) = Tk2Op::try_from(&op) { + let native = NativeOp::try_from_tk2op(tk2op) + .ok_or_else(|| OpConvertError::UnsupportedOpSerialization(op))?; + // Skip serialisation for some special cases. + if native.serial_op().is_none() { + return Ok(None); + } + Ok(Some(Tk1Op::Native(native))) } else { - OpaqueTk1Op::try_from_tket2(&op)?.map(Tk1Op::Opaque) - }; - - tk1_op.ok_or(OpConvertError::UnsupportedOpSerialization(op)) + let opaque = OpaqueTk1Op::try_from_tket2(&op)?; + Ok(opaque.map(Tk1Op::Opaque)) + } } /// Create a new `Tk1Op` from a tket1 `circuit_json::Operation`. @@ -57,11 +62,14 @@ impl Tk1Op { num_qubits: usize, num_bits: usize, ) -> Self { - if let Some(native) = NativeOp::try_from_serial_optype(serial_op.op_type.clone()) { + let op = if let Some(native) = NativeOp::try_from_serial_optype(serial_op.op_type.clone()) { Tk1Op::Native(native) } else { Tk1Op::Opaque(OpaqueTk1Op::new_from_op(serial_op, num_qubits, num_bits)) - } + }; + debug_assert_eq!(num_qubits, op.qubit_inputs().max(op.qubit_outputs())); + debug_assert_eq!(num_bits, op.bit_inputs().max(op.bit_outputs())); + op } /// Get the hugr optype for the operation. @@ -95,6 +103,46 @@ impl Tk1Op { Tk1Op::Opaque(json_op) => itertools::Either::Right(json_op.param_ports()), } } + + /// Returns the number of qubit inputs for this operation. + pub fn qubit_inputs(&self) -> usize { + match self { + Tk1Op::Native(native_op) => native_op.input_qubits, + Tk1Op::Opaque(json_op) => json_op.num_qubits, + } + } + + /// Returns the number of bit inputs for this operation. + pub fn bit_inputs(&self) -> usize { + match self { + Tk1Op::Native(native_op) => native_op.input_bits, + Tk1Op::Opaque(json_op) => json_op.num_bits, + } + } + + /// Returns the number of qubit outputs for this operation. + pub fn qubit_outputs(&self) -> usize { + match self { + Tk1Op::Native(native_op) => native_op.output_qubits, + Tk1Op::Opaque(json_op) => json_op.num_qubits, + } + } + + /// Returns the number of bit outputs for this operation. + pub fn bit_outputs(&self) -> usize { + match self { + Tk1Op::Native(native_op) => native_op.output_bits, + Tk1Op::Opaque(json_op) => json_op.num_bits, + } + } + + /// Returns the number of parameters for this operation. + pub fn num_params(&self) -> usize { + match self { + Tk1Op::Native(native_op) => native_op.num_params, + Tk1Op::Opaque(json_op) => json_op.num_params, + } + } } impl From for OpType { diff --git a/tket2/src/serialize/pytket/op/native.rs b/tket2/src/serialize/pytket/op/native.rs index ad91a0f6..719f9821 100644 --- a/tket2/src/serialize/pytket/op/native.rs +++ b/tket2/src/serialize/pytket/op/native.rs @@ -1,6 +1,6 @@ //! Operations that have corresponding representations in both `pytket` and `tket2`. -use hugr::extension::prelude::QB_T; +use hugr::extension::prelude::{BOOL_T, QB_T}; use hugr::ops::{Noop, OpTrait, OpType}; use hugr::std_extensions::arithmetic::float_types::FLOAT64_TYPE; @@ -10,13 +10,12 @@ use hugr::IncomingPort; use tket_json_rs::circuit_json; use tket_json_rs::optype::OpType as Tk1OpType; -use crate::extension::LINEAR_BIT; use crate::Tk2Op; /// An operation with a native TKET2 counterpart. /// /// Note that the signature of the native and serialised operations may differ. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Default)] pub struct NativeOp { /// The tket2 optype. op: OpType, @@ -25,9 +24,30 @@ pub struct NativeOp { /// Some specific operations do not have a direct pytket counterpart, and must be handled /// separately. serial_op: Option, + /// Number of input qubits to the operation. + pub input_qubits: usize, + /// Number of output qubits to the operation. + pub input_bits: usize, + /// Number of parameters to the operation. + pub num_params: usize, + /// Number of output qubits to the operation. + pub output_qubits: usize, + /// Number of output bits to the operation. + pub output_bits: usize, } impl NativeOp { + /// Initialise a new `NativeOp`. + fn new(op: OpType, serial_op: Option) -> Self { + let mut native_op = Self { + op, + serial_op, + ..Default::default() + }; + native_op.compute_counts(); + native_op + } + /// Create a new `NativeOp` from a `circuit_json::Operation`. pub fn try_from_tk2op(tk2op: Tk2Op) -> Option { let serial_op = match tk2op { @@ -48,28 +68,22 @@ impl NativeOp { Tk2Op::ZZPhase => Tk1OpType::ZZPhase, Tk2Op::CZ => Tk1OpType::CZ, Tk2Op::Reset => Tk1OpType::Reset, + Tk2Op::Measure => Tk1OpType::Measure, Tk2Op::AngleAdd => { // These operations should be folded into constant before serialisation, // or replaced by pytket logic expressions. - return Some(Self { - op: tk2op.into(), - serial_op: None, - }); - } - // TKET2 measurements and TKET1 measurements have different semantics. - Tk2Op::Measure => { - return None; + return Some(Self::new(tk2op.into(), None)); } // These operations do not have a direct pytket counterpart. Tk2Op::QAlloc | Tk2Op::QFree => { - return None; + // These operations are implicitly supported by the encoding, + // they do not create an explicit pytket operation but instead + // add new qubits to the circuit input/output. + return Some(Self::new(tk2op.into(), None)); } }; - Some(Self { - op: tk2op.into(), - serial_op: Some(serial_op), - }) + Some(Self::new(tk2op.into(), Some(serial_op))) } /// Returns the translated tket2 optype for this operation, if it exists. @@ -92,35 +106,24 @@ impl NativeOp { Tk1OpType::ZZPhase => Tk2Op::ZZPhase.into(), Tk1OpType::CZ => Tk2Op::CZ.into(), Tk1OpType::Reset => Tk2Op::Reset.into(), + Tk1OpType::Measure => Tk2Op::Measure.into(), Tk1OpType::noop => Noop::new(QB_T).into(), _ => { return None; } }; - Some(Self { - op, - serial_op: Some(serial_op), - }) + Some(Self::new(op, Some(serial_op))) } /// Converts this `NativeOp` into a tket_json_rs operation. pub fn serialised_op(&self) -> Option { let serial_op = self.serial_op.clone()?; - let mut num_qubits = 0; - let mut num_bits = 0; - let mut num_params = 0; - if let Some(sig) = self.signature() { - for ty in sig.input.iter() { - if ty == &QB_T { - num_qubits += 1 - } else if *ty == *LINEAR_BIT { - num_bits += 1 - } else if ty == &FLOAT64_TYPE { - num_params += 1 - } - } - } + // Since pytket operations are always linear, + // use the maximum of input and output bits/qubits. + let num_qubits = self.input_qubits.max(self.output_qubits); + let num_bits = self.input_bits.max(self.output_bits); + let num_params = self.num_params; let params = (num_params > 0).then(|| vec!["".into(); num_params]); @@ -139,6 +142,14 @@ impl NativeOp { self.op.dataflow_signature() } + /// Returns the serial optype for this operation. + /// + /// Some special operations do not have a direct serialised counterpart, and + /// should be skipped during serialisation. + pub fn serial_op(&self) -> Option<&Tk1OpType> { + self.serial_op.as_ref() + } + /// Returns the tket2 optype for this operation. pub fn optype(&self) -> &OpType { &self.op @@ -159,6 +170,34 @@ impl NativeOp { .map(|(port, _)| port) }) } + + /// Update the internal bit/qubit/parameter counts. + fn compute_counts(&mut self) { + self.input_bits = 0; + self.input_qubits = 0; + self.num_params = 0; + self.output_bits = 0; + self.output_qubits = 0; + let Some(sig) = self.signature() else { + return; + }; + for ty in sig.input_types() { + if ty == &QB_T { + self.input_qubits += 1; + } else if ty == &BOOL_T { + self.input_bits += 1; + } else if ty == &FLOAT64_TYPE { + self.num_params += 1; + } + } + for ty in sig.output_types() { + if ty == &QB_T { + self.output_qubits += 1; + } else if ty == &BOOL_T { + self.output_bits += 1; + } + } + } } #[cfg(test)] diff --git a/tket2/src/serialize/pytket/op/serialised.rs b/tket2/src/serialize/pytket/op/serialised.rs index 3f88ef66..463f96c1 100644 --- a/tket2/src/serialize/pytket/op/serialised.rs +++ b/tket2/src/serialize/pytket/op/serialised.rs @@ -1,6 +1,6 @@ //! Wrapper over pytket operations that cannot be represented naturally in tket2. -use hugr::extension::prelude::QB_T; +use hugr::extension::prelude::{BOOL_T, QB_T}; use hugr::ops::custom::{CustomOp, ExtensionOp}; use hugr::ops::{NamedOp, OpType}; @@ -13,7 +13,7 @@ use serde::de::Error; use tket_json_rs::circuit_json; use crate::extension::{ - LINEAR_BIT, REGISTRY, TKET1_EXTENSION, TKET1_EXTENSION_ID, TKET1_OP_NAME, TKET1_OP_PAYLOAD, + REGISTRY, TKET1_EXTENSION, TKET1_EXTENSION_ID, TKET1_OP_NAME, TKET1_OP_PAYLOAD, }; use crate::serialize::pytket::OpConvertError; @@ -32,9 +32,9 @@ pub struct OpaqueTk1Op { /// Internal operation data. op: circuit_json::Operation, /// Number of qubits declared by the operation. - num_qubits: usize, + pub num_qubits: usize, /// Number of bits declared by the operation. - num_bits: usize, + pub num_bits: usize, /// Node input for each parameter in `op.params`. /// /// If the input is `None`, the parameter does not use a Hugr port and is @@ -42,7 +42,7 @@ pub struct OpaqueTk1Op { param_inputs: Vec>, /// The number of non-None inputs in `param_inputs`, corresponding to the /// FLOAT64_TYPE inputs to the Hugr operation. - num_params: usize, + pub num_params: usize, } impl OpaqueTk1Op { @@ -110,7 +110,7 @@ impl OpaqueTk1Op { pub fn signature(&self) -> FunctionType { let linear = [ vec![QB_T; self.num_qubits], - vec![LINEAR_BIT.clone(); self.num_bits], + vec![BOOL_T.clone(); self.num_bits], ] .concat(); let params = vec![FLOAT64_TYPE; self.num_params]; diff --git a/tket2/src/serialize/pytket/tests.rs b/tket2/src/serialize/pytket/tests.rs index 0e347401..f6e8c760 100644 --- a/tket2/src/serialize/pytket/tests.rs +++ b/tket2/src/serialize/pytket/tests.rs @@ -1,17 +1,20 @@ //! General tests. +use std::collections::{HashMap, HashSet}; use std::io::BufReader; use hugr::builder::{DFGBuilder, Dataflow, DataflowHugr}; -use hugr::extension::prelude::QB_T; +use hugr::extension::prelude::{BOOL_T, QB_T}; +use hugr::hugr::hugrmut::HugrMut; use hugr::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; use hugr::types::FunctionType; +use hugr::HugrView; use rstest::{fixture, rstest}; use tket_json_rs::circuit_json::{self, SerialCircuit}; use tket_json_rs::optype; -use super::TKETDecode; +use super::{TKETDecode, METADATA_Q_OUTPUT_REGISTERS}; use crate::circuit::Circuit; use crate::extension::REGISTRY; use crate::Tk2Op; @@ -27,6 +30,17 @@ const SIMPLE_JSON: &str = r#"{ "implicit_permutation": [[["q", [0]], ["q", [0]]], [["q", [1]], ["q", [1]]]] }"#; +const MULTI_REGISTER: &str = r#"{ + "phase": "0", + "bits": [], + "qubits": [["q", [2]], ["q", [1]], ["my_qubits", [2]]], + "commands": [ + {"args": [["my_qubits", [2]]], "op": {"type": "H"}}, + {"args": [["q", [2]], ["q", [1]]], "op": {"type": "CX"}} + ], + "implicit_permutation": [] + }"#; + const UNKNOWN_OP: &str = r#"{ "phase": "1/2", "bits": [["c", [0]], ["c", [1]]], @@ -40,7 +54,7 @@ const UNKNOWN_OP: &str = r#"{ "implicit_permutation": [[["q", [0]], ["q", [0]]], [["q", [1]], ["q", [1]]], [["q", [2]], ["q", [2]]]] }"#; -const PARAMETRIZED: &str = r#"{ +const PARAMETERIZED: &str = r#"{ "phase": "0.0", "bits": [], "qubits": [["q", [0]], ["q", [1]]], @@ -55,57 +69,121 @@ const PARAMETRIZED: &str = r#"{ "implicit_permutation": [[["q", [0]], ["q", [0]]], [["q", [1]], ["q", [1]]]] }"#; +/// Check some properties of the serial circuit. +fn validate_serial_circ(circ: &SerialCircuit) { + // Check that all commands have valid arguments. + for command in &circ.commands { + for arg in &command.args { + assert!( + circ.qubits.contains(arg) || circ.bits.contains(arg), + "Circuit command {command:?} has an invalid argument '{arg:?}'" + ); + } + } + + // Check that the implicit permutation is valid. + let perm: HashMap = circ + .implicit_permutation + .iter() + .map(|p| (p.0.clone(), p.1.clone())) + .collect(); + for (key, value) in &perm { + let valid_qubits = circ.qubits.contains(key) && circ.qubits.contains(value); + let valid_bits = circ.bits.contains(key) && circ.bits.contains(value); + assert!( + valid_qubits || valid_bits, + "Circuit has an invalid permutation '{key:?} -> {value:?}'" + ); + } + assert_eq!( + perm.len(), + circ.implicit_permutation.len(), + "Circuit has duplicate permutations", + ); + assert_eq!( + HashSet::<&circuit_json::Register>::from_iter(perm.values()).len(), + perm.len(), + "Circuit has duplicate values in permutations" + ); +} + fn compare_serial_circs(a: &SerialCircuit, b: &SerialCircuit) { assert_eq!(a.name, b.name); assert_eq!(a.phase, b.phase); - - let qubits_a: Vec<_> = a.qubits.iter().collect(); - let qubits_b: Vec<_> = b.qubits.iter().collect(); - assert_eq!(qubits_a, qubits_b); - - let bits_a: Vec<_> = a.bits.iter().collect(); - let bits_b: Vec<_> = b.bits.iter().collect(); - assert_eq!(bits_a, bits_b); - - assert_eq!(a.implicit_permutation, b.implicit_permutation); - + assert_eq!(&a.qubits, &b.qubits); + assert_eq!(&a.bits, &b.bits); assert_eq!(a.commands.len(), b.commands.len()); - // the below only works if both serial circuits share a topological ordering - // of commands. + // This comparison only works if both serial circuits share a topological + // ordering of commands. + // + // We also cannot compare the arguments directly, since we may permute them + // internally. + // + // TODO: Do a proper comparison independent of the toposort ordering, and + // track register reordering. for (a, b) in a.commands.iter().zip(b.commands.iter()) { assert_eq!(a.op.op_type, b.op.op_type); - assert_eq!(a.args, b.args); assert_eq!(a.op.params, b.op.params); + assert_eq!(a.args.len(), b.args.len()); } - // TODO: Check commands equality (they only implement PartialEq) } -#[rstest] -#[case::simple(SIMPLE_JSON, 2, 2)] -#[case::unknown_op(UNKNOWN_OP, 2, 3)] -#[case::parametrized(PARAMETRIZED, 4, 2)] -fn json_roundtrip(#[case] circ_s: &str, #[case] num_commands: usize, #[case] num_qubits: usize) { - let ser: circuit_json::SerialCircuit = serde_json::from_str(circ_s).unwrap(); - assert_eq!(ser.commands.len(), num_commands); +/// A simple circuit with some preset qubit registers +#[fixture] +fn circ_preset_qubits() -> Circuit { + let input_t = vec![QB_T]; + let output_t = vec![QB_T, QB_T]; + let mut h = DFGBuilder::new(FunctionType::new(input_t, output_t)).unwrap(); - let circ: Circuit = ser.clone().decode().unwrap(); + let [qb0] = h.input_wires_arr(); + let [qb1] = h.add_dataflow_op(Tk2Op::QAlloc, []).unwrap().outputs_arr(); - assert_eq!(circ.qubit_count(), num_qubits); + let [qb0, qb1] = h + .add_dataflow_op(Tk2Op::CZ, [qb0, qb1]) + .unwrap() + .outputs_arr(); - let reser: SerialCircuit = SerialCircuit::encode(&circ).unwrap(); - compare_serial_circs(&ser, &reser); + let mut hugr = h.finish_hugr_with_outputs([qb0, qb1], ®ISTRY).unwrap(); + + // A preset register for the first qubit output + hugr.set_metadata( + hugr.root(), + METADATA_Q_OUTPUT_REGISTERS, + serde_json::json!([["q", [1]]]), + ); + + hugr.into() } -#[rstest] -#[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri -#[case::barenco_tof_10("../test_files/barenco_tof_10.json")] -fn json_file_roundtrip(#[case] circ: impl AsRef) { - let reader = BufReader::new(std::fs::File::open(circ).unwrap()); - let ser: circuit_json::SerialCircuit = serde_json::from_reader(reader).unwrap(); - let circ: Circuit = ser.clone().decode().unwrap(); - let reser: SerialCircuit = SerialCircuit::encode(&circ).unwrap(); - compare_serial_circs(&ser, &reser); +/// A simple circuit with ancillae +#[fixture] +fn circ_measure_ancilla() -> Circuit { + let input_t = vec![QB_T]; + let output_t = vec![BOOL_T, BOOL_T]; + let mut h = DFGBuilder::new(FunctionType::new(input_t, output_t)).unwrap(); + + let [qb] = h.input_wires_arr(); + let [anc] = h.add_dataflow_op(Tk2Op::QAlloc, []).unwrap().outputs_arr(); + + let [qb, meas_qb] = h + .add_dataflow_op(Tk2Op::Measure, [qb]) + .unwrap() + .outputs_arr(); + let [anc, meas_anc] = h + .add_dataflow_op(Tk2Op::Measure, [anc]) + .unwrap() + .outputs_arr(); + + let [] = h.add_dataflow_op(Tk2Op::QFree, [qb]).unwrap().outputs_arr(); + let [] = h + .add_dataflow_op(Tk2Op::QFree, [anc]) + .unwrap() + .outputs_arr(); + + h.finish_hugr_with_outputs([meas_qb, meas_anc], ®ISTRY) + .unwrap() + .into() } #[fixture] @@ -114,15 +192,15 @@ fn circ_add_angles_symbolic() -> Circuit { let output_t = vec![QB_T]; let mut h = DFGBuilder::new(FunctionType::new(input_t, output_t)).unwrap(); - let mut inps = h.input_wires(); - let qb = inps.next().unwrap(); - let f1 = inps.next().unwrap(); - let f2 = inps.next().unwrap(); - - let res = h.add_dataflow_op(Tk2Op::AngleAdd, [f1, f2]).unwrap(); - let f12 = res.outputs().next().unwrap(); - let res = h.add_dataflow_op(Tk2Op::RxF64, [qb, f12]).unwrap(); - let qb = res.outputs().next().unwrap(); + let [qb, f1, f2] = h.input_wires_arr(); + let [f12] = h + .add_dataflow_op(Tk2Op::AngleAdd, [f1, f2]) + .unwrap() + .outputs_arr(); + let [qb] = h + .add_dataflow_op(Tk2Op::RxF64, [qb, f12]) + .unwrap() + .outputs_arr(); h.finish_hugr_with_outputs([qb], ®ISTRY).unwrap().into() } @@ -148,6 +226,68 @@ fn circ_add_angles_constants() -> Circuit { h.finish_hugr_with_outputs(qbs, ®ISTRY).unwrap().into() } +#[rstest] +#[case::simple(SIMPLE_JSON, 2, 2)] +#[case::simple(MULTI_REGISTER, 2, 3)] +#[case::unknown_op(UNKNOWN_OP, 2, 3)] +#[case::parametrized(PARAMETERIZED, 4, 2)] +fn json_roundtrip(#[case] circ_s: &str, #[case] num_commands: usize, #[case] num_qubits: usize) { + let ser: circuit_json::SerialCircuit = serde_json::from_str(circ_s).unwrap(); + assert_eq!(ser.commands.len(), num_commands); + + let circ: Circuit = ser.clone().decode().unwrap(); + + assert_eq!(circ.qubit_count(), num_qubits); + + let reser: SerialCircuit = SerialCircuit::encode(&circ).unwrap(); + validate_serial_circ(&reser); + compare_serial_circs(&ser, &reser); +} + +#[rstest] +#[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri +#[case::barenco_tof_10("../test_files/barenco_tof_10.json")] +fn json_file_roundtrip(#[case] circ: impl AsRef) { + let reader = BufReader::new(std::fs::File::open(circ).unwrap()); + let ser: circuit_json::SerialCircuit = serde_json::from_reader(reader).unwrap(); + let circ: Circuit = ser.clone().decode().unwrap(); + let reser: SerialCircuit = SerialCircuit::encode(&circ).unwrap(); + validate_serial_circ(&reser); + compare_serial_circs(&ser, &reser); +} + +/// Test the serialisation roundtrip from a tket2 circuit. +/// +/// Note: this is not a pure roundtrip as the encoder may add internal qubits/bits to the circuit. +#[rstest] +#[case::meas_ancilla(circ_measure_ancilla(), FunctionType::new_endo(vec![QB_T, QB_T, BOOL_T, BOOL_T]))] +#[case::preset_qubits(circ_preset_qubits(), FunctionType::new_endo(vec![QB_T, QB_T, QB_T]))] +fn circuit_roundtrip(#[case] circ: Circuit, #[case] decoded_sig: FunctionType) { + let ser: SerialCircuit = SerialCircuit::encode(&circ).unwrap(); + let deser: Circuit = ser.clone().decode().unwrap(); + + let deser_sig = deser.circuit_signature(); + assert_eq!( + &deser_sig.input, &decoded_sig.input, + "Input signature mismatch\n Expected: {}\n Actual: {}", + &deser_sig, &decoded_sig + ); + assert_eq!( + &deser_sig.output, &decoded_sig.output, + "Output signature mismatch\n Expected: {}\n Actual: {}", + &deser_sig, &decoded_sig + ); + + let reser = SerialCircuit::encode(&deser).unwrap(); + validate_serial_circ(&reser); + compare_serial_circs(&ser, &reser); +} + +/// Test serialisation of circuits with a symbolic expression. +/// +/// Note: this is not a proper roundtrip as the symbols f0 and f1 are not +/// converted back to circuit inputs. This would require parsing symbolic +/// expressions. #[rstest] #[case::symbolic(circ_add_angles_symbolic(), "f0 + f1")] #[case::constants(circ_add_angles_constants(), "0.2 + 0.3")] @@ -157,10 +297,8 @@ fn test_add_angle_serialise(#[case] circ_add_angles: Circuit, #[case] param_str: assert_eq!(ser.commands[0].op.op_type, optype::OpType::Rx); assert_eq!(ser.commands[0].op.params, Some(vec![param_str.into()])); - // Note: this is not a proper roundtrip as the symbols f0 and f1 are not - // converted back to circuit inputs. This would require parsing symbolic - // expressions. let deser: Circuit = ser.clone().decode().unwrap(); let reser = SerialCircuit::encode(&deser).unwrap(); + validate_serial_circ(&reser); compare_serial_circs(&ser, &reser); }