Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: validate ExtensionRegistry when built, not as we build it #701

Merged
merged 14 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 28 additions & 14 deletions src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,33 +35,37 @@ pub use prelude::{PRELUDE, PRELUDE_REGISTRY};
pub struct ExtensionRegistry(BTreeMap<ExtensionId, Extension>);

impl ExtensionRegistry {
/// Makes a new (empty) registry.
pub const fn new() -> Self {
Self(BTreeMap::new())
}

/// Gets the Extension with the given name
pub fn get(&self, name: &str) -> Option<&Extension> {
self.0.get(name)
}
}

/// An Extension Registry containing no extensions.
pub const EMPTY_REG: ExtensionRegistry = ExtensionRegistry::new();

impl<T: IntoIterator<Item = Extension>> From<T> for ExtensionRegistry {
fn from(value: T) -> Self {
let mut reg = Self::new();
/// Makes a new ExtensionRegistry, validating all the extensions in it
pub fn try_new(
value: impl IntoIterator<Item = Extension>,
) -> Result<Self, (ExtensionId, SignatureError)> {
let mut exts = BTreeMap::new();
for ext in value.into_iter() {
let prev = reg.0.insert(ext.name.clone(), ext);
let prev = exts.insert(ext.name.clone(), ext);
if let Some(prev) = prev {
panic!("Multiple extensions with same name: {}", prev.name)
};
}
reg
// Note this potentially asks extensions to validate themselves against other extensions that
// may *not* be valid themselves yet. It'd be better to order these respecting dependencies,
// or (given parameterized types could be cyclically dependent) at least to validate types
// before ops, but since we are not even validating types yet, this is much simpler....TOOD!
acl-cqc marked this conversation as resolved.
Show resolved Hide resolved
let res = ExtensionRegistry(exts);
for ext in res.0.values() {
ext.validate(&res).map_err(|e| (ext.name().clone(), e))?;
}
Ok(res)
}
}

/// An Extension Registry containing no extensions.
pub const EMPTY_REG: ExtensionRegistry = ExtensionRegistry(BTreeMap::new());

/// An error that can occur in computing the signature of a node.
/// TODO: decide on failure modes
#[derive(Debug, Clone, Error, PartialEq, Eq)]
Expand Down Expand Up @@ -290,6 +294,16 @@ impl Extension {
let op_def = self.get_op(op_name).expect("Op not found.");
ExtensionOp::new(op_def.clone(), args, ext_reg)
}

// Validates against a registry, which we can assume includes this extension itself.
// (TODO deal with the registry itself containing invalid extensions!)
fn validate(&self, all_exts: &ExtensionRegistry) -> Result<(), SignatureError> {
// We should validate TypeParams of TypeDefs too - https://github.com/CQCL/hugr/issues/624
for op_def in self.operations.values() {
op_def.validate(all_exts)?;
}
Ok(())
}
}

impl PartialEq for Extension {
Expand Down
28 changes: 16 additions & 12 deletions src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,15 @@ impl OpDef {
SignatureFunc::CustomFunc { static_params, .. } => static_params,
}
}

pub(super) fn validate(&self, exts: &ExtensionRegistry) -> Result<(), SignatureError> {
// TODO https://github.com/CQCL/hugr/issues/624 validate declared TypeParams
// for both type scheme and custom binary
if let SignatureFunc::TypeScheme(ts) = &self.signature_func {
ts.validate(exts, &[])?;
}
Ok(())
}
}

impl Extension {
Expand Down Expand Up @@ -356,7 +365,7 @@ mod test {

use crate::builder::{DFGBuilder, Dataflow, DataflowHugr};
use crate::extension::prelude::USIZE_T;
use crate::extension::PRELUDE;
use crate::extension::{ExtensionRegistry, PRELUDE};
use crate::ops::custom::ExternalOp;
use crate::ops::LeafOp;
use crate::std_extensions::collections::{EXTENSION, LIST_TYPENAME};
Expand All @@ -370,34 +379,29 @@ mod test {

#[test]
fn op_def_with_type_scheme() -> Result<(), Box<dyn std::error::Error>> {
let reg1 = [PRELUDE.to_owned(), EXTENSION.to_owned()].into();
let list_def = EXTENSION.get_type(&LIST_TYPENAME).unwrap();
let mut e = Extension::new(EXT_ID);
const TP: TypeParam = TypeParam::Type(TypeBound::Any);
let list_of_var =
Type::new_extension(list_def.instantiate(vec![TypeArg::new_var_use(0, TP)])?);
const OP_NAME: SmolStr = SmolStr::new_inline("Reverse");
let type_scheme = PolyFuncType::new_validated(
vec![TP],
FunctionType::new_endo(vec![list_of_var]),
&reg1,
)?;
let type_scheme = PolyFuncType::new(vec![TP], FunctionType::new_endo(vec![list_of_var]));
e.add_op_type_scheme(OP_NAME, "".into(), Default::default(), vec![], type_scheme)?;
let reg =
ExtensionRegistry::try_new([PRELUDE.to_owned(), EXTENSION.to_owned(), e]).unwrap();
let e = reg.get(&EXT_ID).unwrap();

let list_usize =
Type::new_extension(list_def.instantiate(vec![TypeArg::Type { ty: USIZE_T }])?);
let mut dfg = DFGBuilder::new(FunctionType::new_endo(vec![list_usize]))?;
let rev = dfg.add_dataflow_op(
LeafOp::from(ExternalOp::Extension(
e.instantiate_extension_op(&OP_NAME, vec![TypeArg::Type { ty: USIZE_T }], &reg1)
e.instantiate_extension_op(&OP_NAME, vec![TypeArg::Type { ty: USIZE_T }], &reg)
.unwrap(),
)),
dfg.input_wires(),
)?;
dfg.finish_hugr_with_outputs(
rev.outputs(),
&[PRELUDE.to_owned(), EXTENSION.to_owned(), e].into(),
)?;
dfg.finish_hugr_with_outputs(rev.outputs(), &reg)?;

Ok(())
}
Expand Down
3 changes: 2 additions & 1 deletion src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ lazy_static! {
prelude
};
/// An extension registry containing only the prelude
pub static ref PRELUDE_REGISTRY: ExtensionRegistry = [PRELUDE_DEF.to_owned()].into();
pub static ref PRELUDE_REGISTRY: ExtensionRegistry =
ExtensionRegistry::try_new([PRELUDE_DEF.to_owned()]).unwrap();

/// Prelude extension
pub static ref PRELUDE: &'static Extension = PRELUDE_REGISTRY.get(&PRELUDE_ID).unwrap();
Expand Down
4 changes: 3 additions & 1 deletion src/hugr/rewrite/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,9 @@ mod test {

#[test]
fn cfg() -> Result<(), Box<dyn std::error::Error>> {
let reg: ExtensionRegistry = [PRELUDE.to_owned(), collections::EXTENSION.to_owned()].into();
let reg =
ExtensionRegistry::try_new([PRELUDE.to_owned(), collections::EXTENSION.to_owned()])
.unwrap();
let listy = Type::new_extension(
collections::EXTENSION
.get_type(collections::LIST_TYPENAME.as_str())
Expand Down
2 changes: 1 addition & 1 deletion src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ fn invalid_types() {
TypeDefBound::Explicit(TypeBound::Any),
)
.unwrap();
let reg: ExtensionRegistry = [e, PRELUDE.to_owned()].into();
let reg = ExtensionRegistry::try_new([e, PRELUDE.to_owned()]).unwrap();

let validate_to_sig_error = |t: CustomType| {
let (h, def) = identity_hugr_with_type(Type::new_extension(t));
Expand Down
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@
//! lazy_static! {
//! /// Quantum extension definition.
//! pub static ref EXTENSION: Extension = extension();
//! static ref REG: ExtensionRegistry = [EXTENSION.to_owned(), PRELUDE.to_owned()].into();
//! static ref REG: ExtensionRegistry =
//! ExtensionRegistry::try_new([EXTENSION.to_owned(), PRELUDE.to_owned()]).unwrap();
//!
//! }
//! fn get_gate(gate_name: &str) -> LeafOp {
Expand Down
9 changes: 6 additions & 3 deletions src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,10 @@ mod test {
builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr},
extension::{
prelude::{ConstUsize, USIZE_T},
ExtensionId, ExtensionSet,
ExtensionId, ExtensionRegistry, ExtensionSet, PRELUDE,
},
std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE},
std_extensions::arithmetic::float_types::{self, ConstF64, FLOAT64_TYPE},
type_row,
types::test::test_registry,
types::type_param::TypeArg,
types::{CustomCheckFailure, CustomType, FunctionType, Type, TypeBound, TypeRow},
values::{
Expand All @@ -143,6 +142,10 @@ mod test {

use super::*;

fn test_registry() -> ExtensionRegistry {
ExtensionRegistry::try_new([PRELUDE.to_owned(), float_types::extension()]).unwrap()
}

#[test]
fn test_tuple_sum() -> Result<(), BuildError> {
use crate::builder::Container;
Expand Down
6 changes: 3 additions & 3 deletions src/ops/leaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ impl DataflowOpTrait for LeafOp {
mod test {
use crate::builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr};
use crate::extension::prelude::BOOL_T;
use crate::extension::SignatureError;
use crate::extension::{prelude::USIZE_T, PRELUDE};
use crate::extension::{ExtensionRegistry, SignatureError};
use crate::hugr::ValidationError;
use crate::ops::handle::NodeHandle;
use crate::std_extensions::collections::EXTENSION;
Expand All @@ -206,7 +206,7 @@ mod test {

#[test]
fn hugr_with_type_apply() -> Result<(), Box<dyn std::error::Error>> {
let reg = [PRELUDE.to_owned(), EXTENSION.to_owned()].into();
let reg = ExtensionRegistry::try_new([PRELUDE.to_owned(), EXTENSION.to_owned()]).unwrap();
let pf_in = nested_func();
let pf_out = pf_in.instantiate(&[USIZE_TA], &reg)?;
let mut dfg = DFGBuilder::new(FunctionType::new(
Expand All @@ -225,7 +225,7 @@ mod test {

#[test]
fn bad_type_apply() -> Result<(), Box<dyn std::error::Error>> {
let reg = [PRELUDE.to_owned(), EXTENSION.to_owned()].into();
let reg = ExtensionRegistry::try_new([PRELUDE.to_owned(), EXTENSION.to_owned()]).unwrap();
let pf = nested_func();
let pf_usz = pf.instantiate_poly(&[USIZE_TA], &reg)?;
let pf_bool = pf.instantiate_poly(&[TypeArg::Type { ty: BOOL_T }], &reg)?;
Expand Down
49 changes: 17 additions & 32 deletions src/std_extensions/arithmetic/conversions.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
//! Conversions between integer and floating-point values.

use crate::{
extension::{
prelude::sum_with_error, ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError,
PRELUDE,
},
extension::{prelude::sum_with_error, ExtensionId, ExtensionSet},
type_row,
types::{FunctionType, PolyFuncType},
Extension,
Expand All @@ -16,62 +13,50 @@ use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM};
/// The extension identifier.
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions");

fn ftoi_sig(temp_reg: &ExtensionRegistry) -> Result<PolyFuncType, SignatureError> {
let body = FunctionType::new(
type_row![FLOAT64_TYPE],
vec![sum_with_error(int_type_var(0))],
/// Extension for basic arithmetic operations.
pub fn extension() -> Extension {
let ftoi_sig = PolyFuncType::new(
vec![LOG_WIDTH_TYPE_PARAM],
FunctionType::new(
type_row![FLOAT64_TYPE],
vec![sum_with_error(int_type_var(0))],
),
);

PolyFuncType::new_validated(vec![LOG_WIDTH_TYPE_PARAM], body, temp_reg)
}

fn itof_sig(temp_reg: &ExtensionRegistry) -> Result<PolyFuncType, SignatureError> {
let body = FunctionType::new(vec![int_type_var(0)], type_row![FLOAT64_TYPE]);

PolyFuncType::new_validated(vec![LOG_WIDTH_TYPE_PARAM], body, temp_reg)
}
let itof_sig = PolyFuncType::new(
vec![LOG_WIDTH_TYPE_PARAM],
FunctionType::new(vec![int_type_var(0)], type_row![FLOAT64_TYPE]),
);

/// Extension for basic arithmetic operations.
pub fn extension() -> Extension {
let mut extension = Extension::new_with_reqs(
EXTENSION_ID,
ExtensionSet::from_iter(vec![
super::int_types::EXTENSION_ID,
super::float_types::EXTENSION_ID,
]),
);
let temp_reg: ExtensionRegistry = [
super::int_types::EXTENSION.to_owned(),
super::float_types::extension(),
PRELUDE.to_owned(),
]
.into();
extension
.add_op_type_scheme_simple(
"trunc_u".into(),
"float to unsigned int".to_owned(),
ftoi_sig(&temp_reg).unwrap(),
ftoi_sig.clone(),
)
.unwrap();
extension
.add_op_type_scheme_simple(
"trunc_s".into(),
"float to signed int".to_owned(),
ftoi_sig(&temp_reg).unwrap(),
)
.add_op_type_scheme_simple("trunc_s".into(), "float to signed int".to_owned(), ftoi_sig)
.unwrap();
extension
.add_op_type_scheme_simple(
"convert_u".into(),
"unsigned int to float".to_owned(),
itof_sig(&temp_reg).unwrap(),
itof_sig.clone(),
)
.unwrap();
extension
.add_op_type_scheme_simple(
"convert_s".into(),
"signed int to float".to_owned(),
itof_sig(&temp_reg).unwrap(),
itof_sig,
)
.unwrap();

Expand Down
Loading