diff --git a/Documentation/CHANGELOG.md b/Documentation/CHANGELOG.md index 49b860a50..e2593eca6 100644 --- a/Documentation/CHANGELOG.md +++ b/Documentation/CHANGELOG.md @@ -12,8 +12,9 @@ For more information on HARK, see [our Github organization](https://github.com/e Release Date: TBD -#### Major Changes +### Major Changes +- Adds a discretize method to DBlocks and RBlocks (#1460)[https://github.com/econ-ark/HARK/pull/1460] - Allows structural equations in model files to be provided in string form [#1427](https://github.com/econ-ark/HARK/pull/1427) - Introduces `HARK.parser' module for parsing configuration files into models [#1427](https://github.com/econ-ark/HARK/pull/1427) diff --git a/HARK/model.py b/HARK/model.py index 1816460fd..19c8df296 100644 --- a/HARK/model.py +++ b/HARK/model.py @@ -2,7 +2,8 @@ Tools for crafting models. """ -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace +from copy import copy, deepcopy from HARK.distribution import ( Distribution, DiscreteDistributionLabeled, @@ -135,8 +136,12 @@ def simulate_dynamics( return vals +class Block: + pass + + @dataclass -class DBlock: +class DBlock(Block): """ Represents a 'block' of model behavior. It prioritizes a representation of the dynamics of the block. @@ -162,6 +167,26 @@ class DBlock: dynamics: dict = field(default_factory=dict) reward: dict = field(default_factory=dict) + def discretize(self, disc_params): + """ + Returns a new DBlock which is a copy of this one, but with shock discretized. + """ + + disc_shocks = {} + + for shockn in self.shocks: + if shockn in disc_params: + disc_shocks[shockn] = self.shocks[shockn].discretize( + **disc_params[shockn] + ) + else: + disc_shocks[shockn] = deepcopy(self.shocks[shockn]) + + # replace returns a modified copy + new_dblock = replace(self, shocks=disc_shocks) + + return new_dblock + def __post_init__(self): for v in self.dynamics: if isinstance(self.dynamics[v], str): @@ -261,7 +286,7 @@ def mod_dvf(shock_value_array): @dataclass -class RBlock: +class RBlock(Block): """ A recursive block. @@ -272,7 +297,24 @@ class RBlock: name: str = "" description: str = "" - blocks: List[DBlock] = field(default_factory=list) + blocks: List[Block] = field(default_factory=list) + + def discretize(self, disc_params): + """ + Recursively discretizes all the blocks. + It replaces any DBlocks with new blocks with discretized shocks. + """ + cbs = copy(self.blocks) + + for i, b in list(enumerate(cbs)): + if isinstance(b, DBlock): + nb = b.discretize(disc_params) + cbs[i] = nb + elif isinstance(b, RBlock): + b.discretize(disc_params) + + # returns a copy of the RBlock with the blocks replaced + return replace(self, blocks=cbs) def get_shocks(self): ### TODO: Bug in here is causing AttributeError: 'set' object has no attribute 'draw' diff --git a/HARK/tests/test_model.py b/HARK/tests/test_model.py index f5b2db434..ad14fe27e 100644 --- a/HARK/tests/test_model.py +++ b/HARK/tests/test_model.py @@ -1,6 +1,6 @@ import unittest -from HARK.distribution import Bernoulli +from HARK.distribution import Bernoulli, DiscreteDistribution import HARK.model as model from HARK.model import Control import HARK.models.consumer as cons @@ -44,6 +44,11 @@ def setUp(self): def test_init(self): self.assertEqual(self.test_block_A.name, "test block A") + def test_discretize(self): + dbl = self.cblock.discretize({"theta": {"N": 5}}) + + self.assertEqual(len(dbl.shocks["theta"].pmv), 5) + def test_transition(self): post = self.cblock.transition(self.dpre, self.dr) @@ -79,6 +84,8 @@ def setUp(self): self.test_block_C = model.DBlock(**test_block_C_data) self.test_block_D = model.DBlock(**test_block_D_data) + self.cpp = cons.cons_portfolio_problem + def test_init(self): r_block_tree = model.RBlock( blocks=[ @@ -89,3 +96,13 @@ def test_init(self): r_block_tree.get_shocks() self.assertEqual(len(r_block_tree.get_shocks()), 3) + + def test_discretize(self): + cppd = self.cpp.discretize({"theta": {"N": 5}, "risky_return": {"N": 6}}) + + self.assertEqual(len(cppd.get_shocks()["theta"].pmv), 5) + self.assertEqual(len(cppd.get_shocks()["risky_return"].pmv), 6) + + self.assertFalse( + isinstance(self.cpp.get_shocks()["theta"], DiscreteDistribution) + )