Skip to content

Commit

Permalink
Add BacktrackingTree and debug keyword
Browse files Browse the repository at this point in the history
  • Loading branch information
Kris Brown committed Sep 12, 2023
1 parent ff10eff commit 5ca77d9
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 26 deletions.
128 changes: 102 additions & 26 deletions src/categorical_algebra/HomSearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ export ACSetHomomorphismAlgorithm, BacktrackingSearch, HomomorphismQuery,
homomorphism, homomorphisms, is_homomorphic,
isomorphism, isomorphisms, is_isomorphic,
@acset_transformation, @acset_transformations,
subobject_graph, partial_overlaps, maximum_common_subobject
subobject_graph, partial_overlaps, maximum_common_subobject,
debug_homomorphisms

using ...Theories, ..CSets, ..FinSets, ..FreeDiagrams, ..Subobjects
using ...Graphs.BasicGraphs
Expand All @@ -20,7 +21,7 @@ using ACSets.DenseACSets: attrtype_type, delete_subobj!
using Random
using CompTime
using MLStyle: @match
using DataStructures: BinaryHeap, DefaultDict
using DataStructures: BinaryHeap, DefaultDict, OrderedDict

# Finding C-set transformations
###############################
Expand Down Expand Up @@ -85,7 +86,7 @@ homomorphism(X::ACSet, Y::ACSet; alg=BacktrackingSearch(), kw...) =
function homomorphism(X::ACSet, Y::ACSet, alg::BacktrackingSearch; kw...)
result = nothing
backtracking_search(X, Y; kw...) do α
result = α; return true
result = get_hom(α); return true
end
result
end
Expand All @@ -101,11 +102,19 @@ homomorphisms(X::ACSet, Y::ACSet; alg=BacktrackingSearch(), kw...) =
function homomorphisms(X::ACSet, Y::ACSet, alg::BacktrackingSearch; kw...)
results = []
backtracking_search(X, Y; kw...) do α
push!(results, map_components(deepcopy, α)); return false
push!(results, map_components(deepcopy, get_hom(α))); return false
end
results
end

function debug_homomorphisms(X::ACSet, Y::ACSet; kw...)
results = []
m = backtracking_search(X, Y; debug=true, kw...) do α
push!(results, map_components(deepcopy, get_hom(α))); return false
end
results => m.debug
end

""" Is the first attributed ``C``-set homomorphic to the second?
This function generally reduces to [`homomorphism`](@ref) but certain algorithms
Expand Down Expand Up @@ -152,6 +161,60 @@ is_isomorphic(X::ACSet, Y::ACSet, alg::BacktrackingSearch; kw...) =
# Backtracking search
#--------------------

"""Keep track of progress through backtracking homomorphism search."""
mutable struct BacktrackingTree
node::Union{Nothing,Pair{Symbol,Int}}
success::Bool
asgn::NamedTuple
children::OrderedDict{Int,BacktrackingTree}
BacktrackingTree() = new(nothing, false, (;), OrderedDict{Int,BacktrackingTree}())
end

"""A backtracking tree plus a pointer to a node in the tree"""
struct BacktrackingTreePt
t::BacktrackingTree
curr::Vector{Int}
BacktrackingTreePt() = new(BacktrackingTree(),Int[])
end

function Base.push!(tc::BacktrackingTreePt, c::Symbol, x::Int, y::Int, asgn)
t = tc.t[tc.curr]
t.node = c => x
t.children[y] = BacktrackingTree()
t.children[y].asgn = deepcopy(asgn)
push!(tc.curr, y)
return true
end

function Base.delete!(tc::BacktrackingTreePt, c::Symbol, x::Int, y::Int)
t = tc.t[tc.curr[1:end-1]]
t.node == (c=>x) || error("Bad remove $c#$x->$y")
pop!(tc.curr)
end

function success(tc::BacktrackingTreePt)
tc.t[tc.curr].success = true
end

function Base.show(io::IO, t::BacktrackingTree)
if !isnothing(t.node)
print(io,"{"); print(io, t.node[1]); print(io, t.node[2]); print(io,"}");

Check warning on line 201 in src/categorical_algebra/HomSearch.jl

View check run for this annotation

Codecov / codecov/patch

src/categorical_algebra/HomSearch.jl#L199-L201

Added lines #L199 - L201 were not covered by tests
end
print(io, "[")
for (k,v) in collect(t.children)
print(io, k); print(io, v); print(io, ",")
end
if !isempty(t.children) print(io,"\b") end
print(io,"]")

Check warning on line 208 in src/categorical_algebra/HomSearch.jl

View check run for this annotation

Codecov / codecov/patch

src/categorical_algebra/HomSearch.jl#L203-L208

Added lines #L203 - L208 were not covered by tests
end

function Base.getindex(t::BacktrackingTree, curr::Vector{Int})
for c in curr
t = t.children[c]
end
t
end

""" Get assignment pairs from partially specified component of C-set morphism.
"""
partial_assignments(x::FinFunction; is_attr=false) = partial_assignments(collect(x))
Expand All @@ -177,10 +240,25 @@ struct BacktrackingState{
dom::Dom
codom::Codom
type_components::LooseFun
debug::Union{Nothing,BacktrackingTreePt}
end

"""Extract an ACSetTransformation from BacktrackingState"""
function get_hom(state::BacktrackingState)
if any(!=(identity), state.type_components)
return LooseACSetTransformation(
state.assignment, state.type_components, state.dom, state.codom)
else
S = acset_schema(state.dom)
od = Dict{Symbol,Vector{Int}}(k=>(state.assignment[k]) for k in objects(S))
ad = Dict(k=>last.(state.assignment[k]) for k in attrtypes(S))
comps = merge(NamedTuple(od),NamedTuple(ad))
return ACSetTransformation(comps, state.dom, state.codom)
end
end

function backtracking_search(f, X::ACSet, Y::ACSet;
monic=false, iso=false, random=false,
monic=false, iso=false, random=false, debug=false,
type_components=(;), initial=(;), error_failures=false)
S, Sy = acset_schema.([X,Y])
S == Sy || error("Schemas must match for morphism search")
Expand Down Expand Up @@ -235,9 +313,11 @@ function backtracking_search(f, X::ACSet, Y::ACSet;
inv_assignment = NamedTuple{ObAttr}(
(c in monic ? zeros(Int, nparts(Y, c)) : nothing) for c in ObAttr)
loosefuns = NamedTuple{Attr}(
isnothing(type_components) ? identity : get(type_components, c, identity) for c in Attr)
state = BacktrackingState(assignment, assignment_depth,
inv_assignment, X, Y, loosefuns)
isnothing(type_components) ? identity : get(type_components, c, identity)
for c in Attr)

state = BacktrackingState(assignment, assignment_depth, inv_assignment, X, Y,
loosefuns, debug ? BacktrackingTreePt() : nothing)

# Make any initial assignments, failing immediately if inconsistent.
for (c, c_assignments) in pairs(initial)
Expand All @@ -252,39 +332,32 @@ function backtracking_search(f, X::ACSet, Y::ACSet;
end
end
# Start the main recursion for backtracking search.
backtracking_search(f, state, 1; random=random)
backtracking_search(f, state, 1; random=random, toplevel=true)
end

function backtracking_search(f, state::BacktrackingState, depth::Int;
random=false)
random=false, toplevel=false)
# Choose the next unassigned element.
mrv, mrv_elem = find_mrv_elem(state, depth)
if isnothing(mrv_elem)
# No unassigned elements remain, so we have a complete assignment.
if any(!=(identity), state.type_components)
return f(LooseACSetTransformation(
state.assignment, state.type_components, state.dom, state.codom))
else
S = acset_schema(state.dom)
od = Dict{Symbol,Vector{Int}}(k=>(state.assignment[k]) for k in objects(S))
ad = Dict(k=>last.(state.assignment[k]) for k in attrtypes(S))
comps = merge(NamedTuple(od),NamedTuple(ad))
return f(ACSetTransformation(comps, state.dom, state.codom))
end
isnothing(state.debug) || success(state.debug)
return f(state)
elseif mrv == 0
# An element has no allowable assignment, so we must backtrack.
return false
end
c, x = mrv_elem
c, x, ys = mrv_elem

# Attempt all assignments of the chosen element.
Y = state.codom
for y in (random ? shuffle : identity)(parts(Y, c))
for y in (random ? shuffle : identity)(ys)
(assign_elem!(state, depth, c, x, y)
&& (isnothing(state.debug) ? true : push!(state.debug, c, x, y, state.assignment))
&& backtracking_search(f, state, depth + 1)) && return true
unassign_elem!(state, depth, c, x)
isnothing(state.debug) || delete!(state.debug, c, x, state.assignment[c][x])
end
return false
return toplevel ? state : false # return state to recover debug tree
end

""" Find an unassigned element having the minimum remaining values (MRV).
Expand All @@ -295,9 +368,12 @@ function find_mrv_elem(state::BacktrackingState, depth)
Y = state.codom
for c in ob(S), (x, y) in enumerate(state.assignment[c])
y == 0 || continue
n = count(can_assign_elem(state, depth, c, x, y) for y in parts(Y, c))
ys = filter(parts(Y,c)) do y
can_assign_elem(state, depth, c, x, y)
end
n = length(ys)
if n < mrv
mrv, mrv_elem = n, (c, x)
mrv, mrv_elem = n, (c, x, ys)
end
end
(mrv, mrv_elem)
Expand Down
34 changes: 34 additions & 0 deletions src/graphics/GraphvizCategories.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ export to_graphviz, to_graphviz_property_graph
using ...GATs, ...Theories, ...CategoricalAlgebra, ...Graphs, ..GraphvizGraphs
import ..Graphviz
import ..GraphvizGraphs: to_graphviz, to_graphviz_property_graph
using ...CategoricalAlgebra.HomSearch: BacktrackingTree, BacktrackingTreePt

# Presentations
###############
Expand Down Expand Up @@ -143,4 +144,37 @@ function to_graphviz(f::FinFunction{Int,Int}; kw...)
to_graphviz(g; kw...)
end

# Search trees
###############
to_graphviz(t::BacktrackingTreePt) = to_graphviz(t.t)

Check warning on line 149 in src/graphics/GraphvizCategories.jl

View check run for this annotation

Codecov / codecov/patch

src/graphics/GraphvizCategories.jl#L149

Added line #L149 was not covered by tests

function to_graphviz(t::BacktrackingTree)
pg = PropertyGraph{Any}(;

Check warning on line 152 in src/graphics/GraphvizCategories.jl

View check run for this annotation

Codecov / codecov/patch

src/graphics/GraphvizCategories.jl#L151-L152

Added lines #L151 - L152 were not covered by tests
prog = "dot",
graph = Dict(),
node = merge!(Dict(:shape => "box", :width => ".1", :height => ".1",
:margin => "0.025", :style=>"filled")),
edge = Dict())
kwargs(tr::BacktrackingTree) = (

Check warning on line 158 in src/graphics/GraphvizCategories.jl

View check run for this annotation

Codecov / codecov/patch

src/graphics/GraphvizCategories.jl#L158

Added line #L158 was not covered by tests
fillcolor=tr.success ? "green" : "red",
tooltip=isempty(tr.asgn) ? "" : string(tr.asgn),
label = isnothing(tr.node) ? "" : join(string.([tr.node...])))
add_vertex!(pg; kwargs(t)...)
queue = [Int[]]
paths = Dict([Int[]=>1]) # path to vertex
while !isempty(queue)
curr = popfirst!(queue)
subt = t[curr]

Check warning on line 167 in src/graphics/GraphvizCategories.jl

View check run for this annotation

Codecov / codecov/patch

src/graphics/GraphvizCategories.jl#L162-L167

Added lines #L162 - L167 were not covered by tests
# We ought print the index too, but graphviz renders edges in right order
for (_,e) in enumerate(keys(subt.children))
new_pth = [curr...,e]
v = add_vertex!(pg; kwargs(t[new_pth])...)
paths[new_pth] = v
add_edge!(pg, paths[curr], v; label=string("$e"))
push!(queue, new_pth)
end
end
to_graphviz(pg)

Check warning on line 177 in src/graphics/GraphvizCategories.jl

View check run for this annotation

Codecov / codecov/patch

src/graphics/GraphvizCategories.jl#L169-L177

Added lines #L169 - L177 were not covered by tests
end

end
36 changes: 36 additions & 0 deletions test/categorical_algebra/HomSearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,42 @@ end
@test length(@acset_transformations x x begin V = Dict(1=>1) end monic = [:E]) == 2
@test_throws ErrorException @acset_transformation g h begin V = [4,3,2,1]; E = [1,2,3,4] end

# Debug graph
#------------
@present SchTri <: SchGraph begin
T::Ob
(t1,t2,t3)::Hom(T,E)
t1 src == t2 src
t1 tgt == t3 tgt
t2 src == t3 src
end

@acset_type Tri(SchTri)

""" e₃
2 ← 4
e₁↑ ↖ ↓ e₄
1 → 3
e₂
"""
quad = @acset Tri begin V=4; E=5; T=2;
src=[1,1,4,4,3]; tgt=[2,3,2,3,2];
t1=[1,3]; t2=[2,4]; t3=[5,5]
end

term = apex(terminal(Tri))

tri5 = @acset Tri begin
V=2; E=3; T=5; src=[1,1,2]; tgt=[2,2,2]; t1=1; t2=2; t3=3
end

tri = @acset Tri begin
V=3; E=3; T=1; src=[1,1,2]; tgt=[3,2,3]; t1=1; t2=2; t3=3
end

hs, t = debug_homomorphisms(tri, quad tri5; monic=false)
@test length(hs) == length(homomorphisms(tri, quad tri5))
# to_graphviz(t)

# Enumeration of subobjects
###########################
Expand Down

0 comments on commit 5ca77d9

Please sign in to comment.