diff --git a/docs/cycle/Basic Introduction to Functions and States.ipynb b/docs/cycle/Basic Introduction to Functions and States.ipynb new file mode 100644 index 00000000..1d2f5c28 --- /dev/null +++ b/docs/cycle/Basic Introduction to Functions and States.ipynb @@ -0,0 +1,564 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Basic Introduction to Functions and States" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Using the functions and objects in `autora.state`, we can build flexible pipelines and cycles which operate on state\n", + "objects.\n", + "\n", + "## Theoretical Overview\n", + "\n", + "The fundamental idea is this:\n", + "- We define a \"state\" object $S$ which can be modified with a \"delta\" (a new result) $\\Delta S$.\n", + "- A new state at some point $i+1$ is $$S_{i+1} = S_i + \\Delta S_{i+1}$$\n", + "- The cycle state after $n$ steps is thus $$S_n = S_{0} + \\sum^{n}_{i=1} \\Delta S_{i}$$\n", + "\n", + "To represent $S$ and $\\Delta S$ in code, you can use `autora.state.delta.State` and `autora.state.delta.Delta`\n", + "respectively. To operate on these, we define functions.\n", + "\n", + "- Each operation in an AER cycle (theorist, experimentalist, experiment_runner, etc.) is implemented as a\n", + "function with $n$ arguments $s_j$ which are members of $S$ and $m$ others $a_k$ which are not.\n", + " $$ f(s_0, ..., s_n, a_0, ..., a_m) \\rightarrow \\Delta S_{i+1}$$\n", + "- There is a wrapper function $h$ (`autora.state.delta.wrap_to_use_state`) which changes the signature of $f$ to\n", + "require $S$ and aggregates the resulting $\\Delta S_{i+1}$\n", + " $$h\\left[f(s_0, ..., s_n, a_0, ..., a_m) \\rightarrow \\Delta\n", + "S_{i+1}\\right] \\rightarrow \\left[ f^\\prime(S_i, a_0, ..., a_m) \\rightarrow S_{i} + \\Delta\n", + "S_{i+1} = S_{i+1}\\right]$$\n", + "\n", + "- Assuming that the other arguments $a_k$ are provided by partial evaluation of the $f^\\prime$, the full AER cycle can\n", + "then be represented as:\n", + " $$S_n = f_n^\\prime(...f_2^\\prime(f_1^\\prime(S_0)))$$\n", + "\n", + "There are additional helper functions to wrap common experimentalists, experiment runners and theorists so that we\n", + "can define a full AER cycle using python notation as shown in the following example." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example\n", + "\n", + "First initialize the State. There are two variables `x` with a range [-10, 10] and `y` with an unspecified range." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from autora.state.bundled import BasicAERState\n", + "from autora.variable import VariableCollection, Variable\n", + "\n", + "s_0 = BasicAERState(\n", + " variables=VariableCollection(\n", + " independent_variables=[Variable(\"x\", value_range=(-10, 10))],\n", + " dependent_variables=[Variable(\"y\")]\n", + " )\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Specify the experimentalist. Use a standard function `random_pool_executor`.\n", + "This gets 5 independent random samples (by default, configurable using an argument)\n", + "from the value_range of the independent variables, and returns them in a DataFrame." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from autora.experimentalist.pooler.random_pooler import random_pool_executor\n", + "experimentalist = random_pool_executor" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Specify the experiment runner. This calculates a linear function, adds noise, assigns the value to the `y` column\n", + " in a new DataFrame." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from autora.state.delta import Delta, wrap_to_use_state\n", + "\n", + "rng = np.random.default_rng(180)\n", + "\n", + "@wrap_to_use_state\n", + "def experiment_runner(conditions: pd.DataFrame, c=[2, 4]):\n", + " x = conditions[\"x\"]\n", + " noise = rng.normal(0, 1, len(x))\n", + " y = c[0] + (c[1] * x) + noise\n", + " experiment_data = conditions.assign(y = y)\n", + " return Delta(experiment_data=experiment_data)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Specify a theorist, using a standard LinearRegression from scikit-learn." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.linear_model import LinearRegression\n", + "from autora.state.wrapper import theorist_from_estimator\n", + "\n", + "theorist = theorist_from_estimator(LinearRegression(fit_intercept=True))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Define the cycle: run the experimentalist, experiment_runner and theorist ten times." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "s_ = s_0\n", + "for i in range(10):\n", + " s_ = experimentalist(s_)\n", + " s_ = experiment_runner(s_)\n", + " s_ = theorist(s_)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The experiment_data has 50 entries (10 cycles and 5 samples per cycle):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " | x | \n", + "y | \n", + "
---|---|---|
0 | \n", + "-4.451978 | \n", + "-15.373958 | \n", + "
1 | \n", + "0.323487 | \n", + "2.561481 | \n", + "
2 | \n", + "-2.867211 | \n", + "-10.516852 | \n", + "
3 | \n", + "-2.030568 | \n", + "-5.247614 | \n", + "
4 | \n", + "2.913797 | \n", + "12.957584 | \n", + "
5 | \n", + "-7.340735 | \n", + "-27.820030 | \n", + "
6 | \n", + "-6.019243 | \n", + "-21.600574 | \n", + "
7 | \n", + "-8.893466 | \n", + "-31.496807 | \n", + "
8 | \n", + "6.613056 | \n", + "27.020377 | \n", + "
9 | \n", + "4.825417 | \n", + "21.875249 | \n", + "
10 | \n", + "-9.992198 | \n", + "-36.097453 | \n", + "
11 | \n", + "-1.097681 | \n", + "-3.538933 | \n", + "
12 | \n", + "6.572045 | \n", + "29.078863 | \n", + "
13 | \n", + "-3.039432 | \n", + "-9.749266 | \n", + "
14 | \n", + "6.313866 | \n", + "28.311789 | \n", + "
15 | \n", + "2.804555 | \n", + "12.014208 | \n", + "
16 | \n", + "-7.008751 | \n", + "-27.139038 | \n", + "
17 | \n", + "3.286213 | \n", + "14.225707 | \n", + "
18 | \n", + "-8.826214 | \n", + "-30.646008 | \n", + "
19 | \n", + "-9.652346 | \n", + "-37.233317 | \n", + "
20 | \n", + "-0.370936 | \n", + "-0.088444 | \n", + "
21 | \n", + "-6.641559 | \n", + "-25.624469 | \n", + "
22 | \n", + "-7.938631 | \n", + "-29.646345 | \n", + "
23 | \n", + "1.277432 | \n", + "7.965713 | \n", + "
24 | \n", + "2.684480 | \n", + "14.171408 | \n", + "
25 | \n", + "-0.450963 | \n", + "-0.932371 | \n", + "
26 | \n", + "-4.497923 | \n", + "-13.955542 | \n", + "
27 | \n", + "-8.923897 | \n", + "-31.592700 | \n", + "
28 | \n", + "-9.873687 | \n", + "-37.661495 | \n", + "
29 | \n", + "5.831155 | \n", + "26.193081 | \n", + "
30 | \n", + "2.985742 | \n", + "14.107186 | \n", + "
31 | \n", + "-0.399053 | \n", + "1.001974 | \n", + "
32 | \n", + "5.995893 | \n", + "26.435367 | \n", + "
33 | \n", + "2.131670 | \n", + "11.344637 | \n", + "
34 | \n", + "-1.639935 | \n", + "-4.308918 | \n", + "
35 | \n", + "-2.326959 | \n", + "-5.789104 | \n", + "
36 | \n", + "-1.035607 | \n", + "-3.114820 | \n", + "
37 | \n", + "-8.758742 | \n", + "-31.689823 | \n", + "
38 | \n", + "0.366747 | \n", + "4.527129 | \n", + "
39 | \n", + "1.926732 | \n", + "9.679125 | \n", + "
40 | \n", + "3.577052 | \n", + "15.611630 | \n", + "
41 | \n", + "-9.588634 | \n", + "-37.731120 | \n", + "
42 | \n", + "-7.100105 | \n", + "-27.600941 | \n", + "
43 | \n", + "2.469015 | \n", + "11.837649 | \n", + "
44 | \n", + "-1.727297 | \n", + "-5.464983 | \n", + "
45 | \n", + "4.894551 | \n", + "21.937380 | \n", + "
46 | \n", + "-3.799161 | \n", + "-12.654000 | \n", + "
47 | \n", + "2.707062 | \n", + "11.246337 | \n", + "
48 | \n", + "-2.013533 | \n", + "-7.202246 | \n", + "
49 | \n", + "-5.757174 | \n", + "-22.951716 | \n", + "