diff --git a/docs/cycle/Basic Introduction to Functions and States.ipynb b/docs/cycle/Basic Introduction to Functions and States.ipynb new file mode 100644 index 00000000..1d2f5c28 --- /dev/null +++ b/docs/cycle/Basic Introduction to Functions and States.ipynb @@ -0,0 +1,564 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Basic Introduction to Functions and States" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Using the functions and objects in `autora.state`, we can build flexible pipelines and cycles which operate on state\n", + "objects.\n", + "\n", + "## Theoretical Overview\n", + "\n", + "The fundamental idea is this:\n", + "- We define a \"state\" object $S$ which can be modified with a \"delta\" (a new result) $\\Delta S$.\n", + "- A new state at some point $i+1$ is $$S_{i+1} = S_i + \\Delta S_{i+1}$$\n", + "- The cycle state after $n$ steps is thus $$S_n = S_{0} + \\sum^{n}_{i=1} \\Delta S_{i}$$\n", + "\n", + "To represent $S$ and $\\Delta S$ in code, you can use `autora.state.delta.State` and `autora.state.delta.Delta`\n", + "respectively. To operate on these, we define functions.\n", + "\n", + "- Each operation in an AER cycle (theorist, experimentalist, experiment_runner, etc.) is implemented as a\n", + "function with $n$ arguments $s_j$ which are members of $S$ and $m$ others $a_k$ which are not.\n", + " $$ f(s_0, ..., s_n, a_0, ..., a_m) \\rightarrow \\Delta S_{i+1}$$\n", + "- There is a wrapper function $h$ (`autora.state.delta.wrap_to_use_state`) which changes the signature of $f$ to\n", + "require $S$ and aggregates the resulting $\\Delta S_{i+1}$\n", + " $$h\\left[f(s_0, ..., s_n, a_0, ..., a_m) \\rightarrow \\Delta\n", + "S_{i+1}\\right] \\rightarrow \\left[ f^\\prime(S_i, a_0, ..., a_m) \\rightarrow S_{i} + \\Delta\n", + "S_{i+1} = S_{i+1}\\right]$$\n", + "\n", + "- Assuming that the other arguments $a_k$ are provided by partial evaluation of the $f^\\prime$, the full AER cycle can\n", + "then be represented as:\n", + " $$S_n = f_n^\\prime(...f_2^\\prime(f_1^\\prime(S_0)))$$\n", + "\n", + "There are additional helper functions to wrap common experimentalists, experiment runners and theorists so that we\n", + "can define a full AER cycle using python notation as shown in the following example." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example\n", + "\n", + "First initialize the State. There are two variables `x` with a range [-10, 10] and `y` with an unspecified range." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from autora.state.bundled import BasicAERState\n", + "from autora.variable import VariableCollection, Variable\n", + "\n", + "s_0 = BasicAERState(\n", + " variables=VariableCollection(\n", + " independent_variables=[Variable(\"x\", value_range=(-10, 10))],\n", + " dependent_variables=[Variable(\"y\")]\n", + " )\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Specify the experimentalist. Use a standard function `random_pool_executor`.\n", + "This gets 5 independent random samples (by default, configurable using an argument)\n", + "from the value_range of the independent variables, and returns them in a DataFrame." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from autora.experimentalist.pooler.random_pooler import random_pool_executor\n", + "experimentalist = random_pool_executor" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Specify the experiment runner. This calculates a linear function, adds noise, assigns the value to the `y` column\n", + " in a new DataFrame." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from autora.state.delta import Delta, wrap_to_use_state\n", + "\n", + "rng = np.random.default_rng(180)\n", + "\n", + "@wrap_to_use_state\n", + "def experiment_runner(conditions: pd.DataFrame, c=[2, 4]):\n", + " x = conditions[\"x\"]\n", + " noise = rng.normal(0, 1, len(x))\n", + " y = c[0] + (c[1] * x) + noise\n", + " experiment_data = conditions.assign(y = y)\n", + " return Delta(experiment_data=experiment_data)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Specify a theorist, using a standard LinearRegression from scikit-learn." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.linear_model import LinearRegression\n", + "from autora.state.wrapper import theorist_from_estimator\n", + "\n", + "theorist = theorist_from_estimator(LinearRegression(fit_intercept=True))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Define the cycle: run the experimentalist, experiment_runner and theorist ten times." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "s_ = s_0\n", + "for i in range(10):\n", + " s_ = experimentalist(s_)\n", + " s_ = experiment_runner(s_)\n", + " s_ = theorist(s_)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The experiment_data has 50 entries (10 cycles and 5 samples per cycle):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
xy
0-4.451978-15.373958
10.3234872.561481
2-2.867211-10.516852
3-2.030568-5.247614
42.91379712.957584
5-7.340735-27.820030
6-6.019243-21.600574
7-8.893466-31.496807
86.61305627.020377
94.82541721.875249
10-9.992198-36.097453
11-1.097681-3.538933
126.57204529.078863
13-3.039432-9.749266
146.31386628.311789
152.80455512.014208
16-7.008751-27.139038
173.28621314.225707
18-8.826214-30.646008
19-9.652346-37.233317
20-0.370936-0.088444
21-6.641559-25.624469
22-7.938631-29.646345
231.2774327.965713
242.68448014.171408
25-0.450963-0.932371
26-4.497923-13.955542
27-8.923897-31.592700
28-9.873687-37.661495
295.83115526.193081
302.98574214.107186
31-0.3990531.001974
325.99589326.435367
332.13167011.344637
34-1.639935-4.308918
35-2.326959-5.789104
36-1.035607-3.114820
37-8.758742-31.689823
380.3667474.527129
391.9267329.679125
403.57705215.611630
41-9.588634-37.731120
42-7.100105-27.600941
432.46901511.837649
44-1.727297-5.464983
454.89455121.937380
46-3.799161-12.654000
472.70706211.246337
48-2.013533-7.202246
49-5.757174-22.951716
\n", + "
" + ], + "text/plain": [ + " x y\n", + "0 -4.451978 -15.373958\n", + "1 0.323487 2.561481\n", + "2 -2.867211 -10.516852\n", + "3 -2.030568 -5.247614\n", + "4 2.913797 12.957584\n", + "5 -7.340735 -27.820030\n", + "6 -6.019243 -21.600574\n", + "7 -8.893466 -31.496807\n", + "8 6.613056 27.020377\n", + "9 4.825417 21.875249\n", + "10 -9.992198 -36.097453\n", + "11 -1.097681 -3.538933\n", + "12 6.572045 29.078863\n", + "13 -3.039432 -9.749266\n", + "14 6.313866 28.311789\n", + "15 2.804555 12.014208\n", + "16 -7.008751 -27.139038\n", + "17 3.286213 14.225707\n", + "18 -8.826214 -30.646008\n", + "19 -9.652346 -37.233317\n", + "20 -0.370936 -0.088444\n", + "21 -6.641559 -25.624469\n", + "22 -7.938631 -29.646345\n", + "23 1.277432 7.965713\n", + "24 2.684480 14.171408\n", + "25 -0.450963 -0.932371\n", + "26 -4.497923 -13.955542\n", + "27 -8.923897 -31.592700\n", + "28 -9.873687 -37.661495\n", + "29 5.831155 26.193081\n", + "30 2.985742 14.107186\n", + "31 -0.399053 1.001974\n", + "32 5.995893 26.435367\n", + "33 2.131670 11.344637\n", + "34 -1.639935 -4.308918\n", + "35 -2.326959 -5.789104\n", + "36 -1.035607 -3.114820\n", + "37 -8.758742 -31.689823\n", + "38 0.366747 4.527129\n", + "39 1.926732 9.679125\n", + "40 3.577052 15.611630\n", + "41 -9.588634 -37.731120\n", + "42 -7.100105 -27.600941\n", + "43 2.469015 11.837649\n", + "44 -1.727297 -5.464983\n", + "45 4.894551 21.937380\n", + "46 -3.799161 -12.654000\n", + "47 2.707062 11.246337\n", + "48 -2.013533 -7.202246\n", + "49 -5.757174 -22.951716" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "s_.experiment_data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The fitted coefficients are close to the original intercept = 2, gradient = 4" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2.03390614] [[3.97374104]]\n" + ] + } + ], + "source": [ + "print(s_.model.intercept_, s_.model.coef_)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/mkdocs.yml b/mkdocs.yml index dd9f7238..daf8aeec 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -21,4 +21,5 @@ nav: - Home: 'experimentalists/sampler/random/index.md' - Quickstart: 'experimentalists/sampler/random/quickstart.md' - Cycle: + - Home: 'cycle/Basic Introduction to Functions and States.ipynb' - Functional: 'cycle/Linear and Cyclical Workflows using Functions and States.ipynb' diff --git a/src/autora/state/bundled.py b/src/autora/state/bundled.py new file mode 100644 index 00000000..51845b74 --- /dev/null +++ b/src/autora/state/bundled.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass, field +from typing import Optional + +import pandas as pd +from sklearn.base import BaseEstimator + +from autora.state.delta import State +from autora.variable import VariableCollection + + +@dataclass(frozen=True) +class BasicAERState(State): + variables: VariableCollection = field(metadata={"delta": "replace"}) + conditions: pd.Series = field( + default_factory=pd.Series, metadata={"delta": "replace"} + ) + experiment_data: pd.DataFrame = field( + default_factory=pd.DataFrame, metadata={"delta": "extend"} + ) + model: Optional[BaseEstimator] = field(default=None, metadata={"delta": "replace"})