Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add XoshiroSplit type to copy added task RNG state #51271

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions stdlib/Random/src/RNGs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ typeof_rng(::_GLOBAL_RNG) = TaskLocalRNG
"""
default_rng() -> rng

Return the default global random number generator (RNG).
Return the default task-local random number generator (RNG).

!!! note
What the default RNG is is an implementation detail. Across different versions of
Expand All @@ -346,8 +346,8 @@ Return the default global random number generator (RNG).
@inline default_rng() = TaskLocalRNG()
@inline default_rng(tid::Int) = TaskLocalRNG()

copy!(dst::Xoshiro, ::_GLOBAL_RNG) = copy!(dst, default_rng())
copy!(::_GLOBAL_RNG, src::Xoshiro) = copy!(default_rng(), src)
copy!(dst::XoshiroSplit, ::_GLOBAL_RNG) = copy!(dst, default_rng())
copy!(::_GLOBAL_RNG, src::XoshiroSplit) = copy!(default_rng(), src)
copy(::_GLOBAL_RNG) = copy(default_rng())

GLOBAL_SEED = 0
Expand Down
211 changes: 146 additions & 65 deletions stdlib/Random/src/Xoshiro.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

## Xoshiro RNG
# Lots of implementation is shared with TaskLocalRNG
# Lots of implementation is shared with TaskLocalRNG and XoshiroSplit

"""
Xoshiro(seed)
Expand Down Expand Up @@ -53,14 +53,25 @@ mutable struct Xoshiro <: AbstractRNG
Xoshiro(seed=nothing) = seed!(new(), seed)
end

function setstate!(x::Xoshiro, s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64)
# NON-PUBLIC
@inline function get_xoshiro_state(x::Xoshiro)
x.s0, x.s1, x.s2, x.s3
end

# NON-PUBLIC
@inline function set_xoshiro_state!(x::Xoshiro, s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64)
x.s0 = s0
x.s1 = s1
x.s2 = s2
x.s3 = s3
x
end

# NON-PUBLIC
function seedstate!(x::Xoshiro, s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64)
set_xoshiro_state!(x, s0, s1, s2, s3)
end

copy(rng::Xoshiro) = Xoshiro(rng.s0, rng.s1, rng.s2, rng.s3)

function copy!(dst::Xoshiro, src::Xoshiro)
Expand All @@ -72,21 +83,66 @@ function ==(a::Xoshiro, b::Xoshiro)
a.s0 == b.s0 && a.s1 == b.s1 && a.s2 == b.s2 && a.s3 == b.s3
end

rng_native_52(::Xoshiro) = UInt64

@inline function rand(rng::Xoshiro, ::SamplerType{UInt64})
s0, s1, s2, s3 = rng.s0, rng.s1, rng.s2, rng.s3
tmp = s0 + s3
res = ((tmp << 23) | (tmp >> 41)) + s0
t = s1 << 17
s2 = xor(s2, s0)
s3 = xor(s3, s1)
s1 = xor(s1, s2)
s0 = xor(s0, s3)
s2 = xor(s2, t)
s3 = s3 << 45 | s3 >> 19
rng.s0, rng.s1, rng.s2, rng.s3 = s0, s1, s2, s3
res
"""
XoshiroSplit(seed)
XoshiroSplit()

Creates the same stream as Xoshiro, but has an additional splitting ability.

For more discussion, cf rng_split in task.c

This is the type currently returned by `copy(default_rng())`.

!!! note
What the default RNG is is an implementation detail. Across different versions of
Julia, you should not expect the default RNG to be always the same, nor that it will
return the same stream of random numbers for a given seed.
"""
mutable struct XoshiroSplit <: AbstractRNG
s0::UInt64
s1::UInt64
s2::UInt64
s3::UInt64
s4::UInt64

XoshiroSplit(
s0::Integer, s1::Integer, s2::Integer, s3::Integer, # xoshiro256 state
s4::Integer, # internal splitmix state
) = new(s0, s1, s2, s3, s4)
XoshiroSplit(seed=nothing) = seed!(new(), seed)
end

# NON-PUBLIC
@inline function get_xoshiro_state(x::XoshiroSplit)
x.s0, x.s1, x.s2, x.s3
end

# NON-PUBLIC
@inline function set_xoshiro_state!(x::XoshiroSplit, s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64)
x.s0 = s0
x.s1 = s1
x.s2 = s2
x.s3 = s3
x
end

# NON-PUBLIC
function seedstate!(x::XoshiroSplit, s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64)
set_xoshiro_state!(x, s0, s1, s2, s3)
x.s4 = 1s0 + 3s1 + 5s2 + 7s3
x
end

copy(rng::XoshiroSplit) = XoshiroSplit(rng.s0, rng.s1, rng.s2, rng.s3, rng.s4)

function copy!(dst::XoshiroSplit, src::XoshiroSplit)
dst.s0, dst.s1, dst.s2, dst.s3, dst.s4 = src.s0, src.s1, src.s2, src.s3, src.s4
dst
end

function ==(a::XoshiroSplit, b::XoshiroSplit)
a.s0 == b.s0 && a.s1 == b.s1 && a.s2 == b.s2 && a.s3 == b.s3 && a.s4 == b.s4
end


Expand All @@ -111,25 +167,74 @@ is undefined behavior: it will work most of the time, and may sometimes fail sil
"""
struct TaskLocalRNG <: AbstractRNG end
TaskLocalRNG(::Nothing) = TaskLocalRNG()
rng_native_52(::TaskLocalRNG) = UInt64

function setstate!(
x::TaskLocalRNG,
s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64, # xoshiro256 state
s4::UInt64 = 1s0 + 3s1 + 5s2 + 7s3, # internal splitmix state
)
# NON-PUBLIC
@inline function get_xoshiro_state(x::TaskLocalRNG)
t = current_task()
t.rngState0, t.rngState1, t.rngState2, t.rngState3
end

# NON-PUBLIC
@inline function set_xoshiro_state!(x::TaskLocalRNG, s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64)
t = current_task()
t.rngState0 = s0
t.rngState1 = s1
t.rngState2 = s2
t.rngState3 = s3
t.rngState4 = s4
x
end

@inline function rand(::TaskLocalRNG, ::SamplerType{UInt64})
task = current_task()
s0, s1, s2, s3 = task.rngState0, task.rngState1, task.rngState2, task.rngState3
# NON-PUBLIC
function seedstate!(x::TaskLocalRNG, s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64)
t = current_task()
t.rngState0 = s0
t.rngState1 = s1
t.rngState2 = s2
t.rngState3 = s3
t.rngState4 = 1s0 + 3s1 + 5s2 + 7s3
x
end

function copy(rng::TaskLocalRNG)
t = current_task()
XoshiroSplit(t.rngState0, t.rngState1, t.rngState2, t.rngState3, t.rngState4)
end

function copy!(dst::TaskLocalRNG, src::XoshiroSplit)
t = current_task()
t.rngState0 = src.s0
t.rngState1 = src.s1
t.rngState2 = src.s2
t.rngState3 = src.s3
t.rngState4 = src.s4
return dst
end

function copy!(dst::XoshiroSplit, src::TaskLocalRNG)
t = current_task()
dst.s0 = t.rngState0
dst.s1 = t.rngState1
dst.s2 = t.rngState2
dst.s3 = t.rngState3
dst.s4 = t.rngState4
return dst
end

function ==(a::XoshiroSplit, b::TaskLocalRNG)
t = current_task()
a.s0 == t.rngState0 && a.s1 == t.rngState1 && a.s2 == t.rngState2 && a.s3 == t.rngState3 && a.s4 == t.rngState4
end

==(a::TaskLocalRNG, b::XoshiroSplit) = b == a

# Shared implementation between Xoshiro, XoshiroSplit, and TaskLocalRNG

const XoshiroLike = Union{TaskLocalRNG, Xoshiro, XoshiroSplit}

rng_native_52(::XoshiroLike) = UInt64

@inline function rand(rng::XoshiroLike, ::SamplerType{UInt64})
s0, s1, s2, s3 = get_xoshiro_state(rng)
tmp = s0 + s3
res = ((tmp << 23) | (tmp >> 41)) + s0
t = s1 << 17
Expand All @@ -139,78 +244,54 @@ end
s0 ⊻= s3
s2 ⊻= t
s3 = s3 << 45 | s3 >> 19
task.rngState0, task.rngState1, task.rngState2, task.rngState3 = s0, s1, s2, s3
set_xoshiro_state!(rng, s0, s1, s2, s3)
res
end

# Shared implementation between Xoshiro and TaskLocalRNG -- seeding

function seed!(rng::Union{TaskLocalRNG,Xoshiro})
function seed!(rng::XoshiroLike)
# as we get good randomness from RandomDevice, we can skip hashing
rd = RandomDevice()
setstate!(rng, rand(rd, UInt64), rand(rd, UInt64), rand(rd, UInt64), rand(rd, UInt64))
seedstate!(rng, rand(rd, UInt64), rand(rd, UInt64), rand(rd, UInt64), rand(rd, UInt64))
end

function seed!(rng::Union{TaskLocalRNG,Xoshiro}, seed::Union{Vector{UInt32}, Vector{UInt64}})
function seed!(rng::XoshiroLike, seed::Union{Vector{UInt32}, Vector{UInt64}})
c = SHA.SHA2_256_CTX()
SHA.update!(c, reinterpret(UInt8, seed))
s0, s1, s2, s3 = reinterpret(UInt64, SHA.digest!(c))
setstate!(rng, s0, s1, s2, s3)
seedstate!(rng, s0, s1, s2, s3)
end

seed!(rng::Union{TaskLocalRNG, Xoshiro}, seed::Integer) = seed!(rng, make_seed(seed))
seed!(rng::XoshiroLike, seed::Integer) = seed!(rng, make_seed(seed))


@inline function rand(rng::Union{TaskLocalRNG, Xoshiro}, ::SamplerType{UInt128})
@inline function rand(rng::XoshiroLike, ::SamplerType{UInt128})
first = rand(rng, UInt64)
second = rand(rng,UInt64)
second + UInt128(first) << 64
end

@inline rand(rng::Union{TaskLocalRNG, Xoshiro}, ::SamplerType{Int128}) = rand(rng, UInt128) % Int128
@inline rand(rng::XoshiroLike, ::SamplerType{Int128}) = rand(rng, UInt128) % Int128

@inline function rand(rng::Union{TaskLocalRNG, Xoshiro},
@inline function rand(rng::XoshiroLike,
T::SamplerUnion(Bool, Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64))
S = T[]
# use upper bits
(rand(rng, UInt64) >>> (64 - 8*sizeof(S))) % S
end

function copy(rng::TaskLocalRNG)
t = current_task()
Xoshiro(t.rngState0, t.rngState1, t.rngState2, t.rngState3)
end

function copy!(dst::TaskLocalRNG, src::Xoshiro)
t = current_task()
setstate!(dst, src.s0, src.s1, src.s2, src.s3)
return dst
end

function copy!(dst::Xoshiro, src::TaskLocalRNG)
t = current_task()
setstate!(dst, t.rngState0, t.rngState1, t.rngState2, t.rngState3)
return dst
end

function ==(a::Xoshiro, b::TaskLocalRNG)
t = current_task()
a.s0 == t.rngState0 && a.s1 == t.rngState1 && a.s2 == t.rngState2 && a.s3 == t.rngState3
end

==(a::TaskLocalRNG, b::Xoshiro) = b == a

# for partial words, use upper bits from Xoshiro

rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{UInt52Raw{UInt64}}) = rand(r, UInt64) >>> 12
rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{UInt52{UInt64}}) = rand(r, UInt64) >>> 12
rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{UInt104{UInt128}}) = rand(r, UInt104Raw())
rand(r::XoshiroLike, ::SamplerTrivial{UInt52Raw{UInt64}}) = rand(r, UInt64) >>> 12
rand(r::XoshiroLike, ::SamplerTrivial{UInt52{UInt64}}) = rand(r, UInt64) >>> 12
rand(r::XoshiroLike, ::SamplerTrivial{UInt104{UInt128}}) = rand(r, UInt104Raw())

rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{CloseOpen01{Float16}}) =
rand(r::XoshiroLike, ::SamplerTrivial{CloseOpen01{Float16}}) =
Float16(Float32(rand(r, UInt16) >>> 5) * Float32(0x1.0p-11))

rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{CloseOpen01{Float32}}) =
rand(r::XoshiroLike, ::SamplerTrivial{CloseOpen01{Float32}}) =
Float32(rand(r, UInt32) >>> 8) * Float32(0x1.0p-24)

rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{CloseOpen01_64}) =
rand(r::XoshiroLike, ::SamplerTrivial{CloseOpen01_64}) =
Float64(rand(r, UInt64) >>> 11) * 0x1.0p-53
Loading