Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Tasklocal xoshiro #34852

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2451,7 +2451,7 @@ void jl_init_types(void) JL_GC_DISABLED
NULL,
jl_any_type,
jl_emptysvec,
jl_perm_symsvec(11,
jl_perm_symsvec(15,
"next",
"queue",
"storage",
Expand All @@ -2461,9 +2461,13 @@ void jl_init_types(void) JL_GC_DISABLED
"backtrace",
"logstate",
"code",
"rngState0",
"rngState1",
"rngState2",
"rngState3",
"_state",
"sticky"),
jl_svec(11,
jl_svec(15,
jl_any_type,
jl_any_type,
jl_any_type,
Expand All @@ -2473,6 +2477,10 @@ void jl_init_types(void) JL_GC_DISABLED
jl_any_type,
jl_any_type,
jl_any_type,
jl_uint64_type,
jl_uint64_type,
jl_uint64_type,
jl_uint64_type,
jl_uint8_type,
jl_bool_type),
0, 1, 8);
Expand Down
4 changes: 4 additions & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -1797,6 +1797,10 @@ typedef struct _jl_task_t {
jl_value_t *backtrace;
jl_value_t *logstate;
jl_function_t *start;
uint64_t rngState0; // really rngState[4], but more convenient to split
uint64_t rngState1;
uint64_t rngState2;
uint64_t rngState3;
uint8_t _state;
uint8_t sticky; // record whether this Task can be migrated to a new thread

Expand Down
80 changes: 80 additions & 0 deletions src/task.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "julia_internal.h"
#include "threading.h"
#include "julia_assert.h"
#include "support/hashing.h"

#ifdef __cplusplus
extern "C" {
Expand Down Expand Up @@ -602,6 +603,83 @@ JL_DLLEXPORT void jl_rethrow_other(jl_value_t *e JL_MAYBE_UNROOTED)
throw_internal(NULL);
}

/* This is xoshiro256++ 1.0, used for tasklocal random number generation in julia.
This implementation is intended for embedders and internal use by the runtime, and is
based on the reference implementation on http://prng.di.unimi.it

Credits go to Sebastiano Vigna for coming up with this PRNG.

There is a pure julia implementation in stdlib that tends to be faster when used from
within julia, due to inlining and more agressive architecture-specific optimizations.
*/
JL_DLLEXPORT uint64_t jl_tasklocal_genrandom(jl_task_t *task) {
uint64_t s0 = task->rngState0;
uint64_t s1 = task->rngState1;
uint64_t s2 = task->rngState2;
uint64_t s3 = task->rngState3;

uint64_t t = s0 << 17;
uint64_t tmp = s0 + s3;
uint64_t res = ((tmp << 23) | (tmp >> 41)) + s0;
s2 ^=s0;
s3 ^=s1;
s1 ^= s2;
s0 ^= s3;
s2 ^= t;
s3 = (s3 << 45) | (s3 >> 19);

task->rngState0 = s0;
task->rngState1 = s1;
task->rngState2 = s2;
task->rngState3 = s3;
return res;
};

JL_DLLEXPORT void jl_tasklocal_seedrandom(jl_task_t *task,
uint64_t seed0,
uint64_t seed1,
uint64_t seed2,
uint64_t seed3) {
// Fixme: Consider a less ad-hoc construction
// We can afford burning a handful of cycles here, and we don't want any
// surprises with respect to bad seeds / bad interactions
uint64_t s = int64hash(seed0);
task->rngState0 = s;
s += int64hash(seed1);
task->rngState1 = s;
s += int64hash(seed2);
task->rngState2 = s;
s += int64hash(seed3);
task->rngState3 = s;
jl_tasklocal_genrandom(task);
jl_tasklocal_genrandom(task);
jl_tasklocal_genrandom(task);
jl_tasklocal_genrandom(task);
};

void rng_split(jl_task_t *from, jl_task_t *to) {
/* Fixme: consider a less ad-hoc construction
Ideally we could just use the output of the random stream to seed the initial
state of the child. Out of an overabundance of caution we multiply with
effectively random coefficients, to break possible self-interactions.

It is not the goal to mix bits -- we work under the assumption that the
source is well-seeded, and its output looks effectively random.
However, xoshiro has never been studied in the mode where we seed the
initial state with the output of another xoshiro instance.

Constants have nothing up their sleeve:
0x02011ce34bce797f == hash(UInt(1))|0x01
0x5a94851fb48a6e05 == hash(UInt(2))|0x01
0x3688cf5d48899fa7 == hash(UInt(3))|0x01
0x867b4bb4c42e5661 == hash(UInt(4))|0x01
*/
to -> rngState0 = 0x02011ce34bce797f * jl_tasklocal_genrandom(from);
to -> rngState1 = 0x5a94851fb48a6e05 * jl_tasklocal_genrandom(from);
to -> rngState2 = 0x3688cf5d48899fa7 * jl_tasklocal_genrandom(from);
to -> rngState3 = 0x867b4bb4c42e5661 * jl_tasklocal_genrandom(from);
}

JL_DLLEXPORT jl_task_t *jl_new_task(jl_function_t *start, jl_value_t *completion_future, size_t ssize)
{
jl_ptls_t ptls = jl_get_ptls_states();
Expand Down Expand Up @@ -637,6 +715,8 @@ JL_DLLEXPORT jl_task_t *jl_new_task(jl_function_t *start, jl_value_t *completion
t->backtrace = jl_nothing;
// Inherit logger state from parent task
t->logstate = ptls->current_task->logstate;
// Fork task-local random state from parent
rng_split(ptls->current_task, t);
// there is no active exception handler available on this stack yet
t->eh = NULL;
t->sticky = 1;
Expand Down
144 changes: 144 additions & 0 deletions stdlib/Random/src/RNGs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,149 @@ function fillcache_zeros!(r::MersenneTwister)
r
end

## Xoshiro RNG. Unless this becomes the new default, consider excising to external library
#Lots of implementation is shared with TaskLocal
"""
Xoshiro

Xoshiro256++ is a fast pseudorandom number generator originally developed by Sebastian Vigna. Reference implementation is available on on http://prng.di.unimi.it

Apart from the high speed, Xoshiro has a small memory footprint, making it suitable for applications where many different random states need to be held for long time.

Julia's Xoshiro implementation has a bulk-generation mode; this seeds new virtual PRNGs from the parent, and uses SIMD to generate in parallel (i.e. the bulk stream consists of multiple interleaved xoshiro instances). The virtual PRNGs are discarded once the bulk request has been serviced (and should cause no heap allocations).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please insert newlines (here and elsewhere), this one doesn't even fit on a big screen! This is will make reviewing on Github easier.

"""
mutable struct Xoshiro <: AbstractRNG
s0::UInt64
s1::UInt64
s2::UInt64
s3::UInt64
end

Xoshiro(::Nothing) = Xoshiro()

function Xoshiro(parent::AbstractRNG = RandomDevice())
# Constants have nothing up their sleeve, cf task.c
# 0x02011ce34bce797f == hash(UInt(1))|0x01
# 0x5a94851fb48a6e05 == hash(UInt(2))|0x01
# 0x3688cf5d48899fa7 == hash(UInt(3))|0x01
# 0x867b4bb4c42e5661 == hash(UInt(4))|0x01

Xoshiro(0x02011ce34bce797f * rand(parent, UInt64),
0x5a94851fb48a6e05 * rand(parent, UInt64),
0x3688cf5d48899fa7 * rand(parent, UInt64),
0x867b4bb4c42e5661 * rand(parent, UInt64))
end

rng_native_52(::Xoshiro) = UInt64

function seed!(rng::Xoshiro, s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64)
#cf task.c
s = Base.hash_uint64(s0)
rng.s0 = s
s += Base.hash_uint64(s1)
rng.s1 = s
s += Base.hash_uint64(s2)
rng.s2 = s
s += Base.hash_uint64(s3)
rng.s3 = s
end


@inline function rand(rng::Xoshiro, ::Type{UInt64})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@inline function rand(rng::Xoshiro, ::Type{UInt64})
@inline function rand(rng::Xoshiro, ::SamplerType{UInt64})

s0, s1, s2, s3 = rng.s0, rng.s1, rng.s2, rng.s3
tmp = s0 + s3
res = tmp << 23 | tmp >> 41
t = s1 << 17
s2 = xor(s2, s0)
s3 = xor(s3, s1)
s1 = xor(s1, s2)
s0 = xor(s0, s3)
s2 = xor(s2, t)
s3 = s3 << 45 | s3 >> 19
rng.s0, rng.s1, rng.s2, rng.s3 = s0, s1, s2, s3
res
end


## Task local RNG
"""
TaskLocal

The TaskLocal RNG has state that is local to its task, not its thread. It is seeded upon task creation, from the state of its parent task. Therefore, task creation is an event that changes the parent's RNG state.

As an upside, the TaskLocal RNG is pretty fast, and permits reproducible multithreaded simulations (barring race conditions), independent of scheduler decisions. As long as the number of threads is not used to make decisions on task creation, simulation results are also independent of the number of available threads / CPUs. The random stream should not depend on hardware specifics, up to endianness and possibly word size.

Using or seeding the RNG of any other task than the one returned by `current_task()` is undefined behavior: it will work most of the time, and may sometimes fail silently.
"""
struct TaskLocal <: AbstractRNG end
TaskLocal(::Nothing) = TaskLocal()
rng_native_52(::TaskLocal) = UInt64

function seed!(::TaskLocal, s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64)
task = current_task()
ccall(:jl_tasklocal_seedrandom, Nothing, (Ref{Task}, UInt64, UInt64, UInt64, UInt64), task, s0,s1,s2,s3)
TaskLocal()
end


@inline function rand(::TaskLocal, ::Type{UInt64})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@inline function rand(::TaskLocal, ::Type{UInt64})
@inline function rand(::TaskLocal, ::SamplerType{UInt64})

task = current_task()
s0, s1, s2, s3 = task.rngState0, task.rngState1, task.rngState2, task.rngState3
tmp = s0 + s3
res = tmp << 23 | tmp >> 41
t = s1 << 17
s2 = xor(s2, s0)
s3 = xor(s3, s1)
s1 = xor(s1, s2)
s0 = xor(s0, s3)
s2 = xor(s2, t)
s3 = s3 << 45 | s3 >> 19
task.rngState0, task.rngState1, task.rngState2, task.rngState3 = s0, s1, s2, s3
res
end

#Shared implementation between Xoshiro and TaskLocal -- seeding
function seed!(rng::Union{TaskLocal, Xoshiro}, seed::Union{Int128, UInt128, BigInt})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, I would prefer being more conservative on valid seeds here (like with MersenneTwister): with a big BigInt, high bits will simply be discarded, e.g. two seeds which differ by 2^130 lead to the same initialization state. Similarly, assuming s::Int128 < 0, then s and s % UInt128 lead to the same state, while they are represent distinct numbers.

So I suggest defining this method only for UInt128 (like MT does for Vector{UInt32}), and define other integer methods with something like seed!(rng, seed::Integer) = seed!(rng, UInt128(seed)) (two equal integers should lead to the same state).

seed0 = seed % UInt64
seed1 = (seed>>64) % UInt64
seed!(rng, seed0, seed1, zero(UInt64), zero(UInt64))
end
seed!(rng::Union{TaskLocal, Xoshiro}, seed::Integer) = seed!(rng, seed%UInt64, zero(UInt64), zero(UInt64), zero(UInt64))
seed!(rng::Union{TaskLocal, Xoshiro}, ::Nothing) = seed!(rng)

seed!(rng::Union{TaskLocal, Xoshiro}) = seed!(rng,
rand(RandomDevice(), UInt64),
rand(RandomDevice(), UInt64),
rand(RandomDevice(), UInt64),
rand(RandomDevice(), UInt64))

function seed!(rng::Union{TaskLocal, Xoshiro}, seed::AbstractVector{UInt64})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An error should be thrown if length(seed) > 4.

seed0 = length(seed)>0 ? seed[1] : UInt64(0)
seed1 = length(seed)>1 ? seed[2] : UInt64(0)
seed2 = length(seed)>2 ? seed[3] : UInt64(0)
seed3 = length(seed)>3 ? seed[4] : UInt64(0)
seed!(rng, seed0, seed1, seed2, seed3)
end

function seed!(rng::Union{TaskLocal, Xoshiro}, seed::AbstractVector{UInt32})
#can only process an even length
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't the last entry of an odd-lengthed vector be cast to an UInt64 ? I guess that could be implemented later, but no bits of the input should be discarded (i.e. better to error out on odd length).

len = min(length(seed), 8) & ~1
seed!(rng, reinterpret(UInt64, view(seed, 1:len)))
end

@inline function rand(rng::Union{TaskLocal, Xoshiro}, ::Type{UInt128})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Define this one directly for SamplerType{UInt128}.

first = rand(rng, UInt64)
second = rand(rng,UInt64)
second + UInt128(first)<<64
end
@inline rand(rng::Union{TaskLocal, Xoshiro}, ::SamplerType{UInt128}) = rand(rng, UInt128)
@inline rand(rng::Union{TaskLocal, Xoshiro}, ::Type{Int128}) = rand(rng, UInt128) % Int128
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@inline rand(rng::Union{TaskLocal, Xoshiro}, ::Type{Int128}) = rand(rng, UInt128) % Int128

@inline rand(rng::Union{TaskLocal, Xoshiro}, ::SamplerType{Int128}) = rand(rng, UInt128) % Int128

@inline rand(rng::Union{TaskLocal, Xoshiro}, ::Type{T}) where {T<:Union{Bool, UInt8, Int8, UInt16, Int16, UInt32, Int32, Int64}} = rand(rng, UInt64) % T
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@inline rand(rng::Union{TaskLocal, Xoshiro}, ::Type{T}) where {T<:Union{Bool, UInt8, Int8, UInt16, Int16, UInt32, Int32, Int64}} = rand(rng, UInt64) % T

@inline rand(rng::Union{TaskLocal, Xoshiro}, ::SamplerType{T}) where {T<:Union{Bool, UInt8, Int8, UInt16, Int16, UInt32, Int32, Int64, UInt64}} = rand(rng, UInt64) % T
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the UInt64 from the Union as rand(rng, ::SamplerType{UInt64}) should be defined elsewhere as the "native" method.


Sampler(rng::Union{TaskLocal, Xoshiro}, ::Type{T}) where T<:Union{Float64, Float32} = SamplerType{T}()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be necessary, this definition is already the default.


### low level API

Expand Down Expand Up @@ -310,6 +453,7 @@ end

function __init__()
resize!(empty!(THREAD_RNGs), Threads.nthreads()) # ensures that we didn't save a bad object
seed!(TaskLocal())
end


Expand Down
3 changes: 2 additions & 1 deletion stdlib/Random/src/Random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ export rand!, randn!,
shuffle, shuffle!,
randperm, randperm!,
randcycle, randcycle!,
AbstractRNG, MersenneTwister, RandomDevice
AbstractRNG, MersenneTwister, RandomDevice, TaskLocal

## general definitions

Expand Down Expand Up @@ -296,6 +296,7 @@ include("RNGs.jl")
include("generation.jl")
include("normal.jl")
include("misc.jl")
include("XoshiroSimd.jl")

## rand & rand! & seed! docstrings

Expand Down
Loading