Skip to content

Commit

Permalink
play with Lie derivative observables and abstraction (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
EthanJamesLew committed Mar 17, 2024
1 parent 4a8a327 commit 827bdfe
Showing 1 changed file with 343 additions and 0 deletions.
343 changes: 343 additions & 0 deletions notebooks/Koopman-Lie.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,343 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "f4f243e3-da0b-47e8-85c4-a9c49f3ae4ee",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append(\"..\")\n",
"\n",
"import autokoopman.benchmark.fhn as pfhn\n",
"import autokoopman as ak\n",
"import sympy as sp\n",
"\n",
"from itertools import product\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aa258ddd-cd57-4f9a-8843-cced8edede12",
"metadata": {},
"outputs": [],
"source": [
"def lie_derivative(f, F, variables):\n",
" return sp.expand(sum(sp.diff(f, xi)*F[i] for i, xi in enumerate(variables)))\n",
" \n",
"def update_basis(F, basis, variables):\n",
" # generate the candidate function\n",
" basis_coef = sp.symbols(\" \".join([f\"c{i}\" for i in range(0, len(basis))]))\n",
" f = sum(ci*bi for ci, bi in zip(basis_coef, basis))\n",
"\n",
" # lie derivative\n",
" lf = lie_derivative(f, F, variables)\n",
"\n",
" # remove terms that do not belong to span(*F)\n",
" subs = []\n",
" for term in lf.args:\n",
" in_span = False\n",
" for bf in basis:\n",
" if not any([(term / bf).has(xi) for xi in variables]):\n",
" in_span = True\n",
" if not in_span:\n",
" subs.append([term.has(ci) for ci in basis_coef].index(True))\n",
" return [bi for i, bi in enumerate(basis) if i not in subs]\n",
"\n",
"\n",
"def obs(F, basis, variables):\n",
" return [lie_derivative(bf, F, variables) for bf in basis]\n",
"\n",
"\n",
"def generate_monomials(variables, order):\n",
" monomials = set()\n",
" # Generate all combinations of powers including and up to the order for each variable\n",
" for powers in product(range(order + 1), repeat=len(variables)):\n",
" if sum(powers) <= order:\n",
" monomial = prod([var**power for var, power in zip(variables, powers)])\n",
" monomials.add(monomial)\n",
" return list(monomials)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "41dac27a-1e16-4150-9206-3bbff7c84b9e",
"metadata": {},
"outputs": [],
"source": [
"fhn = pfhn.FitzHughNagumo()\n",
"variables = fhn._variables[1:]\n",
"exprs = fhn._exprs\n",
"x0, x1 = variables\n",
"\n",
"basis = generate_monomials(variables, 3)\n",
"#basis = update_basis(exprs, basis, variables)\n",
"obs(exprs, basis, variables)[1:]\n",
"variables"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "948204d9-e295-4e96-baf2-47791e571111",
"metadata": {},
"outputs": [],
"source": [
"lie_obs = ak.core.observables.SymbolicObservable(variables, [x0, x1] + obs(exprs, basis, variables)[1:])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9c05d7bc-5f71-406c-9c87-bf928737c875",
"metadata": {},
"outputs": [],
"source": [
"training_data = fhn.solve_ivps(\n",
" initial_states=np.random.uniform(low=-2.0, high=2.0, size=(30, 2)),\n",
" tspan=[0.0, 1.0],\n",
" sampling_period=0.1\n",
")\n",
"\n",
"# learn model from data\n",
"experiment_results = ak.auto_koopman(\n",
" training_data, # list of trajectories\n",
" sampling_period=0.1, # sampling period of trajectory snapshots\n",
" obs_type=lie_obs, # use Random Fourier Features Observables\n",
" opt=\"monte-carlo\", # grid search to find best hyperparameters\n",
" n_obs=200, # maximum number of observables to try\n",
" max_opt_iter=200, # maximum number of optimization iterations\n",
" grid_param_slices=5, # for grid search, number of slices for each parameter\n",
" n_splits=5, # k-folds validation for tuning, helps stabilize the scoring\n",
" normalize=False,\n",
" rank=(1, 200, 40) # rank range (start, stop, step) DMD hyperparameter\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f9ca2671-5dca-4a57-a833-540fb5825406",
"metadata": {},
"outputs": [],
"source": [
"# get the model from the experiment results\n",
"model = experiment_results['tuned_model']\n",
"\n",
"testing_data = fhn.solve_ivps(\n",
" initial_states=np.random.uniform(low=-2.0, high=2.0, size=(30, 2)),\n",
" tspan=[0.0, 1.0],\n",
" sampling_period=0.1\n",
")\n",
"\n",
"# simulate using the learned model\n",
"iv = [0.5, 0.1]\n",
"prediction_data = model.solve_ivps(\n",
" initial_states=[t.states[0] for t in testing_data],\n",
" tspan=(0.0, 1.0),\n",
" sampling_period=0.1\n",
")\n",
"\n",
"\n",
"\n",
"# plot the results\n",
"for trajectory in prediction_data:\n",
" plt.plot(*trajectory.states.T, 'r')\n",
"for true_trajectory in testing_data:\n",
" plt.plot(*true_trajectory.states.T, 'k')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "92f83dbd-8f22-49c0-ac4f-b52e1dc48d6d",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "cddc661b-db83-49be-a8c3-147b4d3be803",
"metadata": {},
"outputs": [],
"source": [
"MAX_ITER = 10\n",
"\n",
"basis = generate_monomials(variables, 3) + [sp.exp(-x0**2)*sp.exp(-x1), sp.exp(-x1**2)]\n",
"print(basis)\n",
"old_len = len(basis)\n",
"for _ in range(MAX_ITER):\n",
" print(_)\n",
" basis = update_basis(exprs, basis, variables)\n",
" if len(basis) == old_len:\n",
" break\n",
" old_len = len(basis)\n",
" print(basis)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f3dd932d-1e2e-44be-b157-23fddd38e0aa",
"metadata": {},
"outputs": [],
"source": [
"variables = sp.symbols('x y')\n",
"x, y = variables\n",
"F = [x*y+2*x, -1/2*y**2+7*y+1] \n",
"basis = [1, x, y, x*y, x**2, y**2, x**2*y, x*y**2]\n",
"\n",
"MAX_ITER = 10\n",
"\n",
"basis = generate_monomials(variables, 3)\n",
"old_len = len(basis)\n",
"for _ in range(MAX_ITER):\n",
" print(_)\n",
" basis = update_basis(F, basis, variables)\n",
" if len(basis) == old_len:\n",
" break\n",
" old_len = len(basis)\n",
" print(basis)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ee5d4a4b-1fc4-42e7-b438-bc53d79f4d06",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "5c279f0a-e719-4d0f-b64f-49df1d1814bf",
"metadata": {},
"outputs": [],
"source": [
"lf = lie_derivative(f, F, variables)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "86c58731-3c95-47bb-af43-e97061bab3e1",
"metadata": {},
"outputs": [],
"source": [
"subs = []\n",
"for term in lf.args:\n",
" in_span = False\n",
" for bf in basis:\n",
" if not (term / bf).has(x) and not (term / bf).has(y):\n",
" in_span = True\n",
" if not in_span:\n",
" subs.append([term.has(ci) for ci in basis_coef].index(True))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f4be5da5-c9d0-4928-a681-ecd4719a2ee7",
"metadata": {},
"outputs": [],
"source": [
"basis = [bi for i, bi in enumerate(basis) if i not in subs]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ab9cf0c8-32f1-427f-848e-08c5aad775b4",
"metadata": {},
"outputs": [],
"source": [
"basis_coef = sp.symbols(\" \".join([f\"c{i}\" for i in range(0, len(basis))]))\n",
"alpha_coef = sp.symbols(\" \".join([f\"a{i}\" for i in range(0, len(basis))]))\n",
"f = sum(ci*bi for ci, bi in zip(basis_coef, basis))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6e4bac27-e589-472a-b74b-b9ccf585d6f3",
"metadata": {},
"outputs": [],
"source": [
"lf = lie_derivative(f, F, variables)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "54cb453d-8d7a-4c3b-8662-2ca5d066e2c7",
"metadata": {},
"outputs": [],
"source": [
"subs = []\n",
"for term in lf.args:\n",
" in_span = False\n",
" for bf in basis:\n",
" if not (term / bf).has(x) and not (term / bf).has(y):\n",
" in_span = True\n",
" if not in_span:\n",
" subs.append([term.has(ci) for ci in basis_coef].index(True))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ea008ea3-a141-435f-a7e5-fba0958da216",
"metadata": {},
"outputs": [],
"source": [
"basis = [bi for i, bi in enumerate(basis) if i not in subs]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1306b1cf-08f7-4100-bc8e-67be999f82e5",
"metadata": {},
"outputs": [],
"source": [
"basis"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0098f2e2-7c17-414e-85e1-5f3bfcade1d7",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit 827bdfe

Please sign in to comment.