Skip to content

Commit

Permalink
Shell BDD inside a namespace to avoid conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
RenatoGeh committed Jun 23, 2021
1 parent e1d819b commit f169790
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions src/structurelearner/sample_psdd.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using StatsFuns
using BinaryDecisionDiagrams
using BinaryDecisionDiagrams: Diagram, BinaryDecisionDiagrams
const BDD = BinaryDecisionDiagrams

"Samples an element from a Binomial distribution with p=0.5."
Expand Down Expand Up @@ -52,8 +52,8 @@ Samples a partial partition.
"""
function sample_partition::Diagram, Sc::BitSet, p::Real, k::Integer, ⊤_k::Integer,
exact::Bool)::Dict{Diagram, Vector{Diagram}}
X = intersect!(scope(ϕ), Sc)
idem = isempty(X) || is_⊤(ϕ)
X = intersect!(BDD.scope(ϕ), Sc)
idem = isempty(X) || BDD.is_⊤(ϕ)
O = shuffle!(idem ? collect(Sc) : X)
E = Dict{Diagram, Vector{Diagram}}()
if idem
Expand All @@ -72,7 +72,7 @@ function sample_idem_primes!(O::Vector{Int}, p::Real, P::Vector{Diagram}, k::Int
while !isempty(Q)
i, V = popfirst!(Q)
if e_count >= k || i > Sc_len
push!(P, and(V))
push!(P, BDD.and(V))
continue
end
x = O[i]
Expand All @@ -97,17 +97,17 @@ function sample_primes!(ϕ::Diagram, O::Vector{Int}, E::Dict{Diagram, Vector{Dia
Q = Tuple{Diagram, Int, Vector{Int}}[(ϕ, 1, Vector{Int}())]
while !isempty(Q)
ψ, i, V = popfirst!(Q)
if (e_count >= k && !exact) || (i > Sc_len) || is_⊤(ψ)
if !haskey(E, ψ) E[ψ] = Diagram[and(V)]
else push!(E[ψ], and(V)) end
if (e_count >= k && !exact) || (i > Sc_len) || BDD.is_⊤(ψ)
if !haskey(E, ψ) E[ψ] = Diagram[BDD.and(V)]
else push!(E[ψ], BDD.and(V)) end
continue
end
x = O[i]
if x ψ push!(Q, (ψ, i+1, V))
else
α, β = ψ|x, ψ|-x
if !is_⊥(α) push!(Q, (α, i+1, push!(copy(V), x))) end
if !is_⊥(β) push!(Q, (β, i+1, push!(copy(V), -x))) end
if !BDD.is_⊥(α) push!(Q, (α, i+1, push!(copy(V), x))) end
if !BDD.is_⊥(β) push!(Q, (β, i+1, push!(copy(V), -x))) end
e_count += 1
end
end
Expand Down Expand Up @@ -148,9 +148,9 @@ function sample_psdd_r(ϕ::Diagram, V::Vtree, k::Integer, leaves::Dict{Int32, St
if isleaf(V)
v, v64 = convert(Int32, V.var), convert(Int, V.var)
if v ϕ
if is_lit(ϕ) return get_lit(to_lit(ϕ), V, leaves) end
if is_⊤|v64) return get_lit(v, V, leaves) end
if is_⊤|-v64) return get_lit(-v, V, leaves) end
if BDD.is_lit(ϕ) return get_lit(BDD.to_lit(ϕ), V, leaves) end
if BDD.is_⊤|v64) return get_lit(v, V, leaves) end
if BDD.is_⊤|-v64) return get_lit(-v, V, leaves) end
S = StructSumNode([get_lit(v, V, leaves), get_lit(-v, V, leaves)], V)
if merge_branch repeats[r_p] = S end
if randomize_weights S.log_probs = random_weights(2) end
Expand All @@ -160,7 +160,7 @@ function sample_psdd_r(ϕ::Diagram, V::Vtree, k::Integer, leaves::Dict{Int32, St
if merge_branch repeats[r_p] = S end
if randomize_weights S.log_probs = random_weights(2) end
return S
elseif fact_on_⊤ && (is_⊤(ϕ) || isempty(intersect!(scopeset(ϕ), variables(V))))
elseif fact_on_⊤ && (BDD.is_⊤(ϕ) || isempty(intersect!(BDD.scopeset(ϕ), variables(V))))
# When ϕ ≡ ⊤ and we want to simplify the circuit, fully factorize.
left = sample_psdd_r(⊤, V.left, k, leaves, randomize_weights, opts, fact_on_⊤, ⊤_k, p_mr,
always_compress, always_merge, repeats, merge_branch, merge_branch_pr,
Expand All @@ -176,7 +176,7 @@ function sample_psdd_r(ϕ::Diagram, V::Vtree, k::Integer, leaves::Dict{Int32, St
C = Vector{StructProbCircuit}()
# Left element.
left_sub_ϕ = ϕ|prime_var
if !is_⊥(left_sub_ϕ)
if !BDD.is_⊥(left_sub_ϕ)
left_prime = sample_psdd_r(BDD.variable(prime_var), V.left, k, leaves,
randomize_weights, opts, fact_on_⊤, ⊤_k, p_mr,
always_compress, always_merge, repeats, merge_branch,
Expand All @@ -187,7 +187,7 @@ function sample_psdd_r(ϕ::Diagram, V::Vtree, k::Integer, leaves::Dict{Int32, St
push!(C, StructMulNode(left_prime, left_sub, V))
end
right_sub_ϕ = ϕ|-prime_var
if !is_⊥(right_sub_ϕ)
if !BDD.is_⊥(right_sub_ϕ)
right_prime = sample_psdd_r(BDD.variable(-prime_var), V.left, k, leaves,
randomize_weights, opts, fact_on_⊤, ⊤_k, p_mr,
always_compress, always_merge, repeats, merge_branch,
Expand All @@ -203,7 +203,7 @@ function sample_psdd_r(ϕ::Diagram, V::Vtree, k::Integer, leaves::Dict{Int32, St
return S
end
L, R = variables(V.left), variables(V.right)
force_exact = (exact || anc_exact) && !is_⊤(ϕ)
force_exact = (exact || anc_exact) && !BDD.is_⊤(ϕ)
E = sample_partition(ϕ, L, p_mr, k, ⊤_k, force_exact)
C = Vector{StructProbCircuit}()
for (s, P) E
Expand All @@ -216,7 +216,7 @@ function sample_psdd_r(ϕ::Diagram, V::Vtree, k::Integer, leaves::Dict{Int32, St
merge_branch, merge_branch_pr, true, true)
sub_node = sample_psdd_r(s, V.right, k, leaves, randomize_weights, opts, fact_on_⊤,
⊤_k, p_mr, always_compress, always_merge, repeats,
merge_branch && isempty(intersect!(scopeset(s), R)),
merge_branch && isempty(intersect!(BDD.scopeset(s), R)),
merge_branch_pr, false, anc_exact)
push!(C, StructMulNode(prime_node, sub_node, V))
continue
Expand All @@ -229,7 +229,7 @@ function sample_psdd_r(ϕ::Diagram, V::Vtree, k::Integer, leaves::Dict{Int32, St
merge_branch, merge_branch_pr, true, true)
sub_node = sample_psdd_r(s, V.right, k, leaves, randomize_weights, opts, fact_on_⊤,
⊤_k, p_mr, always_compress, always_merge, repeats,
merge_branch && isempty(intersect!(scopeset(s), R)),
merge_branch && isempty(intersect!(BDD.scopeset(s), R)),
merge_branch_pr, false, anc_exact)
push!(C, StructMulNode(prime_node, sub_node, V))
end
Expand All @@ -249,7 +249,7 @@ function sample_psdd_r(ϕ::Diagram, V::Vtree, k::Integer, leaves::Dict{Int32, St
else # Compress c randomly selected elements.
K = Vector{Tuple{Diagram, Vector{Int}}}(undef, n-c+1)
comb = sample_comb(n, c)
j, ψ = 2, ⊥
j, ψ = 2, BDD.
@inbounds for i 1:n
if i comb
ψ = BDD.:(ψ, P[i])
Expand Down Expand Up @@ -297,7 +297,7 @@ function sample_psdd_r(ϕ::Diagram, V::Vtree, k::Integer, leaves::Dict{Int32, St
if !isempty(M)
sub_node = sample_psdd_r(s, V.right, k, leaves, randomize_weights, opts, fact_on_⊤,
⊤_k, p_mr, always_compress, always_merge, repeats,
merge_branch && isempty(intersect!(scopeset(s), R)),
merge_branch && isempty(intersect!(BDD.scopeset(s), R)),
merge_branch_pr, false, anc_exact)
for j M subs[j] = sub_node end
end
Expand Down

0 comments on commit f169790

Please sign in to comment.