Skip to content

Commit

Permalink
Merge pull request #84 from Juice-jl/circuit-traversal
Browse files Browse the repository at this point in the history
Change circuit traversal API to use `Dict` instead of a `counter` field
  • Loading branch information
guyvdbroeck authored Mar 29, 2021
2 parents 9ffb63a + eee0252 commit 00177e3
Show file tree
Hide file tree
Showing 21 changed files with 282 additions and 505 deletions.
4 changes: 2 additions & 2 deletions src/LoadSave/circuit_line_compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ function compile_logical_m(lines::CircuitFormatLines)
PlainLiteralNode(l)
end

true_node = PlainTrueNode()
false_node = PlainFalseNode()
true_node = PlainConstantNode(true)
false_node = PlainConstantNode(false)

function compile(ln::CircuitFormatLine)
error("Compilation of line $ln is not supported")
Expand Down
196 changes: 39 additions & 157 deletions src/Utils/graphs.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
export Node, Dag, NodeType, Leaf, Inner,
children, has_children, num_children, isleaf, isinner,
reset_counter, reset_counter_hard, foreach, foreach_reset, clear_data, filter,
nload, nsave,
foldup, foldup_aggregate,
count_parents, foreach_down,
foreach, foreach_down, filter, foldup, foldup_aggregate,
num_nodes, num_edges, tree_num_nodes, tree_num_edges, in,
inodes, innernodes, leafnodes, linearize,
left_most_descendent, right_most_descendent,
Expand Down Expand Up @@ -40,10 +37,7 @@ struct Inner <: NodeType end
# basic fields and methods
#####################

# Each `Node` is required to have fields
# - `counter::UInt32`
# - `data:Any`
# and a specialized method for the following functions.
# Each `Node` is required to provide a specialized method for the following functions.

"Get the node type trait of the given `Node`"
@inline NodeType(node::Node) = NodeType(typeof(node))
Expand Down Expand Up @@ -76,93 +70,45 @@ function children end
# traversal
#####################

"Set the counter field throughout this graph (assumes no node in the graph already has the value)"
function reset_counter(node::Dag, v::Int=0)
if node.counter != v
node.counter = v
if isinner(node)
for c in children(node)
reset_counter(c, v)
end
end
end
nothing # returning nothing helps save some allocations and time
end

"Set the counter field throughout this graph. Is slower but works even when the node fields are in an arbitrary state."
function reset_counter_hard(node::Dag, v::Int=0, seen::Set{Dag}=Set{Dag}())
if node seen
push!(seen, node)
node.counter = v
if isinner(node)
for c in children(node)
reset_counter_hard(c, v, seen)
end
end
end
nothing
end


import Base.foreach #extend

"Apply a function to each node in a graph, bottom up"
function foreach(f::Function, node::Dag; reset=true)
@assert node.counter == 0 "Another algorithm is already traversing this circuit and using the `counter` field. You can use `reset_counter` to reset the counter. "
foreach_rec(f, node)
reset && reset_counter(node)
nothing # returning nothing helps save some allocations and time
end
foreach(f::Function, node::Dag, seen::Nothing=nothing) =
foreach_dict(f, node, Dict{Dag,Nothing}())

function foreach(node::Dag, f_leaf::Function, f_inner::Function)
foreach(node) do n
isinner(n) ? f_inner(n) : f_leaf(n)
end
nothing # returning nothing helps save some allocations and time
end

"Apply a function to each node in a graph, bottom up, without resetting the counter"
function foreach_rec(f::Function, node::Dag)
if (node.counter += 1) == 1
function foreach_dict(f::Function, node::Dag, seen)
get!(seen, node) do
if isinner(node)
for c in children(node)
foreach_rec(f, c)
foreach_dict(f, c, seen)
end
end
f(node)
nothing
end
nothing # returning nothing helps save some allocations and time
nothing
end

"Apply a function to each node in a graph, bottom up, while resetting the counter"
function foreach_reset(f::Function, node::Dag)
if node.counter != 0
node.counter = 0
if isinner(node)
for c in children(node)
foreach_reset(f, c)
end
end
f(node)
function foreach(node::Dag, f_leaf::Function, f_inner::Function, seen=nothing)
foreach(node, seen) do n
isinner(n) ? f_inner(n) : f_leaf(n)
end
nothing # returning nothing helps save some allocations and time
nothing
end

"Set all the data fields in a circuit to `nothing`"
function clear_data(node::Dag)
foreach(x -> x.data = nothing, node)
"Apply a function to each node in a graph, top down"
function foreach_down(f::Function, node::Dag)
# naive implementation
lin = linearize(node)
foreach(f, Iterators.reverse(lin))
end


# TODO: consider adding a top-down version of foreach, by either linearizing into a List,
# or by keeping a visit counter to identify processing of the last parent.

import Base.filter #extend

"""Retrieve list of nodes in graph matching predicate `p`"""
function filter(p::Function, root::Dag, ::Type{T} = Union{})::Vector where T
function filter(p::Function, root::Dag, ::Type{T} = Union{}, seen=nothing)::Vector where T
results = Vector{T}()
foreach(root) do n
foreach(root, seen) do n
if p(n)
if !(n isa eltype(results))
results = collect(typejoin(eltype(results), typeof(n)), results)
Expand All @@ -173,48 +119,28 @@ function filter(p::Function, root::Dag, ::Type{T} = Union{})::Vector where T
results
end

"Default getter to obtain data associated with a node"
@inline nload(n) = n.data

"Default setter to assign data associated with a node"
@inline nsave(n,v) = n.data = v

"""
foldup(node::Dag,
f_leaf::Function,
f_inner::Function,
::Type{T}; nload = nload, nsave = nsave, reset=true)::T where {T}
::Type{T})::T where {T}
Compute a function bottom-up on the graph.
`f_leaf` is called on leaf nodes, and `f_inner` is called on inner nodes.
Values of type `T` are passed up the circuit and given to `f_inner` as a function on the children.
"""
function foldup(node::Dag, f_leaf::Function, f_inner::Function,
::Type{T}; nload = nload, nsave = nsave, reset=true) where {T}
@assert node.counter == 0 "Another algorithm is already traversing this circuit and using the `counter` field. You can use `reset_counter` to reset the counter. "
v = foldup_rec(node, f_leaf, f_inner, T; nload, nsave)::T
reset && reset_counter(node)
v
function foldup(node::Dag, f_leaf::Function, f_inner::Function, ::Type{T}, cache::Nothing=nothing) where {T}
foldup(node, f_leaf, f_inner, T, Dict{Dag,T}())
end


"""
Compute a function bottom-up on the graph, without resetting the counter.
`f_leaf` is called on leaf nodes, and `f_inner` is called on inner nodes.
Values of type `T` are passed up the circuit and given to `f_inner` as a function on the children.
"""
function foldup_rec(node::Dag, f_leaf::Function, f_inner::Function, ::Type{T};
nload = nload, nsave = nsave) where {T}
if (node.counter += 1) != 1
return nload(node)::T
else
v = if isinner(node)
callback(c) = (foldup_rec(c, f_leaf, f_inner, T; nload, nsave)::T)
f_inner(node, callback)::T
else
f_leaf(node)::T
end
return nsave(node, v)::T
function foldup(node::Dag, f_leaf::Function, f_inner::Function, ::Type{T}, seen::Dict) where {T}
get!(seen, node) do
if isinner(node)
callback(c) = foldup(c, f_leaf, f_inner, T, seen)::T
f_inner(node, callback)::T
else
f_leaf(node)::T
end
end
end

Expand All @@ -224,65 +150,23 @@ Compute a function bottom-up on the circuit.
Values of type `T` are passed up the circuit and given to `f_inner` in aggregate
as a vector from the children.
"""
# TODO: see whether we could standardize on `foldup` and remove this version?
function foldup_aggregate(node::Dag, f_leaf::Function, f_inner::Function, ::Type{T};
nload = nload, nsave = nsave, reset=true) where {T}
@assert node.counter == 0 "Another algorithm is already traversing this circuit and using the `counter` field. You can use `reset_counter` to reset the counter. "
v = foldup_aggregate_rec(node, f_leaf, f_inner, T; nload, nsave)::T
reset && reset_counter(node)
return v
function foldup_aggregate(node::Dag, f_leaf::Function, f_inner::Function, ::Type{T}, cache::Nothing=nothing) where {T}
foldup_aggregate(node, f_leaf, f_inner, T, Dict{Dag,T}())
end

"""
Compute a function bottom-up on the circuit, without resetting the counter.
`f_leaf` is called on leaf nodes, and `f_inner` is called on inner nodes.
Values of type `T` are passed up the circuit and given to `f_inner` in aggregate
as a vector from the children.
"""
# TODO: see whether we could standardize on `foldup` and remove this version?
function foldup_aggregate_rec(node::Dag, f_leaf::Function, f_inner::Function, ::Type{T};
nload = nload, nsave = nsave) where {T}
if (node.counter += 1) != 1
return nload(node)::T
else
v = if isinner(node)
function foldup_aggregate(node::Dag, f_leaf::Function, f_inner::Function, ::Type{T}, seen::Dict) where {T}
get!(seen, node) do
if isinner(node)
child_values = Vector{T}(undef, num_children(node))
map!(c -> foldup_aggregate_rec(c, f_leaf, f_inner, T; nload, nsave)::T,
child_values, children(node))
map!(c -> foldup_aggregate(c, f_leaf, f_inner, T, seen)::T,
child_values, children(node))
f_inner(node, child_values)::T
else
f_leaf(node)::T
end
return nsave(node, v)::T
end
end

"Set the `counter` field of each node to its number of parents"
function count_parents(node::Dag)
foreach(noop,node;reset=false)
end

"Apply a function to each node in a graph, top down"
function foreach_down(f::Function, node::Dag; setcounter=true)
setcounter && count_parents(node)
foreach_down_rec(f, node)
nothing
end

"Apply a function to each node in a graph, top down, without setting the counter first"
function foreach_down_rec(f::Function, n::Node)
if ((n.counter -= 1) == 0)
f(n)
if isinner(n)
for c in children(n)
foreach_down_rec(f, c)
end
end
end
nothing
end


#####################
# methods using circuit traversal
#####################
Expand All @@ -307,7 +191,7 @@ Compute the number of nodes in of a tree-unfolding of the `Dag`.
"""
function tree_num_nodes(node::Dag)::BigInt
@inline f_leaf(n) = one(BigInt)
@inline f_inner(n, call) = (1 + mapreduce(c -> call(c), +, children(n)))
@inline f_inner(n, call) = (1 + mapreduce(call, +, children(n)))
foldup(node, f_leaf, f_inner, BigInt)
end

Expand Down Expand Up @@ -340,7 +224,6 @@ function Base.in(needle::Dag, circuit::Dag)
contained
end


"Get the list of inner nodes in a given graph"
inodes(c::Dag) = filter(isinner, c)

Expand Down Expand Up @@ -373,7 +256,6 @@ function right_most_descendent(root::Dag)::Dag
root
end


#####################
# debugging methods (not performance critical)
#####################
Expand Down
14 changes: 8 additions & 6 deletions src/abstract_logic_nodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ import ..Utils.children # make available for extension by concrete types
"Get the logical literal in a given literal leaf node"
@inline literal(n::LogicCircuit)::Lit = n.literal # override when needed

"Get the logical constant in a given constant leaf node"
function constant end

"Conjoin nodes into a single circuit"
function conjoin end

Expand All @@ -87,13 +84,18 @@ function compile end
"Is the node a constant gate?"
@inline isconstantgate(n) = GateType(n) isa ConstantGate

"Get the logical constant in a given constant leaf node"
@inline constant(n::LogicCircuit) = constant(GateType(n), n)
@inline constant(::ConstantGate, n::LogicCircuit) = n.constant::Bool

"Get the logical variable in a given literal leaf node"
@inline variable(n::LogicCircuit)::Var = variable(GateType(n), n)
@inline variable(::LiteralGate, n::LogicCircuit)::Var = lit2var(literal(n))
@inline variable(n::LogicCircuit) = variable(GateType(n), n)
@inline variable(::LiteralGate, n::LogicCircuit)::Var = lit2var(literal(n))::Var

"Get the sign of the literal leaf node"
@inline ispositive(n::LogicCircuit)::Bool = ispositive(GateType(n), n)
@inline ispositive(n::LogicCircuit) = ispositive(GateType(n), n)
@inline ispositive(::LiteralGate, n::LogicCircuit)::Bool = literal(n) >= 0

@inline isnegative(n::LogicCircuit)::Bool = !ispositive(n)

"Is the circuit syntactically equal to true?"
Expand Down
16 changes: 9 additions & 7 deletions src/bit_circuit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,15 @@ struct BitCircuit{V,M}
nodes::M
elements::M
parents::V
node2id
end

function BitCircuit(circuit::LogicCircuit, data; reset=true, on_decision=noop)
BitCircuit(circuit, num_features(data); reset, on_decision)
function BitCircuit(circuit::LogicCircuit, data; on_decision=noop)
BitCircuit(circuit, num_features(data); on_decision)
end

"construct a new `BitCircuit` accomodating the given number of features"
function BitCircuit(circuit::LogicCircuit, num_features::Int; reset=true, on_decision=noop)
function BitCircuit(circuit::LogicCircuit, num_features::Int; on_decision=noop)
#TODO: consider not using foldup_aggregate and instead calling twice to ensure order but save allocations
#TODO add inbounds annotations

Expand Down Expand Up @@ -158,7 +159,8 @@ function BitCircuit(circuit::LogicCircuit, num_features::Int; reset=true, on_dec
⋁NodeIds(layer_id, last_dec_id)
end

r = foldup_aggregate(circuit, f_con, f_lit, f_and, f_or, NodeIds; reset)
node2id = Dict{LogicCircuit,NodeIds}()
r = foldup_aggregate(circuit, f_con, f_lit, f_and, f_or, NodeIds, node2id)
to⋁NodeIds(r)

nodes_m = reshape(nodes, 4, :)
Expand All @@ -178,7 +180,7 @@ function BitCircuit(circuit::LogicCircuit, num_features::Int; reset=true, on_dec
end
end

return BitCircuit(layers, nodes_m, elements_m, parents_m)
return BitCircuit(layers, nodes_m, elements_m, parents_m, node2id)
end

import .Utils: num_nodes # extend
Expand Down Expand Up @@ -234,10 +236,10 @@ end
import .Utils: to_gpu, to_cpu, isgpu #extend

to_gpu(c::BitCircuit) =
BitCircuit(map(to_gpu, c.layers), to_gpu(c.nodes), to_gpu(c.elements), to_gpu(c.parents))
BitCircuit(map(to_gpu, c.layers), to_gpu(c.nodes), to_gpu(c.elements), to_gpu(c.parents), c.node2id)

to_cpu(c::BitCircuit) =
BitCircuit(map(to_cpu, c.layers), to_cpu(c.nodes), to_cpu(c.elements), to_cpu(c.parents))
BitCircuit(map(to_cpu, c.layers), to_cpu(c.nodes), to_cpu(c.elements), to_cpu(c.parents), c.node2id)

isgpu(c::BitCircuit{<:CuArray,<:CuArray}) = true
isgpu(c::BitCircuit{<:Array,<:Array}) = false
Loading

0 comments on commit 00177e3

Please sign in to comment.