Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add code for a multivariate discrete random variable #204

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/QuantEcon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ export

# discrete_rv
DiscreteRV,
MVDiscreteRV,
draw,

# mc_tools
Expand Down
66 changes: 66 additions & 0 deletions src/discrete_rv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,69 @@ function Base.rand!(out::AbstractArray{T}, d::DiscreteRV) where T<:Integer
end

@deprecate draw Base.rand


struct MVDiscreteRV{TV1<:AbstractArray,TV2<:AbstractVector,K,TI<:Integer}
q::TV1
Q::TV2
dims::NTuple{K,TI}

function MVDiscreteRV{TV1,TV2,K,TI}(q::TV1, Q::TV2, dims::NTuple{K,TI}) where {TV1,TV2,K,TI}
abs(sum(q) - 1.0) > 1e-10 && error("q should sum to 1")
abs(Q[end] - 1.0) > 1e-10 && error("Q[end] should be 1")
length(Q) != prod(dims) && error("Number of elements is inconsistent")

new{TV1,TV2,K,TI}(q, Q, dims)
end
end


function MVDiscreteRV(q::TV1) where TV1<:AbstractArray
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense to have the only inner constructor take just q. We wouldn't have to do the length(Q) == prod(dims) check

We don't really want people to create these in any other way

Q = cumsum(vec(q))
dims = size(q)

return MVDiscreteRV{typeof(q),typeof(Q),length(dims),eltype(dims)}(q, Q, dims)
end


"""
Make a single draw from the multivariate discrete distribution.

##### Arguments

- `d::MVDiscreteRV`: The `MVDiscreteRV` type represetning the distribution

##### Returns

- `out::NTuple{Int}`: One draw from the discrete distribution
"""
function Base.rand(d::MVDiscreteRV)
x = rand()
i = searchsortedfirst(d.Q, x)

return ind2sub(d.dims, i)
end

"""
Make multiple draws from the discrete distribution represented by a
`MVDiscreteRV` instance

##### Arguments

- `d::MVDiscreteRV`: The `DiscreteRV` type representing the distribution
- `k::Int`

##### Returns

- `out::Vector{NTuple{Int}}`: `k` draws from `d`
"""
Base.rand(d::MVDiscreteRV{T1,T2,K,TI}, k::V) where {T1,T2,K,TI,V} =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is k supposed to be? I think it should have k::Integer instead of a free type param

NTuple{K,TI}[rand(d) for i in 1:k]

function Base.rand!(out::AbstractArray{NTuple{K,TI}}, d::MVDiscreteRV) where {K,TI}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will you please add a tests for this method?

@inbounds for I in eachindex(out)
out[I] = rand(d)
end

return out
end
87 changes: 61 additions & 26 deletions test/test_discrete_rv.jl
Original file line number Diff line number Diff line change
@@ -1,34 +1,69 @@
@testset "Testing discrete_rv.jl" begin

# set up
n = 10
x = rand(n)
x ./= sum(x)
drv = DiscreteRV(x)

# test Q sums to 1
@test drv.Q[end] ≈ 1.0

# test lln
draws = rand(drv, 100_000)
c = counter(draws)
counts = Array{Float64}(n)
for i=1:n
counts[i] = c[i]
end
counts ./= sum(counts)
@testset "Testing univariate discrete rv" begin
# set up
n = 10
x = rand(n)
x ./= sum(x)
drv = DiscreteRV(x)

# test Q sums to 1
@test drv.Q[end] ≈ 1.0

# test lln
draws = rand(drv, 100_000)
c = counter(draws)
counts = Array{Float64}(n)
for i=1:n
counts[i] = c[i]
end
counts ./= sum(counts)

@test isapprox(Base.maximum(abs, counts - drv.q), 0.0; atol=1e-2)
@test isapprox(Base.maximum(abs, counts - drv.q), 0.0; atol=1e-2)

draws = Array{Int}(100_000)
rand!(draws, drv)
c = counter(draws)
counts = Array{Float64}(n)
for i=1:n
counts[i] = c[i]
draws = Array{Int}(100_000)
rand!(draws, drv)
c = counter(draws)
counts = Array{Float64}(n)
for i=1:n
counts[i] = c[i]
end
counts ./= sum(counts)

@test isapprox(Base.maximum(abs, counts - drv.q), 0.0; atol=1e-2)
end
counts ./= sum(counts)

@test isapprox(Base.maximum(abs, counts - drv.q), 0.0; atol=1e-2)
@testset "Testing multivariate discrete rv" begin
# Do tests for various sizes
for dims in [(5, 3), (5, 10, 3), (5, 7, 5, 10)]
# How many dimensions
n = length(dims)

# Make some distributional matrix
q = rand(dims...)
q ./= sum(q) # Normalize to sum to 1

# Create mv rv
rv = MVDiscreteRV(q)

# Make sure it doesn't draw numbers that don't make sense... Must
# be between 1 and n
for i in 1:n
@test rand(rv)[i] >= 1
@test rand(rv)[i] <= dims[i]
end

ndraws = 1_000_000
draws = rand(rv, ndraws)
counter = zeros(dims...)
for i in 1:ndraws
draw = draws[i]
counter[draw...] += 1.0
end
counter ./= ndraws
@test isapprox(Base.maximum(abs, counter - rv.q), 0.0; atol=1e-2)
end

end

end # testset