Skip to content

Commit

Permalink
Merge pull request #1240 from alanlujan91/parameters
Browse files Browse the repository at this point in the history
Parameters
  • Loading branch information
sbenthall authored Sep 25, 2023
2 parents dce9901 + 8198530 commit da30b24
Show file tree
Hide file tree
Showing 13 changed files with 771 additions and 159 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,5 @@ jobs:
pytest --cov=./ --cov-report=xml
- name: upload coverage report
uses: codecov/codecov-action@v3
with:
fail_ci_if_error: false
13 changes: 6 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,28 @@ exclude: Documentation/example_notebooks/

repos:
- repo: https://github.com/mwouts/jupytext
rev: v1.14.5
rev: v1.15.0
hooks:
- id: jupytext
args:
[--sync, --set-formats, "ipynb,py:percent", --pipe, black, --execute]
args: [--sync, --set-formats, "ipynb", --pipe, black, --execute]
additional_dependencies: [jupytext, black, nbconvert]
files: ^examples/.*\.ipynb$

- repo: https://github.com/psf/black
rev: 23.3.0
rev: 23.7.0
hooks:
- id: black
exclude: ^examples/

- repo: https://github.com/asottile/pyupgrade
rev: v3.4.0
rev: v3.10.1
hooks:
- id: pyupgrade
args: ["--py38-plus"]
exclude: ^examples/

- repo: https://github.com/asottile/blacken-docs
rev: 1.13.0
rev: 1.15.0
hooks:
- id: blacken-docs
exclude: ^examples/
Expand All @@ -38,7 +37,7 @@ repos:
exclude: ^examples/

- repo: https://github.com/pre-commit/mirrors-prettier
rev: v3.0.0-alpha.9-for-vscode
rev: v3.0.1
hooks:
- id: prettier
exclude: ^examples/
Expand Down
1 change: 1 addition & 0 deletions Documentation/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Release Date: TBD
### Major Changes

- Adds `HARK.core.AgentPopulation` class to represent a population of agents with ex-ante heterogeneous parametrizations as distributions. [#1237](https://github.com/econ-ark/HARK/pull/1237)
- Adds `HARK.core.Parameters` class to represent a collection of time varying and time invariant parameters in a model. [#1240](https://github.com/econ-ark/HARK/pull/1240)

### Minor Changes

Expand Down
2 changes: 1 addition & 1 deletion HARK/ConsumptionSaving/ConsGenIncProcessModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ def solve(self):
# long run permanent income growth doesn't work yet
init_explicit_perm_inc["PermGroFac"] = [1.0]
init_explicit_perm_inc["aXtraMax"] = 30
init_explicit_perm_inc["aXtraExtra"] = [0.005, 0.01]
init_explicit_perm_inc["aXtraExtra"] = np.array([0.005, 0.01])


class GenIncProcessConsumerType(IndShockConsumerType):
Expand Down
6 changes: 2 additions & 4 deletions HARK/ConsumptionSaving/ConsLaborModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,10 +787,8 @@ def plot_LbrFunc(self, t, bMin=None, bMax=None, ShkSet=None):
1.0,
1.0,
] # Wage rate in a lifecycle
init_labor_lifecycle["LbrCostCoeffs"] = [
-2.0,
0.4,
] # Assume labor cost coeffs is a polynomial of degree 1
# Assume labor cost coeffs is a polynomial of degree 1
init_labor_lifecycle["LbrCostCoeffs"] = np.array([-2.0, 0.4])
init_labor_lifecycle["T_cycle"] = 10
# init_labor_lifecycle['T_retire'] = 7 # IndexError at line 774 in interpolation.py.
init_labor_lifecycle[
Expand Down
6 changes: 2 additions & 4 deletions HARK/ConsumptionSaving/tests/test_ConsGenIncProcessModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,8 @@
# Parameters for constructing the "assets above minimum" grid
"aXtraMin": 0.001, # Minimum end-of-period "assets above minimum" value
"aXtraMax": 30, # Maximum end-of-period "assets above minimum" value
"aXtraExtra": [
0.005,
0.01,
], # Some other value of "assets above minimum" to add to the grid
# Some other value of "assets above minimum" to add to the grid
"aXtraExtra": np.array([0.005, 0.01]),
"aXtraNestFac": 3, # Exponential nesting factor when constructing "assets above minimum" grid
"aXtraCount": 48, # Number of points in the grid of "assets above minimum"
# Parameters describing the income process
Expand Down
5 changes: 3 additions & 2 deletions HARK/ConsumptionSaving/tests/test_IndShockConsumerType.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,13 +762,13 @@ def setUp(self):
"LivPrb": LivPrb,
"PermGroFac": [PermGroFac],
"Rfree": Rfree,
"track_vars": ["bNrm", "t_age"],
}
)

def test_NewbornStatesAndShocks(self):
# Make agent, shock and initial condition histories
agent = IndShockConsumerType(**self.base_params)
agent.track_vars = ["bNrm", "t_age"]
agent.make_shock_history()

# Find indices of agents and time periods that correspond to deaths
Expand Down Expand Up @@ -814,13 +814,13 @@ def setUp(self):
{
"AgentCount": agent_count,
"T_sim": t_sim,
"track_vars": ["t_age", "t_cycle"],
}
)

def test_compare_t_age_t_cycle(self):
# Make agent, shock and initial condition histories
agent = IndShockConsumerType(**self.base_params)
agent.track_vars = ["t_age", "t_cycle"]
agent.make_shock_history()

# Solve and simulate the agent
Expand Down Expand Up @@ -855,6 +855,7 @@ def test_compare_t_age_t_cycle_premature_death(self):
par["T_age"] = par["T_age"] - 8
# Make agent, shock and initial condition histories
agent = IndShockConsumerType(**par)
agent.track_vars = ["t_age", "t_cycle"]
agent.make_shock_history()

# Solve and simulate the agent
Expand Down
215 changes: 207 additions & 8 deletions HARK/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
problem by finding a general equilibrium dynamic rule.
"""
import sys
from collections import defaultdict, namedtuple
from copy import copy, deepcopy
from dataclasses import dataclass, field
from time import time
Expand Down Expand Up @@ -57,11 +58,216 @@ def set_verbosity_level(level):
_log.setLevel(level)


class Parameters:
"""
This class defines an object that stores all of the parameters for a model
as an internal dictionary. It is designed to also handle the age-varying
dynamics of parameters.
Attributes
----------
_length : int
The terminal age of the agents in the model.
_invariant_params : list
A list of the names of the parameters that are invariant over time.
_varying_params : list
A list of the names of the parameters that vary over time.
"""

def __init__(self, **parameters: Any):
"""
Initializes a Parameters object and parses the age-varying
dynamics of the parameters.
Parameters
----------
parameters : keyword arguments
Any number of keyword arguments of the form key=value.
To parse a dictionary of parameters, use the ** operator.
"""
params = parameters.copy()
self._length = params.pop("T_cycle", None)
self._invariant_params = set()
self._varying_params = set()
self._parameters: Dict[str, Union[int, float, np.ndarray, list, tuple]] = {}

for key, value in params.items():
self._parameters[key] = self.__infer_dims__(key, value)

def __infer_dims__(
self, key: str, value: Union[int, float, np.ndarray, list, tuple, None]
) -> Union[int, float, np.ndarray, list, tuple]:
"""
Infers the age-varying dimensions of a parameter.
If the parameter is a scalar, numpy array, or None, it is assumed to be
invariant over time. If the parameter is a list or tuple, it is assumed
to be varying over time. If the parameter is a list or tuple of length
greater than 1, the length of the list or tuple must match the
`_term_age` attribute of the Parameters object.
Parameters
----------
key : str
name of parameter
value : Any
value of parameter
"""
if isinstance(value, (int, float, np.ndarray, type(None))):
self.__add_to_invariant__(key)
return value
if isinstance(value, (list, tuple)):
if len(value) == 1:
self.__add_to_invariant__(key)
return value[0]
if self._length is None or self._length == 1:
self._length = len(value)
if len(value) == self._length:
self.__add_to_varying__(key)
return value
raise ValueError(
f"Parameter {key} must be of length 1 or {self._length}, not {len(value)}"
)
raise ValueError(f"Parameter {key} has unsupported type {type(value)}")

def __add_to_invariant__(self, key: str):
"""
Adds parameter name to invariant set and removes from varying set.
"""
self._varying_params.discard(key)
self._invariant_params.add(key)

def __add_to_varying__(self, key: str):
"""
Adds parameter name to varying set and removes from invariant set.
"""
self._invariant_params.discard(key)
self._varying_params.add(key)

def __getitem__(self, item_or_key: Union[int, str]):
"""
If item_or_key is an integer, returns a Parameters object with the parameters
that apply to that age. This includes all invariant parameters and the
`item_or_key`th element of all age-varying parameters. If item_or_key is a string,
it returns the value of the parameter with that name.
"""
if isinstance(item_or_key, int):
if item_or_key >= self._length:
raise ValueError(
f"Age {item_or_key} is greater than or equal to terminal age {self._length}."
)

params = {key: self._parameters[key] for key in self._invariant_params}
params.update(
{
key: self._parameters[key][item_or_key]
for key in self._varying_params
}
)
return Parameters(**params)
elif isinstance(item_or_key, str):
return self._parameters[item_or_key]

def __setitem__(self, key: str, value: Any):
"""
Sets the value of a parameter.
Parameters
----------
key : str
name of parameter
value : Any
value of parameter
"""
if not isinstance(key, str):
raise ValueError("Parameters must be set with a string key")
self._parameters[key] = self.__infer_dims__(key, value)

def keys(self):
"""
Returns a list of the names of the parameters.
"""
return self._invariant_params | self._varying_params

def values(self):
"""
Returns a list of the values of the parameters.
"""
return list(self._parameters.values())

def items(self):
"""
Returns a list of tuples of the form (name, value) for each parameter.
"""
return list(self._parameters.items())

def __iter__(self):
"""
Allows for iterating over the parameter names.
"""
return iter(self.keys())

def __deepcopy__(self, memo):
"""
Returns a deep copy of the Parameters object.
"""
return Parameters(**deepcopy(self.to_dict(), memo))

def to_dict(self):
"""
Returns a dictionary of the parameters.
"""
return {key: self._parameters[key] for key in self.keys()}

def to_namedtuple(self):
"""
Returns a namedtuple of the parameters.
"""
return namedtuple("Parameters", self.keys())(**self.to_dict())

def update(self, other_params):
"""
Updates the parameters with the values from another
Parameters object or a dictionary.
Parameters
----------
other_params : Parameters or dict
Parameters object or dictionary of parameters to update with.
"""
if isinstance(other_params, Parameters):
self._parameters.update(other_params.to_dict())
elif isinstance(other_params, dict):
self._parameters.update(other_params)
else:
raise ValueError("Parameters must be a dict or a Parameters object")

def __str__(self):
"""
Returns a simple string representation of the Parameters object.
"""
return f"Parameters({str(self.to_dict())})"

def __repr__(self):
"""
Returns a detailed string representation of the Parameters object.
"""
return f"Parameters( _age_inv = {self._invariant_params}, _age_var = {self._varying_params}, | {self.to_dict()})"


class Model:
"""
A class with special handling of parameters assignment.
"""

def __init__(self):
if not hasattr(self, "parameters"):
self.parameters = {}

def assign_parameters(self, **kwds):
"""
Assign an arbitrary number of attributes to this agent.
Expand Down Expand Up @@ -102,10 +308,6 @@ def __eq__(self, other):

return NotImplemented

def __init__(self):
if not hasattr(self, "parameters"):
self.parameters = {}

def __str__(self):
type_ = type(self)
module = type_.__module__
Expand Down Expand Up @@ -1468,17 +1670,14 @@ def distribute_params(agent, param_name, param_count, distribution):
return agent_set


Parameters = NewType("ParameterDict", dict)


@dataclass
class AgentPopulation:
"""
A class for representing a population of ex-ante heterogeneous agents.
"""

agent_type: AgentType # type of agent in the population
parameters: Parameters # dictionary of parameters
parameters: dict # dictionary of parameters
seed: int = 0 # random seed
time_var: List[str] = field(init=False)
time_inv: List[str] = field(init=False)
Expand Down
Loading

0 comments on commit da30b24

Please sign in to comment.