Skip to content

Commit

Permalink
feat: constant folding for list operations
Browse files Browse the repository at this point in the history
+ utility enums and structs for dealing with list ops and types
  • Loading branch information
ss2165 committed Jan 9, 2024
1 parent 94be2b7 commit 9d7243d
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 42 deletions.
43 changes: 42 additions & 1 deletion src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,14 @@ mod test {
use crate::extension::prelude::{sum_with_error, BOOL_T};
use crate::extension::{ExtensionRegistry, PRELUDE};
use crate::ops::OpType;
use crate::std_extensions::arithmetic;
use crate::std_extensions::arithmetic::conversions::ConvertOpDef;
use crate::std_extensions::arithmetic::float_ops::FloatOps;
use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE};
use crate::std_extensions::arithmetic::int_types::{ConstIntU, INT_TYPES};
use crate::std_extensions::collections::{make_list_const, ListOp, ListValue};
use crate::std_extensions::logic::{self, const_from_bool, NaryLogic};
use crate::std_extensions::{arithmetic, collections};
use crate::types::TypeArg;
use rstest::rstest;

/// int to constant
Expand Down Expand Up @@ -332,6 +334,45 @@ mod test {
Ok(())
}

#[test]
fn test_list_ops() -> Result<(), Box<dyn std::error::Error>> {
let reg = ExtensionRegistry::try_new([
PRELUDE.to_owned(),
logic::EXTENSION.to_owned(),
collections::EXTENSION.to_owned(),
])
.unwrap();
let list = make_list_const(
ListValue::new(vec![Value::unit_sum(1)]),
&[TypeArg::Type { ty: BOOL_T }],
);
let mut build = DFGBuilder::new(FunctionType::new(
type_row![],
vec![list.const_type().clone()],
))
.unwrap();

let list_wire = build.add_load_const(list.clone())?;

let pop = build.add_dataflow_op(
ListOp::Pop.with_type(BOOL_T).to_extension_op(&reg).unwrap(),
[list_wire],
)?;

let push = build.add_dataflow_op(
ListOp::Push
.with_type(BOOL_T)
.to_extension_op(&reg)
.unwrap(),
pop.outputs(),
)?;
let mut h = build.finish_hugr_with_outputs(push.outputs(), &reg)?;
constant_fold_pass(&mut h, &reg);

assert_fully_folded(&h, &list);
Ok(())
}

fn assert_fully_folded(h: &Hugr, expected_const: &Const) {
// check the hugr just loads and returns a single const
let mut node_count = 0;
Expand Down
224 changes: 183 additions & 41 deletions src/std_extensions/collections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@ use serde::{Deserialize, Serialize};
use smol_str::SmolStr;

use crate::{
extension::{ExtensionId, ExtensionSet, TypeDef, TypeDefBound},
algorithm::const_fold::sorted_consts,
extension::{
simple_op::{MakeExtensionOp, OpLoadError},
ConstFold, ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, TypeDef,
TypeDefBound,
},
ops::{self, custom::ExtensionOp, OpName},
types::{
type_param::{TypeArg, TypeParam},
CustomCheckFailure, CustomType, FunctionType, PolyFuncType, Type, TypeBound,
Expand Down Expand Up @@ -47,7 +53,9 @@ impl CustomConst for ListValue {
CustomCheckFailure::Message("List type check fail.".to_string())
};

get_type(&LIST_TYPENAME)
EXTENSION
.get_type(&LIST_TYPENAME)
.unwrap()
.check_custom(typ)
.map_err(|_| error())?;

Expand All @@ -72,6 +80,55 @@ impl CustomConst for ListValue {
.union(&ExtensionSet::singleton(&EXTENSION_NAME))
}
}

struct PopFold;

impl ConstFold for PopFold {
fn fold(
&self,
type_args: &[TypeArg],
consts: &[(crate::IncomingPort, ops::Const)],
) -> crate::extension::ConstFoldResult {
let [TypeArg::Type { ty }] = type_args else {
return None;
};
let [list]: [&ops::Const; 1] = sorted_consts(consts).try_into().ok()?;
let list: &ListValue = list.get_custom_value().expect("Should be list value.");
let mut list = list.clone();
let elem = list.0.pop()?; // empty list fails to evaluate "pop"
let list = make_list_const(list, type_args);
let elem = ops::Const::new(elem, ty.clone()).unwrap();

Some(vec![(0.into(), list), (1.into(), elem)])
}
}

pub(crate) fn make_list_const(list: ListValue, type_args: &[TypeArg]) -> ops::Const {
let list_type_def = EXTENSION.get_type(&LIST_TYPENAME).unwrap();
ops::Const::new(
list.into(),
Type::new_extension(list_type_def.instantiate(type_args).unwrap()),
)
.unwrap()
}

struct PushFold;

impl ConstFold for PushFold {
fn fold(
&self,
type_args: &[TypeArg],
consts: &[(crate::IncomingPort, ops::Const)],
) -> crate::extension::ConstFoldResult {
let [list, elem]: [&ops::Const; 2] = sorted_consts(consts).try_into().ok()?;
let list: &ListValue = list.get_custom_value().expect("Should be list value.");
let mut list = list.clone();
list.0.push(elem.value().clone());
let list = make_list_const(list, type_args);

Some(vec![(0.into(), list)])
}
}
const TP: TypeParam = TypeParam::Type { b: TypeBound::Any };

fn extension() -> Extension {
Expand All @@ -87,7 +144,7 @@ fn extension() -> Extension {
.unwrap();
let list_type_def = extension.get_type(&LIST_TYPENAME).unwrap();

let (l, e) = list_and_elem_type(list_type_def);
let (l, e) = list_and_elem_type_vars(list_type_def);
extension
.add_op(
POP_NAME,
Expand All @@ -97,14 +154,17 @@ fn extension() -> Extension {
FunctionType::new(vec![l.clone()], vec![l.clone(), e.clone()]),
),
)
.unwrap();
.unwrap()
.set_constant_folder(PopFold);
extension
.add_op(
PUSH_NAME,
"Push to back of list".into(),
PolyFuncType::new(vec![TP], FunctionType::new(vec![l.clone(), e], vec![l])),
)
.unwrap();
.unwrap()
.set_constant_folder(PushFold);

extension
}

Expand All @@ -113,11 +173,18 @@ lazy_static! {
pub static ref EXTENSION: Extension = extension();
}

fn get_type(name: &str) -> &TypeDef {
EXTENSION.get_type(name).unwrap()
/// Get the type of a list of `elem_type`
pub fn list_type(elem_type: Type) -> Type {
Type::new_extension(
EXTENSION
.get_type(&LIST_TYPENAME)
.unwrap()
.instantiate(vec![TypeArg::Type { ty: elem_type }])
.unwrap(),
)
}

fn list_and_elem_type(list_type_def: &TypeDef) -> (Type, Type) {
fn list_and_elem_type_vars(list_type_def: &TypeDef) -> (Type, Type) {
let elem_type = Type::new_var_use(0, TypeBound::Any);
let list_type = Type::new_extension(
list_type_def
Expand All @@ -126,22 +193,107 @@ fn list_and_elem_type(list_type_def: &TypeDef) -> (Type, Type) {
);
(list_type, elem_type)
}

/// A list operation
#[derive(Debug, Clone, PartialEq)]
pub enum ListOp {
/// Pop from end of list
Pop,
/// Push to end of list
Push,
}

impl ListOp {
/// Instantiate a list operation with an `element_type`
pub fn with_type(self, element_type: Type) -> ListOpInst {
ListOpInst {
elem_type: element_type,
op: self,
}
}
}

/// A list operation with a concrete element type.
#[derive(Debug, Clone, PartialEq)]
pub struct ListOpInst {
op: ListOp,
elem_type: Type,
}

impl OpName for ListOpInst {
fn name(&self) -> SmolStr {
match self.op {
ListOp::Pop => POP_NAME,
ListOp::Push => PUSH_NAME,
}
}
}

impl MakeExtensionOp for ListOpInst {
fn from_extension_op(
ext_op: &ExtensionOp,
) -> Result<Self, crate::extension::simple_op::OpLoadError> {
let [TypeArg::Type { ty }] = ext_op.args() else {
return Err(SignatureError::InvalidTypeArgs.into());
};
let name = ext_op.def().name();
let op = match name {
// can't use const SmolStr in pattern
_ if name == &POP_NAME => ListOp::Pop,
_ if name == &PUSH_NAME => ListOp::Push,
_ => return Err(OpLoadError::NotMember(name.to_string())),
};

Ok(Self {
elem_type: ty.clone(),
op,
})
}

fn type_args(&self) -> Vec<TypeArg> {
vec![TypeArg::Type {
ty: self.elem_type.clone(),
}]
}
}

impl ListOpInst {
/// Convert this list operation to an [`ExtensionOp`] by providing a
/// registry to validate the element tyoe against.
pub fn to_extension_op(self, elem_type_registry: &ExtensionRegistry) -> Option<ExtensionOp> {
let registry = ExtensionRegistry::try_new(
elem_type_registry
.clone()
.into_iter()
// ignore self if already in registry
.filter_map(|(_, ext)| (ext.name() != EXTENSION.name()).then_some(ext))
.chain(std::iter::once(EXTENSION.to_owned())),
)
.unwrap();
ExtensionOp::new(
registry.get(&EXTENSION_NAME)?.get_op(&self.name())?.clone(),
self.type_args(),
&registry,
)
.ok()
}
}

#[cfg(test)]
mod test {
use crate::{
extension::{
prelude::{ConstUsize, QB_T, USIZE_T},
ExtensionRegistry, OpDef, PRELUDE,
ExtensionRegistry, PRELUDE,
},
ops::OpTrait,
std_extensions::arithmetic::float_types::{self, ConstF64, FLOAT64_TYPE},
types::{type_param::TypeArg, Type, TypeRow},
types::{type_param::TypeArg, TypeRow},
Extension,
};

use super::*;
fn get_op(name: &str) -> &OpDef {
EXTENSION.get_op(name).unwrap()
}

#[test]
fn test_extension() {
let r: Extension = extension();
Expand Down Expand Up @@ -174,40 +326,30 @@ mod test {

#[test]
fn test_list_ops() {
let reg = ExtensionRegistry::try_new([
EXTENSION.to_owned(),
PRELUDE.to_owned(),
float_types::EXTENSION.to_owned(),
])
.unwrap();
let pop_sig = get_op(&POP_NAME)
.compute_signature(&[TypeArg::Type { ty: QB_T }], &reg)
.unwrap();
let reg =
ExtensionRegistry::try_new([PRELUDE.to_owned(), float_types::EXTENSION.to_owned()])
.unwrap();
let pop_op = ListOp::Pop.with_type(QB_T);
let pop_ext = pop_op.clone().to_extension_op(&reg).unwrap();
assert_eq!(ListOpInst::from_extension_op(&pop_ext).unwrap(), pop_op);
let pop_sig = pop_ext.dataflow_signature().unwrap();

let list_type = Type::new_extension(CustomType::new(
LIST_TYPENAME,
vec![TypeArg::Type { ty: QB_T }],
EXTENSION_NAME,
TypeBound::Any,
));
let list_t = list_type(QB_T);

let both_row: TypeRow = vec![list_type.clone(), QB_T].into();
let just_list_row: TypeRow = vec![list_type].into();
let both_row: TypeRow = vec![list_t.clone(), QB_T].into();
let just_list_row: TypeRow = vec![list_t].into();
assert_eq!(pop_sig.input(), &just_list_row);
assert_eq!(pop_sig.output(), &both_row);

let push_sig = get_op(&PUSH_NAME)
.compute_signature(&[TypeArg::Type { ty: FLOAT64_TYPE }], &reg)
.unwrap();
let push_op = ListOp::Push.with_type(FLOAT64_TYPE);
let push_ext = push_op.clone().to_extension_op(&reg).unwrap();
assert_eq!(ListOpInst::from_extension_op(&push_ext).unwrap(), push_op);
let push_sig = push_ext.dataflow_signature().unwrap();

let list_type = Type::new_extension(CustomType::new(
LIST_TYPENAME,
vec![TypeArg::Type { ty: FLOAT64_TYPE }],
EXTENSION_NAME,
TypeBound::Copyable,
));
let both_row: TypeRow = vec![list_type.clone(), FLOAT64_TYPE].into();
let just_list_row: TypeRow = vec![list_type].into();
let list_t = list_type(FLOAT64_TYPE);

let both_row: TypeRow = vec![list_t.clone(), FLOAT64_TYPE].into();
let just_list_row: TypeRow = vec![list_t].into();

assert_eq!(push_sig.input(), &both_row);
assert_eq!(push_sig.output(), &just_list_row);
Expand Down

0 comments on commit 9d7243d

Please sign in to comment.