Skip to content

Commit

Permalink
Introduce task-local and free-standing xoshiro RNG
Browse files Browse the repository at this point in the history
  • Loading branch information
chethega authored and JeffBezanson committed Apr 20, 2021
1 parent 592db58 commit bce6397
Show file tree
Hide file tree
Showing 7 changed files with 658 additions and 5 deletions.
16 changes: 12 additions & 4 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2462,7 +2462,7 @@ void jl_init_types(void) JL_GC_DISABLED
NULL,
jl_any_type,
jl_emptysvec,
jl_perm_symsvec(10,
jl_perm_symsvec(14,
"next",
"queue",
"storage",
Expand All @@ -2472,8 +2472,12 @@ void jl_init_types(void) JL_GC_DISABLED
"code",
"_state",
"sticky",
"_isexception"),
jl_svec(10,
"_isexception",
"rngState0",
"rngState1",
"rngState2",
"rngState3"),
jl_svec(14,
jl_any_type,
jl_any_type,
jl_any_type,
Expand All @@ -2483,7 +2487,11 @@ void jl_init_types(void) JL_GC_DISABLED
jl_any_type,
jl_uint8_type,
jl_bool_type,
jl_bool_type),
jl_bool_type,
jl_uint64_type,
jl_uint64_type,
jl_uint64_type,
jl_uint64_type),
0, 1, 6);
jl_value_t *listt = jl_new_struct(jl_uniontype_type, jl_task_type, jl_nothing_type);
jl_svecset(jl_task_type->types, 0, listt);
Expand Down
4 changes: 4 additions & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -1809,6 +1809,10 @@ typedef struct _jl_task_t {
uint8_t _state;
uint8_t sticky; // record whether this Task can be migrated to a new thread
uint8_t _isexception; // set if `result` is an exception to throw or that we exited with
uint64_t rngState0; // really rngState[4], but more convenient to split
uint64_t rngState1;
uint64_t rngState2;
uint64_t rngState3;

// hidden state:
// id of owning thread - does not need to be defined until the task runs
Expand Down
60 changes: 60 additions & 0 deletions src/task.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "julia_internal.h"
#include "threading.h"
#include "julia_assert.h"
#include "support/hashing.h"

#ifdef __cplusplus
extern "C" {
Expand Down Expand Up @@ -666,6 +667,63 @@ JL_DLLEXPORT void jl_rethrow_other(jl_value_t *e JL_MAYBE_UNROOTED)
throw_internal(NULL);
}

/* This is xoshiro256++ 1.0, used for tasklocal random number generation in julia.
This implementation is intended for embedders and internal use by the runtime, and is
based on the reference implementation on http://prng.di.unimi.it
Credits go to Sebastiano Vigna for coming up with this PRNG.
There is a pure julia implementation in stdlib that tends to be faster when used from
within julia, due to inlining and more agressive architecture-specific optimizations.
*/
JL_DLLEXPORT uint64_t jl_tasklocal_genrandom(jl_task_t *task) JL_NOTSAFEPOINT
{
uint64_t s0 = task->rngState0;
uint64_t s1 = task->rngState1;
uint64_t s2 = task->rngState2;
uint64_t s3 = task->rngState3;

uint64_t t = s0 << 17;
uint64_t tmp = s0 + s3;
uint64_t res = ((tmp << 23) | (tmp >> 41)) + s0;
s2 ^= s0;
s3 ^= s1;
s1 ^= s2;
s0 ^= s3;
s2 ^= t;
s3 = (s3 << 45) | (s3 >> 19);

task->rngState0 = s0;
task->rngState1 = s1;
task->rngState2 = s2;
task->rngState3 = s3;
return res;
}

void rng_split(jl_task_t *from, jl_task_t *to) JL_NOTSAFEPOINT
{
/* TODO: consider a less ad-hoc construction
Ideally we could just use the output of the random stream to seed the initial
state of the child. Out of an overabundance of caution we multiply with
effectively random coefficients, to break possible self-interactions.
It is not the goal to mix bits -- we work under the assumption that the
source is well-seeded, and its output looks effectively random.
However, xoshiro has never been studied in the mode where we seed the
initial state with the output of another xoshiro instance.
Constants have nothing up their sleeve:
0x02011ce34bce797f == hash(UInt(1))|0x01
0x5a94851fb48a6e05 == hash(UInt(2))|0x01
0x3688cf5d48899fa7 == hash(UInt(3))|0x01
0x867b4bb4c42e5661 == hash(UInt(4))|0x01
*/
to->rngState0 = 0x02011ce34bce797f * jl_tasklocal_genrandom(from);
to->rngState1 = 0x5a94851fb48a6e05 * jl_tasklocal_genrandom(from);
to->rngState2 = 0x3688cf5d48899fa7 * jl_tasklocal_genrandom(from);
to->rngState3 = 0x867b4bb4c42e5661 * jl_tasklocal_genrandom(from);
}

JL_DLLEXPORT jl_task_t *jl_new_task(jl_function_t *start, jl_value_t *completion_future, size_t ssize)
{
jl_ptls_t ptls = jl_get_ptls_states();
Expand Down Expand Up @@ -701,6 +759,8 @@ JL_DLLEXPORT jl_task_t *jl_new_task(jl_function_t *start, jl_value_t *completion
t->_isexception = 0;
// Inherit logger state from parent task
t->logstate = ptls->current_task->logstate;
// Fork task-local random state from parent
rng_split(ptls->current_task, t);
// there is no active exception handler available on this stack yet
t->eh = NULL;
t->sticky = 1;
Expand Down
211 changes: 211 additions & 0 deletions stdlib/Random/src/RNGs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,216 @@ function reset_caches!(r::MersenneTwister)
r
end


## Xoshiro RNG
# Lots of implementation is shared with TaskLocal

"""
Xoshiro
Xoshiro256++ is a fast pseudorandom number generator originally developed by Sebastian Vigna.
Reference implementation is available on on http://prng.di.unimi.it
Apart from the high speed, Xoshiro has a small memory footprint, making it suitable for
applications where many different random states need to be held for long time.
Julia's Xoshiro implementation has a bulk-generation mode; this seeds new virtual PRNGs
from the parent, and uses SIMD to generate in parallel (i.e. the bulk stream consists of
multiple interleaved xoshiro instances).
The virtual PRNGs are discarded once the bulk request has been serviced (and should cause
no heap allocations).
"""
mutable struct Xoshiro <: AbstractRNG
s0::UInt64
s1::UInt64
s2::UInt64
s3::UInt64
end

Xoshiro(::Nothing) = Xoshiro()

function Xoshiro(parent::AbstractRNG = RandomDevice())
# Constants have nothing up their sleeve, see task.c
# 0x02011ce34bce797f == hash(UInt(1))|0x01
# 0x5a94851fb48a6e05 == hash(UInt(2))|0x01
# 0x3688cf5d48899fa7 == hash(UInt(3))|0x01
# 0x867b4bb4c42e5661 == hash(UInt(4))|0x01

Xoshiro(0x02011ce34bce797f * rand(parent, UInt64),
0x5a94851fb48a6e05 * rand(parent, UInt64),
0x3688cf5d48899fa7 * rand(parent, UInt64),
0x867b4bb4c42e5661 * rand(parent, UInt64))
end

copy(rng::Xoshiro) = Xoshiro(rng.s0, rng.s1, rng.s2, rng.s3)

function copy!(dst::Xoshiro, src::Xoshiro)
dst.s0, dst.s1, dst.s2, dst.s3 = src.s0, src.s1, src.s2, src.s3
dst
end

function ==(a::Xoshiro, b::Xoshiro)
a.s0 == b.s0 && a.s1 == b.s1 && a.s2 == b.s2 && a.s3 == b.s3
end

rng_native_52(::Xoshiro) = UInt64

function seed!(rng::Xoshiro, s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64)
# see task.c
s = Base.hash_uint64(s0)
rng.s0 = s
s += Base.hash_uint64(s1)
rng.s1 = s
s += Base.hash_uint64(s2)
rng.s2 = s
s += Base.hash_uint64(s3)
rng.s3 = s
rng
end

@inline function rand(rng::Xoshiro, ::SamplerType{UInt64})
s0, s1, s2, s3 = rng.s0, rng.s1, rng.s2, rng.s3
tmp = s0 + s3
res = tmp << 23 | tmp >> 41
t = s1 << 17
s2 = xor(s2, s0)
s3 = xor(s3, s1)
s1 = xor(s1, s2)
s0 = xor(s0, s3)
s2 = xor(s2, t)
s3 = s3 << 45 | s3 >> 19
rng.s0, rng.s1, rng.s2, rng.s3 = s0, s1, s2, s3
res
end


## Task local RNG

"""
TaskLocal
The TaskLocal RNG has state that is local to its task, not its thread.
It is seeded upon task creation, from the state of its parent task.
Therefore, task creation is an event that changes the parent's RNG state.
As an upside, the TaskLocal RNG is pretty fast, and permits reproducible
multithreaded simulations (barring race conditions), independent of scheduler
decisions. As long as the number of threads is not used to make decisions on
task creation, simulation results are also independent of the number of available
threads / CPUs. The random stream should not depend on hardware specifics, up to
endianness and possibly word size.
Using or seeding the RNG of any other task than the one returned by `current_task()`
is undefined behavior: it will work most of the time, and may sometimes fail silently.
"""
struct TaskLocal <: AbstractRNG end
TaskLocal(::Nothing) = TaskLocal()
rng_native_52(::TaskLocal) = UInt64

function seed!(rng::TaskLocal, s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64)
# TODO: Consider a less ad-hoc construction
# We can afford burning a handful of cycles here, and we don't want any
# surprises with respect to bad seeds / bad interactions.
t = current_task()
s = hash(s0)
t.rngState0 = s
s += hash(s1)
t.rngState1 = s
s += hash(s2)
t.rngState2 = s
s += hash(s3)
t.rngState3 = s
rand(rng, UInt64)
rand(rng, UInt64)
rand(rng, UInt64)
rand(rng, UInt64)
rng
end

@inline function rand(::TaskLocal, ::SamplerType{UInt64})
task = current_task()
s0, s1, s2, s3 = task.rngState0, task.rngState1, task.rngState2, task.rngState3
tmp = s0 + s3
res = tmp << 23 | tmp >> 41
t = s1 << 17
s2 = xor(s2, s0)
s3 = xor(s3, s1)
s1 = xor(s1, s2)
s0 = xor(s0, s3)
s2 = xor(s2, t)
s3 = s3 << 45 | s3 >> 19
task.rngState0, task.rngState1, task.rngState2, task.rngState3 = s0, s1, s2, s3
res
end

# Shared implementation between Xoshiro and TaskLocal -- seeding
function seed!(rng::Union{TaskLocal, Xoshiro}, seed::UInt128)
seed0 = seed % UInt64
seed1 = (seed>>>64) % UInt64
seed!(rng, seed0, seed1, zero(UInt64), zero(UInt64))
end
seed!(rng::Union{TaskLocal, Xoshiro}, seed::Integer) = seed!(rng, seed%UInt64, zero(UInt64), zero(UInt64), zero(UInt64))
seed!(rng::Union{TaskLocal, Xoshiro}, ::Nothing) = seed!(rng)

seed!(rng::Union{TaskLocal, Xoshiro}) =
seed!(rng, rand(RandomDevice(), UInt64), rand(RandomDevice(), UInt64),
rand(RandomDevice(), UInt64), rand(RandomDevice(), UInt64))

function seed!(rng::Union{TaskLocal, Xoshiro}, seed::AbstractVector{UInt64})
if length(seed) > 4
throw(ArgumentError("seed should have no more than 256 bits"))
end
seed0 = length(seed)>0 ? seed[1] : UInt64(0)
seed1 = length(seed)>1 ? seed[2] : UInt64(0)
seed2 = length(seed)>2 ? seed[3] : UInt64(0)
seed3 = length(seed)>3 ? seed[4] : UInt64(0)
seed!(rng, seed0, seed1, seed2, seed3)
end

function seed!(rng::Union{TaskLocal, Xoshiro}, seed::AbstractVector{UInt32})
if iseven(length(seed))
seed!(rng, reinterpret(UInt64, seed))
else
seed!(rng, UInt64[reinterpret(UInt64, @view(seed[begin:end-1])); seed[end] % UInt64])
end
end

@inline function rand(rng::Union{TaskLocal, Xoshiro}, ::SamplerType{UInt128})
first = rand(rng, UInt64)
second = rand(rng,UInt64)
second + UInt128(first)<<64
end

@inline rand(rng::Union{TaskLocal, Xoshiro}, ::SamplerType{Int128}) = rand(rng, UInt128) % Int128

@inline rand(rng::Union{TaskLocal, Xoshiro}, ::SamplerType{T}) where {T<:Union{Bool, UInt8, Int8, UInt16, Int16, UInt32, Int32, Int64}} = rand(rng, UInt64) % T

function copy(rng::TaskLocal)
t = current_task()
Xoshiro(t.rngState0, t.rngState1, t.rngState2, t.rngState3)
end

function copy!(dst::TaskLocal, src::Xoshiro)
t = current_task()
t.rngState0, t.rngState1, t.rngState2, t.rngState3 = src.s0, src.s1, src.s2, src.s3
dst
end

function copy!(dst::Xoshiro, src::TaskLocal)
t = current_task()
dst.s0, dst.s1, dst.s2, dst.s3 = t.rngState0, t.rngState1, t.rngState2, t.rngState3
dst
end

function ==(a::Xoshiro, b::TaskLocal)
t = current_task()
a.s0 == t.rngState0 && a.s1 == t.rngState1 && a.s2 == t.rngState2 && a.s3 == t.rngState3
end

==(a::TaskLocal, b::Xoshiro) = b == a

### low level API

#### floats

mt_avail(r::MersenneTwister) = MT_CACHE_F - r.idxF
Expand Down Expand Up @@ -382,6 +592,7 @@ end

function __init__()
resize!(empty!(THREAD_RNGs), Threads.nthreads()) # ensures that we didn't save a bad object
seed!(TaskLocal())
end


Expand Down
3 changes: 2 additions & 1 deletion stdlib/Random/src/Random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ export rand!, randn!,
shuffle, shuffle!,
randperm, randperm!,
randcycle, randcycle!,
AbstractRNG, MersenneTwister, RandomDevice
AbstractRNG, MersenneTwister, RandomDevice, TaskLocal, Xoshiro

## general definitions

Expand Down Expand Up @@ -296,6 +296,7 @@ include("RNGs.jl")
include("generation.jl")
include("normal.jl")
include("misc.jl")
include("XoshiroSimd.jl")

## rand & rand! & seed! docstrings

Expand Down
Loading

0 comments on commit bce6397

Please sign in to comment.