Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(Inference): Work harder in variable instantiation #591

Merged
merged 22 commits into from
Nov 8, 2023
Merged
Changes from 12 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
61ce9a0
refactor: Make `EqGraph` generic over directedness
croyzor Oct 5, 2023
456e55b
fix: Improve instantiate_vars code
croyzor Oct 5, 2023
802b011
tests: Add more tests for looping CFGs
croyzor Oct 5, 2023
c043802
refactor: Remove unnecessary macro
croyzor Oct 10, 2023
fad983e
doc: Update comment
croyzor Oct 10, 2023
d49926d
refactor: Rename `new_{un,}directed` to `new`
croyzor Oct 10, 2023
983bb3a
cosmetic: Move comment
croyzor Oct 10, 2023
e5dbc78
refactor: Rewrite `instantiate_variables` in a functional style
croyzor Oct 10, 2023
80516c3
Reduce mutable variables in search_variable_deps
acl-cqc Oct 10, 2023
84bffa8
refactor: Redo `search_variable_deps`
croyzor Oct 10, 2023
4530168
refactor: Redo `search_variable_deps` in functional style
croyzor Oct 11, 2023
a94d2c8
doc: Move comment
croyzor Oct 11, 2023
52fe101
Fix case of dependent cycles?? Need a test, and some comments to resolve
acl-cqc Oct 11, 2023
fada318
cosmetic: Rename `ccs` to `sccs`
croyzor Oct 13, 2023
a1fb7ec
Missed `resolve`
acl-cqc Oct 23, 2023
4f62200
Drop comment - calling self.resolve enough should handle Equals const…
acl-cqc Oct 23, 2023
ea3c102
Merge remote-tracking branch 'origin/main' into fix/inference-variabl…
croyzor Nov 7, 2023
9f032d6
Add failing test of SCC logic
croyzor Nov 7, 2023
6f288cd
Merge branch 'fix/inference-variable-instantiation' into inference-va…
croyzor Nov 8, 2023
c04e302
Update test case
croyzor Nov 8, 2023
8e87d3e
tests: Add failing test of SCC logic
croyzor Nov 7, 2023
9d49c29
Merge branch 'inference-variable/fix-dependent-sccs' into fix/inferen…
croyzor Nov 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 116 additions & 47 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use crate::{
use super::validate::ExtensionError;

use petgraph::graph as pg;
use petgraph::{Directed, EdgeType, Undirected};

use std::collections::{HashMap, HashSet};

Expand Down Expand Up @@ -107,53 +108,65 @@ pub enum InferExtensionError {
EdgeMismatch(#[from] ExtensionError),
}

/// A graph of metavariables which we've found equality constraints for. Edges
/// between nodes represent equality constraints.
struct EqGraph {
equalities: pg::Graph<Meta, (), petgraph::Undirected>,
/// A graph of metavariables connected by constraints.
/// The edges represent `Equal` constraints in the undirected graph and `Plus`
croyzor marked this conversation as resolved.
Show resolved Hide resolved
/// constraints in the directed case.
struct GraphContainer<Dir: EdgeType> {
graph: pg::Graph<Meta, (), Dir>,
node_map: HashMap<Meta, pg::NodeIndex>,
}

impl EqGraph {
/// Create a new `EqGraph`
fn new() -> Self {
EqGraph {
equalities: pg::Graph::new_undirected(),
node_map: HashMap::new(),
}
}

impl<T: EdgeType> GraphContainer<T> {
/// Add a metavariable to the graph as a node and return the `NodeIndex`.
/// If it's already there, just return the existing `NodeIndex`
fn add_or_retrieve(&mut self, m: Meta) -> pg::NodeIndex {
self.node_map.get(&m).cloned().unwrap_or_else(|| {
let ix = self.equalities.add_node(m);
let ix = self.graph.add_node(m);
self.node_map.insert(m, ix);
ix
})
}

/// Create an edge between two nodes on the graph, declaring that they stand
/// for metavariables which should be equal.
fn register_eq(&mut self, src: Meta, tgt: Meta) {
/// Create an edge between two nodes on the graph
fn add_edge(&mut self, src: Meta, tgt: Meta) {
let src_ix = self.add_or_retrieve(src);
let tgt_ix = self.add_or_retrieve(tgt);
self.equalities.add_edge(src_ix, tgt_ix, ());
self.graph.add_edge(src_ix, tgt_ix, ());
}

/// Return the connected components of the graph in terms of metavariables
fn ccs(&self) -> Vec<Vec<Meta>> {
petgraph::algo::tarjan_scc(&self.equalities)
petgraph::algo::tarjan_scc(&self.graph)
croyzor marked this conversation as resolved.
Show resolved Hide resolved
.into_iter()
.map(|cc| {
cc.into_iter()
.map(|n| *self.equalities.node_weight(n).unwrap())
.map(|n| *self.graph.node_weight(n).unwrap())
.collect()
})
.collect()
}
}

impl GraphContainer<Undirected> {
fn new() -> Self {
GraphContainer {
graph: pg::Graph::new_undirected(),
croyzor marked this conversation as resolved.
Show resolved Hide resolved
node_map: HashMap::new(),
}
}
}

impl GraphContainer<Directed> {
fn new() -> Self {
GraphContainer {
graph: pg::Graph::new(),
node_map: HashMap::new(),
}
}
}

type EqGraph = GraphContainer<Undirected>;

/// Our current knowledge about the extensions of the graph
struct UnificationContext {
/// A list of constraints for each metavariable
Expand Down Expand Up @@ -504,7 +517,7 @@ impl UnificationContext {
match c {
// Just register the equality in the EqGraph, we'll process it later
Constraint::Equal(other_meta) => {
self.eq_graph.register_eq(meta, *other_meta);
self.eq_graph.add_edge(meta, *other_meta);
}
// N.B. If `meta` is already solved, we can't use that
// information to solve `other_meta`. This is because the Plus
Expand Down Expand Up @@ -663,31 +676,90 @@ impl UnificationContext {
self.results()
}

/// Instantiate all variables in the graph with the empty extension set.
/// Gather all the transitive dependencies (induced by constraints) of the
/// variables in the context.
fn search_variable_deps(&self) -> HashSet<Meta> {
let mut seen = HashSet::new();
let mut new_variables: HashSet<Meta> = self.variables.clone();
while !new_variables.is_empty() {
new_variables = new_variables
.into_iter()
.filter(|m| seen.insert(*m))
.flat_map(|m| self.get_constraints(&m).unwrap())
.map(|c| match c {
Constraint::Plus(_, other) => self.resolve(*other),
Constraint::Equal(other) => self.resolve(*other),
})
.collect();
}
seen
}

/// Instantiate all variables in the graph with the empty extension set, or
/// the smallest solution possible given their constraints.
/// This is done to solve metas which depend on variables, which allows
/// us to come up with a fully concrete solution to pass into validation.
///
/// Nodes which loop into themselves must be considered as a "minimum" set
/// of requirements. If we have
/// 1 = 2 + X, ...
/// 2 = 1 + x, ...
/// then 1 and 2 both definitely contain X, even if we don't know what else.
/// So instead of instantiating to the empty set, we'll instantiate to `{X}`
pub fn instantiate_variables(&mut self) {
for m in self.variables.clone().into_iter() {
// A directed graph to keep track of `Plus` constraint relationships
let mut relations = GraphContainer::<Directed>::new();
let mut solutions: HashMap<Meta, ExtensionSet> = HashMap::new();

let variable_scope = self.search_variable_deps();
for m in variable_scope.into_iter() {
// If `m` has been merged, [`self.variables`] entry
// will have already been updated to the merged
// value by [`self.merge_equal_metas`] so we don't
// need to worry about resolving it.
if !self.solved.contains_key(&m) {
// Handle the case where the constraints for `m` contain a self
// reference, i.e. "m = Plus(E, m)", in which case the variable
// should be instantiated to E rather than the empty set.
let solution =
ExtensionSet::from_iter(self.get_constraints(&m).unwrap().iter().filter_map(
|c| match c {
// If `m` has been merged, [`self.variables`] entry
// will have already been updated to the merged
// value by [`self.merge_equal_metas`] so we don't
// need to worry about resolving it.
Constraint::Plus(x, other_m) if m == self.resolve(*other_m) => {
Some(x.clone())
}
let plus_constraints =
self.get_constraints(&m)
.unwrap()
.iter()
.cloned()
acl-cqc marked this conversation as resolved.
Show resolved Hide resolved
.flat_map(|c| match c {
Constraint::Plus(r, other_m) => Some((r, self.resolve(other_m))),
_ => None,
},
));
self.add_solution(m, solution);
});

let (rs, other_ms): (Vec<_>, Vec<_>) = plus_constraints.unzip();
let solution = ExtensionSet::from_iter(rs.into_iter());
let unresolved_metas = other_ms
.into_iter()
.filter(|other_m| m != *other_m)
.collect::<Vec<_>>();

// If `m` doesn't depend on any other metas then we have all the
// information we need to come up with a solution for it.
if unresolved_metas.is_empty() {
self.add_solution(m, solution.clone());
} else {
unresolved_metas
.iter()
.for_each(|other_m| relations.add_edge(m, *other_m));
}
solutions.insert(m, solution);
}
}

// Strongly connected components are looping constraint dependencies.
// This means that each metavariable in the CC has the same solution.
for cc in relations.ccs() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing that bothers me a bit here is that we don't have much of a guarantee that relations.ccs() contains every node that we didn't add a solution for earlier. (That guarantee comes from the idea that only if there was a cyclic dependency would we have failed to find a solution earlier, so every node added to relations must be in a cycle, but still.)

One possible answer?? Don't mutate via self.add_solution but return a complete Map. Then you can assert that the map has the right number of entries (same as variable_scope.len()). Then you can add that Map to self.solutions all in one go (after debug_assert!ing that the keyset is disjoint, as per fn add_solution)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(There are other ways, you could record the ms for which the first loop failed to find a solution, then tick them off here. But that'd be more expensive when debug_assert is disabled)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could add an assert that the number of nodes in relations.css() (flattened) is the same as the length of variable_scope?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'd need to filter out the number of those that were already in self.solved (there's an if !self.solved.contains)...

But the flattening idea is good - how about just that the number of nodes in flattened-relations.css() is equal to the number of edges in relations?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, some edges share endpoints. This works:

Suggested change
for cc in relations.ccs() {
assert_eq!(relations.node_map.len(), relations.ccs().iter().map(Vec::len).sum::<usize>());

Common up the two relations.ccs tho!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure that check really makes sense, since it's apparent that the size of node_map and the number of nodes in the graph will be the same from looking at GraphContainer::add_or_retrieve, and sccs will return a clustered version of every node in the graph

Copy link
Contributor

@acl-cqc acl-cqc Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed size of node map equals number of nodes in the graph, yes the latter is what I meant.

But SCC = Strongly Connected Component, i.e. every node reaches every other? In a directed graph? Ah, are ccs undirectly-connected?

So A -> B -> C -> A is one SCC, with all three nodes. But A -> B -> C is NOT an SCC by that definition, but are you saying that .ccs() would give you [[A],[B],[C]]?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, indeed it does. In which case yeah, no assert/check really seems necessary, we're only really protecting against somehow getting in a muddle ourselves between the two loops.

Does make me wonder whether we could just put everything through the relations map, then. A node with no connections will trivially (and cheaply) be its own SCC. So this above:

if unresolved_metas,is_empty() {
   self.add_solution(m, solution.clone())
} else {
  unresolved_metas.iter().for_each(.....)
}

could just be

relations.add_or_retrieve(m);
unresolved_metas.iter().for_each(....)

and then all the add_solution calls would be in the second loop.

let mut combined_solution = ExtensionSet::new();
for sol in cc.iter().filter_map(|m| solutions.get(m)) {
combined_solution = combined_solution.union(sol);
}
for m in cc.iter() {
self.add_solution(*m, combined_solution.clone());
Copy link
Contributor

@acl-cqc acl-cqc Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to add to solutions here do we? What if we had, say, two SCCs with an edge between them:

A -> B -> C -> A
D -> E -> F -> D
A -> D

I note Petgraph's tarjan_scc returns the components in a defined order (reverse topsort or something), so in theory (maybe you have to reverse cc) it's soluble, but I think you need to consider not just solutions.get(m) for each m in the SCC but the solutions to the constraints upon m (which by that topsort ordering have already been computed).

Or, maybe such a structure can't occur, I dunno....????

Copy link
Contributor

@acl-cqc acl-cqc Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try branch inference-variable/fix-dependent-sccs (the last commit) - some uncertainties detailed in the comments there, and needs a test of a structure such as the ABCDEF above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooops, realize that wasn't passing tests! A missed self.resolve should make it pass now.

}
}
self.variables = HashSet::new();
Expand Down Expand Up @@ -1509,17 +1581,14 @@ mod test {
#[test]
fn test_cfg_loops() -> Result<(), Box<dyn Error>> {
let just_a = ExtensionSet::singleton(&A);
let variants = vec![
(
ExtensionSet::new(),
ExtensionSet::new(),
ExtensionSet::new(),
),
(just_a.clone(), ExtensionSet::new(), ExtensionSet::new()),
(ExtensionSet::new(), just_a.clone(), ExtensionSet::new()),
(ExtensionSet::new(), ExtensionSet::new(), just_a.clone()),
];

let mut variants = Vec::new();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hahaha! Neat :-) You might consider flat_map (*3) to avoid mut but like it either way :)

for entry in [ExtensionSet::new(), just_a.clone()] {
for bb1 in [ExtensionSet::new(), just_a.clone()] {
for bb2 in [ExtensionSet::new(), just_a.clone()] {
variants.push((entry.clone(), bb1.clone(), bb2.clone()));
}
}
}
for (bb0, bb1, bb2) in variants.into_iter() {
let mut hugr = make_looping_cfg(bb0, bb1, bb2)?;
hugr.infer_and_validate(&PRELUDE_REGISTRY)?;
Expand Down