Skip to content

Commit

Permalink
Merge pull request #408 from sbenthall/issue-404
Browse files Browse the repository at this point in the history
Issue 404 - refactoring simulation.py
  • Loading branch information
llorracc authored Oct 29, 2019
2 parents f6e9e91 + b5b35ca commit 8edcebf
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 12 deletions.
15 changes: 3 additions & 12 deletions HARK/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,13 @@ def drawMeanOneLognormal(N, sigma=1.0, seed=0):
T-length list of arrays of mean one lognormal draws each of size N, or
a single array of size N (if sigma is a scalar).
'''
# Set up the RNG
RNG = np.random.RandomState(seed)
mu = -0.5*sigma**2

if isinstance(sigma,float): # Return a single array of length N
mu = -0.5*sigma**2
draws = RNG.lognormal(mean=mu, sigma=sigma, size=N)
else: # Set up empty list to populate, then loop and populate list with draws
draws=[]
for sig in sigma:
mu = -0.5*(sig**2)
draws.append(RNG.lognormal(mean=mu, sigma=sig, size=N))
return draws
return drawLognormal(N,mu=mu,sigma=sigma,seed=seed)

def drawLognormal(N,mu=0.0,sigma=1.0,seed=0):
'''
Generate arrays of mean one lognormal draws. The sigma input can be a number
Generate arrays of lognormal draws. The sigma input can be a number
or list-like. If a number, output is a length N array of draws from the
lognormal distribution with standard deviation sigma. If a list, output is
a length T list whose t-th entry is a length N array of draws from the
Expand Down
46 changes: 46 additions & 0 deletions HARK/tests/test_simulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import unittest

import HARK.simulation as simulation

class SimulationTests(unittest.TestCase):
'''
Tests for simulation.py sampling distributions
with default seed.
'''

def test_drawMeanOneLognormal(self):
self.assertEqual(
simulation.drawMeanOneLognormal(1)[0],
3.5397367004222002)

def test_drawLognormal(self):
self.assertEqual(
simulation.drawLognormal(1)[0],
5.836039190663969)

def test_drawNormal(self):
self.assertEqual(
simulation.drawNormal(1)[0],
1.764052345967664)

def test_drawWeibull(self):
self.assertEqual(
simulation.drawWeibull(1)[0],
0.79587450816311)

def test_drawUniform(self):
self.assertEqual(
simulation.drawUniform(1)[0],
0.5488135039273248)

def test_drawBernoulli(self):
self.assertEqual(
simulation.drawBernoulli(1)[0],
False)


def test_drawDiscrete(self):
self.assertEqual(
simulation.drawDiscrete(1)[0],
0)

0 comments on commit 8edcebf

Please sign in to comment.