Skip to content

Commit

Permalink
created a demo notebook for #84
Browse files Browse the repository at this point in the history
  • Loading branch information
EthanJamesLew committed Mar 11, 2024
1 parent 6866e31 commit a777adf
Showing 1 changed file with 179 additions and 0 deletions.
179 changes: 179 additions & 0 deletions notebooks/weighted-cost-func.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "01a2f054-0d0e-497a-8bb9-d5d395ca7c04",
"metadata": {},
"source": [
"# Weighted Cost Function\n",
"\n",
"Shows how to use the cost function requested in [issue #84](https://github.com/EthanJamesLew/AutoKoopman/issues/84)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dfdb47e2-0ea0-4bf8-8279-8500ff3cf21f",
"metadata": {},
"outputs": [],
"source": [
"# the notebook imports\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"import sys\n",
"sys.path.append(\"..\")\n",
"# this is the convenience function\n",
"from autokoopman import auto_koopman"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "291d3409-1c8c-44cb-8380-44f08019b57d",
"metadata": {},
"outputs": [],
"source": [
"# for a complete example, let's create an example dataset using an included benchmark system\n",
"import autokoopman.benchmark.fhn as fhn\n",
"fhn = fhn.FitzHughNagumo()\n",
"training_data = fhn.solve_ivps(\n",
" initial_states=np.random.uniform(low=-2.0, high=2.0, size=(10, 2)),\n",
" tspan=[0.0, 10.0],\n",
" sampling_period=0.1\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e2d42e41-46c2-467c-9ce7-9bd6a7c509a1",
"metadata": {},
"outputs": [],
"source": [
"# create trajectories as numpy array and create a weights array\n",
"trajectories = []\n",
"weights = []\n",
"\n",
"# create weights for every time point\n",
"for idx, traj in enumerate(training_data):\n",
" trajectories.append(traj.states)\n",
" # uniform weights\n",
" #weights.append(np.ones(len(traj.states)) / len(traj.states) )\n",
" \n",
" # distance weights\n",
" weights.append(traj.norm().states / (np.max(traj.norm().states) * len(traj.states)) )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "98510aa7-3416-4181-a493-00500be53f61",
"metadata": {},
"outputs": [],
"source": [
"# learn model from data\n",
"experiment_results = auto_koopman(\n",
" trajectories, # list of trajectories\n",
" sampling_period=0.1, # sampling period of trajectory snapshots\n",
" obs_type=\"rff\", # use Random Fourier Features Observables\n",
" cost_func=\"weighted\", # use \"weighted\" cost function\n",
" scoring_weights=weights, # pass weights as required for cost_func=\"weighted\"\n",
" opt=\"grid\", # grid search to find best hyperparameters\n",
" n_obs=200, # maximum number of observables to try\n",
" max_opt_iter=200, # maximum number of optimization iterations\n",
" grid_param_slices=5, # for grid search, number of slices for each parameter\n",
" n_splits=5, # k-folds validation for tuning, helps stabilize the scoring\n",
" rank=(1, 200, 40) # rank range (start, stop, step) DMD hyperparameter\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "476c496d-56c5-477b-9579-2c0121b3247d",
"metadata": {},
"outputs": [],
"source": [
"# view our custom weighted cost\n",
"experiment_results"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cb4dfdcf-01ca-4cbc-966d-3f531d8475ea",
"metadata": {},
"outputs": [],
"source": [
"# get the model from the experiment results\n",
"model = experiment_results['tuned_model']\n",
"\n",
"# simulate using the learned model\n",
"iv = [0.5, 0.1]\n",
"trajectory = model.solve_ivp(\n",
" initial_state=iv,\n",
" tspan=(0.0, 10.0),\n",
" sampling_period=0.1\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0b1e329c-c25c-442d-8b76-146924a6e46b",
"metadata": {},
"outputs": [],
"source": [
"# simulate the ground truth for comparison\n",
"true_trajectory = fhn.solve_ivp(\n",
" initial_state=iv,\n",
" tspan=(0.0, 10.0),\n",
" sampling_period=0.1\n",
")\n",
"\n",
"plt.figure(figsize=(10, 6))\n",
"\n",
"# plot the results\n",
"plt.plot(*trajectory.states.T, label='Trajectory Prediction')\n",
"plt.plot(*true_trajectory.states.T, label='Ground Truth')\n",
"\n",
"plt.xlabel(\"$x_1$\")\n",
"plt.ylabel(\"$x_2$\")\n",
"plt.grid()\n",
"plt.legend()\n",
"plt.title(\"FHN Test Trajectory Plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b1458259-6c92-46e5-91a3-f56e53633b35",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit a777adf

Please sign in to comment.