Skip to content

Commit

Permalink
Merge pull request #702 from econ-ark/SimPerformance
Browse files Browse the repository at this point in the history
Fix simulation performance of AggShockMarkovConsumerType
  • Loading branch information
llorracc authored Jun 18, 2020
2 parents d93e80c + 8262c28 commit 0c5a315
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 43 deletions.
29 changes: 14 additions & 15 deletions HARK/ConsumptionSaving/ConsAggShockModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,17 +477,14 @@ def getShocks(self):
if N > 0:
IncomeDstnNow = self.IncomeDstn[t-1][self.MrkvNow] # set current income distribution
PermGroFacNow = self.PermGroFac[t-1] # and permanent growth factor
Indices = np.arange(IncomeDstnNow.pmf.size) # just a list of integers

# Get random draws of income shocks from the discrete distribution
EventDraws = DiscreteDistribution(
IncomeDstnNow.pmf,
Indices
).drawDiscrete(N,
exact_match=True,
seed=self.RNG.randint(0, 2**31-1))
# permanent "shock" includes expected growth
PermShkNow[these] = IncomeDstnNow.X[0][EventDraws]*PermGroFacNow
TranShkNow[these] = IncomeDstnNow.X[1][EventDraws]
ShockDraws = IncomeDstnNow.drawDiscrete(N,
exact_match=True,
seed=self.RNG.randint(0, 2**31-1))
# Permanent "shock" includes expected growth
PermShkNow[these] = ShockDraws[0]*PermGroFacNow
TranShkNow[these] = ShockDraws[1]

# That procedure used the *last* period in the sequence for newborns, but that's not right
# Redraw shocks for newborns, using the *first* period in the sequence. Approximation.
Expand All @@ -496,12 +493,14 @@ def getShocks(self):
these = newborn
IncomeDstnNow = self.IncomeDstn[0][self.MrkvNow] # set current income distribution
PermGroFacNow = self.PermGroFac[0] # and permanent growth factor

# Get random draws of income shocks from the discrete distribution
EventDraws = IncomeDstnNow.draw_events(N,
seed=self.RNG.randint(0, 2**31-1))
# permanent "shock" includes expected growth
PermShkNow[these] = IncomeDstnNow.X[0][EventDraws]*PermGroFacNow
TranShkNow[these] = IncomeDstnNow.X[1][EventDraws]
ShockDraws = IncomeDstnNow.drawDiscrete(N,
exact_match=True,
seed=self.RNG.randint(0, 2**31-1))
# Permanent "shock" includes expected growth
PermShkNow[these] = ShockDraws[0]*PermGroFacNow
TranShkNow[these] = ShockDraws[1]

# Store the shocks in self
self.EmpNow = np.ones(self.AgentCount, dtype=bool)
Expand Down
4 changes: 2 additions & 2 deletions HARK/ConsumptionSaving/tests/test_ConsAggShockModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_agent(self):
self.agent.getEconomyData(self.economy)
self.agent.solve()
self.assertAlmostEqual(self.agent.solution[0].cFunc[0](10., self.economy.MSS),
2.5635896520991377)
2.5635896520991377)

def test_economy(self):
# Adjust the economy so that it (fake) solves quickly
Expand All @@ -79,5 +79,5 @@ def test_economy(self):

self.economy.AFunc = self.economy.dynamics.AFunc
self.assertAlmostEqual(self.economy.AFunc[0].slope,
1.0845554708377696)
1.0801777346256896)

24 changes: 15 additions & 9 deletions HARK/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,30 +501,36 @@ def drawDiscrete(self, N,X=None,exact_match=False,seed=0):
X = X
J = 1

# Draw indices whose empirical distribution closely matches discrete pmf
if exact_match:
# Set up the RNG
RNG = np.random.RandomState(seed)

events = np.arange(self.pmf.size) # just a list of integers
cutoffs = np.round(np.cumsum(self.pmf)*N).astype(int) # cutoff points between discrete outcomes
top = 0

# Make a list of event indices that closely matches the discrete distribution
event_list = []
for j in range(events.size):
bot = top
top = cutoffs[j]
event_list += (top-bot)*[events[j]]
# Randomly permute the event indices and store the corresponding results
event_draws = RNG.permutation(event_list)
draws = X[event_draws]

# Randomly permute the event indices
indices = RNG.permutation(event_list)

# Draw event indices randomly from the discrete distribution
else:
indices = self.draw_events(N, seed=seed)
if J > 1:
draws = np.zeros((J,N))
for j in range(J):
draws[j,:] = X[j][indices]
else:
draws = np.asarray(X)[indices]

# Create and fill in the output array of draws based on the output of event indices
if J > 1:
draws = np.zeros((J,N))
for j in range(J):
draws[j,:] = X[j][indices]
else:
draws = np.asarray(X)[indices]

return draws

Expand Down
19 changes: 9 additions & 10 deletions examples/ConsumptionSaving/example_ConsAggShockModel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# %%
from time import process_time, time
from time import time
import numpy as np
import matplotlib.pyplot as plt
from HARK.utilities import plotFuncs
Expand All @@ -23,9 +23,9 @@ def mystr(number):
# Solve an AggShockMarkovConsumerType's microeconomic problem
solve_markov_micro = False
# Solve for the equilibrium aggregate saving rule in a CobbDouglasMarkovEconomy
solve_markov_market = False
solve_markov_market = True
# Solve a simple Krusell-Smith-style two state, two shock model
solve_krusell_smith = True
solve_krusell_smith = False
# Solve a CobbDouglasEconomy with many states, potentially utilizing the "state jumper"
solve_poly_state = False

Expand All @@ -48,9 +48,9 @@ def mystr(number):
# %%
if solve_agg_shocks_micro:
# Solve the microeconomic model for the aggregate shocks example type (and display results)
t_start = process_time()
t_start = time()
AggShockExample.solve()
t_end = process_time()
t_end = time()
print(
"Solving an aggregate shocks consumer took "
+ mystr(t_end - t_start)
Expand Down Expand Up @@ -117,9 +117,9 @@ def mystr(number):
# %%
if solve_markov_micro:
# Solve the microeconomic model for the Markov aggregate shocks example type (and display results)
t_start = process_time()
t_start = time()
AggShockMrkvExample.solve()
t_end = process_time()
t_end = time()
print(
"Solving an aggregate shocks Markov consumer took "
+ mystr(t_end - t_start)
Expand All @@ -144,7 +144,6 @@ def mystr(number):
# Solve the "macroeconomic" model by searching for a "fixed point dynamic rule"
t_start = time()
MrkvEconomyExample.verbose = True
MrkvEconomyExample.act_T = 500
print("Now solving a two-state Markov economy. This should take a few minutes...")
MrkvEconomyExample.solve()
t_end = time()
Expand Down Expand Up @@ -240,14 +239,14 @@ def mystr(number):
) # Have the consumers inherit relevant objects from the economy

# Solve the many state model
t_start = process_time()
t_start = time()
print(
"Now solving an economy with "
+ str(StateCount)
+ " Markov states. This might take a while..."
)
PolyStateEconomy.solve()
t_end = process_time()
t_end = time()
print(
"Solving a model with "
+ str(StateCount)
Expand Down
8 changes: 1 addition & 7 deletions examples/ConsumptionSaving/example_ConsIndShock.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
+ " seconds."
)
PFexample.unpackcFunc()
PFexample.timeFwd()

# %%
# Plot the perfect foresight consumption function
Expand Down Expand Up @@ -62,7 +61,6 @@
+ " seconds."
)
IndShockExample.unpackcFunc()
IndShockExample.timeFwd()

# %%
# Plot the consumption function and MPC for the infinite horizon consumer
Expand Down Expand Up @@ -113,7 +111,6 @@
end_time = time()
print("Solving a lifecycle consumer took " + mystr(end_time - start_time) + " seconds.")
LifecycleExample.unpackcFunc()
LifecycleExample.timeFwd()

# %%
# Plot the consumption functions during working life
Expand All @@ -126,8 +123,7 @@
# %%
# Plot the consumption functions during retirement
print("Consumption functions while retired:")
plotFuncs(LifecycleExample.cFunc[LifecycleExample.T_retire :], 0, 5)
LifecycleExample.timeRev()
plotFuncs(LifecycleExample.cFunc[LifecycleExample.T_retire:], 0, 5)

# %%
# Simulate some data; results stored in mNrmNow_hist, cNrmNow_hist, pLvlNow_hist, and t_age_hist
Expand All @@ -149,7 +145,6 @@
end_time = time()
print("Solving a cyclical consumer took " + mystr(end_time - start_time) + " seconds.")
CyclicalExample.unpackcFunc()
CyclicalExample.timeFwd()

# %%
# Plot the consumption functions for the cyclical consumer type
Expand Down Expand Up @@ -177,7 +172,6 @@
print("Solving a kinky consumer took " + mystr(end_time - start_time) + " seconds.")
KinkyExample.unpackcFunc()
print("Kinky consumption function:")
KinkyExample.timeFwd()
plotFuncs(KinkyExample.cFunc[0], KinkyExample.solution[0].mNrmMin, 5)

# %%
Expand Down

0 comments on commit 0c5a315

Please sign in to comment.