Skip to content

Commit

Permalink
docs: make example work with state logic (also make pre-commit run)
Browse files Browse the repository at this point in the history
  • Loading branch information
younesStrittmatter committed Jul 21, 2024
1 parent ab7d890 commit 577445a
Show file tree
Hide file tree
Showing 3 changed files with 418 additions and 132 deletions.
365 changes: 309 additions & 56 deletions docs/Example.ipynb

Large diffs are not rendered by default.

107 changes: 67 additions & 40 deletions src/autora/experiment_runner/synthetic/abstract/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
>>> formula_2 = 'rt ~ x1'
>>> fixed_effects_2 = {'x1': 2.}
>>> experiment_2 = lmm_experiment(formula=formula_2,fixed_effects=fixed_effects_2)
>>> experiment_1.ground_truth(conditions=conditions) == experiment_2.ground_truth(conditions=conditions)
>>> experiment_1.ground_truth(conditions=conditions) ==\
experiment_2.ground_truth(conditions=conditions)
x1 rt
0 True True
1 True True
Expand All @@ -48,7 +49,9 @@
>>> formula = 'rt ~ 1 + (1|subject) + x1'
>>> fixed_effects = {'Intercept': 1, 'x1': 2}
>>> random_effects = {'subject': {'Intercept': .1}}
>>> experiment = lmm_experiment(formula=formula,fixed_effects=fixed_effects,random_effects=random_effects)
>>> experiment = lmm_experiment(formula=formula,
... fixed_effects=fixed_effects,
... random_effects=random_effects)
>>> conditions_1 = pd.DataFrame({
... 'x1':np.linspace(0, 1, 3),
... 'subject': np.repeat(1, 3)
Expand Down Expand Up @@ -90,7 +93,9 @@
>>> formula = 'rt ~ (x1|subject) + x1'
>>> fixed_effects = {'x1': 1.}
>>> random_effects = {'subject': {'x1': .01}}
>>> experiment = lmm_experiment(formula=formula,fixed_effects=fixed_effects,random_effects=random_effects)
>>> experiment = lmm_experiment(formula=formula,
... fixed_effects=fixed_effects,
... random_effects=random_effects)
>>> experiment.ground_truth(conditions=conditions,random_state=42)
x1 subject rt
0 0.0 1 0.000000
Expand All @@ -106,7 +111,9 @@
... 'subject': {'1': 0.5, 'x1': 0.3},
... 'group': {'x2': 0.4}
... }
>>> experiment = lmm_experiment(formula=formula, fixed_effects=fixed_effects,random_effects=random_effects)
>>> experiment = lmm_experiment(formula=formula,
... fixed_effects=fixed_effects,
... random_effects=random_effects)
>>> n_samples = 10
>>> rng = np.random.default_rng(0)
>>> conditions = pd.DataFrame({
Expand Down Expand Up @@ -143,10 +150,9 @@
"""


from functools import partial
from typing import Optional, List
import re
from functools import partial
from typing import List, Optional

import numpy as np
import pandas as pd
Expand All @@ -173,7 +179,7 @@ def lmm_experiment(
fixed_effects: dictionary describing the fixed effects (Intercept and slopes)
random_effects: nested dictionary describing the random effects of slopes and intercept.
These are standard deviasions in a normal distribution with a mean of zero.
X: Independent variable descriptions. Used to add allowed values
X: Independent variable descriptions. Used to add allowed values
"""

if not fixed_effects:
Expand All @@ -186,17 +192,17 @@ def lmm_experiment(
name=name,
formula=formula,
fixed_effects=fixed_effects,
random_effects=random_effects
random_effects=random_effects,
)

dependent, fixed_variables, random_variables = _extract_variable_names(formula)

dependent = DV(name=dependent)
x = [IV(name=f) for f in fixed_variables] + [IV(name=r) for r in random_variables]
# x = [IV(name=f) for f in fixed_variables] + [IV(name=r) for r in random_variables]
#
# if X:
# x = X

if X:
x = X

variables = VariableCollection(
independent_variables=[X],
dependent_variables=[dependent],
Expand All @@ -216,28 +222,32 @@ def run(
rng_ = np.random.default_rng(random_state)
else:
rng_ = rng # use the RNG from the outer scope


dependent_var, rhs = formula.split('~')

dependent_var, rhs = formula.split("~")
dependent_var = dependent_var.strip()
fixed_vars = fixed_variables


# Check for the presence of an intercept in the formula
has_intercept = True if '1' in fixed_effects or re.search(r'\b0\b', rhs) is None else False
has_intercept = (
True if "1" in fixed_effects or re.search(r"\b0\b", rhs) is None else False
)

experiment_data = conditions.copy()

# Initialize the dependent variable
experiment_data[dependent_var] = fixed_effects.get('Intercept', 0) if has_intercept else 0
experiment_data[dependent_var] = (
fixed_effects.get("Intercept", 0) if has_intercept else 0
)

# Add fixed effects
for var in fixed_vars:
if var in experiment_data.columns:
experiment_data[dependent_var] += fixed_effects.get(var, 0) * experiment_data[var]
experiment_data[dependent_var] += (
fixed_effects.get(var, 0) * experiment_data[var]
)

# Process each random effect term
random_effect_terms = re.findall(r'\((.+?)\|(.+?)\)', formula)
random_effect_terms = re.findall(r"\((.+?)\|(.+?)\)", formula)
for term in random_effect_terms:
random_effects_, group_var = term
group_var = group_var.strip()
Expand All @@ -247,25 +257,36 @@ def run(
raise ValueError(f"Group variable '{group_var}' not found in the data")

# Process each part of the random effect (intercept and slopes)
for part in random_effects_.split('+'):
part = 'Intercept' if part == '1' else part
for part in random_effects_.split("+"):
part = "Intercept" if part == "1" else part
part = part.strip()
std_dev = random_effects[group_var].get(part, 0.5)
random_effect_values = {group: rng_.normal(0, std_dev) for group in experiment_data[group_var].unique()}
if part == 'Intercept': # Random intercept
random_effect_values = {
group: rng_.normal(0, std_dev)
for group in experiment_data[group_var].unique()
}
if part == "Intercept": # Random intercept
if has_intercept:
experiment_data[dependent_var] += experiment_data[group_var].map(random_effect_values)
experiment_data[dependent_var] += experiment_data[
group_var
].map(random_effect_values)
else: # Random slopes
if part in experiment_data.columns:
experiment_data[dependent_var] += experiment_data[group_var].map(random_effect_values) * experiment_data[part]
experiment_data[dependent_var] += (
experiment_data[group_var].map(random_effect_values)
* experiment_data[part]
)

# Add noise
experiment_data[dependent_var] += rng_.normal(0, added_noise, len(experiment_data))
experiment_data[dependent_var] += rng_.normal(
0, added_noise, len(experiment_data)
)

return experiment_data

ground_truth = partial(run, added_noise=0.0)
"""A function which simulates perfect observations. This still uses random values for random effects."""
"""A function which simulates perfect observations.
This still uses random values for random effects."""

def domain():
"""A function which returns all possible independent variable values as a 2D array."""
Expand All @@ -279,9 +300,9 @@ def plotter(model=None):
plt.figure()
dom = domain()
data = ground_truth(dom)
y = data[depedent]
x = data.drop(depenent, axis=1)

y = data[dependent]
x = data.drop(dependent, axis=1)

if x.shape[1] > 2:
Exception(
Expand Down Expand Up @@ -332,7 +353,8 @@ def _extract_variable_names(formula):
formula (str): Formula specifying the model, e.g., 'y ~ x1 + x2 + (1 + x1|group) + (x2|subject)'
Returns:
tuple of (list, list): A tuple containing two lists - one for fixed effects and another for random effects.
tuple of (list, list): A tuple containing two lists - one for fixed effects and another for
random effects.
Examples:
>>> formula_1 = 'y ~ x1 + x2 + (1 + x1|group) + (x2|subject)'
>>> _extract_variable_names(formula_1)
Expand All @@ -349,19 +371,24 @@ def _extract_variable_names(formula):
"""
# Extract the right-hand side of the formula
dependent, rhs = formula.split('~')
dependent, rhs = formula.split("~")
dependent = dependent.strip()

fixed_effects = re.findall(r'[a-z]\w*(?![^\(]*\))', rhs) # Matches variables outside parentheses
random_effects = re.findall(r'\(([^\|]+)\|([^\)]+)\)', rhs) # Matches random effects groups
fixed_effects = re.findall(
r"[a-z]\w*(?![^\(]*\))", rhs
) # Matches variables outside parentheses
random_effects = re.findall(
r"\(([^\|]+)\|([^\)]+)\)", rhs
) # Matches random effects groups

# Include variables from random effects in fixed effects and make unique
for reffect in random_effects:
fixed_effects.extend(reffect[0].replace('1 + ', '').split('+'))
fixed_effects.extend(reffect[0].replace("1 + ", "").split("+"))

# Removing duplicates and stripping whitespaces
fixed_effects = sorted(list(set([effect.strip() for effect in fixed_effects])))
random_groups = sorted(list(set([reffect[1].strip() for reffect in random_effects])))

random_groups = sorted(
list(set([reffect[1].strip() for reffect in random_effects]))
)

return dependent, fixed_effects, random_groups
78 changes: 42 additions & 36 deletions src/autora/experiment_runner/synthetic/psychology/q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
from autora.experiment_runner.synthetic.utilities import SyntheticExperimentCollection
from autora.variable import DV, IV, ValueType, VariableCollection


def _check_in_0_1_range(x, name):
if not (0 <= x <= 1):
raise ValueError(
f'Value of {name} must be in [0, 1] range. Found value of {x}.')
if not (0 <= x <= 1):
raise ValueError(
f"Value of {name} must be in [0, 1] range. Found value of {x}."
)


class AgentQ:
"""An agent that runs simple Q-learning for an n-armed bandits tasks.
Expand All @@ -22,13 +25,13 @@ class AgentQ:
"""

def __init__(
self,
alpha: float = 0.2,
beta: float = 3.,
n_actions: int = 2,
forget_rate: float = 0.,
perseverance_bias: float = 0.,
correlated_reward: bool = False,
self,
alpha: float = 0.2,
beta: float = 3.0,
n_actions: int = 2,
forget_rate: float = 0.0,
perseverance_bias: float = 0.0,
correlated_reward: bool = False,
):
"""Update the agent after one step of the task.
Expand All @@ -49,8 +52,8 @@ def __init__(
self._q_init = 0.5
self.new_sess()

_check_in_0_1_range(alpha, 'alpha')
_check_in_0_1_range(forget_rate, 'forget_rate')
_check_in_0_1_range(alpha, "alpha")
_check_in_0_1_range(forget_rate, "forget_rate")

def new_sess(self):
"""Reset the agent for the beginning of a new session."""
Expand All @@ -69,9 +72,7 @@ def get_choice(self) -> int:
choice = np.random.choice(self._n_actions, p=choice_probs)
return choice

def update(self,
choice: int,
reward: float):
def update(self, choice: int, reward: float):
"""Update the agent after one step of the task.
Args:
Expand All @@ -82,15 +83,17 @@ def update(self,
# Forgetting - restore q-values of non-chosen actions towards the initial value
non_chosen_action = np.arange(self._n_actions) != choice
self._q[non_chosen_action] = (1 - self._forget_rate) * self._q[
non_chosen_action] + self._forget_rate * self._q_init
non_chosen_action
] + self._forget_rate * self._q_init

# Reward-based update - Update chosen q for chosen action with observed reward
q_reward_update = - self._alpha * self._q[choice] + self._alpha * reward
q_reward_update = -self._alpha * self._q[choice] + self._alpha * reward

# Correlated update - Update non-chosen q for non-chosen action with observed reward
if self._correlated_reward:
# index_correlated_update = self._n_actions - choice - 1
# self._q[index_correlated_update] = (1 - self._alpha) * self._q[index_correlated_update] + self._alpha * (1 - reward)
# self._q[index_correlated_update] =
# (1 - self._alpha) * self._q[index_correlated_update] + self._alpha * (1 - reward)
# alternative implementation - not dependent on reward but on reward-based update
index_correlated_update = self._n_actions - 1 - choice
self._q[index_correlated_update] -= 0.5 * q_reward_update
Expand All @@ -111,10 +114,10 @@ def q(self):
def q_learning(
name="Q-Learning",
learning_rate: float = 0.2,
decision_noise: float = 3.,
decision_noise: float = 3.0,
n_actions: int = 2,
forget_rate: float = 0.,
perseverance_bias: float = 0.,
forget_rate: float = 0.0,
perseverance_bias: float = 0.0,
correlated_reward: bool = False,
):
"""
Expand All @@ -136,7 +139,8 @@ def q_learning(
# The runner can accept numpy arrays or pandas DataFrames, but the return value will
# always be a list of numpy arrays. Each array corresponds to the choices made by the agent
# for each trial in the input. Thus, arrays have shape (n_trials, n_actions).
>>> experiment.run(np.array([[0, 1], [0, 1], [0, 1], [1, 0], [1, 0], [1, 0]]), random_state=42)
>>> experiment.run(np.array([[0, 1], [0, 1], [0, 1], [1, 0], [1, 0], [1, 0]]),
... random_state=42)
[array([[1., 0.],
[0., 1.],
[0., 1.],
Expand All @@ -147,7 +151,10 @@ def q_learning(
# The runner can accept pandas DataFrames. Each cell of the DataFrame should contain a
# numpy array with shape (n_trials, n_actions). The return value will be a list of numpy
# arrays, each corresponding to the choices made by the agent for each trial in the input.
>>> experiment.run(pd.DataFrame({'reward array': [np.array([[0, 1], [0, 1], [0, 1], [1, 0], [1, 0], [1, 0]])]}), random_state = 42)
>>> experiment.run(
... pd.DataFrame(
... {'reward array': [np.array([[0, 1], [0, 1], [0, 1], [1, 0], [1, 0], [1, 0]])]}),
... random_state = 42)
[array([[1., 0.],
[0., 1.],
[0., 1.],
Expand All @@ -159,12 +166,12 @@ def q_learning(
params = dict(
name=name,
trials=100,
learning_rate = learning_rate,
decision_noise = decision_noise,
n_actions = n_actions,
forget_rate = forget_rate,
perseverance_bias = perseverance_bias,
correlated_reward = correlated_reward,
learning_rate=learning_rate,
decision_noise=decision_noise,
n_actions=n_actions,
forget_rate=forget_rate,
perseverance_bias=perseverance_bias,
correlated_reward=correlated_reward,
)

iv1 = IV(
Expand All @@ -187,9 +194,11 @@ def q_learning(
)

def run_AgentQ(rewards):
if (rewards.shape[1] != n_actions):
Warning("Number of actions in rewards does not match n_actions. Will use " + str(rewards.shape[1]
+ " actions."))
if rewards.shape[1] != n_actions:
Warning(
"Number of actions in rewards does not match n_actions. Will use "
+ str(rewards.shape[1] + " actions.")
)
num_trials = rewards.shape[0]

y = np.zeros(rewards.shape)
Expand All @@ -216,9 +225,8 @@ def run_AgentQ(rewards):
def run(
conditions: Union[pd.DataFrame, np.ndarray, np.recarray],
random_state: Optional[int] = None,
return_choice_probabilities = False,
return_choice_probabilities=False,
):

if random_state is not None:
np.random.seed(random_state)

Expand Down Expand Up @@ -256,5 +264,3 @@ def domain():
factory_function=q_learning,
)
return collection


0 comments on commit 577445a

Please sign in to comment.