Skip to content

Commit

Permalink
refactor: use type schemes in extension definitions wherever possible (
Browse files Browse the repository at this point in the history
…#678)

Closes #658 

As a drive by improve extensions test coverage as I went.

Definition can be a bit unintuitive, should improve after #676 is done.
  • Loading branch information
ss2165 authored Nov 13, 2023
1 parent d32d033 commit 201d1a2
Show file tree
Hide file tree
Showing 11 changed files with 387 additions and 373 deletions.
17 changes: 17 additions & 0 deletions src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,23 @@ impl Extension {
SignatureFunc::TypeScheme(type_scheme),
)
}

/// Create an OpDef with a signature (inputs+outputs) read from e.g.
/// declarative YAML; and no "misc" or "lowering functions" defined.
pub fn add_op_type_scheme_simple(
&mut self,
name: SmolStr,
description: String,
type_scheme: PolyFuncType,
) -> Result<&OpDef, ExtensionBuildError> {
self.add_op(
name,
description,
Default::default(),
vec![],
SignatureFunc::TypeScheme(type_scheme),
)
}
}

#[cfg(test)]
Expand Down
5 changes: 5 additions & 0 deletions src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ pub const ERROR_TYPE: Type = Type::new_extension(CustomType::new_simple(
/// The string name of the error type.
pub const ERROR_TYPE_NAME: SmolStr = SmolStr::new_inline("error");

/// Return a Sum type with the first variant as the given type and the second an Error.
pub fn sum_with_error(ty: Type) -> Type {
Type::new_sum(vec![ty, ERROR_TYPE])
}

#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
/// Structure for holding constant usize values.
pub struct ConstUsize(u64);
Expand Down
10 changes: 5 additions & 5 deletions src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::macros::const_extension_ids;
use crate::ops::dataflow::IOTrait;
use crate::ops::{self, LeafOp, OpType};
use crate::std_extensions::logic;
use crate::std_extensions::logic::test::{and_op, not_op};
use crate::std_extensions::logic::test::{and_op, not_op, or_op};
use crate::types::type_param::{TypeArg, TypeArgError, TypeParam};
use crate::types::{CustomType, FunctionType, Type, TypeBound, TypeRow};
use crate::{type_row, Direction, IncomingPort, Node};
Expand Down Expand Up @@ -612,12 +612,12 @@ fn dfg_with_cycles() -> Result<(), HugrError> {
type_row![BOOL_T],
));
let [input, output] = h.get_io(h.root()).unwrap();
let and = h.add_node_with_parent(h.root(), and_op())?;
let or = h.add_node_with_parent(h.root(), or_op())?;
let not1 = h.add_node_with_parent(h.root(), not_op())?;
let not2 = h.add_node_with_parent(h.root(), not_op())?;
h.connect(input, 0, and, 0)?;
h.connect(and, 0, not1, 0)?;
h.connect(not1, 0, and, 1)?;
h.connect(input, 0, or, 0)?;
h.connect(or, 0, not1, 0)?;
h.connect(not1, 0, or, 1)?;
h.connect(input, 1, not2, 0)?;
h.connect(not2, 0, output, 0)?;
// The graph contains a cycle:
Expand Down
61 changes: 30 additions & 31 deletions src/std_extensions/arithmetic/conversions.rs
Original file line number Diff line number Diff line change
@@ -1,36 +1,34 @@
//! Conversions between integer and floating-point values.
use crate::{
extension::{ExtensionId, ExtensionSet, SignatureError},
extension::{
prelude::sum_with_error, ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError,
PRELUDE,
},
type_row,
types::{type_param::TypeArg, FunctionType, Type},
utils::collect_array,
types::{FunctionType, PolyFuncType},
Extension,
};

use super::int_types::int_type;
use super::int_types::int_type_var;
use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM};

/// The extension identifier.
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions");

fn ftoi_sig(arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
let [arg] = collect_array(arg_values);
Ok(FunctionType::new(
fn ftoi_sig(temp_reg: &ExtensionRegistry) -> Result<PolyFuncType, SignatureError> {
let body = FunctionType::new(
type_row![FLOAT64_TYPE],
vec![Type::new_sum(vec![
int_type(arg.clone()),
crate::extension::prelude::ERROR_TYPE,
])],
))
vec![sum_with_error(int_type_var(0))],
);

PolyFuncType::new_validated(vec![LOG_WIDTH_TYPE_PARAM], body, temp_reg)
}

fn itof_sig(arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
let [arg] = collect_array(arg_values);
Ok(FunctionType::new(
vec![int_type(arg.clone())],
type_row![FLOAT64_TYPE],
))
fn itof_sig(temp_reg: &ExtensionRegistry) -> Result<PolyFuncType, SignatureError> {
let body = FunctionType::new(vec![int_type_var(0)], type_row![FLOAT64_TYPE]);

PolyFuncType::new_validated(vec![LOG_WIDTH_TYPE_PARAM], body, temp_reg)
}

/// Extension for basic arithmetic operations.
Expand All @@ -42,37 +40,38 @@ pub fn extension() -> Extension {
super::float_types::EXTENSION_ID,
]),
);

let temp_reg: ExtensionRegistry = [
super::int_types::EXTENSION.to_owned(),
super::float_types::extension(),
PRELUDE.to_owned(),
]
.into();
extension
.add_op_custom_sig_simple(
.add_op_type_scheme_simple(
"trunc_u".into(),
"float to unsigned int".to_owned(),
vec![LOG_WIDTH_TYPE_PARAM],
ftoi_sig,
ftoi_sig(&temp_reg).unwrap(),
)
.unwrap();
extension
.add_op_custom_sig_simple(
.add_op_type_scheme_simple(
"trunc_s".into(),
"float to signed int".to_owned(),
vec![LOG_WIDTH_TYPE_PARAM],
ftoi_sig,
ftoi_sig(&temp_reg).unwrap(),
)
.unwrap();
extension
.add_op_custom_sig_simple(
.add_op_type_scheme_simple(
"convert_u".into(),
"unsigned int to float".to_owned(),
vec![LOG_WIDTH_TYPE_PARAM],
itof_sig,
itof_sig(&temp_reg).unwrap(),
)
.unwrap();
extension
.add_op_custom_sig_simple(
.add_op_type_scheme_simple(
"convert_s".into(),
"signed int to float".to_owned(),
vec![LOG_WIDTH_TYPE_PARAM],
itof_sig,
itof_sig(&temp_reg).unwrap(),
)
.unwrap();

Expand Down
79 changes: 31 additions & 48 deletions src/std_extensions/arithmetic/float_ops.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
//! Basic floating-point operations.
use crate::{
extension::{ExtensionId, ExtensionSet, SignatureError},
extension::{ExtensionId, ExtensionSet},
type_row,
types::{type_param::TypeArg, FunctionType},
types::{FunctionType, PolyFuncType},
Extension,
};

Expand All @@ -12,106 +12,89 @@ use super::float_types::FLOAT64_TYPE;
/// The extension identifier.
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.float");

fn fcmp_sig(_arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
Ok(FunctionType::new(
type_row![FLOAT64_TYPE; 2],
type_row![crate::extension::prelude::BOOL_T],
))
}

fn fbinop_sig(_arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
Ok(FunctionType::new(
type_row![FLOAT64_TYPE; 2],
type_row![FLOAT64_TYPE],
))
}

fn funop_sig(_arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
Ok(FunctionType::new(
type_row![FLOAT64_TYPE],
type_row![FLOAT64_TYPE],
))
}

/// Extension for basic arithmetic operations.
pub fn extension() -> Extension {
let mut extension = Extension::new_with_reqs(
EXTENSION_ID,
ExtensionSet::singleton(&super::float_types::EXTENSION_ID),
);

let fcmp_sig: PolyFuncType = FunctionType::new(
type_row![FLOAT64_TYPE; 2],
type_row![crate::extension::prelude::BOOL_T],
)
.into();
let fbinop_sig: PolyFuncType =
FunctionType::new(type_row![FLOAT64_TYPE; 2], type_row![FLOAT64_TYPE]).into();
let funop_sig: PolyFuncType =
FunctionType::new(type_row![FLOAT64_TYPE], type_row![FLOAT64_TYPE]).into();
extension
.add_op_custom_sig_simple("feq".into(), "equality test".to_owned(), vec![], fcmp_sig)
.add_op_type_scheme_simple("feq".into(), "equality test".to_owned(), fcmp_sig.clone())
.unwrap();
extension
.add_op_custom_sig_simple("fne".into(), "inequality test".to_owned(), vec![], fcmp_sig)
.add_op_type_scheme_simple("fne".into(), "inequality test".to_owned(), fcmp_sig.clone())
.unwrap();
extension
.add_op_custom_sig_simple("flt".into(), "\"less than\"".to_owned(), vec![], fcmp_sig)
.add_op_type_scheme_simple("flt".into(), "\"less than\"".to_owned(), fcmp_sig.clone())
.unwrap();
extension
.add_op_custom_sig_simple(
.add_op_type_scheme_simple(
"fgt".into(),
"\"greater than\"".to_owned(),
vec![],
fcmp_sig,
fcmp_sig.clone(),
)
.unwrap();
extension
.add_op_custom_sig_simple(
.add_op_type_scheme_simple(
"fle".into(),
"\"less than or equal\"".to_owned(),
vec![],
fcmp_sig,
fcmp_sig.clone(),
)
.unwrap();
extension
.add_op_custom_sig_simple(
.add_op_type_scheme_simple(
"fge".into(),
"\"greater than or equal\"".to_owned(),
vec![],
fcmp_sig,
)
.unwrap();
extension
.add_op_custom_sig_simple("fmax".into(), "maximum".to_owned(), vec![], fbinop_sig)
.add_op_type_scheme_simple("fmax".into(), "maximum".to_owned(), fbinop_sig.clone())
.unwrap();
extension
.add_op_custom_sig_simple("fmin".into(), "minimum".to_owned(), vec![], fbinop_sig)
.add_op_type_scheme_simple("fmin".into(), "minimum".to_owned(), fbinop_sig.clone())
.unwrap();
extension
.add_op_custom_sig_simple("fadd".into(), "addition".to_owned(), vec![], fbinop_sig)
.add_op_type_scheme_simple("fadd".into(), "addition".to_owned(), fbinop_sig.clone())
.unwrap();
extension
.add_op_custom_sig_simple("fsub".into(), "subtraction".to_owned(), vec![], fbinop_sig)
.add_op_type_scheme_simple("fsub".into(), "subtraction".to_owned(), fbinop_sig.clone())
.unwrap();
extension
.add_op_custom_sig_simple("fneg".into(), "negation".to_owned(), vec![], funop_sig)
.add_op_type_scheme_simple("fneg".into(), "negation".to_owned(), funop_sig.clone())
.unwrap();
extension
.add_op_custom_sig_simple(
.add_op_type_scheme_simple(
"fabs".into(),
"absolute value".to_owned(),
vec![],
funop_sig,
funop_sig.clone(),
)
.unwrap();
extension
.add_op_custom_sig_simple(
.add_op_type_scheme_simple(
"fmul".into(),
"multiplication".to_owned(),
vec![],
fbinop_sig,
fbinop_sig.clone(),
)
.unwrap();
extension
.add_op_custom_sig_simple("fdiv".into(), "division".to_owned(), vec![], fbinop_sig)
.add_op_type_scheme_simple("fdiv".into(), "division".to_owned(), fbinop_sig)
.unwrap();
extension
.add_op_custom_sig_simple("ffloor".into(), "floor".to_owned(), vec![], funop_sig)
.add_op_type_scheme_simple("ffloor".into(), "floor".to_owned(), funop_sig.clone())
.unwrap();
extension
.add_op_custom_sig_simple("fceil".into(), "ceiling".to_owned(), vec![], funop_sig)
.add_op_type_scheme_simple("fceil".into(), "ceiling".to_owned(), funop_sig)
.unwrap();

extension
Expand Down
5 changes: 5 additions & 0 deletions src/std_extensions/arithmetic/float_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ mod test {
fn test_float_consts() {
let const_f64_1 = ConstF64::new(1.0);
let const_f64_2 = ConstF64::new(2.0);

assert_eq!(const_f64_1.value(), 1.0);
assert_eq!(*const_f64_2, 2.0);
assert_eq!(const_f64_1.name(), "f64(1)");
assert!(const_f64_1.equal_consts(&ConstF64::new(1.0)));
assert_ne!(const_f64_1, const_f64_2);
assert_eq!(const_f64_1, ConstF64::new(1.0));
}
Expand Down
Loading

0 comments on commit 201d1a2

Please sign in to comment.