Skip to content

Commit

Permalink
Add code for a multivariate discrete rv
Browse files Browse the repository at this point in the history
  • Loading branch information
cc7768 committed Feb 21, 2018
1 parent 9334727 commit d8798d5
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 26 deletions.
1 change: 1 addition & 0 deletions src/QuantEcon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ export

# discrete_rv
DiscreteRV,
MVDiscreteRV,
draw,

# mc_tools
Expand Down
66 changes: 66 additions & 0 deletions src/discrete_rv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,69 @@ function Base.rand!(out::AbstractArray{T}, d::DiscreteRV) where T<:Integer
end

@deprecate draw Base.rand


struct MVDiscreteRV{TV1<:AbstractArray,TV2<:AbstractVector,K,TI<:Integer}
q::TV1
Q::TV2
dims::NTuple{K,TI}

function MVDiscreteRV{TV1,TV2,K,TI}(q::TV1, Q::TV2, dims::NTuple{K,TI}) where {TV1,TV2,K,TI}
abs(sum(q) - 1.0) > 1e-10 && error("q should sum to 1")
abs(Q[end] - 1.0) > 1e-10 && error("Q[end] should be 1")
length(Q) != prod(dims) && error("Number of elements is inconsistent")

new{TV1,TV2,K,TI}(q, Q, dims)
end
end


function MVDiscreteRV(q::TV1) where TV1<:AbstractArray
Q = cumsum(vec(q))
dims = size(q)

return MVDiscreteRV{typeof(q),typeof(Q),length(dims),eltype(dims)}(q, Q, dims)
end


"""
Make a single draw from the multivariate discrete distribution.
##### Arguments
- `d::MVDiscreteRV`: The `MVDiscreteRV` type represetning the distribution
##### Returns
- `out::NTuple{Int}`: One draw from the discrete distribution
"""
function Base.rand(d::MVDiscreteRV)
x = rand()
i = searchsortedfirst(d.Q, x)

return ind2sub(d.dims, i)
end

"""
Make multiple draws from the discrete distribution represented by a
`MVDiscreteRV` instance
##### Arguments
- `d::MVDiscreteRV`: The `DiscreteRV` type representing the distribution
- `k::Int`
##### Returns
- `out::Vector{NTuple{Int}}`: `k` draws from `d`
"""
Base.rand(d::MVDiscreteRV{T1,T2,K,TI}, k::V) where {T1,T2,K,TI,V} =
NTuple{K,TI}[rand(d) for i in 1:k]

function Base.rand!(out::AbstractArray{NTuple{K,TI}}, d::MVDiscreteRV) where {K,TI}
@inbounds for I in eachindex(out)
out[I] = rand(d)
end

return out
end
87 changes: 61 additions & 26 deletions test/test_discrete_rv.jl
Original file line number Diff line number Diff line change
@@ -1,34 +1,69 @@
@testset "Testing discrete_rv.jl" begin

# set up
n = 10
x = rand(n)
x ./= sum(x)
drv = DiscreteRV(x)

# test Q sums to 1
@test drv.Q[end] 1.0

# test lln
draws = rand(drv, 100_000)
c = counter(draws)
counts = Array{Float64}(n)
for i=1:n
counts[i] = c[i]
end
counts ./= sum(counts)
@testset "Testing univariate discrete rv" begin
# set up
n = 10
x = rand(n)
x ./= sum(x)
drv = DiscreteRV(x)

# test Q sums to 1
@test drv.Q[end] 1.0

# test lln
draws = rand(drv, 100_000)
c = counter(draws)
counts = Array{Float64}(n)
for i=1:n
counts[i] = c[i]
end
counts ./= sum(counts)

@test isapprox(Base.maximum(abs, counts - drv.q), 0.0; atol=1e-2)
@test isapprox(Base.maximum(abs, counts - drv.q), 0.0; atol=1e-2)

draws = Array{Int}(100_000)
rand!(draws, drv)
c = counter(draws)
counts = Array{Float64}(n)
for i=1:n
counts[i] = c[i]
draws = Array{Int}(100_000)
rand!(draws, drv)
c = counter(draws)
counts = Array{Float64}(n)
for i=1:n
counts[i] = c[i]
end
counts ./= sum(counts)

@test isapprox(Base.maximum(abs, counts - drv.q), 0.0; atol=1e-2)
end
counts ./= sum(counts)

@test isapprox(Base.maximum(abs, counts - drv.q), 0.0; atol=1e-2)
@testset "Testing multivariate discrete rv" begin
# Do tests for various sizes
for dims in [(5, 3), (5, 10, 3), (5, 7, 5, 10)]
# How many dimensions
n = length(dims)

# Make some distributional matrix
q = rand(dims...)
q ./= sum(q) # Normalize to sum to 1

# Create mv rv
rv = MVDiscreteRV(q)

# Make sure it doesn't draw numbers that don't make sense... Must
# be between 1 and n
for i in 1:n
@test rand(rv)[i] >= 1
@test rand(rv)[i] <= dims[i]
end

ndraws = 1_000_000
draws = rand(rv, ndraws)
counter = zeros(dims...)
for i in 1:ndraws
draw = draws[i]
counter[draw...] += 1.0
end
counter ./= ndraws
@test isapprox(Base.maximum(abs, counter - rv.q), 0.0; atol=1e-2)
end

end

end # testset

0 comments on commit d8798d5

Please sign in to comment.