From a19dd8b18e25e8024c71385c4cf515204eae211f Mon Sep 17 00:00:00 2001 From: Guy Van den Broeck Date: Sun, 28 Mar 2021 14:25:42 -0700 Subject: [PATCH 1/6] use `Dict` instead of `counter` for circuit traversal --- src/LogicCircuits.jl | 1 + src/Utils/graphs.jl | 191 ++++++++------------------------------- src/bit_circuit.jl | 8 +- src/queries/queries.jl | 72 +++++++-------- src/small_logic_nodes.jl | 95 +++++++++++++++++++ 5 files changed, 173 insertions(+), 194 deletions(-) create mode 100644 src/small_logic_nodes.jl diff --git a/src/LogicCircuits.jl b/src/LogicCircuits.jl index ff8c8c5a..ef48e1cd 100644 --- a/src/LogicCircuits.jl +++ b/src/LogicCircuits.jl @@ -15,6 +15,7 @@ include("queries/satisfies.jl") include("queries/satisfies_flow.jl") include("plain_logic_nodes.jl") +include("small_logic_nodes.jl") include("structured/abstract_vtrees.jl") include("structured/plain_vtrees.jl") diff --git a/src/Utils/graphs.jl b/src/Utils/graphs.jl index d845efa5..f86a1ccb 100644 --- a/src/Utils/graphs.jl +++ b/src/Utils/graphs.jl @@ -1,9 +1,6 @@ export Node, Dag, NodeType, Leaf, Inner, children, has_children, num_children, isleaf, isinner, - reset_counter, reset_counter_hard, foreach, foreach_reset, clear_data, filter, - nload, nsave, - foldup, foldup_aggregate, - count_parents, foreach_down, + foreach, foreach_down, filter, foldup, foldup_aggregate, num_nodes, num_edges, tree_num_nodes, tree_num_edges, in, inodes, innernodes, leafnodes, linearize, left_most_descendent, right_most_descendent, @@ -76,93 +73,45 @@ function children end # traversal ##################### -"Set the counter field throughout this graph (assumes no node in the graph already has the value)" -function reset_counter(node::Dag, v::Int=0) - if node.counter != v - node.counter = v - if isinner(node) - for c in children(node) - reset_counter(c, v) - end - end - end - nothing # returning nothing helps save some allocations and time -end - -"Set the counter field throughout this graph. Is slower but works even when the node fields are in an arbitrary state." -function reset_counter_hard(node::Dag, v::Int=0, seen::Set{Dag}=Set{Dag}()) - if node ∉ seen - push!(seen, node) - node.counter = v - if isinner(node) - for c in children(node) - reset_counter_hard(c, v, seen) - end - end - end - nothing -end - - import Base.foreach #extend "Apply a function to each node in a graph, bottom up" -function foreach(f::Function, node::Dag; reset=true) - @assert node.counter == 0 "Another algorithm is already traversing this circuit and using the `counter` field. You can use `reset_counter` to reset the counter. " - foreach_rec(f, node) - reset && reset_counter(node) - nothing # returning nothing helps save some allocations and time -end +foreach(f::Function, node::Dag, seen::Nothing=nothing) = + foreach_dict(f, node, Dict{Dag,Nothing}()) -function foreach(node::Dag, f_leaf::Function, f_inner::Function) - foreach(node) do n - isinner(n) ? f_inner(n) : f_leaf(n) - end - nothing # returning nothing helps save some allocations and time -end - -"Apply a function to each node in a graph, bottom up, without resetting the counter" -function foreach_rec(f::Function, node::Dag) - if (node.counter += 1) == 1 +function foreach_dict(f::Function, node::Dag, seen) + get!(seen, node) do if isinner(node) for c in children(node) - foreach_rec(f, c) + foreach_dict(f, c, seen) end end f(node) + nothing end - nothing # returning nothing helps save some allocations and time + nothing end -"Apply a function to each node in a graph, bottom up, while resetting the counter" -function foreach_reset(f::Function, node::Dag) - if node.counter != 0 - node.counter = 0 - if isinner(node) - for c in children(node) - foreach_reset(f, c) - end - end - f(node) +function foreach(node::Dag, f_leaf::Function, f_inner::Function, seen=nothing) + foreach(node, seen) do n + isinner(n) ? f_inner(n) : f_leaf(n) end - nothing # returning nothing helps save some allocations and time + nothing end -"Set all the data fields in a circuit to `nothing`" -function clear_data(node::Dag) - foreach(x -> x.data = nothing, node) +"Apply a function to each node in a graph, top down" +function foreach_down(f::Function, node::Dag) + # naive implementation + lin = linearize(node) + foreach(f, Iterators.reverse(lin)) end - -# TODO: consider adding a top-down version of foreach, by either linearizing into a List, -# or by keeping a visit counter to identify processing of the last parent. - import Base.filter #extend """Retrieve list of nodes in graph matching predicate `p`""" -function filter(p::Function, root::Dag, ::Type{T} = Union{})::Vector where T +function filter(p::Function, root::Dag, ::Type{T} = Union{}, seen=nothing)::Vector where T results = Vector{T}() - foreach(root) do n + foreach(root, seen) do n if p(n) if !(n isa eltype(results)) results = collect(typejoin(eltype(results), typeof(n)), results) @@ -173,48 +122,28 @@ function filter(p::Function, root::Dag, ::Type{T} = Union{})::Vector where T results end -"Default getter to obtain data associated with a node" -@inline nload(n) = n.data - -"Default setter to assign data associated with a node" -@inline nsave(n,v) = n.data = v - """ foldup(node::Dag, f_leaf::Function, f_inner::Function, - ::Type{T}; nload = nload, nsave = nsave, reset=true)::T where {T} + ::Type{T})::T where {T} Compute a function bottom-up on the graph. `f_leaf` is called on leaf nodes, and `f_inner` is called on inner nodes. Values of type `T` are passed up the circuit and given to `f_inner` as a function on the children. """ -function foldup(node::Dag, f_leaf::Function, f_inner::Function, - ::Type{T}; nload = nload, nsave = nsave, reset=true) where {T} - @assert node.counter == 0 "Another algorithm is already traversing this circuit and using the `counter` field. You can use `reset_counter` to reset the counter. " - v = foldup_rec(node, f_leaf, f_inner, T; nload, nsave)::T - reset && reset_counter(node) - v +function foldup(node::Dag, f_leaf::Function, f_inner::Function, ::Type{T}, cache::Nothing=nothing) where {T} + foldup(node, f_leaf, f_inner, T, Dict{Dag,T}()) end - -""" -Compute a function bottom-up on the graph, without resetting the counter. -`f_leaf` is called on leaf nodes, and `f_inner` is called on inner nodes. -Values of type `T` are passed up the circuit and given to `f_inner` as a function on the children. -""" -function foldup_rec(node::Dag, f_leaf::Function, f_inner::Function, ::Type{T}; - nload = nload, nsave = nsave) where {T} - if (node.counter += 1) != 1 - return nload(node)::T - else - v = if isinner(node) - callback(c) = (foldup_rec(c, f_leaf, f_inner, T; nload, nsave)::T) - f_inner(node, callback)::T - else - f_leaf(node)::T - end - return nsave(node, v)::T +function foldup(node::Dag, f_leaf::Function, f_inner::Function, ::Type{T}, seen::Dict) where {T} + get!(seen, node) do + if isinner(node) + callback(c) = foldup(c, f_leaf, f_inner, T, seen)::T + f_inner(node, callback)::T + else + f_leaf(node)::T + end end end @@ -224,65 +153,23 @@ Compute a function bottom-up on the circuit. Values of type `T` are passed up the circuit and given to `f_inner` in aggregate as a vector from the children. """ -# TODO: see whether we could standardize on `foldup` and remove this version? -function foldup_aggregate(node::Dag, f_leaf::Function, f_inner::Function, ::Type{T}; - nload = nload, nsave = nsave, reset=true) where {T} - @assert node.counter == 0 "Another algorithm is already traversing this circuit and using the `counter` field. You can use `reset_counter` to reset the counter. " - v = foldup_aggregate_rec(node, f_leaf, f_inner, T; nload, nsave)::T - reset && reset_counter(node) - return v +function foldup_aggregate(node::Dag, f_leaf::Function, f_inner::Function, ::Type{T}, cache::Nothing=nothing) where {T} + foldup_aggregate(node, f_leaf, f_inner, T, Dict{Dag,T}()) end -""" -Compute a function bottom-up on the circuit, without resetting the counter. -`f_leaf` is called on leaf nodes, and `f_inner` is called on inner nodes. -Values of type `T` are passed up the circuit and given to `f_inner` in aggregate -as a vector from the children. -""" -# TODO: see whether we could standardize on `foldup` and remove this version? -function foldup_aggregate_rec(node::Dag, f_leaf::Function, f_inner::Function, ::Type{T}; - nload = nload, nsave = nsave) where {T} - if (node.counter += 1) != 1 - return nload(node)::T - else - v = if isinner(node) +function foldup_aggregate(node::Dag, f_leaf::Function, f_inner::Function, ::Type{T}, seen::Dict) where {T} + get!(seen, node) do + if isinner(node) child_values = Vector{T}(undef, num_children(node)) - map!(c -> foldup_aggregate_rec(c, f_leaf, f_inner, T; nload, nsave)::T, - child_values, children(node)) + map!(c -> foldup_aggregate(c, f_leaf, f_inner, T, seen)::T, + child_values, children(node)) f_inner(node, child_values)::T else f_leaf(node)::T end - return nsave(node, v)::T - end -end - -"Set the `counter` field of each node to its number of parents" -function count_parents(node::Dag) - foreach(noop,node;reset=false) -end - -"Apply a function to each node in a graph, top down" -function foreach_down(f::Function, node::Dag; setcounter=true) - setcounter && count_parents(node) - foreach_down_rec(f, node) - nothing -end - -"Apply a function to each node in a graph, top down, without setting the counter first" -function foreach_down_rec(f::Function, n::Node) - if ((n.counter -= 1) == 0) - f(n) - if isinner(n) - for c in children(n) - foreach_down_rec(f, c) - end - end end - nothing end - ##################### # methods using circuit traversal ##################### @@ -307,7 +194,7 @@ Compute the number of nodes in of a tree-unfolding of the `Dag`. """ function tree_num_nodes(node::Dag)::BigInt @inline f_leaf(n) = one(BigInt) - @inline f_inner(n, call) = (1 + mapreduce(c -> call(c), +, children(n))) + @inline f_inner(n, call) = (1 + mapreduce(call, +, children(n))) foldup(node, f_leaf, f_inner, BigInt) end @@ -340,7 +227,6 @@ function Base.in(needle::Dag, circuit::Dag) contained end - "Get the list of inner nodes in a given graph" inodes(c::Dag) = filter(isinner, c) @@ -373,7 +259,6 @@ function right_most_descendent(root::Dag)::Dag root end - ##################### # debugging methods (not performance critical) ##################### diff --git a/src/bit_circuit.jl b/src/bit_circuit.jl index 3b5bc389..9a056e25 100644 --- a/src/bit_circuit.jl +++ b/src/bit_circuit.jl @@ -70,12 +70,12 @@ struct BitCircuit{V,M} parents::V end -function BitCircuit(circuit::LogicCircuit, data; reset=true, on_decision=noop) - BitCircuit(circuit, num_features(data); reset, on_decision) +function BitCircuit(circuit::LogicCircuit, data; on_decision=noop) + BitCircuit(circuit, num_features(data); on_decision) end "construct a new `BitCircuit` accomodating the given number of features" -function BitCircuit(circuit::LogicCircuit, num_features::Int; reset=true, on_decision=noop) +function BitCircuit(circuit::LogicCircuit, num_features::Int; on_decision=noop) #TODO: consider not using foldup_aggregate and instead calling twice to ensure order but save allocations #TODO add inbounds annotations @@ -158,7 +158,7 @@ function BitCircuit(circuit::LogicCircuit, num_features::Int; reset=true, on_dec ⋁NodeIds(layer_id, last_dec_id) end - r = foldup_aggregate(circuit, f_con, f_lit, f_and, f_or, NodeIds; reset) + r = foldup_aggregate(circuit, f_con, f_lit, f_and, f_or, NodeIds) to⋁NodeIds(r) nodes_m = reshape(nodes, 4, :) diff --git a/src/queries/queries.jl b/src/queries/queries.jl index a870a1a1..a8bb95d4 100644 --- a/src/queries/queries.jl +++ b/src/queries/queries.jl @@ -35,8 +35,7 @@ import ..Utils: foldup # extend f_con::Function, f_lit::Function, f_a::Function, - f_o::Function, - ::Type{T}; nload = nload, nsave = nsave, reset=true)::T where {T} + f_o::Function)::T where {T} Compute a function bottom-up on the circuit. `f_con` is called on constant gates, `f_lit` is called on literal gates, @@ -44,11 +43,10 @@ Compute a function bottom-up on the circuit. Values of type `T` are passed up the circuit and given to `f_a` and `f_o` through a callback from the children. """ function foldup(node::LogicCircuit, f_con::Function, f_lit::Function, - f_a::Function, f_o::Function, ::Type{T}; - nload = nload, nsave = nsave, reset=true)::T where {T} + f_a::Function, f_o::Function, ::Type{T})::T where {T} f_leaf(n) = isliteralgate(n) ? f_lit(n)::T : f_con(n)::T f_inner(n, call) = is⋀gate(n) ? f_a(n, call)::T : f_o(n, call) - foldup(node, f_leaf, f_inner, T; nload, nsave, reset)::T + foldup(node, f_leaf, f_inner, T)::T end import ..Utils: foldup_aggregate # extend @@ -59,7 +57,7 @@ import ..Utils: foldup_aggregate # extend f_lit::Function, f_a::Function, f_o::Function, - ::Type{T}; nload = nload, nsave = nsave, reset=true)::T where T + ::Type{T})::T where T Compute a function bottom-up on the circuit. `f_con` is called on constant gates, `f_lit` is called on literal gates, @@ -67,15 +65,14 @@ Compute a function bottom-up on the circuit. Values of type `T` are passed up the circuit and given to `f_a` and `f_o` in an aggregate vector from the children. """ function foldup_aggregate(node::LogicCircuit, f_con::Function, f_lit::Function, - f_a::Function, f_o::Function, ::Type{T}; - nload = nload, nsave = nsave, reset=true) where T + f_a::Function, f_o::Function, ::Type{T}) where T function f_leaf(n) isliteralgate(n) ? f_lit(n)::T : f_con(n)::T end function f_inner(n, cs) is⋀gate(n) ? f_a(n, cs)::T : f_o(n, cs)::T end - foldup_aggregate(node, f_leaf::Function, f_inner::Function, T; nload, nsave, reset)::T + foldup_aggregate(node, f_leaf::Function, f_inner::Function, T)::T end ##################### @@ -108,11 +105,11 @@ function variables_by_node(root::LogicCircuit)::Dict{LogicCircuit,BitSet} end "Get the variable in the circuit with the largest index" -function max_variable(root::LogicCircuit; reset=true)::Var +function max_variable(root::LogicCircuit)::Var f_con(n) = Var(0) f_lit(n) = variable(n) f_inner(n, call) = mapreduce(call, max, children(n)) - foldup(root, f_con, f_lit, f_inner, f_inner, Var; reset) + foldup(root, f_con, f_lit, f_inner, f_inner, Var) end """ @@ -212,6 +209,33 @@ function infer_vtree(root::LogicCircuit)::Union{Vtree, Nothing} end +##################### +# structural properties +##################### + +""" + iscanonical(circuit::LogicCircuit, k::Int; verbose = false) + +Does the given circuit have canonical Or gates, as determined by a probabilistic equivalence check? +""" +function iscanonical(circuit::LogicCircuit, k::Int; verbose = false) + signatures = prob_equiv_signature(circuit, k) + decision_nodes_by_signature = groupby(n -> signatures[n], ⋁_nodes(circuit)) + for (signature, nodes) in decision_nodes_by_signature + if length(nodes) > 1 + if verbose + println("Equivalent Nodes:") + for node in nodes + println(" - Node: $node MC: $(model_count(node))") + end + end + return false + end + end + return true +end + + """ isdeterministic(root::LogicCircuit)::Bool @@ -260,32 +284,6 @@ function implied_literals(root::LogicCircuit) end -##################### -# structural properties -##################### - -""" - iscanonical(circuit::LogicCircuit, k::Int; verbose = false) - -Does the given circuit have canonical Or gates, as determined by a probabilistic equivalence check? -""" -function iscanonical(circuit::LogicCircuit, k::Int; verbose = false) - signatures = prob_equiv_signature(circuit, k) - decision_nodes_by_signature = groupby(n -> signatures[n], ⋁_nodes(circuit)) - for (signature, nodes) in decision_nodes_by_signature - if length(nodes) > 1 - if verbose - println("Equivalent Nodes:") - for node in nodes - println(" - Node: $node MC: $(model_count(node))") - end - end - return false - end - end - return true -end - ##################### # algebraic model counting queries ##################### diff --git a/src/small_logic_nodes.jl b/src/small_logic_nodes.jl new file mode 100644 index 00000000..e936a875 --- /dev/null +++ b/src/small_logic_nodes.jl @@ -0,0 +1,95 @@ + + +""" +Root of the Small logic circuit node hierarchy +""" +abstract type SmallLogicCircuit <: LogicCircuit end + +""" +A Small logical leaf node +""" +abstract type SmallLogicLeafNode <: SmallLogicCircuit end + +""" +A Small logical inner node +""" +abstract type SmallLogicInnerNode <: SmallLogicCircuit end + +""" +A Small logical literal leaf node, representing the positive or negative literal of its variable +""" +struct SmallLiteralNode <: SmallLogicLeafNode + literal::Lit +end + +""" +A Small logical constant leaf node, representing true or false +""" +struct SmallConstantNode <: SmallLogicInnerNode + constant::Bool +end + +""" +A Small logical conjunction node (And node) +""" +mutable struct Small⋀Node <: SmallLogicInnerNode + children::Vector{SmallLogicCircuit} +end + +""" +A Small logical disjunction node (Or node) +""" +mutable struct Small⋁Node <: SmallLogicInnerNode + children::Vector{SmallLogicCircuit} +end + +##################### +# traits +##################### + +@inline GateType(::Type{<:SmallLiteralNode}) = LiteralGate() +@inline GateType(::Type{<:SmallConstantNode}) = ConstantGate() +@inline GateType(::Type{<:Small⋀Node}) = ⋀Gate() +@inline GateType(::Type{<:Small⋁Node}) = ⋁Gate() + +##################### +# methods +##################### + +"Get the logical constant in a given constant leaf node" +@inline constant(n::SmallConstantNode)::Bool = n.constant + +"Get the children of a given inner node" +@inline children(n::SmallLogicInnerNode) = n.children::Vector{SmallLogicCircuit} + +function conjoin(arguments::Vector{<:SmallLogicCircuit}; + reuse=nothing) + @assert length(arguments) > 0 + reuse isa Small⋀Node && children(reuse) == arguments && return reuse + return Small⋀Node(arguments) +end + + +function disjoin(arguments::Vector{<:SmallLogicCircuit}; + reuse=nothing) + @assert length(arguments) > 0 + reuse isa Small⋁Node && children(reuse) == arguments && return reuse + return Small⋁Node(arguments) +end + +negate(a::SmallLiteralNode) = compile(SmallLiteralNode, -a.literal) + +# claim `SmallLogicCircuit` as the default `LogicCircuit` implementation + +compile(::Type{<:SmallLogicCircuit}, b::Bool) = SmallConstantNode(b) + +compile(::Type{<:SmallLogicCircuit}, l::Lit) = + SmallLiteralNode(l) + +function compile(::Type{<:SmallLogicCircuit}, circuit::LogicCircuit) + f_con(n) = compile(SmallLogicCircuit, constant(n)) + f_lit(n) = compile(SmallLogicCircuit, literal(n)) + f_a(_, cns) = conjoin(cns) + f_o(_, cns) = disjoin(cns) + foldup_aggregate(circuit, f_con, f_lit, f_a, f_o, SmallLogicCircuit) +end From 0d7b1d3783404bb1b5f70966d1e46915c0e73310 Mon Sep 17 00:00:00 2001 From: Guy Van den Broeck Date: Sun, 28 Mar 2021 16:11:48 -0700 Subject: [PATCH 2/6] fix unit tests to no longer use data field --- src/bit_circuit.jl | 10 +- src/queries/queries.jl | 8 +- src/queries/satisfies_flow.jl | 54 +++++------ test/Utils/graphs_test.jl | 22 ----- test/{ => queries}/queries_test.jl | 13 +-- test/{ => queries}/satisfies_flow_test.jl | 106 +++++++++++----------- 6 files changed, 98 insertions(+), 115 deletions(-) rename test/{ => queries}/queries_test.jl (93%) rename test/{ => queries}/satisfies_flow_test.jl (74%) diff --git a/src/bit_circuit.jl b/src/bit_circuit.jl index 9a056e25..e18f8839 100644 --- a/src/bit_circuit.jl +++ b/src/bit_circuit.jl @@ -68,6 +68,7 @@ struct BitCircuit{V,M} nodes::M elements::M parents::V + node2id end function BitCircuit(circuit::LogicCircuit, data; on_decision=noop) @@ -158,7 +159,8 @@ function BitCircuit(circuit::LogicCircuit, num_features::Int; on_decision=noop) ⋁NodeIds(layer_id, last_dec_id) end - r = foldup_aggregate(circuit, f_con, f_lit, f_and, f_or, NodeIds) + node2id = Dict{LogicCircuit,NodeIds}() + r = foldup_aggregate(circuit, f_con, f_lit, f_and, f_or, NodeIds, node2id) to⋁NodeIds(r) nodes_m = reshape(nodes, 4, :) @@ -178,7 +180,7 @@ function BitCircuit(circuit::LogicCircuit, num_features::Int; on_decision=noop) end end - return BitCircuit(layers, nodes_m, elements_m, parents_m) + return BitCircuit(layers, nodes_m, elements_m, parents_m, node2id) end import .Utils: num_nodes # extend @@ -234,10 +236,10 @@ end import .Utils: to_gpu, to_cpu, isgpu #extend to_gpu(c::BitCircuit) = - BitCircuit(map(to_gpu, c.layers), to_gpu(c.nodes), to_gpu(c.elements), to_gpu(c.parents)) + BitCircuit(map(to_gpu, c.layers), to_gpu(c.nodes), to_gpu(c.elements), to_gpu(c.parents), c.node2id) to_cpu(c::BitCircuit) = - BitCircuit(map(to_cpu, c.layers), to_cpu(c.nodes), to_cpu(c.elements), to_cpu(c.parents)) + BitCircuit(map(to_cpu, c.layers), to_cpu(c.nodes), to_cpu(c.elements), to_cpu(c.parents), c.node2id) isgpu(c::BitCircuit{<:CuArray,<:CuArray}) = true isgpu(c::BitCircuit{<:Array,<:Array}) = false diff --git a/src/queries/queries.jl b/src/queries/queries.jl index a8bb95d4..956ee7ad 100644 --- a/src/queries/queries.jl +++ b/src/queries/queries.jl @@ -65,14 +65,14 @@ Compute a function bottom-up on the circuit. Values of type `T` are passed up the circuit and given to `f_a` and `f_o` in an aggregate vector from the children. """ function foldup_aggregate(node::LogicCircuit, f_con::Function, f_lit::Function, - f_a::Function, f_o::Function, ::Type{T}) where T + f_a::Function, f_o::Function, ::Type{T}, cache=nothing) where T function f_leaf(n) isliteralgate(n) ? f_lit(n)::T : f_con(n)::T end function f_inner(n, cs) is⋀gate(n) ? f_a(n, cs)::T : f_o(n, cs)::T end - foldup_aggregate(node, f_leaf::Function, f_inner::Function, T)::T + foldup_aggregate(node, f_leaf::Function, f_inner::Function, T, cache)::T end ##################### @@ -269,7 +269,7 @@ Compute at each node literals that are implied by the formula. This algorithm is sound but not complete - all literals returned are correct, but some true implied literals may be missing. """ -function implied_literals(root::LogicCircuit) +function implied_literals(root::LogicCircuit, cache=nothing) f_con(c) = constant(c) ? BitSet() : nothing f_lit(n) = BitSet([literal(n)]) f_a(_, cs) = if any(isnothing, cs) @@ -280,7 +280,7 @@ function implied_literals(root::LogicCircuit) f_o(_, cs) = begin reduce(intersect, filter(issomething, cs)) end - foldup_aggregate(root, f_con, f_lit, f_a, f_o, Union{BitSet, Nothing}) + foldup_aggregate(root, f_con, f_lit, f_a, f_o, Union{BitSet, Nothing}, cache) end diff --git a/src/queries/satisfies_flow.jl b/src/queries/satisfies_flow.jl index 5b02c296..fe2112a0 100644 --- a/src/queries/satisfies_flow.jl +++ b/src/queries/satisfies_flow.jl @@ -2,14 +2,13 @@ using CUDA: CUDA, @cuda using DataFrames: DataFrame using LoopVectorization: @avx -export model_var_prob, satisfies_flows_down, satisfies_flows, -count_downflow, downflow_all +export model_var_prob, satisfies_flows_down, satisfies_flows, count_downflow, downflow_all "Compute the probability of each variable for a random satisfying assignment of the logical circuit" function model_var_prob(root::LogicCircuit) nvars = num_variables(root) - v, f = satisfies_flows(root, DataFrame(fill(0.5, 1, nvars))) + v, f, _ = satisfies_flows(root, DataFrame(fill(0.5, 1, nvars))) f[3:2+nvars]./v[end] end @@ -22,7 +21,8 @@ function satisfies_flows(circuit::LogicCircuit, data, reuse_values=nothing, reuse_flows=nothing; on_node=noop, on_edge=noop, weights=nothing) bc = same_device(BitCircuit(circuit, data), data) - satisfies_flows(bc, data, reuse_values, reuse_flows; on_node, on_edge, weights) + v, f = satisfies_flows(bc, data, reuse_values, reuse_flows; on_node, on_edge, weights) + v,f, bc.node2id end function satisfies_flows(circuit::BitCircuit, data, @@ -290,30 +290,32 @@ end # API to get flows -function count_downflow(values::Matrix{UInt64}, flows::Matrix{UInt64}, N, n::LogicCircuit) - dec_id = n.data.node_id +#TODO: these functions need documentation! + +function count_downflow(values::Matrix{UInt64}, flows::Matrix{UInt64}, N, n::LogicCircuit, node2id) + dec_id = node2id[n].node_id sum(1:size(flows,1)) do i count_ones(flows[i, dec_id]) end end -function downflow_all(values::Matrix{UInt64}, flows::Matrix{UInt64}, N, n::LogicCircuit) - dec_id = n.data.node_id +function downflow_all(values::Matrix{UInt64}, flows::Matrix{UInt64}, N, n::LogicCircuit, node2id) + dec_id = node2id[n].node_id indices = map(1:size(flows,1)) do i digits(Bool, flows[i, dec_id], base=2, pad=64) end BitArray(vcat(indices...)[1:N]) end -function count_downflow(values::Matrix{UInt64}, flows::Matrix{UInt64}, N, n::LogicCircuit, c::LogicCircuit) - grandpa = n.data.node_id +function count_downflow(values::Matrix{UInt64}, flows::Matrix{UInt64}, N, n::LogicCircuit, c::LogicCircuit, node2id) + grandpa = node2id[n].node_id if isleafgate(c) - par = c.data.node_id + par = node2id[c].node_id return sum(1:size(flows,1)) do i count_ones(values[i, par] & flows[i, grandpa]) end else - ids = [x.data.node_id for x in children(c)] + ids = [node2id[x].node_id for x in children(c)] return sum(1:size(flows,1)) do i indices = flows[i, grandpa] for id in ids @@ -324,15 +326,15 @@ function count_downflow(values::Matrix{UInt64}, flows::Matrix{UInt64}, N, n::Log end end -function downflow_all(values::Matrix{UInt64}, flows::Matrix{UInt64}, N, n::LogicCircuit, c::LogicCircuit) - grandpa = n.data.node_id +function downflow_all(values::Matrix{UInt64}, flows::Matrix{UInt64}, N, n::LogicCircuit, c::LogicCircuit, node2id) + grandpa = node2id[n].node_id if isleafgate(c) - par = c.data.node_id + par = node2id[c].node_id edge = map(1:size(flows,1)) do i digits(Bool, values[i, par] & flows[i, grandpa], base=2, pad=64) end else - ids = [x.data.node_id for x in children(c)] + ids = [node2id[x].node_id for x in children(c)] edge = map(1:size(flows,1)) do i indices = flows[i, grandpa] for id in ids @@ -344,31 +346,31 @@ function downflow_all(values::Matrix{UInt64}, flows::Matrix{UInt64}, N, n::Logic BitArray(vcat(edge...)[1:N]) end -function count_downflow(values::Matrix{<:AbstractFloat}, flows::Matrix{<:AbstractFloat}, N, n::LogicCircuit) - return sum(downflow_all(values, flows, N, n)) +function count_downflow(values::Matrix{<:AbstractFloat}, flows::Matrix{<:AbstractFloat}, N, n::LogicCircuit, node2id) + return sum(downflow_all(values, flows, N, n, node2id)) end -function downflow_all(values::Matrix{<:AbstractFloat}, flows::Matrix{<:AbstractFloat}, N, n::LogicCircuit) - dec_id = n.data.node_id +function downflow_all(values::Matrix{<:AbstractFloat}, flows::Matrix{<:AbstractFloat}, N, n::LogicCircuit, node2id) + dec_id = node2id[n].node_id map(1:size(flows, 1)) do i flows[i, dec_id] end end -function count_downflow(values::Matrix{<:AbstractFloat}, flows::Matrix{<:AbstractFloat}, N, n::LogicCircuit, c::LogicCircuit) - sum(downflow_all(values, flows, N, n, c)) +function count_downflow(values::Matrix{<:AbstractFloat}, flows::Matrix{<:AbstractFloat}, N, n::LogicCircuit, c::LogicCircuit, node2id) + sum(downflow_all(values, flows, N, n, c, node2id)) end -function downflow_all(values::Matrix{<:AbstractFloat}, flows::Matrix{<:AbstractFloat}, N, n::LogicCircuit, c::LogicCircuit) - grandpa = n.data.node_id +function downflow_all(values::Matrix{<:AbstractFloat}, flows::Matrix{<:AbstractFloat}, N, n::LogicCircuit, c::LogicCircuit, node2id) + grandpa = node2id[n].node_id if isleafgate(c) - par = c.data.node_id + par = node2id[c].node_id edge_flows = map(1:size(flows,1)) do i values[i, par] * flows[i, grandpa] / values[i, grandpa] end return edge_flows else - ids = [x.data.node_id for x in children(c)] + ids = [node2id[x].node_id for x in children(c)] edge_flows = map(1:size(flows,1)) do i n_down = flows[i, grandpa] n_up = values[i, grandpa] diff --git a/test/Utils/graphs_test.jl b/test/Utils/graphs_test.jl index f303a38c..ef98a7be 100644 --- a/test/Utils/graphs_test.jl +++ b/test/Utils/graphs_test.jl @@ -54,28 +54,6 @@ module TestNodes @test has_children(r) @test num_children(r) == 3 - reset_counter(r,5) - @test r.counter == 5 - @test l1.counter == 5 - @test i12.counter == 5 - - reset_counter(r) - @test r.counter == 0 - @test l1.counter == 0 - @test i12.counter == 0 - - i12.counter = 5 # break counters - reset_counter_hard(r,5) - @test r.counter == 5 - @test l1.counter == 5 - @test i12.counter == 5 - - i12.counter = 42 # break counters - reset_counter_hard(r) - @test r.counter == 0 - @test l1.counter == 0 - @test i12.counter == 0 - foreach(r) do n n.id += 1 end diff --git a/test/queries_test.jl b/test/queries/queries_test.jl similarity index 93% rename from test/queries_test.jl rename to test/queries/queries_test.jl index 312bb183..70d5745f 100644 --- a/test/queries_test.jl +++ b/test/queries/queries_test.jl @@ -2,7 +2,7 @@ using Test using Suppressor using LogicCircuits -include("helper/plain_logic_circuits.jl") +include("../helper/plain_logic_circuits.jl") @testset "Queries test" begin @@ -131,17 +131,18 @@ end plc = PlainLogicCircuit(circuit) # store implied literals in each data field - implied_literals(plc) + data = Dict() + implied_literals(plc, data) for ornode in or_nodes(plc) @test num_children(ornode) == 2 # If there's a nothing just continue, it'll always work - if ornode.children[1].data === nothing || - ornode.children[2].data === nothing + if data[ornode.children[1]] === nothing || + data[ornode.children[2]] === nothing continue end - implied1 = ornode.children[1].data - implied2 = ornode.children[2].data + implied1 = data[ornode.children[1]] + implied2 = data[ornode.children[2]] neg_implied2 = BitSet(map(x -> -x, collect(implied2))) @test !isempty(intersect(implied1, neg_implied2)) end diff --git a/test/satisfies_flow_test.jl b/test/queries/satisfies_flow_test.jl similarity index 74% rename from test/satisfies_flow_test.jl rename to test/queries/satisfies_flow_test.jl index b52090f5..8c8096a3 100644 --- a/test/satisfies_flow_test.jl +++ b/test/queries/satisfies_flow_test.jl @@ -3,7 +3,7 @@ using LogicCircuits using Random: bitrand, rand using DataFrames: DataFrame -include("helper/gpu.jl") +include("../helper/gpu.jl") @testset "Binary flows test" begin @@ -20,9 +20,9 @@ include("helper/gpu.jl") r = fully_factorized_circuit(StructLogicCircuit, vtree) @test r(input) == BitVector([1,1,1,1]) - v, f = satisfies_flows(r, input) + v, f, node2id= satisfies_flows(r, input) foreach(literal_nodes(r)) do n - id = n.data.node_id + id = node2id[n].node_id @test v[:,id] == f[:,id] # invariant of logically valid circuits end @@ -45,9 +45,9 @@ include("helper/gpu.jl") r = smooth(PlainLogicCircuit(c)) # flows don't make sense unless the circuit is smooth; cannot smooth trimmed SDDs - v, f = satisfies_flows(r, input) + v, f, node2id = satisfies_flows(r, input) foreach(literal_nodes(r)) do n - id = n.data.node_id + id = node2id[n].node_id @test v[:,id] .& v[:,end] == f[:,id] # invariant of all circuits end @@ -66,9 +66,9 @@ include("helper/gpu.jl") r = (l1 & l2) | (l3 & l4) input = DataFrame(bitrand(4,2)) - v, f = satisfies_flows(r, input) + v, f, node2id = satisfies_flows(r, input) foreach(literal_nodes(r)) do n - id = n.data.node_id + id = node2id[n].node_id @test v[:,id] .& v[:,end] == f[:,id] # invariant of all circuits end @@ -101,9 +101,9 @@ end r = fully_factorized_circuit(StructLogicCircuit, vtree) @test r(input) == BitVector([1,1,1,1]) - v, f = satisfies_flows(r, input; weights = weights) + v, f, node2id = satisfies_flows(r, input; weights = weights) foreach(literal_nodes(r)) do n - id = n.data.node_id + id = node2id[n].node_id @test v[:,id] == f[:,id] # invariant of logically valid circuits end @@ -126,9 +126,9 @@ end r = smooth(PlainLogicCircuit(c)) # flows don't make sense unless the circuit is smooth; cannot smooth trimmed SDDs - v, f = satisfies_flows(r, input; weights = weights) + v, f, node2id = satisfies_flows(r, input; weights = weights) foreach(literal_nodes(r)) do n - id = n.data.node_id + id = node2id[n].node_id @test v[:,id] .& v[:,end] == f[:,id] # invariant of all circuits end @@ -145,9 +145,9 @@ end r = (l1 & l2) | (l3 & l4) input = DataFrame(bitrand(4,2)) - v, f = satisfies_flows(r, input; weights = weights) + v, f, node2id = satisfies_flows(r, input; weights = weights) foreach(literal_nodes(r)) do n - id = n.data.node_id + id = node2id[n].node_id @test v[:,id] .& v[:,end] == f[:,id] # invariant of all circuits end @@ -176,17 +176,17 @@ end df = DataFrame(BitMatrix([true true; false true; false false])) sdf = soften(df, 0.001; scale_by_marginal = false) - v, f = satisfies_flows(r, sdf) - @test all(v[:,o_c.data.node_id] .≈ sdf[:, 2]) - @test all(v[:,o_c.data.node_id] .≈ f[:,o_c.data.node_id]) - @test all(v[:,o_d.data.node_id] .≈ sdf[:, 1]) - @test all(v[:,o_d.data.node_id] .≈ f[:,o_d.data.node_id]) + v, f, node2id = satisfies_flows(r, sdf) + @test all(v[:,node2id[o_c].node_id] .≈ sdf[:, 2]) + @test all(v[:,node2id[o_c].node_id] .≈ f[:,node2id[o_c].node_id]) + @test all(v[:,node2id[o_d].node_id] .≈ sdf[:, 1]) + @test all(v[:,node2id[o_d].node_id] .≈ f[:,node2id[o_d].node_id]) - v, f = satisfies_flows(r, sdf; weights = weights) - @test all(v[:,o_c.data.node_id] .≈ sdf[:, 2]) - @test all(v[:,o_c.data.node_id] .≈ f[:,o_c.data.node_id]) - @test all(v[:,o_d.data.node_id] .≈ sdf[:, 1]) - @test all(v[:,o_d.data.node_id] .≈ f[:,o_d.data.node_id]) + v, f, node2id = satisfies_flows(r, sdf; weights = weights) + @test all(v[:,node2id[o_c].node_id] .≈ sdf[:, 2]) + @test all(v[:,node2id[o_c].node_id] .≈ f[:,node2id[o_c].node_id]) + @test all(v[:,node2id[o_d].node_id] .≈ sdf[:, 1]) + @test all(v[:,node2id[o_d].node_id] .≈ f[:,node2id[o_d].node_id]) if CUDA.functional() @test all(satisfies_flows(r, sdf; weights = weights)[1] .≈ to_cpu(satisfies_flows(r, to_gpu(sdf); weights = to_gpu(weights))[1])) @@ -205,9 +205,9 @@ end @test r(input) ≈ [1.0, 1.0, 1.0, 1.0] - v, f = satisfies_flows(r, input) + v, f, node2id = satisfies_flows(r, input) foreach(literal_nodes(r)) do n - id = n.data.node_id + id = node2id[n].node_id @test v[:,id] ≈ f[:,id] # invariant of logically valid circuits end @@ -234,17 +234,17 @@ end @test all(r(input) .≈ 1.0) - v, f = satisfies_flows(r, input) + v, f, node2id = satisfies_flows(r, input) @test all(v[:,end] .≈ 1.0) @test all(f[:,end].≈ 1.0) - @test all(v[:,o_c.data.node_id] .≈ input[:, 3]) + @test all(v[:,node2id[o_c].node_id] .≈ input[:, 3]) - @test all(f[:,o_c.data.node_id] .≈ input[:, 3]) - @test all(v[:,o_nc.data.node_id] .+ input[:, 3] .≈ 1.0) - @test all(f[:,o_nc.data.node_id] .+ input[:, 3] .≈ 1.0) + @test all(f[:,node2id[o_c].node_id] .≈ input[:, 3]) + @test all(v[:,node2id[o_nc].node_id] .+ input[:, 3] .≈ 1.0) + @test all(f[:,node2id[o_nc].node_id] .+ input[:, 3] .≈ 1.0) foreach(literal_nodes(r)) do n - id = n.data.node_id + id = node2id[n].node_id @test v[:,id] ≈ f[:,id] # invariant of logically valid circuits end @@ -263,17 +263,17 @@ end input = DataFrame(rand(Float64, (10,3))) input[1:5, 3] .= 0.0 - v, f = satisfies_flows(r, input) - @test all(v[:,o_c.data.node_id] .≈ input[:, 2] .* input[:, 3]) - @test all(f[:,o_c.data.node_id] .≈ v[:,o_c.data.node_id]) - @test all(v[:,o_nc.data.node_id] .≈ (1.0 .- input[:, 2]) .* (1.0 .- input[:, 3])) - @test all(f[:,o_nc.data.node_id] .≈ v[:,o_nc.data.node_id]) - @test all(f[:,l_a.data.node_id] .≈ v[:,l_a.data.node_id] .* f[:,r.data.node_id]) - @test all(f[:,l_na.data.node_id] .≈ v[:,l_na.data.node_id] .* f[:,r.data.node_id]) - @test all(f[:,l_b.data.node_id] .≈ input[:, 2] .* input[:, 3]) - @test all(f[:,l_b.data.node_id] .≈ f[:,l_c.data.node_id]) - @test all(f[:,l_nb.data.node_id] .≈ (1.0 .- input[:, 2]) .* (1.0 .- input[:, 3])) - @test all(f[:,l_nb.data.node_id] .≈ f[:,l_nc.data.node_id]) + v, f, node2id = satisfies_flows(r, input) + @test all(v[:,node2id[o_c].node_id] .≈ input[:, 2] .* input[:, 3]) + @test all(f[:,node2id[o_c].node_id] .≈ v[:,node2id[o_c].node_id]) + @test all(v[:,node2id[o_nc].node_id] .≈ (1.0 .- input[:, 2]) .* (1.0 .- input[:, 3])) + @test all(f[:,node2id[o_nc].node_id] .≈ v[:,node2id[o_nc].node_id]) + @test all(f[:,node2id[l_a].node_id] .≈ v[:,node2id[l_a].node_id] .* f[:,node2id[r].node_id]) + @test all(f[:,node2id[l_na].node_id] .≈ v[:,node2id[l_na].node_id] .* f[:,node2id[r].node_id]) + @test all(f[:,node2id[l_b].node_id] .≈ input[:, 2] .* input[:, 3]) + @test all(f[:,node2id[l_b].node_id] .≈ f[:,node2id[l_c].node_id]) + @test all(f[:,node2id[l_nb].node_id] .≈ (1.0 .- input[:, 2]) .* (1.0 .- input[:, 3])) + @test all(f[:,node2id[l_nb].node_id] .≈ f[:,node2id[l_nc].node_id]) cpu_gpu_agree_approx(input) do d satisfies_flows(r, d)[1] # same value @@ -294,18 +294,18 @@ end 0 1 1 0 1 0 0 1 0 1])) input_f = DataFrame(Float64.(Matrix(input_b))) - - f_b, v_b = satisfies_flows(r, input_b) - f_f, v_f = satisfies_flows(r, input_f) + + f_b, v_b, node2id = satisfies_flows(r, input_b) + f_f, v_f, node2id = satisfies_flows(r, input_f) N = num_examples(input_b) foreach(r) do n if is⋁gate(n) - @test all(Float64.(downflow_all(f_b, v_b, N, n)) .≈ downflow_all(f_f, v_f, N, n)) - @test Float64.(count_downflow(f_b, v_b, N, n)) ≈ count_downflow(f_f, v_f, N, n) + @test all(Float64.(downflow_all(f_b, v_b, N, n, node2id)) .≈ downflow_all(f_f, v_f, N, n, node2id)) + @test Float64.(count_downflow(f_b, v_b, N, n, node2id)) ≈ count_downflow(f_f, v_f, N, n, node2id) for c in children(n) - @test all(Float64.(downflow_all(f_b, v_b, N, n, c)) .≈ downflow_all(f_f, v_f, N, n, c)) - @test Float64.(count_downflow(f_b, v_b, N, n, c)) ≈ count_downflow(f_f, v_f, N, n, c) + @test all(Float64.(downflow_all(f_b, v_b, N, n, c, node2id)) .≈ downflow_all(f_f, v_f, N, n, c, node2id)) + @test Float64.(count_downflow(f_b, v_b, N, n, c, node2id)) ≈ count_downflow(f_f, v_f, N, n, c, node2id) end end end @@ -335,15 +335,15 @@ end N = num_examples(input_b) foreach(r) do n if is⋁gate(n) - df = [count_downflow(v, f, N, n) for (v, f) in vfs] + df = [count_downflow(v, f, N, n, node2id) for (v, f) in vfs] @test 0.3 * df[1] + 0.1 * df[2] + 0.6 * df[3] ≈ df[4] - df_all = [downflow_all(v, f, N, n) for (v, f) in vfs] + df_all = [downflow_all(v, f, N, n, node2id) for (v, f) in vfs] @test all(0.3 * Float64.(df_all[1]) + 0.1 * Float64.(df_all[2]) + 0.6 * Float64.(df_all[3]) .≈ df_all[4]) for c in children(n) - df = [count_downflow(v, f, N, n, c) for (v, f) in vfs] + df = [count_downflow(v, f, N, n, c, node2id) for (v, f) in vfs] @test 0.3 * df[1] + 0.1 * df[2] + 0.6 * df[3] ≈ df[4] - df_all = [downflow_all(v, f, N, n, c) for (v, f) in vfs] + df_all = [downflow_all(v, f, N, n, c, node2id) for (v, f) in vfs] @test all(0.3 * Float64.(df_all[1]) + 0.1 * Float64.(df_all[2]) + 0.6 * Float64.(df_all[3]) .≈ df_all[4]) end end From 8219120f90291cc114164f08a5a5d9f8f95fc0cf Mon Sep 17 00:00:00 2001 From: Guy Van den Broeck Date: Sun, 28 Mar 2021 18:40:45 -0700 Subject: [PATCH 3/6] remove `data` and `counter` from Plain circuit nodes --- src/LoadSave/circuit_line_compiler.jl | 4 +- src/LogicCircuits.jl | 1 - src/plain_logic_nodes.jl | 40 ++-------- src/small_logic_nodes.jl | 95 ------------------------ src/structured/structured_logic_nodes.jl | 34 +++------ test/Utils/graphs_test.jl | 3 +- test/plain_logic_nodes_test.jl | 6 +- 7 files changed, 21 insertions(+), 162 deletions(-) delete mode 100644 src/small_logic_nodes.jl diff --git a/src/LoadSave/circuit_line_compiler.jl b/src/LoadSave/circuit_line_compiler.jl index e92bd913..b563df54 100644 --- a/src/LoadSave/circuit_line_compiler.jl +++ b/src/LoadSave/circuit_line_compiler.jl @@ -24,8 +24,8 @@ function compile_logical_m(lines::CircuitFormatLines) PlainLiteralNode(l) end - true_node = PlainTrueNode() - false_node = PlainFalseNode() + true_node = PlainConstantNode(true) + false_node = PlainConstantNode(false) function compile(ln::CircuitFormatLine) error("Compilation of line $ln is not supported") diff --git a/src/LogicCircuits.jl b/src/LogicCircuits.jl index ef48e1cd..ff8c8c5a 100644 --- a/src/LogicCircuits.jl +++ b/src/LogicCircuits.jl @@ -15,7 +15,6 @@ include("queries/satisfies.jl") include("queries/satisfies_flow.jl") include("plain_logic_nodes.jl") -include("small_logic_nodes.jl") include("structured/abstract_vtrees.jl") include("structured/plain_vtrees.jl") diff --git a/src/plain_logic_nodes.jl b/src/plain_logic_nodes.jl index aa9dce81..5df1fd7e 100644 --- a/src/plain_logic_nodes.jl +++ b/src/plain_logic_nodes.jl @@ -1,6 +1,5 @@ export PlainLogicCircuit, PlainLogicLeafNode, PlainLogicInnerNode, - PlainLiteralNode, PlainConstantNode, PlainTrueNode, PlainFalseNode, - Plain⋀Node, Plain⋁Node + PlainLiteralNode, PlainConstantNode, Plain⋀Node, Plain⋁Node ##################### # Plain logic nodes without additional fields @@ -24,34 +23,15 @@ abstract type PlainLogicInnerNode <: PlainLogicCircuit end """ A plain logical literal leaf node, representing the positive or negative literal of its variable """ -mutable struct PlainLiteralNode <: PlainLogicLeafNode +struct PlainLiteralNode <: PlainLogicLeafNode literal::Lit - data - counter::UInt32 - PlainLiteralNode(l) = new(l, nothing, 0) end """ A plain logical constant leaf node, representing true or false """ -abstract type PlainConstantNode <: PlainLogicInnerNode end - -""" -Plain constant true node -""" -mutable struct PlainTrueNode <: PlainConstantNode - data - counter::UInt32 - PlainTrueNode() = new(nothing, 0) -end - -""" -Plain constant false node -""" -mutable struct PlainFalseNode <: PlainConstantNode - data - counter::UInt32 - PlainFalseNode() = new(nothing, 0) +struct PlainConstantNode <: PlainLogicInnerNode + constant::Bool end """ @@ -59,9 +39,6 @@ A plain logical conjunction node (And node) """ mutable struct Plain⋀Node <: PlainLogicInnerNode children::Vector{PlainLogicCircuit} - data - counter::UInt32 - Plain⋀Node(c) = new(c, nothing, 0) end """ @@ -69,9 +46,6 @@ A plain logical disjunction node (Or node) """ mutable struct Plain⋁Node <: PlainLogicInnerNode children::Vector{PlainLogicCircuit} - data - counter::UInt32 - Plain⋁Node(c) = new(c, nothing, 0) end ##################### @@ -88,8 +62,7 @@ end ##################### "Get the logical constant in a given constant leaf node" -@inline constant(n::PlainTrueNode)::Bool = true -@inline constant(n::PlainFalseNode)::Bool = false +@inline constant(n::PlainConstantNode) = n.constant::Bool "Get the children of a given inner node" @inline children(n::PlainLogicInnerNode) = n.children @@ -116,7 +89,7 @@ compile(::Type{LogicCircuit}, args...) = compile(PlainLogicCircuit, args...) compile(::Type{<:PlainLogicCircuit}, b::Bool) = - b ? PlainTrueNode() : PlainFalseNode() + PlainConstantNode(b) compile(::Type{<:PlainLogicCircuit}, l::Lit) = PlainLiteralNode(l) @@ -129,7 +102,6 @@ function compile(::Type{<:PlainLogicCircuit}, circuit::LogicCircuit) foldup_aggregate(circuit, f_con, f_lit, f_a, f_o, PlainLogicCircuit) end - fully_factorized_circuit(::Type{LogicCircuit}, n::Int) = fully_factorized_circuit(PlainLogicCircuit, n) diff --git a/src/small_logic_nodes.jl b/src/small_logic_nodes.jl deleted file mode 100644 index e936a875..00000000 --- a/src/small_logic_nodes.jl +++ /dev/null @@ -1,95 +0,0 @@ - - -""" -Root of the Small logic circuit node hierarchy -""" -abstract type SmallLogicCircuit <: LogicCircuit end - -""" -A Small logical leaf node -""" -abstract type SmallLogicLeafNode <: SmallLogicCircuit end - -""" -A Small logical inner node -""" -abstract type SmallLogicInnerNode <: SmallLogicCircuit end - -""" -A Small logical literal leaf node, representing the positive or negative literal of its variable -""" -struct SmallLiteralNode <: SmallLogicLeafNode - literal::Lit -end - -""" -A Small logical constant leaf node, representing true or false -""" -struct SmallConstantNode <: SmallLogicInnerNode - constant::Bool -end - -""" -A Small logical conjunction node (And node) -""" -mutable struct Small⋀Node <: SmallLogicInnerNode - children::Vector{SmallLogicCircuit} -end - -""" -A Small logical disjunction node (Or node) -""" -mutable struct Small⋁Node <: SmallLogicInnerNode - children::Vector{SmallLogicCircuit} -end - -##################### -# traits -##################### - -@inline GateType(::Type{<:SmallLiteralNode}) = LiteralGate() -@inline GateType(::Type{<:SmallConstantNode}) = ConstantGate() -@inline GateType(::Type{<:Small⋀Node}) = ⋀Gate() -@inline GateType(::Type{<:Small⋁Node}) = ⋁Gate() - -##################### -# methods -##################### - -"Get the logical constant in a given constant leaf node" -@inline constant(n::SmallConstantNode)::Bool = n.constant - -"Get the children of a given inner node" -@inline children(n::SmallLogicInnerNode) = n.children::Vector{SmallLogicCircuit} - -function conjoin(arguments::Vector{<:SmallLogicCircuit}; - reuse=nothing) - @assert length(arguments) > 0 - reuse isa Small⋀Node && children(reuse) == arguments && return reuse - return Small⋀Node(arguments) -end - - -function disjoin(arguments::Vector{<:SmallLogicCircuit}; - reuse=nothing) - @assert length(arguments) > 0 - reuse isa Small⋁Node && children(reuse) == arguments && return reuse - return Small⋁Node(arguments) -end - -negate(a::SmallLiteralNode) = compile(SmallLiteralNode, -a.literal) - -# claim `SmallLogicCircuit` as the default `LogicCircuit` implementation - -compile(::Type{<:SmallLogicCircuit}, b::Bool) = SmallConstantNode(b) - -compile(::Type{<:SmallLogicCircuit}, l::Lit) = - SmallLiteralNode(l) - -function compile(::Type{<:SmallLogicCircuit}, circuit::LogicCircuit) - f_con(n) = compile(SmallLogicCircuit, constant(n)) - f_lit(n) = compile(SmallLogicCircuit, literal(n)) - f_a(_, cns) = conjoin(cns) - f_o(_, cns) = disjoin(cns) - foldup_aggregate(circuit, f_con, f_lit, f_a, f_o, SmallLogicCircuit) -end diff --git a/src/structured/structured_logic_nodes.jl b/src/structured/structured_logic_nodes.jl index 77239aa9..6fae94da 100644 --- a/src/structured/structured_logic_nodes.jl +++ b/src/structured/structured_logic_nodes.jl @@ -1,6 +1,6 @@ export StructLogicCircuit, PlainStructLogicCircuit, PlainStructLogicLeafNode, PlainStructLogicInnerNode, - PlainStructLiteralNode, PlainStructConstantNode, PlainStructTrueNode, PlainStructFalseNode, + PlainStructLiteralNode, PlainStructConstantNode, PlainStruct⋀Node, PlainStruct⋁Node, vtree, vtree_safe, prime, sub @@ -25,11 +25,9 @@ abstract type PlainStructLogicInnerNode <: PlainStructLogicCircuit end mutable struct PlainStructLiteralNode <: PlainStructLogicLeafNode literal::Lit vtree::Vtree - data - counter::UInt32 PlainStructLiteralNode(l,v) = begin @assert lit2var(l) ∈ v - new(l, v, nothing, 0) + new(l, v) end end @@ -37,32 +35,21 @@ end A plain structured logical constant leaf node, representing true or false. These are the only structured nodes that don't have an associated vtree node (cf. SDD file format) """ -abstract type PlainStructConstantNode <: PlainStructLogicInnerNode end - -"A plain structured logical true constant. Never construct one, use `structtrue` to access its unique instance" -mutable struct PlainStructTrueNode <: PlainStructConstantNode - data - counter::UInt32 +struct PlainStructConstantNode <: PlainStructLogicInnerNode + constant::Bool end -"A plain structured logical false constant. Never construct one, use `structfalse` to access its unique instance" -mutable struct PlainStructFalseNode <: PlainStructConstantNode - data - counter::UInt32 -end "A plain structured logical conjunction node" mutable struct PlainStruct⋀Node <: PlainStructLogicInnerNode prime::PlainStructLogicCircuit sub::PlainStructLogicCircuit vtree::Vtree - data - counter::UInt32 PlainStruct⋀Node(p,s,v) = begin @assert isinner(v) "Structured conjunctions must respect inner vtree node" @assert isconstantgate(p) || varsubset_left(vtree(p),v) "$p does not go left in $v" @assert isconstantgate(s) || varsubset_right(vtree(s),v) "$s does not go right in $v" - new(p,s, v, nothing, 0) + new(p,s,v) end end @@ -70,16 +57,14 @@ end mutable struct PlainStruct⋁Node <: PlainStructLogicInnerNode children::Vector{PlainStructLogicCircuit} vtree::Vtree # could be leaf or inner - data - counter::UInt32 - PlainStruct⋁Node(c,v) = new(c, v, nothing, 0) end +#TODO remove these now that they are unique and not mutable? "The unique plain structured logical true constant" -const structtrue = PlainStructTrueNode(nothing, 0) +const structtrue = PlainStructConstantNode(true) "The unique splain tructured logical false constant" -const structfalse = PlainStructFalseNode(nothing, 0) +const structfalse = PlainStructConstantNode(false) ##################### # traits @@ -94,8 +79,7 @@ const structfalse = PlainStructFalseNode(nothing, 0) # methods ##################### -@inline constant(n::PlainStructTrueNode)::Bool = true -@inline constant(n::PlainStructFalseNode)::Bool = false +@inline constant(n::PlainStructConstantNode)::Bool = n.constant @inline children(n::PlainStruct⋁Node) = n.children @inline children(n::PlainStruct⋀Node) = [n.prime,n.sub] diff --git a/test/Utils/graphs_test.jl b/test/Utils/graphs_test.jl index ef98a7be..992c6df9 100644 --- a/test/Utils/graphs_test.jl +++ b/test/Utils/graphs_test.jl @@ -122,8 +122,7 @@ end istats = inode_stats(lc); nstats = node_stats(lc); - @test !(PlainTrueNode in keys(lstats)); - @test !(PlainFalseNode in keys(lstats)); + @test !(PlainConstantNode in keys(lstats)); @test lstats[PlainLiteralNode] == 8; @test istats[(Plain⋀Node, 2)] == 9; diff --git a/test/plain_logic_nodes_test.jl b/test/plain_logic_nodes_test.jl index 6cf7d8d3..f9e302c8 100644 --- a/test/plain_logic_nodes_test.jl +++ b/test/plain_logic_nodes_test.jl @@ -62,7 +62,7 @@ include("helper/plain_logic_circuits.jl") @test canonical_constants(r1) == (nothing, nothing) - @test node_stats(n0c)[PlainFalseNode] == 1 + @test node_stats(n0c)[PlainConstantNode] == 2 io = IOBuffer() show(io,n0c) @@ -73,14 +73,14 @@ include("helper/plain_logic_circuits.jl") @test n0c2 !== n0c @test num_edges(n0c2) == num_edges(n0c) @test num_nodes(n0c2) == num_nodes(n0c) - @test isempty(intersect(linearize(n0c2),linearize(n0c))) + @test all(isleaf, intersect(linearize(n0c2),linearize(n0c))) n0c2 = PlainLogicCircuit(n0c) @test n0c2 isa PlainLogicCircuit @test n0c2 !== n0c @test num_edges(n0c2) == num_edges(n0c) @test num_nodes(n0c2) == num_nodes(n0c) - @test isempty(intersect(linearize(n0c2),linearize(n0c))) + @test all(isleaf, intersect(linearize(n0c2),linearize(n0c))) end From 7d893631b5257f9d73846f2535767b7bad6c3a20 Mon Sep 17 00:00:00 2001 From: Guy Van den Broeck Date: Sun, 28 Mar 2021 19:20:45 -0700 Subject: [PATCH 4/6] remove `data` and `counter` from `Sdd` nodes --- src/abstract_logic_nodes.jl | 14 +++++----- src/plain_logic_nodes.jl | 3 --- src/sdd/apply.jl | 14 +++++----- src/sdd/sdd_functions.jl | 11 +++----- src/sdd/sdds.jl | 34 +++++------------------- src/structured/structured_logic_nodes.jl | 1 - 6 files changed, 25 insertions(+), 52 deletions(-) diff --git a/src/abstract_logic_nodes.jl b/src/abstract_logic_nodes.jl index d07f0a62..2be4f773 100644 --- a/src/abstract_logic_nodes.jl +++ b/src/abstract_logic_nodes.jl @@ -58,9 +58,6 @@ import ..Utils.children # make available for extension by concrete types "Get the logical literal in a given literal leaf node" @inline literal(n::LogicCircuit)::Lit = n.literal # override when needed -"Get the logical constant in a given constant leaf node" -function constant end - "Conjoin nodes into a single circuit" function conjoin end @@ -87,13 +84,18 @@ function compile end "Is the node a constant gate?" @inline isconstantgate(n) = GateType(n) isa ConstantGate +"Get the logical constant in a given constant leaf node" +@inline constant(n::LogicCircuit) = constant(GateType(n), n) +@inline constant(::ConstantGate, n::LogicCircuit) = n.constant::Bool + "Get the logical variable in a given literal leaf node" -@inline variable(n::LogicCircuit)::Var = variable(GateType(n), n) -@inline variable(::LiteralGate, n::LogicCircuit)::Var = lit2var(literal(n)) +@inline variable(n::LogicCircuit) = variable(GateType(n), n) +@inline variable(::LiteralGate, n::LogicCircuit)::Var = lit2var(literal(n))::Var "Get the sign of the literal leaf node" -@inline ispositive(n::LogicCircuit)::Bool = ispositive(GateType(n), n) +@inline ispositive(n::LogicCircuit) = ispositive(GateType(n), n) @inline ispositive(::LiteralGate, n::LogicCircuit)::Bool = literal(n) >= 0 + @inline isnegative(n::LogicCircuit)::Bool = !ispositive(n) "Is the circuit syntactically equal to true?" diff --git a/src/plain_logic_nodes.jl b/src/plain_logic_nodes.jl index 5df1fd7e..65406365 100644 --- a/src/plain_logic_nodes.jl +++ b/src/plain_logic_nodes.jl @@ -61,9 +61,6 @@ end # methods ##################### -"Get the logical constant in a given constant leaf node" -@inline constant(n::PlainConstantNode) = n.constant::Bool - "Get the children of a given inner node" @inline children(n::PlainLogicInnerNode) = n.children diff --git a/src/sdd/apply.jl b/src/sdd/apply.jl index 6fc2a0c1..eb921aa4 100644 --- a/src/sdd/apply.jl +++ b/src/sdd/apply.jl @@ -1,14 +1,12 @@ """ Conjoin two SDDs """ -@inline conjoin(::SddFalseNode, ::SddTrueNode) = false_sdd -@inline conjoin(::SddTrueNode, ::SddFalseNode) = false_sdd -@inline conjoin(s::Sdd, ::SddTrueNode) = s -@inline conjoin(::Sdd, ::SddFalseNode) = false_sdd -@inline conjoin(::SddTrueNode, s::Sdd) = s -@inline conjoin(::SddFalseNode, ::Sdd) = false_sdd -@inline conjoin(::SddTrueNode, ::SddTrueNode) = true_sdd -@inline conjoin(::SddFalseNode, ::SddFalseNode) = false_sdd +@inline conjoin(x::SddConstantNode, y::SddConstantNode) = + (isfalse(x) || isfalse(y)) ? false_sdd : true_sdd +@inline conjoin(x::Sdd, y::SddConstantNode) = + isfalse(y) ? false_sdd : x +@inline conjoin(x::SddConstantNode, y::Sdd) = + conjoin(y, x) # const stats = Dict{Tuple{Int,Int},Int}() diff --git a/src/sdd/sdd_functions.jl b/src/sdd/sdd_functions.jl index ba3dec3f..5b28f25f 100644 --- a/src/sdd/sdd_functions.jl +++ b/src/sdd/sdd_functions.jl @@ -20,9 +20,6 @@ export prime, sub, sdd_size, sdd_num_nodes, mgr, "Get the manager of a `Sdd` node, which is its `SddMgr` vtree" mgr(s::Sdd) = s.vtree -@inline constant(::SddTrueNode) = true -@inline constant(::SddFalseNode) = false - @inline children(n::Sdd⋀Node) = [n.prime,n.sub] @inline children(n::Sdd⋁Node) = n.children @@ -37,8 +34,8 @@ sdd_size(sdd) = mapreduce(n -> num_children(n), +, ⋁_nodes(sdd); init=0) # def "Count the number of decision nodes in the SDD" sdd_num_nodes(sdd) = length(⋁_nodes(sdd)) # defined as the number of `decisions` -Base.show(io::IO, ::SddTrueNode) = print(io, "⊤") -Base.show(io::IO, ::SddFalseNode) = print(io, "⊥") +Base.show(io::IO, n::SddConstantNode) = + print(io, (isfalse(n) ? "⊥" : "⊤")) Base.show(io::IO, c::SddLiteralNode) = print(io, literal(c)) Base.show(io::IO, c::Sdd⋀Node) = begin recshow(c::Union{SddConstantNode,SddLiteralNode}) = "$c" @@ -114,8 +111,8 @@ compile(::SddMgr, constant::Bool) = """ Negate an SDD """ -@inline negate(::SddFalseNode) = true_sdd -@inline negate(::SddTrueNode) = false_sdd +@inline negate(c::SddConstantNode) = + isfalse(c) ? true_sdd : false_sdd function negate(s::SddLiteralNode) if ispositive(s) diff --git a/src/sdd/sdds.jl b/src/sdd/sdds.jl index d187b1aa..fc6b337e 100644 --- a/src/sdd/sdds.jl +++ b/src/sdd/sdds.jl @@ -1,6 +1,6 @@ export Sdd, SddMgr, SddLeafNode, SddInnerNode, SddLiteralNode, SddConstantNode, - Sdd⋀Node, Sdd⋁Node, SddTrueNode, SddFalseNode, + Sdd⋀Node, Sdd⋁Node, sdd_mgr_for ############# @@ -127,56 +127,36 @@ abstract type SddInnerNode <: Sdd end mutable struct SddLiteralNode <: SddLeafNode literal::Lit vtree::SddMgrLeafNode - data - counter::UInt32 - SddLiteralNode(l,v) = new(l,v,nothing,false) end """ A SDD logical constant leaf node, representing true or false. These are the only structured nodes that don't have an associated vtree node (cf. SDD file format) """ -abstract type SddConstantNode <: SddLeafNode end - -"A SDD logical true constant." -mutable struct SddTrueNode <: SddConstantNode - counter::UInt32 - data +struct SddConstantNode <: SddLeafNode + constant::Bool end -# there is an issue with using this canonical node: if someone breaks the `bit` field in one SDD manager, it will be broken for all SDD managers... "Canonical true Sdd node" -const true_sdd = SddTrueNode(false, nothing) - -"A SDD logical false constant." -mutable struct SddFalseNode <: SddConstantNode - counter::UInt32 - data -end +const true_sdd = SddConstantNode(true) -# there is an issue with using this canonical node: if someone breaks the `bit` field in one SDD manager, it will be broken for all SDD managers... "Canonical false Sdd node" -const false_sdd = SddFalseNode(false, nothing) +const false_sdd = SddConstantNode(false) "A SDD logical conjunction node" mutable struct Sdd⋀Node <: SddInnerNode prime::Sdd sub::Sdd vtree::SddMgrInnerNode - counter::UInt32 - data - Sdd⋀Node(p,s,v) = new(p,s,v,false) end "A SDD logical disjunction node" mutable struct Sdd⋁Node <: SddInnerNode children::Vector{Sdd⋀Node} vtree::SddMgrInnerNode - counter::UInt32 negation::Sdd⋁Node - data - Sdd⋁Node(ch,v) = new(ch, v, false) # leave negation uninitialized - Sdd⋁Node(ch,v,neg) = new(ch, v, false, neg) + Sdd⋁Node(ch,v) = new(ch, v) # leave negation uninitialized + Sdd⋁Node(ch,v,neg) = new(ch, v, neg) end ##################### diff --git a/src/structured/structured_logic_nodes.jl b/src/structured/structured_logic_nodes.jl index 6fae94da..b876a590 100644 --- a/src/structured/structured_logic_nodes.jl +++ b/src/structured/structured_logic_nodes.jl @@ -79,7 +79,6 @@ const structfalse = PlainStructConstantNode(false) # methods ##################### -@inline constant(n::PlainStructConstantNode)::Bool = n.constant @inline children(n::PlainStruct⋁Node) = n.children @inline children(n::PlainStruct⋀Node) = [n.prime,n.sub] From ccdd504b8b064aba0ab2b2baf2f9c2e73f9e31dd Mon Sep 17 00:00:00 2001 From: Guy Van den Broeck Date: Sun, 28 Mar 2021 19:30:00 -0700 Subject: [PATCH 5/6] remove `counter` and `data` from test structs --- src/Utils/graphs.jl | 5 +---- test/Utils/graphs_test.jl | 14 ++------------ 2 files changed, 3 insertions(+), 16 deletions(-) diff --git a/src/Utils/graphs.jl b/src/Utils/graphs.jl index f86a1ccb..e1f9d99a 100644 --- a/src/Utils/graphs.jl +++ b/src/Utils/graphs.jl @@ -37,10 +37,7 @@ struct Inner <: NodeType end # basic fields and methods ##################### -# Each `Node` is required to have fields -# - `counter::UInt32` -# - `data:Any` -# and a specialized method for the following functions. +# Each `Node` is required to provide a specialized method for the following functions. "Get the node type trait of the given `Node`" @inline NodeType(node::Node) = NodeType(typeof(node)) diff --git a/test/Utils/graphs_test.jl b/test/Utils/graphs_test.jl index 992c6df9..8dafbcc6 100644 --- a/test/Utils/graphs_test.jl +++ b/test/Utils/graphs_test.jl @@ -9,16 +9,12 @@ module TestNodes mutable struct TestINode <: Dag id::Int children::Vector{Dag} - data - counter::UInt32 - TestINode(i,c) = new(i,c,nothing,false) + TestINode(i,c) = new(i,c) end mutable struct TestLNode <: Dag id::Int - data - counter::UInt32 - TestLNode(i) = new(i,nothing,false) + TestLNode(i) = new(i) end LogicCircuits.NodeType(::Type{<:TestINode}) = Inner() @@ -62,9 +58,6 @@ module TestNodes @test i12.id == 4 @test j2.id == 3 @test r.id == 6 - @test r.counter == 0 - @test l1.counter == 0 - @test i12.counter == 0 foreach(r, l -> l.id += 1, i -> i.id -= 1) @test l1.id == 2+1 @@ -72,9 +65,6 @@ module TestNodes @test i12.id == 4-1 @test j2.id == 3-1 @test r.id == 6-1 - @test r.counter == 0 - @test l1.counter == 0 - @test i12.counter == 0 @test filter(n -> iseven(n.id), r) == [l2,i2,j2] From eee02523d1d50c6f338c3ad62283da89358f04d9 Mon Sep 17 00:00:00 2001 From: Guy Van den Broeck Date: Sun, 28 Mar 2021 21:15:31 -0700 Subject: [PATCH 6/6] various sdd and vtree fixes --- src/queries/queries.jl | 4 +- src/sdd/apply.jl | 5 +- src/sdd/sdd_functions.jl | 4 +- src/structured/abstract_vtrees.jl | 3 - test/sdd/sdds_test.jl | 2 +- test/sdd/trimmed_apply_test.jl | 6 +- test/sdd/trimmed_sdds_test.jl | 108 +++++++++--------- .../structured/structured_logic_nodes_test.jl | 2 +- test/structured/vtree_tests.jl | 2 +- 9 files changed, 66 insertions(+), 70 deletions(-) diff --git a/src/queries/queries.jl b/src/queries/queries.jl index 956ee7ad..0359db27 100644 --- a/src/queries/queries.jl +++ b/src/queries/queries.jl @@ -245,8 +245,8 @@ Note: this function is generally intractable for large circuits. function isdeterministic(root::LogicCircuit)::Bool mgr = sdd_mgr_for(root) result::Bool = true - f_con(c) = mgr(constant(c)) - f_lit(n) = mgr(literal(n)) + f_con(c) = compile(mgr,constant(c)) + f_lit(n) = compile(mgr,literal(n)) f_a(_, cs) = reduce(&, cs) f_o(_, cs) = begin for i = 1:length(cs) diff --git a/src/sdd/apply.jl b/src/sdd/apply.jl index eb921aa4..253f4a85 100644 --- a/src/sdd/apply.jl +++ b/src/sdd/apply.jl @@ -2,9 +2,9 @@ Conjoin two SDDs """ @inline conjoin(x::SddConstantNode, y::SddConstantNode) = - (isfalse(x) || isfalse(y)) ? false_sdd : true_sdd + ((x === false_sdd) || (y === false_sdd)) ? false_sdd : true_sdd @inline conjoin(x::Sdd, y::SddConstantNode) = - isfalse(y) ? false_sdd : x + (y === false_sdd) ? false_sdd : x @inline conjoin(x::SddConstantNode, y::Sdd) = conjoin(y, x) @@ -18,7 +18,6 @@ function conjoin(s::SddLiteralNode, t::SddLiteralNode)::Sdd end end - # Note: attempts to make a special cache for conjunctions with literals have not yielded speedups function conjoin(s::Sdd, t::Sdd)::Sdd diff --git a/src/sdd/sdd_functions.jl b/src/sdd/sdd_functions.jl index 5b28f25f..ef6aa495 100644 --- a/src/sdd/sdd_functions.jl +++ b/src/sdd/sdd_functions.jl @@ -112,7 +112,7 @@ compile(::SddMgr, constant::Bool) = Negate an SDD """ @inline negate(c::SddConstantNode) = - isfalse(c) ? true_sdd : false_sdd + (c === false_sdd) ? true_sdd : false_sdd function negate(s::SddLiteralNode) if ispositive(s) @@ -210,7 +210,7 @@ model_count(root::StructLogicCircuit)::BigInt = model_count(root, length(global_scope(vtree(root)))) "Decide whether one sentence logically entails another" -entails(x::Sdd, y::Sdd) = isfalse(x & !y) +entails(x::Sdd, y::Sdd) = ((x & !y) === false_sdd) "Decide whether two sentences are logically equivalent" equivalent(x::Sdd, y::Sdd) = begin diff --git a/src/structured/abstract_vtrees.jl b/src/structured/abstract_vtrees.jl index 0c5dc67f..08c917ff 100644 --- a/src/structured/abstract_vtrees.jl +++ b/src/structured/abstract_vtrees.jl @@ -106,9 +106,6 @@ Base.show(io::IO, c::Vtree) = print(io, "$(typeof(c))($(join(variables(c), ',')) # Constructors ############# -# Syntactic sugar to compile circuits using a vtree -(vtree::Vtree)(arg) = compile(vtree, arg) - # construct vtrees from other vtrees function (::Type{V})(vtree::Vtree)::V where V<:Vtree f_leaf(l) = V(variable(l)) diff --git a/test/sdd/sdds_test.jl b/test/sdd/sdds_test.jl index c89f6eb1..4615db33 100644 --- a/test/sdd/sdds_test.jl +++ b/test/sdd/sdds_test.jl @@ -12,7 +12,7 @@ include("../helper/validate_sdd.jl") @test_throws Exception compile(mgr.left, cnf) @test_throws Exception compile(right_most_descendent(mgr), cnf) - r = mgr(cnf) + r = compile(mgr,cnf) @test compile(mgr, cnf) === r @test respects_vtree(r) diff --git a/test/sdd/trimmed_apply_test.jl b/test/sdd/trimmed_apply_test.jl index 9e4c7b2d..bd599f72 100644 --- a/test/sdd/trimmed_apply_test.jl +++ b/test/sdd/trimmed_apply_test.jl @@ -16,12 +16,12 @@ using LogicCircuits notx_c = compile(mgr,notx) true_c = compile(mgr,true) @test true_c isa Sdd - @test true_c === mgr(true) + @test true_c === compile(mgr,true) @test true_c === compile(Sdd,mgr,true) @test true_c === (Sdd,mgr)(true) @test true_c === (mgr,Sdd)(true) false_c = compile(mgr,false) - @test false_c === mgr(false) + @test false_c === compile(mgr,false) @test false_c & true_c == false_c @test false_c & notx_c == false_c @@ -139,7 +139,7 @@ using LogicCircuits @test model_count(f4c) == model_count(f4) @test f4c.vtree === f4.vtree - f4c = mgr(StructLogicCircuit(mgr,f4)) + f4c = compile(mgr,StructLogicCircuit(mgr,f4)) @test Sdd(mgr,f4c) === f4c @test f4 === f4c @test f4c isa Sdd diff --git a/test/sdd/trimmed_sdds_test.jl b/test/sdd/trimmed_sdds_test.jl index c0d5a914..fed07085 100644 --- a/test/sdd/trimmed_sdds_test.jl +++ b/test/sdd/trimmed_sdds_test.jl @@ -5,60 +5,60 @@ using LogicCircuits: Element # test some internals @testset "Trimmed SDD test" begin num_vars = 7 - mgr = SddMgr(num_vars, :balanced) + manager = SddMgr(num_vars, :balanced) - @test num_variables(mgr) == num_vars - @test num_nodes(mgr) == 2*num_vars-1 - @test num_edges(mgr) == 2*num_vars-2 - @test mgr isa SddMgr - - @test varsubset(left_most_descendent(mgr), mgr) - @test varsubset(mgr.left, mgr) - @test varsubset(mgr.right, mgr) - @test varsubset_left(mgr.left, mgr) - @test varsubset_left(mgr.left.left, mgr) - @test varsubset_left(mgr.left.right, mgr) - @test varsubset_right(mgr.right, mgr) - @test varsubset_right(mgr.right.right, mgr) - @test varsubset_right(mgr.right.left, mgr) - - @test !varsubset(mgr, left_most_descendent(mgr)) - @test !varsubset_left(mgr.right, mgr) - @test !varsubset_left(mgr.right.left, mgr) - @test !varsubset_left(mgr.right.right, mgr) - @test !varsubset_left(mgr, mgr) - @test !varsubset_left(mgr, mgr.left) - @test !varsubset_left(mgr, mgr.right) - @test !varsubset_right(mgr.left, mgr) - @test !varsubset_right(mgr.left.right, mgr) - @test !varsubset_right(mgr.left.left, mgr) - @test !varsubset_right(mgr, mgr) - @test !varsubset_right(mgr, mgr.left) - @test !varsubset_right(mgr, mgr.right) + @test num_variables(manager) == num_vars + @test num_nodes(manager) == 2*num_vars-1 + @test num_edges(manager) == 2*num_vars-2 + @test manager isa SddMgr + + @test varsubset(left_most_descendent(manager), manager) + @test varsubset(manager.left, manager) + @test varsubset(manager.right, manager) + @test varsubset_left(manager.left, manager) + @test varsubset_left(manager.left.left, manager) + @test varsubset_left(manager.left.right, manager) + @test varsubset_right(manager.right, manager) + @test varsubset_right(manager.right.right, manager) + @test varsubset_right(manager.right.left, manager) + + @test !varsubset(manager, left_most_descendent(manager)) + @test !varsubset_left(manager.right, manager) + @test !varsubset_left(manager.right.left, manager) + @test !varsubset_left(manager.right.right, manager) + @test !varsubset_left(manager, manager) + @test !varsubset_left(manager, manager.left) + @test !varsubset_left(manager, manager.right) + @test !varsubset_right(manager.left, manager) + @test !varsubset_right(manager.left.right, manager) + @test !varsubset_right(manager.left.left, manager) + @test !varsubset_right(manager, manager) + @test !varsubset_right(manager, manager.left) + @test !varsubset_right(manager, manager.right) x = Var(1) y = Var(2) - x_c = compile(mgr, var2lit(x)) - y_c = compile(mgr, var2lit(y)) + x_c = compile(manager, var2lit(x)) + y_c = compile(manager, var2lit(y)) @test x_c != y_c @test variable(x_c) == x @test literal(x_c) == var2lit(x) - @test vtree(x_c) ∈ mgr + @test vtree(x_c) ∈ manager @test ispositive(x_c) - @test x_c == compile(mgr, var2lit(x)) + @test x_c == compile(manager, var2lit(x)) @test variable(y_c) == y @test literal(y_c) == var2lit(y) - @test vtree(y_c) ∈ mgr + @test vtree(y_c) ∈ manager @test ispositive(y_c) - @test y_c == compile(mgr, var2lit(y)) + @test y_c == compile(manager, var2lit(y)) notx = -var2lit(x) - notx_c = compile(mgr,notx) + notx_c = compile(manager,notx) @test sat_prob(x_c) == 1//2 @test sat_prob(notx_c) == 1//2 @@ -69,16 +69,16 @@ using LogicCircuits: Element # test some internals @test variable(notx_c) == x @test literal(notx_c) == notx - @test vtree(notx_c) ∈ mgr + @test vtree(notx_c) ∈ manager @test isnegative(notx_c) - @test notx_c == compile(mgr, notx) + @test notx_c == compile(manager, notx) - true_c = compile(mgr,true) + true_c = compile(manager,true) @test istrue(true_c) @test constant(true_c) == true - false_c = compile(mgr,false) + false_c = compile(manager,false) @test isfalse(false_c) @test constant(false_c) == false @@ -91,26 +91,26 @@ using LogicCircuits: Element # test some internals @test model_count(true_c,num_vars) == BigInt(2)^(num_vars) @test model_count(false_c,num_vars) == BigInt(0) - v1 = compile(mgr, Lit(1)) - v2 = compile(mgr, Lit(2)) - v3 = compile(mgr, Lit(3)) - v4 = compile(mgr, Lit(4)) - v5 = compile(mgr, Lit(5)) - v6 = compile(mgr, Lit(6)) - v7 = compile(mgr, Lit(7)) - @test_throws Exception compile(mgr, Lit(8)) + v1 = compile(manager, Lit(1)) + v2 = compile(manager, Lit(2)) + v3 = compile(manager, Lit(3)) + v4 = compile(manager, Lit(4)) + v5 = compile(manager, Lit(5)) + v6 = compile(manager, Lit(6)) + v7 = compile(manager, Lit(7)) + @test_throws Exception compile(manager, Lit(8)) p1 = [Element(true_c,v3)] - @test canonicalize(p1, mgr.left.right) === v3 + @test canonicalize(p1, manager.left.right) === v3 p2 = [Element(v1,true_c), Element(!v1,false_c)] - @test canonicalize(p2, mgr.left) === v1 + @test canonicalize(p2, manager.left) === v1 p3 = [Element(v1,v4), Element(!v1,v7)] - n1 = canonicalize(p3, mgr) + n1 = canonicalize(p3, manager) p4 = [Element(!v1,v7), Element(v1,v4)] - n2 = canonicalize(p4,mgr) - @test n1.vtree.left === mgr.left - @test n1.vtree.right === mgr.right + n2 = canonicalize(p4,manager) + @test n1.vtree.left === manager.left + @test n1.vtree.right === manager.right @test n1 === n2 @test isdeterministic(n1) @test n1(true, false, false, true, false, false, false) diff --git a/test/structured/structured_logic_nodes_test.jl b/test/structured/structured_logic_nodes_test.jl index 30a978d8..6db230f2 100644 --- a/test/structured/structured_logic_nodes_test.jl +++ b/test/structured/structured_logic_nodes_test.jl @@ -37,7 +37,7 @@ using DataFrames: DataFrame @test !istrue(f) @test !isfalse(f) - @test literal(vtree(Lit(-5))) == Lit(-5) + @test literal(compile(vtree,Lit(-5))) == Lit(-5) @test literal((PlainStructLogicCircuit,vtree)(Lit(-5))) == Lit(-5) @test constant((PlainStructLogicCircuit,vtree)(false)) == false diff --git a/test/structured/vtree_tests.jl b/test/structured/vtree_tests.jl index b73f5065..b8d9f33b 100644 --- a/test/structured/vtree_tests.jl +++ b/test/structured/vtree_tests.jl @@ -21,7 +21,7 @@ using LogicCircuits @test lca(v1,i1) == i1 @test lca(v1,i1,v1) == i1 @test lca(v1,v2,v3) == r - @test lca(vtree_safe(r(true)), vtree_safe(r(false))) === nothing + @test lca(vtree_safe(compile(r,true)), vtree_safe(compile(r,false))) === nothing @test_throws Exception lca(i1,PlainVtree(Var(4))) @test varsubset_left(v1,r)