Skip to content

Commit

Permalink
remove duplicate Params
Browse files Browse the repository at this point in the history
  • Loading branch information
alanlujan91 committed Nov 7, 2024
1 parent 37e8bce commit b0bc3ed
Showing 1 changed file with 0 additions and 344 deletions.
344 changes: 0 additions & 344 deletions HARK/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,350 +650,6 @@ def describe_constructors(self, *args):
from typing import Any, Dict, Iterator, List, Set, Tuple, Union


class Parameters:
"""
A smart container for model parameters that handles age-varying dynamics.
This class stores parameters as an internal dictionary and manages their
age-varying properties. It provides both attribute-style and dictionary-style
access to parameters.
Attributes:
_length (int): The terminal age of the agents in the model.
_invariant_params (Set[str]): A set of parameter names that are invariant over time.
_varying_params (Set[str]): A set of parameter names that vary over time.
_parameters (Dict[str, Any]): The internal dictionary storing all parameters.
"""

__slots__ = ("_length", "_invariant_params", "_varying_params", "_parameters")

def __init__(self, **parameters: Any) -> None:
"""
Initialize a Parameters object and parse the age-varying dynamics of parameters.
Args:
**parameters (Any): Keyword arguments representing parameter names and values.
"""
self._length: int = parameters.pop("T_cycle", 1)
self._invariant_params: Set[str] = set()
self._varying_params: Set[str] = set()
self._parameters: Dict[str, Any] = {"T_cycle": self._length}

for key, value in parameters.items():
self[key] = value

def __getitem__(self, item_or_key: Union[int, str]) -> Union["Parameters", Any]:
"""
Access parameters by age index or parameter name.
Args:
item_or_key (Union[int, str]): Age index or parameter name.
Returns:
Union[Parameters, Any]: A new Parameters object for the specified age,
or the value of the specified parameter.
Raises:
ValueError: If the age index is out of bounds.
KeyError: If the parameter name is not found.
TypeError: If the key is neither an integer nor a string.
"""
if isinstance(item_or_key, int):
if item_or_key >= self._length:
raise ValueError(
f"Age {item_or_key} is out of bounds (max: {self._length - 1})."
)

params = {key: self._parameters[key] for key in self._invariant_params}
params.update(
{
key: self._parameters[key][item_or_key]
if isinstance(self._parameters[key], (list, tuple, np.ndarray))
else self._parameters[key]
for key in self._varying_params
}
)
return Parameters(**params)
elif isinstance(item_or_key, str):
return self._parameters[item_or_key]
else:
raise TypeError("Key must be an integer (age) or string (parameter name).")

def __setitem__(self, key: str, value: Any) -> None:
"""
Set parameter values, automatically inferring time variance.
Args:
key (str): Name of the parameter.
value (Any): Value of the parameter.
Raises:
ValueError: If the parameter name is not a string or if the value type is unsupported.
ValueError: If the parameter value is inconsistent with the current model length.
"""
if not isinstance(key, str):
raise ValueError(f"Parameter name must be a string, got {type(key)}")

if isinstance(
value, (int, float, np.ndarray, type(None), Distribution, bool, Callable)
):
self._invariant_params.add(key)
self._varying_params.discard(key)
elif isinstance(value, (list, tuple)):
if len(value) == 1:
value = value[0]
self._invariant_params.add(key)
self._varying_params.discard(key)
elif self._length is None or self._length == 1:
self._length = len(value)
self._varying_params.add(key)
self._invariant_params.discard(key)
elif len(value) == self._length:
self._varying_params.add(key)
self._invariant_params.discard(key)
else:
raise ValueError(
f"Parameter {key} must have length 1 or {self._length}, not {len(value)}"
)
else:
raise ValueError(f"Unsupported type for parameter {key}: {type(value)}")

self._parameters[key] = value

def __getattr__(self, name: str) -> Any:
"""
Allow attribute-style access to parameters.
Args:
name (str): Name of the parameter to access.
Returns:
Any: The value of the specified parameter.
Raises:
AttributeError: If the parameter name is not found.
"""
if name.startswith("_"):
return super().__getattribute__(name)
try:
return self._parameters[name]
except KeyError:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)

def __setattr__(self, name: str, value: Any) -> None:
"""
Allow attribute-style setting of parameters.
Args:
name (str): Name of the parameter to set.
value (Any): Value to set for the parameter.
"""
if name.startswith("_"):
super().__setattr__(name, value)
else:
self[name] = value

def __contains__(self, key: str) -> bool:
"""
Check if a parameter exists.
Args:
key (str): The name of the parameter.
Returns:
bool: True if the parameter exists, False otherwise.
"""
return key in self._parameters

def __iter__(self) -> Iterator[str]:
"""
Iterate over parameter names.
Returns:
Iterator[str]: An iterator over parameter names.
"""
return iter(self._parameters)

def __len__(self) -> int:
"""
Get the number of parameters.
Returns:
int: The number of parameters.
"""
return len(self._parameters)

def __repr__(self) -> str:
"""
Get a string representation of the Parameters object.
Returns:
str: A string representation of the Parameters object.
"""
return f"Parameters(_length={self._length}, _invariant_params={self._invariant_params}, _varying_params={self._varying_params}, _parameters={self._parameters})"

def __str__(self) -> str:
"""
Get a string representation of the Parameters object.
Returns:
str: A string representation of the Parameters object.
"""
return self.__repr__()

def keys(self) -> Set[str]:
"""
Get the names of all parameters.
Returns:
Set[str]: The names of all parameters.
"""
return set(self._parameters.keys())

def values(self) -> List[Any]:
"""
Get the values of all parameters.
Returns:
List[Any]: The values of all parameters.
"""
return list(self._parameters.values())

def items(self) -> List[Tuple[str, Any]]:
"""
Get the names and values of all parameters.
Returns:
List[Tuple[str, Any]]: The names and values of all parameters.
"""
return list(self._parameters.items())

def to_dict(self) -> Dict[str, Any]:
"""
Convert parameters to a plain dictionary.
Returns:
Dict[str, Any]: A dictionary containing all parameters.
"""
return dict(self._parameters)

def to_namedtuple(self) -> namedtuple:
"""
Convert parameters to a namedtuple.
Returns:
namedtuple: A namedtuple containing all parameters.
"""
return namedtuple("Parameters", self.keys())(**self.to_dict())

def update(self, other: Union["Parameters", Dict[str, Any]]) -> None:
"""
Update parameters from another Parameters object or dictionary.
Args:
other (Union[Parameters, Dict[str, Any]]): The source of parameters to update from.
Raises:
TypeError: If the input is neither a Parameters object nor a dictionary.
"""
if isinstance(other, Parameters):
for key, value in other._parameters.items():
self[key] = value
elif isinstance(other, dict):
for key, value in other.items():
self[key] = value
else:
raise TypeError(f"Expected Parameters or dict, got {type(other)}")

def copy(self) -> "Parameters":
"""
Create a deep copy of the Parameters object.
Returns:
Parameters: A new Parameters object with the same contents.
"""
return deepcopy(self)

def add_to_time_vary(self, *params: str) -> None:
"""
Adds any number of parameters to the time-varying set.
Args:
*params (str): Any number of strings naming parameters to be added to time_vary.
"""
for param in params:
if param in self._parameters:
self._varying_params.add(param)

def add_to_time_inv(self, *params: str) -> None:
"""
Adds any number of parameters to the time-invariant set.
Args:
*params (str): Any number of strings naming parameters to be added to time_inv.
"""
for param in params:
if param in self._parameters:
self._invariant_params.add(param)

def del_from_time_vary(self, *params: str) -> None:
"""
Removes any number of parameters from the time-varying set.
Args:
*params (str): Any number of strings naming parameters to be removed from time_vary.
"""
for param in params:
self._varying_params.discard(param)

def del_from_time_inv(self, *params: str) -> None:
"""
Removes any number of parameters from the time-invariant set.
Args:
*params (str): Any number of strings naming parameters to be removed from time_inv.
"""
for param in params:
self._invariant_params.discard(param)

def get(self, key: str, default: Any = None) -> Any:
"""
Get a parameter value, returning a default if not found.
Args:
key (str): The parameter name.
default (Any, optional): The default value to return if the key is not found.
Returns:
Any: The parameter value or the default.
"""
return self._parameters.get(key, default)

def set_many(self, **kwargs: Any) -> None:
"""
Set multiple parameters at once.
Args:
**kwargs: Keyword arguments representing parameter names and values.
"""
for key, value in kwargs.items():
self[key] = value

def is_time_varying(self, key: str) -> bool:
"""
Check if a parameter is time-varying.
Args:
key (str): The parameter name.
Returns:
bool: True if the parameter is time-varying, False otherwise.
"""
return key in self._varying_params


class AgentType(Model):
"""
A superclass for economic agents in the HARK framework. Each model should
Expand Down

0 comments on commit b0bc3ed

Please sign in to comment.