From 5ae71edca261ba18714f252091a7b3a6a3a4615b Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 22 Aug 2023 14:12:35 +0100 Subject: [PATCH] add HUGR payload to PrimValue::Function (#431) add `get_function_type` method to HugrView to report the function type of the root if valid. Closes #374 --- src/builder.rs | 21 +++++++++++++++++---- src/builder/dataflow.rs | 19 +++++++------------ src/hugr/views.rs | 17 +++++++++++++++-- src/hugr/views/hierarchy.rs | 8 ++++++++ src/types/check.rs | 13 +++++++++++-- src/values.rs | 29 +++++++++++++++++++++++++---- 6 files changed, 83 insertions(+), 24 deletions(-) diff --git a/src/builder.rs b/src/builder.rs index 022caa576..f68311331 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -87,12 +87,17 @@ impl From for PyErr { } #[cfg(test)] -mod test { - use crate::types::{Signature, Type}; - use crate::Hugr; +pub(crate) mod test { + use rstest::fixture; + + use crate::types::{FunctionType, Signature, Type}; + use crate::{type_row, Hugr}; use super::handle::BuildHandle; - use super::{BuildError, Container, FuncID, FunctionBuilder, ModuleBuilder}; + use super::{ + BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, FuncID, FunctionBuilder, + ModuleBuilder, + }; use super::{DataflowSubContainer, HugrBuilder}; pub(super) const NAT: Type = crate::extension::prelude::USIZE_T; @@ -117,4 +122,12 @@ mod test { f(f_builder)?; Ok(module_builder.finish_hugr()?) } + + #[fixture] + pub(crate) fn simple_dfg_hugr() -> Hugr { + let dfg_builder = + DFGBuilder::new(FunctionType::new(type_row![BIT], type_row![BIT])).unwrap(); + let [i1] = dfg_builder.input_wires_arr(); + dfg_builder.finish_hugr_with_outputs([i1]).unwrap() + } } diff --git a/src/builder/dataflow.rs b/src/builder/dataflow.rs index 836e9e1ae..dd658af6c 100644 --- a/src/builder/dataflow.rs +++ b/src/builder/dataflow.rs @@ -212,8 +212,9 @@ impl HugrBuilder for DFGWrapper { } #[cfg(test)] -mod test { +pub(crate) mod test { use cool_asserts::assert_matches; + use rstest::rstest; use serde_json::json; use crate::builder::build_traits::DataflowHugr; @@ -233,6 +234,7 @@ mod test { type_row, Wire, }; + use super::super::test::simple_dfg_hugr; use super::*; #[test] fn nested_identity() -> Result<(), BuildError> { @@ -392,17 +394,10 @@ mod test { Ok(()) } - #[test] - fn dfg_hugr() -> Result<(), BuildError> { - let dfg_builder = DFGBuilder::new(FunctionType::new(type_row![BIT], type_row![BIT]))?; - - let [i1] = dfg_builder.input_wires_arr(); - let hugr = dfg_builder.finish_hugr_with_outputs([i1])?; - - assert_eq!(hugr.node_count(), 3); - assert_matches!(hugr.root_type().tag(), OpTag::Dfg); - - Ok(()) + #[rstest] + fn dfg_hugr(simple_dfg_hugr: Hugr) { + assert_eq!(simple_dfg_hugr.node_count(), 3); + assert_matches!(simple_dfg_hugr.root_type().tag(), OpTag::Dfg); } #[test] diff --git a/src/hugr/views.rs b/src/hugr/views.rs index 101c8f0a0..69b6545f4 100644 --- a/src/hugr/views.rs +++ b/src/hugr/views.rs @@ -13,8 +13,8 @@ use portgraph::{multiportgraph, LinkView, MultiPortGraph, PortView}; use super::{Hugr, NodeMetadata, NodeType}; use crate::ops::handle::NodeHandle; -use crate::ops::{OpName, OpTag, OpType}; -use crate::types::EdgeKind; +use crate::ops::{FuncDecl, FuncDefn, OpName, OpTag, OpType, DFG}; +use crate::types::{EdgeKind, FunctionType}; use crate::Direction; use crate::{Node, Port}; @@ -154,6 +154,10 @@ pub trait HugrView: sealed::HugrInternals { /// If the node isn't a dataflow parent, then return None fn get_io(&self, node: Node) -> Option<[Node; 2]>; + /// For function-like HUGRs (DFG, FuncDefn, FuncDecl), report the function + /// type. Otherwise return None. + fn get_function_type(&self) -> Option<&FunctionType>; + /// Return dot string showing underlying graph and hierarchy side by side. fn dot_string(&self) -> String { let hugr = self.base_hugr(); @@ -317,6 +321,15 @@ where } } + fn get_function_type(&self) -> Option<&FunctionType> { + let op = self.get_nodetype(self.root()); + match &op.op { + OpType::DFG(DFG { signature }) + | OpType::FuncDecl(FuncDecl { signature, .. }) + | OpType::FuncDefn(FuncDefn { signature, .. }) => Some(signature), + _ => None, + } + } #[inline] fn get_metadata(&self, node: Node) -> &NodeMetadata { self.as_ref().metadata.get(node.index) diff --git a/src/hugr/views/hierarchy.rs b/src/hugr/views/hierarchy.rs index 36a4abdd4..9514a841c 100644 --- a/src/hugr/views/hierarchy.rs +++ b/src/hugr/views/hierarchy.rs @@ -197,6 +197,10 @@ where None } } + + fn get_function_type(&self) -> Option<&crate::types::FunctionType> { + self.base_hugr().get_function_type() + } } type RegionGraph<'g, Base> = portgraph::view::Region<'g, ::Portgraph>; @@ -358,6 +362,10 @@ where fn get_io(&self, node: Node) -> Option<[Node; 2]> { self.base_hugr().get_io(node) } + + fn get_function_type(&self) -> Option<&crate::types::FunctionType> { + self.base_hugr().get_function_type() + } } /// A common trait for views of a HUGR hierarchical subgraph. diff --git a/src/types/check.rs b/src/types/check.rs index d2d27f0d9..3f5ecbae3 100644 --- a/src/types/check.rs +++ b/src/types/check.rs @@ -1,7 +1,10 @@ //! Logic for checking values against types. use thiserror::Error; -use crate::values::{PrimValue, Value}; +use crate::{ + values::{PrimValue, Value}, + HugrView, +}; use super::{primitive::PrimType, CustomType, Type, TypeEnum}; @@ -56,7 +59,13 @@ impl PrimType { e_val.0.check_custom_type(e)?; Ok(()) } - (PrimType::Function(_), PrimValue::Function) => todo!(), + (PrimType::Function(t), PrimValue::Function(v)) + if Some(t.as_ref()) == v.get_function_type() => + { + // exact signature equality, in future this may need to be + // relaxed to be compatibility checks between the signatures. + Ok(()) + } _ => Err(ConstTypeError::ValueCheckFail( Type::new(TypeEnum::Prim(self.clone())), Value::Prim(val.clone()), diff --git a/src/values.rs b/src/values.rs index 3b9f946b7..4d6e4b479 100644 --- a/src/values.rs +++ b/src/values.rs @@ -9,6 +9,7 @@ use downcast_rs::{impl_downcast, Downcast}; use smol_str::SmolStr; use crate::macros::impl_box_clone; +use crate::{Hugr, HugrView}; use crate::types::{CustomCheckFailure, CustomType}; @@ -20,15 +21,20 @@ pub enum PrimValue { // Note: the extra level of tupling is to avoid https://github.com/rust-lang/rust/issues/78808 Extension((Box,)), /// A higher-order function value. - // TODO add HUGR payload - Function, + // TODO use a root parametrised hugr, e.g. Hugr. + Function(Box), } impl PrimValue { fn name(&self) -> String { match self { PrimValue::Extension(e) => format!("const:custom:{}", e.0.name()), - PrimValue::Function => todo!(), + PrimValue::Function(h) => { + let Some(t) = h.get_function_type() else { + panic!("HUGR root node isn't a valid function parent."); + }; + format!("const:function:[{}]", t) + } } } } @@ -201,9 +207,13 @@ impl PartialEq for dyn CustomConst { #[cfg(test)] pub(crate) mod test { - use crate::types::{custom::test::COPYABLE_CUST, TypeBound}; + use rstest::rstest; use super::*; + use crate::builder::test::simple_dfg_hugr; + use crate::type_row; + use crate::types::{custom::test::COPYABLE_CUST, TypeBound}; + use crate::types::{FunctionType, Type}; #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] @@ -233,4 +243,15 @@ pub(crate) mod test { value: serde_yaml::Value::Number(f.into()), }) } + + #[rstest] + fn function_value(simple_dfg_hugr: Hugr) { + let v = Value::Prim(PrimValue::Function(Box::new(simple_dfg_hugr))); + + let correct_type = Type::new_function(FunctionType::new_linear(type_row![ + crate::extension::prelude::USIZE_T + ])); + + assert!(correct_type.check_type(&v).is_ok()); + } }