Skip to content

Commit

Permalink
Merge pull request #36 from Juice-jl/sdd_mgr_refactoring
Browse files Browse the repository at this point in the history
Sdd mgr refactoring
  • Loading branch information
guyvdbroeck authored Jul 29, 2020
2 parents 9a357ed + 36a7d11 commit bc7a4b1
Show file tree
Hide file tree
Showing 16 changed files with 402 additions and 440 deletions.
4 changes: 2 additions & 2 deletions src/LogicCircuits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ include("structured/plain_vtrees.jl")
include("structured/structured_logic_nodes.jl")

include("sdd/sdds.jl")
include("sdd/trimmed_sdds.jl")
include("sdd/trimmed_apply.jl")
include("sdd/sdd_functions.jl")
include("sdd/apply.jl")

include("LoadSave/LoadSave.jl")
@reexport using .LoadSave
Expand Down
2 changes: 1 addition & 1 deletion src/Utils/trees.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ abstract type Tree <: Dag end
import Base: parent

"Get the parent of a given tree node (or nothing if the node is root)"
function parent end
parent(n::Tree) = n.parent

"Does the node have a parent?"
@inline has_parent(n::Tree)::Bool = issomething(parent(n))
Expand Down
77 changes: 39 additions & 38 deletions src/sdd/trimmed_apply.jl → src/sdd/apply.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
"""
Conjoin two SDDs
"""
@inline conjoin(::SddFalseNode, ::SddTrueNode)::SddFalseNode = trimfalse
@inline conjoin(::SddTrueNode, ::SddFalseNode)::SddFalseNode = trimfalse
@inline conjoin(s::Sdd, ::SddTrueNode)::Sdd = s
@inline conjoin(::Sdd, ::SddFalseNode)::SddFalseNode = trimfalse
@inline conjoin(::SddTrueNode, s::Sdd)::Sdd = s
@inline conjoin(::SddFalseNode, ::Sdd)::SddFalseNode = trimfalse
@inline conjoin(::SddTrueNode, ::SddTrueNode)::Sdd = trimtrue
@inline conjoin(::SddFalseNode, ::SddFalseNode)::Sdd = trimfalse
@inline conjoin(::SddFalseNode, ::SddTrueNode) = false_sdd
@inline conjoin(::SddTrueNode, ::SddFalseNode) = false_sdd
@inline conjoin(s::Sdd, ::SddTrueNode) = s
@inline conjoin(::Sdd, ::SddFalseNode) = false_sdd
@inline conjoin(::SddTrueNode, s::Sdd) = s
@inline conjoin(::SddFalseNode, ::Sdd) = false_sdd
@inline conjoin(::SddTrueNode, ::SddTrueNode) = true_sdd
@inline conjoin(::SddFalseNode, ::SddFalseNode) = false_sdd

# const stats = Dict{Tuple{Int,Int},Int}()

function conjoin(s::SddLiteralNode, t::SddLiteralNode)::Sdd
if tmgr(s) === tmgr(t)
(s === t) ? s : trimfalse
if mgr(s) === mgr(t)
(s === t) ? s : false_sdd
else
conjoin_indep(s,t)
end
Expand All @@ -24,7 +24,7 @@ end
# Note: attempts to make a special cache for conjunctions with literals have not yielded speedups

function conjoin(s::Sdd, t::Sdd)::Sdd
if tmgr(s) === tmgr(t)
if mgr(s) === mgr(t)
conjoin_cartesian(t,s)
elseif varsubset(s,t)
conjoin_descendent(s,t)
Expand All @@ -42,15 +42,15 @@ function conjoin_cartesian(n1::Sdd⋁Node, n2::Sdd⋁Node)::Sdd
if n1 === n2
return n1
elseif n1 === !n2
return trimfalse
return false_sdd
end
get!(tmgr(n1).conjoin_cache, ApplyArgs(n1,n2)) do
get!(mgr(n1).conjoin_cache, ApplyArgs(n1,n2)) do
conjoin_cartesian_general(n1,n2)
end
end::Sdd
end


function conjoin_cartesian_general(n1::Sdd⋁Node, n2::Sdd⋁Node)::Sdd
function conjoin_cartesian_general(n1::Sdd⋁Node, n2::Sdd⋁Node)
# vast majority of cases are 2x2 and 2x3 applies, yet specializing for those cases does not appear to speed things up

out = XYPartition()
Expand All @@ -69,7 +69,7 @@ function conjoin_cartesian_general(n1::Sdd⋁Node, n2::Sdd⋁Node)::Sdd
conjoin_cartesian_cheap(out, elems1, elems2, maski, maskj)
conjoin_cartesian_expensive(out, elems1, elems2, maski, maskj)

canonicalize(out, tmgr(n1))
canonicalize(out, mgr(n1))::Sdd
end

function conjoin_cartesian_cheap(out, elems1, elems2, maski, maskj)
Expand Down Expand Up @@ -118,7 +118,7 @@ function conjoin_cartesian_expensive(out, elems1, elems2, maski, maskj)
if !maski[i] && !maskj[j]
e2 = elems2[j]
newprime = conjoin(prime(e1),prime(e2))
if newprime !== trimfalse
if newprime !== false_sdd
newsub = conjoin(sub(e1),sub(e2))
push!(out, Element(newprime, newsub))
end
Expand All @@ -141,24 +141,24 @@ end
"""
Conjoin two SDDs when one descends from the other
"""
function conjoin_descendent(d::Sdd, n::Sdd)::Sdd # specialize for Literals?
get!(tmgr(n).conjoin_cache, ApplyArgs(d,n)) do
function conjoin_descendent(d::Sdd, n::Sdd)
get!(mgr(n).conjoin_cache, ApplyArgs(d,n)) do
elems = children(n)
if varsubset_left(d, n)
out = XYPartition()
sizehint!(out, length(elems)+1)
i = findfirst(c -> prime(c) === d, elems)
if issomething(i)
# there is a prime equal to d, all other primes will conjoin to false
if sub(elems[i]) === trimfalse
return trimfalse
elseif sub(elems[i]) === trimtrue
if sub(elems[i]) === false_sdd
return false_sdd
elseif sub(elems[i]) === true_sdd
return d
else
push!(out, Element(d, sub(elems[i])))
push!(out, Element(!d, trimfalse))
push!(out, Element(!d, false_sdd))
# since d is not a constant, must be trimmed and compressed
return unique⋁(out, tmgr(n))
return unique⋁(out, mgr(n))
end
end
i = findfirst(c -> prime(c) === !d, elems)
Expand All @@ -170,41 +170,42 @@ function conjoin_descendent(d::Sdd, n::Sdd)::Sdd # specialize for Literals?
else
for e in elems
newprime = conjoin(prime(e),d)
if (newprime !== trimfalse)
if (newprime !== false_sdd)
push!(out, Element(newprime, sub(e)))
elseif newprime === d
# all future conjunctions will yield false
break
end
end
end
push!(out, Element(!d, trimfalse))
push!(out, Element(!d, false_sdd))
else
# @assert varsubset_right(d, n)
# TODO: build vector and compress all at once...
out = [Element(prime(e),conjoin(sub(e),d)) for e in elems]
end
canonicalize(out, tmgr(n))
canonicalize(out, mgr(n))
end
end

"""
Conjoin two SDDs in separate parts of the vtree
"""
function conjoin_indep(s::Sdd, t::Sdd)::Sdd⋁Node
function conjoin_indep(s::Sdd, t::Sdd)
# @assert GateType(s)!=ConstantGate() && GateType(t)!=ConstantGate()
mgr = lca(tmgr(s),tmgr(t))
get!(mgr.conjoin_cache, ApplyArgs(s,t)) do
if varsubset_left(tmgr(s), mgr)
@assert varsubset_right(tmgr(t), mgr)
elements = Element[Element(s,t),Element(!s,trimfalse)]
lca_mgr = lca(mgr(s),mgr(t))
get!(lca_mgr.conjoin_cache, ApplyArgs(s,t)) do
if varsubset_left(mgr(s), lca_mgr)
@assert varsubset_right(mgr(t), lca_mgr)
elements = Element[Element(s,t),Element(!s,false_sdd)]
else
@assert varsubset_left(tmgr(t), mgr)
@assert varsubset_right(tmgr(s), mgr)
elements = Element[Element(t,s),Element(!t,trimfalse)]
@assert varsubset_left(mgr(t), lca_mgr)
@assert varsubset_right(mgr(s), lca_mgr)
elements = Element[Element(t,s),Element(!t,false_sdd)]
end
# TODO: the XY partition must already be compressed and trimmed
unique⋁(elements, mgr)
end
unique⋁(elements, lca_mgr)
end::Sdd⋁Node
end

"""
Expand Down
208 changes: 208 additions & 0 deletions src/sdd/sdd_functions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
export prime, sub, sdd_size, sdd_num_nodes, mgr,
compress, unique⋁, canonicalize, negate


#####################
# SDDmgr
#####################

@inline children(n::SddMgrInnerNode) = [n.left, n.right]

@inline variable(n::SddMgrLeafNode)::Var = n.var
@inline variables(n::SddMgrLeafNode)::BitSet = BitSet(n.var)
@inline variables(n::SddMgrInnerNode)::BitSet = n.variables

#####################
# Sdd
#####################

"Get the manager of a `Sdd` node, which is its `SddMgr` vtree"
mgr(s::Sdd) = s.vtree

@inline constant(::SddTrueNode) = true
@inline constant(::SddFalseNode) = false

@inline children(n::Sdd⋀Node) = [n.prime,n.sub]
@inline children(n::Sdd⋁Node) = n.children

@inline varsubset(n::Sdd, m::Sdd) = varsubset(mgr(n), mgr(m))
@inline varsubset_left(n::Sdd, m::Sdd)::Bool = varsubset_left(mgr(n), mgr(m))
@inline varsubset_right(n::Sdd, m::Sdd)::Bool = varsubset_right(mgr(n), mgr(m))


"Count the number of elements in the decision nodes of the SDD"
sdd_size(sdd::Sdd) = mapreduce(n -> num_children(n), +, ⋁_nodes(sdd); init=0) # defined as the number of `elements`; length(⋀_nodes(sdd)) also works but undercounts in case the compiler decides to cache elements

"Count the number of decision nodes in the SDD"
sdd_num_nodes(sdd::Sdd) = length(⋁_nodes(sdd)) # defined as the number of `decisions`

Base.show(io::IO, ::SddTrueNode) = print(io, "")
Base.show(io::IO, ::SddFalseNode) = print(io, "")
Base.show(io::IO, c::SddLiteralNode) = print(io, literal(c))
Base.show(io::IO, c::Sdd⋀Node) = begin
recshow(c::Union{SddConstantNode,SddLiteralNode}) = "$c"
recshow(c::Sdd⋁Node) = "D$(hash(c))"
print(io, "($(recshow(prime(c))),$(recshow(sub(c))))")
end
Base.show(io::IO, c::Sdd⋁Node) = begin
elems = ["$e" for e in children(c)]
print(io, "[$(join(elems,','))]")
end

#############
# Compilation
#############

compile(::Type{<:Sdd}, mgr::SddMgr, arg::Bool) = compile(mgr, arg)
compile(::Type{<:Sdd}, mgr::SddMgr, arg::Lit) = compile(mgr, arg)
compile(::Type{<:Sdd}, mgr::SddMgr, arg::LogicCircuit) = compile(mgr, arg)

"Compile a circuit (e.g., CNF or DNF) into an SDD, bottom up by distributing circuit nodes over vtree nodes"

compile(mgr::SddMgr, c::LogicCircuit, scopes=variables_by_node(c)) =
compile(mgr, c, GateType(c), scopes)
compile(mgr::SddMgr, c::LogicCircuit, ::ConstantGate, _) =
compile(mgr, constant(c))
compile(mgr::SddMgr, c::LogicCircuit, ::LiteralGate, _) =
compile(mgr, literal(c))
compile(mgr::SddMgr, c::LogicCircuit, gt::InnerGate, scopes) =
compile(mgr, NodeType(mgr), children(c), gt, scopes)
compile(mgr::SddMgr, children::Vector{<:LogicCircuit}, gt::InnerGate, scopes) =
compile(mgr, NodeType(mgr), children, gt, scopes)

function compile(mgr::SddMgr, ::Leaf, children::Vector{<:LogicCircuit}, gt::InnerGate, scopes)
isempty(children) && return compile(mgr, neutral(gt))
mapreduce(x -> compile(mgr,x,scopes), op(gt), children)
end

function compile(mgr::SddMgr, ::Inner, children::Vector{<:LogicCircuit}, gt::InnerGate, scopes)
isempty(children) && return compile(mgr, neutral(gt))

# partition children according to vtree
left_children = filter(x -> subseteq_fast(scopes[x], variables(mgr.left)), children)
right_children = filter(x -> subseteq_fast(scopes[x], variables(mgr.right)), children)
middle_children = setdiff(children, left_children, right_children)

# separately compile left and right vtree children
left = compile(mgr.left, left_children, gt, scopes)
right = compile(mgr.right, right_children, gt, scopes)

mapreduce(x -> compile(mgr,x,scopes), op(gt), middle_children; init=op(gt)(left, right))
end

"""
Compile a given variable, literal, or constant
"""
function compile(n::SddMgrLeafNode, l::Lit)::SddLiteralNode
@assert n.var == lit2var(l) "Cannot compile literal $l respecting vtree leaf for variable $(n.var)"
l>0 ? n.positive_literal::SddLiteralNode : n.negative_literal::SddLiteralNode
end

function compile(n::SddMgrInnerNode, l::Lit)
if lit2var(l) variables(n.left)
compile(n.left, l)::SddLiteralNode
else
@assert lit2var(l) variables(n.right) "$l is not contained in this vtree $n with scope $(variables(n))"
compile(n.right, l)::SddLiteralNode
end
end

compile(::SddMgr, constant::Bool) =
constant ? true_sdd : false_sdd

"""
Negate an SDD
"""
@inline negate(::SddFalseNode) = true_sdd
@inline negate(::SddTrueNode) = false_sdd

function negate(s::SddLiteralNode)
if ispositive(s)
mgr(s).negative_literal::SddLiteralNode
else
mgr(s).positive_literal::SddLiteralNode
end
end

negate(node::Sdd⋁Node) = node.negation::Sdd⋁Node

@inline Base.:!(s) = negate(s)

"""
Get the canonical compilation of the given XY Partition
"""
function canonicalize(xy::XYPartition, mgr::SddMgrInnerNode)
# @assert !isempty(xy)
return canonicalize_compressed(compress(xy), mgr)::Sdd
end

"""
Compress a given XY Partition (merge elements with identical subs)
"""
function compress(xy::XYPartition)
compressed = true
for i in eachindex(xy), j in i+1:length(xy)
if (sub(xy[i]) === sub(xy[j]))
compressed = false
break
end
end
compressed && return xy
# make it compressed
out = XYPartition()
sizehint!(out, length(xy))
mask = falses(length(xy))
for i in eachindex(xy)
if !mask[i]
prime_all = prime(xy[i])
sub_i = sub(xy[i])
for j in i+1:length(xy)
sub_j = sub(xy[j])
if !mask[j] && (sub_i === sub_j)
prime_all = prime_all | prime(xy[j])
mask[j] = true
end
end
push!(out,Element(prime_all,sub_i))
end
end
return out::XYPartition
end

"""
Get the canonical compilation of the given compressed XY Partition
"""
function canonicalize_compressed(xy::XYPartition, mgr::SddMgrInnerNode)
# @assert !isempty(xy)
# trim
if length(xy) == 1 && (prime(first(xy)) === true_sdd)
return sub(first(xy))
elseif length(xy) == 2
if (sub(xy[1]) === true_sdd) && (sub(xy[2]) === false_sdd)
return prime(xy[1])
elseif (sub(xy[2]) === true_sdd) && (sub(xy[1]) === false_sdd)
return prime(xy[2])
end
end
# get unique node representation
return unique⋁(xy, mgr)::Sdd⋁Node
end

"""
Construct a unique decision gate for the given vtree
"""
function unique⋁(xy::XYPartition, mgr::SddMgrInnerNode)
#TODO add finalization trigger to remove from the cache when the node is gc'ed + weak value reference
get!(mgr.unique⋁cache, xy) do
xynodes = [Sdd⋀Node(prime(e), sub(e), mgr) for e in xy]
node = Sdd⋁Node(xynodes, mgr)
# some memory allocations can be saved here, by not allocating the intermediate vector of elements
# however, that does not appear to speed things up...
not_xy = [Element(prime(e), !sub(e)) for e in xy]
not_xynodes = [Sdd⋀Node(prime(e), sub(e), mgr) for e in not_xy]
not_node = Sdd⋁Node(not_xynodes, mgr, node)
node.negation = not_node
mgr.unique⋁cache[not_xy] = not_node
node
end::Sdd⋁Node
end
Loading

0 comments on commit bc7a4b1

Please sign in to comment.