Skip to content

Commit

Permalink
feat: Drop linear bits, improve pytket encoding/decoding (#420)
Browse files Browse the repository at this point in the history
Removes the ad-hoc `LINEAR_BIT` used for decoding pytket circuits, and
uses non-linear `BOOL_T`s instead.
This now lets us encode guppy circuits with measurements;

```python
module = GuppyModule("test")
module.load(quantum)

@guppy(module)
def my_func(q0: qubit, q1: qubit) -> tuple[bool,]:
    q0 = phased_x(q0, py(math.pi / 2), py(-math.pi / 2))
    q0 = rz(q0, py(math.pi))
    q1 = phased_x(q1, py(math.pi / 2), py(-math.pi / 2))
    q1 = rz(q1, py(math.pi))
    q0, q1 = zz_max(q0, q1)
    _ = measure(q0)
    return (measure(q1),)

circ = guppy_to_circuit(my_func)
print(to_hugr_mermaid(circ))

tk1 = circ.to_tket1()
render_circuit_jupyter(tk1)

circ2 = Tk2Circuit(tk1)
print(to_hugr_mermaid(circ2))
```

Mermaid diagram (rooted on the `DataflowBlock`):
```mermaid
graph LR
    subgraph 0 ["(0) Module"]
        direction LR
        subgraph 7 ["(7) FuncDefn"]
            direction LR
            3["(3) Input"]
            3--"0:0<br>qubit"-->8
            3--"1:1<br>qubit"-->8
            6["(6) Output"]
            subgraph 8 ["(8) CFG"]
                direction LR
                subgraph 1 ["(1) DataflowBlock"]
                    direction LR
                    4["(4) Input"]
                    4--"0:0<br>qubit"-->13
                    4--"1:0<br>qubit"-->21
                    5["(5) Output"]
                    9["(9) const:custom:f64(1.5707963267948966)"]
                    9--"0:0<br>float64"-->10
                    10["(10) LoadConstant"]
                    10--"0:1<br>float64"-->13
                    11["(11) const:custom:f64(-1.5707963267948966)"]
                    11--"0:0<br>float64"-->12
                    12["(12) LoadConstant"]
                    12--"0:2<br>float64"-->13
                    13["(13) quantum.tket2.PhasedX"]
                    13--"0:0<br>qubit"-->16
                    14["(14) const:custom:f64(3.141592653589793)"]
                    14--"0:0<br>float64"-->15
                    15["(15) LoadConstant"]
                    15--"0:1<br>float64"-->16
                    16["(16) quantum.tket2.RzF64"]
                    16--"0:0<br>qubit"-->25
                    17["(17) const:custom:f64(1.5707963267948966)"]
                    17--"0:0<br>float64"-->18
                    18["(18) LoadConstant"]
                    18--"0:1<br>float64"-->21
                    19["(19) const:custom:f64(-1.5707963267948966)"]
                    19--"0:0<br>float64"-->20
                    20["(20) LoadConstant"]
                    20--"0:2<br>float64"-->21
                    21["(21) quantum.tket2.PhasedX"]
                    21--"0:0<br>qubit"-->24
                    22["(22) const:custom:f64(3.141592653589793)"]
                    22--"0:0<br>float64"-->23
                    23["(23) LoadConstant"]
                    23--"0:1<br>float64"-->24
                    24["(24) quantum.tket2.RzF64"]
                    24--"0:1<br>qubit"-->25
                    25["(25) quantum.tket2.ZZMax"]
                    25--"0:0<br>qubit"-->26
                    25--"1:1<br>qubit"-->26
                    26["(26) MakeTuple"]
                    26--"0:0<br>[qubit, qubit]"-->27
                    27["(27) UnpackTuple"]
                    27--"0:0<br>qubit"-->28
                    27--"1:0<br>qubit"-->30
                    28["(28) quantum.tket2.Measure"]
                    28--"0:0<br>qubit"-->29
                    29["(29) quantum.tket2.QFree"]
                    30["(30) quantum.tket2.Measure"]
                    30--"0:0<br>qubit"-->31
                    30--"1:0<br>[]+[]"-->32
                    31["(31) quantum.tket2.QFree"]
                    32["(32) MakeTuple"]
                    32--"0:0<br>[[]+[]]"-->33
                    33["(33) UnpackTuple"]
                    33--"0:1<br>[]+[]"-->5
                    34["(34) Tag"]
                    34--"0:0<br>[]"-->5
                end
                1-."0:0".->2
                2["(2) ExitBlock"]
            end
            8--"0:0<br>[]+[]"-->6
        end
    end
```
tket1 circuit:

![circuit](https://github.com/CQCL/tket2/assets/121866228/62877218-dd24-4e5e-a8f7-ce738e05662c)
Re-extracted circuit:
```mermaid
graph LR
    subgraph 0 ["(0) FuncDefn"]
        direction LR
        1["(1) Input"]
        1--"0:0<br>qubit"-->7
        1--"1:0<br>qubit"-->12
        2["(2) Output"]
        3["(3) const:custom:f64(1.5707963267948966)"]
        3--"0:0<br>float64"-->4
        4["(4) LoadConstant"]
        4--"0:1<br>float64"-->7
        5["(5) const:custom:f64(-1.5707963267948966)"]
        5--"0:0<br>float64"-->6
        6["(6) LoadConstant"]
        6--"0:2<br>float64"-->7
        7["(7) quantum.tket2.PhasedX"]
        7--"0:0<br>qubit"-->15
        8["(8) const:custom:f64(1.5707963267948966)"]
        8--"0:0<br>float64"-->9
        9["(9) LoadConstant"]
        9--"0:1<br>float64"-->12
        10["(10) const:custom:f64(-1.5707963267948966)"]
        10--"0:0<br>float64"-->11
        11["(11) LoadConstant"]
        11--"0:2<br>float64"-->12
        12["(12) quantum.tket2.PhasedX"]
        12--"0:0<br>qubit"-->18
        13["(13) const:custom:f64(3.141592653589793)"]
        13--"0:0<br>float64"-->14
        14["(14) LoadConstant"]
        14--"0:1<br>float64"-->15
        15["(15) quantum.tket2.RzF64"]
        15--"0:0<br>qubit"-->19
        16["(16) const:custom:f64(3.141592653589793)"]
        16--"0:0<br>float64"-->17
        17["(17) LoadConstant"]
        17--"0:1<br>float64"-->18
        18["(18) quantum.tket2.RzF64"]
        18--"0:1<br>qubit"-->19
        19["(19) quantum.tket2.ZZMax"]
        19--"0:0<br>qubit"-->21
        19--"1:0<br>qubit"-->20
        20["(20) quantum.tket2.Measure"]
        20--"0:1<br>qubit"-->2
        20--"1:2<br>[]+[]"-->2
        21["(21) quantum.tket2.Measure"]
        21--"0:0<br>qubit"-->2
        21--"1:3<br>[]+[]"-->2
    end
```


This required multiple improvements to the encoder/decoder logic,
including
- `Tk1Op::Native` operations (backed by a `Tk2Op`) can now have
different number of input/output qubit/bits.
- Added support for tket2 circuits with different input and output
signatures.
- Added support for `QAlloc`/`QFree` operations (generated by guppy) by
adding extra input/outputs to the circuit.
- Added support for pytket's implicit permutations, and recalculates the
value when encoding a tket2 circuit.
- Preserve the `opgroup` value from decoded pytket operations.
- Improved error reporting.

Closes #379

---------

Co-authored-by: Seyon Sivarajah <[email protected]>
  • Loading branch information
aborgna-q and ss2165 authored Jun 25, 2024
1 parent 5499817 commit a6e9e13
Show file tree
Hide file tree
Showing 14 changed files with 1,222 additions and 403 deletions.
7 changes: 1 addition & 6 deletions tket2-py/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use pyo3::prelude::*;
use std::fmt;

use hugr::{type_row, Hugr, HugrView, PortIndex};
use tket2::extension::{LINEAR_BIT, REGISTRY};
use tket2::extension::REGISTRY;
use tket2::rewrite::CircuitRewrite;
use tket2::serialize::TKETDecode;
use tket_json_rs::circuit_json::SerialCircuit;
Expand Down Expand Up @@ -315,11 +315,6 @@ impl PyHugrType {
Self(QB_T)
}

#[staticmethod]
fn linear_bit() -> Self {
Self(LINEAR_BIT.to_owned())
}

#[staticmethod]
fn bool() -> Self {
Self(BOOL_T)
Expand Down
7 changes: 7 additions & 0 deletions tket2-py/test/test_guppy.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,10 @@ def my_func(

# The 7 operations in the function, plus two implicit QFree
assert circ.num_operations() == 9

tk1 = circ.to_tket1()
assert tk1.n_gates == 7
assert tk1.n_qubits == 2

gates = list(tk1)
assert gates[4].op.type == pytket.circuit.OpType.ZZMax
4 changes: 0 additions & 4 deletions tket2-py/tket2/_tket2/circuit.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,6 @@ class HugrType:
def qubit() -> HugrType:
"""Qubit type from HUGR prelude."""

@staticmethod
def linear_bit() -> HugrType:
"""Linear bit type from TKET1 extension."""

@staticmethod
def bool() -> HugrType:
"""Boolean type (HUGR 2-ary unit sum)."""
Expand Down
3 changes: 1 addition & 2 deletions tket2-py/tket2/circuit/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from dataclasses import dataclass

QB_T = HugrType.qubit()
LB_T = HugrType.linear_bit()
BOOL_T = HugrType.bool()


Expand All @@ -22,7 +21,7 @@ def bits(self) -> list[int]:

@classmethod
def op(cls) -> CustomOp:
types = [QB_T] * cls.n_qb + [LB_T] * cls.n_lb
types = [QB_T] * cls.n_qb + [BOOL_T] * cls.n_lb
return CustomOp(cls.extension_name, cls.gate_name, types, types)


Expand Down
4 changes: 2 additions & 2 deletions tket2/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -681,8 +681,8 @@ mod tests {
assert_eq!(circ.operations().count(), 3);

assert_eq!(circ.units().count(), qubits + bits);
assert_eq!(circ.nonlinear_units().count(), 0);
assert_eq!(circ.linear_units().count(), qubits + bits);
assert_eq!(circ.nonlinear_units().count(), bits);
assert_eq!(circ.linear_units().count(), qubits);
assert_eq!(circ.qubits().count(), qubits);
}

Expand Down
8 changes: 7 additions & 1 deletion tket2/src/circuit/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::collections::{HashMap, HashSet};
use std::iter::FusedIterator;

use hugr::hugr::views::{HierarchyView, SiblingGraph};
use hugr::hugr::NodeType;
use hugr::hugr::{NodeMetadata, NodeType};
use hugr::ops::{OpTag, OpTrait};
use hugr::{HugrView, IncomingPort, OutgoingPort};
use itertools::Either::{self, Left, Right};
Expand Down Expand Up @@ -161,6 +161,12 @@ impl<'circ, T: HugrView> Command<'circ, T> {
.port_kind(port)
.map_or(false, |kind| kind.is_linear())
}

/// Returns a metadata value associated with the command's node.
#[inline]
pub fn metadata(&self, key: impl AsRef<str>) -> Option<&NodeMetadata> {
self.circ.hugr().get_metadata(self.node, key)
}
}

impl<'a, 'circ, T: HugrView> UnitLabeller for &'a Command<'circ, T> {
Expand Down
16 changes: 1 addition & 15 deletions tket2/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use hugr::extension::{CustomSignatureFunc, ExtensionId, ExtensionRegistry, Signa
use hugr::hugr::IdentList;
use hugr::std_extensions::arithmetic::float_types::{EXTENSION as FLOAT_EXTENSION, FLOAT64_TYPE};
use hugr::types::type_param::{TypeArg, TypeParam};
use hugr::types::{CustomType, FunctionType, PolyFuncType, Type, TypeBound};
use hugr::types::{CustomType, FunctionType, PolyFuncType, TypeBound};
use hugr::{type_row, Extension};
use lazy_static::lazy_static;
use smol_str::SmolStr;
Expand All @@ -21,9 +21,6 @@ pub mod angle;
/// The ID of the TKET1 extension.
pub const TKET1_EXTENSION_ID: ExtensionId = IdentList::new_unchecked("TKET1");

/// The name for the linear bit custom type.
pub const LINEAR_BIT_NAME: SmolStr = SmolStr::new_inline("LBit");

/// The name for opaque TKET1 operations.
pub const TKET1_OP_NAME: SmolStr = SmolStr::new_inline("TKET1 Json Op");

Expand All @@ -39,8 +36,6 @@ pub static ref TKET1_OP_PAYLOAD : CustomType =
pub static ref TKET1_EXTENSION: Extension = {
let mut res = Extension::new(TKET1_EXTENSION_ID);

res.add_type(LINEAR_BIT_NAME, vec![], "A linear bit.".into(), TypeBound::Any.into()).unwrap();

let tket1_op_payload_def = res.add_type(TKET1_PAYLOAD_NAME, vec![], "Opaque TKET1 operation metadata.".into(), TypeBound::Eq.into()).unwrap();
let tket1_op_payload = TypeParam::Opaque{ty:tket1_op_payload_def.instantiate([]).unwrap()};
res.add_op(
Expand All @@ -52,15 +47,6 @@ pub static ref TKET1_EXTENSION: Extension = {
res
};

/// The type for linear bits. Part of the TKET1 extension.
pub static ref LINEAR_BIT: Type = {
Type::new_extension(TKET1_EXTENSION
.get_type(&LINEAR_BIT_NAME)
.unwrap()
.instantiate([])
.unwrap())
};

/// Extension registry including the prelude, TKET1 and Tk2Ops extensions.
pub static ref REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([
TKET1_EXTENSION.clone(),
Expand Down
123 changes: 109 additions & 14 deletions tket2/src/serialize/pytket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,39 +6,47 @@ mod op;

use hugr::types::Type;

use hugr::Node;
use itertools::Itertools;
// Required for serialising ops in the tket1 hugr extension.
pub(crate) use op::serialised::OpaqueTk1Op;

#[cfg(test)]
mod tests;

use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::path::Path;
use std::{fs, io};

use hugr::ops::{OpType, Value};
use hugr::ops::{NamedOp, OpType, Value};
use hugr::std_extensions::arithmetic::float_types::ConstF64;

use thiserror::Error;
use tket_json_rs::circuit_json::SerialCircuit;
use tket_json_rs::circuit_json::{self, SerialCircuit};
use tket_json_rs::optype::OpType as SerialOpType;

use crate::circuit::Circuit;

use self::decoder::JsonDecoder;
use self::encoder::JsonEncoder;
use self::decoder::Tk1Decoder;
use self::encoder::Tk1Encoder;

pub use crate::passes::pytket::lower_to_pytket;

/// Prefix used for storing metadata in the hugr nodes.
pub const METADATA_PREFIX: &str = "TKET1_JSON";
pub const METADATA_PREFIX: &str = "TKET1";
/// The global phase specified as metadata.
const METADATA_PHASE: &str = "TKET1_JSON.phase";
/// The implicit permutation of qubits.
const METADATA_IMPLICIT_PERM: &str = "TKET1_JSON.implicit_permutation";
const METADATA_PHASE: &str = "TKET1.phase";
/// Explicit names for the input qubit registers.
const METADATA_Q_REGISTERS: &str = "TKET1_JSON.qubit_registers";
const METADATA_Q_REGISTERS: &str = "TKET1.qubit_registers";
/// The reordered qubit registers in the output, if an implicit permutation was applied.
const METADATA_Q_OUTPUT_REGISTERS: &str = "TKET1.qubit_output_registers";
/// Explicit names for the input bit registers.
const METADATA_B_REGISTERS: &str = "TKET1_JSON.bit_registers";
const METADATA_B_REGISTERS: &str = "TKET1.bit_registers";
/// The reordered bit registers in the output, if an implicit permutation was applied.
const METADATA_B_OUTPUT_REGISTERS: &str = "TKET1.bit_output_registers";
/// A tket1 operation "opgroup" field.
const METADATA_OPGROUP: &str = "TKET1.opgroup";

/// A serialized representation of a [`Circuit`].
///
Expand All @@ -59,7 +67,7 @@ impl TKETDecode for SerialCircuit {
type EncodeError = TK1ConvertError;

fn decode(self) -> Result<Circuit, Self::DecodeError> {
let mut decoder = JsonDecoder::try_new(&self)?;
let mut decoder = Tk1Decoder::try_new(&self)?;

if !self.phase.is_empty() {
// TODO - add a phase gate
Expand All @@ -68,18 +76,18 @@ impl TKETDecode for SerialCircuit {
}

for com in self.commands {
decoder.add_command(com);
decoder.add_command(com)?;
}
Ok(decoder.finish().into())
}

fn encode(circ: &Circuit) -> Result<Self, Self::EncodeError> {
let mut encoder = JsonEncoder::new(circ)?;
let mut encoder = Tk1Encoder::new(circ)?;
for com in circ.commands() {
let optype = com.optype();
encoder.add_command(com.clone(), optype)?;
}
Ok(encoder.finish())
Ok(encoder.finish(circ))
}
}

Expand Down Expand Up @@ -156,13 +164,83 @@ pub enum OpConvertError {
/// The serialized operation is not supported.
#[error("Cannot serialize tket2 operation: {0:?}")]
UnsupportedOpSerialization(OpType),
/// The operation has non-serializable inputs.
#[error("Operation {} in {node} has an unsupported input of type {typ}.", optype.name())]
UnsupportedInputType {
/// The unsupported type.
typ: Type,
/// The operation name.
optype: OpType,
/// The node.
node: Node,
},
/// The operation has non-serializable outputs.
#[error("Operation {} in {node} has an unsupported output of type {typ}.", optype.name())]
UnsupportedOutputType {
/// The unsupported type.
typ: Type,
/// The operation name.
optype: OpType,
/// The node.
node: Node,
},
/// A parameter input could not be evaluated.
#[error("The {typ} parameter input for operation {} in {node} could not be resolved.", optype.name())]
UnresolvedParamInput {
/// The parameter type.
typ: Type,
/// The operation with the missing input param.
optype: OpType,
/// The node.
node: Node,
},
/// The operation has output-only qubits.
/// This is not currently supported by the encoder.
#[error("Operation {} in {node} has more output qubits than inputs.", optype.name())]
TooManyOutputQubits {
/// The unsupported type.
typ: Type,
/// The operation name.
optype: OpType,
/// The node.
node: Node,
},
/// The opaque tket1 operation had an invalid type parameter.
#[error("Opaque TKET1 operation had an invalid type parameter. {error}")]
InvalidOpaqueTypeParam {
/// The serialization error.
#[from]
error: serde_yaml::Error,
},
/// Tried to decode a tket1 operation with not enough parameters.
#[error(
"Operation {} is missing encoded parameters. Expected at least {expected} but only \"{}\" were specified.",
optype.name(),
params.iter().join(", "),
)]
MissingSerialisedParams {
/// The operation name.
optype: OpType,
/// The expected number of parameters.
expected: usize,
/// The given of parameters.
params: Vec<String>,
},
/// Tried to decode a tket1 operation with not enough qubit/bit arguments.
#[error(
"Operation {} is missing encoded arguments. Expected {expected_qubits} and {expected_bits}, but only \"{args:?}\" were specified.",
optype.name(),
)]
MissingSerialisedArguments {
/// The operation name.
optype: OpType,
/// The expected number of qubits.
expected_qubits: usize,
/// The expected number of bits.
expected_bits: usize,
/// The given of parameters.
args: Vec<circuit_json::Register>,
},
}

/// Error type for conversion between `Op` and `OpType`.
Expand Down Expand Up @@ -234,3 +312,20 @@ fn try_constant_to_param(val: &Value) -> Option<String> {
let half_turns = radians / std::f64::consts::PI;
Some(half_turns.to_string())
}

/// A hashed register, used to identify registers in the [`Tk1Decoder::register_wire`] map,
/// avoiding string and vector clones on lookup.
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
struct RegisterHash {
hash: u64,
}

impl From<&circuit_json::Register> for RegisterHash {
fn from(reg: &circuit_json::Register) -> Self {
let mut hasher = DefaultHasher::new();
reg.hash(&mut hasher);
Self {
hash: hasher.finish(),
}
}
}
Loading

0 comments on commit a6e9e13

Please sign in to comment.