Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Cleanup tket1 serialized op structures #419

Merged
merged 9 commits into from
Jun 20, 2024
6 changes: 5 additions & 1 deletion tket2-py/src/passes/chunks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ impl PyCircuitChunks {
fn update_circuit(&mut self, index: usize, new_circ: &Bound<PyAny>) -> PyResult<()> {
try_with_hugr(new_circ, |hugr, _| {
let circ: Circuit = hugr.into();
if circ.circuit_signature() != self.chunks[index].circuit_signature() {
let circuit_sig = circ.circuit_signature();
let chunk_sig = self.chunks[index].circuit_signature();
if circuit_sig.input() != chunk_sig.input()
|| circuit_sig.output() != chunk_sig.output()
{
return Err(PyAttributeError::new_err(
"The new circuit has a different signature.",
));
Expand Down
65 changes: 11 additions & 54 deletions tket2/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@
//!
//! This includes a extension for the opaque TKET1 operations.

use super::serialize::pytket::JsonOp;
use crate::serialize::pytket::OpaqueTk1Op;
use crate::Tk2Op;
use hugr::extension::prelude::PRELUDE;
use hugr::extension::simple_op::MakeOpDef;
use hugr::extension::{CustomSignatureFunc, ExtensionId, ExtensionRegistry, SignatureError};
use hugr::hugr::IdentList;
use hugr::ops::custom::{CustomOp, OpaqueOp};
use hugr::ops::NamedOp;
use hugr::std_extensions::arithmetic::float_types::{EXTENSION as FLOAT_EXTENSION, FLOAT64_TYPE};
use hugr::types::type_param::{CustomTypeArg, TypeArg, TypeParam};
use hugr::types::type_param::{TypeArg, TypeParam};
use hugr::types::{CustomType, FunctionType, PolyFuncType, Type, TypeBound};
use hugr::{type_row, Extension};
use lazy_static::lazy_static;
Expand All @@ -27,28 +25,28 @@ pub const TKET1_EXTENSION_ID: ExtensionId = IdentList::new_unchecked("TKET1");
pub const LINEAR_BIT_NAME: SmolStr = SmolStr::new_inline("LBit");

/// The name for opaque TKET1 operations.
pub const JSON_OP_NAME: SmolStr = SmolStr::new_inline("TKET1 Json Op");
pub const TKET1_OP_NAME: SmolStr = SmolStr::new_inline("TKET1 Json Op");

/// The ID of an opaque TKET1 operation metadata.
pub const JSON_PAYLOAD_NAME: SmolStr = SmolStr::new_inline("TKET1 Json Payload");
pub const TKET1_PAYLOAD_NAME: SmolStr = SmolStr::new_inline("TKET1 Json Payload");

lazy_static! {
/// A custom type for the encoded TKET1 operation
static ref TKET1_OP_PAYLOAD : CustomType =
TKET1_EXTENSION.get_type(&JSON_PAYLOAD_NAME).unwrap().instantiate([]).unwrap();
pub static ref TKET1_OP_PAYLOAD : CustomType =
TKET1_EXTENSION.get_type(&TKET1_PAYLOAD_NAME).unwrap().instantiate([]).unwrap();

/// The TKET1 extension, containing the opaque TKET1 operations.
pub static ref TKET1_EXTENSION: Extension = {
let mut res = Extension::new(TKET1_EXTENSION_ID);

res.add_type(LINEAR_BIT_NAME, vec![], "A linear bit.".into(), TypeBound::Any.into()).unwrap();

let json_op_payload_def = res.add_type(JSON_PAYLOAD_NAME, vec![], "Opaque TKET1 operation metadata.".into(), TypeBound::Eq.into()).unwrap();
let json_op_payload = TypeParam::Opaque{ty:json_op_payload_def.instantiate([]).unwrap()};
let tket1_op_payload_def = res.add_type(TKET1_PAYLOAD_NAME, vec![], "Opaque TKET1 operation metadata.".into(), TypeBound::Eq.into()).unwrap();
let tket1_op_payload = TypeParam::Opaque{ty:tket1_op_payload_def.instantiate([]).unwrap()};
res.add_op(
JSON_OP_NAME,
TKET1_OP_NAME,
"An opaque TKET1 operation.".into(),
JsonOpSignature([json_op_payload])
JsonOpSignature([tket1_op_payload])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we rename JsonOpSignature?

).unwrap();

res
Expand All @@ -72,47 +70,6 @@ pub static ref REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([
]).unwrap();


}
/// Create a new opaque operation
pub(crate) fn wrap_json_op(op: &JsonOp) -> CustomOp {
// TODO: This throws an error
//let op = serde_yaml::to_value(op).unwrap();
//let payload = TypeArg::Opaque(CustomTypeArg::new(TKET1_OP_PAYLOAD.clone(), op).unwrap());
//TKET1_EXTENSION
// .get_op(&JSON_OP_NAME)
// .unwrap()
// .instantiate_opaque([payload])
// .unwrap()
// .into()
let sig = op.signature();
let op = serde_yaml::to_value(op).unwrap();
let payload = TypeArg::Opaque {
arg: CustomTypeArg::new(TKET1_OP_PAYLOAD.clone(), op).unwrap(),
};
OpaqueOp::new(
TKET1_EXTENSION_ID,
JSON_OP_NAME,
"".into(),
vec![payload],
sig,
)
.into()
}

/// Extract a json-encoded TKET1 operation from an opaque operation, if
/// possible.
pub(crate) fn try_unwrap_json_op(ext: &CustomOp) -> Option<JsonOp> {
// TODO: Check `extensions.contains(&TKET1_EXTENSION_ID)`
// (but the ext op extensions are an empty set?)
if ext.name() != format!("{TKET1_EXTENSION_ID}.{JSON_OP_NAME}") {
return None;
}
let Some(TypeArg::Opaque { arg }) = ext.args().first() else {
// TODO: Throw an error? We should never get here if the name matches.
return None;
};
let op = serde_yaml::from_value(arg.value.clone()).ok()?;
Some(op)
}

struct JsonOpSignature([TypeParam; 1]);
Expand All @@ -128,7 +85,7 @@ impl CustomSignatureFunc for JsonOpSignature {
// This should have already been checked.
panic!("Wrong number of arguments");
};
let op: JsonOp = serde_yaml::from_value(arg.value.clone()).unwrap(); // TODO Errors!
let op: OpaqueTk1Op = serde_yaml::from_value(arg.value.clone()).unwrap(); // TODO Errors!
Ok(op.signature().into())
}

Expand Down
44 changes: 26 additions & 18 deletions tket2/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,12 @@ pub enum Pauli {
Z,
}

#[derive(Debug, Error, PartialEq, Clone, Copy)]
#[error("Not a Tk2Op.")]
pub struct NotTk2Op;
#[derive(Debug, Error, PartialEq, Clone)]
#[error("{} is not a Tk2Op.", op.name())]
pub struct NotTk2Op {
/// The offending operation.
pub op: OpType,
}

impl Pauli {
/// Check if this pauli commutes with another.
Expand Down Expand Up @@ -226,29 +229,34 @@ impl TryFrom<&OpType> for Tk2Op {
type Error = NotTk2Op;

fn try_from(op: &OpType) -> Result<Self, Self::Error> {
let OpType::CustomOp(custom_op) = op else {
return Err(NotTk2Op);
};

match custom_op {
CustomOp::Extension(ext) => Tk2Op::from_extension_op(ext),
CustomOp::Opaque(opaque) => {
if opaque.extension() != &EXTENSION_ID {
return Err(NotTk2Op);
}
try_from_name(opaque.name())
}
}
.map_err(|_| NotTk2Op)
optype_to_tk2op(op).ok_or_else(|| NotTk2Op { op: op.clone() })
}
}

impl TryFrom<OpType> for Tk2Op {
type Error = NotTk2Op;

fn try_from(op: OpType) -> Result<Self, Self::Error> {
Self::try_from(&op)
optype_to_tk2op(&op).ok_or_else(|| NotTk2Op { op })
}
}

// Internal implementation for `TryFrom<Optype> for Tk2Op` that doesn't copy the `OpType` when it errors.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Internal implementation for `TryFrom<Optype> for Tk2Op` that doesn't copy the `OpType` when it errors.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, the original function that returned an Option was meant as an optimisation to avoid cloning the optype on the error path of TryFrom<OpType>. I guess it's too much of a micro optimisation, so I'll just leave the TryFrom<&OpType impl as we had before.

fn optype_to_tk2op(op: &OpType) -> Option<Tk2Op> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both callsites create the same error anyway, I would prefer to do it once here. Not a big deal though.

let OpType::CustomOp(custom_op) = op else {
return None;
};

match custom_op {
CustomOp::Extension(ext) => Tk2Op::from_extension_op(ext),
CustomOp::Opaque(opaque) => {
if opaque.extension() != &EXTENSION_ID {
return None;
}
try_from_name(opaque.name())
}
}
.ok()
}

#[cfg(test)]
Expand Down
32 changes: 15 additions & 17 deletions tket2/src/serialize/pytket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@ mod encoder;
mod op;

use hugr::types::Type;
pub(crate) use op::JsonOp;

// Required for serialising ops in the tket1 hugr extension.
pub(crate) use op::serialised::OpaqueTk1Op;

#[cfg(test)]
mod tests;

use hugr::CircuitUnit;

use std::path::Path;
use std::{fs, io};

use hugr::ops::{OpType, Value};
use hugr::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE};
use hugr::std_extensions::arithmetic::float_types::ConstF64;

use thiserror::Error;
use tket_json_rs::circuit_json::SerialCircuit;
Expand Down Expand Up @@ -59,7 +59,7 @@ impl TKETDecode for SerialCircuit {
type EncodeError = TK1ConvertError;

fn decode(self) -> Result<Circuit, Self::DecodeError> {
let mut decoder = JsonDecoder::new(&self);
let mut decoder = JsonDecoder::try_new(&self)?;

if !self.phase.is_empty() {
// TODO - add a phase gate
Expand All @@ -74,17 +74,7 @@ impl TKETDecode for SerialCircuit {
}

fn encode(circ: &Circuit) -> Result<Self, Self::EncodeError> {
let mut encoder = JsonEncoder::new(circ);
let f64_inputs = circ.units().filter_map(|(wire, _, t)| match (wire, t) {
(CircuitUnit::Wire(wire), t) if t == FLOAT64_TYPE => Some(Ok(wire)),
(CircuitUnit::Linear(_), _) => None,
(_, typ) => Some(Err(TK1ConvertError::NonSerializableInputs { typ })),
});
for (i, wire) in f64_inputs.enumerate() {
let wire = wire?;
let param = format!("f{i}");
encoder.add_parameter(wire, param);
}
let mut encoder = JsonEncoder::new(circ)?;
for com in circ.commands() {
let optype = com.optype();
encoder.add_command(com.clone(), optype)?;
Expand Down Expand Up @@ -171,14 +161,22 @@ pub fn save_tk1_json_str(circ: &Circuit) -> Result<String, TK1ConvertError> {
#[derive(Debug, Error)]
pub enum TK1ConvertError {
/// Operation conversion error.
#[error("{0}")]
#[error(transparent)]
OpConversionError(#[from] OpConvertError),
/// The circuit has non-serializable inputs.
#[error("Circuit contains non-serializable input of type {typ}.")]
NonSerializableInputs {
/// The unsupported type.
typ: Type,
},
/// The circuit uses multi-indexed registers.
//
// This could be supported in the future, if there is a need for it.
#[error("Register {register} in the circuit has multiple indices. Tket2 does not support multi-indexed registers.")]
MultiIndexedRegister {
/// The register name.
register: String,
},
/// Invalid JSON,
#[error("Invalid pytket JSON. {0}")]
InvalidJson(#[from] serde_json::Error),
Expand Down
Loading