diff --git a/docs/Example.ipynb b/docs/Example.ipynb index a99e74a5..0368f777 100644 --- a/docs/Example.ipynb +++ b/docs/Example.ipynb @@ -65,7 +65,7 @@ " Examples:\n", " >>> experiment = weber_fechner_law()\n", " \n", - " # We can run the runner with numpy arrays or DataFrames. Ther return value will\n", + " # The runner can accept numpy arrays or pandas DataFrames, but the return value will\n", " # always be a pandas DataFrame.\n", " >>> experiment.run(np.array([[.1,.2]]), random_state=42)\n", " S1 S2 difference_detected\n", @@ -110,7 +110,7 @@ " Examples:\n", " >>> experiment = weber_fechner_law()\n", "\n", - " # We can run the runner with numpy arrays or DataFrames. Ther return value will\n", + " # The runner can accept numpy arrays or pandas DataFrames, but the return value will\n", " # always be a pandas DataFrame.\n", " >>> experiment.run(np.array([[.1,.2]]), random_state=42)\n", " S1 S2 difference_detected\n", @@ -260,7 +260,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "... the experiment_runner runner which can be called to generate experimental results:" + "... the experiment_runner which can be called to generate experimental results:" ] }, { @@ -299,31 +299,31 @@ " 0\n", " 0.010000\n", " 0.010000\n", - " -0.000829\n", + " -0.013734\n", " \n", " \n", " 1\n", " 0.010000\n", " 0.060404\n", - " 1.806354\n", + " 1.788438\n", " \n", " \n", " 2\n", " 0.010000\n", " 0.110808\n", - " 2.406270\n", + " 2.403593\n", " \n", " \n", " 3\n", " 0.010000\n", " 0.161212\n", - " 2.774411\n", + " 2.781474\n", " \n", " \n", " 4\n", " 0.010000\n", " 0.211616\n", - " 3.056933\n", + " 3.055966\n", " \n", " \n", " ...\n", @@ -335,31 +335,31 @@ " 5045\n", " 4.899192\n", " 4.949596\n", - " -0.000753\n", + " 0.025322\n", " \n", " \n", " 5046\n", " 4.899192\n", " 5.000000\n", - " 0.037958\n", + " 0.022726\n", " \n", " \n", " 5047\n", " 4.949596\n", " 4.949596\n", - " -0.013647\n", + " 0.000098\n", " \n", " \n", " 5048\n", " 4.949596\n", " 5.000000\n", - " 0.020839\n", + " 0.002932\n", " \n", " \n", " 5049\n", " 5.000000\n", " 5.000000\n", - " -0.021462\n", + " -0.010160\n", " \n", " \n", "\n", @@ -368,17 +368,17 @@ ], "text/plain": [ " S1 S2 difference_detected\n", - "0 0.010000 0.010000 -0.000829\n", - "1 0.010000 0.060404 1.806354\n", - "2 0.010000 0.110808 2.406270\n", - "3 0.010000 0.161212 2.774411\n", - "4 0.010000 0.211616 3.056933\n", + "0 0.010000 0.010000 -0.013734\n", + "1 0.010000 0.060404 1.788438\n", + "2 0.010000 0.110808 2.403593\n", + "3 0.010000 0.161212 2.781474\n", + "4 0.010000 0.211616 3.055966\n", "... ... ... ...\n", - "5045 4.899192 4.949596 -0.000753\n", - "5046 4.899192 5.000000 0.037958\n", - "5047 4.949596 4.949596 -0.013647\n", - "5048 4.949596 5.000000 0.020839\n", - "5049 5.000000 5.000000 -0.021462\n", + "5045 4.899192 4.949596 0.025322\n", + "5046 4.899192 5.000000 0.022726\n", + "5047 4.949596 4.949596 0.000098\n", + "5048 4.949596 5.000000 0.002932\n", + "5049 5.000000 5.000000 -0.010160\n", "\n", "[5050 rows x 3 columns]" ] @@ -435,7 +435,7 @@ "outputs": [ { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -454,7 +454,34 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "These can be used to run a full experimental cycle" + "We can wrap this functions to use with the state logic of AutoRA:\n", + "First, we create the state with the variables:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from autora.state import StandardState, on_state, experiment_runner_on_state, estimator_on_state\n", + "from autora.experimentalist.grid import grid_pool\n", + "from autora.experimentalist.random import random_sample\n", + "from functools import partial\n", + "import random\n", + "\n", + "# We can get the variables from the runner\n", + "variables = s.variables\n", + "\n", + "# With the variables, we initialize a StandardState\n", + "state = StandardState(variables)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Wrap the experimentalists in `on_state` function to use them on state:" ] }, { @@ -466,49 +493,275 @@ "name": "stdout", "output_type": "stream", "text": [ - "finished cycle 1\n", - "finished cycle 2\n", - "finished cycle 3\n", - "I = -0.49 S0 +0.58 S1 -0.21\n" + " S1 S2\n", + "1484 0.715657 4.243939\n", + "5456 2.731818 2.832626\n", + "5539 2.782222 1.975758\n", + "6062 3.034242 3.135051\n", + "1622 0.816465 1.118889\n", + "5901 2.983838 0.060404\n", + "1936 0.967677 1.824545\n", + "6844 3.437475 2.227778\n", + "7073 3.538283 3.689495\n", + "2196 1.068485 4.848788\n", + "1710 0.866869 0.514040\n", + "3251 1.622929 2.580606\n", + "4298 2.126970 4.949596\n", + "7055 3.538283 2.782222\n", + "406 0.211616 0.312424\n", + "3787 1.874949 4.395152\n", + "4728 2.378990 1.421313\n", + "5214 2.631010 0.715657\n", + "1227 0.614848 1.370909\n", + "8482 4.243939 4.143131\n" ] } ], "source": [ - "from autora.workflow.protocol import ResultKind\n", - "from autora.experimentalist.pipeline import make_pipeline\n", - "from autora.experimentalist.pooler.grid import grid_pool\n", - "from autora.experimentalist.sampler.random_sampler import random_sample\n", - "from functools import partial\n", - "import random\n", - "variables = s.variables\n", - "pool = partial(grid_pool, ivs=variables.independent_variables)\n", - "random.seed(181) # set the seed for the random sampler\n", - "sampler = partial(random_sample, n=20)\n", - "experimentalist_pipeline = make_pipeline([pool, sampler])\n", + "# Wrap the functions to use on state\n", + "# Experimentalists:\n", + "pool_on_state = on_state(grid_pool, output=['conditions'])\n", + "sample_on_state = on_state(random_sample, output=['conditions'])\n", "\n", - "from autora.workflow import Controller\n", - "theorist = LinearRegression()\n", - "\n", - "cycle = Controller(\n", - " variables=variables, experimentalist=experimentalist_pipeline,\n", - " experiment_runner=s.experiment_runner, theorist=theorist,\n", - " monitor=lambda s: (s.history[-1].kind == ResultKind.MODEL) and\n", - " print(f\"finished cycle {len(s.models)}\"))\n", + "state = pool_on_state(state)\n", + "state = sample_on_state(state, num_samples=20)\n", + "print(state.conditions)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Wrap the runner with the `experiment_runner_on_state` wrapper to use it on state:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
S1S2difference_detected
14840.7156574.2439391.777697
54562.7318182.8326260.039988
55392.7822221.975758-0.344299
60623.0342423.1350510.036409
16220.8164651.1188890.287000
59012.9838380.060404-3.903989
19360.9676771.8245450.618795
68443.4374752.227778-0.432813
70733.5382833.6894950.035054
21961.0684854.8487881.514635
17100.8668690.514040-0.527993
32511.6229292.5806060.469571
42982.1269704.9495960.815710
70553.5382832.782222-0.237467
4060.2116160.3124240.373714
37871.8749494.3951520.830938
47282.3789901.421313-0.503627
52142.6310100.715657-1.289548
12270.6148481.3709090.803107
84824.2439394.143131-0.029045
\n", + "
" + ], + "text/plain": [ + " S1 S2 difference_detected\n", + "1484 0.715657 4.243939 1.777697\n", + "5456 2.731818 2.832626 0.039988\n", + "5539 2.782222 1.975758 -0.344299\n", + "6062 3.034242 3.135051 0.036409\n", + "1622 0.816465 1.118889 0.287000\n", + "5901 2.983838 0.060404 -3.903989\n", + "1936 0.967677 1.824545 0.618795\n", + "6844 3.437475 2.227778 -0.432813\n", + "7073 3.538283 3.689495 0.035054\n", + "2196 1.068485 4.848788 1.514635\n", + "1710 0.866869 0.514040 -0.527993\n", + "3251 1.622929 2.580606 0.469571\n", + "4298 2.126970 4.949596 0.815710\n", + "7055 3.538283 2.782222 -0.237467\n", + "406 0.211616 0.312424 0.373714\n", + "3787 1.874949 4.395152 0.830938\n", + "4728 2.378990 1.421313 -0.503627\n", + "5214 2.631010 0.715657 -1.289548\n", + "1227 0.614848 1.370909 0.803107\n", + "8482 4.243939 4.143131 -0.029045" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Runner:\n", + "run_on_state = experiment_runner_on_state(s.run)\n", + "state = run_on_state(state)\n", "\n", - "c = cycle.run(10)\n", - "best_model = c.state.models[-1]\n", - "print(f\"I = \"\n", - " f\"{best_model.coef_[0]:.2f} S0 \"\n", - " f\"{best_model.coef_[1]:+.2f} S1 \"\n", - " f\"{best_model.intercept_:+.2f}\")\n" + "state.experiment_data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Wrap the regressor with the `estimator_on_state` wrapper:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], - "source": [] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "I = -0.62*S0 +0.57*S1 -0.09 \n" + ] + } + ], + "source": [ + "theorist = LinearRegression()\n", + "theorist_on_state = estimator_on_state(theorist)\n", + "\n", + "state = theorist_on_state(state)\n", + "# Access the last model:\n", + "model = state.models[-1]\n", + "\n", + "\n", + "print(f\"I = \"\n", + " f\"{model.coef_[0][0]:.2f}*S0 \"\n", + " f\"{model.coef_[0][1]:+.2f}*S1 \"\n", + " f\"{model.intercept_[0]:+.2f} \")" + ] } ], "metadata": { diff --git a/src/autora/experiment_runner/synthetic/abstract/lmm.py b/src/autora/experiment_runner/synthetic/abstract/lmm.py index 1fc7c926..a2c869c6 100644 --- a/src/autora/experiment_runner/synthetic/abstract/lmm.py +++ b/src/autora/experiment_runner/synthetic/abstract/lmm.py @@ -37,7 +37,8 @@ >>> formula_2 = 'rt ~ x1' >>> fixed_effects_2 = {'x1': 2.} >>> experiment_2 = lmm_experiment(formula=formula_2,fixed_effects=fixed_effects_2) - >>> experiment_1.ground_truth(conditions=conditions) == experiment_2.ground_truth(conditions=conditions) + >>> experiment_1.ground_truth(conditions=conditions) ==\ +experiment_2.ground_truth(conditions=conditions) x1 rt 0 True True 1 True True @@ -48,7 +49,9 @@ >>> formula = 'rt ~ 1 + (1|subject) + x1' >>> fixed_effects = {'Intercept': 1, 'x1': 2} >>> random_effects = {'subject': {'Intercept': .1}} - >>> experiment = lmm_experiment(formula=formula,fixed_effects=fixed_effects,random_effects=random_effects) + >>> experiment = lmm_experiment(formula=formula, + ... fixed_effects=fixed_effects, + ... random_effects=random_effects) >>> conditions_1 = pd.DataFrame({ ... 'x1':np.linspace(0, 1, 3), ... 'subject': np.repeat(1, 3) @@ -90,7 +93,9 @@ >>> formula = 'rt ~ (x1|subject) + x1' >>> fixed_effects = {'x1': 1.} >>> random_effects = {'subject': {'x1': .01}} - >>> experiment = lmm_experiment(formula=formula,fixed_effects=fixed_effects,random_effects=random_effects) + >>> experiment = lmm_experiment(formula=formula, + ... fixed_effects=fixed_effects, + ... random_effects=random_effects) >>> experiment.ground_truth(conditions=conditions,random_state=42) x1 subject rt 0 0.0 1 0.000000 @@ -106,7 +111,9 @@ ... 'subject': {'1': 0.5, 'x1': 0.3}, ... 'group': {'x2': 0.4} ... } - >>> experiment = lmm_experiment(formula=formula, fixed_effects=fixed_effects,random_effects=random_effects) + >>> experiment = lmm_experiment(formula=formula, + ... fixed_effects=fixed_effects, + ... random_effects=random_effects) >>> n_samples = 10 >>> rng = np.random.default_rng(0) >>> conditions = pd.DataFrame({ @@ -143,10 +150,9 @@ """ - -from functools import partial -from typing import Optional, List import re +from functools import partial +from typing import List, Optional import numpy as np import pandas as pd @@ -173,7 +179,7 @@ def lmm_experiment( fixed_effects: dictionary describing the fixed effects (Intercept and slopes) random_effects: nested dictionary describing the random effects of slopes and intercept. These are standard deviasions in a normal distribution with a mean of zero. - X: Independent variable descriptions. Used to add allowed values + X: Independent variable descriptions. Used to add allowed values """ if not fixed_effects: @@ -186,17 +192,17 @@ def lmm_experiment( name=name, formula=formula, fixed_effects=fixed_effects, - random_effects=random_effects + random_effects=random_effects, ) dependent, fixed_variables, random_variables = _extract_variable_names(formula) dependent = DV(name=dependent) - x = [IV(name=f) for f in fixed_variables] + [IV(name=r) for r in random_variables] + # x = [IV(name=f) for f in fixed_variables] + [IV(name=r) for r in random_variables] + # + # if X: + # x = X - if X: - x = X - variables = VariableCollection( independent_variables=[X], dependent_variables=[dependent], @@ -216,28 +222,32 @@ def run( rng_ = np.random.default_rng(random_state) else: rng_ = rng # use the RNG from the outer scope - - - dependent_var, rhs = formula.split('~') + + dependent_var, rhs = formula.split("~") dependent_var = dependent_var.strip() fixed_vars = fixed_variables - # Check for the presence of an intercept in the formula - has_intercept = True if '1' in fixed_effects or re.search(r'\b0\b', rhs) is None else False + has_intercept = ( + True if "1" in fixed_effects or re.search(r"\b0\b", rhs) is None else False + ) experiment_data = conditions.copy() # Initialize the dependent variable - experiment_data[dependent_var] = fixed_effects.get('Intercept', 0) if has_intercept else 0 + experiment_data[dependent_var] = ( + fixed_effects.get("Intercept", 0) if has_intercept else 0 + ) # Add fixed effects for var in fixed_vars: if var in experiment_data.columns: - experiment_data[dependent_var] += fixed_effects.get(var, 0) * experiment_data[var] + experiment_data[dependent_var] += ( + fixed_effects.get(var, 0) * experiment_data[var] + ) # Process each random effect term - random_effect_terms = re.findall(r'\((.+?)\|(.+?)\)', formula) + random_effect_terms = re.findall(r"\((.+?)\|(.+?)\)", formula) for term in random_effect_terms: random_effects_, group_var = term group_var = group_var.strip() @@ -247,25 +257,36 @@ def run( raise ValueError(f"Group variable '{group_var}' not found in the data") # Process each part of the random effect (intercept and slopes) - for part in random_effects_.split('+'): - part = 'Intercept' if part == '1' else part + for part in random_effects_.split("+"): + part = "Intercept" if part == "1" else part part = part.strip() std_dev = random_effects[group_var].get(part, 0.5) - random_effect_values = {group: rng_.normal(0, std_dev) for group in experiment_data[group_var].unique()} - if part == 'Intercept': # Random intercept + random_effect_values = { + group: rng_.normal(0, std_dev) + for group in experiment_data[group_var].unique() + } + if part == "Intercept": # Random intercept if has_intercept: - experiment_data[dependent_var] += experiment_data[group_var].map(random_effect_values) + experiment_data[dependent_var] += experiment_data[ + group_var + ].map(random_effect_values) else: # Random slopes if part in experiment_data.columns: - experiment_data[dependent_var] += experiment_data[group_var].map(random_effect_values) * experiment_data[part] + experiment_data[dependent_var] += ( + experiment_data[group_var].map(random_effect_values) + * experiment_data[part] + ) # Add noise - experiment_data[dependent_var] += rng_.normal(0, added_noise, len(experiment_data)) + experiment_data[dependent_var] += rng_.normal( + 0, added_noise, len(experiment_data) + ) return experiment_data ground_truth = partial(run, added_noise=0.0) - """A function which simulates perfect observations. This still uses random values for random effects.""" + """A function which simulates perfect observations. + This still uses random values for random effects.""" def domain(): """A function which returns all possible independent variable values as a 2D array.""" @@ -279,9 +300,9 @@ def plotter(model=None): plt.figure() dom = domain() data = ground_truth(dom) - - y = data[depedent] - x = data.drop(depenent, axis=1) + + y = data[dependent] + x = data.drop(dependent, axis=1) if x.shape[1] > 2: Exception( @@ -332,7 +353,8 @@ def _extract_variable_names(formula): formula (str): Formula specifying the model, e.g., 'y ~ x1 + x2 + (1 + x1|group) + (x2|subject)' Returns: - tuple of (list, list): A tuple containing two lists - one for fixed effects and another for random effects. + tuple of (list, list): A tuple containing two lists - one for fixed effects and another for + random effects. Examples: >>> formula_1 = 'y ~ x1 + x2 + (1 + x1|group) + (x2|subject)' >>> _extract_variable_names(formula_1) @@ -349,19 +371,24 @@ def _extract_variable_names(formula): """ # Extract the right-hand side of the formula - dependent, rhs = formula.split('~') + dependent, rhs = formula.split("~") dependent = dependent.strip() - fixed_effects = re.findall(r'[a-z]\w*(?![^\(]*\))', rhs) # Matches variables outside parentheses - random_effects = re.findall(r'\(([^\|]+)\|([^\)]+)\)', rhs) # Matches random effects groups + fixed_effects = re.findall( + r"[a-z]\w*(?![^\(]*\))", rhs + ) # Matches variables outside parentheses + random_effects = re.findall( + r"\(([^\|]+)\|([^\)]+)\)", rhs + ) # Matches random effects groups # Include variables from random effects in fixed effects and make unique for reffect in random_effects: - fixed_effects.extend(reffect[0].replace('1 + ', '').split('+')) - + fixed_effects.extend(reffect[0].replace("1 + ", "").split("+")) + # Removing duplicates and stripping whitespaces fixed_effects = sorted(list(set([effect.strip() for effect in fixed_effects]))) - random_groups = sorted(list(set([reffect[1].strip() for reffect in random_effects]))) - + random_groups = sorted( + list(set([reffect[1].strip() for reffect in random_effects])) + ) return dependent, fixed_effects, random_groups diff --git a/src/autora/experiment_runner/synthetic/psychology/q_learning.py b/src/autora/experiment_runner/synthetic/psychology/q_learning.py index c07650f2..ded28d3f 100644 --- a/src/autora/experiment_runner/synthetic/psychology/q_learning.py +++ b/src/autora/experiment_runner/synthetic/psychology/q_learning.py @@ -7,10 +7,13 @@ from autora.experiment_runner.synthetic.utilities import SyntheticExperimentCollection from autora.variable import DV, IV, ValueType, VariableCollection + def _check_in_0_1_range(x, name): - if not (0 <= x <= 1): - raise ValueError( - f'Value of {name} must be in [0, 1] range. Found value of {x}.') + if not (0 <= x <= 1): + raise ValueError( + f"Value of {name} must be in [0, 1] range. Found value of {x}." + ) + class AgentQ: """An agent that runs simple Q-learning for an n-armed bandits tasks. @@ -22,13 +25,13 @@ class AgentQ: """ def __init__( - self, - alpha: float = 0.2, - beta: float = 3., - n_actions: int = 2, - forget_rate: float = 0., - perseverance_bias: float = 0., - correlated_reward: bool = False, + self, + alpha: float = 0.2, + beta: float = 3.0, + n_actions: int = 2, + forget_rate: float = 0.0, + perseverance_bias: float = 0.0, + correlated_reward: bool = False, ): """Update the agent after one step of the task. @@ -49,8 +52,8 @@ def __init__( self._q_init = 0.5 self.new_sess() - _check_in_0_1_range(alpha, 'alpha') - _check_in_0_1_range(forget_rate, 'forget_rate') + _check_in_0_1_range(alpha, "alpha") + _check_in_0_1_range(forget_rate, "forget_rate") def new_sess(self): """Reset the agent for the beginning of a new session.""" @@ -69,9 +72,7 @@ def get_choice(self) -> int: choice = np.random.choice(self._n_actions, p=choice_probs) return choice - def update(self, - choice: int, - reward: float): + def update(self, choice: int, reward: float): """Update the agent after one step of the task. Args: @@ -82,15 +83,17 @@ def update(self, # Forgetting - restore q-values of non-chosen actions towards the initial value non_chosen_action = np.arange(self._n_actions) != choice self._q[non_chosen_action] = (1 - self._forget_rate) * self._q[ - non_chosen_action] + self._forget_rate * self._q_init + non_chosen_action + ] + self._forget_rate * self._q_init # Reward-based update - Update chosen q for chosen action with observed reward - q_reward_update = - self._alpha * self._q[choice] + self._alpha * reward + q_reward_update = -self._alpha * self._q[choice] + self._alpha * reward # Correlated update - Update non-chosen q for non-chosen action with observed reward if self._correlated_reward: # index_correlated_update = self._n_actions - choice - 1 - # self._q[index_correlated_update] = (1 - self._alpha) * self._q[index_correlated_update] + self._alpha * (1 - reward) + # self._q[index_correlated_update] = + # (1 - self._alpha) * self._q[index_correlated_update] + self._alpha * (1 - reward) # alternative implementation - not dependent on reward but on reward-based update index_correlated_update = self._n_actions - 1 - choice self._q[index_correlated_update] -= 0.5 * q_reward_update @@ -111,10 +114,10 @@ def q(self): def q_learning( name="Q-Learning", learning_rate: float = 0.2, - decision_noise: float = 3., + decision_noise: float = 3.0, n_actions: int = 2, - forget_rate: float = 0., - perseverance_bias: float = 0., + forget_rate: float = 0.0, + perseverance_bias: float = 0.0, correlated_reward: bool = False, ): """ @@ -136,7 +139,8 @@ def q_learning( # The runner can accept numpy arrays or pandas DataFrames, but the return value will # always be a list of numpy arrays. Each array corresponds to the choices made by the agent # for each trial in the input. Thus, arrays have shape (n_trials, n_actions). - >>> experiment.run(np.array([[0, 1], [0, 1], [0, 1], [1, 0], [1, 0], [1, 0]]), random_state=42) + >>> experiment.run(np.array([[0, 1], [0, 1], [0, 1], [1, 0], [1, 0], [1, 0]]), + ... random_state=42) [array([[1., 0.], [0., 1.], [0., 1.], @@ -147,7 +151,10 @@ def q_learning( # The runner can accept pandas DataFrames. Each cell of the DataFrame should contain a # numpy array with shape (n_trials, n_actions). The return value will be a list of numpy # arrays, each corresponding to the choices made by the agent for each trial in the input. - >>> experiment.run(pd.DataFrame({'reward array': [np.array([[0, 1], [0, 1], [0, 1], [1, 0], [1, 0], [1, 0]])]}), random_state = 42) + >>> experiment.run( + ... pd.DataFrame( + ... {'reward array': [np.array([[0, 1], [0, 1], [0, 1], [1, 0], [1, 0], [1, 0]])]}), + ... random_state = 42) [array([[1., 0.], [0., 1.], [0., 1.], @@ -159,12 +166,12 @@ def q_learning( params = dict( name=name, trials=100, - learning_rate = learning_rate, - decision_noise = decision_noise, - n_actions = n_actions, - forget_rate = forget_rate, - perseverance_bias = perseverance_bias, - correlated_reward = correlated_reward, + learning_rate=learning_rate, + decision_noise=decision_noise, + n_actions=n_actions, + forget_rate=forget_rate, + perseverance_bias=perseverance_bias, + correlated_reward=correlated_reward, ) iv1 = IV( @@ -187,9 +194,11 @@ def q_learning( ) def run_AgentQ(rewards): - if (rewards.shape[1] != n_actions): - Warning("Number of actions in rewards does not match n_actions. Will use " + str(rewards.shape[1] - + " actions.")) + if rewards.shape[1] != n_actions: + Warning( + "Number of actions in rewards does not match n_actions. Will use " + + str(rewards.shape[1] + " actions.") + ) num_trials = rewards.shape[0] y = np.zeros(rewards.shape) @@ -216,9 +225,8 @@ def run_AgentQ(rewards): def run( conditions: Union[pd.DataFrame, np.ndarray, np.recarray], random_state: Optional[int] = None, - return_choice_probabilities = False, + return_choice_probabilities=False, ): - if random_state is not None: np.random.seed(random_state) @@ -256,5 +264,3 @@ def domain(): factory_function=q_learning, ) return collection - -