From 9ef975e149ce9104ffeb1c5a1b958a7c624fd72f Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 7 Dec 2023 17:41:11 +0000 Subject: [PATCH] chore: Update devenv + bump hugr to latest (#263) --- Cargo.toml | 2 +- devenv.lock | 36 ++++----- devenv.nix | 2 +- tket2/src/circuit.rs | 10 ++- tket2/src/circuit/command.rs | 2 +- tket2/src/extension.rs | 62 ++++++++------- tket2/src/extension/angle.rs | 142 ++++++++++++++++----------------- tket2/src/json/op.rs | 4 +- tket2/src/ops.rs | 149 +++++++++++------------------------ tket2/src/passes/chunks.rs | 12 +-- 10 files changed, 177 insertions(+), 244 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index bf0dd006..a362d2ca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ missing_docs = "warn" [workspace.dependencies] tket2 = { path = "./tket2" } -quantinuum-hugr = { git = "https://github.com/CQCL/hugr", rev = "e7473f2" } +quantinuum-hugr = { git = "https://github.com/CQCL/hugr", rev = "2efcfb3" } portgraph = { version = "0.10" } pyo3 = { version = "0.20" } itertools = { version = "0.12.0" } diff --git a/devenv.lock b/devenv.lock index 2c49946a..4e8ca024 100644 --- a/devenv.lock +++ b/devenv.lock @@ -3,11 +3,11 @@ "devenv": { "locked": { "dir": "src/modules", - "lastModified": 1699492209, - "narHash": "sha256-AhaFZrKIpU6GYUaA26erOQg2X+YHHzJpJ8r1mBHOaM8=", + "lastModified": 1701187605, + "narHash": "sha256-NctguPdUeDVLXFsv6vI1RlEiHLsXkeW3pgZe/mwn1BU=", "owner": "cachix", "repo": "devenv", - "rev": "80e740c7eb91b3d1c82013ec0ba4bfbc9a83734a", + "rev": "a7c4dd8f4eb1f98a6b8f04bf08364954e1e73e4f", "type": "github" }, "original": { @@ -25,11 +25,11 @@ "rust-analyzer-src": "rust-analyzer-src" }, "locked": { - "lastModified": 1699597299, - "narHash": "sha256-uJMCDTKSUB7+K+s7SB2DS6WU2VGDmruXmP9TQwTYGkw=", + "lastModified": 1701325357, + "narHash": "sha256-+CF74n9/AlLwgdCTM5WuKsa/4C1YxJSpRDCfz1ErOl0=", "owner": "nix-community", "repo": "fenix", - "rev": "ae8ecab0dbfe3552bd1a0bf5504416fd07dd2e8a", + "rev": "07a409ce1fe2c6d6e871793394b0cc0e5e262e3b", "type": "github" }, "original": { @@ -95,11 +95,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1699343069, - "narHash": "sha256-s7BBhyLA6MI6FuJgs4F/SgpntHBzz40/qV0xLPW6A1Q=", + "lastModified": 1701237617, + "narHash": "sha256-Ryd8xpNDY9MJnBFDYhB37XSFIxCPVVVXAbInNPa95vs=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "ec750fd01963ab6b20ee1f0cb488754e8036d89d", + "rev": "85306ef2470ba705c97ce72741d56e42d0264015", "type": "github" }, "original": { @@ -111,11 +111,11 @@ }, "nixpkgs-stable": { "locked": { - "lastModified": 1699291058, - "narHash": "sha256-5ggduoaAMPHUy4riL+OrlAZE14Kh7JWX4oLEs22ZqfU=", + "lastModified": 1701053011, + "narHash": "sha256-8QQ7rFbKFqgKgLoaXVJRh7Ik5LtI3pyBBCfOnNOGkF0=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "41de143fda10e33be0f47eab2bfe08a50f234267", + "rev": "5b528f99f73c4fad127118a8c1126b5e003b01a9", "type": "github" }, "original": { @@ -152,11 +152,11 @@ "nixpkgs-stable": "nixpkgs-stable_2" }, "locked": { - "lastModified": 1699271226, - "narHash": "sha256-8Jt1KW3xTjolD6c6OjJm9USx/jmL+VVmbooADCkdDfU=", + "lastModified": 1700922917, + "narHash": "sha256-ej2fch/T584b5K9sk1UhmZF7W6wEfDHuoUYpFN8dtvM=", "owner": "cachix", "repo": "pre-commit-hooks.nix", - "rev": "ea758da1a6dcde6dc36db348ed690d09b9864128", + "rev": "e5ee5c5f3844550c01d2131096c7271cec5e9b78", "type": "github" }, "original": { @@ -177,11 +177,11 @@ "rust-analyzer-src": { "flake": false, "locked": { - "lastModified": 1699451299, - "narHash": "sha256-7HJMyp62fCS6/aCCCASz8MdJM2/M8d1pBNukyLmPdwA=", + "lastModified": 1701186284, + "narHash": "sha256-euPBY3EmEy7+Jjm2ToRPlSp/qrj0UL9+PRobxVz6+aQ=", "owner": "rust-lang", "repo": "rust-analyzer", - "rev": "7059ae2fc2d55fa20d7e2671597b516431129445", + "rev": "c7c582afb57bb802715262d7f1ba73b8a86c1c5a", "type": "github" }, "original": { diff --git a/devenv.nix b/devenv.nix index 85483656..98779650 100644 --- a/devenv.nix +++ b/devenv.nix @@ -42,7 +42,7 @@ in languages.rust = { enable = true; - channel = "stable"; + channel = "beta"; components = [ "rustc" "cargo" "clippy" "rustfmt" "rust-analyzer" ]; }; diff --git a/tket2/src/circuit.rs b/tket2/src/circuit.rs index 6dd3e346..f2e8fd7b 100644 --- a/tket2/src/circuit.rs +++ b/tket2/src/circuit.rs @@ -24,7 +24,7 @@ use portgraph::Direction; use thiserror::Error; pub use hugr::ops::OpType; -pub use hugr::types::{EdgeKind, Signature, Type, TypeRow}; +pub use hugr::types::{EdgeKind, Type, TypeRow}; pub use hugr::{Node, Port, Wire}; use self::units::{filter, FilteredUnits, Units}; @@ -46,9 +46,11 @@ pub trait Circuit: HugrView { /// /// Equivalent to [`HugrView::get_function_type`]. #[inline] - fn circuit_signature(&self) -> &FunctionType { + fn circuit_signature(&self) -> FunctionType { self.get_function_type() .expect("Circuit has no function type") + .body() + .clone() } /// Returns the input node to the circuit. @@ -338,8 +340,8 @@ mod tests { let circ = test_circuit(); assert_eq!(circ.name(), None); - assert_eq!(circ.circuit_signature().input.len(), 3); - assert_eq!(circ.circuit_signature().output.len(), 3); + assert_eq!(circ.circuit_signature().input_count(), 3); + assert_eq!(circ.circuit_signature().output_count(), 3); assert_eq!(circ.qubit_count(), 2); assert_eq!(circ.num_gates(), 3); diff --git a/tket2/src/circuit/command.rs b/tket2/src/circuit/command.rs index 891fdf36..192247b5 100644 --- a/tket2/src/circuit/command.rs +++ b/tket2/src/circuit/command.rs @@ -16,7 +16,7 @@ use super::units::{filter, DefaultUnitLabeller, LinearUnit, UnitLabeller, Units} use super::Circuit; pub use hugr::ops::OpType; -pub use hugr::types::{EdgeKind, Signature, Type, TypeRow}; +pub use hugr::types::{EdgeKind, Type, TypeRow}; pub use hugr::{CircuitUnit, Direction, Node, Port, PortIndex, Wire}; /// An operation applied to specific wires. diff --git a/tket2/src/extension.rs b/tket2/src/extension.rs index 2f65d567..d564855a 100644 --- a/tket2/src/extension.rs +++ b/tket2/src/extension.rs @@ -2,18 +2,16 @@ //! //! This includes a extension for the opaque TKET1 operations. -use std::collections::HashMap; - use super::json::op::JsonOp; -use crate::ops::load_all_ops; use crate::Tk2Op; use hugr::extension::prelude::PRELUDE; -use hugr::extension::{ExtensionId, ExtensionRegistry, SignatureError}; +use hugr::extension::simple_op::MakeOpDef; +use hugr::extension::{CustomSignatureFunc, ExtensionId, ExtensionRegistry, SignatureError}; use hugr::hugr::IdentList; use hugr::ops::custom::{ExternalOp, OpaqueOp}; use hugr::std_extensions::arithmetic::float_types::{extension as float_extension, FLOAT64_TYPE}; use hugr::types::type_param::{CustomTypeArg, TypeArg, TypeParam}; -use hugr::types::{CustomType, FunctionType, Type, TypeBound}; +use hugr::types::{CustomType, FunctionType, PolyFuncType, Type, TypeBound}; use hugr::{type_row, Extension}; use lazy_static::lazy_static; use smol_str::SmolStr; @@ -45,14 +43,11 @@ pub static ref TKET1_EXTENSION: Extension = { res.add_type(LINEAR_BIT_NAME, vec![], "A linear bit.".into(), TypeBound::Any.into()).unwrap(); let json_op_payload_def = res.add_type(JSON_PAYLOAD_NAME, vec![], "Opaque TKET1 operation metadata.".into(), TypeBound::Eq.into()).unwrap(); - let json_op_payload = TypeParam::Opaque(json_op_payload_def.instantiate([]).unwrap()); - res.add_op_custom_sig( + let json_op_payload = TypeParam::Opaque{ty:json_op_payload_def.instantiate([]).unwrap()}; + res.add_op( JSON_OP_NAME, "An opaque TKET1 operation.".into(), - vec![json_op_payload], - HashMap::new(), - vec![], - json_op_signature, + JsonOpSignature([json_op_payload]) ).unwrap(); res @@ -68,12 +63,12 @@ pub static ref LINEAR_BIT: Type = { }; /// Extension registry including the prelude, TKET1 and Tk2Ops extensions. -pub static ref REGISTRY: ExtensionRegistry = ExtensionRegistry::from([ +pub static ref REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ TKET1_EXTENSION.clone(), PRELUDE.clone(), TKET2_EXTENSION.clone(), float_extension(), -]); +]).unwrap(); } @@ -98,7 +93,7 @@ pub(crate) fn wrap_json_op(op: &JsonOp) -> ExternalOp { JSON_OP_NAME, "".into(), vec![payload], - Some(sig), + sig, ) .into() } @@ -111,7 +106,7 @@ pub(crate) fn try_unwrap_json_op(ext: &ExternalOp) -> Option { if ext.name() != format!("{TKET1_EXTENSION_ID}.{JSON_OP_NAME}") { return None; } - let Some(TypeArg::Opaque { arg }) = ext.args().get(0) else { + let Some(TypeArg::Opaque { arg }) = ext.args().first() else { // TODO: Throw an error? We should never get here if the name matches. return None; }; @@ -119,14 +114,26 @@ pub(crate) fn try_unwrap_json_op(ext: &ExternalOp) -> Option { Some(op) } -/// Compute the signature of a json-encoded TKET1 operation. -fn json_op_signature(args: &[TypeArg]) -> Result { - let [TypeArg::Opaque { arg }] = args else { - // This should have already been checked. - panic!("Wrong number of arguments"); - }; - let op: JsonOp = serde_yaml::from_value(arg.value.clone()).unwrap(); // TODO Errors! - Ok(op.signature()) +struct JsonOpSignature([TypeParam; 1]); + +impl CustomSignatureFunc for JsonOpSignature { + fn compute_signature<'o, 'a: 'o>( + &'a self, + arg_values: &[TypeArg], + _def: &'o hugr::extension::OpDef, + _extension_registry: &ExtensionRegistry, + ) -> Result { + let [TypeArg::Opaque { arg }] = arg_values else { + // This should have already been checked. + panic!("Wrong number of arguments"); + }; + let op: JsonOp = serde_yaml::from_value(arg.value.clone()).unwrap(); // TODO Errors! + Ok(op.signature().into()) + } + + fn static_params(&self) -> &[TypeParam] { + &self.0 + } } /// Angle type with given log denominator. @@ -151,7 +158,7 @@ pub static ref SYM_EXPR_T: CustomType = /// The extension definition for TKET2 ops and types. pub static ref TKET2_EXTENSION: Extension = { let mut e = Extension::new(TKET2_EXTENSION_ID); - load_all_ops::(&mut e).expect("add fail"); + Tk2Op::load_all_ops(&mut e).expect("add fail"); let sym_expr_opdef = e.add_type( SYM_EXPR_NAME, @@ -160,13 +167,12 @@ pub static ref TKET2_EXTENSION: Extension = { TypeBound::Eq.into(), ) .unwrap(); - let sym_expr_param = TypeParam::Opaque(sym_expr_opdef.instantiate([]).unwrap()); + let sym_expr_param = TypeParam::Opaque{ty:sym_expr_opdef.instantiate([]).unwrap()}; - e.add_op_custom_sig_simple( + e.add_op( SYM_OP_ID, "Store a sympy expression that can be evaluated to a float.".to_string(), - vec![sym_expr_param], - |_: &[TypeArg]| Ok(FunctionType::new(type_row![], type_row![FLOAT64_TYPE])), + PolyFuncType::new(vec![sym_expr_param], FunctionType::new(type_row![], type_row![FLOAT64_TYPE])), ) .unwrap(); diff --git a/tket2/src/extension/angle.rs b/tket2/src/extension/angle.rs index e52c1b3e..5aedb1a7 100644 --- a/tket2/src/extension/angle.rs +++ b/tket2/src/extension/angle.rs @@ -1,7 +1,7 @@ use std::{cmp::max, num::NonZeroU64}; use hugr::{ - extension::{prelude::ERROR_TYPE, ExtensionRegistry, SignatureError, TypeDef, PRELUDE}, + extension::{prelude::ERROR_TYPE, SignatureError, SignatureFromArgs, TypeDef}, types::{ type_param::{TypeArgError, TypeParam}, ConstTypeError, CustomCheckFailure, CustomType, FunctionType, PolyFuncType, Type, TypeArg, @@ -127,127 +127,113 @@ impl CustomConst for ConstAngle { } } -fn type_var(var_id: usize, extension: &Extension) -> Result { - Ok(Type::new_extension(angle_def(extension).instantiate( - vec![TypeArg::new_var_use(var_id, LOG_DENOM_TYPE_PARAM)], - )?)) -} -fn atrunc_sig(extension: &Extension) -> Result { - let in_angle = type_var(0, extension)?; - let out_angle = type_var(1, extension)?; - - Ok(FunctionType::new(vec![in_angle], vec![out_angle])) -} - -fn aconvert_sig(extension: &Extension) -> Result { - let in_angle = type_var(0, extension)?; - let out_angle = type_var(1, extension)?; - Ok(FunctionType::new( - vec![in_angle], - vec![Type::new_sum(vec![out_angle, ERROR_TYPE])], - )) -} - /// Collect a vector into an array. fn collect_array(arr: &[T]) -> [&T; N] { arr.iter().collect_vec().try_into().unwrap() } -fn abinop_sig(arg_values: &[TypeArg]) -> Result { - let [arg0, arg1] = collect_array(arg_values); - let m: u8 = get_log_denom(arg0)?; - let n: u8 = get_log_denom(arg1)?; - let l: u8 = max(m, n); - Ok(FunctionType::new( - vec![angle_type(m), angle_type(n)], - vec![angle_type(l)], - )) -} +fn abinop_sig() -> impl SignatureFromArgs { + struct BinOp; + const PARAMS: &[TypeParam] = &[LOG_DENOM_TYPE_PARAM]; + + impl SignatureFromArgs for BinOp { + fn compute_signature( + &self, + arg_values: &[TypeArg], + ) -> Result { + let [arg0, arg1] = collect_array(arg_values); + let m: u8 = get_log_denom(arg0)?; + let n: u8 = get_log_denom(arg1)?; + let l: u8 = max(m, n); + Ok(FunctionType::new(vec![angle_type(m), angle_type(n)], vec![angle_type(l)]).into()) + } + + fn static_params(&self) -> &[TypeParam] { + PARAMS + } + } -fn aunop_sig(extension: &Extension) -> Result { - let angle = type_var(0, extension)?; - Ok(FunctionType::new_endo(vec![angle])) + BinOp } fn angle_def(extension: &Extension) -> &TypeDef { extension.get_type(&ANGLE_TYPE_ID).unwrap() } +fn generic_angle_type(var_id: usize, angle_type_def: &TypeDef) -> Type { + Type::new_extension( + angle_type_def + .instantiate(vec![TypeArg::new_var_use(var_id, LOG_DENOM_TYPE_PARAM)]) + .unwrap(), + ) +} pub(super) fn add_to_extension(extension: &mut Extension) { - extension + let angle_type_def = extension .add_type( ANGLE_TYPE_ID, vec![LOG_DENOM_TYPE_PARAM], "angle value with a given log-denominator".to_owned(), TypeBound::Eq.into(), ) - .unwrap(); + .unwrap() + .clone(); - let reg1: ExtensionRegistry = [PRELUDE.to_owned(), extension.to_owned()].into(); extension - .add_op_type_scheme( + .add_op( "atrunc".into(), "truncate an angle to one with a lower log-denominator with the same value, rounding \ down in [0, 2π) if necessary" .to_owned(), - Default::default(), - vec![], - PolyFuncType::new_validated( + PolyFuncType::new( vec![LOG_DENOM_TYPE_PARAM, LOG_DENOM_TYPE_PARAM], - atrunc_sig(extension).unwrap(), - ®1, - ) - .unwrap(), + // atrunc_sig(extension).unwrap(), + FunctionType::new( + vec![generic_angle_type(0, &angle_type_def)], + vec![generic_angle_type(1, &angle_type_def)], + ), + ), ) .unwrap(); extension - .add_op_type_scheme( + .add_op( "aconvert".into(), "convert an angle to one with another log-denominator having the same value, if \ possible, otherwise return an error" .to_owned(), - Default::default(), - vec![], - PolyFuncType::new_validated( + PolyFuncType::new( vec![LOG_DENOM_TYPE_PARAM, LOG_DENOM_TYPE_PARAM], - aconvert_sig(extension).unwrap(), - ®1, - ) - .unwrap(), + FunctionType::new( + vec![generic_angle_type(0, &angle_type_def)], + vec![Type::new_sum(vec![ + generic_angle_type(1, &angle_type_def), + ERROR_TYPE, + ])], + ), + ), ) .unwrap(); extension - .add_op_custom_sig_simple( - "aadd".into(), - "addition of angles".to_owned(), - vec![LOG_DENOM_TYPE_PARAM], - abinop_sig, - ) + .add_op("aadd".into(), "addition of angles".to_owned(), abinop_sig()) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op( "asub".into(), "subtraction of the second angle from the first".to_owned(), - vec![LOG_DENOM_TYPE_PARAM], - abinop_sig, + abinop_sig(), ) .unwrap(); extension - .add_op_type_scheme( + .add_op( "aneg".into(), "negation of an angle".to_owned(), - Default::default(), - vec![], - PolyFuncType::new_validated( - vec![LOG_DENOM_TYPE_PARAM, LOG_DENOM_TYPE_PARAM], - aunop_sig(extension).unwrap(), - ®1, - ) - .unwrap(), + PolyFuncType::new( + vec![LOG_DENOM_TYPE_PARAM], + FunctionType::new_endo(vec![generic_angle_type(0, &angle_type_def)]), + ), ) .unwrap(); } @@ -305,13 +291,19 @@ mod test { } #[test] fn test_binop_sig() { - let sig = abinop_sig(&[type_arg(23), type_arg(42)]).unwrap(); + let binop_sig = abinop_sig(); + + let sig = binop_sig + .compute_signature(&[type_arg(23), type_arg(42)]) + .unwrap(); assert_eq!( sig, - FunctionType::new(vec![angle_type(23), angle_type(42)], vec![angle_type(42)]) + FunctionType::new(vec![angle_type(23), angle_type(42)], vec![angle_type(42)]).into() ); - assert!(abinop_sig(&[type_arg(23), type_arg(89)]).is_err()); + assert!(binop_sig + .compute_signature(&[type_arg(23), type_arg(89)]) + .is_err()); } } diff --git a/tket2/src/json/op.rs b/tket2/src/json/op.rs index 58c9df06..7ae81832 100644 --- a/tket2/src/json/op.rs +++ b/tket2/src/json/op.rs @@ -207,11 +207,11 @@ impl TryFrom<&OpType> for JsonOp { // // Non-supported Hugr operations throw an error. let err = || OpConvertError::UnsupportedOpSerialization(op.clone()); - let OpType::LeafOp(leaf) = op else { + let Some(leaf) = op.as_leaf_op() else { return Err(err()); }; - let json_optype = if let Ok(tk2op) = leaf.clone().try_into() { + let json_optype = if let Ok(tk2op) = leaf.try_into() { match tk2op { Tk2Op::H => JsonOpType::H, Tk2Op::CX => JsonOpType::CX, diff --git a/tket2/src/ops.rs b/tket2/src/ops.rs index 50df8601..9f76820a 100644 --- a/tket2/src/ops.rs +++ b/tket2/src/ops.rs @@ -1,12 +1,11 @@ -use std::collections::HashMap; - use crate::extension::{ SYM_EXPR_T, SYM_OP_ID, TKET2_EXTENSION as EXTENSION, TKET2_EXTENSION_ID as EXTENSION_ID, }; use hugr::{ extension::{ prelude::{BOOL_T, QB_T}, - ExtensionBuildError, ExtensionId, OpDef, + simple_op::{try_from_name, MakeExtensionOp, MakeOpDef, MakeRegisteredOp}, + ExtensionId, OpDef, SignatureFunc, }, ops::{custom::ExternalOp, LeafOp, OpType}, std_extensions::arithmetic::float_types::FLOAT64_TYPE, @@ -15,13 +14,10 @@ use hugr::{ type_param::{CustomTypeArg, TypeArg}, FunctionType, }, - Extension, }; use serde::{Deserialize, Serialize}; -use std::str::FromStr; -use strum::IntoEnumIterator; use strum_macros::{Display, EnumIter, EnumString, IntoStaticStr}; use thiserror::Error; @@ -72,10 +68,10 @@ pub enum Tk2Op { /// Whether an op is a given Tk2Op. pub fn op_matches(op: &OpType, tk2op: Tk2Op) -> bool { - let Ok(op) = Tk2Op::try_from(op) else { + let Some(ext_op) = tk2op.to_extension_op() else { return false; }; - op == tk2op + op.as_leaf_op().and_then(|op| op.as_extension_op()) == Some(&ext_op) } #[derive(Clone, Copy, Debug, Serialize, Deserialize, EnumIter, Display, PartialEq, PartialOrd)] @@ -93,50 +89,14 @@ pub enum Pauli { #[error("Not a Tk2Op.")] pub struct NotTk2Op; -// this trait could be implemented in Hugr -pub(crate) trait SimpleOpEnum: - Into<&'static str> + FromStr + Copy + IntoEnumIterator -{ - type LoadError: std::error::Error; - - fn signature(&self) -> FunctionType; - fn name(&self) -> &str { - (*self).into() - } - fn from_extension_name(extension: &ExtensionId, op_name: &str) - -> Result; - fn try_from_op_def(op_def: &OpDef) -> Result { - Self::from_extension_name(op_def.extension(), op_def.name()) - } - fn add_to_extension<'e>( - &self, - ext: &'e mut Extension, - ) -> Result<&'e OpDef, ExtensionBuildError>; - - fn all_variants() -> ::Iterator { - ::iter() - } -} - -fn from_extension_name( - extension: &ExtensionId, - op_name: &str, -) -> Result { - if extension != &EXTENSION_ID { - return Err(NotTk2Op); - } - T::from_str(op_name).map_err(|_| NotTk2Op) -} - impl Pauli { /// Check if this pauli commutes with another. pub fn commutes_with(&self, other: Self) -> bool { *self == Pauli::I || other == Pauli::I || *self == other } } -impl SimpleOpEnum for Tk2Op { - type LoadError = NotTk2Op; - fn signature(&self) -> FunctionType { +impl MakeOpDef for Tk2Op { + fn signature(&self) -> SignatureFunc { use Tk2Op::*; let one_qb_row = type_row![QB_T]; let two_qb_row = type_row![QB_T, QB_T]; @@ -156,35 +116,28 @@ impl SimpleOpEnum for Tk2Op { one_qb_row, ), } + .into() } - fn add_to_extension<'e>( - &self, - ext: &'e mut Extension, - ) -> Result<&'e OpDef, ExtensionBuildError> { - let name = self.name().into(); - let FunctionType { input, output, .. } = self.signature(); - ext.add_op_custom_sig( - name, - format!("TKET 2 quantum op: {}", self.name()), - vec![], - HashMap::from_iter([( - "commutation".to_string(), - serde_yaml::to_value(self.qubit_commutation()).unwrap(), - )]), - vec![], - move |_: &_| Ok(FunctionType::new(input.clone(), output.clone())), - ) + fn post_opdef(&self, def: &mut OpDef) { + def.add_misc( + "commutation", + serde_yaml::to_value(self.qubit_commutation()).unwrap(), + ); } - fn from_extension_name( - extension: &ExtensionId, - op_name: &str, - ) -> Result { - if extension != &EXTENSION_ID { - return Err(NotTk2Op); - } - Self::from_str(op_name).map_err(|_| NotTk2Op) + fn from_def(op_def: &OpDef) -> Result { + try_from_name(op_def.name()) + } +} + +impl MakeRegisteredOp for Tk2Op { + fn extension_id(&self) -> ExtensionId { + EXTENSION_ID.to_owned() + } + + fn registry<'s, 'r: 's>(&'s self) -> &'r hugr::extension::ExtensionRegistry { + ®ISTRY } } @@ -238,7 +191,7 @@ pub(crate) fn match_symb_const_op(op: &OpType) -> Option<&str> { { // TODO also check extension name - let Some(TypeArg::Opaque { arg }) = e.args().get(0) else { + let Some(TypeArg::Opaque { arg }) = e.args().first() else { panic!("should be an opaque type arg.") }; @@ -256,20 +209,9 @@ pub(crate) fn match_symb_const_op(op: &OpType) -> Option<&str> { } } -// From implementations could be made generic over SimpleOpEnum impl From for LeafOp { fn from(op: Tk2Op) -> Self { - EXTENSION - .instantiate_extension_op(op.name(), [], ®ISTRY) - .unwrap() - .into() - } -} - -impl From for OpType { - fn from(op: Tk2Op) -> Self { - let l: LeafOp = op.into(); - l.into() + op.to_extension_op().unwrap().into() } } @@ -277,7 +219,8 @@ impl TryFrom for Tk2Op { type Error = NotTk2Op; fn try_from(op: OpType) -> Result { - Self::try_from(&op) + let leaf: LeafOp = op.try_into().map_err(|_| NotTk2Op)?; + leaf.try_into() } } @@ -296,13 +239,15 @@ impl TryFrom<&LeafOp> for Tk2Op { type Error = NotTk2Op; fn try_from(op: &LeafOp) -> Result { - match op { - LeafOp::CustomOp(b) => match b.as_ref() { - ExternalOp::Extension(e) => Self::try_from_op_def(e.def()), - ExternalOp::Opaque(o) => from_extension_name(o.extension(), o.name()), - }, - _ => Err(NotTk2Op), + let LeafOp::CustomOp(ext) = op else { + return Err(NotTk2Op); + }; + + match ext.as_ref() { + ExternalOp::Extension(ext) => Tk2Op::from_extension_op(ext), + ExternalOp::Opaque(opaque) => try_from_name(opaque.name()), } + .map_err(|_| NotTk2Op) } } @@ -314,35 +259,29 @@ impl TryFrom for Tk2Op { } } -/// load all variants of a `SimpleOpEnum` in to an extension as op defs. -pub(crate) fn load_all_ops( - extension: &mut Extension, -) -> Result<(), ExtensionBuildError> { - for op in T::all_variants() { - op.add_to_extension(extension)?; - } - Ok(()) -} #[cfg(test)] pub(crate) mod test { use std::sync::Arc; + use hugr::extension::simple_op::MakeOpDef; + use hugr::ops::OpName; use hugr::{extension::OpDef, Hugr}; use rstest::{fixture, rstest}; + use strum::IntoEnumIterator; use super::Tk2Op; use crate::extension::{TKET2_EXTENSION as EXTENSION, TKET2_EXTENSION_ID as EXTENSION_ID}; - use crate::{circuit::Circuit, ops::SimpleOpEnum, utils::build_simple_circuit}; - fn get_opdef(op: impl SimpleOpEnum) -> Option<&'static Arc> { - EXTENSION.get_op(op.name()) + use crate::{circuit::Circuit, utils::build_simple_circuit}; + fn get_opdef(op: impl OpName) -> Option<&'static Arc> { + EXTENSION.get_op(&op.name()) } #[test] fn create_extension() { assert_eq!(EXTENSION.name(), &EXTENSION_ID); - for o in Tk2Op::all_variants() { - assert_eq!(Tk2Op::try_from_op_def(get_opdef(o).unwrap()), Ok(o)); + for o in Tk2Op::iter() { + assert_eq!(Tk2Op::from_def(get_opdef(o).unwrap()), Ok(o)); } } diff --git a/tket2/src/passes/chunks.rs b/tket2/src/passes/chunks.rs index 706296d4..b89171e4 100644 --- a/tket2/src/passes/chunks.rs +++ b/tket2/src/passes/chunks.rs @@ -8,14 +8,13 @@ use std::ops::{Index, IndexMut}; use derive_more::From; use hugr::builder::{Container, FunctionBuilder}; -use hugr::extension::ExtensionSet; use hugr::hugr::hugrmut::HugrMut; use hugr::hugr::views::sibling_subgraph::ConvexChecker; use hugr::hugr::views::{HierarchyView, SiblingGraph, SiblingSubgraph}; use hugr::hugr::{HugrError, NodeMetadataMap}; use hugr::ops::handle::DataflowParentID; use hugr::ops::OpType; -use hugr::types::{FunctionType, Signature}; +use hugr::types::FunctionType; use hugr::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex, Wire}; use itertools::Itertools; @@ -59,7 +58,7 @@ impl Chunk { ) .expect("Failed to define the chunk subgraph"); let extracted = subgraph - .extract_subgraph(circ, "Chunk", ExtensionSet::new()) + .extract_subgraph(circ, "Chunk") .expect("Failed to extract chunk"); // Transform the subgraph's input/output sets into wires that can be // matched between different chunks. @@ -314,13 +313,8 @@ impl CircuitChunks { .and_then(|map| map.get("name")) .and_then(|s| s.as_str()) .unwrap_or(""); - let signature = Signature { - signature: self.signature, - // TODO: Is this correct? Can a circuit root have a fixed set of input extensions? - input_extensions: ExtensionSet::new(), - }; - let mut builder = FunctionBuilder::new(name, signature).unwrap(); + let mut builder = FunctionBuilder::new(name, self.signature.into()).unwrap(); // Take the unfinished Hugr from the builder, to avoid unnecessary // validation checks that require connecting the inputs an outputs. let mut reassembled = mem::take(builder.hugr_mut());