Skip to content

Commit

Permalink
add HUGR payload to PrimValue::Function (#431)
Browse files Browse the repository at this point in the history
add `get_function_type` method to HugrView to report the function type
of the root if valid.

Closes #374
  • Loading branch information
ss2165 authored Aug 22, 2023
1 parent caacecf commit 5ae71ed
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 24 deletions.
21 changes: 17 additions & 4 deletions src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,17 @@ impl From<BuildError> for PyErr {
}

#[cfg(test)]
mod test {
use crate::types::{Signature, Type};
use crate::Hugr;
pub(crate) mod test {
use rstest::fixture;

use crate::types::{FunctionType, Signature, Type};
use crate::{type_row, Hugr};

use super::handle::BuildHandle;
use super::{BuildError, Container, FuncID, FunctionBuilder, ModuleBuilder};
use super::{
BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, FuncID, FunctionBuilder,
ModuleBuilder,
};
use super::{DataflowSubContainer, HugrBuilder};

pub(super) const NAT: Type = crate::extension::prelude::USIZE_T;
Expand All @@ -117,4 +122,12 @@ mod test {
f(f_builder)?;
Ok(module_builder.finish_hugr()?)
}

#[fixture]
pub(crate) fn simple_dfg_hugr() -> Hugr {
let dfg_builder =
DFGBuilder::new(FunctionType::new(type_row![BIT], type_row![BIT])).unwrap();
let [i1] = dfg_builder.input_wires_arr();
dfg_builder.finish_hugr_with_outputs([i1]).unwrap()
}
}
19 changes: 7 additions & 12 deletions src/builder/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,9 @@ impl<T> HugrBuilder for DFGWrapper<Hugr, T> {
}

#[cfg(test)]
mod test {
pub(crate) mod test {
use cool_asserts::assert_matches;
use rstest::rstest;
use serde_json::json;

use crate::builder::build_traits::DataflowHugr;
Expand All @@ -233,6 +234,7 @@ mod test {
type_row, Wire,
};

use super::super::test::simple_dfg_hugr;
use super::*;
#[test]
fn nested_identity() -> Result<(), BuildError> {
Expand Down Expand Up @@ -392,17 +394,10 @@ mod test {
Ok(())
}

#[test]
fn dfg_hugr() -> Result<(), BuildError> {
let dfg_builder = DFGBuilder::new(FunctionType::new(type_row![BIT], type_row![BIT]))?;

let [i1] = dfg_builder.input_wires_arr();
let hugr = dfg_builder.finish_hugr_with_outputs([i1])?;

assert_eq!(hugr.node_count(), 3);
assert_matches!(hugr.root_type().tag(), OpTag::Dfg);

Ok(())
#[rstest]
fn dfg_hugr(simple_dfg_hugr: Hugr) {
assert_eq!(simple_dfg_hugr.node_count(), 3);
assert_matches!(simple_dfg_hugr.root_type().tag(), OpTag::Dfg);
}

#[test]
Expand Down
17 changes: 15 additions & 2 deletions src/hugr/views.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ use portgraph::{multiportgraph, LinkView, MultiPortGraph, PortView};

use super::{Hugr, NodeMetadata, NodeType};
use crate::ops::handle::NodeHandle;
use crate::ops::{OpName, OpTag, OpType};
use crate::types::EdgeKind;
use crate::ops::{FuncDecl, FuncDefn, OpName, OpTag, OpType, DFG};
use crate::types::{EdgeKind, FunctionType};
use crate::Direction;
use crate::{Node, Port};

Expand Down Expand Up @@ -154,6 +154,10 @@ pub trait HugrView: sealed::HugrInternals {
/// If the node isn't a dataflow parent, then return None
fn get_io(&self, node: Node) -> Option<[Node; 2]>;

/// For function-like HUGRs (DFG, FuncDefn, FuncDecl), report the function
/// type. Otherwise return None.
fn get_function_type(&self) -> Option<&FunctionType>;

/// Return dot string showing underlying graph and hierarchy side by side.
fn dot_string(&self) -> String {
let hugr = self.base_hugr();
Expand Down Expand Up @@ -317,6 +321,15 @@ where
}
}

fn get_function_type(&self) -> Option<&FunctionType> {
let op = self.get_nodetype(self.root());
match &op.op {
OpType::DFG(DFG { signature })
| OpType::FuncDecl(FuncDecl { signature, .. })
| OpType::FuncDefn(FuncDefn { signature, .. }) => Some(signature),
_ => None,
}
}
#[inline]
fn get_metadata(&self, node: Node) -> &NodeMetadata {
self.as_ref().metadata.get(node.index)
Expand Down
8 changes: 8 additions & 0 deletions src/hugr/views/hierarchy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,10 @@ where
None
}
}

fn get_function_type(&self) -> Option<&crate::types::FunctionType> {
self.base_hugr().get_function_type()
}
}

type RegionGraph<'g, Base> = portgraph::view::Region<'g, <Base as HugrInternals>::Portgraph>;
Expand Down Expand Up @@ -358,6 +362,10 @@ where
fn get_io(&self, node: Node) -> Option<[Node; 2]> {
self.base_hugr().get_io(node)
}

fn get_function_type(&self) -> Option<&crate::types::FunctionType> {
self.base_hugr().get_function_type()
}
}

/// A common trait for views of a HUGR hierarchical subgraph.
Expand Down
13 changes: 11 additions & 2 deletions src/types/check.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
//! Logic for checking values against types.
use thiserror::Error;

use crate::values::{PrimValue, Value};
use crate::{
values::{PrimValue, Value},
HugrView,
};

use super::{primitive::PrimType, CustomType, Type, TypeEnum};

Expand Down Expand Up @@ -56,7 +59,13 @@ impl PrimType {
e_val.0.check_custom_type(e)?;
Ok(())
}
(PrimType::Function(_), PrimValue::Function) => todo!(),
(PrimType::Function(t), PrimValue::Function(v))
if Some(t.as_ref()) == v.get_function_type() =>
{
// exact signature equality, in future this may need to be
// relaxed to be compatibility checks between the signatures.
Ok(())
}
_ => Err(ConstTypeError::ValueCheckFail(
Type::new(TypeEnum::Prim(self.clone())),
Value::Prim(val.clone()),
Expand Down
29 changes: 25 additions & 4 deletions src/values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use downcast_rs::{impl_downcast, Downcast};
use smol_str::SmolStr;

use crate::macros::impl_box_clone;
use crate::{Hugr, HugrView};

use crate::types::{CustomCheckFailure, CustomType};

Expand All @@ -20,15 +21,20 @@ pub enum PrimValue {
// Note: the extra level of tupling is to avoid https://github.com/rust-lang/rust/issues/78808
Extension((Box<dyn CustomConst>,)),
/// A higher-order function value.
// TODO add HUGR<DFG> payload
Function,
// TODO use a root parametrised hugr, e.g. Hugr<DFG>.
Function(Box<Hugr>),
}

impl PrimValue {
fn name(&self) -> String {
match self {
PrimValue::Extension(e) => format!("const:custom:{}", e.0.name()),
PrimValue::Function => todo!(),
PrimValue::Function(h) => {
let Some(t) = h.get_function_type() else {
panic!("HUGR root node isn't a valid function parent.");
};
format!("const:function:[{}]", t)
}
}
}
}
Expand Down Expand Up @@ -201,9 +207,13 @@ impl PartialEq for dyn CustomConst {

#[cfg(test)]
pub(crate) mod test {
use crate::types::{custom::test::COPYABLE_CUST, TypeBound};
use rstest::rstest;

use super::*;
use crate::builder::test::simple_dfg_hugr;
use crate::type_row;
use crate::types::{custom::test::COPYABLE_CUST, TypeBound};
use crate::types::{FunctionType, Type};

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]

Expand Down Expand Up @@ -233,4 +243,15 @@ pub(crate) mod test {
value: serde_yaml::Value::Number(f.into()),
})
}

#[rstest]
fn function_value(simple_dfg_hugr: Hugr) {
let v = Value::Prim(PrimValue::Function(Box::new(simple_dfg_hugr)));

let correct_type = Type::new_function(FunctionType::new_linear(type_row![
crate::extension::prelude::USIZE_T
]));

assert!(correct_type.check_type(&v).is_ok());
}
}

0 comments on commit 5ae71ed

Please sign in to comment.