From 6126f105d8765a4245b0a71002f903d7c8eb3e93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Mon, 23 Sep 2024 13:59:00 +0100 Subject: [PATCH] feat: Add an explicit struct for the tket2 sympy op (#616) Replaces the ad-hoc definition of sympy operations with an opdef / concrete op pair that can be `cast`ed to. We need this for the _tket2->pytket_ encoder. Added temporary re-exports to keep this change non-breaking. Includes #615 --- .pre-commit-config.yaml | 2 +- justfile | 4 +- tket2/src/extension.rs | 29 ++---- tket2/src/extension/sympy.rs | 176 +++++++++++++++++++++++++++++++++++ tket2/src/ops.rs | 10 +- uv.lock | 4 +- 6 files changed, 192 insertions(+), 33 deletions(-) create mode 100644 tket2/src/extension/sympy.rs diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f038bc6b..f2f696b6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -72,7 +72,7 @@ repos: - id: cargo-test name: cargo test description: Run tests with `cargo test`. - entry: uv run -- cargo test --all-features --workspace + entry: uv run -- cargo test --all-features language: system files: \.rs$ pass_filenames: false diff --git a/justfile b/justfile index 7e33e32f..8b9d0035 100644 --- a/justfile +++ b/justfile @@ -18,8 +18,8 @@ build: # Run all the tests. test language="[rust|python]" : (_run_lang language \ - "uv run cargo test --all-features --workspace" \ - "uv run maturin develop && uv run pytest" + "uv run cargo test --all-features" \ + "uv run maturin develop --uv && uv run pytest" ) # Auto-fix all clippy warnings. diff --git a/tket2/src/extension.rs b/tket2/src/extension.rs index dc60295b..e5fe8e61 100644 --- a/tket2/src/extension.rs +++ b/tket2/src/extension.rs @@ -12,14 +12,19 @@ use hugr::extension::{ use hugr::hugr::IdentList; use hugr::std_extensions::arithmetic::float_types::EXTENSION as FLOAT_TYPES; use hugr::types::type_param::{TypeArg, TypeParam}; -use hugr::types::{CustomType, PolyFuncType, PolyFuncTypeRV, Signature}; -use hugr::{type_row, Extension}; +use hugr::types::{CustomType, PolyFuncType, PolyFuncTypeRV}; +use hugr::Extension; use lazy_static::lazy_static; -use rotation::ROTATION_TYPE; use smol_str::SmolStr; /// Definition for Angle ops and types. pub mod rotation; +pub mod sympy; + +use sympy::SympyOpDef; +/// Backwards compatible exports. +/// TODO: Remove in a breaking release. +pub use sympy::{SYM_EXPR_NAME, SYM_EXPR_T, SYM_OP_ID}; /// The ID of the TKET1 extension. pub const TKET1_EXTENSION_ID: ExtensionId = IdentList::new_unchecked("TKET1"); @@ -90,31 +95,15 @@ impl CustomSignatureFunc for Tk1Signature { /// Name of tket 2 extension. pub const TKET2_EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("tket2.quantum"); -/// The name of the symbolic expression opaque type arg. -pub const SYM_EXPR_NAME: SmolStr = SmolStr::new_inline("SymExpr"); - -/// The name of the symbolic expression opaque type arg. -pub const SYM_OP_ID: SmolStr = SmolStr::new_inline("symbolic_angle"); - /// Current version of the TKET 2 extension pub const TKET2_EXTENSION_VERSION: Version = Version::new(0, 1, 0); lazy_static! { -/// The type of the symbolic expression opaque type arg. -pub static ref SYM_EXPR_T: CustomType = - TKET2_EXTENSION.get_type(&SYM_EXPR_NAME).unwrap().instantiate([]).unwrap(); - /// The extension definition for TKET2 ops and types. pub static ref TKET2_EXTENSION: Extension = { let mut e = Extension::new(TKET2_EXTENSION_ID, TKET2_EXTENSION_VERSION); Tk2Op::load_all_ops(&mut e).expect("add fail"); - - e.add_op( - SYM_OP_ID, - "Store a sympy expression that can be evaluated to an angle.".to_string(), - PolyFuncType::new(vec![TypeParam::String], Signature::new(type_row![], type_row![ROTATION_TYPE])), - ) - .unwrap(); + SympyOpDef.add_to_extension(&mut e).unwrap(); e }; } diff --git a/tket2/src/extension/sympy.rs b/tket2/src/extension/sympy.rs new file mode 100644 index 00000000..e163093c --- /dev/null +++ b/tket2/src/extension/sympy.rs @@ -0,0 +1,176 @@ +//! Opaque operations encoding sympy expressions. +//! +//! Part of the TKET2 extension. + +use std::str::FromStr; + +use hugr::extension::simple_op::{ + try_from_name, HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError, +}; +use hugr::extension::{ExtensionId, ExtensionRegistry, SignatureError}; +use hugr::ops::{ExtensionOp, NamedOp, OpName}; +use hugr::type_row; +use hugr::types::type_param::TypeParam; +use hugr::types::{CustomType, PolyFuncType, Signature, TypeArg}; +use lazy_static::lazy_static; +use smol_str::SmolStr; + +use crate::extension::TKET2_EXTENSION; + +use super::rotation::ROTATION_TYPE; +use super::{REGISTRY, TKET2_EXTENSION_ID}; + +/// The name of the symbolic expression opaque type arg. +pub const SYM_EXPR_NAME: SmolStr = SmolStr::new_inline("SymExpr"); + +/// The name of the symbolic expression opaque type arg. +pub const SYM_OP_ID: SmolStr = SmolStr::new_inline("symbolic_angle"); + +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] +/// An operation hardcoding a Sympy expression in its parameter. +/// +/// Returns the expression as an angle. +pub struct SympyOpDef; + +impl SympyOpDef { + /// Create a new concrete sympy definition using the given sympy expression. + pub fn with_expr(self, expr: String) -> SympyOp { + SympyOp { expr } + } +} + +impl NamedOp for SympyOpDef { + fn name(&self) -> hugr::ops::OpName { + SYM_OP_ID.to_owned() + } +} + +impl FromStr for SympyOpDef { + type Err = (); + + fn from_str(s: &str) -> Result { + if s == SYM_OP_ID { + Ok(Self) + } else { + Err(()) + } + } +} + +impl MakeOpDef for SympyOpDef { + fn from_def( + op_def: &hugr::extension::OpDef, + ) -> Result + where + Self: Sized, + { + try_from_name(op_def.name(), op_def.extension()) + } + + fn signature(&self) -> hugr::extension::SignatureFunc { + PolyFuncType::new( + vec![TypeParam::String], + Signature::new(type_row![], type_row![ROTATION_TYPE]), + ) + .into() + } + + fn description(&self) -> String { + "Store a sympy expression that can be evaluated to an angle.".to_string() + } + + fn extension(&self) -> hugr::extension::ExtensionId { + TKET2_EXTENSION_ID + } +} + +/// A concrete operation hardcoding a Sympy expression in its parameter. +/// +/// Returns the expression as an angle. +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct SympyOp { + /// The expression to evaluate. + pub expr: String, +} + +impl NamedOp for SympyOp { + fn name(&self) -> OpName { + SYM_OP_ID.to_owned() + } +} + +impl MakeExtensionOp for SympyOp { + fn from_extension_op(ext_op: &ExtensionOp) -> Result { + let def = SympyOpDef::from_def(ext_op.def())?; + def.instantiate(ext_op.args()) + } + + fn type_args(&self) -> Vec { + vec![self.expr.clone().into()] + } +} + +impl MakeRegisteredOp for SympyOp { + fn extension_id(&self) -> ExtensionId { + TKET2_EXTENSION_ID.to_owned() + } + + fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry { + ®ISTRY + } +} + +impl HasConcrete for SympyOpDef { + type Concrete = SympyOp; + + fn instantiate(&self, type_args: &[TypeArg]) -> Result { + let ty = match type_args { + [TypeArg::String { arg }] => arg.clone(), + _ => return Err(SignatureError::InvalidTypeArgs.into()), + }; + + Ok(self.with_expr(ty)) + } +} + +impl HasDef for SympyOp { + type Def = SympyOpDef; +} + +lazy_static! { + +/// The type of the symbolic expression opaque type arg. +pub static ref SYM_EXPR_T: CustomType = + TKET2_EXTENSION.get_type(&SYM_EXPR_NAME).unwrap().instantiate([]).unwrap(); + +} + +#[cfg(test)] +mod tests { + use hugr::extension::simple_op::MakeOpDef; + use hugr::ops::NamedOp; + + use super::*; + use crate::extension::TKET2_EXTENSION; + + #[test] + fn test_extension() { + assert_eq!(TKET2_EXTENSION.name(), &SympyOpDef.extension()); + + let opdef = TKET2_EXTENSION.get_op(&SympyOpDef.name()); + assert_eq!(SympyOpDef::from_def(opdef.unwrap()), Ok(SympyOpDef)); + } + + #[test] + fn test_op() { + let op = SympyOp { + expr: "cos(pi/2)".to_string(), + }; + + let op_t: ExtensionOp = op.clone().to_extension_op().unwrap(); + assert!(SympyOpDef::from_op(&op_t).is_ok()); + + let new_op = SympyOp::from_op(&op_t).unwrap(); + assert_eq!(new_op, op); + } +} diff --git a/tket2/src/ops.rs b/tket2/src/ops.rs index b6cb23ed..7e57aeb5 100644 --- a/tket2/src/ops.rs +++ b/tket2/src/ops.rs @@ -1,7 +1,6 @@ use crate::extension::rotation::ROTATION_TYPE; -use crate::extension::{ - SYM_OP_ID, TKET2_EXTENSION as EXTENSION, TKET2_EXTENSION_ID as EXTENSION_ID, -}; +use crate::extension::sympy::SympyOpDef; +use crate::extension::{SYM_OP_ID, TKET2_EXTENSION_ID as EXTENSION_ID}; use hugr::ops::custom::ExtensionOp; use hugr::ops::NamedOp; use hugr::{ @@ -176,10 +175,7 @@ impl Tk2Op { /// Initialize a new custom symbolic expression constant op from a string. pub fn symbolic_constant_op(arg: String) -> OpType { - EXTENSION - .instantiate_extension_op(&SYM_OP_ID, vec![arg.into()], ®ISTRY) - .unwrap() - .into() + SympyOpDef.with_expr(arg).into() } /// match against a symbolic constant diff --git a/uv.lock b/uv.lock index ddcc4d4c..db059f7f 100644 --- a/uv.lock +++ b/uv.lock @@ -1,8 +1,6 @@ version = 1 requires-python = ">=3.10, <3.13" resolution-markers = [ - "python_full_version < '3.13'", - "python_full_version >= '3.13'", ] [manifest] @@ -784,7 +782,7 @@ wheels = [ [[package]] name = "tket2" -version = "0.2.1" +version = "0.3.0" source = { editable = "tket2-py" } dependencies = [ { name = "hugr" },