Skip to content

Commit

Permalink
flatten 1D arrays at the end of calcExpectation, see econ-ark#625
Browse files Browse the repository at this point in the history
  • Loading branch information
sbenthall committed Jan 9, 2021
1 parent 6940a28 commit e9a2550
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 5 deletions.
3 changes: 0 additions & 3 deletions HARK/ConsumptionSaving/ConsIndShockModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,9 +793,6 @@ def vp_next(shocks, a_nrm):
)
)

if EndOfPrdvP.shape[0] == EndOfPrdvP.size:
EndOfPrdvP = EndOfPrdvP.flatten()

return EndOfPrdvP

def getPointsForInterpolation(self, EndOfPrdvP, aNrmNow):
Expand Down
4 changes: 3 additions & 1 deletion HARK/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,7 +1011,7 @@ def calcExpectation(dstn,func=lambda x : x,*args):
The N-valued distribution over which the function is to be evaluated.
func : function
The function to be evaluated.
This function should take a 1D array of size N.
This function should take an array of size N x M.
It may also take other arguments *args
Please see numpy.apply_along_axis() for guidance on
design of func.
Expand Down Expand Up @@ -1049,6 +1049,8 @@ def calcExpectation(dstn,func=lambda x : x,*args):
# a hack.
if f_exp.size == 1:
f_exp = f_exp.flat[0]
elif f_exp.shape[0] == f_exp.size:
f_exp = f_exp.flatten()

return f_exp

Expand Down
2 changes: 1 addition & 1 deletion HARK/tests/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_calcExpectation(self):
)

self.assertAlmostEqual(
ce9[3][0],
ce9[3],
9.518015322143837
)

Expand Down

0 comments on commit e9a2550

Please sign in to comment.