Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add angle type to tket2 extension #231

Merged
merged 7 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ license-file = "LICENCE"
[workspace.dependencies]

tket2 = { path = "./tket2" }
quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "d0499ad" }
quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "b256c2b" }
portgraph = { version = "0.10" }
pyo3 = { version = "0.20" }
itertools = { version = "0.11.0" }
Expand Down
5 changes: 5 additions & 0 deletions tket2/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ fn json_op_signature(args: &[TypeArg]) -> Result<FunctionType, SignatureError> {
Ok(op.signature())
}

/// Angle type with given log denominator.
pub fn angle_custom_type(log_denom: u8) -> CustomType {
angle::angle_custom_type(&TKET2_EXTENSION, angle::type_arg(log_denom))
}

/// Name of tket 2 extension.
pub const TKET2_EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("quantum.tket2");

Expand Down
139 changes: 85 additions & 54 deletions tket2/src/extension/angle.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::{cmp::max, num::NonZeroU64};

use hugr::{
extension::{prelude::ERROR_TYPE, SignatureError},
extension::{prelude::ERROR_TYPE, ExtensionRegistry, SignatureError, TypeDef, PRELUDE},
types::{
type_param::{TypeArgError, TypeParam},
ConstTypeError, CustomCheckFailure, CustomType, FunctionType, Type, TypeArg, TypeBound,
ConstTypeError, CustomCheckFailure, CustomType, FunctionType, PolyFuncType, Type, TypeArg,
TypeBound,
},
values::CustomConst,
Extension,
Expand All @@ -13,26 +14,15 @@ use itertools::Itertools;
use smol_str::SmolStr;
use std::f64::consts::TAU;

use super::TKET2_EXTENSION_ID;

/// Identifier for the angle type.
const ANGLE_TYPE_ID: SmolStr = SmolStr::new_inline("angle");

fn angle_custom_type(log_denom_arg: TypeArg) -> CustomType {
CustomType::new(
ANGLE_TYPE_ID,
[log_denom_arg],
TKET2_EXTENSION_ID,
TypeBound::Eq,
)
pub(super) fn angle_custom_type(extension: &Extension, log_denom_arg: TypeArg) -> CustomType {
angle_def(extension).instantiate([log_denom_arg]).unwrap()
}

/// Angle type with a given log-denominator (specified by the TypeArg).
///
/// This type is capable of representing angles that are multiples of 2π / 2^N where N is the
/// log-denominator.
pub(super) fn angle_type(log_denom_arg: TypeArg) -> Type {
Type::new_extension(angle_custom_type(log_denom_arg))
fn angle_type(log_denom: u8) -> Type {
Type::new_extension(super::angle_custom_type(log_denom))
}

/// The largest permitted log-denominator.
Expand All @@ -47,7 +37,7 @@ pub const LOG_DENOM_TYPE_PARAM: TypeParam =
TypeParam::bounded_nat(NonZeroU64::MIN.saturating_add(LOG_DENOM_MAX as u64));

/// Get the log-denominator of the specified type argument or error if the argument is invalid.
pub(super) fn get_log_denom(arg: &TypeArg) -> Result<u8, TypeArgError> {
fn get_log_denom(arg: &TypeArg) -> Result<u8, TypeArgError> {
match arg {
TypeArg::BoundedNat { n } if is_valid_log_denom(*n as u8) => Ok(*n as u8),
_ => Err(TypeArgError::TypeMismatch {
Expand Down Expand Up @@ -124,7 +114,7 @@ impl CustomConst for ConstAngle {
format!("a(2π*{}/2^{})", self.value, self.log_denom).into()
}
fn check_custom_type(&self, typ: &CustomType) -> Result<(), CustomCheckFailure> {
if typ.clone() == angle_custom_type(type_arg(self.log_denom)) {
if typ.clone() == super::angle_custom_type(self.log_denom) {
Ok(())
} else {
Err(CustomCheckFailure::Message(
Expand All @@ -136,49 +126,51 @@ impl CustomConst for ConstAngle {
hugr::values::downcast_equal_consts(self, other)
}
}
/// Collect a vector into an array.
fn collect_array<const N: usize, T: std::fmt::Debug>(arr: &[T]) -> [&T; N] {
arr.iter().collect_vec().try_into().unwrap()

fn type_var(var_id: usize, extension: &Extension) -> Result<Type, SignatureError> {
Ok(Type::new_extension(angle_def(extension).instantiate(
vec![TypeArg::new_var_use(var_id, LOG_DENOM_TYPE_PARAM)],
)?))
}
fn atrunc_sig(extension: &Extension) -> Result<FunctionType, SignatureError> {
let in_angle = type_var(0, extension)?;
let out_angle = type_var(1, extension)?;

fn atrunc_sig(arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
let [arg0, arg1] = collect_array(arg_values);
let m: u8 = get_log_denom(arg0)?;
let n: u8 = get_log_denom(arg1)?;
if m < n {
return Err(SignatureError::InvalidTypeArgs);
}
Ok(FunctionType::new(
vec![angle_type(arg0.clone())],
vec![angle_type(arg1.clone())],
))
Ok(FunctionType::new(vec![in_angle], vec![out_angle]))
}

fn aconvert_sig(arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
let [arg0, arg1] = collect_array(arg_values);
fn aconvert_sig(extension: &Extension) -> Result<FunctionType, SignatureError> {
let in_angle = type_var(0, extension)?;
let out_angle = type_var(1, extension)?;
Ok(FunctionType::new(
vec![angle_type(arg0.clone())],
vec![Type::new_sum(vec![angle_type(arg1.clone()), ERROR_TYPE])],
vec![in_angle],
vec![Type::new_sum(vec![out_angle, ERROR_TYPE])],
))
}

/// Collect a vector into an array.
fn collect_array<const N: usize, T: std::fmt::Debug>(arr: &[T]) -> [&T; N] {
arr.iter().collect_vec().try_into().unwrap()
}

fn abinop_sig(arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
let [arg0, arg1] = collect_array(arg_values);
let m: u8 = get_log_denom(arg0)?;
let n: u8 = get_log_denom(arg1)?;
let l: u8 = max(m, n);
Ok(FunctionType::new(
vec![
angle_type(TypeArg::BoundedNat { n: m as u64 }),
angle_type(TypeArg::BoundedNat { n: n as u64 }),
],
vec![angle_type(TypeArg::BoundedNat { n: l as u64 })],
vec![angle_type(m), angle_type(n)],
vec![angle_type(l)],
))
}

fn aunop_sig(arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
let [arg] = collect_array(arg_values);
Ok(FunctionType::new_linear(vec![angle_type(arg.clone())]))
fn aunop_sig(extension: &Extension) -> Result<FunctionType, SignatureError> {
let angle = type_var(0, extension)?;
Ok(FunctionType::new_linear(vec![angle]))
}

fn angle_def(extension: &Extension) -> &TypeDef {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this is a wider question about the extensions API, but it seems strange that this function (and some others -- type_var, angle_custom_type, atrunc_sig etc) need to take their own extension as an argument. I suppose it is because angle_custom_type is exposed at the module level? But it feels like it should be unnecessary.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inside angle.rs, I use these functions to construct the extension. Outside, it uses the lazy static reference to the extension. If I used the static reference here it would cause unbounded recursion on construction - there is probably a nicer way of doing this where that isn't a problem!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, maybe something to ponder but this LGTM.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have an issue to raise already about needing a temporary register containing types in order to be able to create validated type schemes, perhaps we can rethink the definition API as part of that.

extension.get_type(&ANGLE_TYPE_ID).unwrap()
}

pub(super) fn add_to_extension(extension: &mut Extension) {
Expand All @@ -191,25 +183,38 @@ pub(super) fn add_to_extension(extension: &mut Extension) {
)
.unwrap();

let reg1: ExtensionRegistry = [PRELUDE.to_owned(), extension.to_owned()].into();
extension
.add_op_custom_sig_simple(
.add_op_type_scheme(
"atrunc".into(),
"truncate an angle to one with a lower log-denominator with the same value, rounding \
down in [0, 2π) if necessary"
.to_owned(),
vec![LOG_DENOM_TYPE_PARAM, LOG_DENOM_TYPE_PARAM],
atrunc_sig,
Default::default(),
vec![],
PolyFuncType::new_validated(
vec![LOG_DENOM_TYPE_PARAM, LOG_DENOM_TYPE_PARAM],
atrunc_sig(extension).unwrap(),
&reg1,
)
.unwrap(),
)
.unwrap();

extension
.add_op_custom_sig_simple(
.add_op_type_scheme(
"aconvert".into(),
"convert an angle to one with another log-denominator having the same value, if \
possible, otherwise return an error"
.to_owned(),
vec![LOG_DENOM_TYPE_PARAM, LOG_DENOM_TYPE_PARAM],
aconvert_sig,
Default::default(),
vec![],
PolyFuncType::new_validated(
vec![LOG_DENOM_TYPE_PARAM, LOG_DENOM_TYPE_PARAM],
aconvert_sig(extension).unwrap(),
&reg1,
)
.unwrap(),
)
.unwrap();

Expand All @@ -232,11 +237,17 @@ pub(super) fn add_to_extension(extension: &mut Extension) {
.unwrap();

extension
.add_op_custom_sig_simple(
.add_op_type_scheme(
"aneg".into(),
"negation of an angle".to_owned(),
vec![LOG_DENOM_TYPE_PARAM],
aunop_sig,
Default::default(),
vec![],
PolyFuncType::new_validated(
vec![LOG_DENOM_TYPE_PARAM, LOG_DENOM_TYPE_PARAM],
aunop_sig(extension).unwrap(),
&reg1,
)
.unwrap(),
)
.unwrap();
}
Expand Down Expand Up @@ -266,6 +277,13 @@ mod test {
assert_ne!(const_a32_7, const_a33_7);
assert_ne!(const_a32_7, const_a32_8);
assert_eq!(const_a32_7, ConstAngle::new(5, 7).unwrap());

assert!(const_a32_7
.check_custom_type(&super::super::angle_custom_type(5))
.is_ok());
assert!(const_a32_7
.check_custom_type(&super::super::angle_custom_type(6))
.is_err());
assert!(matches!(
ConstAngle::new(3, 256),
Err(ConstTypeError::CustomCheckFail(_))
Expand All @@ -277,5 +295,18 @@ mod test {
let const_af1 = ConstAngle::from_radians_rounding(5, 0.21874 * TAU).unwrap();
assert_eq!(const_af1.value(), 7);
assert_eq!(const_af1.log_denom(), 5);

assert!(ConstAngle::from_radians_rounding(54, 0.21874 * TAU).is_err());
}
#[test]
fn test_binop_sig() {
let sig = abinop_sig(&[type_arg(23), type_arg(42)]).unwrap();

assert_eq!(
sig,
FunctionType::new(vec![angle_type(23), angle_type(42)], vec![angle_type(42)])
);

assert!(abinop_sig(&[type_arg(23), type_arg(89)]).is_err());
}
}