diff --git a/src/builder/circuit.rs b/src/builder/circuit.rs index 4b45d34fc..043945ff9 100644 --- a/src/builder/circuit.rs +++ b/src/builder/circuit.rs @@ -139,9 +139,9 @@ mod test { }, extension::prelude::BOOL_T, ops::{custom::OpaqueOp, LeafOp}, - std_extensions::quantum::test::{cx_gate, h_gate, measure}, type_row, types::FunctionType, + utils::test_quantum_extension::{cx_gate, h_gate, measure}, }; #[test] diff --git a/src/builder/dataflow.rs b/src/builder/dataflow.rs index d4fa1eafe..3ef28ac8d 100644 --- a/src/builder/dataflow.rs +++ b/src/builder/dataflow.rs @@ -218,8 +218,8 @@ pub(crate) mod test { use crate::ops::{handle::NodeHandle, LeafOp, OpTag}; use crate::std_extensions::logic::test::and_op; - use crate::std_extensions::quantum::test::h_gate; use crate::types::Type; + use crate::utils::test_quantum_extension::h_gate; use crate::{ builder::{ test::{n_identity, BIT, NAT, QB}, diff --git a/src/hugr/rewrite/simple_replace.rs b/src/hugr/rewrite/simple_replace.rs index 51d99da8a..4a376604d 100644 --- a/src/hugr/rewrite/simple_replace.rs +++ b/src/hugr/rewrite/simple_replace.rs @@ -234,9 +234,9 @@ pub(in crate::hugr::rewrite) mod test { use crate::ops::OpTag; use crate::ops::{OpTrait, OpType}; use crate::std_extensions::logic::test::and_op; - use crate::std_extensions::quantum::test::{cx_gate, h_gate}; use crate::type_row; use crate::types::{FunctionType, Type}; + use crate::utils::test_quantum_extension::{cx_gate, h_gate}; use crate::{IncomingPort, Node}; use super::SimpleReplacement; diff --git a/src/hugr/views/descendants.rs b/src/hugr/views/descendants.rs index bbbfd81ca..bd31107a2 100644 --- a/src/hugr/views/descendants.rs +++ b/src/hugr/views/descendants.rs @@ -202,9 +202,9 @@ pub(super) mod test { use crate::{ builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}, ops::handle::NodeHandle, - std_extensions::quantum::test::h_gate, type_row, types::{FunctionType, Type}, + utils::test_quantum_extension::h_gate, }; use super::*; diff --git a/src/hugr/views/sibling_subgraph.rs b/src/hugr/views/sibling_subgraph.rs index 2e1c4aaa0..515392615 100644 --- a/src/hugr/views/sibling_subgraph.rs +++ b/src/hugr/views/sibling_subgraph.rs @@ -681,6 +681,7 @@ mod tests { use cool_asserts::assert_matches; use crate::extension::PRELUDE_REGISTRY; + use crate::utils::test_quantum_extension::cx_gate; use crate::{ builder::{ BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, @@ -696,10 +697,7 @@ mod tests { handle::{DfgID, FuncID, NodeHandle}, OpType, }, - std_extensions::{ - logic::test::{and_op, not_op}, - quantum::test::cx_gate, - }, + std_extensions::logic::test::{and_op, not_op}, type_row, }; diff --git a/src/hugr/views/tests.rs b/src/hugr/views/tests.rs index 7597c841a..01935e49d 100644 --- a/src/hugr/views/tests.rs +++ b/src/hugr/views/tests.rs @@ -4,9 +4,9 @@ use crate::{ builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr}, extension::prelude::QB_T, ops::handle::NodeHandle, - std_extensions::quantum::test::cx_gate, type_row, types::FunctionType, + utils::test_quantum_extension::cx_gate, HugrView, }; diff --git a/src/std_extensions.rs b/src/std_extensions.rs index bc4af73cb..09552e945 100644 --- a/src/std_extensions.rs +++ b/src/std_extensions.rs @@ -5,4 +5,3 @@ pub mod arithmetic; pub mod collections; pub mod logic; -pub mod quantum; diff --git a/src/std_extensions/quantum.rs b/src/std_extensions/quantum.rs deleted file mode 100644 index 41b9054e4..000000000 --- a/src/std_extensions/quantum.rs +++ /dev/null @@ -1,369 +0,0 @@ -//! Basic HUGR quantum operations - -use std::cmp::max; -use std::f64::consts::TAU; -use std::num::NonZeroU64; - -use smol_str::SmolStr; - -use crate::extension::prelude::{BOOL_T, ERROR_TYPE, QB_T}; -use crate::extension::{ExtensionId, SignatureError}; -use crate::std_extensions::arithmetic::float_types::FLOAT64_TYPE; -use crate::type_row; -use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; -use crate::types::{ConstTypeError, CustomCheckFailure, CustomType, FunctionType, Type, TypeBound}; -use crate::utils::collect_array; -use crate::values::CustomConst; -use crate::Extension; - -use lazy_static::lazy_static; - -/// The extension identifier. -pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("quantum"); - -/// Identifier for the angle type. -const ANGLE_TYPE_ID: SmolStr = SmolStr::new_inline("angle"); - -fn angle_custom_type(log_denom_arg: TypeArg) -> CustomType { - CustomType::new(ANGLE_TYPE_ID, [log_denom_arg], EXTENSION_ID, TypeBound::Eq) -} - -/// Angle type with a given log-denominator (specified by the TypeArg). -/// -/// This type is capable of representing angles that are multiples of 2π / 2^N where N is the -/// log-denominator. -pub(super) fn angle_type(log_denom_arg: TypeArg) -> Type { - Type::new_extension(angle_custom_type(log_denom_arg)) -} - -/// The largest permitted log-denominator. -pub const LOG_DENOM_MAX: u8 = 53; - -const fn is_valid_log_denom(n: u8) -> bool { - n <= LOG_DENOM_MAX -} - -/// Type parameter for the log-denominator of an angle. -pub const LOG_DENOM_TYPE_PARAM: TypeParam = - TypeParam::bounded_nat(NonZeroU64::MIN.saturating_add(LOG_DENOM_MAX as u64)); - -/// Get the log-denominator of the specified type argument or error if the argument is invalid. -pub(super) fn get_log_denom(arg: &TypeArg) -> Result { - match arg { - TypeArg::BoundedNat { n } if is_valid_log_denom(*n as u8) => Ok(*n as u8), - _ => Err(TypeArgError::TypeMismatch { - arg: arg.clone(), - param: LOG_DENOM_TYPE_PARAM, - }), - } -} - -pub(super) const fn type_arg(log_denom: u8) -> TypeArg { - TypeArg::BoundedNat { - n: log_denom as u64, - } -} - -/// An angle -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct ConstAngle { - log_denom: u8, - value: u64, -} - -impl ConstAngle { - /// Create a new [`ConstAngle`] from a log-denominator and a numerator - pub fn new(log_denom: u8, value: u64) -> Result { - if !is_valid_log_denom(log_denom) { - return Err(ConstTypeError::CustomCheckFail( - crate::types::CustomCheckFailure::Message( - "Invalid angle log-denominator.".to_owned(), - ), - )); - } - if value >= (1u64 << log_denom) { - return Err(ConstTypeError::CustomCheckFail( - crate::types::CustomCheckFailure::Message( - "Invalid unsigned integer value.".to_owned(), - ), - )); - } - Ok(Self { log_denom, value }) - } - - /// Create a new [`ConstAngle`] from a log-denominator and a floating-point value in radians, - /// rounding to the nearest corresponding value. (Ties round away from zero.) - pub fn from_radians_rounding(log_denom: u8, theta: f64) -> Result { - if !is_valid_log_denom(log_denom) { - return Err(ConstTypeError::CustomCheckFail( - crate::types::CustomCheckFailure::Message( - "Invalid angle log-denominator.".to_owned(), - ), - )); - } - let a = (((1u64 << log_denom) as f64) * theta / TAU).round() as i64; - Ok(Self { - log_denom, - value: a.rem_euclid(1i64 << log_denom) as u64, - }) - } - - /// Returns the value of the constant - pub fn value(&self) -> u64 { - self.value - } - - /// Returns the log-denominator of the constant - pub fn log_denom(&self) -> u8 { - self.log_denom - } -} - -#[typetag::serde] -impl CustomConst for ConstAngle { - fn name(&self) -> SmolStr { - format!("a(2π*{}/2^{})", self.value, self.log_denom).into() - } - fn check_custom_type(&self, typ: &CustomType) -> Result<(), CustomCheckFailure> { - if typ.clone() == angle_custom_type(type_arg(self.log_denom)) { - Ok(()) - } else { - Err(CustomCheckFailure::Message( - "Angle constant type mismatch.".into(), - )) - } - } - fn equal_consts(&self, other: &dyn CustomConst) -> bool { - crate::values::downcast_equal_consts(self, other) - } -} - -fn atrunc_sig(arg_values: &[TypeArg]) -> Result { - let [arg0, arg1] = collect_array(arg_values); - let m: u8 = get_log_denom(arg0)?; - let n: u8 = get_log_denom(arg1)?; - if m < n { - return Err(SignatureError::InvalidTypeArgs); - } - Ok(FunctionType::new( - vec![angle_type(arg0.clone())], - vec![angle_type(arg1.clone())], - )) -} - -fn aconvert_sig(arg_values: &[TypeArg]) -> Result { - let [arg0, arg1] = collect_array(arg_values); - Ok(FunctionType::new( - vec![angle_type(arg0.clone())], - vec![Type::new_sum(vec![angle_type(arg1.clone()), ERROR_TYPE])], - )) -} - -fn abinop_sig(arg_values: &[TypeArg]) -> Result { - let [arg0, arg1] = collect_array(arg_values); - let m: u8 = get_log_denom(arg0)?; - let n: u8 = get_log_denom(arg1)?; - let l: u8 = max(m, n); - Ok(FunctionType::new( - vec![ - angle_type(TypeArg::BoundedNat { n: m as u64 }), - angle_type(TypeArg::BoundedNat { n: n as u64 }), - ], - vec![angle_type(TypeArg::BoundedNat { n: l as u64 })], - )) -} - -fn aunop_sig(arg_values: &[TypeArg]) -> Result { - let [arg] = collect_array(arg_values); - Ok(FunctionType::new_linear(vec![angle_type(arg.clone())])) -} - -fn one_qb_func(_: &[TypeArg]) -> Result { - Ok(FunctionType::new_linear(type_row![QB_T])) -} - -fn two_qb_func(_: &[TypeArg]) -> Result { - Ok(FunctionType::new_linear(type_row![QB_T, QB_T])) -} - -fn extension() -> Extension { - let mut extension = Extension::new(EXTENSION_ID); - - extension - .add_type( - ANGLE_TYPE_ID, - vec![LOG_DENOM_TYPE_PARAM], - "angle value with a given log-denominator".to_owned(), - TypeBound::Eq.into(), - ) - .unwrap(); - - extension - .add_op_custom_sig_simple( - "atrunc".into(), - "truncate an angle to one with a lower log-denominator with the same value, rounding \ - down in [0, 2π) if necessary" - .to_owned(), - vec![LOG_DENOM_TYPE_PARAM, LOG_DENOM_TYPE_PARAM], - atrunc_sig, - ) - .unwrap(); - - extension - .add_op_custom_sig_simple( - "aconvert".into(), - "convert an angle to one with another log-denominator having the same value, if \ - possible, otherwise return an error" - .to_owned(), - vec![LOG_DENOM_TYPE_PARAM, LOG_DENOM_TYPE_PARAM], - aconvert_sig, - ) - .unwrap(); - - extension - .add_op_custom_sig_simple( - "aadd".into(), - "addition of angles".to_owned(), - vec![LOG_DENOM_TYPE_PARAM], - abinop_sig, - ) - .unwrap(); - - extension - .add_op_custom_sig_simple( - "asub".into(), - "subtraction of the second angle from the first".to_owned(), - vec![LOG_DENOM_TYPE_PARAM], - abinop_sig, - ) - .unwrap(); - - extension - .add_op_custom_sig_simple( - "aneg".into(), - "negation of an angle".to_owned(), - vec![LOG_DENOM_TYPE_PARAM], - aunop_sig, - ) - .unwrap(); - - extension - .add_op_custom_sig_simple( - SmolStr::new_inline("H"), - "Hadamard".into(), - vec![], - one_qb_func, - ) - .unwrap(); - extension - .add_op_custom_sig_simple( - SmolStr::new_inline("RzF64"), - "Rotation specified by float".into(), - vec![], - |_: &[_]| { - Ok(FunctionType::new( - type_row![QB_T, FLOAT64_TYPE], - type_row![QB_T], - )) - }, - ) - .unwrap(); - - extension - .add_op_custom_sig_simple(SmolStr::new_inline("CX"), "CX".into(), vec![], two_qb_func) - .unwrap(); - - extension - .add_op_custom_sig_simple( - SmolStr::new_inline("Measure"), - "Measure a qubit, returning the qubit and the measurement result.".into(), - vec![], - |_arg_values: &[TypeArg]| { - Ok(FunctionType::new(type_row![QB_T], type_row![QB_T, BOOL_T])) - // TODO add logic as an extension delta when inference is - // done? - // https://github.com/CQCL-DEV/hugr/issues/425 - }, - ) - .unwrap(); - - extension -} - -lazy_static! { - /// Quantum extension definition. - pub static ref EXTENSION: Extension = extension(); -} - -#[cfg(test)] -pub(crate) mod test { - use lazy_static::lazy_static; - use std::f64::consts::TAU; - - use cool_asserts::assert_matches; - - use crate::{ - extension::{ExtensionRegistry, PRELUDE}, - ops::LeafOp, - types::{type_param::TypeArgError, ConstTypeError, TypeArg}, - }; - - use super::{get_log_denom, ConstAngle, EXTENSION}; - - lazy_static! { - /// Quantum extension definition. - static ref REG: ExtensionRegistry = [EXTENSION.to_owned(), PRELUDE.to_owned()].into(); - } - - fn get_gate(gate_name: &str) -> LeafOp { - EXTENSION - .instantiate_extension_op(gate_name, [], ®) - .unwrap() - .into() - } - - pub(crate) fn h_gate() -> LeafOp { - get_gate("H") - } - - pub(crate) fn cx_gate() -> LeafOp { - get_gate("CX") - } - - pub(crate) fn measure() -> LeafOp { - get_gate("Measure") - } - - #[test] - fn test_angle_log_denoms() { - let type_arg_53 = TypeArg::BoundedNat { n: 53 }; - assert_matches!(get_log_denom(&type_arg_53), Ok(53)); - - let type_arg_54 = TypeArg::BoundedNat { n: 54 }; - assert_matches!( - get_log_denom(&type_arg_54), - Err(TypeArgError::TypeMismatch { .. }) - ); - } - - #[test] - fn test_angle_consts() { - let const_a32_7 = ConstAngle::new(5, 7).unwrap(); - let const_a33_7 = ConstAngle::new(6, 7).unwrap(); - let const_a32_8 = ConstAngle::new(6, 8).unwrap(); - assert_ne!(const_a32_7, const_a33_7); - assert_ne!(const_a32_7, const_a32_8); - assert_eq!(const_a32_7, ConstAngle::new(5, 7).unwrap()); - assert_matches!( - ConstAngle::new(3, 256), - Err(ConstTypeError::CustomCheckFail(_)) - ); - assert_matches!( - ConstAngle::new(54, 256), - Err(ConstTypeError::CustomCheckFail(_)) - ); - let const_af1 = ConstAngle::from_radians_rounding(5, 0.21874 * TAU).unwrap(); - assert_eq!(const_af1.value(), 7); - assert_eq!(const_af1.log_denom(), 5); - } -} diff --git a/src/utils.rs b/src/utils.rs index aa15ec087..de37f348c 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -27,6 +27,104 @@ pub fn collect_array(arr: &[T]) -> [&T; N] { arr.iter().collect_vec().try_into().unwrap() } +#[cfg(test)] +pub(crate) mod test_quantum_extension { + use smol_str::SmolStr; + + use crate::{ + extension::{ + prelude::{BOOL_T, QB_T}, + ExtensionId, ExtensionRegistry, SignatureError, PRELUDE, + }, + ops::LeafOp, + std_extensions::arithmetic::float_types::FLOAT64_TYPE, + type_row, + types::{FunctionType, TypeArg}, + Extension, + }; + + use lazy_static::lazy_static; + + fn one_qb_func(_: &[TypeArg]) -> Result { + Ok(FunctionType::new_linear(type_row![QB_T])) + } + + fn two_qb_func(_: &[TypeArg]) -> Result { + Ok(FunctionType::new_linear(type_row![QB_T, QB_T])) + } + /// The extension identifier. + pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("test.quantum"); + fn extension() -> Extension { + let mut extension = Extension::new(EXTENSION_ID); + + extension + .add_op_custom_sig_simple( + SmolStr::new_inline("H"), + "Hadamard".into(), + vec![], + one_qb_func, + ) + .unwrap(); + extension + .add_op_custom_sig_simple( + SmolStr::new_inline("RzF64"), + "Rotation specified by float".into(), + vec![], + |_: &[_]| { + Ok(FunctionType::new( + type_row![QB_T, FLOAT64_TYPE], + type_row![QB_T], + )) + }, + ) + .unwrap(); + + extension + .add_op_custom_sig_simple(SmolStr::new_inline("CX"), "CX".into(), vec![], two_qb_func) + .unwrap(); + + extension + .add_op_custom_sig_simple( + SmolStr::new_inline("Measure"), + "Measure a qubit, returning the qubit and the measurement result.".into(), + vec![], + |_arg_values: &[TypeArg]| { + Ok(FunctionType::new(type_row![QB_T], type_row![QB_T, BOOL_T])) + // TODO add logic as an extension delta when inference is + // done? + // https://github.com/CQCL-DEV/hugr/issues/425 + }, + ) + .unwrap(); + + extension + } + + lazy_static! { + /// Quantum extension definition. + pub static ref EXTENSION: Extension = extension(); + static ref REG: ExtensionRegistry = [EXTENSION.to_owned(), PRELUDE.to_owned()].into(); + + } + fn get_gate(gate_name: &str) -> LeafOp { + EXTENSION + .instantiate_extension_op(gate_name, [], ®) + .unwrap() + .into() + } + pub(crate) fn h_gate() -> LeafOp { + get_gate("H") + } + + pub(crate) fn cx_gate() -> LeafOp { + get_gate("CX") + } + + pub(crate) fn measure() -> LeafOp { + get_gate("Measure") + } +} + #[allow(dead_code)] // Test only utils #[cfg(test)]