Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: remove Signature struct #714

Merged
merged 3 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions src/builder/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,9 +437,8 @@ pub(crate) mod test {
FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(&abc_extensions);
let mut parent = DFGBuilder::new(parent_sig)?;

let add_c_sig = FunctionType::new(type_row![BIT], type_row![BIT])
.with_extension_delta(&c_extensions)
.with_input_extensions(ab_extensions.clone());
let add_c_sig =
FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(&c_extensions);

let [w] = parent.input_wires_arr();

Expand Down Expand Up @@ -476,8 +475,7 @@ pub(crate) mod test {

// Add another node (a sibling to add_ab) which adds extension C
// via a child lift node
let mut add_c =
parent.dfg_builder(add_c_sig.signature, Some(add_c_sig.input_extensions), [w])?;
let mut add_c = parent.dfg_builder(add_c_sig, Some(ab_extensions.clone()), [w])?;
let [w] = add_c.input_wires_arr();
let lift_c = add_c.add_dataflow_node(
NodeType::new(
Expand Down
39 changes: 11 additions & 28 deletions src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -453,13 +453,10 @@ fn missing_lift_node() -> Result<(), BuildError> {
)?;
let [main_input] = main.input_wires_arr();

let inner_sig = FunctionType::new(type_row![NAT], type_row![NAT])
// Inner DFG has extension requirements that the wire wont satisfy
.with_input_extensions(ExtensionSet::from_iter([XA, XB]));

let f_builder = main.dfg_builder(
inner_sig.signature,
Some(inner_sig.input_extensions),
FunctionType::new(type_row![NAT], type_row![NAT]),
// Inner DFG has extension requirements that the wire wont satisfy
Some(ExtensionSet::from_iter([XA, XB])),
[main_input],
)?;
let f_inputs = f_builder.input_wires();
Expand Down Expand Up @@ -491,14 +488,9 @@ fn too_many_extension() -> Result<(), BuildError> {
let [main_input] = main.input_wires_arr();

let inner_sig = FunctionType::new(type_row![NAT], type_row![NAT])
.with_extension_delta(&ExtensionSet::singleton(&XA))
.with_input_extensions(ExtensionSet::new());
.with_extension_delta(&ExtensionSet::singleton(&XA));

let f_builder = main.dfg_builder(
inner_sig.signature,
Some(inner_sig.input_extensions),
[main_input],
)?;
let f_builder = main.dfg_builder(inner_sig, Some(ExtensionSet::new()), [main_input])?;
let f_inputs = f_builder.input_wires();
let f_handle = f_builder.finish_with_outputs(f_inputs)?;
let [f_output] = f_handle.outputs_arr();
Expand Down Expand Up @@ -529,36 +521,27 @@ fn extensions_mismatch() -> Result<(), BuildError> {

let mut main = module_builder.define_function("main", main_sig)?;

let inner_left_sig = FunctionType::new(type_row![], type_row![NAT])
.with_input_extensions(ExtensionSet::singleton(&XA));

let inner_right_sig = FunctionType::new(type_row![], type_row![NAT])
.with_input_extensions(ExtensionSet::singleton(&XB));

let inner_mult_sig =
FunctionType::new(type_row![NAT, NAT], type_row![NAT]).with_input_extensions(all_rs);

let [left_wire] = main
.dfg_builder(
inner_left_sig.signature,
Some(inner_left_sig.input_extensions),
FunctionType::new(type_row![], type_row![NAT]),
Some(ExtensionSet::singleton(&XA)),
[],
)?
.finish_with_outputs([])?
.outputs_arr();

let [right_wire] = main
.dfg_builder(
inner_right_sig.signature,
Some(inner_right_sig.input_extensions),
FunctionType::new(type_row![], type_row![NAT]),
Some(ExtensionSet::singleton(&XB)),
[],
)?
.finish_with_outputs([])?
.outputs_arr();

let builder = main.dfg_builder(
inner_mult_sig.signature,
Some(inner_mult_sig.input_extensions),
FunctionType::new(type_row![NAT, NAT], type_row![NAT]),
Some(all_rs),
[left_wire, right_wire],
)?;
let [_left, _right] = builder.input_wires_arr();
Expand Down
2 changes: 1 addition & 1 deletion src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ pub(crate) use impl_box_clone;
/// const U: Type = Type::UNIT;
/// let static_row: TypeRow = type_row![U, U];
/// let dynamic_row: TypeRow = vec![U, U, U].into();
/// let sig = FunctionType::new(static_row, dynamic_row).pure();
/// let sig = FunctionType::new(static_row, dynamic_row);
///
/// let repeated_row: TypeRow = type_row![U; 3];
/// assert_eq!(repeated_row, *sig.output());
Expand Down
2 changes: 1 addition & 1 deletion src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub mod type_row;
pub use check::{ConstTypeError, CustomCheckFailure};
pub use custom::CustomType;
pub use poly_func::PolyFuncType;
pub use signature::{FunctionType, Signature};
pub use signature::FunctionType;
pub use type_param::TypeArg;
pub use type_row::TypeRow;

Expand Down
82 changes: 1 addition & 81 deletions src/types/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

use itertools::Either;

use delegate::delegate;
use std::fmt::{self, Display, Write};

use super::type_param::TypeParam;
Expand All @@ -23,35 +22,13 @@ pub struct FunctionType {
pub extension_reqs: ExtensionSet,
}

#[derive(Clone, Default, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
/// A combination of a FunctionType and a set of input extensions, used for declaring functions
pub struct Signature {
/// The underlying signature
pub signature: FunctionType,
/// The extensions which are associated with all the inputs and carried through
pub input_extensions: ExtensionSet,
}

impl FunctionType {
/// Builder method, add extension_reqs to an FunctionType
pub fn with_extension_delta(mut self, rs: &ExtensionSet) -> Self {
self.extension_reqs = self.extension_reqs.union(rs);
self
}

/// Instantiate an FunctionType, converting it to a concrete one
pub fn with_input_extensions(self, es: ExtensionSet) -> Signature {
Signature {
signature: self,
input_extensions: es,
}
}

/// Instantiate a signature with the empty set of extensions
pub fn pure(self) -> Signature {
self.with_input_extensions(ExtensionSet::new())
}

pub(crate) fn validate(
&self,
extension_registry: &ExtensionRegistry,
Expand All @@ -73,21 +50,6 @@ impl FunctionType {
}
}

impl From<Signature> for FunctionType {
fn from(sig: Signature) -> Self {
sig.signature
}
}

impl Signature {
/// Calculate the extension requirements of the output wires
pub fn output_extensions(&self) -> ExtensionSet {
self.input_extensions
.clone()
.union(&self.signature.extension_reqs)
}
}

impl FunctionType {
/// The number of wires in the signature.
#[inline(always)]
Expand Down Expand Up @@ -239,17 +201,6 @@ impl FunctionType {
}
}

impl Signature {
delegate! {
to self.signature {
/// Inputs of the function type
pub fn input(&self) -> &TypeRow;
/// Outputs of the function type
pub fn output(&self) -> &TypeRow;
}
}
}

impl Display for FunctionType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if !self.input.is_empty() {
Expand All @@ -263,20 +214,9 @@ impl Display for FunctionType {
}
}

impl Display for Signature {
delegate! {
to self.signature {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result;
}
}
}

#[cfg(test)]
mod test {
use crate::{
extension::{prelude::USIZE_T, ExtensionId},
type_row,
};
use crate::{extension::prelude::USIZE_T, type_row};

use super::*;
#[test]
Expand All @@ -300,24 +240,4 @@ mod test {
assert_eq!(f_type.input_types(), &[Type::UNIT]);
assert_eq!(f_type.output_types(), &[USIZE_T]);
}

#[test]
fn test_signature() {
let f_type = FunctionType::new(type_row![Type::UNIT], type_row![USIZE_T]);

let sig: Signature = f_type.pure();

assert_eq!(sig.input(), &type_row![Type::UNIT]);
assert_eq!(sig.output(), &type_row![USIZE_T]);
}

#[test]
fn test_display() {
let f_type = FunctionType::new(type_row![Type::UNIT], type_row![USIZE_T]);
assert_eq!(f_type.to_string(), "[Tuple([])] -> [[]][usize([])]");
let sig: Signature = f_type.with_input_extensions(ExtensionSet::singleton(
&ExtensionId::new("Example").unwrap(),
));
assert_eq!(sig.to_string(), "[Tuple([])] -> [[]][usize([])]");
}
}