Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Structured smoothing #69

Merged
merged 11 commits into from
Mar 4, 2021
3 changes: 2 additions & 1 deletion src/LogicCircuits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ include("abstract_logic_nodes.jl")
include("bit_circuit.jl")
include("queries.jl")
include("satisfies_flow.jl")
include("transformations.jl")
include("plain_logic_nodes.jl")

include("structured/abstract_vtrees.jl")
include("structured/plain_vtrees.jl")
include("structured/structured_logic_nodes.jl")

include("transformations.jl")

include("sdd/sdds.jl")
include("sdd/sdd_functions.jl")
include("sdd/apply.jl")
Expand Down
1 change: 1 addition & 0 deletions src/Utils/trees.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ lca(::Nothing, v::Tree, ::Function)::Tree = v
lca(v::Tree, ::Nothing, ::Function)::Tree = v
lca(::Nothing, ::Nothing, ::Function)::Nothing = nothing
lca(v::Tree, w::Tree, u::Tree, r::Tree...) = lca(lca(v,w), u, r...)
lca(v::Tree)::Tree = v

"Find the leaf in the tree by follwing the branching function"
find_leaf(n::Tree, branch::Function) = find_leaf(n, NodeType(n), branch)
Expand Down
10 changes: 7 additions & 3 deletions src/queries.jl
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,9 @@ function implied_literals_rec(root::LogicCircuit, lcache::Dict{LogicCircuit, Uni
lcache[root] = BitSet([literal(root)])
elseif isinnergate(root)
for c in root.children
implied_literals_rec(c, lcache)
if !haskey(lcache, c)
implied_literals_rec(c, lcache)
end
end
if is⋀gate(root)
# If there's a false in here then this is false too
Expand Down Expand Up @@ -331,9 +333,11 @@ const Signature = Vector{Rational{BigInt}}
Get a signature for each node using probabilistic equivalence checking.
Note that this implentation may not have any formal guarantees as such.
"""
function prob_equiv_signature(circuit::LogicCircuit, k::Int)::Dict{Union{Var,Node},Signature}
function prob_equiv_signature(circuit::LogicCircuit, k::Int, signs=Dict{Union{Var,Node},Signature}())::Dict{Union{Var,Node},Signature}
# uses probability instead of integers to circumvent smoothing, no mod though
signs::Dict{Union{Var,Node},Signature} = Dict{Union{Var,Node},Signature}()
if signs === nothing
signs = Dict{Union{Var,Node},Signature}()
end
prime::Int = 7919 #TODO set as smallest prime larger than num_variables
randprob() = BigInt(1) .// rand(1:prime,k)
do_signs(v::Var) = get!(randprob, signs, v)
Expand Down
77 changes: 77 additions & 0 deletions src/transformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,43 @@ export smooth, forget, propagate_constants, deepcopy, condition, replace_node,
split, clone, merge, split_candidates, random_split, split_step, struct_learn,
clone_candidates, standardize_circuit


"""
smooth(root::StructLogicCircuit)::StructLogicCircuit

Create an equivalent smooth circuit from the given circuit.
"""
function smooth(root::StructLogicCircuit)::StructLogicCircuit
(false_node, true_node) = canonical_constants(root)
@assert (false_node === nothing && true_node === nothing) "You should propagate constants before smoothing a structured circuit!"

lit_nodes = canonical_literals(root)
f_con(n) = (n, BitSet())
f_lit(n) = (n, BitSet(variable(n)))
f_a(n, call) = begin
(prime, pscope) = call(n.prime)
prime = fill_missing_vtree(prime, vtree(n).left, vtree(prime), lit_nodes)
pscope = variables(vtree(n).left)
(sub, sscope) = call(n.sub)
sub = fill_missing_vtree(sub, vtree(n).right, vtree(sub), lit_nodes)
sscope = variables(vtree(n).right)

parent_scope = pscope ∪ sscope
smoothed = conjoin([prime, sub])
@assert variables(vtree(smoothed)) == parent_scope
(smoothed, parent_scope)
end
f_o(n, call) = begin
parent_scope = mapreduce(c -> call(c)[2], union, children(n))
smooth_children = Vector{Node}(undef, num_children(n))
map!(smooth_children, children(n)) do child
(smooth_child, scope) = call(child)
smooth_node(smooth_child, parent_scope, scope, lit_nodes)
end
return (disjoin([smooth_children...]; reuse=n), parent_scope)
end
foldup(root, f_con, f_lit, f_a, f_o, Tuple{Node, BitSet})[1]
end
"""
smooth(root::Node)::Node

Expand Down Expand Up @@ -33,6 +70,46 @@ function smooth(root::Node)::Node
foldup(root, f_con, f_lit, f_a, f_o, Tuple{Node, BitSet})[1]
end


"""
smooth_node(node::StructLogicCircuit, parent_scope, scope, lit_nodes)

Return a smooth version of the node where
the are added to the scope by filling the gap in vtrees, using literals from `lit_nodes`
"""
function smooth_node(node::StructLogicCircuit, parent_scope, scope, lit_nodes)
# Compute the vtrees based on variables used
target_vtr = lca(map(x -> vtree(lit_nodes[x]), collect(parent_scope))...)
curr_vtr = lca(map(x -> vtree(lit_nodes[x]), collect(scope))...)
if curr_vtr == target_vtr
@assert isempty(setdiff(parent_scope, scope))
node # If the node is where it should be on the vtree, we're done
else
fill_missing_vtree(node, target_vtr, curr_vtr, lit_nodes)
end
end


"Construct a smoothed node from start_vtr, when you get to end_vtr insert the original node"
function fill_missing_vtree(node::StructLogicCircuit, start_vtr, end_vtr, lit_nodes)
# If we're at the end just return
if start_vtr == end_vtr
node
elseif isleaf(start_vtr)
# If we're at a leaf node, just return the disjunction of literals
get_lit(l,vtr) = get!(() -> compile(StructLogicCircuit, vtr, l), lit_nodes, l)
lit = var2lit(start_vtr.var)
lit_node = get_lit(lit, start_vtr)
not_lit_node = get_lit(-lit, start_vtr)
disjoin([lit_node, not_lit_node]; use_vtree=start_vtr)
else
# We're at an inner node, so recurse and conjoin the results?
left = fill_missing_vtree(node, start_vtr.left, end_vtr, lit_nodes)
right = fill_missing_vtree(node, start_vtr.right, end_vtr, lit_nodes)
conjoin(left, right, use_vtree=start_vtr)
end
end

"""
smooth_node(node::Node, missing_scope, lit_nodes)

Expand Down
17 changes: 17 additions & 0 deletions test/transformations_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,23 @@ include("helper/plain_logic_circuits.jl")
end
end

@testset "Structured smooth test" begin
sdd = zoo_sdd("random.sdd")
vtr = zoo_vtree("random.vtree")
slc = smooth(sdd)
plc = propagate_constants(sdd, remove_unary=true)
structplc = compile(StructLogicCircuit, vtr, plc)
sstructplc = smooth(structplc)
Comment on lines +21 to +24
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to smooth it twice? Is it because propagage constants need smooth circuit?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The smoothed logic circuit is totally orthogonal, it's just for testing. I'm using probabilistic equivalence checking to make sure the smoothing didn't somehow change the meaning of the circuit (that you get the same thing regardless of how you smooth).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, makes sense, I see now, initially read line 22 wrong.


@test !issmooth(sdd)
@test issmooth(slc)
@test issmooth(sstructplc)

e1 = prob_equiv_signature(slc, 3)
e2 = prob_equiv_signature(sstructplc, 3, e1)
@test e1[slc] == e2[sstructplc]
end

@testset "Forget test" begin
for file in [zoo_sdd_file("random.sdd")] # save some test time;zoo_psdd_file("plants.psdd"),
c1 = load_logic_circuit(file)
Expand Down