Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor!: Replace SmolStr identifiers with wrapper types. #959

Merged
merged 7 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hugr/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ use thiserror::Error;
use crate::extension::SignatureError;
use crate::hugr::ValidationError;
use crate::ops::handle::{BasicBlockID, CfgID, ConditionalID, DfgID, FuncID, TailLoopID};
use crate::ops::{OpName, OpType};
use crate::ops::{NamedOp, OpType};
use crate::types::ConstTypeError;
use crate::types::Type;
use crate::{Node, Port, Wire};
Expand Down
2 changes: 1 addition & 1 deletion hugr/src/builder/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::mem;

use thiserror::Error;

use crate::ops::{OpName, OpType};
use crate::ops::{NamedOp, OpType};
use crate::utils::collect_array;

use super::{BuildError, Dataflow};
Expand Down
51 changes: 30 additions & 21 deletions hugr/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@ use std::collections::{BTreeMap, BTreeSet, HashMap};
use std::fmt::{Debug, Display, Formatter};
use std::sync::Arc;

use smol_str::SmolStr;
use thiserror::Error;

use crate::hugr::IdentList;
use crate::ops;
use crate::ops::constant::{ValueName, ValueNameRef};
use crate::ops::custom::{ExtensionOp, OpaqueOp};
use crate::ops::{self, OpName, OpNameRef};
use crate::types::type_param::{check_type_args, TypeArgError};
use crate::types::type_param::{TypeArg, TypeParam};
use crate::types::FunctionType;
use crate::types::{check_typevar_decl, CustomType, Substitution, TypeBound, TypeName};
use crate::types::{FunctionType, TypeNameRef};

#[allow(dead_code)]
mod infer;
Expand Down Expand Up @@ -177,18 +177,22 @@ pub enum SignatureError {

/// Concrete instantiations of types and operations defined in extensions.
trait CustomConcrete {
/// The identifier type for the concrete object.
type Identifier;
Comment on lines 179 to +181
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left this associated type even though it always gets instantiated with SmolStr.

It doesn't change the way we use this trait but documents a bit better the behaviour.

/// A generic identifier to the element.
///
/// This may either refer to a [`TypeName`] or an [`OpName`].
fn def_name(&self) -> &SmolStr;
fn def_name(&self) -> &Self::Identifier;
/// The concrete type arguments for the instantiation.
fn type_args(&self) -> &[TypeArg];
/// Extension required by the instantiation.
fn parent_extension(&self) -> &ExtensionId;
}

impl CustomConcrete for OpaqueOp {
fn def_name(&self) -> &SmolStr {
type Identifier = OpName;

fn def_name(&self) -> &OpName {
self.name()
}

Expand All @@ -202,7 +206,9 @@ impl CustomConcrete for OpaqueOp {
}

impl CustomConcrete for CustomType {
fn def_name(&self) -> &SmolStr {
type Identifier = TypeName;

fn def_name(&self) -> &TypeName {
// Casts the `TypeName` to a generic string.
self.name()
}
Expand All @@ -221,7 +227,7 @@ trait TypeParametrised {
/// The concrete object built by binding type arguments to parameters
type Concrete: CustomConcrete;
/// The extension-unique name.
fn name(&self) -> &SmolStr;
fn name(&self) -> &<Self::Concrete as CustomConcrete>::Identifier;
/// Type parameters.
fn params(&self) -> &[TypeParam];
/// The parent extension.
Expand All @@ -237,7 +243,7 @@ trait TypeParametrised {
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct ExtensionValue {
extension: ExtensionId,
name: SmolStr,
name: ValueName,
typed_value: ops::Value,
}

Expand All @@ -249,7 +255,7 @@ impl ExtensionValue {

/// Returns a reference to the name of this [`ExtensionValue`].
pub fn name(&self) -> &str {
self.name.as_ref()
self.name.as_str()
}

/// Returns a reference to the extension this [`ExtensionValue`] belongs to.
Expand All @@ -276,14 +282,14 @@ pub struct Extension {
/// Types defined by this extension.
types: HashMap<TypeName, TypeDef>,
/// Static values defined by this extension.
values: HashMap<SmolStr, ExtensionValue>,
values: HashMap<ValueName, ExtensionValue>,
/// Operation declarations with serializable definitions.
// Note: serde will serialize this because we configure with `features=["rc"]`.
// That will clone anything that has multiple references, but each
// OpDef should appear exactly once in this map (keyed by its name),
// and the other references to the OpDef are from ExternalOp's in the Hugr
// (which are serialized as OpaqueOp's i.e. Strings).
operations: HashMap<SmolStr, Arc<op_def::OpDef>>,
operations: HashMap<OpName, Arc<op_def::OpDef>>,
}

impl Extension {
Expand All @@ -304,18 +310,18 @@ impl Extension {
}

/// Allows read-only access to the operations in this Extension
pub fn get_op(&self, op_name: &str) -> Option<&Arc<op_def::OpDef>> {
pub fn get_op(&self, op_name: &OpNameRef) -> Option<&Arc<op_def::OpDef>> {
self.operations.get(op_name)
}

/// Allows read-only access to the types in this Extension
pub fn get_type(&self, type_name: &str) -> Option<&type_def::TypeDef> {
pub fn get_type(&self, type_name: &TypeNameRef) -> Option<&type_def::TypeDef> {
self.types.get(type_name)
}

/// Allows read-only access to the values in this Extension
pub fn get_value(&self, type_name: &str) -> Option<&ExtensionValue> {
self.values.get(type_name)
pub fn get_value(&self, value_name: &ValueNameRef) -> Option<&ExtensionValue> {
self.values.get(value_name)
}

/// Returns the name of the extension.
Expand All @@ -324,7 +330,7 @@ impl Extension {
}

/// Iterator over the operations of this [`Extension`].
pub fn operations(&self) -> impl Iterator<Item = (&SmolStr, &Arc<OpDef>)> {
pub fn operations(&self) -> impl Iterator<Item = (&OpName, &Arc<OpDef>)> {
self.operations.iter()
}

Expand All @@ -336,7 +342,7 @@ impl Extension {
/// Add a named static value to the extension.
pub fn add_value(
&mut self,
name: impl Into<SmolStr>,
name: impl Into<ValueName>,
typed_value: ops::Value,
) -> Result<&mut ExtensionValue, ExtensionBuildError> {
let extension_value = ExtensionValue {
Expand All @@ -346,7 +352,7 @@ impl Extension {
};
match self.values.entry(extension_value.name.clone()) {
hash_map::Entry::Occupied(_) => {
Err(ExtensionBuildError::OpDefExists(extension_value.name))
Err(ExtensionBuildError::ValueExists(extension_value.name))
}
hash_map::Entry::Vacant(ve) => Ok(ve.insert(extension_value)),
}
Expand All @@ -355,7 +361,7 @@ impl Extension {
/// Instantiate an [`ExtensionOp`] which references an [`OpDef`] in this extension.
pub fn instantiate_extension_op(
&self,
op_name: &str,
op_name: &OpNameRef,
args: impl Into<Vec<TypeArg>>,
ext_reg: &ExtensionRegistry,
) -> Result<ExtensionOp, SignatureError> {
Expand Down Expand Up @@ -396,10 +402,13 @@ pub enum ExtensionRegistryError {
pub enum ExtensionBuildError {
/// Existing [`OpDef`]
#[error("Extension already has an op called {0}.")]
OpDefExists(SmolStr),
OpDefExists(OpName),
/// Existing [`TypeDef`]
#[error("Extension already has an type called {0}.")]
TypeDefExists(SmolStr),
TypeDefExists(TypeName),
/// Existing [`ExtensionValue`]
#[error("Extension already has an extension value called {0}.")]
ValueExists(ValueName),
}

/// A set of extensions identified by their unique [`ExtensionId`].
Expand Down
7 changes: 4 additions & 3 deletions hugr/src/extension/declarative.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use std::fs::File;
use std::path::Path;

use crate::extension::prelude::PRELUDE_ID;
use crate::ops::OpName;
use crate::types::TypeName;
use crate::Extension;

Expand Down Expand Up @@ -221,7 +222,7 @@ pub enum ExtensionDeclarationError {
/// The extension that referenced the unsupported op parameter.
ext: ExtensionId,
/// The operation.
op: SmolStr,
op: OpName,
},
/// Operation definitions with no signature are not currently supported.
///
Expand All @@ -233,7 +234,7 @@ pub enum ExtensionDeclarationError {
/// The extension containing the operation.
ext: ExtensionId,
/// The operation with no signature.
op: SmolStr,
op: OpName,
},
/// An unknown type was specified in a signature.
#[error("Type {ty} is not in scope. In extension {ext}.")]
Expand Down Expand Up @@ -261,7 +262,7 @@ pub enum ExtensionDeclarationError {
/// The extension.
ext: crate::hugr::IdentList,
/// The operation with the lowering definition.
op: SmolStr,
op: OpName,
},
}

Expand Down
3 changes: 2 additions & 1 deletion hugr/src/extension/declarative/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use serde::{Deserialize, Serialize};
use smol_str::SmolStr;

use crate::extension::{OpDef, SignatureFunc};
use crate::ops::OpName;
use crate::types::type_param::TypeParam;
use crate::Extension;

Expand All @@ -25,7 +26,7 @@ use super::{DeclarationContext, ExtensionDeclarationError};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(super) struct OperationDeclaration {
/// The identifier the operation.
name: SmolStr,
name: OpName,
/// A description for the operation.
#[serde(default)]
#[serde(skip_serializing_if = "crate::utils::is_default")]
Expand Down
8 changes: 2 additions & 6 deletions hugr/src/extension/declarative/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,17 +206,13 @@ impl TypeDeclaration {
}

// Try to resolve the type in the current extension.
if let Some(ty) = ext.get_type(self.0.as_str()) {
if let Some(ty) = ext.get_type(&self.0) {
return Some(ty);
}

// Try to resolve the type in the other extensions in scope.
for ext in ctx.scope.iter() {
if let Some(ty) = ctx
.registry
.get(ext)
.and_then(|ext| ext.get_type(self.0.as_str()))
{
if let Some(ty) = ctx.registry.get(ext).and_then(|ext| ext.get_type(&self.0)) {
return Some(ty);
}
}
Expand Down
17 changes: 7 additions & 10 deletions hugr/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@ use std::collections::HashMap;
use std::fmt::{Debug, Formatter};
use std::sync::Arc;

use smol_str::SmolStr;

use super::{
ConstFold, ConstFoldResult, Extension, ExtensionBuildError, ExtensionId, ExtensionRegistry,
ExtensionSet, SignatureError,
};

use crate::ops::{OpName, OpNameRef};
use crate::types::type_param::{check_type_args, TypeArg, TypeParam};
use crate::types::{FunctionType, PolyFuncType};
use crate::Hugr;
Expand Down Expand Up @@ -108,7 +107,7 @@ pub trait CustomLowerFunc: Send + Sync {
/// TODO: some error type to indicate Extensions required?
fn try_lower(
&self,
name: &SmolStr,
name: &OpNameRef,
arg_values: &[TypeArg],
misc: &HashMap<String, serde_yaml::Value>,
available_extensions: &ExtensionSet,
Expand Down Expand Up @@ -295,7 +294,7 @@ pub struct OpDef {
extension: ExtensionId,
/// Unique identifier of the operation. Used to look up OpDefs in the registry
/// when deserializing nodes (which store only the name).
name: SmolStr,
name: OpName,
/// Human readable description of the operation.
description: String,
/// Miscellaneous data associated with the operation.
Expand Down Expand Up @@ -376,7 +375,7 @@ impl OpDef {
}

/// Returns a reference to the name of this [`OpDef`].
pub fn name(&self) -> &SmolStr {
pub fn name(&self) -> &OpName {
&self.name
}

Expand Down Expand Up @@ -442,7 +441,7 @@ impl Extension {
/// function for computing the signature given type arguments (`impl [CustomSignatureFunc]`).
pub fn add_op(
&mut self,
name: SmolStr,
name: OpName,
description: String,
signature_func: impl Into<SignatureFunc>,
) -> Result<&mut OpDef, ExtensionBuildError> {
Expand All @@ -468,15 +467,13 @@ impl Extension {
mod test {
use std::num::NonZeroU64;

use smol_str::SmolStr;

use super::SignatureFromArgs;
use crate::builder::{DFGBuilder, Dataflow, DataflowHugr};
use crate::extension::op_def::LowerFunc;
use crate::extension::prelude::USIZE_T;
use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE};
use crate::extension::{SignatureError, EMPTY_REG, PRELUDE_REGISTRY};
use crate::ops::CustomOp;
use crate::ops::{CustomOp, OpName};
use crate::std_extensions::collections::{EXTENSION, LIST_TYPENAME};
use crate::types::Type;
use crate::types::{type_param::TypeParam, FunctionType, PolyFuncType, TypeArg, TypeBound};
Expand All @@ -494,7 +491,7 @@ mod test {
const TP: TypeParam = TypeParam::Type { b: TypeBound::Any };
let list_of_var =
Type::new_extension(list_def.instantiate(vec![TypeArg::new_var_use(0, TP)])?);
const OP_NAME: SmolStr = SmolStr::new_inline("Reverse");
const OP_NAME: OpName = OpName::new_inline("Reverse");
let type_scheme = PolyFuncType::new(vec![TP], FunctionType::new_endo(vec![list_of_var]));

let def = e.add_op(OP_NAME, "desc".into(), type_scheme)?;
Expand Down
Loading