diff --git a/tket2-py/src/passes/chunks.rs b/tket2-py/src/passes/chunks.rs index ad973d7a..edb31ffd 100644 --- a/tket2-py/src/passes/chunks.rs +++ b/tket2-py/src/passes/chunks.rs @@ -55,7 +55,11 @@ impl PyCircuitChunks { fn update_circuit(&mut self, index: usize, new_circ: &Bound) -> PyResult<()> { try_with_hugr(new_circ, |hugr, _| { let circ: Circuit = hugr.into(); - if circ.circuit_signature() != self.chunks[index].circuit_signature() { + let circuit_sig = circ.circuit_signature(); + let chunk_sig = self.chunks[index].circuit_signature(); + if circuit_sig.input() != chunk_sig.input() + || circuit_sig.output() != chunk_sig.output() + { return Err(PyAttributeError::new_err( "The new circuit has a different signature.", )); diff --git a/tket2/src/extension.rs b/tket2/src/extension.rs index 3e048c40..12b47517 100644 --- a/tket2/src/extension.rs +++ b/tket2/src/extension.rs @@ -2,16 +2,14 @@ //! //! This includes a extension for the opaque TKET1 operations. -use super::serialize::pytket::JsonOp; +use crate::serialize::pytket::OpaqueTk1Op; use crate::Tk2Op; use hugr::extension::prelude::PRELUDE; use hugr::extension::simple_op::MakeOpDef; use hugr::extension::{CustomSignatureFunc, ExtensionId, ExtensionRegistry, SignatureError}; use hugr::hugr::IdentList; -use hugr::ops::custom::{CustomOp, OpaqueOp}; -use hugr::ops::NamedOp; use hugr::std_extensions::arithmetic::float_types::{EXTENSION as FLOAT_EXTENSION, FLOAT64_TYPE}; -use hugr::types::type_param::{CustomTypeArg, TypeArg, TypeParam}; +use hugr::types::type_param::{TypeArg, TypeParam}; use hugr::types::{CustomType, FunctionType, PolyFuncType, Type, TypeBound}; use hugr::{type_row, Extension}; use lazy_static::lazy_static; @@ -27,15 +25,15 @@ pub const TKET1_EXTENSION_ID: ExtensionId = IdentList::new_unchecked("TKET1"); pub const LINEAR_BIT_NAME: SmolStr = SmolStr::new_inline("LBit"); /// The name for opaque TKET1 operations. -pub const JSON_OP_NAME: SmolStr = SmolStr::new_inline("TKET1 Json Op"); +pub const TKET1_OP_NAME: SmolStr = SmolStr::new_inline("TKET1 Json Op"); /// The ID of an opaque TKET1 operation metadata. -pub const JSON_PAYLOAD_NAME: SmolStr = SmolStr::new_inline("TKET1 Json Payload"); +pub const TKET1_PAYLOAD_NAME: SmolStr = SmolStr::new_inline("TKET1 Json Payload"); lazy_static! { /// A custom type for the encoded TKET1 operation -static ref TKET1_OP_PAYLOAD : CustomType = - TKET1_EXTENSION.get_type(&JSON_PAYLOAD_NAME).unwrap().instantiate([]).unwrap(); +pub static ref TKET1_OP_PAYLOAD : CustomType = + TKET1_EXTENSION.get_type(&TKET1_PAYLOAD_NAME).unwrap().instantiate([]).unwrap(); /// The TKET1 extension, containing the opaque TKET1 operations. pub static ref TKET1_EXTENSION: Extension = { @@ -43,12 +41,12 @@ pub static ref TKET1_EXTENSION: Extension = { res.add_type(LINEAR_BIT_NAME, vec![], "A linear bit.".into(), TypeBound::Any.into()).unwrap(); - let json_op_payload_def = res.add_type(JSON_PAYLOAD_NAME, vec![], "Opaque TKET1 operation metadata.".into(), TypeBound::Eq.into()).unwrap(); - let json_op_payload = TypeParam::Opaque{ty:json_op_payload_def.instantiate([]).unwrap()}; + let tket1_op_payload_def = res.add_type(TKET1_PAYLOAD_NAME, vec![], "Opaque TKET1 operation metadata.".into(), TypeBound::Eq.into()).unwrap(); + let tket1_op_payload = TypeParam::Opaque{ty:tket1_op_payload_def.instantiate([]).unwrap()}; res.add_op( - JSON_OP_NAME, + TKET1_OP_NAME, "An opaque TKET1 operation.".into(), - JsonOpSignature([json_op_payload]) + Tk1Signature([tket1_op_payload]) ).unwrap(); res @@ -72,52 +70,11 @@ pub static ref REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ ]).unwrap(); -} -/// Create a new opaque operation -pub(crate) fn wrap_json_op(op: &JsonOp) -> CustomOp { - // TODO: This throws an error - //let op = serde_yaml::to_value(op).unwrap(); - //let payload = TypeArg::Opaque(CustomTypeArg::new(TKET1_OP_PAYLOAD.clone(), op).unwrap()); - //TKET1_EXTENSION - // .get_op(&JSON_OP_NAME) - // .unwrap() - // .instantiate_opaque([payload]) - // .unwrap() - // .into() - let sig = op.signature(); - let op = serde_yaml::to_value(op).unwrap(); - let payload = TypeArg::Opaque { - arg: CustomTypeArg::new(TKET1_OP_PAYLOAD.clone(), op).unwrap(), - }; - OpaqueOp::new( - TKET1_EXTENSION_ID, - JSON_OP_NAME, - "".into(), - vec![payload], - sig, - ) - .into() -} - -/// Extract a json-encoded TKET1 operation from an opaque operation, if -/// possible. -pub(crate) fn try_unwrap_json_op(ext: &CustomOp) -> Option { - // TODO: Check `extensions.contains(&TKET1_EXTENSION_ID)` - // (but the ext op extensions are an empty set?) - if ext.name() != format!("{TKET1_EXTENSION_ID}.{JSON_OP_NAME}") { - return None; - } - let Some(TypeArg::Opaque { arg }) = ext.args().first() else { - // TODO: Throw an error? We should never get here if the name matches. - return None; - }; - let op = serde_yaml::from_value(arg.value.clone()).ok()?; - Some(op) } -struct JsonOpSignature([TypeParam; 1]); +struct Tk1Signature([TypeParam; 1]); -impl CustomSignatureFunc for JsonOpSignature { +impl CustomSignatureFunc for Tk1Signature { fn compute_signature<'o, 'a: 'o>( &'a self, arg_values: &[TypeArg], @@ -128,7 +85,7 @@ impl CustomSignatureFunc for JsonOpSignature { // This should have already been checked. panic!("Wrong number of arguments"); }; - let op: JsonOp = serde_yaml::from_value(arg.value.clone()).unwrap(); // TODO Errors! + let op: OpaqueTk1Op = serde_yaml::from_value(arg.value.clone()).unwrap(); // TODO Errors! Ok(op.signature().into()) } diff --git a/tket2/src/ops.rs b/tket2/src/ops.rs index f941c517..42ea2ab9 100644 --- a/tket2/src/ops.rs +++ b/tket2/src/ops.rs @@ -90,9 +90,12 @@ pub enum Pauli { Z, } -#[derive(Debug, Error, PartialEq, Clone, Copy)] -#[error("Not a Tk2Op.")] -pub struct NotTk2Op; +#[derive(Debug, Error, PartialEq, Clone)] +#[error("{} is not a Tk2Op.", op.name())] +pub struct NotTk2Op { + /// The offending operation. + pub op: OpType, +} impl Pauli { /// Check if this pauli commutes with another. @@ -226,28 +229,20 @@ impl TryFrom<&OpType> for Tk2Op { type Error = NotTk2Op; fn try_from(op: &OpType) -> Result { - let OpType::CustomOp(custom_op) = op else { - return Err(NotTk2Op); - }; - - match custom_op { - CustomOp::Extension(ext) => Tk2Op::from_extension_op(ext), - CustomOp::Opaque(opaque) => { - if opaque.extension() != &EXTENSION_ID { - return Err(NotTk2Op); - } - try_from_name(opaque.name()) + { + let OpType::CustomOp(custom_op) = op else { + return Err(NotTk2Op { op: op.clone() }); + }; + + match custom_op { + CustomOp::Extension(ext) => Tk2Op::from_extension_op(ext).ok(), + CustomOp::Opaque(opaque) => match opaque.extension() == &EXTENSION_ID { + true => try_from_name(opaque.name()).ok(), + false => None, + }, } + .ok_or_else(|| NotTk2Op { op: op.clone() }) } - .map_err(|_| NotTk2Op) - } -} - -impl TryFrom for Tk2Op { - type Error = NotTk2Op; - - fn try_from(op: OpType) -> Result { - Self::try_from(&op) } } diff --git a/tket2/src/passes/commutation.rs b/tket2/src/passes/commutation.rs index 564a6b6a..970f0eb5 100644 --- a/tket2/src/passes/commutation.rs +++ b/tket2/src/passes/commutation.rs @@ -100,7 +100,7 @@ fn load_slices(circ: &Circuit) -> SliceVec { /// check if node is one we want to put in to a slice. fn is_slice_op(h: &impl HugrView, node: Node) -> bool { - let op: Result = h.get_optype(node).clone().try_into(); + let op: Result = h.get_optype(node).try_into(); op.is_ok() } @@ -156,22 +156,12 @@ fn commutes_at_slice( let port = command.port_of_qb(q, Direction::Incoming)?; - let op: Tk2Op = circ - .hugr() - .get_optype(command.node()) - .clone() - .try_into() - .ok()?; + let op: Tk2Op = circ.hugr().get_optype(command.node()).try_into().ok()?; // TODO: if not tk2op, might still have serialized commutation data we // can use. let pauli = commutation_on_port(&op.qubit_commutation(), port)?; - let other_op: Tk2Op = circ - .hugr() - .get_optype(other_com.node()) - .clone() - .try_into() - .ok()?; + let other_op: Tk2Op = circ.hugr().get_optype(other_com.node()).try_into().ok()?; let other_pauli = commutation_on_port( &other_op.qubit_commutation(), other_com.port_of_qb(q, Direction::Outgoing)?, diff --git a/tket2/src/passes/pytket.rs b/tket2/src/passes/pytket.rs index c37051ad..c4dd3ada 100644 --- a/tket2/src/passes/pytket.rs +++ b/tket2/src/passes/pytket.rs @@ -27,7 +27,8 @@ pub fn lower_to_pytket(circ: &Circuit) -> Result { } /// Errors that can occur during the lowering process. -#[derive(Clone, PartialEq, Debug, thiserror::Error)] +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] pub enum PytketLoweringError { /// An error occurred during the conversion of an operation. #[error("operation conversion error: {0}")] diff --git a/tket2/src/serialize/pytket.rs b/tket2/src/serialize/pytket.rs index e032d67c..8d335780 100644 --- a/tket2/src/serialize/pytket.rs +++ b/tket2/src/serialize/pytket.rs @@ -5,22 +5,22 @@ mod encoder; mod op; use hugr::types::Type; -pub(crate) use op::JsonOp; + +// Required for serialising ops in the tket1 hugr extension. +pub(crate) use op::serialised::OpaqueTk1Op; #[cfg(test)] mod tests; -use hugr::CircuitUnit; - use std::path::Path; use std::{fs, io}; use hugr::ops::{OpType, Value}; -use hugr::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; +use hugr::std_extensions::arithmetic::float_types::ConstF64; use thiserror::Error; use tket_json_rs::circuit_json::SerialCircuit; -use tket_json_rs::optype::OpType as JsonOpType; +use tket_json_rs::optype::OpType as SerialOpType; use crate::circuit::Circuit; @@ -59,7 +59,7 @@ impl TKETDecode for SerialCircuit { type EncodeError = TK1ConvertError; fn decode(self) -> Result { - let mut decoder = JsonDecoder::new(&self); + let mut decoder = JsonDecoder::try_new(&self)?; if !self.phase.is_empty() { // TODO - add a phase gate @@ -74,17 +74,7 @@ impl TKETDecode for SerialCircuit { } fn encode(circ: &Circuit) -> Result { - let mut encoder = JsonEncoder::new(circ); - let f64_inputs = circ.units().filter_map(|(wire, _, t)| match (wire, t) { - (CircuitUnit::Wire(wire), t) if t == FLOAT64_TYPE => Some(Ok(wire)), - (CircuitUnit::Linear(_), _) => None, - (_, typ) => Some(Err(TK1ConvertError::NonSerializableInputs { typ })), - }); - for (i, wire) in f64_inputs.enumerate() { - let wire = wire?; - let param = format!("f{i}"); - encoder.add_parameter(wire, param); - } + let mut encoder = JsonEncoder::new(circ)?; for com in circ.commands() { let optype = com.optype(); encoder.add_command(com.clone(), optype)?; @@ -93,17 +83,6 @@ impl TKETDecode for SerialCircuit { } } -/// Error type for conversion between `Op` and `OpType`. -#[derive(Clone, PartialEq, Debug, Error)] -pub enum OpConvertError { - /// The serialized operation is not supported. - #[error("Unsupported serialized pytket operation: {0:?}")] - UnsupportedSerializedOp(JsonOpType), - /// The serialized operation is not supported. - #[error("Cannot serialize tket2 operation: {0:?}")] - UnsupportedOpSerialization(OpType), -} - /// Load a TKET1 circuit from a JSON file. pub fn load_tk1_json_file(path: impl AsRef) -> Result { let file = fs::File::open(path)?; @@ -169,9 +148,29 @@ pub fn save_tk1_json_str(circ: &Circuit) -> Result { /// Error type for conversion between `Op` and `OpType`. #[derive(Debug, Error)] +#[non_exhaustive] +pub enum OpConvertError { + /// The serialized operation is not supported. + #[error("Unsupported serialized pytket operation: {0:?}")] + UnsupportedSerializedOp(SerialOpType), + /// The serialized operation is not supported. + #[error("Cannot serialize tket2 operation: {0:?}")] + UnsupportedOpSerialization(OpType), + /// The opaque tket1 operation had an invalid type parameter. + #[error("Opaque TKET1 operation had an invalid type parameter. {error}")] + InvalidOpaqueTypeParam { + /// The serialization error. + #[from] + error: serde_yaml::Error, + }, +} + +/// Error type for conversion between `Op` and `OpType`. +#[derive(Debug, Error)] +#[non_exhaustive] pub enum TK1ConvertError { /// Operation conversion error. - #[error("{0}")] + #[error(transparent)] OpConversionError(#[from] OpConvertError), /// The circuit has non-serializable inputs. #[error("Circuit contains non-serializable input of type {typ}.")] @@ -179,6 +178,14 @@ pub enum TK1ConvertError { /// The unsupported type. typ: Type, }, + /// The circuit uses multi-indexed registers. + // + // This could be supported in the future, if there is a need for it. + #[error("Register {register} in the circuit has multiple indices. Tket2 does not support multi-indexed registers.")] + MultiIndexedRegister { + /// The register name. + register: String, + }, /// Invalid JSON, #[error("Invalid pytket JSON. {0}")] InvalidJson(#[from] serde_json::Error), diff --git a/tket2/src/serialize/pytket/decoder.rs b/tket2/src/serialize/pytket/decoder.rs index 3b589774..04fc6242 100644 --- a/tket2/src/serialize/pytket/decoder.rs +++ b/tket2/src/serialize/pytket/decoder.rs @@ -8,18 +8,20 @@ use std::mem; use hugr::builder::{CircuitBuilder, Container, Dataflow, DataflowHugr, FunctionBuilder}; use hugr::extension::prelude::QB_T; +use hugr::ops::OpType; use hugr::types::FunctionType; use hugr::CircuitUnit; use hugr::{Hugr, Wire}; +use itertools::Itertools; use serde_json::json; use tket_json_rs::circuit_json; use tket_json_rs::circuit_json::SerialCircuit; -use super::op::JsonOp; -use super::{try_param_to_constant, METADATA_IMPLICIT_PERM, METADATA_PHASE}; +use super::op::Tk1Op; +use super::{try_param_to_constant, TK1ConvertError, METADATA_IMPLICIT_PERM, METADATA_PHASE}; use super::{METADATA_B_REGISTERS, METADATA_Q_REGISTERS}; -use crate::extension::{LINEAR_BIT, REGISTRY}; +use crate::extension::{LINEAR_BIT, REGISTRY, TKET1_EXTENSION_ID}; use crate::symbolic_constant_op; /// The state of an in-progress [`FunctionBuilder`] being built from a [`SerialCircuit`]. @@ -32,8 +34,8 @@ pub(super) struct JsonDecoder { /// The dangling wires of the builder. /// Used to generate [`CircuitBuilder`]s. dangling_wires: Vec, - /// A map from the json registers to flat wire indices. - register_wire: HashMap, + /// A map from the json registers to the units in the circuit being built. + register_units: HashMap, /// The number of qubits in the circuit. num_qubits: usize, /// The number of bits in the circuit. @@ -42,32 +44,17 @@ pub(super) struct JsonDecoder { impl JsonDecoder { /// Initialize a new [`JsonDecoder`], using the metadata from a [`SerialCircuit`]. - pub fn new(serialcirc: &SerialCircuit) -> Self { + pub fn try_new(serialcirc: &SerialCircuit) -> Result { let num_qubits = serialcirc.qubits.len(); let num_bits = serialcirc.bits.len(); - - // Map each (register name, index) pair to an offset in the signature. - let mut wire_map: HashMap = - HashMap::with_capacity(num_bits + num_qubits); - for (i, register) in serialcirc - .qubits - .iter() - .chain(serialcirc.bits.iter()) - .enumerate() - { - if register.1.len() != 1 { - // TODO: Support multi-index registers? - panic!("Register {} has more than one index", register.0); - } - wire_map.insert((register, 0).into(), i); - } let sig = FunctionType::new_endo( [vec![QB_T; num_qubits], vec![LINEAR_BIT.clone(); num_bits]].concat(), - ); - // .with_extension_delta(&ExtensionSet::singleton(&TKET1_EXTENSION_ID)); + ) + .with_extension_delta(TKET1_EXTENSION_ID); let name = serialcirc.name.clone().unwrap_or_default(); let mut dfg = FunctionBuilder::new(name, sig.into()).unwrap(); + let dangling_wires = dfg.input_wires().collect::>(); // Metadata. The circuit requires "name", and we store other things that // should pass through the serialization roundtrip. @@ -79,14 +66,25 @@ impl JsonDecoder { dfg.set_metadata(METADATA_Q_REGISTERS, json!(serialcirc.qubits)); dfg.set_metadata(METADATA_B_REGISTERS, json!(serialcirc.bits)); - let dangling_wires = dfg.input_wires().collect::>(); - JsonDecoder { + // Map each register element to their starting `CircuitUnit`. + let mut wire_map: HashMap = + HashMap::with_capacity(num_bits + num_qubits); + for (i, register) in serialcirc.qubits.iter().enumerate() { + check_register(register)?; + wire_map.insert(register.into(), CircuitUnit::Linear(i)); + } + for (i, register) in serialcirc.bits.iter().enumerate() { + check_register(register)?; + wire_map.insert(register.into(), CircuitUnit::Linear(i + num_qubits)); + } + + Ok(JsonDecoder { hugr: dfg, dangling_wires, - register_wire: wire_map, + register_units: wire_map, num_qubits, num_bits, - } + }) } /// Finish building the [`Hugr`]. @@ -97,41 +95,47 @@ impl JsonDecoder { .unwrap() } - /// Add a [`Command`] from the serial circuit to the [`JsonDecoder`]. - /// - /// - [`Command`]: circuit_json::Command + /// Add a tket1 [`circuit_json::Command`] from the serial circuit to the + /// decoder. pub fn add_command(&mut self, command: circuit_json::Command) { // TODO Store the command's `opgroup` in the metadata. let circuit_json::Command { op, args, .. } = command; let num_qubits = args .iter() - .take_while(|&arg| self.reg_wire(arg, 0) < self.num_qubits) + .take_while(|&arg| match self.reg_wire(arg) { + CircuitUnit::Linear(i) => i < self.num_qubits, + _ => false, + }) .count(); - let num_bits = args.len() - num_qubits; - let op = JsonOp::new_from_op(op, num_qubits, num_bits); + let num_input_bits = args.len() - num_qubits; + let op_params = op.params.clone(); + let tk1op = Tk1Op::from_serialised_op(op, num_qubits, num_input_bits); - let args: Vec<_> = args.into_iter().map(|reg| self.reg_wire(®, 0)).collect(); - - let param_wires: Vec = op - .param_inputs() - .map(|p| self.create_param_wire(p)) - .collect(); + let param_units = tk1op + .param_ports() + .enumerate() + .filter_map(|(i, _port)| op_params.as_ref()?.get(i).map(String::as_ref)) + .map(|p| CircuitUnit::Wire(self.create_param_wire(p))) + .collect_vec(); + let arg_units = args.into_iter().map(|reg| self.reg_wire(®)); - let append_wires = args - .into_iter() - .map(CircuitUnit::Linear) - .chain(param_wires.into_iter().map(CircuitUnit::Wire)); + let append_wires: Vec = arg_units.chain(param_units).collect_vec(); + let op: OpType = (&tk1op).into(); self.with_circ_builder(|circ| { - circ.append_and_consume(&op, append_wires).unwrap(); + circ.append_and_consume(op, append_wires).unwrap(); }); } /// Apply a function to the internal hugr builder viewed as a [`CircuitBuilder`]. - fn with_circ_builder(&mut self, f: impl FnOnce(&mut CircuitBuilder>)) { + fn with_circ_builder(&mut self, f: F) -> T + where + F: FnOnce(&mut CircuitBuilder>) -> T, + { let mut circ = self.hugr.as_circuit(mem::take(&mut self.dangling_wires)); - f(&mut circ); + let res = f(&mut circ); self.dangling_wires = circ.finish(); + res } /// Returns the wire carrying a parameter. @@ -151,11 +155,11 @@ impl JsonDecoder { } } - /// Return the wire index for the `elem`th value of a given register. + /// Return the wire unit for the `elem`th value of a given register. /// /// Relies on TKET1 constraint that all registers have unique names. - fn reg_wire(&self, register: &circuit_json::Register, elem: usize) -> usize { - self.register_wire[&(register, elem).into()] + fn reg_wire(&self, register: &circuit_json::Register) -> CircuitUnit { + self.register_units[®ister.into()] } } @@ -166,13 +170,24 @@ struct RegisterHash { hash: u64, } -impl From<(&circuit_json::Register, usize)> for RegisterHash { - fn from((reg, elem): (&circuit_json::Register, usize)) -> Self { +impl From<&circuit_json::Register> for RegisterHash { + fn from(reg: &circuit_json::Register) -> Self { let mut hasher = DefaultHasher::new(); reg.0.hash(&mut hasher); - reg.1[elem].hash(&mut hasher); + reg.1.hash(&mut hasher); Self { hash: hasher.finish(), } } } + +/// Only single-indexed registers are supported. +fn check_register(register: &circuit_json::Register) -> Result<(), TK1ConvertError> { + if register.1.len() != 1 { + Err(TK1ConvertError::MultiIndexedRegister { + register: register.0.clone(), + }) + } else { + Ok(()) + } +} diff --git a/tket2/src/serialize/pytket/encoder.rs b/tket2/src/serialize/pytket/encoder.rs index 5704fb04..e64768b8 100644 --- a/tket2/src/serialize/pytket/encoder.rs +++ b/tket2/src/serialize/pytket/encoder.rs @@ -5,6 +5,7 @@ use std::collections::HashMap; use hugr::extension::prelude::QB_T; use hugr::ops::{NamedOp, OpType}; +use hugr::std_extensions::arithmetic::float_types::FLOAT64_TYPE; use hugr::{HugrView, Wire}; use itertools::{Either, Itertools}; use tket_json_rs::circuit_json::{self, Permutation, Register, SerialCircuit}; @@ -15,10 +16,10 @@ use crate::extension::LINEAR_BIT; use crate::ops::{match_symb_const_op, op_matches}; use crate::Tk2Op; -use super::op::JsonOp; +use super::op::Tk1Op; use super::{ - try_constant_to_param, OpConvertError, METADATA_B_REGISTERS, METADATA_IMPLICIT_PERM, - METADATA_PHASE, METADATA_Q_REGISTERS, + try_constant_to_param, OpConvertError, TK1ConvertError, METADATA_B_REGISTERS, + METADATA_IMPLICIT_PERM, METADATA_PHASE, METADATA_Q_REGISTERS, }; /// The state of an in-progress [`SerialCircuit`] being built from a [`Circuit`]. @@ -37,8 +38,14 @@ pub(super) struct JsonEncoder { /// The TKET1 bit registers associated to each linear bit unit of the circuit. bit_to_reg: HashMap, /// The ordered TKET1 names for the input qubit registers. + /// + /// Nb: Although `tket-json-rs` calls these "registers", they're actually + /// identifiers for single qubits in the `Register::0` register. qubit_registers: Vec, /// The ordered TKET1 names for the input bit registers. + /// + /// Nb: Although `tket-json-rs` calls these "registers", they're actually + /// identifiers for single bits in the `Register::0` register. bit_registers: Vec, /// A register of wires with constant values, used to recover TKET1 /// parameters. @@ -47,7 +54,7 @@ pub(super) struct JsonEncoder { impl JsonEncoder { /// Create a new [`JsonEncoder`] from a [`Circuit`]. - pub fn new(circ: &Circuit) -> Self { + pub fn new(circ: &Circuit) -> Result { let name = circ.name().map(str::to_string); let hugr = circ.hugr(); @@ -74,11 +81,11 @@ impl JsonEncoder { // Map the Hugr units to tket1 register names. // Uses the names from the metadata if available, or initializes new sequentially-numbered registers. - let mut bit_to_reg = HashMap::new(); let mut qubit_to_reg = HashMap::new(); - let get_register = |registers: &mut Vec, prefix: &str, index| { + let mut bit_to_reg = HashMap::new(); + let get_register = |registers: &mut Vec, name: &str, index| { registers.get(index).cloned().unwrap_or_else(|| { - let r = Register(prefix.to_string(), vec![index as i64]); + let r = Register(name.to_string(), vec![index as i64]); registers.push(r.clone()); r }) @@ -95,7 +102,7 @@ impl JsonEncoder { } } - Self { + let mut encoder = Self { name, phase, implicit_permutation, @@ -105,7 +112,11 @@ impl JsonEncoder { qubit_registers, bit_registers, parameters: HashMap::new(), - } + }; + + encoder.add_input_parameters(circ)?; + + Ok(encoder) } /// Add a circuit command to the serialization. @@ -136,8 +147,16 @@ impl JsonEncoder { // TODO Restore the opgroup (once the decoding supports it) let opgroup = None; - let op: JsonOp = optype.try_into()?; - let mut op: circuit_json::Operation = op.into_operation(); + + // Convert the command's operator to a pytket serialized one. This will + // return an error for operations that should have been caught by the + // `record_parameters` branch above (in addition to other unsupported + // ops). + let op: Tk1Op = Tk1Op::try_from_optype(optype.clone())?; + let mut op: circuit_json::Operation = op + .serialised_op() + .ok_or_else(|| OpConvertError::UnsupportedOpSerialization(optype.clone()))?; + if !params.is_empty() { op.params = Some( params @@ -155,6 +174,7 @@ impl JsonEncoder { Ok(()) } + /// Finish building and return the final [`SerialCircuit`]. pub fn finish(self) -> SerialCircuit { SerialCircuit { name: self.name, @@ -237,7 +257,28 @@ impl JsonEncoder { .cloned() } - pub(super) fn add_parameter(&mut self, wire: Wire, param: String) { + /// Associate a parameter expression with a wire. + fn add_parameter(&mut self, wire: Wire, param: String) { self.parameters.insert(wire, param); } + + /// Adds a parameter for each floating-point input to the circuit. + fn add_input_parameters( + &mut self, + circ: &Circuit, + ) -> Result<(), TK1ConvertError> { + let mut num_f64_inputs = 0; + for (wire, _, typ) in circ.units() { + match wire { + CircuitUnit::Linear(_) => {} + CircuitUnit::Wire(wire) if typ == FLOAT64_TYPE => { + let param = format!("f{num_f64_inputs}"); + num_f64_inputs += 1; + self.add_parameter(wire, param); + } + CircuitUnit::Wire(_) => return Err(TK1ConvertError::NonSerializableInputs { typ }), + } + } + Ok(()) + } } diff --git a/tket2/src/serialize/pytket/op.rs b/tket2/src/serialize/pytket/op.rs index 5dc898a1..85a04c54 100644 --- a/tket2/src/serialize/pytket/op.rs +++ b/tket2/src/serialize/pytket/op.rs @@ -1,279 +1,110 @@ -//! This module defines the internal `JsonOp` struct wrapping the logic for +//! This module defines the internal [`Tk1Op`] struct wrapping the logic for //! going between `tket_json_rs::optype::OpType` and `hugr::ops::OpType`. //! -//! The `JsonOp` tries to homogenize the +//! The `Tk1Op` tries to homogenize the //! `tket_json_rs::circuit_json::Operation`s coming from the encoded TKET1 //! circuits by ensuring they always define a signature, and computing the //! explicit count of qubits and linear bits. -use hugr::extension::prelude::QB_T; +mod native; +pub(crate) mod serialised; -use hugr::ops::custom::CustomOp; -use hugr::ops::{Noop, OpTrait, OpType}; -use hugr::std_extensions::arithmetic::float_types::FLOAT64_TYPE; -use hugr::types::FunctionType; - -use itertools::Itertools; +use hugr::ops::OpType; +use hugr::IncomingPort; use tket_json_rs::circuit_json; -use tket_json_rs::optype::OpType as JsonOpType; +use self::native::NativeOp; +use self::serialised::OpaqueTk1Op; use super::OpConvertError; -use crate::extension::{try_unwrap_json_op, LINEAR_BIT}; -use crate::Tk2Op; -/// A serialized operation, containing the operation type and all its attributes. +/// An operation originating from pytket, containing the operation type and all its attributes. /// /// Wrapper around [`tket_json_rs::circuit_json::Operation`] with cached number of qubits and bits. /// /// The `Operation` contained by this struct is guaranteed to have a signature. -#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)] -pub(crate) struct JsonOp { - op: circuit_json::Operation, - num_qubits: usize, - num_bits: usize, - /// Node input for each parameter in `op.params`. - /// - /// If the input is `None`, the parameter does not use a Hugr port and is - /// instead stored purely as metadata for the `Operation`. - param_inputs: Vec>, - /// The number of non-None inputs in `param_inputs`, corresponding to the - /// FLOAT64_TYPE inputs to the Hugr operation. - num_params: usize, +#[derive(Clone, Debug, PartialEq, derive_more::From)] +pub enum Tk1Op { + /// An operation with a native TKET2 counterpart. + Native(NativeOp), + /// An operation without a native TKET2 counterpart. + Opaque(OpaqueTk1Op), } -impl JsonOp { - /// Create a new `JsonOp` from a `circuit_json::Operation`, computing its - /// number of qubits from the signature +impl Tk1Op { + /// Create a new `Tk1Op` from a hugr optype. /// - /// Fails if the operation does not define a signature. See - /// [`JsonOp::new_from_op`] for a version that generates a signature if none - /// is defined. - #[allow(unused)] - #[allow(clippy::question_mark)] - pub fn new(op: circuit_json::Operation) -> Option { - let Some(sig) = &op.signature else { - return None; - }; - let input_counts = sig.iter().map(String::as_ref).counts(); - let num_qubits = input_counts.get("Q").copied().unwrap_or(0); - let num_bits = input_counts.get("B").copied().unwrap_or(0); - let mut op = Self { - op, - num_qubits, - num_bits, - param_inputs: Vec::new(), - num_params: 0, + /// Supports either native `Tk2Op`s or serialised tket1 `CustomOps`s. + /// + /// # Errors + /// + /// Returns an error if the operation is not supported by the TKET1 serialization. + pub fn try_from_optype(op: OpType) -> Result { + let res = (&op).try_into(); + let tk1_op = if let Ok(tk2op) = res { + NativeOp::try_from_tk2op(tk2op).map(Tk1Op::Native) + } else { + OpaqueTk1Op::try_from_tket2(&op)?.map(Tk1Op::Opaque) }; - op.compute_param_fields(); - Some(op) + + tk1_op.ok_or(OpConvertError::UnsupportedOpSerialization(op)) } - /// Create a new `JsonOp` from a `circuit_json::Operation`, with the number - /// of qubits and bits explicitly specified. + /// Create a new `Tk1Op` from a tket1 `circuit_json::Operation`. /// - /// If the operation does not define a signature, one is generated with the - /// given amounts. - pub fn new_from_op( - mut op: circuit_json::Operation, + /// If `serial_op` defines a signature then `num_qubits` and `num_qubits` are ignored. Otherwise, a signature is synthesised from those parameters. + pub fn from_serialised_op( + serial_op: circuit_json::Operation, num_qubits: usize, num_bits: usize, ) -> Self { - if op.signature.is_none() { - op.signature = - Some([vec!["Q".into(); num_qubits], vec!["B".into(); num_bits]].concat()); + if let Some(native) = NativeOp::try_from_serial_optype(serial_op.op_type.clone()) { + Tk1Op::Native(native) + } else { + Tk1Op::Opaque(OpaqueTk1Op::new_from_op(serial_op, num_qubits, num_bits)) } - let mut op = Self { - op, - num_qubits, - num_bits, - param_inputs: Vec::new(), - num_params: 0, - }; - op.compute_param_fields(); - op } - /// Create a new `JsonOp` from the optype and the number of parameters. - pub fn new_with_counts( - json_optype: JsonOpType, - num_qubits: usize, - num_bits: usize, - num_params: usize, - ) -> Self { - let mut params = None; - let mut param_inputs = vec![]; - if num_params > 0 { - let offset = num_qubits + num_bits; - params = Some(vec!["".into(); num_params]); - param_inputs = (offset..offset + num_params).map(Option::Some).collect(); - } - let op = circuit_json::Operation { - op_type: json_optype, - n_qb: Some(num_qubits as u32), - params, - op_box: None, - signature: Some([vec!["Q".into(); num_qubits], vec!["B".into(); num_bits]].concat()), - conditional: None, - }; - Self { - op, - num_qubits, - num_bits, - param_inputs, - num_params, + /// Get the hugr optype for the operation. + pub fn optype(&self) -> OpType { + match self { + Tk1Op::Native(native_op) => native_op.optype().clone(), + Tk1Op::Opaque(json_op) => json_op.as_custom_op().into(), } } - /// Compute the signature of the operation. - #[inline] - pub fn signature(&self) -> FunctionType { - let linear = [ - vec![QB_T; self.num_qubits], - vec![LINEAR_BIT.clone(); self.num_bits], - ] - .concat(); - let params = vec![FLOAT64_TYPE; self.num_params]; - FunctionType::new([linear.clone(), params].concat(), linear) - // .with_extension_delta(&ExtensionSet::singleton(&TKET1_EXTENSION_ID)) - } - - /// List of parameters in the operation that should be exposed as inputs. - #[inline] - pub fn param_inputs(&self) -> impl Iterator { - self.param_inputs - .iter() - .filter_map(|&i| self.op.params.as_ref()?.get(i?).map(String::as_ref)) - } - - pub fn into_operation(self) -> circuit_json::Operation { - self.op + /// Consumes the operation and returns a hugr optype. + pub fn into_optype(self) -> OpType { + match self { + Tk1Op::Native(native_op) => native_op.into_optype(), + Tk1Op::Opaque(json_op) => json_op.as_custom_op().into(), + } } - /// Wraps the op into a Hugr opaque operation - fn as_custom_op(&self) -> CustomOp { - crate::extension::wrap_json_op(self) + /// Get the [`tket_json_rs::circuit_json::Operation`] for the operation. + pub fn serialised_op(&self) -> Option { + match self { + Tk1Op::Native(native_op) => native_op.serialised_op(), + Tk1Op::Opaque(json_op) => Some(json_op.serialised_op().clone()), + } } - /// Compute the `parameter_input` and `num_params` fields by looking for - /// parameters in `op.params` that can be mapped to input wires in the Hugr. - /// - /// Updates the internal `num_params` and `param_inputs` fields. - fn compute_param_fields(&mut self) { - let Some(params) = self.op.params.as_ref() else { - self.param_inputs = vec![]; - self.num_params = 0; - return; - }; - - self.num_params = params.len(); - self.param_inputs = (0..params.len()).map(Some).collect(); + /// Returns the ports corresponding to parameters for this operation. + pub fn param_ports(&self) -> impl Iterator + '_ { + match self { + Tk1Op::Native(native_op) => itertools::Either::Left(native_op.param_ports()), + Tk1Op::Opaque(json_op) => itertools::Either::Right(json_op.param_ports()), + } } } -impl From<&JsonOp> for OpType { - /// Convert the operation into a HUGR operation. - /// - /// We only translate operations that have a 1:1 mapping between TKET and HUGR. - /// Any other operation is wrapped in an `OpaqueOp`. - fn from(json_op: &JsonOp) -> Self { - match json_op.op.op_type { - JsonOpType::H => Tk2Op::H.into(), - JsonOpType::CX => Tk2Op::CX.into(), - JsonOpType::T => Tk2Op::T.into(), - JsonOpType::Tdg => Tk2Op::Tdg.into(), - JsonOpType::X => Tk2Op::X.into(), - JsonOpType::Y => Tk2Op::Y.into(), - JsonOpType::Z => Tk2Op::Z.into(), - JsonOpType::Rz => Tk2Op::RzF64.into(), - JsonOpType::Rx => Tk2Op::RxF64.into(), - JsonOpType::TK1 => Tk2Op::TK1.into(), - JsonOpType::PhasedX => Tk2Op::PhasedX.into(), - JsonOpType::ZZMax => Tk2Op::ZZMax.into(), - JsonOpType::ZZPhase => Tk2Op::ZZPhase.into(), - JsonOpType::CZ => Tk2Op::CZ.into(), - JsonOpType::Reset => Tk2Op::Reset.into(), - JsonOpType::noop => { - // TODO: Replace with `Noop::new` once that is published. - let mut noop = Noop::default(); - noop.ty = QB_T; - noop.into() - } - _ => json_op.as_custom_op().into(), - } +impl From for OpType { + fn from(tk1_op: Tk1Op) -> Self { + tk1_op.into_optype() } } -impl TryFrom<&OpType> for JsonOp { - type Error = OpConvertError; - - fn try_from(op: &OpType) -> Result { - // We only translate operations that have a 1:1 mapping between TKET and TKET2 - // - // Other TKET1 operations are wrapped in an `OpaqueOp`. - // - // Non-supported Hugr operations throw an error. - let err = || OpConvertError::UnsupportedOpSerialization(op.clone()); - - let Ok(tk2op) = op.try_into() else { - if let OpType::CustomOp(custom_op) = op { - return try_unwrap_json_op(custom_op).ok_or_else(err); - } else { - return Err(err()); - } - }; - - let json_optype = match tk2op { - Tk2Op::H => JsonOpType::H, - Tk2Op::CX => JsonOpType::CX, - Tk2Op::T => JsonOpType::T, - Tk2Op::S => JsonOpType::S, - Tk2Op::X => JsonOpType::X, - Tk2Op::Y => JsonOpType::Y, - Tk2Op::Z => JsonOpType::Z, - Tk2Op::Tdg => JsonOpType::Tdg, - Tk2Op::Sdg => JsonOpType::Sdg, - Tk2Op::ZZMax => JsonOpType::ZZMax, - Tk2Op::Measure => { - unimplemented!( - "Cannot convert TKET2 Measure to TKET1 due to mismatching semantics." - ) - } - Tk2Op::RzF64 => JsonOpType::Rz, - Tk2Op::RxF64 => JsonOpType::Rx, - // TODO: Use a TK2 opaque op once we update the tket-json-rs dependency. - Tk2Op::AngleAdd => { - unimplemented!("Serialising AngleAdd not supported. Are all constants folded?") - } - Tk2Op::TK1 => JsonOpType::TK1, - Tk2Op::PhasedX => JsonOpType::PhasedX, - Tk2Op::ZZPhase => JsonOpType::ZZPhase, - Tk2Op::CZ => JsonOpType::CZ, - Tk2Op::Reset => JsonOpType::Reset, - Tk2Op::QAlloc | Tk2Op::QFree => { - unimplemented!("TKET1 does not support dynamic qubit allocation/discarding.") - } - }; - - let mut num_qubits = 0; - let mut num_bits = 0; - let mut num_params = 0; - if let Some(sig) = op.dataflow_signature() { - for ty in sig.input.iter() { - if ty == &QB_T { - num_qubits += 1 - } else if *ty == *LINEAR_BIT { - num_bits += 1 - } else if ty == &FLOAT64_TYPE { - num_params += 1 - } - } - } - - Ok(JsonOp::new_with_counts( - json_optype, - num_qubits, - num_bits, - num_params, - )) +impl From<&Tk1Op> for OpType { + fn from(tk1_op: &Tk1Op) -> Self { + tk1_op.optype() } } diff --git a/tket2/src/serialize/pytket/op/native.rs b/tket2/src/serialize/pytket/op/native.rs new file mode 100644 index 00000000..ad91a0f6 --- /dev/null +++ b/tket2/src/serialize/pytket/op/native.rs @@ -0,0 +1,195 @@ +//! Operations that have corresponding representations in both `pytket` and `tket2`. + +use hugr::extension::prelude::QB_T; + +use hugr::ops::{Noop, OpTrait, OpType}; +use hugr::std_extensions::arithmetic::float_types::FLOAT64_TYPE; +use hugr::types::FunctionType; + +use hugr::IncomingPort; +use tket_json_rs::circuit_json; +use tket_json_rs::optype::OpType as Tk1OpType; + +use crate::extension::LINEAR_BIT; +use crate::Tk2Op; + +/// An operation with a native TKET2 counterpart. +/// +/// Note that the signature of the native and serialised operations may differ. +#[derive(Clone, Debug, PartialEq)] +pub struct NativeOp { + /// The tket2 optype. + op: OpType, + /// The corresponding serialised optype. + /// + /// Some specific operations do not have a direct pytket counterpart, and must be handled + /// separately. + serial_op: Option, +} + +impl NativeOp { + /// Create a new `NativeOp` from a `circuit_json::Operation`. + pub fn try_from_tk2op(tk2op: Tk2Op) -> Option { + let serial_op = match tk2op { + Tk2Op::H => Tk1OpType::H, + Tk2Op::CX => Tk1OpType::CX, + Tk2Op::T => Tk1OpType::T, + Tk2Op::S => Tk1OpType::S, + Tk2Op::X => Tk1OpType::X, + Tk2Op::Y => Tk1OpType::Y, + Tk2Op::Z => Tk1OpType::Z, + Tk2Op::Tdg => Tk1OpType::Tdg, + Tk2Op::Sdg => Tk1OpType::Sdg, + Tk2Op::ZZMax => Tk1OpType::ZZMax, + Tk2Op::RzF64 => Tk1OpType::Rz, + Tk2Op::RxF64 => Tk1OpType::Rx, + Tk2Op::TK1 => Tk1OpType::TK1, + Tk2Op::PhasedX => Tk1OpType::PhasedX, + Tk2Op::ZZPhase => Tk1OpType::ZZPhase, + Tk2Op::CZ => Tk1OpType::CZ, + Tk2Op::Reset => Tk1OpType::Reset, + Tk2Op::AngleAdd => { + // These operations should be folded into constant before serialisation, + // or replaced by pytket logic expressions. + return Some(Self { + op: tk2op.into(), + serial_op: None, + }); + } + // TKET2 measurements and TKET1 measurements have different semantics. + Tk2Op::Measure => { + return None; + } + // These operations do not have a direct pytket counterpart. + Tk2Op::QAlloc | Tk2Op::QFree => { + return None; + } + }; + + Some(Self { + op: tk2op.into(), + serial_op: Some(serial_op), + }) + } + + /// Returns the translated tket2 optype for this operation, if it exists. + pub fn try_from_serial_optype(serial_op: Tk1OpType) -> Option { + let op = match serial_op { + Tk1OpType::H => Tk2Op::H.into(), + Tk1OpType::CX => Tk2Op::CX.into(), + Tk1OpType::T => Tk2Op::T.into(), + Tk1OpType::S => Tk2Op::S.into(), + Tk1OpType::X => Tk2Op::X.into(), + Tk1OpType::Y => Tk2Op::Y.into(), + Tk1OpType::Z => Tk2Op::Z.into(), + Tk1OpType::Tdg => Tk2Op::Tdg.into(), + Tk1OpType::Sdg => Tk2Op::Sdg.into(), + Tk1OpType::Rz => Tk2Op::RzF64.into(), + Tk1OpType::Rx => Tk2Op::RxF64.into(), + Tk1OpType::TK1 => Tk2Op::TK1.into(), + Tk1OpType::PhasedX => Tk2Op::PhasedX.into(), + Tk1OpType::ZZMax => Tk2Op::ZZMax.into(), + Tk1OpType::ZZPhase => Tk2Op::ZZPhase.into(), + Tk1OpType::CZ => Tk2Op::CZ.into(), + Tk1OpType::Reset => Tk2Op::Reset.into(), + Tk1OpType::noop => Noop::new(QB_T).into(), + _ => { + return None; + } + }; + Some(Self { + op, + serial_op: Some(serial_op), + }) + } + + /// Converts this `NativeOp` into a tket_json_rs operation. + pub fn serialised_op(&self) -> Option { + let serial_op = self.serial_op.clone()?; + + let mut num_qubits = 0; + let mut num_bits = 0; + let mut num_params = 0; + if let Some(sig) = self.signature() { + for ty in sig.input.iter() { + if ty == &QB_T { + num_qubits += 1 + } else if *ty == *LINEAR_BIT { + num_bits += 1 + } else if ty == &FLOAT64_TYPE { + num_params += 1 + } + } + } + + let params = (num_params > 0).then(|| vec!["".into(); num_params]); + + Some(circuit_json::Operation { + op_type: serial_op, + n_qb: Some(num_qubits as u32), + params, + op_box: None, + signature: Some([vec!["Q".into(); num_qubits], vec!["B".into(); num_bits]].concat()), + conditional: None, + }) + } + + /// Returns the dataflow signature for this operation. + pub fn signature(&self) -> Option { + self.op.dataflow_signature() + } + + /// Returns the tket2 optype for this operation. + pub fn optype(&self) -> &OpType { + &self.op + } + + /// Consumes the `NativeOp` and returns the underlying `OpType`. + pub fn into_optype(self) -> OpType { + self.op + } + + /// Returns the ports corresponding to parameters for this operation. + pub fn param_ports(&self) -> impl Iterator + '_ { + self.signature().into_iter().flat_map(|sig| { + let types = sig.input_types().to_owned(); + sig.input_ports() + .zip(types) + .filter(|(_, ty)| ty == &FLOAT64_TYPE) + .map(|(port, _)| port) + }) + } +} + +#[cfg(test)] +mod cfg { + use super::*; + use hugr::ops::NamedOp; + use rstest::rstest; + use strum::IntoEnumIterator; + + #[rstest] + fn tk2_optype_correspondence() { + for tk2op in Tk2Op::iter() { + let Some(native_op) = NativeOp::try_from_tk2op(tk2op) else { + // Ignore unsupported ops. + continue; + }; + + let Some(serial_op) = native_op.serial_op.clone() else { + // Ignore ops that do not have a serialised equivalent. + // (But are still handled by the encoder). + continue; + }; + + let Some(native_op2) = NativeOp::try_from_serial_optype(serial_op.clone()) else { + panic!( + "{} serialises into {serial_op:?}, but failed to be deserialised.", + tk2op.name() + ) + }; + + assert_eq!(native_op, native_op2); + } + } +} diff --git a/tket2/src/serialize/pytket/op/serialised.rs b/tket2/src/serialize/pytket/op/serialised.rs new file mode 100644 index 00000000..3f88ef66 --- /dev/null +++ b/tket2/src/serialize/pytket/op/serialised.rs @@ -0,0 +1,157 @@ +//! Wrapper over pytket operations that cannot be represented naturally in tket2. + +use hugr::extension::prelude::QB_T; + +use hugr::ops::custom::{CustomOp, ExtensionOp}; +use hugr::ops::{NamedOp, OpType}; +use hugr::std_extensions::arithmetic::float_types::FLOAT64_TYPE; +use hugr::types::type_param::CustomTypeArg; +use hugr::types::{FunctionType, TypeArg}; + +use hugr::IncomingPort; +use serde::de::Error; +use tket_json_rs::circuit_json; + +use crate::extension::{ + LINEAR_BIT, REGISTRY, TKET1_EXTENSION, TKET1_EXTENSION_ID, TKET1_OP_NAME, TKET1_OP_PAYLOAD, +}; +use crate::serialize::pytket::OpConvertError; + +/// A serialized operation, containing the operation type and all its attributes. +/// +/// This value is only used if the operation does not have a native TKET2 +/// counterpart that can be represented as a [`NativeOp`]. +/// +/// Wrapper around [`tket_json_rs::circuit_json::Operation`] with cached number of qubits and bits. +/// +/// The `Operation` contained by this struct is guaranteed to have a signature. +/// +/// [`NativeOp`]: super::native::NativeOp +#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)] +pub struct OpaqueTk1Op { + /// Internal operation data. + op: circuit_json::Operation, + /// Number of qubits declared by the operation. + num_qubits: usize, + /// Number of bits declared by the operation. + num_bits: usize, + /// Node input for each parameter in `op.params`. + /// + /// If the input is `None`, the parameter does not use a Hugr port and is + /// instead stored purely as metadata for the `Operation`. + param_inputs: Vec>, + /// The number of non-None inputs in `param_inputs`, corresponding to the + /// FLOAT64_TYPE inputs to the Hugr operation. + num_params: usize, +} + +impl OpaqueTk1Op { + /// Create a new `OpaqueTk1Op` from a `circuit_json::Operation`, with the number + /// of qubits and bits explicitly specified. + /// + /// If the operation does not define a signature, one is generated with the + /// given amounts. + pub fn new_from_op( + mut op: circuit_json::Operation, + num_qubits: usize, + num_bits: usize, + ) -> Self { + if op.signature.is_none() { + op.signature = + Some([vec!["Q".into(); num_qubits], vec!["B".into(); num_bits]].concat()); + } + let mut op = Self { + op, + num_qubits, + num_bits, + param_inputs: Vec::new(), + num_params: 0, + }; + op.compute_param_fields(); + op + } + + /// Try to convert a tket2 operation into a `OpaqueTk1Op`. + /// + /// Only succeeds if the operation is a [`CustomOp`] containing a tket1 operation + /// from the [`TKET1_EXTENSION_ID`] extension. Returns `None` if the operation + /// is not a tket1 operation. + /// + /// # Errors + /// + /// Returns an [`OpConvertError`] if the operation is a tket1 operation, but it + /// contains invalid data. + pub fn try_from_tket2(op: &OpType) -> Result, OpConvertError> { + let OpType::CustomOp(custom_op) = op else { + return Ok(None); + }; + + // TODO: Check `extensions.contains(&TKET1_EXTENSION_ID)` + // (but the ext op extensions are an empty set?) + if custom_op.name() != format!("{TKET1_EXTENSION_ID}.{TKET1_OP_NAME}") { + return Ok(None); + } + let Some(TypeArg::Opaque { arg }) = custom_op.args().first() else { + return Err(serde_yaml::Error::custom( + "Opaque TKET1 operation did not have a yaml-encoded type argument.", + ) + .into()); + }; + let op = serde_yaml::from_value(arg.value.clone())?; + Ok(Some(op)) + } + + /// Compute the signature of the operation. + /// + /// The signature returned has `num_qubits` qubit inputs, followed by + /// `num_bits` bit inputs, followed by `num_params` `f64` inputs. It has + /// `num_qubits` qubit outputs followed by `num_bits` bit outputs. + #[inline] + pub fn signature(&self) -> FunctionType { + let linear = [ + vec![QB_T; self.num_qubits], + vec![LINEAR_BIT.clone(); self.num_bits], + ] + .concat(); + let params = vec![FLOAT64_TYPE; self.num_params]; + FunctionType::new([linear.clone(), params].concat(), linear) + .with_extension_delta(TKET1_EXTENSION_ID) + } + + /// Returns the ports corresponding to parameters for this operation. + pub fn param_ports(&self) -> impl Iterator + '_ { + self.param_inputs.iter().filter_map(|&i| i) + } + + /// Returns the lower level `circuit_json::Operation` contained by this struct. + pub fn serialised_op(&self) -> &circuit_json::Operation { + &self.op + } + + /// Wraps the op into a [`TKET1_OP_NAME`] opaque operation. + pub fn as_custom_op(&self) -> CustomOp { + let op = serde_yaml::to_value(self).unwrap(); + let payload = TypeArg::Opaque { + arg: CustomTypeArg::new(TKET1_OP_PAYLOAD.clone(), op).unwrap(), + }; + let op_def = TKET1_EXTENSION.get_op(&TKET1_OP_NAME).unwrap(); + ExtensionOp::new(op_def.clone(), vec![payload], ®ISTRY) + .unwrap_or_else(|e| panic!("{e}")) + .into() + } + + /// Compute the `parameter_input` and `num_params` fields by looking for + /// parameters in `op.params` that can be mapped to input wires in the Hugr. + /// + /// Updates the internal `num_params` and `param_inputs` fields. + fn compute_param_fields(&mut self) { + let Some(params) = self.op.params.as_ref() else { + self.param_inputs = vec![]; + self.num_params = 0; + return; + }; + + self.num_params = params.len(); + self.param_inputs = (0..params.len()).map(|i| Some(i.into())).collect(); + } +} diff --git a/tket2/src/serialize/pytket/tests.rs b/tket2/src/serialize/pytket/tests.rs index 7c87eabd..0e347401 100644 --- a/tket2/src/serialize/pytket/tests.rs +++ b/tket2/src/serialize/pytket/tests.rs @@ -55,6 +55,32 @@ const PARAMETRIZED: &str = r#"{ "implicit_permutation": [[["q", [0]], ["q", [0]]], [["q", [1]], ["q", [1]]]] }"#; +fn compare_serial_circs(a: &SerialCircuit, b: &SerialCircuit) { + assert_eq!(a.name, b.name); + assert_eq!(a.phase, b.phase); + + let qubits_a: Vec<_> = a.qubits.iter().collect(); + let qubits_b: Vec<_> = b.qubits.iter().collect(); + assert_eq!(qubits_a, qubits_b); + + let bits_a: Vec<_> = a.bits.iter().collect(); + let bits_b: Vec<_> = b.bits.iter().collect(); + assert_eq!(bits_a, bits_b); + + assert_eq!(a.implicit_permutation, b.implicit_permutation); + + assert_eq!(a.commands.len(), b.commands.len()); + + // the below only works if both serial circuits share a topological ordering + // of commands. + for (a, b) in a.commands.iter().zip(b.commands.iter()) { + assert_eq!(a.op.op_type, b.op.op_type); + assert_eq!(a.args, b.args); + assert_eq!(a.op.params, b.op.params); + } + // TODO: Check commands equality (they only implement PartialEq) +} + #[rstest] #[case::simple(SIMPLE_JSON, 2, 2)] #[case::unknown_op(UNKNOWN_OP, 2, 3)] @@ -138,29 +164,3 @@ fn test_add_angle_serialise(#[case] circ_add_angles: Circuit, #[case] param_str: let reser = SerialCircuit::encode(&deser).unwrap(); compare_serial_circs(&ser, &reser); } - -fn compare_serial_circs(a: &SerialCircuit, b: &SerialCircuit) { - assert_eq!(a.name, b.name); - assert_eq!(a.phase, b.phase); - - let qubits_a: Vec<_> = a.qubits.iter().collect(); - let qubits_b: Vec<_> = b.qubits.iter().collect(); - assert_eq!(qubits_a, qubits_b); - - let bits_a: Vec<_> = a.bits.iter().collect(); - let bits_b: Vec<_> = b.bits.iter().collect(); - assert_eq!(bits_a, bits_b); - - assert_eq!(a.implicit_permutation, b.implicit_permutation); - - assert_eq!(a.commands.len(), b.commands.len()); - - // the below only works if both serial circuits share a topological ordering - // of commands. - for (a, b) in a.commands.iter().zip(b.commands.iter()) { - assert_eq!(a.op.op_type, b.op.op_type); - assert_eq!(a.args, b.args); - assert_eq!(a.op.params, b.op.params); - } - // TODO: Check commands equality (they only implement PartialEq) -}