diff --git a/src/abstract_logic_nodes.jl b/src/abstract_logic_nodes.jl index 3e0c2dcb..02ebf159 100644 --- a/src/abstract_logic_nodes.jl +++ b/src/abstract_logic_nodes.jl @@ -7,7 +7,7 @@ export LogicCircuit, GateType, InnerGate, LeafGate, fully_factorized_circuit, ⋁_nodes, ⋀_nodes, or_nodes, and_nodes, canonical_literals, canonical_constants, tree_formula_string, - isflat, iscnf, isdnf + isflat, iscnf, isdnf, has_vars_contiguous ##################### # Abstract infrastructure for logic circuit nodes @@ -249,4 +249,10 @@ iscnf(circuit) = isdnf(circuit) = is⋁gate(circuit) && all(children(circuit)) do clause is⋀gate(clause) && all(isliteralgate, children(clause)) - end \ No newline at end of file + end + +"Does the circuit have a contiguously indexed set of variables" +has_vars_contiguous(circuit) = begin + vars = variables(circuit) + (maximum(vars) == length(vars)) +end diff --git a/src/io/nnf_io.jl b/src/io/nnf_io.jl index a3ecedd3..db757e73 100644 --- a/src/io/nnf_io.jl +++ b/src/io/nnf_io.jl @@ -99,7 +99,7 @@ Base.read(io::IO, ::Type{PlainLogicCircuit}, ::NnfFormat) = function Base.write(io::IO, circuit::LogicCircuit, ::NnfFormat) labeling = label_nodes(circuit) - map!(x -> x-1, values(labeling)) # vtree nodes are 0-based indexed + map!(x -> x-1, values(labeling)) # nodes are 0-based indexed println(io, "nnf $(num_nodes(circuit)) $(num_edges(circuit)) $(num_variables(circuit))") foreach(circuit) do n diff --git a/src/transformations.jl b/src/transformations.jl index c7bf4391..7a55c5a6 100644 --- a/src/transformations.jl +++ b/src/transformations.jl @@ -1,6 +1,6 @@ export smooth, forget, propagate_constants, deepcopy, conjoin, replace_node, split, clone, merge, split_candidates, random_split, split_step, struct_learn, - clone_candidates, standardize_circuit + clone_candidates, standardize_circuit, make_vars_contiguous """ @@ -589,3 +589,30 @@ function standardize_circuit(circuit::LogicCircuit) foldup_aggregate(circuit, f_con, f_lit, f_a, f_o, LogicCircuit) end + + +""" +Make all variables in this circuit contiguously numbered. Return new circuit and the variable mapping. +""" +function make_vars_contiguous(root::Node) + var_bijection = [(v, i) for (i,v) in enumerate(variables(root))] + var_dict = Dict(var_bijection) + var2lits = Dict(map(var_bijection) do (v,i) + pos_lit = compile(typeof(root), Lit(i)) + neg_lit = compile(typeof(root), -Lit(i)) + (v => (pos_lit, neg_lit)) + end) + f_con(n) = n + f_lit(n) = begin + if var_dict[variable(n)] == variable(n) + n + else + lits = var2lits[variable(n)] + ispositive(n) ? lits[1] : lits[2] + end + end + f_a(n, cn) = conjoin([cn...]; reuse=n) + f_o(n, cn) = disjoin([cn...]; reuse=n) + root2 = foldup_aggregate(root, f_con, f_lit, f_a, f_o, Node) + root2, var_bijection +end \ No newline at end of file diff --git a/test/transformations_test.jl b/test/transformations_test.jl index b5d4ea80..ba483b36 100644 --- a/test/transformations_test.jl +++ b/test/transformations_test.jl @@ -268,3 +268,19 @@ end @test circuit.children[3].literal == Lit(3) @test circuit.children[4].literal == Lit(4) end + +@testset "Contiguous variable test" begin + c1 = zoo_sdd_random + @test has_vars_contiguous(c1) + + c2 = forget(c1, isodd) + @test variables(c2) == BitSet(2:2:num_variables(c1)) + @test !has_vars_contiguous(c2) + + c3, bijection = make_vars_contiguous(c2) + @test num_nodes(c2) == num_nodes(c3) + @test num_edges(c2) == num_edges(c3) + @test has_vars_contiguous(c3) + + @test make_vars_contiguous(c1)[1] === c1 +end \ No newline at end of file