From cfe0f42544ca085966059a6e79c22f9a820ddde9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Wed, 16 Oct 2024 12:10:10 +0200 Subject: [PATCH] fixup #65 its much better to pass the rng via iterator state, otherwise both identical positions get the same random force thus move in the same direction. This was messing up an example in the docs --- src/sfdp.jl | 10 ++++------ src/spring.jl | 10 ++++------ 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/src/sfdp.jl b/src/sfdp.jl index a052500..b9a9379 100644 --- a/src/sfdp.jl +++ b/src/sfdp.jl @@ -76,12 +76,12 @@ function Base.iterate(iter::LayoutIterator{<:SFDP{Dim,Ptype,T}}) where {Dim,Ptyp pin = [get(algo.pin, i, SVector{Dim,Bool}(false for _ in 1:Dim)) for i in 1:N] # iteratorstate: (#iter, energy, step, progress, old pos, pin, stopflag) - return startpos, (1, typemax(T), one(T), 0, startpos, pin, false) + return startpos, (1, typemax(T), one(T), 0, startpos, pin, rng, false) end function Base.iterate(iter::LayoutIterator{<:SFDP}, state) algo, adj_matrix = iter.algorithm, iter.adj_matrix - iter, energy0, step, progress, locs0, pin, stopflag = state + iter, energy0, step, progress, locs0, pin, rng, stopflag = state K, C, tol = algo.K, algo.C, algo.tol # stop if stopflag (tol reached) or nr of iterations reached @@ -109,9 +109,7 @@ function Base.iterate(iter::LayoutIterator{<:SFDP}, state) end if any(isnan, force) # if two points are at the exact same location use random force in any direction - # copy rng from alg struct to not advance the "initial" rng state - # otherwise algo(g)==algo(g) might be broken - force += randn(copy(algo.rng), Ftype) + force = randn(rng, Ftype) end mask = (!).(pin[i]) # where pin=true mask will multiply with 0 locs[i] = locs[i] .+ (step .* (force ./ norm(force))) .* mask @@ -124,7 +122,7 @@ function Base.iterate(iter::LayoutIterator{<:SFDP}, state) stopflag = true end - return locs, (iter + 1, energy, step, progress, locs, pin, stopflag) + return locs, (iter + 1, energy, step, progress, locs, pin, rng, stopflag) end # Calculate Attractive force diff --git a/src/spring.jl b/src/spring.jl index 5eb14f4..b0b26d0 100644 --- a/src/spring.jl +++ b/src/spring.jl @@ -75,12 +75,12 @@ function Base.iterate(iter::LayoutIterator{<:Spring{Dim,Ptype}}) where {Dim,Ptyp pin = [get(algo.pin, i, SVector{Dim,Bool}(false for _ in 1:Dim)) for i in 1:N] # iteratorstate: #iter nr, old pos, pin - return (startpos, (1, startpos, pin)) + return (startpos, (1, startpos, pin, rng)) end function Base.iterate(iter::LayoutIterator{<:Spring}, state) algo, adj_matrix = iter.algorithm, iter.adj_matrix - iteration, old_pos, pin = state + iteration, old_pos, pin, rng = state iteration >= algo.iterations && return nothing # The optimal distance bewteen vertices @@ -114,9 +114,7 @@ function Base.iterate(iter::LayoutIterator{<:Spring}, state) else # if two points are at the exact same location # use random force in any direction - # copy rng from alg struct to not advance the "initial" rng state - # otherwise algo(g)==algo(g) might be broken - force_vec += randn(copy(algo.rng), Ftype) + force_vec += randn(rng, Ftype) end end @@ -134,5 +132,5 @@ function Base.iterate(iter::LayoutIterator{<:Spring}, state) locs[i] += force[i] .* scale .* mask end - return locs, (iteration + 1, locs, pin) + return locs, (iteration + 1, locs, pin, rng) end