Skip to content
This repository has been archived by the owner on Oct 23, 2022. It is now read-only.

Commit

Permalink
Allow the Spectrum.Uniform to accept an external random generator and…
Browse files Browse the repository at this point in the history
… fixed the tests to use the specific generator we are expecting.
  • Loading branch information
Ran Gal committed Jul 14, 2021
1 parent 3793474 commit 60750a6
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 25 deletions.
12 changes: 7 additions & 5 deletions src/Optical/Emitters/Spectrum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using DataFrames: DataFrame
using Distributions
import Unitful: Length, ustrip
using Unitful.DefaultSymbols
using Random

const UNIFORMSHORT = 0.450 #um
const UNIFORMLONG = 0.680 #um
Expand All @@ -30,19 +31,20 @@ Uniform(::Type{T} = Float64) where {T<:Real}
struct Uniform{T} <: AbstractSpectrum{T}
low_end::T
high_end::T
rng::Random.AbstractRNG

# user defined range of spectrum
function Uniform(low_end::T, high_end::T) where {T<:Real}
return new{T}(low_end, high_end)
function Uniform(low_end::T, high_end::T; rng=Random.GLOBAL_RNG) where {T<:Real}
return new{T}(low_end, high_end, rng)
end

# with no specific range we will use the constants' values
function Uniform(::Type{T} = Float64) where {T<:Real}
return new{T}(UNIFORMSHORT, UNIFORMLONG)
function Uniform(::Type{T} = Float64; rng=Random.GLOBAL_RNG) where {T<:Real}
return new{T}(UNIFORMSHORT, UNIFORMLONG, rng)
end
end

Emitters.generate(s::Uniform{T}) where {T<:Real} = (one(T), rand(Distributions.Uniform(s.low_end, s.high_end)))
Emitters.generate(s::Uniform{T}) where {T<:Real} = (one(T), rand(s.rng, Distributions.Uniform(s.low_end, s.high_end)))

"""
DeltaFunction{T} <: AbstractSpectrum{T}
Expand Down
31 changes: 11 additions & 20 deletions test/testsets/Emitters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ using StaticArrays
@test Directions.UniformCone(0., 0).vvec === unitY3()

@test Base.length(Directions.UniformCone(0., 1)) === 1
Random.seed!(0)
@test collect(Directions.UniformCone/4, 2)) == [

@test collect(Directions.UniformCone/4, 2, rng=Random.MersenneTwister(0))) == [
[0.30348115383395624, -0.6083405920618145, 0.7333627433388552],
[0.16266571675478964, 0.2733479444418462, 0.9480615833700192],
]
Expand Down Expand Up @@ -131,8 +131,7 @@ using StaticArrays
@test Base.length(Origins.RectUniform(1, 2, 3)) === 3
@test Emitters.visual_size(Origins.RectUniform(1, 2, 3)) === 2

Random.seed!(0)
@test collect(Origins.RectUniform(1, 2, 3)) == [
@test collect(Origins.RectUniform(1, 2, 3, rng=Random.MersenneTwister(0))) == [
[0.3236475079774124, 0.8207130758528729, 0.0],
[-0.3354342018663148, -0.6453423070674709, 0.0],
[-0.221119890668799, -0.5930468839161547, 0.0],
Expand Down Expand Up @@ -191,8 +190,7 @@ using StaticArrays
@test Spectrum.Uniform().low_end === 0.450
@test Spectrum.Uniform().high_end === 0.680

Random.seed!(0)
@test Emitters.generate(Spectrum.Uniform()) === (1.0, 0.6394389268348049)
@test Emitters.generate(Spectrum.Uniform(rng=Random.MersenneTwister(0))) === (1.0, 0.6394389268348049)
end

@testset "DeltaFunction" begin
Expand Down Expand Up @@ -236,24 +234,20 @@ using StaticArrays
directions=Directions.HexapolarCone(0., 1)
)) === 49

Random.seed!(0)
@test Base.iterate(Sources.Source()) === (expected_rays[1], Sources.SourceGenerationState(2, 0, Vec3()))
@test Base.iterate(Sources.Source(spectrum=Spectrum.Uniform(rng=Random.MersenneTwister(0)))) === (expected_rays[1], Sources.SourceGenerationState(2, 0, Vec3()))

Random.seed!(0)
@test Base.getindex(Sources.Source(), 0) === expected_rays[1]
Random.seed!(0)
@test Emitters.generate(Sources.Source(), 0) === expected_rays[1]
@test Base.getindex(Sources.Source(spectrum=Spectrum.Uniform(rng=Random.MersenneTwister(0))), 0) === expected_rays[1]
@test Emitters.generate(Sources.Source(spectrum=Spectrum.Uniform(rng=Random.MersenneTwister(0))), 0) === expected_rays[1]

Random.seed!(0)
@test Emitters.generate(Sources.Source()) === (expected_rays[1], Sources.SourceGenerationState(0, -2, Vec3()))
@test Emitters.generate(Sources.Source(spectrum=Spectrum.Uniform(rng=Random.MersenneTwister(0)))) === (expected_rays[1], Sources.SourceGenerationState(0, -2, Vec3()))

@test Base.firstindex(Sources.Source()) === 0
@test Base.lastindex(Sources.Source()) === 0
@test Base.copy(Sources.Source()) === Sources.Source()
end

@testset "CompositeSource" begin
s() = Sources.Source()
s() = Sources.Source(spectrum=Spectrum.Uniform(rng=Random.MersenneTwister(0)))
tr = Transform()
cs1 = Sources.CompositeSource(tr, [s()])
cs2 = Sources.CompositeSource(tr, [s(), s()])
Expand All @@ -277,14 +271,11 @@ using StaticArrays
@test Base.length(cs2) === 2
@test Base.length(cs3) === 3

Random.seed!(0)
@test Base.iterate(cs1) === (expected_rays[1], Sources.SourceGenerationState(2, 0, Vec3()))

Random.seed!(0)
@test collect(cs2) == expected_rays[1:2]
@test collect(cs2) == vcat(expected_rays[1:1], expected_rays[1:1])

Random.seed!(0)
@test collect(cs3) == expected_rays[1:3]
@test collect(cs3) == vcat(expected_rays[1:1], expected_rays[2:2], expected_rays[2:2])
end
end
end # testset Emitters

0 comments on commit 60750a6

Please sign in to comment.