-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
play with Lie derivative observables and abstraction (#83)
- Loading branch information
1 parent
4a8a327
commit 827bdfe
Showing
1 changed file
with
343 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,343 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f4f243e3-da0b-47e8-85c4-a9c49f3ae4ee", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import sys\n", | ||
"sys.path.append(\"..\")\n", | ||
"\n", | ||
"import autokoopman.benchmark.fhn as pfhn\n", | ||
"import autokoopman as ak\n", | ||
"import sympy as sp\n", | ||
"\n", | ||
"from itertools import product\n", | ||
"import numpy as np\n", | ||
"import matplotlib.pyplot as plt" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "aa258ddd-cd57-4f9a-8843-cced8edede12", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def lie_derivative(f, F, variables):\n", | ||
" return sp.expand(sum(sp.diff(f, xi)*F[i] for i, xi in enumerate(variables)))\n", | ||
" \n", | ||
"def update_basis(F, basis, variables):\n", | ||
" # generate the candidate function\n", | ||
" basis_coef = sp.symbols(\" \".join([f\"c{i}\" for i in range(0, len(basis))]))\n", | ||
" f = sum(ci*bi for ci, bi in zip(basis_coef, basis))\n", | ||
"\n", | ||
" # lie derivative\n", | ||
" lf = lie_derivative(f, F, variables)\n", | ||
"\n", | ||
" # remove terms that do not belong to span(*F)\n", | ||
" subs = []\n", | ||
" for term in lf.args:\n", | ||
" in_span = False\n", | ||
" for bf in basis:\n", | ||
" if not any([(term / bf).has(xi) for xi in variables]):\n", | ||
" in_span = True\n", | ||
" if not in_span:\n", | ||
" subs.append([term.has(ci) for ci in basis_coef].index(True))\n", | ||
" return [bi for i, bi in enumerate(basis) if i not in subs]\n", | ||
"\n", | ||
"\n", | ||
"def obs(F, basis, variables):\n", | ||
" return [lie_derivative(bf, F, variables) for bf in basis]\n", | ||
"\n", | ||
"\n", | ||
"def generate_monomials(variables, order):\n", | ||
" monomials = set()\n", | ||
" # Generate all combinations of powers including and up to the order for each variable\n", | ||
" for powers in product(range(order + 1), repeat=len(variables)):\n", | ||
" if sum(powers) <= order:\n", | ||
" monomial = prod([var**power for var, power in zip(variables, powers)])\n", | ||
" monomials.add(monomial)\n", | ||
" return list(monomials)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "41dac27a-1e16-4150-9206-3bbff7c84b9e", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"fhn = pfhn.FitzHughNagumo()\n", | ||
"variables = fhn._variables[1:]\n", | ||
"exprs = fhn._exprs\n", | ||
"x0, x1 = variables\n", | ||
"\n", | ||
"basis = generate_monomials(variables, 3)\n", | ||
"#basis = update_basis(exprs, basis, variables)\n", | ||
"obs(exprs, basis, variables)[1:]\n", | ||
"variables" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "948204d9-e295-4e96-baf2-47791e571111", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"lie_obs = ak.core.observables.SymbolicObservable(variables, [x0, x1] + obs(exprs, basis, variables)[1:])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "9c05d7bc-5f71-406c-9c87-bf928737c875", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"training_data = fhn.solve_ivps(\n", | ||
" initial_states=np.random.uniform(low=-2.0, high=2.0, size=(30, 2)),\n", | ||
" tspan=[0.0, 1.0],\n", | ||
" sampling_period=0.1\n", | ||
")\n", | ||
"\n", | ||
"# learn model from data\n", | ||
"experiment_results = ak.auto_koopman(\n", | ||
" training_data, # list of trajectories\n", | ||
" sampling_period=0.1, # sampling period of trajectory snapshots\n", | ||
" obs_type=lie_obs, # use Random Fourier Features Observables\n", | ||
" opt=\"monte-carlo\", # grid search to find best hyperparameters\n", | ||
" n_obs=200, # maximum number of observables to try\n", | ||
" max_opt_iter=200, # maximum number of optimization iterations\n", | ||
" grid_param_slices=5, # for grid search, number of slices for each parameter\n", | ||
" n_splits=5, # k-folds validation for tuning, helps stabilize the scoring\n", | ||
" normalize=False,\n", | ||
" rank=(1, 200, 40) # rank range (start, stop, step) DMD hyperparameter\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f9ca2671-5dca-4a57-a833-540fb5825406", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# get the model from the experiment results\n", | ||
"model = experiment_results['tuned_model']\n", | ||
"\n", | ||
"testing_data = fhn.solve_ivps(\n", | ||
" initial_states=np.random.uniform(low=-2.0, high=2.0, size=(30, 2)),\n", | ||
" tspan=[0.0, 1.0],\n", | ||
" sampling_period=0.1\n", | ||
")\n", | ||
"\n", | ||
"# simulate using the learned model\n", | ||
"iv = [0.5, 0.1]\n", | ||
"prediction_data = model.solve_ivps(\n", | ||
" initial_states=[t.states[0] for t in testing_data],\n", | ||
" tspan=(0.0, 1.0),\n", | ||
" sampling_period=0.1\n", | ||
")\n", | ||
"\n", | ||
"\n", | ||
"\n", | ||
"# plot the results\n", | ||
"for trajectory in prediction_data:\n", | ||
" plt.plot(*trajectory.states.T, 'r')\n", | ||
"for true_trajectory in testing_data:\n", | ||
" plt.plot(*true_trajectory.states.T, 'k')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "92f83dbd-8f22-49c0-ac4f-b52e1dc48d6d", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "cddc661b-db83-49be-a8c3-147b4d3be803", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"MAX_ITER = 10\n", | ||
"\n", | ||
"basis = generate_monomials(variables, 3) + [sp.exp(-x0**2)*sp.exp(-x1), sp.exp(-x1**2)]\n", | ||
"print(basis)\n", | ||
"old_len = len(basis)\n", | ||
"for _ in range(MAX_ITER):\n", | ||
" print(_)\n", | ||
" basis = update_basis(exprs, basis, variables)\n", | ||
" if len(basis) == old_len:\n", | ||
" break\n", | ||
" old_len = len(basis)\n", | ||
" print(basis)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f3dd932d-1e2e-44be-b157-23fddd38e0aa", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"variables = sp.symbols('x y')\n", | ||
"x, y = variables\n", | ||
"F = [x*y+2*x, -1/2*y**2+7*y+1] \n", | ||
"basis = [1, x, y, x*y, x**2, y**2, x**2*y, x*y**2]\n", | ||
"\n", | ||
"MAX_ITER = 10\n", | ||
"\n", | ||
"basis = generate_monomials(variables, 3)\n", | ||
"old_len = len(basis)\n", | ||
"for _ in range(MAX_ITER):\n", | ||
" print(_)\n", | ||
" basis = update_basis(F, basis, variables)\n", | ||
" if len(basis) == old_len:\n", | ||
" break\n", | ||
" old_len = len(basis)\n", | ||
" print(basis)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "ee5d4a4b-1fc4-42e7-b438-bc53d79f4d06", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "5c279f0a-e719-4d0f-b64f-49df1d1814bf", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"lf = lie_derivative(f, F, variables)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "86c58731-3c95-47bb-af43-e97061bab3e1", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"subs = []\n", | ||
"for term in lf.args:\n", | ||
" in_span = False\n", | ||
" for bf in basis:\n", | ||
" if not (term / bf).has(x) and not (term / bf).has(y):\n", | ||
" in_span = True\n", | ||
" if not in_span:\n", | ||
" subs.append([term.has(ci) for ci in basis_coef].index(True))\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f4be5da5-c9d0-4928-a681-ecd4719a2ee7", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"basis = [bi for i, bi in enumerate(basis) if i not in subs]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "ab9cf0c8-32f1-427f-848e-08c5aad775b4", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"basis_coef = sp.symbols(\" \".join([f\"c{i}\" for i in range(0, len(basis))]))\n", | ||
"alpha_coef = sp.symbols(\" \".join([f\"a{i}\" for i in range(0, len(basis))]))\n", | ||
"f = sum(ci*bi for ci, bi in zip(basis_coef, basis))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "6e4bac27-e589-472a-b74b-b9ccf585d6f3", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"lf = lie_derivative(f, F, variables)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "54cb453d-8d7a-4c3b-8662-2ca5d066e2c7", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"subs = []\n", | ||
"for term in lf.args:\n", | ||
" in_span = False\n", | ||
" for bf in basis:\n", | ||
" if not (term / bf).has(x) and not (term / bf).has(y):\n", | ||
" in_span = True\n", | ||
" if not in_span:\n", | ||
" subs.append([term.has(ci) for ci in basis_coef].index(True))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "ea008ea3-a141-435f-a7e5-fba0958da216", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"basis = [bi for i, bi in enumerate(basis) if i not in subs]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "1306b1cf-08f7-4100-bc8e-67be999f82e5", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"basis" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "0098f2e2-7c17-414e-85e1-5f3bfcade1d7", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.0" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |