Skip to content

Commit

Permalink
refactor: Cleanup tket1 serialized op structures (#419)
Browse files Browse the repository at this point in the history
This is a noisy internal refactor of `::serialize::pytket::op::JsonOp`,
extracted from the work towards #379.

`JsonOp` was a temporary structure used during the encoding/decoding of
pytket circuits that represented two different kinds of operation:
- pytket operations with a direct tket2 counterpart
- other operations that have to be encoded as OpaqueOps

This mixed up the two definitions, and made applying custom logic to one
of the variants more annoying.
(E.g. the special handling of bit input/outputs for tket2 ops needed for
#379).

This PR splits the structs into a `Native` and an `Opaque` variant, so
we can keep the implementation clean. The code is functionally the same.

---------

Co-authored-by: doug-q <[email protected]>
  • Loading branch information
aborgna-q and doug-q authored Jun 20, 2024
1 parent 8c5a487 commit a46e63f
Show file tree
Hide file tree
Showing 12 changed files with 643 additions and 450 deletions.
6 changes: 5 additions & 1 deletion tket2-py/src/passes/chunks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ impl PyCircuitChunks {
fn update_circuit(&mut self, index: usize, new_circ: &Bound<PyAny>) -> PyResult<()> {
try_with_hugr(new_circ, |hugr, _| {
let circ: Circuit = hugr.into();
if circ.circuit_signature() != self.chunks[index].circuit_signature() {
let circuit_sig = circ.circuit_signature();
let chunk_sig = self.chunks[index].circuit_signature();
if circuit_sig.input() != chunk_sig.input()
|| circuit_sig.output() != chunk_sig.output()
{
return Err(PyAttributeError::new_err(
"The new circuit has a different signature.",
));
Expand Down
69 changes: 13 additions & 56 deletions tket2/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@
//!
//! This includes a extension for the opaque TKET1 operations.

use super::serialize::pytket::JsonOp;
use crate::serialize::pytket::OpaqueTk1Op;
use crate::Tk2Op;
use hugr::extension::prelude::PRELUDE;
use hugr::extension::simple_op::MakeOpDef;
use hugr::extension::{CustomSignatureFunc, ExtensionId, ExtensionRegistry, SignatureError};
use hugr::hugr::IdentList;
use hugr::ops::custom::{CustomOp, OpaqueOp};
use hugr::ops::NamedOp;
use hugr::std_extensions::arithmetic::float_types::{EXTENSION as FLOAT_EXTENSION, FLOAT64_TYPE};
use hugr::types::type_param::{CustomTypeArg, TypeArg, TypeParam};
use hugr::types::type_param::{TypeArg, TypeParam};
use hugr::types::{CustomType, FunctionType, PolyFuncType, Type, TypeBound};
use hugr::{type_row, Extension};
use lazy_static::lazy_static;
Expand All @@ -27,28 +25,28 @@ pub const TKET1_EXTENSION_ID: ExtensionId = IdentList::new_unchecked("TKET1");
pub const LINEAR_BIT_NAME: SmolStr = SmolStr::new_inline("LBit");

/// The name for opaque TKET1 operations.
pub const JSON_OP_NAME: SmolStr = SmolStr::new_inline("TKET1 Json Op");
pub const TKET1_OP_NAME: SmolStr = SmolStr::new_inline("TKET1 Json Op");

/// The ID of an opaque TKET1 operation metadata.
pub const JSON_PAYLOAD_NAME: SmolStr = SmolStr::new_inline("TKET1 Json Payload");
pub const TKET1_PAYLOAD_NAME: SmolStr = SmolStr::new_inline("TKET1 Json Payload");

lazy_static! {
/// A custom type for the encoded TKET1 operation
static ref TKET1_OP_PAYLOAD : CustomType =
TKET1_EXTENSION.get_type(&JSON_PAYLOAD_NAME).unwrap().instantiate([]).unwrap();
pub static ref TKET1_OP_PAYLOAD : CustomType =
TKET1_EXTENSION.get_type(&TKET1_PAYLOAD_NAME).unwrap().instantiate([]).unwrap();

/// The TKET1 extension, containing the opaque TKET1 operations.
pub static ref TKET1_EXTENSION: Extension = {
let mut res = Extension::new(TKET1_EXTENSION_ID);

res.add_type(LINEAR_BIT_NAME, vec![], "A linear bit.".into(), TypeBound::Any.into()).unwrap();

let json_op_payload_def = res.add_type(JSON_PAYLOAD_NAME, vec![], "Opaque TKET1 operation metadata.".into(), TypeBound::Eq.into()).unwrap();
let json_op_payload = TypeParam::Opaque{ty:json_op_payload_def.instantiate([]).unwrap()};
let tket1_op_payload_def = res.add_type(TKET1_PAYLOAD_NAME, vec![], "Opaque TKET1 operation metadata.".into(), TypeBound::Eq.into()).unwrap();
let tket1_op_payload = TypeParam::Opaque{ty:tket1_op_payload_def.instantiate([]).unwrap()};
res.add_op(
JSON_OP_NAME,
TKET1_OP_NAME,
"An opaque TKET1 operation.".into(),
JsonOpSignature([json_op_payload])
Tk1Signature([tket1_op_payload])
).unwrap();

res
Expand All @@ -72,52 +70,11 @@ pub static ref REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([
]).unwrap();


}
/// Create a new opaque operation
pub(crate) fn wrap_json_op(op: &JsonOp) -> CustomOp {
// TODO: This throws an error
//let op = serde_yaml::to_value(op).unwrap();
//let payload = TypeArg::Opaque(CustomTypeArg::new(TKET1_OP_PAYLOAD.clone(), op).unwrap());
//TKET1_EXTENSION
// .get_op(&JSON_OP_NAME)
// .unwrap()
// .instantiate_opaque([payload])
// .unwrap()
// .into()
let sig = op.signature();
let op = serde_yaml::to_value(op).unwrap();
let payload = TypeArg::Opaque {
arg: CustomTypeArg::new(TKET1_OP_PAYLOAD.clone(), op).unwrap(),
};
OpaqueOp::new(
TKET1_EXTENSION_ID,
JSON_OP_NAME,
"".into(),
vec![payload],
sig,
)
.into()
}

/// Extract a json-encoded TKET1 operation from an opaque operation, if
/// possible.
pub(crate) fn try_unwrap_json_op(ext: &CustomOp) -> Option<JsonOp> {
// TODO: Check `extensions.contains(&TKET1_EXTENSION_ID)`
// (but the ext op extensions are an empty set?)
if ext.name() != format!("{TKET1_EXTENSION_ID}.{JSON_OP_NAME}") {
return None;
}
let Some(TypeArg::Opaque { arg }) = ext.args().first() else {
// TODO: Throw an error? We should never get here if the name matches.
return None;
};
let op = serde_yaml::from_value(arg.value.clone()).ok()?;
Some(op)
}

struct JsonOpSignature([TypeParam; 1]);
struct Tk1Signature([TypeParam; 1]);

impl CustomSignatureFunc for JsonOpSignature {
impl CustomSignatureFunc for Tk1Signature {
fn compute_signature<'o, 'a: 'o>(
&'a self,
arg_values: &[TypeArg],
Expand All @@ -128,7 +85,7 @@ impl CustomSignatureFunc for JsonOpSignature {
// This should have already been checked.
panic!("Wrong number of arguments");
};
let op: JsonOp = serde_yaml::from_value(arg.value.clone()).unwrap(); // TODO Errors!
let op: OpaqueTk1Op = serde_yaml::from_value(arg.value.clone()).unwrap(); // TODO Errors!
Ok(op.signature().into())
}

Expand Down
41 changes: 18 additions & 23 deletions tket2/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,12 @@ pub enum Pauli {
Z,
}

#[derive(Debug, Error, PartialEq, Clone, Copy)]
#[error("Not a Tk2Op.")]
pub struct NotTk2Op;
#[derive(Debug, Error, PartialEq, Clone)]
#[error("{} is not a Tk2Op.", op.name())]
pub struct NotTk2Op {
/// The offending operation.
pub op: OpType,
}

impl Pauli {
/// Check if this pauli commutes with another.
Expand Down Expand Up @@ -226,28 +229,20 @@ impl TryFrom<&OpType> for Tk2Op {
type Error = NotTk2Op;

fn try_from(op: &OpType) -> Result<Self, Self::Error> {
let OpType::CustomOp(custom_op) = op else {
return Err(NotTk2Op);
};

match custom_op {
CustomOp::Extension(ext) => Tk2Op::from_extension_op(ext),
CustomOp::Opaque(opaque) => {
if opaque.extension() != &EXTENSION_ID {
return Err(NotTk2Op);
}
try_from_name(opaque.name())
{
let OpType::CustomOp(custom_op) = op else {
return Err(NotTk2Op { op: op.clone() });
};

match custom_op {
CustomOp::Extension(ext) => Tk2Op::from_extension_op(ext).ok(),
CustomOp::Opaque(opaque) => match opaque.extension() == &EXTENSION_ID {
true => try_from_name(opaque.name()).ok(),
false => None,
},
}
.ok_or_else(|| NotTk2Op { op: op.clone() })
}
.map_err(|_| NotTk2Op)
}
}

impl TryFrom<OpType> for Tk2Op {
type Error = NotTk2Op;

fn try_from(op: OpType) -> Result<Self, Self::Error> {
Self::try_from(&op)
}
}

Expand Down
16 changes: 3 additions & 13 deletions tket2/src/passes/commutation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ fn load_slices(circ: &Circuit<impl HugrView>) -> SliceVec {

/// check if node is one we want to put in to a slice.
fn is_slice_op(h: &impl HugrView, node: Node) -> bool {
let op: Result<Tk2Op, _> = h.get_optype(node).clone().try_into();
let op: Result<Tk2Op, _> = h.get_optype(node).try_into();
op.is_ok()
}

Expand Down Expand Up @@ -156,22 +156,12 @@ fn commutes_at_slice(

let port = command.port_of_qb(q, Direction::Incoming)?;

let op: Tk2Op = circ
.hugr()
.get_optype(command.node())
.clone()
.try_into()
.ok()?;
let op: Tk2Op = circ.hugr().get_optype(command.node()).try_into().ok()?;
// TODO: if not tk2op, might still have serialized commutation data we
// can use.
let pauli = commutation_on_port(&op.qubit_commutation(), port)?;

let other_op: Tk2Op = circ
.hugr()
.get_optype(other_com.node())
.clone()
.try_into()
.ok()?;
let other_op: Tk2Op = circ.hugr().get_optype(other_com.node()).try_into().ok()?;
let other_pauli = commutation_on_port(
&other_op.qubit_commutation(),
other_com.port_of_qb(q, Direction::Outgoing)?,
Expand Down
3 changes: 2 additions & 1 deletion tket2/src/passes/pytket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ pub fn lower_to_pytket(circ: &Circuit) -> Result<Circuit, PytketLoweringError> {
}

/// Errors that can occur during the lowering process.
#[derive(Clone, PartialEq, Debug, thiserror::Error)]
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum PytketLoweringError {
/// An error occurred during the conversion of an operation.
#[error("operation conversion error: {0}")]
Expand Down
65 changes: 36 additions & 29 deletions tket2/src/serialize/pytket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,22 @@ mod encoder;
mod op;

use hugr::types::Type;
pub(crate) use op::JsonOp;

// Required for serialising ops in the tket1 hugr extension.
pub(crate) use op::serialised::OpaqueTk1Op;

#[cfg(test)]
mod tests;

use hugr::CircuitUnit;

use std::path::Path;
use std::{fs, io};

use hugr::ops::{OpType, Value};
use hugr::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE};
use hugr::std_extensions::arithmetic::float_types::ConstF64;

use thiserror::Error;
use tket_json_rs::circuit_json::SerialCircuit;
use tket_json_rs::optype::OpType as JsonOpType;
use tket_json_rs::optype::OpType as SerialOpType;

use crate::circuit::Circuit;

Expand Down Expand Up @@ -59,7 +59,7 @@ impl TKETDecode for SerialCircuit {
type EncodeError = TK1ConvertError;

fn decode(self) -> Result<Circuit, Self::DecodeError> {
let mut decoder = JsonDecoder::new(&self);
let mut decoder = JsonDecoder::try_new(&self)?;

if !self.phase.is_empty() {
// TODO - add a phase gate
Expand All @@ -74,17 +74,7 @@ impl TKETDecode for SerialCircuit {
}

fn encode(circ: &Circuit) -> Result<Self, Self::EncodeError> {
let mut encoder = JsonEncoder::new(circ);
let f64_inputs = circ.units().filter_map(|(wire, _, t)| match (wire, t) {
(CircuitUnit::Wire(wire), t) if t == FLOAT64_TYPE => Some(Ok(wire)),
(CircuitUnit::Linear(_), _) => None,
(_, typ) => Some(Err(TK1ConvertError::NonSerializableInputs { typ })),
});
for (i, wire) in f64_inputs.enumerate() {
let wire = wire?;
let param = format!("f{i}");
encoder.add_parameter(wire, param);
}
let mut encoder = JsonEncoder::new(circ)?;
for com in circ.commands() {
let optype = com.optype();
encoder.add_command(com.clone(), optype)?;
Expand All @@ -93,17 +83,6 @@ impl TKETDecode for SerialCircuit {
}
}

/// Error type for conversion between `Op` and `OpType`.
#[derive(Clone, PartialEq, Debug, Error)]
pub enum OpConvertError {
/// The serialized operation is not supported.
#[error("Unsupported serialized pytket operation: {0:?}")]
UnsupportedSerializedOp(JsonOpType),
/// The serialized operation is not supported.
#[error("Cannot serialize tket2 operation: {0:?}")]
UnsupportedOpSerialization(OpType),
}

/// Load a TKET1 circuit from a JSON file.
pub fn load_tk1_json_file(path: impl AsRef<Path>) -> Result<Circuit, TK1ConvertError> {
let file = fs::File::open(path)?;
Expand Down Expand Up @@ -169,16 +148,44 @@ pub fn save_tk1_json_str(circ: &Circuit) -> Result<String, TK1ConvertError> {

/// Error type for conversion between `Op` and `OpType`.
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum OpConvertError {
/// The serialized operation is not supported.
#[error("Unsupported serialized pytket operation: {0:?}")]
UnsupportedSerializedOp(SerialOpType),
/// The serialized operation is not supported.
#[error("Cannot serialize tket2 operation: {0:?}")]
UnsupportedOpSerialization(OpType),
/// The opaque tket1 operation had an invalid type parameter.
#[error("Opaque TKET1 operation had an invalid type parameter. {error}")]
InvalidOpaqueTypeParam {
/// The serialization error.
#[from]
error: serde_yaml::Error,
},
}

/// Error type for conversion between `Op` and `OpType`.
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum TK1ConvertError {
/// Operation conversion error.
#[error("{0}")]
#[error(transparent)]
OpConversionError(#[from] OpConvertError),
/// The circuit has non-serializable inputs.
#[error("Circuit contains non-serializable input of type {typ}.")]
NonSerializableInputs {
/// The unsupported type.
typ: Type,
},
/// The circuit uses multi-indexed registers.
//
// This could be supported in the future, if there is a need for it.
#[error("Register {register} in the circuit has multiple indices. Tket2 does not support multi-indexed registers.")]
MultiIndexedRegister {
/// The register name.
register: String,
},
/// Invalid JSON,
#[error("Invalid pytket JSON. {0}")]
InvalidJson(#[from] serde_json::Error),
Expand Down
Loading

0 comments on commit a46e63f

Please sign in to comment.