diff --git a/HARK/distribution.py b/HARK/distribution.py index ea3300070c..acf6421819 100644 --- a/HARK/distribution.py +++ b/HARK/distribution.py @@ -1105,3 +1105,54 @@ def calcExpectation(dstn,func=None,values=None): # Compute expectations over the shocks and return it f_exp = np.dot(f_query, dstn.pmf) return f_exp + + +class MarkovProcess(Distribution): + """ + A representation of a discrete Markov process. + + Parameters + ---------- + transition_matrix : np.array + An array of floats representing a probability mass for + each state transition. + seed : int + Seed for random number generator. + + """ + + transition_matrix = None + + def __init__(self, transition_matrix, seed=0): + """ + Initialize a discrete distribution. + + """ + self.transition_matrix = transition_matrix + + # Set up the RNG + super().__init__(seed) + + def draw(self, state): + """ + Draw new states fromt the transition matrix. + + Parameters + ---------- + state : int or nd.array + The state or states (1-D array) from which to draw new states. + + Returns + ------- + new_state : int or nd.array + New states. + """ + def sample(s): + return self.RNG.choice( + self.transition_matrix.shape[1], + p = self.transition_matrix[s,:] + ) + + array_sample = np.frompyfunc(sample, 1, 1) + + return array_sample(state) diff --git a/HARK/tests/test_distribution.py b/HARK/tests/test_distribution.py index 64efb94663..5b48abbeb8 100644 --- a/HARK/tests/test_distribution.py +++ b/HARK/tests/test_distribution.py @@ -6,7 +6,7 @@ class DiscreteDistributionTests(unittest.TestCase): """ - Tests for simulation.py sampling distributions + Tests for distribution.py sampling distributions with default seed. """ @@ -21,7 +21,7 @@ def test_drawDiscrete(self): class DistributionClassTests(unittest.TestCase): """ - Tests for simulation.py sampling distributions + Tests for distribution.py sampling distributions with default seed. """ @@ -64,3 +64,25 @@ def test_Uniform(self): def test_Bernoulli(self): self.assertEqual(distribution.Bernoulli().draw(1)[0], False) + + +class MarkovProcessTests(unittest.TestCase): + """ + Tests for MarkovProcess class. + """ + + def test_draw(self): + + mrkv_array = np.array( + [[.75, .25],[0.1, 0.9]] + ) + + mp = distribution.MarkovProcess(mrkv_array) + + new_state = mp.draw(np.zeros(100).astype(int)) + + self.assertEqual(new_state.sum(), 20) + + new_state = mp.draw(new_state) + + self.assertEqual(new_state.sum(), 39)