diff --git a/src/QuantEcon.jl b/src/QuantEcon.jl index 1abfcf4b..aefc22ca 100644 --- a/src/QuantEcon.jl +++ b/src/QuantEcon.jl @@ -42,6 +42,7 @@ export # discrete_rv DiscreteRV, + MVDiscreteRV, draw, # mc_tools diff --git a/src/discrete_rv.jl b/src/discrete_rv.jl index 23593b26..6d1d4334 100644 --- a/src/discrete_rv.jl +++ b/src/discrete_rv.jl @@ -78,3 +78,69 @@ function Base.rand!(out::AbstractArray{T}, d::DiscreteRV) where T<:Integer end @deprecate draw Base.rand + + +struct MVDiscreteRV{TV1<:AbstractArray,TV2<:AbstractVector,K,TI<:Integer} + q::TV1 + Q::TV2 + dims::NTuple{K,TI} + + function MVDiscreteRV{TV1,TV2,K,TI}(q::TV1, Q::TV2, dims::NTuple{K,TI}) where {TV1,TV2,K,TI} + abs(sum(q) - 1.0) > 1e-10 && error("q should sum to 1") + abs(Q[end] - 1.0) > 1e-10 && error("Q[end] should be 1") + length(Q) != prod(dims) && error("Number of elements is inconsistent") + + new{TV1,TV2,K,TI}(q, Q, dims) + end +end + + +function MVDiscreteRV(q::TV1) where TV1<:AbstractArray + Q = cumsum(vec(q)) + dims = size(q) + + return MVDiscreteRV{typeof(q),typeof(Q),length(dims),eltype(dims)}(q, Q, dims) +end + + +""" +Make a single draw from the multivariate discrete distribution. + +##### Arguments + +- `d::MVDiscreteRV`: The `MVDiscreteRV` type represetning the distribution + +##### Returns + +- `out::NTuple{Int}`: One draw from the discrete distribution +""" +function Base.rand(d::MVDiscreteRV) + x = rand() + i = searchsortedfirst(d.Q, x) + + return ind2sub(d.dims, i) +end + +""" +Make multiple draws from the discrete distribution represented by a +`MVDiscreteRV` instance + +##### Arguments + +- `d::MVDiscreteRV`: The `DiscreteRV` type representing the distribution +- `k::Int` + +##### Returns + +- `out::Vector{NTuple{Int}}`: `k` draws from `d` +""" +Base.rand(d::MVDiscreteRV{T1,T2,K,TI}, k::V) where {T1,T2,K,TI,V} = + NTuple{K,TI}[rand(d) for i in 1:k] + +function Base.rand!(out::AbstractArray{NTuple{K,TI}}, d::MVDiscreteRV) where {K,TI} + @inbounds for I in eachindex(out) + out[I] = rand(d) + end + + return out +end diff --git a/test/test_discrete_rv.jl b/test/test_discrete_rv.jl index 914b68c1..7bdf1df3 100644 --- a/test/test_discrete_rv.jl +++ b/test/test_discrete_rv.jl @@ -1,34 +1,69 @@ @testset "Testing discrete_rv.jl" begin - # set up - n = 10 - x = rand(n) - x ./= sum(x) - drv = DiscreteRV(x) - - # test Q sums to 1 - @test drv.Q[end] ≈ 1.0 - - # test lln - draws = rand(drv, 100_000) - c = counter(draws) - counts = Array{Float64}(n) - for i=1:n - counts[i] = c[i] - end - counts ./= sum(counts) + @testset "Testing univariate discrete rv" begin + # set up + n = 10 + x = rand(n) + x ./= sum(x) + drv = DiscreteRV(x) + + # test Q sums to 1 + @test drv.Q[end] ≈ 1.0 + + # test lln + draws = rand(drv, 100_000) + c = counter(draws) + counts = Array{Float64}(n) + for i=1:n + counts[i] = c[i] + end + counts ./= sum(counts) - @test isapprox(Base.maximum(abs, counts - drv.q), 0.0; atol=1e-2) + @test isapprox(Base.maximum(abs, counts - drv.q), 0.0; atol=1e-2) - draws = Array{Int}(100_000) - rand!(draws, drv) - c = counter(draws) - counts = Array{Float64}(n) - for i=1:n - counts[i] = c[i] + draws = Array{Int}(100_000) + rand!(draws, drv) + c = counter(draws) + counts = Array{Float64}(n) + for i=1:n + counts[i] = c[i] + end + counts ./= sum(counts) + + @test isapprox(Base.maximum(abs, counts - drv.q), 0.0; atol=1e-2) end - counts ./= sum(counts) - @test isapprox(Base.maximum(abs, counts - drv.q), 0.0; atol=1e-2) + @testset "Testing multivariate discrete rv" begin + # Do tests for various sizes + for dims in [(5, 3), (5, 10, 3), (5, 7, 5, 10)] + # How many dimensions + n = length(dims) + + # Make some distributional matrix + q = rand(dims...) + q ./= sum(q) # Normalize to sum to 1 + + # Create mv rv + rv = MVDiscreteRV(q) + + # Make sure it doesn't draw numbers that don't make sense... Must + # be between 1 and n + for i in 1:n + @test rand(rv)[i] >= 1 + @test rand(rv)[i] <= dims[i] + end + + ndraws = 1_000_000 + draws = rand(rv, ndraws) + counter = zeros(dims...) + for i in 1:ndraws + draw = draws[i] + counter[draw...] += 1.0 + end + counter ./= ndraws + @test isapprox(Base.maximum(abs, counter - rv.q), 0.0; atol=1e-2) + end + + end end # testset