From a6e9e131717be8b34c59ab6ab2044114510fd830 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Agust=C3=ADn=20Borgna?=
<121866228+aborgna-q@users.noreply.github.com>
Date: Tue, 25 Jun 2024 12:49:11 +0100
Subject: [PATCH] feat: Drop linear bits, improve pytket encoding/decoding
(#420)
Removes the ad-hoc `LINEAR_BIT` used for decoding pytket circuits, and
uses non-linear `BOOL_T`s instead.
This now lets us encode guppy circuits with measurements;
```python
module = GuppyModule("test")
module.load(quantum)
@guppy(module)
def my_func(q0: qubit, q1: qubit) -> tuple[bool,]:
q0 = phased_x(q0, py(math.pi / 2), py(-math.pi / 2))
q0 = rz(q0, py(math.pi))
q1 = phased_x(q1, py(math.pi / 2), py(-math.pi / 2))
q1 = rz(q1, py(math.pi))
q0, q1 = zz_max(q0, q1)
_ = measure(q0)
return (measure(q1),)
circ = guppy_to_circuit(my_func)
print(to_hugr_mermaid(circ))
tk1 = circ.to_tket1()
render_circuit_jupyter(tk1)
circ2 = Tk2Circuit(tk1)
print(to_hugr_mermaid(circ2))
```
Mermaid diagram (rooted on the `DataflowBlock`):
```mermaid
graph LR
subgraph 0 ["(0) Module"]
direction LR
subgraph 7 ["(7) FuncDefn"]
direction LR
3["(3) Input"]
3--"0:0 qubit"-->8
3--"1:1 qubit"-->8
6["(6) Output"]
subgraph 8 ["(8) CFG"]
direction LR
subgraph 1 ["(1) DataflowBlock"]
direction LR
4["(4) Input"]
4--"0:0 qubit"-->13
4--"1:0 qubit"-->21
5["(5) Output"]
9["(9) const:custom:f64(1.5707963267948966)"]
9--"0:0 float64"-->10
10["(10) LoadConstant"]
10--"0:1 float64"-->13
11["(11) const:custom:f64(-1.5707963267948966)"]
11--"0:0 float64"-->12
12["(12) LoadConstant"]
12--"0:2 float64"-->13
13["(13) quantum.tket2.PhasedX"]
13--"0:0 qubit"-->16
14["(14) const:custom:f64(3.141592653589793)"]
14--"0:0 float64"-->15
15["(15) LoadConstant"]
15--"0:1 float64"-->16
16["(16) quantum.tket2.RzF64"]
16--"0:0 qubit"-->25
17["(17) const:custom:f64(1.5707963267948966)"]
17--"0:0 float64"-->18
18["(18) LoadConstant"]
18--"0:1 float64"-->21
19["(19) const:custom:f64(-1.5707963267948966)"]
19--"0:0 float64"-->20
20["(20) LoadConstant"]
20--"0:2 float64"-->21
21["(21) quantum.tket2.PhasedX"]
21--"0:0 qubit"-->24
22["(22) const:custom:f64(3.141592653589793)"]
22--"0:0 float64"-->23
23["(23) LoadConstant"]
23--"0:1 float64"-->24
24["(24) quantum.tket2.RzF64"]
24--"0:1 qubit"-->25
25["(25) quantum.tket2.ZZMax"]
25--"0:0 qubit"-->26
25--"1:1 qubit"-->26
26["(26) MakeTuple"]
26--"0:0 [qubit, qubit]"-->27
27["(27) UnpackTuple"]
27--"0:0 qubit"-->28
27--"1:0 qubit"-->30
28["(28) quantum.tket2.Measure"]
28--"0:0 qubit"-->29
29["(29) quantum.tket2.QFree"]
30["(30) quantum.tket2.Measure"]
30--"0:0 qubit"-->31
30--"1:0 []+[]"-->32
31["(31) quantum.tket2.QFree"]
32["(32) MakeTuple"]
32--"0:0 [[]+[]]"-->33
33["(33) UnpackTuple"]
33--"0:1 []+[]"-->5
34["(34) Tag"]
34--"0:0 []"-->5
end
1-."0:0".->2
2["(2) ExitBlock"]
end
8--"0:0 []+[]"-->6
end
end
```
tket1 circuit:
![circuit](https://github.com/CQCL/tket2/assets/121866228/62877218-dd24-4e5e-a8f7-ce738e05662c)
Re-extracted circuit:
```mermaid
graph LR
subgraph 0 ["(0) FuncDefn"]
direction LR
1["(1) Input"]
1--"0:0 qubit"-->7
1--"1:0 qubit"-->12
2["(2) Output"]
3["(3) const:custom:f64(1.5707963267948966)"]
3--"0:0 float64"-->4
4["(4) LoadConstant"]
4--"0:1 float64"-->7
5["(5) const:custom:f64(-1.5707963267948966)"]
5--"0:0 float64"-->6
6["(6) LoadConstant"]
6--"0:2 float64"-->7
7["(7) quantum.tket2.PhasedX"]
7--"0:0 qubit"-->15
8["(8) const:custom:f64(1.5707963267948966)"]
8--"0:0 float64"-->9
9["(9) LoadConstant"]
9--"0:1 float64"-->12
10["(10) const:custom:f64(-1.5707963267948966)"]
10--"0:0 float64"-->11
11["(11) LoadConstant"]
11--"0:2 float64"-->12
12["(12) quantum.tket2.PhasedX"]
12--"0:0 qubit"-->18
13["(13) const:custom:f64(3.141592653589793)"]
13--"0:0 float64"-->14
14["(14) LoadConstant"]
14--"0:1 float64"-->15
15["(15) quantum.tket2.RzF64"]
15--"0:0 qubit"-->19
16["(16) const:custom:f64(3.141592653589793)"]
16--"0:0 float64"-->17
17["(17) LoadConstant"]
17--"0:1 float64"-->18
18["(18) quantum.tket2.RzF64"]
18--"0:1 qubit"-->19
19["(19) quantum.tket2.ZZMax"]
19--"0:0 qubit"-->21
19--"1:0 qubit"-->20
20["(20) quantum.tket2.Measure"]
20--"0:1 qubit"-->2
20--"1:2 []+[]"-->2
21["(21) quantum.tket2.Measure"]
21--"0:0 qubit"-->2
21--"1:3 []+[]"-->2
end
```
This required multiple improvements to the encoder/decoder logic,
including
- `Tk1Op::Native` operations (backed by a `Tk2Op`) can now have
different number of input/output qubit/bits.
- Added support for tket2 circuits with different input and output
signatures.
- Added support for `QAlloc`/`QFree` operations (generated by guppy) by
adding extra input/outputs to the circuit.
- Added support for pytket's implicit permutations, and recalculates the
value when encoding a tket2 circuit.
- Preserve the `opgroup` value from decoded pytket operations.
- Improved error reporting.
Closes #379
---------
Co-authored-by: Seyon Sivarajah
---
tket2-py/src/circuit.rs | 7 +-
tket2-py/test/test_guppy.py | 7 +
tket2-py/tket2/_tket2/circuit.pyi | 4 -
tket2-py/tket2/circuit/build.py | 3 +-
tket2/src/circuit.rs | 4 +-
tket2/src/circuit/command.rs | 8 +-
tket2/src/extension.rs | 16 +-
tket2/src/serialize/pytket.rs | 123 +++-
tket2/src/serialize/pytket/decoder.rs | 300 +++++---
tket2/src/serialize/pytket/encoder.rs | 720 +++++++++++++++-----
tket2/src/serialize/pytket/op.rs | 76 ++-
tket2/src/serialize/pytket/op/native.rs | 107 ++-
tket2/src/serialize/pytket/op/serialised.rs | 12 +-
tket2/src/serialize/pytket/tests.rs | 238 +++++--
14 files changed, 1222 insertions(+), 403 deletions(-)
diff --git a/tket2-py/src/circuit.rs b/tket2-py/src/circuit.rs
index 931dbe85..49bd6905 100644
--- a/tket2-py/src/circuit.rs
+++ b/tket2-py/src/circuit.rs
@@ -15,7 +15,7 @@ use pyo3::prelude::*;
use std::fmt;
use hugr::{type_row, Hugr, HugrView, PortIndex};
-use tket2::extension::{LINEAR_BIT, REGISTRY};
+use tket2::extension::REGISTRY;
use tket2::rewrite::CircuitRewrite;
use tket2::serialize::TKETDecode;
use tket_json_rs::circuit_json::SerialCircuit;
@@ -315,11 +315,6 @@ impl PyHugrType {
Self(QB_T)
}
- #[staticmethod]
- fn linear_bit() -> Self {
- Self(LINEAR_BIT.to_owned())
- }
-
#[staticmethod]
fn bool() -> Self {
Self(BOOL_T)
diff --git a/tket2-py/test/test_guppy.py b/tket2-py/test/test_guppy.py
index d349495c..22865982 100644
--- a/tket2-py/test/test_guppy.py
+++ b/tket2-py/test/test_guppy.py
@@ -65,3 +65,10 @@ def my_func(
# The 7 operations in the function, plus two implicit QFree
assert circ.num_operations() == 9
+
+ tk1 = circ.to_tket1()
+ assert tk1.n_gates == 7
+ assert tk1.n_qubits == 2
+
+ gates = list(tk1)
+ assert gates[4].op.type == pytket.circuit.OpType.ZZMax
diff --git a/tket2-py/tket2/_tket2/circuit.pyi b/tket2-py/tket2/_tket2/circuit.pyi
index 6c92063e..101718ba 100644
--- a/tket2-py/tket2/_tket2/circuit.pyi
+++ b/tket2-py/tket2/_tket2/circuit.pyi
@@ -157,10 +157,6 @@ class HugrType:
def qubit() -> HugrType:
"""Qubit type from HUGR prelude."""
- @staticmethod
- def linear_bit() -> HugrType:
- """Linear bit type from TKET1 extension."""
-
@staticmethod
def bool() -> HugrType:
"""Boolean type (HUGR 2-ary unit sum)."""
diff --git a/tket2-py/tket2/circuit/build.py b/tket2-py/tket2/circuit/build.py
index 83ee5082..cf742720 100644
--- a/tket2-py/tket2/circuit/build.py
+++ b/tket2-py/tket2/circuit/build.py
@@ -3,7 +3,6 @@
from dataclasses import dataclass
QB_T = HugrType.qubit()
-LB_T = HugrType.linear_bit()
BOOL_T = HugrType.bool()
@@ -22,7 +21,7 @@ def bits(self) -> list[int]:
@classmethod
def op(cls) -> CustomOp:
- types = [QB_T] * cls.n_qb + [LB_T] * cls.n_lb
+ types = [QB_T] * cls.n_qb + [BOOL_T] * cls.n_lb
return CustomOp(cls.extension_name, cls.gate_name, types, types)
diff --git a/tket2/src/circuit.rs b/tket2/src/circuit.rs
index c8b13b78..92b72f6f 100644
--- a/tket2/src/circuit.rs
+++ b/tket2/src/circuit.rs
@@ -681,8 +681,8 @@ mod tests {
assert_eq!(circ.operations().count(), 3);
assert_eq!(circ.units().count(), qubits + bits);
- assert_eq!(circ.nonlinear_units().count(), 0);
- assert_eq!(circ.linear_units().count(), qubits + bits);
+ assert_eq!(circ.nonlinear_units().count(), bits);
+ assert_eq!(circ.linear_units().count(), qubits);
assert_eq!(circ.qubits().count(), qubits);
}
diff --git a/tket2/src/circuit/command.rs b/tket2/src/circuit/command.rs
index 50aa5815..d1cd8223 100644
--- a/tket2/src/circuit/command.rs
+++ b/tket2/src/circuit/command.rs
@@ -7,7 +7,7 @@ use std::collections::{HashMap, HashSet};
use std::iter::FusedIterator;
use hugr::hugr::views::{HierarchyView, SiblingGraph};
-use hugr::hugr::NodeType;
+use hugr::hugr::{NodeMetadata, NodeType};
use hugr::ops::{OpTag, OpTrait};
use hugr::{HugrView, IncomingPort, OutgoingPort};
use itertools::Either::{self, Left, Right};
@@ -161,6 +161,12 @@ impl<'circ, T: HugrView> Command<'circ, T> {
.port_kind(port)
.map_or(false, |kind| kind.is_linear())
}
+
+ /// Returns a metadata value associated with the command's node.
+ #[inline]
+ pub fn metadata(&self, key: impl AsRef) -> Option<&NodeMetadata> {
+ self.circ.hugr().get_metadata(self.node, key)
+ }
}
impl<'a, 'circ, T: HugrView> UnitLabeller for &'a Command<'circ, T> {
diff --git a/tket2/src/extension.rs b/tket2/src/extension.rs
index 12b47517..66a71e26 100644
--- a/tket2/src/extension.rs
+++ b/tket2/src/extension.rs
@@ -10,7 +10,7 @@ use hugr::extension::{CustomSignatureFunc, ExtensionId, ExtensionRegistry, Signa
use hugr::hugr::IdentList;
use hugr::std_extensions::arithmetic::float_types::{EXTENSION as FLOAT_EXTENSION, FLOAT64_TYPE};
use hugr::types::type_param::{TypeArg, TypeParam};
-use hugr::types::{CustomType, FunctionType, PolyFuncType, Type, TypeBound};
+use hugr::types::{CustomType, FunctionType, PolyFuncType, TypeBound};
use hugr::{type_row, Extension};
use lazy_static::lazy_static;
use smol_str::SmolStr;
@@ -21,9 +21,6 @@ pub mod angle;
/// The ID of the TKET1 extension.
pub const TKET1_EXTENSION_ID: ExtensionId = IdentList::new_unchecked("TKET1");
-/// The name for the linear bit custom type.
-pub const LINEAR_BIT_NAME: SmolStr = SmolStr::new_inline("LBit");
-
/// The name for opaque TKET1 operations.
pub const TKET1_OP_NAME: SmolStr = SmolStr::new_inline("TKET1 Json Op");
@@ -39,8 +36,6 @@ pub static ref TKET1_OP_PAYLOAD : CustomType =
pub static ref TKET1_EXTENSION: Extension = {
let mut res = Extension::new(TKET1_EXTENSION_ID);
- res.add_type(LINEAR_BIT_NAME, vec![], "A linear bit.".into(), TypeBound::Any.into()).unwrap();
-
let tket1_op_payload_def = res.add_type(TKET1_PAYLOAD_NAME, vec![], "Opaque TKET1 operation metadata.".into(), TypeBound::Eq.into()).unwrap();
let tket1_op_payload = TypeParam::Opaque{ty:tket1_op_payload_def.instantiate([]).unwrap()};
res.add_op(
@@ -52,15 +47,6 @@ pub static ref TKET1_EXTENSION: Extension = {
res
};
-/// The type for linear bits. Part of the TKET1 extension.
-pub static ref LINEAR_BIT: Type = {
- Type::new_extension(TKET1_EXTENSION
- .get_type(&LINEAR_BIT_NAME)
- .unwrap()
- .instantiate([])
- .unwrap())
- };
-
/// Extension registry including the prelude, TKET1 and Tk2Ops extensions.
pub static ref REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([
TKET1_EXTENSION.clone(),
diff --git a/tket2/src/serialize/pytket.rs b/tket2/src/serialize/pytket.rs
index 8d335780..81d08406 100644
--- a/tket2/src/serialize/pytket.rs
+++ b/tket2/src/serialize/pytket.rs
@@ -6,39 +6,47 @@ mod op;
use hugr::types::Type;
+use hugr::Node;
+use itertools::Itertools;
// Required for serialising ops in the tket1 hugr extension.
pub(crate) use op::serialised::OpaqueTk1Op;
#[cfg(test)]
mod tests;
+use std::collections::hash_map::DefaultHasher;
+use std::hash::{Hash, Hasher};
use std::path::Path;
use std::{fs, io};
-use hugr::ops::{OpType, Value};
+use hugr::ops::{NamedOp, OpType, Value};
use hugr::std_extensions::arithmetic::float_types::ConstF64;
use thiserror::Error;
-use tket_json_rs::circuit_json::SerialCircuit;
+use tket_json_rs::circuit_json::{self, SerialCircuit};
use tket_json_rs::optype::OpType as SerialOpType;
use crate::circuit::Circuit;
-use self::decoder::JsonDecoder;
-use self::encoder::JsonEncoder;
+use self::decoder::Tk1Decoder;
+use self::encoder::Tk1Encoder;
pub use crate::passes::pytket::lower_to_pytket;
/// Prefix used for storing metadata in the hugr nodes.
-pub const METADATA_PREFIX: &str = "TKET1_JSON";
+pub const METADATA_PREFIX: &str = "TKET1";
/// The global phase specified as metadata.
-const METADATA_PHASE: &str = "TKET1_JSON.phase";
-/// The implicit permutation of qubits.
-const METADATA_IMPLICIT_PERM: &str = "TKET1_JSON.implicit_permutation";
+const METADATA_PHASE: &str = "TKET1.phase";
/// Explicit names for the input qubit registers.
-const METADATA_Q_REGISTERS: &str = "TKET1_JSON.qubit_registers";
+const METADATA_Q_REGISTERS: &str = "TKET1.qubit_registers";
+/// The reordered qubit registers in the output, if an implicit permutation was applied.
+const METADATA_Q_OUTPUT_REGISTERS: &str = "TKET1.qubit_output_registers";
/// Explicit names for the input bit registers.
-const METADATA_B_REGISTERS: &str = "TKET1_JSON.bit_registers";
+const METADATA_B_REGISTERS: &str = "TKET1.bit_registers";
+/// The reordered bit registers in the output, if an implicit permutation was applied.
+const METADATA_B_OUTPUT_REGISTERS: &str = "TKET1.bit_output_registers";
+/// A tket1 operation "opgroup" field.
+const METADATA_OPGROUP: &str = "TKET1.opgroup";
/// A serialized representation of a [`Circuit`].
///
@@ -59,7 +67,7 @@ impl TKETDecode for SerialCircuit {
type EncodeError = TK1ConvertError;
fn decode(self) -> Result {
- let mut decoder = JsonDecoder::try_new(&self)?;
+ let mut decoder = Tk1Decoder::try_new(&self)?;
if !self.phase.is_empty() {
// TODO - add a phase gate
@@ -68,18 +76,18 @@ impl TKETDecode for SerialCircuit {
}
for com in self.commands {
- decoder.add_command(com);
+ decoder.add_command(com)?;
}
Ok(decoder.finish().into())
}
fn encode(circ: &Circuit) -> Result {
- let mut encoder = JsonEncoder::new(circ)?;
+ let mut encoder = Tk1Encoder::new(circ)?;
for com in circ.commands() {
let optype = com.optype();
encoder.add_command(com.clone(), optype)?;
}
- Ok(encoder.finish())
+ Ok(encoder.finish(circ))
}
}
@@ -156,6 +164,47 @@ pub enum OpConvertError {
/// The serialized operation is not supported.
#[error("Cannot serialize tket2 operation: {0:?}")]
UnsupportedOpSerialization(OpType),
+ /// The operation has non-serializable inputs.
+ #[error("Operation {} in {node} has an unsupported input of type {typ}.", optype.name())]
+ UnsupportedInputType {
+ /// The unsupported type.
+ typ: Type,
+ /// The operation name.
+ optype: OpType,
+ /// The node.
+ node: Node,
+ },
+ /// The operation has non-serializable outputs.
+ #[error("Operation {} in {node} has an unsupported output of type {typ}.", optype.name())]
+ UnsupportedOutputType {
+ /// The unsupported type.
+ typ: Type,
+ /// The operation name.
+ optype: OpType,
+ /// The node.
+ node: Node,
+ },
+ /// A parameter input could not be evaluated.
+ #[error("The {typ} parameter input for operation {} in {node} could not be resolved.", optype.name())]
+ UnresolvedParamInput {
+ /// The parameter type.
+ typ: Type,
+ /// The operation with the missing input param.
+ optype: OpType,
+ /// The node.
+ node: Node,
+ },
+ /// The operation has output-only qubits.
+ /// This is not currently supported by the encoder.
+ #[error("Operation {} in {node} has more output qubits than inputs.", optype.name())]
+ TooManyOutputQubits {
+ /// The unsupported type.
+ typ: Type,
+ /// The operation name.
+ optype: OpType,
+ /// The node.
+ node: Node,
+ },
/// The opaque tket1 operation had an invalid type parameter.
#[error("Opaque TKET1 operation had an invalid type parameter. {error}")]
InvalidOpaqueTypeParam {
@@ -163,6 +212,35 @@ pub enum OpConvertError {
#[from]
error: serde_yaml::Error,
},
+ /// Tried to decode a tket1 operation with not enough parameters.
+ #[error(
+ "Operation {} is missing encoded parameters. Expected at least {expected} but only \"{}\" were specified.",
+ optype.name(),
+ params.iter().join(", "),
+ )]
+ MissingSerialisedParams {
+ /// The operation name.
+ optype: OpType,
+ /// The expected number of parameters.
+ expected: usize,
+ /// The given of parameters.
+ params: Vec,
+ },
+ /// Tried to decode a tket1 operation with not enough qubit/bit arguments.
+ #[error(
+ "Operation {} is missing encoded arguments. Expected {expected_qubits} and {expected_bits}, but only \"{args:?}\" were specified.",
+ optype.name(),
+ )]
+ MissingSerialisedArguments {
+ /// The operation name.
+ optype: OpType,
+ /// The expected number of qubits.
+ expected_qubits: usize,
+ /// The expected number of bits.
+ expected_bits: usize,
+ /// The given of parameters.
+ args: Vec,
+ },
}
/// Error type for conversion between `Op` and `OpType`.
@@ -234,3 +312,20 @@ fn try_constant_to_param(val: &Value) -> Option {
let half_turns = radians / std::f64::consts::PI;
Some(half_turns.to_string())
}
+
+/// A hashed register, used to identify registers in the [`Tk1Decoder::register_wire`] map,
+/// avoiding string and vector clones on lookup.
+#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
+struct RegisterHash {
+ hash: u64,
+}
+
+impl From<&circuit_json::Register> for RegisterHash {
+ fn from(reg: &circuit_json::Register) -> Self {
+ let mut hasher = DefaultHasher::new();
+ reg.hash(&mut hasher);
+ Self {
+ hash: hasher.finish(),
+ }
+ }
+}
diff --git a/tket2/src/serialize/pytket/decoder.rs b/tket2/src/serialize/pytket/decoder.rs
index 04fc6242..b1f2eeb5 100644
--- a/tket2/src/serialize/pytket/decoder.rs
+++ b/tket2/src/serialize/pytket/decoder.rs
@@ -1,56 +1,51 @@
-//! Intermediate structure for converting decoding [`SerialCircuit`]s into [`Hugr`]s.
+//! Intermediate structure for decoding [`SerialCircuit`]s into [`Hugr`]s.
-use std::collections::hash_map::DefaultHasher;
-use std::collections::HashMap;
-use std::hash::{Hash, Hasher};
-use std::mem;
+use std::collections::{HashMap, HashSet};
-use hugr::builder::{CircuitBuilder, Container, Dataflow, DataflowHugr, FunctionBuilder};
-use hugr::extension::prelude::QB_T;
+use hugr::builder::{Container, Dataflow, DataflowHugr, FunctionBuilder};
+use hugr::extension::prelude::{BOOL_T, QB_T};
+use hugr::ops::handle::NodeHandle;
use hugr::ops::OpType;
use hugr::types::FunctionType;
-use hugr::CircuitUnit;
use hugr::{Hugr, Wire};
-use itertools::Itertools;
+use itertools::{EitherOrBoth, Itertools};
use serde_json::json;
use tket_json_rs::circuit_json;
use tket_json_rs::circuit_json::SerialCircuit;
use super::op::Tk1Op;
-use super::{try_param_to_constant, TK1ConvertError, METADATA_IMPLICIT_PERM, METADATA_PHASE};
-use super::{METADATA_B_REGISTERS, METADATA_Q_REGISTERS};
-use crate::extension::{LINEAR_BIT, REGISTRY, TKET1_EXTENSION_ID};
+use super::{
+ try_param_to_constant, OpConvertError, RegisterHash, TK1ConvertError,
+ METADATA_B_OUTPUT_REGISTERS, METADATA_B_REGISTERS, METADATA_OPGROUP, METADATA_PHASE,
+ METADATA_Q_OUTPUT_REGISTERS, METADATA_Q_REGISTERS,
+};
+use crate::extension::{REGISTRY, TKET1_EXTENSION_ID};
use crate::symbolic_constant_op;
/// The state of an in-progress [`FunctionBuilder`] being built from a [`SerialCircuit`].
///
/// Mostly used to define helper internal methods.
#[derive(Debug, PartialEq)]
-pub(super) struct JsonDecoder {
+pub(super) struct Tk1Decoder {
/// The Hugr being built.
pub hugr: FunctionBuilder,
- /// The dangling wires of the builder.
- /// Used to generate [`CircuitBuilder`]s.
- dangling_wires: Vec,
- /// A map from the json registers to the units in the circuit being built.
- register_units: HashMap,
- /// The number of qubits in the circuit.
- num_qubits: usize,
- /// The number of bits in the circuit.
- num_bits: usize,
+ /// A map from the tracked pytket registers to the [`Wire`]s in the circuit.
+ register_wires: HashMap,
+ /// The ordered list of register to have at the output.
+ ordered_registers: Vec,
+ /// A set of registers that encode qubits.
+ qubit_registers: HashSet,
}
-impl JsonDecoder {
- /// Initialize a new [`JsonDecoder`], using the metadata from a [`SerialCircuit`].
+impl Tk1Decoder {
+ /// Initialize a new [`Tk1Decoder`], using the metadata from a [`SerialCircuit`].
pub fn try_new(serialcirc: &SerialCircuit) -> Result {
let num_qubits = serialcirc.qubits.len();
let num_bits = serialcirc.bits.len();
- let sig = FunctionType::new_endo(
- [vec![QB_T; num_qubits], vec![LINEAR_BIT.clone(); num_bits]].concat(),
- )
- .with_extension_delta(TKET1_EXTENSION_ID);
+ let sig = FunctionType::new_endo([vec![QB_T; num_qubits], vec![BOOL_T; num_bits]].concat())
+ .with_extension_delta(TKET1_EXTENSION_ID);
let name = serialcirc.name.clone().unwrap_or_default();
let mut dfg = FunctionBuilder::new(name, sig.into()).unwrap();
@@ -59,83 +54,202 @@ impl JsonDecoder {
// Metadata. The circuit requires "name", and we store other things that
// should pass through the serialization roundtrip.
dfg.set_metadata(METADATA_PHASE, json!(serialcirc.phase));
- dfg.set_metadata(
- METADATA_IMPLICIT_PERM,
- json!(serialcirc.implicit_permutation),
- );
dfg.set_metadata(METADATA_Q_REGISTERS, json!(serialcirc.qubits));
dfg.set_metadata(METADATA_B_REGISTERS, json!(serialcirc.bits));
- // Map each register element to their starting `CircuitUnit`.
- let mut wire_map: HashMap =
- HashMap::with_capacity(num_bits + num_qubits);
- for (i, register) in serialcirc.qubits.iter().enumerate() {
- check_register(register)?;
- wire_map.insert(register.into(), CircuitUnit::Linear(i));
+ // Compute the output register reordering, and store it in the metadata.
+ //
+ // The `implicit_permutation` field is a dictionary mapping input
+ // registers to output registers on the same path.
+ //
+ // Here we store an ordered list showing the order in which the input
+ // registers appear in the output.
+ //
+ // For a circuit with three qubit registers 0, 1, 2 and an implicit
+ // permutation {0 -> 1, 1 -> 2, 2 -> 0}, `output_to_input` will be
+ // {1 -> 0, 2 -> 1, 0 -> 2} and the output order will be [2, 0, 1].
+ // That is, at position 0 of the output we'll see the register originally
+ // named 2, at position 1 the register originally named 0, and so on.
+ let mut output_qubits = Vec::with_capacity(serialcirc.qubits.len());
+ let mut output_bits = Vec::with_capacity(serialcirc.bits.len());
+ let output_to_input: HashMap = serialcirc
+ .implicit_permutation
+ .iter()
+ .map(|p| (p.1.clone(), p.0.clone()))
+ .collect();
+ for qubit in &serialcirc.qubits {
+ // For each output position, find the input register that should be there.
+ output_qubits.push(output_to_input.get(qubit).unwrap_or(qubit).clone());
}
- for (i, register) in serialcirc.bits.iter().enumerate() {
- check_register(register)?;
- wire_map.insert(register.into(), CircuitUnit::Linear(i + num_qubits));
+ for bit in &serialcirc.bits {
+ // For each output position, find the input register that should be there.
+ output_bits.push(output_to_input.get(bit).unwrap_or(bit).clone());
}
+ dfg.set_metadata(METADATA_Q_OUTPUT_REGISTERS, json!(output_qubits));
+ dfg.set_metadata(METADATA_B_OUTPUT_REGISTERS, json!(output_bits));
+
+ let qubit_registers = serialcirc.qubits.iter().map(RegisterHash::from).collect();
+
+ let ordered_registers = serialcirc
+ .qubits
+ .iter()
+ .chain(&serialcirc.bits)
+ .map(|reg| {
+ check_register(reg)?;
+ Ok(RegisterHash::from(reg))
+ })
+ .collect::, TK1ConvertError>>()?;
+
+ // Map each register element to their starting wire.
+ let register_wires: HashMap = ordered_registers
+ .iter()
+ .copied()
+ .zip(dangling_wires)
+ .collect();
- Ok(JsonDecoder {
+ Ok(Tk1Decoder {
hugr: dfg,
- dangling_wires,
- register_units: wire_map,
- num_qubits,
- num_bits,
+ register_wires,
+ ordered_registers,
+ qubit_registers,
})
}
/// Finish building the [`Hugr`].
- pub fn finish(self) -> Hugr {
- // TODO: Throw validation error?
+ pub fn finish(mut self) -> Hugr {
+ // Order the final wires according to the serial circuit register order.
+ let mut outputs = Vec::with_capacity(self.ordered_registers.len());
+ for register in self.ordered_registers {
+ let wire = self.register_wires.remove(®ister).unwrap();
+ outputs.push(wire);
+ }
+ debug_assert!(
+ self.register_wires.is_empty(),
+ "Some output wires were not associated with a register."
+ );
+
self.hugr
- .finish_hugr_with_outputs(self.dangling_wires, ®ISTRY)
+ .finish_hugr_with_outputs(outputs, ®ISTRY)
.unwrap()
}
/// Add a tket1 [`circuit_json::Command`] from the serial circuit to the
/// decoder.
- pub fn add_command(&mut self, command: circuit_json::Command) {
- // TODO Store the command's `opgroup` in the metadata.
- let circuit_json::Command { op, args, .. } = command;
+ pub fn add_command(&mut self, command: circuit_json::Command) -> Result<(), OpConvertError> {
+ let circuit_json::Command {
+ op, args, opgroup, ..
+ } = command;
+ let op_params = op.params.clone().unwrap_or_default();
+
+ // Interpret the serialised operation as a [`Tk1Op`].
let num_qubits = args
.iter()
- .take_while(|&arg| match self.reg_wire(arg) {
- CircuitUnit::Linear(i) => i < self.num_qubits,
- _ => false,
- })
+ .take_while(|&arg| self.is_qubit_register(arg))
.count();
let num_input_bits = args.len() - num_qubits;
- let op_params = op.params.clone();
let tk1op = Tk1Op::from_serialised_op(op, num_qubits, num_input_bits);
- let param_units = tk1op
- .param_ports()
- .enumerate()
- .filter_map(|(i, _port)| op_params.as_ref()?.get(i).map(String::as_ref))
- .map(|p| CircuitUnit::Wire(self.create_param_wire(p)))
- .collect_vec();
- let arg_units = args.into_iter().map(|reg| self.reg_wire(®));
-
- let append_wires: Vec = arg_units.chain(param_units).collect_vec();
+ let (input_wires, output_registers) = self.get_op_wires(&tk1op, &args, op_params)?;
let op: OpType = (&tk1op).into();
- self.with_circ_builder(|circ| {
- circ.append_and_consume(op, append_wires).unwrap();
- });
+ let new_op = self.hugr.add_dataflow_op(op, input_wires).unwrap();
+ let wires = new_op.outputs();
+
+ // Store the opgroup metadata.
+ if let Some(opgroup) = opgroup {
+ self.hugr
+ .set_child_metadata(new_op.node(), METADATA_OPGROUP, json!(opgroup));
+ }
+
+ // Assign the new output wires to some register, replacing the previous association.
+ for (register, wire) in output_registers.into_iter().zip_eq(wires) {
+ self.set_register_wire(register, wire);
+ }
+
+ Ok(())
}
- /// Apply a function to the internal hugr builder viewed as a [`CircuitBuilder`].
- fn with_circ_builder(&mut self, f: F) -> T
- where
- F: FnOnce(&mut CircuitBuilder>) -> T,
- {
- let mut circ = self.hugr.as_circuit(mem::take(&mut self.dangling_wires));
- let res = f(&mut circ);
- self.dangling_wires = circ.finish();
- res
+ /// Returns the input wires to connect to a new operation
+ /// and the registers to associate with outputs.
+ ///
+ /// It may add constant nodes to the Hugr if the operation has constant parameters.
+ fn get_op_wires(
+ &mut self,
+ tk1op: &Tk1Op,
+ args: &[circuit_json::Register],
+ params: Vec,
+ ) -> Result<(Vec, Vec), OpConvertError> {
+ // Arguments are always ordered with qubits first, and then bits.
+ let mut inputs: Vec = Vec::with_capacity(args.len() + params.len());
+ let mut outputs: Vec =
+ Vec::with_capacity(tk1op.qubit_outputs() + tk1op.bit_outputs());
+
+ let mut current_arg = 0;
+ let mut next_arg = || {
+ if args.len() <= current_arg {
+ return Err(OpConvertError::MissingSerialisedArguments {
+ optype: tk1op.optype(),
+ expected_qubits: tk1op.qubit_inputs(),
+ expected_bits: tk1op.bit_inputs(),
+ args: args.to_owned(),
+ });
+ }
+ current_arg += 1;
+ Ok(&args[current_arg - 1])
+ };
+
+ // Qubit wires
+ assert_eq!(
+ tk1op.qubit_inputs(),
+ tk1op.qubit_outputs(),
+ "Operations with different numbers of input and output qubits are not currently supported."
+ );
+ for _ in 0..tk1op.qubit_inputs() {
+ let reg = next_arg()?;
+ inputs.push(self.register_wire(reg));
+ outputs.push(reg.into());
+ }
+
+ // Bit wires
+ for zip in (0..tk1op.bit_inputs()).zip_longest(0..tk1op.bit_outputs()) {
+ let reg = next_arg()?;
+ match zip {
+ EitherOrBoth::Both(_inp, _out) => {
+ // A bit used both as input and output.
+ inputs.push(self.register_wire(reg));
+ outputs.push(reg.into());
+ }
+ EitherOrBoth::Left(_inp) => {
+ // A bit used only used as input.
+ inputs.push(self.register_wire(reg));
+ }
+ EitherOrBoth::Right(_out) => {
+ // A new bit output.
+ outputs.push(reg.into());
+ }
+ }
+ }
+
+ // Check that the operation is not missing parameters.
+ //
+ // Nb: `Tk1Op::Opaque` operations may not have parameters in their hugr definition.
+ // In that case, we just store the parameter values in the opaque data.
+ if tk1op.num_params() > params.len() {
+ return Err(OpConvertError::MissingSerialisedParams {
+ optype: tk1op.optype(),
+ expected: tk1op.num_params(),
+ params,
+ });
+ }
+ // Add the parameter wires to the input.
+ inputs.extend(
+ tk1op
+ .param_ports()
+ .zip(params)
+ .map(|(_port, param)| self.create_param_wire(¶m)),
+ );
+
+ Ok((inputs, outputs))
}
/// Returns the wire carrying a parameter.
@@ -155,29 +269,19 @@ impl JsonDecoder {
}
}
- /// Return the wire unit for the `elem`th value of a given register.
- ///
- /// Relies on TKET1 constraint that all registers have unique names.
- fn reg_wire(&self, register: &circuit_json::Register) -> CircuitUnit {
- self.register_units[®ister.into()]
+ /// Return the [`Wire`] associated with a register.
+ fn register_wire(&self, register: impl Into) -> Wire {
+ self.register_wires[®ister.into()]
}
-}
-/// A hashed register, used to identify registers in the [`JsonDecoder::register_wire`] map,
-/// avoiding string clones on lookup.
-#[derive(Clone, Debug, PartialEq, Eq, Hash)]
-struct RegisterHash {
- hash: u64,
-}
+ /// Update the tracked [`Wire`] for a register.
+ fn set_register_wire(&mut self, register: impl Into, unit: Wire) {
+ self.register_wires.insert(register.into(), unit);
+ }
-impl From<&circuit_json::Register> for RegisterHash {
- fn from(reg: &circuit_json::Register) -> Self {
- let mut hasher = DefaultHasher::new();
- reg.0.hash(&mut hasher);
- reg.1.hash(&mut hasher);
- Self {
- hash: hasher.finish(),
- }
+ /// Returns `true` if the register is a qubit register.
+ fn is_qubit_register(&self, register: impl Into) -> bool {
+ self.qubit_registers.contains(®ister.into())
}
}
diff --git a/tket2/src/serialize/pytket/encoder.rs b/tket2/src/serialize/pytket/encoder.rs
index e64768b8..fa26b2b7 100644
--- a/tket2/src/serialize/pytket/encoder.rs
+++ b/tket2/src/serialize/pytket/encoder.rs
@@ -1,122 +1,78 @@
-//! Intermediate structure for converting encoding [`Circuit`]s into [`SerialCircuit`]s.
+//! Intermediate structure for encoding [`Circuit`]s into [`SerialCircuit`]s.
use core::panic;
-use std::collections::HashMap;
+use std::collections::{HashMap, HashSet, VecDeque};
-use hugr::extension::prelude::QB_T;
-use hugr::ops::{NamedOp, OpType};
+use hugr::extension::prelude::{BOOL_T, QB_T};
+use hugr::ops::{OpTrait, OpType};
use hugr::std_extensions::arithmetic::float_types::FLOAT64_TYPE;
use hugr::{HugrView, Wire};
-use itertools::{Either, Itertools};
-use tket_json_rs::circuit_json::{self, Permutation, Register, SerialCircuit};
+use itertools::Itertools;
+use tket_json_rs::circuit_json::Register as RegisterUnit;
+use tket_json_rs::circuit_json::{self, SerialCircuit};
use crate::circuit::command::{CircuitUnit, Command};
use crate::circuit::Circuit;
-use crate::extension::LINEAR_BIT;
use crate::ops::{match_symb_const_op, op_matches};
+use crate::serialize::pytket::RegisterHash;
use crate::Tk2Op;
use super::op::Tk1Op;
use super::{
- try_constant_to_param, OpConvertError, TK1ConvertError, METADATA_B_REGISTERS,
- METADATA_IMPLICIT_PERM, METADATA_PHASE, METADATA_Q_REGISTERS,
+ try_constant_to_param, OpConvertError, TK1ConvertError, METADATA_B_OUTPUT_REGISTERS,
+ METADATA_B_REGISTERS, METADATA_OPGROUP, METADATA_PHASE, METADATA_Q_OUTPUT_REGISTERS,
+ METADATA_Q_REGISTERS,
};
/// The state of an in-progress [`SerialCircuit`] being built from a [`Circuit`].
-#[derive(Debug, PartialEq)]
-pub(super) struct JsonEncoder {
+#[derive(Debug, Clone)]
+pub(super) struct Tk1Encoder {
/// The name of the circuit being encoded.
name: Option,
/// Global phase value. Defaults to "0"
phase: String,
- /// Implicit permutation of output qubits
- implicit_permutation: Vec,
- /// The current commands
+ /// The current serialised commands
commands: Vec,
- /// The TKET1 qubit registers associated to each qubit unit of the circuit.
- qubit_to_reg: HashMap,
- /// The TKET1 bit registers associated to each linear bit unit of the circuit.
- bit_to_reg: HashMap,
- /// The ordered TKET1 names for the input qubit registers.
- ///
- /// Nb: Although `tket-json-rs` calls these "registers", they're actually
- /// identifiers for single qubits in the `Register::0` register.
- qubit_registers: Vec,
- /// The ordered TKET1 names for the input bit registers.
- ///
- /// Nb: Although `tket-json-rs` calls these "registers", they're actually
- /// identifiers for single bits in the `Register::0` register.
- bit_registers: Vec,
- /// A register of wires with constant values, used to recover TKET1
- /// parameters.
- parameters: HashMap,
+ /// A tracker for the qubits used in the circuit.
+ qubits: QubitTracker,
+ /// A tracker for the bits used in the circuit.
+ bits: BitTracker,
+ /// A tracker for the operation parameters used in the circuit.
+ parameters: ParameterTracker,
}
-impl JsonEncoder {
+impl Tk1Encoder {
/// Create a new [`JsonEncoder`] from a [`Circuit`].
pub fn new(circ: &Circuit) -> Result {
let name = circ.name().map(str::to_string);
let hugr = circ.hugr();
- let mut qubit_registers = vec![];
- let mut bit_registers = vec![];
- let mut phase = "0".to_string();
- let mut implicit_permutation = vec![];
+ // Check for unsupported input types.
+ for (_, _, typ) in circ.units() {
+ if ![FLOAT64_TYPE, QB_T, BOOL_T].contains(&typ) {
+ return Err(TK1ConvertError::NonSerializableInputs { typ });
+ }
+ }
// Recover other parameters stored in the metadata
// TODO: Check for invalid encoded metadata
- let root = circ.parent();
- if let Some(p) = hugr.get_metadata(root, METADATA_PHASE) {
- phase = p.as_str().unwrap().to_string();
- }
- if let Some(perm) = hugr.get_metadata(root, METADATA_IMPLICIT_PERM) {
- implicit_permutation = serde_json::from_value(perm.clone()).unwrap();
- }
- if let Some(q_regs) = hugr.get_metadata(root, METADATA_Q_REGISTERS) {
- qubit_registers = serde_json::from_value(q_regs.clone()).unwrap();
- }
- if let Some(b_regs) = hugr.get_metadata(root, METADATA_B_REGISTERS) {
- bit_registers = serde_json::from_value(b_regs.clone()).unwrap();
- }
-
- // Map the Hugr units to tket1 register names.
- // Uses the names from the metadata if available, or initializes new sequentially-numbered registers.
- let mut qubit_to_reg = HashMap::new();
- let mut bit_to_reg = HashMap::new();
- let get_register = |registers: &mut Vec, name: &str, index| {
- registers.get(index).cloned().unwrap_or_else(|| {
- let r = Register(name.to_string(), vec![index as i64]);
- registers.push(r.clone());
- r
- })
+ let phase = match hugr.get_metadata(circ.parent(), METADATA_PHASE) {
+ Some(p) => p.as_str().unwrap().to_string(),
+ None => "0".to_string(),
};
- for (unit, _, ty) in circ.units() {
- if ty == QB_T {
- let index = qubit_to_reg.len();
- let reg = get_register(&mut qubit_registers, "q", index);
- qubit_to_reg.insert(unit, reg);
- } else if ty == *LINEAR_BIT {
- let index = bit_to_reg.len();
- let reg = get_register(&mut bit_registers, "b", index);
- bit_to_reg.insert(unit, reg.clone());
- }
- }
- let mut encoder = Self {
+ let qubit_tracker = QubitTracker::new(circ);
+ let bit_tracker = BitTracker::new(circ);
+ let parameter_tracker = ParameterTracker::new(circ);
+
+ Ok(Self {
name,
phase,
- implicit_permutation,
commands: vec![],
- qubit_to_reg,
- bit_to_reg,
- qubit_registers,
- bit_registers,
- parameters: HashMap::new(),
- };
-
- encoder.add_input_parameters(circ)?;
-
- Ok(encoder)
+ qubits: qubit_tracker,
+ bits: bit_tracker,
+ parameters: parameter_tracker,
+ })
}
/// Add a circuit command to the serialization.
@@ -126,39 +82,113 @@ impl JsonEncoder {
optype: &OpType,
) -> Result<(), OpConvertError> {
// Register any output of the command that can be used as a TKET1 parameter.
- if self.record_parameters(&command, optype) {
+ if self.parameters.record_parameters(&command, optype)? {
// for now all ops that record parameters should be ignored (are
// just constants)
return Ok(());
}
- let (args, params): (Vec, Vec) =
- command
- .inputs()
- .partition_map(|(u, _, _)| match self.unit_to_register(u) {
- Some(r) => Either::Left(r),
- None => match u {
- CircuitUnit::Wire(w) => Either::Right(w),
- CircuitUnit::Linear(_) => {
- panic!("No register found for the linear input {u:?}.")
- }
- },
+ // Special case for the QAlloc operation.
+ // This does not translate to a TKET1 operation, we just start tracking a new qubit register.
+ if optype == &Tk2Op::QAlloc.into() {
+ let Some((CircuitUnit::Linear(unit_id), _, _)) = command.outputs().next() else {
+ panic!("QAlloc should have a single qubit output.")
+ };
+ debug_assert!(self.qubits.get(unit_id).is_none());
+ self.qubits.add_qubit_register(unit_id);
+ return Ok(());
+ }
+
+ let Some(tk1op) = Tk1Op::try_from_optype(optype.clone())? else {
+ // This command should be ignored.
+ return Ok(());
+ };
+
+ // Get the registers and wires associated with the operation's inputs.
+ let mut qubit_args = Vec::with_capacity(tk1op.qubit_inputs());
+ let mut bit_args = Vec::with_capacity(tk1op.bit_inputs());
+ let mut params = Vec::with_capacity(tk1op.num_params());
+ for (unit, _, ty) in command.inputs() {
+ if ty == QB_T {
+ let reg = self.unit_to_register(unit).unwrap_or_else(|| {
+ panic!(
+ "No register found for qubit input {unit} in node {}.",
+ command.node(),
+ )
+ });
+ qubit_args.push(reg);
+ } else if ty == BOOL_T {
+ let reg = self.unit_to_register(unit).unwrap_or_else(|| {
+ panic!(
+ "No register found for bit input {unit} in node {}.",
+ command.node(),
+ )
+ });
+ bit_args.push(reg);
+ } else if ty == FLOAT64_TYPE {
+ let CircuitUnit::Wire(param_wire) = unit else {
+ unreachable!("Float types are not linear.")
+ };
+ params.push(param_wire);
+ } else {
+ return Err(OpConvertError::UnsupportedInputType {
+ typ: ty.clone(),
+ optype: optype.clone(),
+ node: command.node(),
});
+ }
+ }
- // TODO Restore the opgroup (once the decoding supports it)
- let opgroup = None;
+ for (unit, _, ty) in command.outputs() {
+ if ty == QB_T {
+ // If the qubit is not already in the qubit tracker, add it as a
+ // new register.
+ let CircuitUnit::Linear(unit_id) = unit else {
+ panic!("Qubit types are linear.")
+ };
+ if self.qubits.get(unit_id).is_none() {
+ let reg = self.qubits.add_qubit_register(unit_id);
+ qubit_args.push(reg.clone());
+ }
+ } else if ty == BOOL_T {
+ // If the operation has any bit outputs, create a new one bit
+ // register.
+ //
+ // Note that we do not reassign input registers to the new
+ // output wires as we do not know if the bit value was modified
+ // by the operation, and the old value may be needed later.
+ //
+ // This may cause register duplication for opaque operations
+ // with input bits.
+ let CircuitUnit::Wire(wire) = unit else {
+ panic!("Bool types are not linear.")
+ };
+ let reg = self.bits.add_bit_register(wire);
+ bit_args.push(reg.clone());
+ } else {
+ return Err(OpConvertError::UnsupportedOutputType {
+ typ: ty.clone(),
+ optype: optype.clone(),
+ node: command.node(),
+ });
+ }
+ }
+
+ let opgroup: Option = command
+ .metadata(METADATA_OPGROUP)
+ .and_then(serde_json::Value::as_str)
+ .map(ToString::to_string);
// Convert the command's operator to a pytket serialized one. This will
// return an error for operations that should have been caught by the
// `record_parameters` branch above (in addition to other unsupported
// ops).
- let op: Tk1Op = Tk1Op::try_from_optype(optype.clone())?;
- let mut op: circuit_json::Operation = op
+ let mut serial_op: circuit_json::Operation = tk1op
.serialised_op()
.ok_or_else(|| OpConvertError::UnsupportedOpSerialization(optype.clone()))?;
if !params.is_empty() {
- op.params = Some(
+ serial_op.params = Some(
params
.into_iter()
.filter_map(|w| self.parameters.get(&w))
@@ -169,21 +199,369 @@ impl JsonEncoder {
// TODO: ops that contain free variables.
// (update decoder to ignore them too, but store them in the wrapped op)
- let command = circuit_json::Command { op, args, opgroup };
+ let mut args = qubit_args;
+ args.append(&mut bit_args);
+ let command = circuit_json::Command {
+ op: serial_op,
+ args,
+ opgroup,
+ };
self.commands.push(command);
+
Ok(())
}
/// Finish building and return the final [`SerialCircuit`].
- pub fn finish(self) -> SerialCircuit {
+ pub fn finish(self, circ: &Circuit) -> SerialCircuit {
+ let (qubits, qubits_permutation) = self.qubits.finish(circ);
+ let (bits, mut bits_permutation) = self.bits.finish(circ);
+
+ let mut implicit_permutation = qubits_permutation;
+ implicit_permutation.append(&mut bits_permutation);
+
SerialCircuit {
name: self.name,
phase: self.phase,
commands: self.commands,
- qubits: self.qubit_registers,
- bits: self.bit_registers,
- implicit_permutation: self.implicit_permutation,
+ qubits,
+ bits,
+ implicit_permutation,
+ }
+ }
+
+ /// Translate a linear [`CircuitUnit`] into a [`RegisterUnit`], if possible.
+ fn unit_to_register(&self, unit: CircuitUnit) -> Option {
+ match unit {
+ CircuitUnit::Linear(i) => self.qubits.get(i).cloned(),
+ CircuitUnit::Wire(wire) => self.bits.get(&wire).cloned(),
+ }
+ }
+}
+
+/// A structure for tracking qubits used in the circuit being encoded.
+///
+/// Nb: Although `tket-json-rs` has a "Register" struct, it's actually
+/// an identifier for single qubits in the `Register::0` register.
+/// We rename it to `RegisterUnit` here to avoid confusion.
+#[derive(Debug, Clone, Default)]
+struct QubitTracker {
+ /// The ordered TKET1 names for the input qubit registers.
+ inputs: Vec,
+ /// The ordered TKET1 names for the output qubit registers.
+ outputs: Option>,
+ /// The TKET1 qubit registers associated to each qubit unit of the circuit.
+ qubit_to_reg: HashMap,
+ /// A generator of new registers units to use for bit wires.
+ unit_generator: RegisterUnitGenerator,
+}
+
+impl QubitTracker {
+ /// Create a new [`QubitTracker`] from the qubit inputs of a [`Circuit`].
+ /// Reads the [`METADATA_Q_REGISTERS`] metadata entry with preset pytket qubit register names.
+ ///
+ /// If the circuit contains more qubit inputs than the provided list,
+ /// new registers are created for the remaining qubits.
+ pub fn new(circ: &Circuit) -> Self {
+ let mut tracker = QubitTracker::default();
+
+ if let Some(input_regs) = circ
+ .hugr()
+ .get_metadata(circ.parent(), METADATA_Q_REGISTERS)
+ {
+ tracker.inputs = serde_json::from_value(input_regs.clone()).unwrap();
+ }
+ let output_regs = circ
+ .hugr()
+ .get_metadata(circ.parent(), METADATA_Q_OUTPUT_REGISTERS)
+ .map(|regs| serde_json::from_value(regs.clone()).unwrap());
+ if let Some(output_regs) = output_regs {
+ tracker.outputs = Some(output_regs);
+ }
+
+ tracker.unit_generator = RegisterUnitGenerator::new(
+ "q",
+ tracker
+ .inputs
+ .iter()
+ .chain(tracker.outputs.iter().flatten()),
+ );
+
+ let qubit_count = circ.units().filter(|(_, _, ty)| ty == &QB_T).count();
+
+ for i in 0..qubit_count {
+ // Use the given input register names if available, or create new ones.
+ if let Some(reg) = tracker.inputs.get(i) {
+ tracker.qubit_to_reg.insert(i, reg.clone());
+ } else {
+ let reg = tracker.add_qubit_register(i).clone();
+ tracker.inputs.push(reg);
+ }
+ }
+
+ tracker
+ }
+
+ /// Add a new register unit for a qubit wire.
+ pub fn add_qubit_register(&mut self, unit_id: usize) -> &RegisterUnit {
+ let reg = self.unit_generator.next();
+ self.qubit_to_reg.insert(unit_id, reg);
+ self.qubit_to_reg.get(&unit_id).unwrap()
+ }
+
+ /// Returns the register unit for a qubit wire, if it exists.
+ pub fn get(&self, unit_id: usize) -> Option<&RegisterUnit> {
+ self.qubit_to_reg.get(&unit_id)
+ }
+
+ /// Consumes the tracker and returns the final list of qubit registers, along
+ /// with the final permutation of the outputs.
+ pub fn finish(
+ mut self,
+ _circ: &Circuit,
+ ) -> (Vec, Vec) {
+ // Ensure the input and output lists have the same registers.
+ let mut outputs = self.outputs.unwrap_or_default();
+ let mut input_regs: HashSet =
+ self.inputs.iter().map(RegisterHash::from).collect();
+ let output_regs: HashSet = outputs.iter().map(RegisterHash::from).collect();
+
+ for inp in &self.inputs {
+ if !output_regs.contains(&inp.into()) {
+ outputs.push(inp.clone());
+ }
+ }
+ for out in &outputs {
+ if !input_regs.contains(&out.into()) {
+ self.inputs.push(out.clone());
+ }
+ }
+ input_regs.extend(output_regs);
+
+ // Add registers defined mid-circuit to both ends.
+ for reg in self.qubit_to_reg.into_values() {
+ if !input_regs.contains(&(®).into()) {
+ self.inputs.push(reg.clone());
+ outputs.push(reg);
+ }
+ }
+
+ // TODO: Look at the circuit outputs to determine the final permutation.
+ //
+ // We don't have the `CircuitUnit::Linear` assignments for the outputs
+ // here, so that requires some extra piping.
+ let permutation = outputs
+ .into_iter()
+ .zip(&self.inputs)
+ .map(|(out, inp)| circuit_json::Permutation(inp.clone(), out))
+ .collect_vec();
+
+ (self.inputs, permutation)
+ }
+}
+
+/// A structure for tracking bits used in the circuit being encoded.
+///
+/// Nb: Although `tket-json-rs` has a "Register" struct, it's actually
+/// an identifier for single bits in the `Register::0` register.
+/// We rename it to `RegisterUnit` here to avoid confusion.
+#[derive(Debug, Clone, Default)]
+struct BitTracker {
+ /// The ordered TKET1 names for the bit inputs.
+ inputs: Vec,
+ /// The expected order of TKET1 names for the bit outputs,
+ /// if that was stored in the metadata.
+ outputs: Option>,
+ /// Map each bit wire to a TKET1 register element.
+ bit_to_reg: HashMap,
+ /// Registers defined in the metadata, but not present in the circuit
+ /// inputs.
+ unused_registers: VecDeque,
+ /// A generator of new registers units to use for bit wires.
+ unit_generator: RegisterUnitGenerator,
+}
+
+impl BitTracker {
+ /// Create a new [`BitTracker`] from the bit inputs of a [`Circuit`].
+ /// Reads the [`METADATA_B_REGISTERS`] metadata entry with preset pytket bit register names.
+ ///
+ /// If the circuit contains more bit inputs than the provided list,
+ /// new registers are created for the remaining bits.
+ ///
+ /// TODO: Compute output bit permutations when finishing the circuit.
+ pub fn new(circ: &Circuit) -> Self {
+ let mut tracker = BitTracker::default();
+
+ if let Some(input_regs) = circ
+ .hugr()
+ .get_metadata(circ.parent(), METADATA_B_REGISTERS)
+ {
+ tracker.inputs = serde_json::from_value(input_regs.clone()).unwrap();
+ }
+ let output_regs = circ
+ .hugr()
+ .get_metadata(circ.parent(), METADATA_B_OUTPUT_REGISTERS)
+ .map(|regs| serde_json::from_value(regs.clone()).unwrap());
+ if let Some(output_regs) = output_regs {
+ tracker.outputs = Some(output_regs);
+ }
+
+ tracker.unit_generator = RegisterUnitGenerator::new(
+ "c",
+ tracker
+ .inputs
+ .iter()
+ .chain(tracker.outputs.iter().flatten()),
+ );
+
+ let bit_input_wires = circ.units().filter_map(|u| match u {
+ (CircuitUnit::Wire(w), _, ty) if ty == BOOL_T => Some(w),
+ _ => None,
+ });
+
+ let mut unused_registers: HashSet = tracker.inputs.iter().cloned().collect();
+ for (i, wire) in bit_input_wires.enumerate() {
+ // If the input is not used in the circuit, ignore it.
+ if circ
+ .hugr()
+ .linked_inputs(wire.node(), wire.source())
+ .next()
+ .is_none()
+ {
+ continue;
+ }
+
+ // Use the given input register names if available, or create new ones.
+ if let Some(reg) = tracker.inputs.get(i) {
+ unused_registers.remove(reg);
+ tracker.bit_to_reg.insert(wire, reg.clone());
+ } else {
+ let reg = tracker.add_bit_register(wire).clone();
+ tracker.inputs.push(reg);
+ };
}
+
+ // If a register was defined in the metadata but not used in the circuit,
+ // we keep it so it can be assigned to an operation output.
+ tracker.unused_registers = unused_registers.into_iter().collect();
+
+ tracker
+ }
+
+ /// Add a new register unit for a bit wire.
+ pub fn add_bit_register(&mut self, wire: Wire) -> &RegisterUnit {
+ let reg = self
+ .unused_registers
+ .pop_front()
+ .unwrap_or_else(|| self.unit_generator.next());
+
+ self.bit_to_reg.insert(wire, reg);
+ self.bit_to_reg.get(&wire).unwrap()
+ }
+
+ /// Returns the register unit for a bit wire, if it exists.
+ pub fn get(&self, wire: &Wire) -> Option<&RegisterUnit> {
+ self.bit_to_reg.get(wire)
+ }
+
+ /// Consumes the tracker and returns the final list of bit registers, along
+ /// with the final permutation of the outputs.
+ pub fn finish(
+ mut self,
+ circ: &Circuit,
+ ) -> (Vec, Vec) {
+ let mut circuit_output_order: Vec = Vec::with_capacity(self.inputs.len());
+ for (node, port) in circ.hugr().all_linked_outputs(circ.output_node()) {
+ let wire = Wire::new(node, port);
+ if let Some(reg) = self.bit_to_reg.get(&wire) {
+ circuit_output_order.push(reg.clone());
+ }
+ }
+
+ // Ensure the input and output lists have the same registers.
+ let mut outputs = self.outputs.unwrap_or_default();
+ let mut input_regs: HashSet =
+ self.inputs.iter().map(RegisterHash::from).collect();
+ let output_regs: HashSet = outputs.iter().map(RegisterHash::from).collect();
+
+ for inp in &self.inputs {
+ if !output_regs.contains(&inp.into()) {
+ outputs.push(inp.clone());
+ }
+ }
+ for out in &outputs {
+ if !input_regs.contains(&out.into()) {
+ self.inputs.push(out.clone());
+ }
+ }
+ input_regs.extend(output_regs);
+
+ // Add registers defined mid-circuit to both ends.
+ for reg in self.bit_to_reg.into_values() {
+ if !input_regs.contains(&(®).into()) {
+ self.inputs.push(reg.clone());
+ outputs.push(reg);
+ }
+ }
+
+ // And ensure `circuit_output_order` has all virtual registers added too.
+ let circuit_outputs: HashSet = circuit_output_order
+ .iter()
+ .map(RegisterHash::from)
+ .collect();
+ for out in &outputs {
+ if !circuit_outputs.contains(&out.into()) {
+ circuit_output_order.push(out.clone());
+ }
+ }
+
+ // Compute the final permutation. This is a combination of two mappings:
+ // - First, the original implicit permutation for the circuit, if this was decoded from pytket.
+ let original_permutation: HashMap = self
+ .inputs
+ .iter()
+ .zip(&outputs)
+ .map(|(inp, out)| (inp.clone(), RegisterHash::from(out)))
+ .collect();
+ // - Second, the actual reordering of outputs seen at the circuit's output node.
+ let mut circuit_permutation: HashMap = outputs
+ .iter()
+ .zip(circuit_output_order)
+ .map(|(out, circ_out)| (RegisterHash::from(out), circ_out))
+ .collect();
+ // The final permutation is the composition of these two mappings.
+ let permutation = original_permutation
+ .into_iter()
+ .map(|(inp, out)| {
+ circuit_json::Permutation(inp, circuit_permutation.remove(&out).unwrap())
+ })
+ .collect_vec();
+
+ (self.inputs, permutation)
+ }
+}
+
+/// A structure for tracking the parameters of a circuit being encoded.
+#[derive(Debug, Clone, Default)]
+struct ParameterTracker {
+ /// The parameters associated with each wire.
+ parameters: HashMap,
+}
+
+impl ParameterTracker {
+ /// Create a new [`ParameterTracker`] from the input parameters of a [`Circuit`].
+ fn new(circ: &Circuit) -> Self {
+ let mut tracker = ParameterTracker::default();
+
+ let float_input_wires = circ.units().filter_map(|u| match u {
+ (CircuitUnit::Wire(w), _, ty) if ty == FLOAT64_TYPE => Some(w),
+ _ => None,
+ });
+
+ for (i, wire) in float_input_wires.enumerate() {
+ tracker.add_parameter(wire, format!("f{i}"));
+ }
+
+ tracker
}
/// Record any output of the command that can be used as a TKET1 parameter.
@@ -193,25 +571,40 @@ impl JsonEncoder {
&mut self,
command: &Command<'_, T>,
optype: &OpType,
- ) -> bool {
- // Only consider commands where all inputs are parameters.
- let inputs = command
- .inputs()
- .filter_map(|(unit, _, _)| match unit {
- CircuitUnit::Wire(wire) => self.parameters.get(&wire),
- CircuitUnit::Linear(_) => None,
- })
- .collect_vec();
- if inputs.len() != command.input_count() {
- debug_assert!(
- !matches!(optype, OpType::Const(_) | OpType::LoadConstant(_)),
- "Found a {} with {} inputs, of which {} are non-linear. In node {:?}",
- optype.name(),
- command.input_count(),
- inputs.len(),
- command.node()
- );
- return false;
+ ) -> Result {
+ let input_count = if let Some(signature) = optype.dataflow_signature() {
+ // Only consider commands where all inputs are parameters,
+ // and some outputs are also parameters.
+ let all_inputs = signature.input().iter().all(|ty| ty == &FLOAT64_TYPE);
+ let some_output = signature.output().iter().any(|ty| ty == &FLOAT64_TYPE);
+ if !all_inputs || !some_output {
+ return Ok(false);
+ }
+ signature.input_count()
+ } else if let OpType::Const(_) = optype {
+ // `Const` is a special non-dataflow command we can handle.
+ // It has zero inputs.
+ 0
+ } else {
+ // Not a parameter-generating command.
+ return Ok(false);
+ };
+
+ // Collect the input parameters.
+ let mut inputs = Vec::with_capacity(input_count);
+ for (unit, _, _) in command.inputs() {
+ let CircuitUnit::Wire(wire) = unit else {
+ panic!("Float types are not linear")
+ };
+ let Some(param) = self.parameters.get(&wire) else {
+ let typ = FLOAT64_TYPE;
+ return Err(OpConvertError::UnresolvedParamInput {
+ typ,
+ optype: optype.clone(),
+ node: command.node(),
+ });
+ };
+ inputs.push(param);
}
let param = match optype {
@@ -219,7 +612,7 @@ impl JsonEncoder {
// New constant, register it if it can be interpreted as a parameter.
match try_constant_to_param(const_op.value()) {
Some(param) => param,
- None => return false,
+ None => return Ok(false),
}
}
OpType::LoadConstant(_op_type) => {
@@ -231,30 +624,18 @@ impl JsonEncoder {
}
_ => {
let Some(s) = match_symb_const_op(optype) else {
- return false;
+ return Ok(false);
};
s.to_string()
}
};
for (unit, _, _) in command.outputs() {
- match unit {
- CircuitUnit::Wire(wire) => self.add_parameter(wire, param.clone()),
- CircuitUnit::Linear(_) => panic!(
- "Found a non-wire output {unit:?} for a {} command.",
- optype.name()
- ),
+ if let CircuitUnit::Wire(wire) = unit {
+ self.add_parameter(wire, param.clone())
}
}
- true
- }
-
- /// Translate a linear [`CircuitUnit`] into a [`Register`], if possible.
- fn unit_to_register(&self, unit: CircuitUnit) -> Option {
- self.qubit_to_reg
- .get(&unit)
- .or_else(|| self.bit_to_reg.get(&unit))
- .cloned()
+ Ok(true)
}
/// Associate a parameter expression with a wire.
@@ -262,23 +643,48 @@ impl JsonEncoder {
self.parameters.insert(wire, param);
}
- /// Adds a parameter for each floating-point input to the circuit.
- fn add_input_parameters(
- &mut self,
- circ: &Circuit,
- ) -> Result<(), TK1ConvertError> {
- let mut num_f64_inputs = 0;
- for (wire, _, typ) in circ.units() {
- match wire {
- CircuitUnit::Linear(_) => {}
- CircuitUnit::Wire(wire) if typ == FLOAT64_TYPE => {
- let param = format!("f{num_f64_inputs}");
- num_f64_inputs += 1;
- self.add_parameter(wire, param);
- }
- CircuitUnit::Wire(_) => return Err(TK1ConvertError::NonSerializableInputs { typ }),
+ /// Returns the parameter expression for a wire, if it exists.
+ fn get(&self, wire: &Wire) -> Option<&String> {
+ self.parameters.get(wire)
+ }
+}
+
+/// A utility class for finding new unused qubit/bit names.
+#[derive(Debug, Clone, Default)]
+struct RegisterUnitGenerator {
+ /// The next index to use for a new register.
+ next_unit: u16,
+ /// The register name to use.
+ register: String,
+}
+
+impl RegisterUnitGenerator {
+ /// Create a new [`RegisterUnitGenerator`]
+ ///
+ /// Scans the set of existing registers to find the last used index, and
+ /// starts generating new unit names from there.
+ pub fn new<'a>(
+ register: impl ToString,
+ existing: impl IntoIterator,
+ ) -> Self {
+ let register = register.to_string();
+ let mut last_unit: Option = None;
+ for reg in existing {
+ if reg.0 != register {
+ continue;
}
+ last_unit = Some(last_unit.unwrap_or_default().max(reg.1[0] as u16));
}
- Ok(())
+ RegisterUnitGenerator {
+ register,
+ next_unit: last_unit.map_or(0, |i| i + 1),
+ }
+ }
+
+ /// Returns a fresh register unit.
+ pub fn next(&mut self) -> RegisterUnit {
+ let unit = self.next_unit;
+ self.next_unit += 1;
+ RegisterUnit(self.register.clone(), vec![unit as i64])
}
}
diff --git a/tket2/src/serialize/pytket/op.rs b/tket2/src/serialize/pytket/op.rs
index 85a04c54..3fff3da1 100644
--- a/tket2/src/serialize/pytket/op.rs
+++ b/tket2/src/serialize/pytket/op.rs
@@ -13,15 +13,16 @@ use hugr::ops::OpType;
use hugr::IncomingPort;
use tket_json_rs::circuit_json;
+use crate::Tk2Op;
+
use self::native::NativeOp;
use self::serialised::OpaqueTk1Op;
use super::OpConvertError;
-/// An operation originating from pytket, containing the operation type and all its attributes.
-///
-/// Wrapper around [`tket_json_rs::circuit_json::Operation`] with cached number of qubits and bits.
+/// An intermediary artifact when converting between TKET1 and TKET2 operations.
///
-/// The `Operation` contained by this struct is guaranteed to have a signature.
+/// This enum represents either operations that can be represented natively in TKET2,
+/// or operations that must be serialised as opaque TKET1 operations.
#[derive(Clone, Debug, PartialEq, derive_more::From)]
pub enum Tk1Op {
/// An operation with a native TKET2 counterpart.
@@ -38,15 +39,19 @@ impl Tk1Op {
/// # Errors
///
/// Returns an error if the operation is not supported by the TKET1 serialization.
- pub fn try_from_optype(op: OpType) -> Result {
- let res = (&op).try_into();
- let tk1_op = if let Ok(tk2op) = res {
- NativeOp::try_from_tk2op(tk2op).map(Tk1Op::Native)
+ pub fn try_from_optype(op: OpType) -> Result