From 827bdfefd0e89eb543cd9b379ae0737ee5da6e85 Mon Sep 17 00:00:00 2001 From: Ethan Lew Date: Sun, 17 Mar 2024 14:44:12 -0700 Subject: [PATCH] play with Lie derivative observables and abstraction (#83) --- notebooks/Koopman-Lie.ipynb | 343 ++++++++++++++++++++++++++++++++++++ 1 file changed, 343 insertions(+) create mode 100644 notebooks/Koopman-Lie.ipynb diff --git a/notebooks/Koopman-Lie.ipynb b/notebooks/Koopman-Lie.ipynb new file mode 100644 index 0000000..ac7339a --- /dev/null +++ b/notebooks/Koopman-Lie.ipynb @@ -0,0 +1,343 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "f4f243e3-da0b-47e8-85c4-a9c49f3ae4ee", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append(\"..\")\n", + "\n", + "import autokoopman.benchmark.fhn as pfhn\n", + "import autokoopman as ak\n", + "import sympy as sp\n", + "\n", + "from itertools import product\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aa258ddd-cd57-4f9a-8843-cced8edede12", + "metadata": {}, + "outputs": [], + "source": [ + "def lie_derivative(f, F, variables):\n", + " return sp.expand(sum(sp.diff(f, xi)*F[i] for i, xi in enumerate(variables)))\n", + " \n", + "def update_basis(F, basis, variables):\n", + " # generate the candidate function\n", + " basis_coef = sp.symbols(\" \".join([f\"c{i}\" for i in range(0, len(basis))]))\n", + " f = sum(ci*bi for ci, bi in zip(basis_coef, basis))\n", + "\n", + " # lie derivative\n", + " lf = lie_derivative(f, F, variables)\n", + "\n", + " # remove terms that do not belong to span(*F)\n", + " subs = []\n", + " for term in lf.args:\n", + " in_span = False\n", + " for bf in basis:\n", + " if not any([(term / bf).has(xi) for xi in variables]):\n", + " in_span = True\n", + " if not in_span:\n", + " subs.append([term.has(ci) for ci in basis_coef].index(True))\n", + " return [bi for i, bi in enumerate(basis) if i not in subs]\n", + "\n", + "\n", + "def obs(F, basis, variables):\n", + " return [lie_derivative(bf, F, variables) for bf in basis]\n", + "\n", + "\n", + "def generate_monomials(variables, order):\n", + " monomials = set()\n", + " # Generate all combinations of powers including and up to the order for each variable\n", + " for powers in product(range(order + 1), repeat=len(variables)):\n", + " if sum(powers) <= order:\n", + " monomial = prod([var**power for var, power in zip(variables, powers)])\n", + " monomials.add(monomial)\n", + " return list(monomials)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41dac27a-1e16-4150-9206-3bbff7c84b9e", + "metadata": {}, + "outputs": [], + "source": [ + "fhn = pfhn.FitzHughNagumo()\n", + "variables = fhn._variables[1:]\n", + "exprs = fhn._exprs\n", + "x0, x1 = variables\n", + "\n", + "basis = generate_monomials(variables, 3)\n", + "#basis = update_basis(exprs, basis, variables)\n", + "obs(exprs, basis, variables)[1:]\n", + "variables" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "948204d9-e295-4e96-baf2-47791e571111", + "metadata": {}, + "outputs": [], + "source": [ + "lie_obs = ak.core.observables.SymbolicObservable(variables, [x0, x1] + obs(exprs, basis, variables)[1:])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9c05d7bc-5f71-406c-9c87-bf928737c875", + "metadata": {}, + "outputs": [], + "source": [ + "training_data = fhn.solve_ivps(\n", + " initial_states=np.random.uniform(low=-2.0, high=2.0, size=(30, 2)),\n", + " tspan=[0.0, 1.0],\n", + " sampling_period=0.1\n", + ")\n", + "\n", + "# learn model from data\n", + "experiment_results = ak.auto_koopman(\n", + " training_data, # list of trajectories\n", + " sampling_period=0.1, # sampling period of trajectory snapshots\n", + " obs_type=lie_obs, # use Random Fourier Features Observables\n", + " opt=\"monte-carlo\", # grid search to find best hyperparameters\n", + " n_obs=200, # maximum number of observables to try\n", + " max_opt_iter=200, # maximum number of optimization iterations\n", + " grid_param_slices=5, # for grid search, number of slices for each parameter\n", + " n_splits=5, # k-folds validation for tuning, helps stabilize the scoring\n", + " normalize=False,\n", + " rank=(1, 200, 40) # rank range (start, stop, step) DMD hyperparameter\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f9ca2671-5dca-4a57-a833-540fb5825406", + "metadata": {}, + "outputs": [], + "source": [ + "# get the model from the experiment results\n", + "model = experiment_results['tuned_model']\n", + "\n", + "testing_data = fhn.solve_ivps(\n", + " initial_states=np.random.uniform(low=-2.0, high=2.0, size=(30, 2)),\n", + " tspan=[0.0, 1.0],\n", + " sampling_period=0.1\n", + ")\n", + "\n", + "# simulate using the learned model\n", + "iv = [0.5, 0.1]\n", + "prediction_data = model.solve_ivps(\n", + " initial_states=[t.states[0] for t in testing_data],\n", + " tspan=(0.0, 1.0),\n", + " sampling_period=0.1\n", + ")\n", + "\n", + "\n", + "\n", + "# plot the results\n", + "for trajectory in prediction_data:\n", + " plt.plot(*trajectory.states.T, 'r')\n", + "for true_trajectory in testing_data:\n", + " plt.plot(*true_trajectory.states.T, 'k')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "92f83dbd-8f22-49c0-ac4f-b52e1dc48d6d", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cddc661b-db83-49be-a8c3-147b4d3be803", + "metadata": {}, + "outputs": [], + "source": [ + "MAX_ITER = 10\n", + "\n", + "basis = generate_monomials(variables, 3) + [sp.exp(-x0**2)*sp.exp(-x1), sp.exp(-x1**2)]\n", + "print(basis)\n", + "old_len = len(basis)\n", + "for _ in range(MAX_ITER):\n", + " print(_)\n", + " basis = update_basis(exprs, basis, variables)\n", + " if len(basis) == old_len:\n", + " break\n", + " old_len = len(basis)\n", + " print(basis)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f3dd932d-1e2e-44be-b157-23fddd38e0aa", + "metadata": {}, + "outputs": [], + "source": [ + "variables = sp.symbols('x y')\n", + "x, y = variables\n", + "F = [x*y+2*x, -1/2*y**2+7*y+1] \n", + "basis = [1, x, y, x*y, x**2, y**2, x**2*y, x*y**2]\n", + "\n", + "MAX_ITER = 10\n", + "\n", + "basis = generate_monomials(variables, 3)\n", + "old_len = len(basis)\n", + "for _ in range(MAX_ITER):\n", + " print(_)\n", + " basis = update_basis(F, basis, variables)\n", + " if len(basis) == old_len:\n", + " break\n", + " old_len = len(basis)\n", + " print(basis)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ee5d4a4b-1fc4-42e7-b438-bc53d79f4d06", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5c279f0a-e719-4d0f-b64f-49df1d1814bf", + "metadata": {}, + "outputs": [], + "source": [ + "lf = lie_derivative(f, F, variables)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "86c58731-3c95-47bb-af43-e97061bab3e1", + "metadata": {}, + "outputs": [], + "source": [ + "subs = []\n", + "for term in lf.args:\n", + " in_span = False\n", + " for bf in basis:\n", + " if not (term / bf).has(x) and not (term / bf).has(y):\n", + " in_span = True\n", + " if not in_span:\n", + " subs.append([term.has(ci) for ci in basis_coef].index(True))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f4be5da5-c9d0-4928-a681-ecd4719a2ee7", + "metadata": {}, + "outputs": [], + "source": [ + "basis = [bi for i, bi in enumerate(basis) if i not in subs]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ab9cf0c8-32f1-427f-848e-08c5aad775b4", + "metadata": {}, + "outputs": [], + "source": [ + "basis_coef = sp.symbols(\" \".join([f\"c{i}\" for i in range(0, len(basis))]))\n", + "alpha_coef = sp.symbols(\" \".join([f\"a{i}\" for i in range(0, len(basis))]))\n", + "f = sum(ci*bi for ci, bi in zip(basis_coef, basis))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6e4bac27-e589-472a-b74b-b9ccf585d6f3", + "metadata": {}, + "outputs": [], + "source": [ + "lf = lie_derivative(f, F, variables)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "54cb453d-8d7a-4c3b-8662-2ca5d066e2c7", + "metadata": {}, + "outputs": [], + "source": [ + "subs = []\n", + "for term in lf.args:\n", + " in_span = False\n", + " for bf in basis:\n", + " if not (term / bf).has(x) and not (term / bf).has(y):\n", + " in_span = True\n", + " if not in_span:\n", + " subs.append([term.has(ci) for ci in basis_coef].index(True))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ea008ea3-a141-435f-a7e5-fba0958da216", + "metadata": {}, + "outputs": [], + "source": [ + "basis = [bi for i, bi in enumerate(basis) if i not in subs]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1306b1cf-08f7-4100-bc8e-67be999f82e5", + "metadata": {}, + "outputs": [], + "source": [ + "basis" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0098f2e2-7c17-414e-85e1-5f3bfcade1d7", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}