Skip to content

Commit

Permalink
feat(modules): add import!() expression
Browse files Browse the repository at this point in the history
  • Loading branch information
zzlk committed Sep 6, 2023
1 parent 58bf7ba commit 6f80e4a
Show file tree
Hide file tree
Showing 11 changed files with 452 additions and 48 deletions.
46 changes: 46 additions & 0 deletions hydroflow/examples/modules/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use std::cell::RefCell;
use std::rc::Rc;

use hydroflow::hydroflow_syntax;
use hydroflow::scheduled::graph::Hydroflow;
use hydroflow::util::multiset::HashMultiSet;

pub fn main() {
let output = Rc::new(RefCell::new(
HashMultiSet::<(usize, usize, usize)>::default(),
));

let mut df: Hydroflow = {
let output = output.clone();
hydroflow_syntax! {
source_iter(0..2) -> [0]cj;
source_iter(0..2) -> [1]cj;
source_iter(0..2) -> [2]cj;

cj = import!("triple_cross_join.hf")
-> for_each(|x| output.borrow_mut().insert(x));
}
};

df.run_available();

#[rustfmt::skip]
assert_eq!(
*output.borrow(),
HashMultiSet::from_iter([
(0, 0, 0),
(0, 0, 1),
(0, 1, 0),
(0, 1, 1),
(1, 0, 0),
(1, 0, 1),
(1, 1, 0),
(1, 1, 1),
])
);
}

#[test]
fn test() {
main();
}
15 changes: 15 additions & 0 deletions hydroflow/examples/modules/triple_cross_join.hf
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
input[0]
-> [0]cj1;

input[1]
-> [1]cj1;

cj1 = cross_join()
-> [0]cj2;

input[2]
-> [1]cj2;

cj2 = cross_join()
-> map(|((a, b), c)| (a, b, c))
-> output;
1 change: 1 addition & 0 deletions hydroflow_datalog_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ pub fn gen_hydroflow_graph(
if !diagnostics.is_empty() {
Err(diagnostics)
} else {
flat_graph.merge_modules();
eliminate_extra_unions_tees(&mut flat_graph);
Ok(flat_graph)
}
Expand Down
20 changes: 20 additions & 0 deletions hydroflow_lang/src/graph/di_mul_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,26 @@ where
Some((new_edge, (pred_edge, succ_edge)))
}

/// Remove an edge from the graph. If the edgeId is found then the edge is removed from the graph and returned.
/// If the edgeId was not found in the graph then nothing is returned and nothing is done.
pub fn remove_edge(&mut self, e: E) -> Option<(V, V)> {
let Some((src, dst)) = self.edges.remove(e) else {
return None;
};

self.succs[src].retain(|x| *x != e);
self.preds[dst].retain(|x| *x != e);

Some((src, dst))
}

/// Remove a vertex from the graph, it must have no edges to or from it when doing this.
pub fn remove_vertex(&mut self, v: V) {
assert!(!self.edges.values().any(|(v1, v2)| *v1 == v || *v2 == v),);
self.preds.remove(v);
self.succs.remove(v);
}

/// Get the source and destination vertex IDs for the given edge ID.
pub fn edge(&self, e: E) -> Option<(V, V)> {
self.edges.get(e).copied()
Expand Down
208 changes: 181 additions & 27 deletions hydroflow_lang/src/graph/flat_graph_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
use std::borrow::Cow;
use std::collections::btree_map::Entry;
use std::collections::{BTreeMap, BTreeSet};
use std::path::PathBuf;

use proc_macro2::Span;
use quote::ToTokens;
use syn::spanned::Spanned;
use syn::{Ident, ItemUse};
use syn::{Error, Ident, ItemUse};

use super::ops::find_op_op_constraints;
use super::{GraphNodeId, HydroflowGraph, Node, PortIndexValue};
Expand Down Expand Up @@ -44,6 +45,16 @@ pub struct FlatGraphBuilder {

/// Use statements.
uses: Vec<ItemUse>,

/// In order to make import!() statements relative to the current file, we need to know where the file is that is building the flat graph.
macro_invocation_path: PathBuf,

/// If the flat graph is being loaded as a module, then there are some additional things that happen
/// 1. the varname 'input' and 'output' is reserved and used for the input and output of the module.
/// 2. two initial module boundary nodes are inserted into the graph before statements are processed, one input and one output.
is_module: bool,
input_node: GraphNodeId,
output_node: GraphNodeId,
}

impl FlatGraphBuilder {
Expand All @@ -53,8 +64,42 @@ impl FlatGraphBuilder {
}

/// Convert the Hydroflow code AST into a graph builder.
pub fn from_hfcode(input: HfCode) -> Self {
input.into()
pub fn from_hfcode(input: HfCode, macro_invocation_path: PathBuf) -> Self {
let mut builder = Self {
macro_invocation_path,
..Default::default()
};
builder.process_statements(input.statements);

builder
}

/// Convert the Hydroflow code AST into a graph builder.
pub fn from_hfmodule(input: HfCode) -> Self {
let mut builder = Self::default();
builder.is_module = true;
builder.input_node = builder.flat_graph.insert_node(
Node::ModuleBoundary {
input: true,
import_expr: Span::call_site(),
},
Some(Ident::new("input", Span::call_site())),
);
builder.output_node = builder.flat_graph.insert_node(
Node::ModuleBoundary {
input: false,
import_expr: Span::call_site(),
},
Some(Ident::new("output", Span::call_site())),
);
builder.process_statements(input.statements);
builder
}

fn process_statements(&mut self, statements: impl IntoIterator<Item = HfStatement>) {
for stmt in statements {
self.add_statement(stmt);
}
}

/// Build into an unpartitioned [`HydroflowGraph`], returning a tuple of a `HydroflowGraph` and
Expand All @@ -77,7 +122,13 @@ impl FlatGraphBuilder {
}
HfStatement::Named(named) => {
let stmt_span = named.span();
let ends = self.add_pipeline(named.pipeline, Some(&named.name));
let ends = match self.add_pipeline(named.pipeline, Some(&named.name)) {
Err(err) => {
self.diagnostics.push(err.into());
return;
}
Ok(ends) => ends,
};
match self.varname_ends.entry(named.name) {
Entry::Vacant(vacant_entry) => {
vacant_entry.insert(Ok(ends));
Expand Down Expand Up @@ -106,33 +157,57 @@ impl FlatGraphBuilder {
}
}
HfStatement::Pipeline(pipeline_stmt) => {
self.add_pipeline(pipeline_stmt.pipeline, None);
if let Err(err) = self.add_pipeline(pipeline_stmt.pipeline, None) {
self.diagnostics.push(err.into());
}
}
}
}

/// Helper: Add a pipeline, i.e. `a -> b -> c`. Return the input and output ends for it.
fn add_pipeline(&mut self, pipeline: Pipeline, current_varname: Option<&Ident>) -> Ends {
fn add_pipeline(
&mut self,
pipeline: Pipeline,
current_varname: Option<&Ident>,
) -> syn::Result<Ends> {
match pipeline {
Pipeline::Paren(ported_pipeline_paren) => {
let (inn_port, pipeline_paren, out_port) =
PortIndexValue::from_ported(ported_pipeline_paren);
let og_ends = self.add_pipeline(*pipeline_paren.pipeline, current_varname);
Self::helper_combine_ends(&mut self.diagnostics, og_ends, inn_port, out_port)
let og_ends = self.add_pipeline(*pipeline_paren.pipeline, current_varname)?;
Ok(Self::helper_combine_ends(
&mut self.diagnostics,
og_ends,
inn_port,
out_port,
))
}
Pipeline::Name(pipeline_name) => {
let (inn_port, ident, out_port) = PortIndexValue::from_ported(pipeline_name);
// We could lookup non-forward references immediately, but easier to just have one
// consistent code path. -mingwei
Ends {
inn: Some((inn_port, GraphDet::Undetermined(ident.clone()))),
out: Some((out_port, GraphDet::Undetermined(ident))),

if self.is_module && ident == "input" {
Ok(Ends {
inn: Some((inn_port, GraphDet::Determined(self.input_node))),
out: Some((out_port, GraphDet::Determined(self.input_node))),
})
} else if self.is_module && ident == "output" {
Ok(Ends {
inn: Some((inn_port, GraphDet::Determined(self.output_node))),
out: Some((out_port, GraphDet::Determined(self.output_node))),
})
} else {
// We could lookup non-forward references immediately, but easier to just have one
// consistent code path. -mingwei
Ok(Ends {
inn: Some((inn_port, GraphDet::Undetermined(ident.clone()))),
out: Some((out_port, GraphDet::Undetermined(ident))),
})
}
}
Pipeline::Link(pipeline_link) => {
// Add the nested LHS and RHS of this link.
let lhs_ends = self.add_pipeline(*pipeline_link.lhs, current_varname);
let rhs_ends = self.add_pipeline(*pipeline_link.rhs, current_varname);
let lhs_ends = self.add_pipeline(*pipeline_link.lhs, current_varname)?;
let rhs_ends = self.add_pipeline(*pipeline_link.rhs, current_varname)?;

// Outer (first and last) ends.
let outer_ends = Ends {
Expand All @@ -145,17 +220,103 @@ impl FlatGraphBuilder {
inn: rhs_ends.inn,
};
self.links.push(link_ends);
outer_ends
Ok(outer_ends)
}
Pipeline::Operator(operator) => {
let op_span = Some(operator.span());
let nid = self
.flat_graph
.insert_node(Node::Operator(operator), current_varname.cloned());
Ends {
Ok(Ends {
inn: Some((PortIndexValue::Elided(op_span), GraphDet::Determined(nid))),
out: Some((PortIndexValue::Elided(op_span), GraphDet::Determined(nid))),
})
}
Pipeline::Import(import) => {
// TODO: https://github.com/rust-lang/rfcs/pull/3200
// this would be way better...
let mut dir = self.macro_invocation_path.clone();
dir.pop();

let file_contents = std::fs::read_to_string(dir.join(import.filename.value()))
.map_err(|e| {
Error::new(
import.filename.span(),
format!("filename: {}, err: {e}", import.filename.value()),
)
})?;

// TODO: see also above, parse_str sets all the spans in the resulting parsed token stream to the parent macro invocation span.
// This means that any error inside the imported module will manifest as a giant red squiggly line under the parent hydroflow_syntax!{} call.
let statements = match syn::parse_str::<HfCode>(&file_contents) {
Ok(code) => code,
Err(err) => {
self.diagnostics.push(err.clone().into());
return syn::Result::Err(err);
}
};

let flat_graph_builder = crate::graph::FlatGraphBuilder::from_hfmodule(statements);
let (flat_graph, _uses, diagnostics) = flat_graph_builder.build();
diagnostics
.iter()
.for_each(crate::diagnostic::Diagnostic::emit);

let mut ends = Ends {
inn: None,
out: None,
};

let mut node_mapping = BTreeMap::new();

for (nid, node) in flat_graph.nodes() {
match node {
Node::Operator(_) => {
let varname = flat_graph.node_varname(nid);
let new_id = self.flat_graph.insert_node(node.clone(), varname);
node_mapping.insert(nid, new_id);
}
Node::ModuleBoundary { input, .. } => {
let new_id = self.flat_graph.insert_node(
Node::ModuleBoundary {
input: *input,
import_expr: import.span(),
},
Some(Ident::new(
&format!("module_{}", input.to_string()),
import.span(),
)),
);
node_mapping.insert(nid, new_id);

if *input {
ends.inn = Some((
PortIndexValue::Elided(None),
GraphDet::Determined(new_id),
));
} else {
ends.out = Some((
PortIndexValue::Elided(None),
GraphDet::Determined(new_id),
));
}
}
_ => panic!(),
}
}

for (eid, (src, dst)) in flat_graph.edges() {
let (src_port, dst_port) = flat_graph.edge_ports(eid);

self.flat_graph.insert_edge(
*node_mapping.get(&src).unwrap(),
src_port.clone(),
*node_mapping.get(&dst).unwrap(),
dst_port.clone(),
);
}

Ok(ends)
}
}
}
Expand Down Expand Up @@ -515,6 +676,9 @@ impl FlatGraphBuilder {
);
}
Node::Handoff { .. } => todo!("Node::Handoff"),
Node::ModuleBoundary { .. } => {
// Module boundaries don't require any checking.
}
}
}
}
Expand Down Expand Up @@ -577,13 +741,3 @@ impl FlatGraphBuilder {
}
}
}

impl From<HfCode> for FlatGraphBuilder {
fn from(input: HfCode) -> Self {
let mut builder = Self::default();
for stmt in input.statements {
builder.add_statement(stmt);
}
builder
}
}
Loading

0 comments on commit 6f80e4a

Please sign in to comment.