Skip to content

Commit

Permalink
vbi algo handling cases where stage has 0 controls
Browse files Browse the repository at this point in the history
  • Loading branch information
sbenthall committed Nov 14, 2024
1 parent 1e22c86 commit 61fd2f7
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 40 deletions.
24 changes: 24 additions & 0 deletions HARK/algos/tests/test_vbi.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,21 @@
}
)

block_2 = DBlock( # has no control variable
**{
"name": "vbi_test_1",
"shocks": {
"coin": Bernoulli(p=0.5),
},
"dynamics": {
"m": lambda y, coin: y + coin,
"a": lambda m: m - 1,
},
"reward": {"u": lambda m: 0},
}
)



class test_vbi(unittest.TestCase):
# def setUp(self):
Expand All @@ -33,6 +48,15 @@ def test_solve_block_1(self):

self.assertAlmostEqual(dr["c"](**{"m": 1}), 0.5)

def test_solve_block_2(self):
# no control variable case.
state_grid = {"m": np.linspace(0, 2, 10)}

dr, dec_vf, arr_vf = vbi.solve(block_2, lambda a: a, state_grid)

# arrival value function gives the correct expect value of continuation
self.assertAlmostEqual(arr_vf({"y": 10}), 9.5)

def test_solve_consumption_problem(self):
state_grid = {"m": np.linspace(0, 5, 10)}

Expand Down
83 changes: 43 additions & 40 deletions HARK/algos/vbi.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,50 +114,53 @@ def negated_value(a): # old! (should be negative)
# negative, for minimization later
return -srv_function(pre_states, dr)

## get lower bound.
## assumes only one control currently
lower_bound = -1e-6 ## a really low number!
feq = block.dynamics[controls[0]].lower_bound
if feq is not None:
lower_bound = feq(*[pre_states[var] for var in signature(feq).parameters])

## get upper bound
## assumes only one control currently
upper_bound = 1e-12 # a very high number
feq = block.dynamics[controls[0]].upper_bound
if feq is not None:
upper_bound = feq(*[pre_states[var] for var in signature(feq).parameters])

# pseudo
# optimize_action(pre_states, srv_function)

bounds = ((lower_bound, upper_bound),)

res = minimize( # choice of
negated_value,
1, # x0 is starting guess, here arbitrary.
bounds=bounds,
)

dr_best = {c: get_action_rule(res.x[i]) for i, c in enumerate(controls)}

if res.success:
policy_data.sel(**state_vals).variable.data.put(
0, res.x[0]
) # will only work for scalar actions
value_data.sel(**state_vals).variable.data.put(
0, srv_function(pre_states, dr_best)
if len(controls) == 0:
# if no controls, no optimization is necessary
pass
elif len(controls) == 1:
## get lower bound.
## assumes only one control currently
lower_bound = -1e-6 ## a really low number!
feq = block.dynamics[controls[0]].lower_bound
if feq is not None:
lower_bound = feq(*[pre_states[var] for var in signature(feq).parameters])

## get upper bound
## assumes only one control currently
upper_bound = 1e-12 # a very high number
feq = block.dynamics[controls[0]].upper_bound
if feq is not None:
upper_bound = feq(*[pre_states[var] for var in signature(feq).parameters])

bounds = ((lower_bound, upper_bound),)

res = minimize( # choice of
negated_value,
1, # x0 is starting guess, here arbitrary.
bounds=bounds,
)
else:
print(f"Optimization failure at {state_vals}.")
print(res)

dr_best = {c: get_action_rule(res.x[i]) for i, c in enumerate(controls)}

policy_data.sel(**state_vals).variable.data.put(0, res.x[0]) # ?
value_data.sel(**state_vals).variable.data.put(
0, srv_function(pre_states, dr_best)
)
if res.success:
policy_data.sel(**state_vals).variable.data.put(
0, res.x[0]
) # will only work for scalar actions
value_data.sel(**state_vals).variable.data.put(
0, srv_function(pre_states, dr_best)
)
else:
print(f"Optimization failure at {state_vals}.")
print(res)

dr_best = {c: get_action_rule(res.x[i]) for i, c in enumerate(controls)}

policy_data.sel(**state_vals).variable.data.put(0, res.x[0]) # ?
value_data.sel(**state_vals).variable.data.put(
0, srv_function(pre_states, dr_best)
)
elif len(controls) > 1:
raise Exception(f"Value backup iteration is not yet implemented for stages with {len(controls)} > 1 control variables.")

# use the xarray interpolator to create a decision rule.
dr_from_data = {
Expand Down
1 change: 1 addition & 0 deletions HARK/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def get_state_rule_value_function_from_continuation(
def state_rule_value_function(pre, dr):
vals = self.transition(pre, dr, screen=screen)
r = list(self.calc_reward(vals).values())[0] # a hack; to be improved
# this assumes a single reward variable; instead, a named could be passed in.
cv = continuation(
*[vals[var] for var in signature(continuation).parameters]
)
Expand Down

0 comments on commit 61fd2f7

Please sign in to comment.