From 61fd2f73df611c3b8d704f5ad387840a8f849c3c Mon Sep 17 00:00:00 2001 From: sb Date: Thu, 14 Nov 2024 10:42:13 -0500 Subject: [PATCH] vbi algo handling cases where stage has 0 controls --- HARK/algos/tests/test_vbi.py | 24 +++++++++++ HARK/algos/vbi.py | 83 +++++++++++++++++++----------------- HARK/model.py | 1 + 3 files changed, 68 insertions(+), 40 deletions(-) diff --git a/HARK/algos/tests/test_vbi.py b/HARK/algos/tests/test_vbi.py index 7bf77e9bb..ff15a76be 100644 --- a/HARK/algos/tests/test_vbi.py +++ b/HARK/algos/tests/test_vbi.py @@ -21,6 +21,21 @@ } ) +block_2 = DBlock( # has no control variable + **{ + "name": "vbi_test_1", + "shocks": { + "coin": Bernoulli(p=0.5), + }, + "dynamics": { + "m": lambda y, coin: y + coin, + "a": lambda m: m - 1, + }, + "reward": {"u": lambda m: 0}, + } +) + + class test_vbi(unittest.TestCase): # def setUp(self): @@ -33,6 +48,15 @@ def test_solve_block_1(self): self.assertAlmostEqual(dr["c"](**{"m": 1}), 0.5) + def test_solve_block_2(self): + # no control variable case. + state_grid = {"m": np.linspace(0, 2, 10)} + + dr, dec_vf, arr_vf = vbi.solve(block_2, lambda a: a, state_grid) + + # arrival value function gives the correct expect value of continuation + self.assertAlmostEqual(arr_vf({"y": 10}), 9.5) + def test_solve_consumption_problem(self): state_grid = {"m": np.linspace(0, 5, 10)} diff --git a/HARK/algos/vbi.py b/HARK/algos/vbi.py index a0a215b06..9d06712d3 100644 --- a/HARK/algos/vbi.py +++ b/HARK/algos/vbi.py @@ -114,50 +114,53 @@ def negated_value(a): # old! (should be negative) # negative, for minimization later return -srv_function(pre_states, dr) - ## get lower bound. - ## assumes only one control currently - lower_bound = -1e-6 ## a really low number! - feq = block.dynamics[controls[0]].lower_bound - if feq is not None: - lower_bound = feq(*[pre_states[var] for var in signature(feq).parameters]) - - ## get upper bound - ## assumes only one control currently - upper_bound = 1e-12 # a very high number - feq = block.dynamics[controls[0]].upper_bound - if feq is not None: - upper_bound = feq(*[pre_states[var] for var in signature(feq).parameters]) - - # pseudo - # optimize_action(pre_states, srv_function) - - bounds = ((lower_bound, upper_bound),) - - res = minimize( # choice of - negated_value, - 1, # x0 is starting guess, here arbitrary. - bounds=bounds, - ) - - dr_best = {c: get_action_rule(res.x[i]) for i, c in enumerate(controls)} - - if res.success: - policy_data.sel(**state_vals).variable.data.put( - 0, res.x[0] - ) # will only work for scalar actions - value_data.sel(**state_vals).variable.data.put( - 0, srv_function(pre_states, dr_best) + if len(controls) == 0: + # if no controls, no optimization is necessary + pass + elif len(controls) == 1: + ## get lower bound. + ## assumes only one control currently + lower_bound = -1e-6 ## a really low number! + feq = block.dynamics[controls[0]].lower_bound + if feq is not None: + lower_bound = feq(*[pre_states[var] for var in signature(feq).parameters]) + + ## get upper bound + ## assumes only one control currently + upper_bound = 1e-12 # a very high number + feq = block.dynamics[controls[0]].upper_bound + if feq is not None: + upper_bound = feq(*[pre_states[var] for var in signature(feq).parameters]) + + bounds = ((lower_bound, upper_bound),) + + res = minimize( # choice of + negated_value, + 1, # x0 is starting guess, here arbitrary. + bounds=bounds, ) - else: - print(f"Optimization failure at {state_vals}.") - print(res) dr_best = {c: get_action_rule(res.x[i]) for i, c in enumerate(controls)} - policy_data.sel(**state_vals).variable.data.put(0, res.x[0]) # ? - value_data.sel(**state_vals).variable.data.put( - 0, srv_function(pre_states, dr_best) - ) + if res.success: + policy_data.sel(**state_vals).variable.data.put( + 0, res.x[0] + ) # will only work for scalar actions + value_data.sel(**state_vals).variable.data.put( + 0, srv_function(pre_states, dr_best) + ) + else: + print(f"Optimization failure at {state_vals}.") + print(res) + + dr_best = {c: get_action_rule(res.x[i]) for i, c in enumerate(controls)} + + policy_data.sel(**state_vals).variable.data.put(0, res.x[0]) # ? + value_data.sel(**state_vals).variable.data.put( + 0, srv_function(pre_states, dr_best) + ) + elif len(controls) > 1: + raise Exception(f"Value backup iteration is not yet implemented for stages with {len(controls)} > 1 control variables.") # use the xarray interpolator to create a decision rule. dr_from_data = { diff --git a/HARK/model.py b/HARK/model.py index 191546801..c281f42f0 100644 --- a/HARK/model.py +++ b/HARK/model.py @@ -343,6 +343,7 @@ def get_state_rule_value_function_from_continuation( def state_rule_value_function(pre, dr): vals = self.transition(pre, dr, screen=screen) r = list(self.calc_reward(vals).values())[0] # a hack; to be improved + # this assumes a single reward variable; instead, a named could be passed in. cv = continuation( *[vals[var] for var in signature(continuation).parameters] )