-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d294936
commit 823202f
Showing
1 changed file
with
318 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,318 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "e0b15caf-2e2b-4b65-a832-01892a353bed", | ||
"metadata": {}, | ||
"source": [ | ||
"## Static dataset creation\n", | ||
"This notebook walks through how to use the modules to create and save a static dataset for use in all of the statistical and ML methods. The method-focused notebooks show how to import and utilize this static dataset in inference." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 39, | ||
"id": "47611152-0598-4d26-ac4d-d1f243dd0736", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"import jax.numpy as jnp\n", | ||
"from deepbench.physics_object import Pendulum\n", | ||
"import pandas as pd" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 46, | ||
"id": "4ca50a6f-8f0e-469f-9993-e1e082133a7f", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"def save_thetas_and_xs_multiple(params_in):\n", | ||
" # except length will have 4 elements\n", | ||
" lengths, thetas, μ_a_g, σ_a_g = params_in\n", | ||
"\n", | ||
" ag0 = rs.normal(loc=μ_a_g, scale=σ_a_g)\n", | ||
" ag1 = rs.normal(loc=μ_a_g, scale=σ_a_g)\n", | ||
" ags = np.array([np.repeat(ag0,int(len(lengths)/2)), np.repeat(ag1,int(len(lengths)/2))]).flatten()\n", | ||
" #ags = np.array([rs.normal(loc=μ_a_g, scale=σ_a_g, size = int(len(lengths)/2)),\n", | ||
" # rs.normal(loc=μ_a_g, scale=σ_a_g, size = int(len(lengths)/2))]).flatten()\n", | ||
" \n", | ||
" \n", | ||
" xs = []\n", | ||
" for i in range(len(lengths)):\n", | ||
" #print(lengths[i], thetas[i], ags[i])\n", | ||
" pendulum = Pendulum(\n", | ||
" pendulum_arm_length=float(lengths[i]),\n", | ||
" starting_angle_radians=float(thetas[i]),\n", | ||
" acceleration_due_to_gravity=float(ags[i]),\n", | ||
" noise_std_percent={\n", | ||
" \"pendulum_arm_length\": 0.0,\n", | ||
" \"starting_angle_radians\": 0.1,\n", | ||
" \"acceleration_due_to_gravity\": 0.0,\n", | ||
" },\n", | ||
" )\n", | ||
" x = pendulum.create_object(0.75, noiseless=False)\n", | ||
" xs.append(x)\n", | ||
" del pendulum\n", | ||
" return ags, xs" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 47, | ||
"id": "9fdc1f49-e453-4526-b18b-814b68f92aca", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#length0, length1, length2, length3, theta0, theta1, theta2, theta3, μ_a_g, σ_a_g = thetas_in\n", | ||
"length_percent_error_all = 0.0\n", | ||
"theta_percent_error_all = 0.1\n", | ||
"a_g_percent_error_all = 0.0\n", | ||
"pos_err = 0.0\n", | ||
"\n", | ||
"time = 0.75\n", | ||
"\n", | ||
"total_length = 1000\n", | ||
"length_df = int(total_length/4) # divide by four because we want the same total size as above\n", | ||
"\n", | ||
"pendulums_per_planet = 100\n", | ||
"\n", | ||
"# and we get four pendulums per iteration of the below\n", | ||
"thetas = np.zeros((total_length, 5))\n", | ||
"# this needs to have the extra 1 so that SBI is happy\n", | ||
"xs = np.zeros((total_length,1))\n", | ||
"#labels = np.zeros((2*length_df, 2))\n", | ||
"#error = []\n", | ||
"#y_noisy = []\n", | ||
"\n", | ||
" \n", | ||
"rs = np.random.RandomState()#2147483648)# \n", | ||
"\n", | ||
"\n", | ||
"lengths_draw = abs(rs.normal(loc=5, scale=2, size = pendulums_per_planet))\n", | ||
"thetas_draw = abs(rs.normal(loc=jnp.pi/100, scale=jnp.pi/500, size = pendulums_per_planet))\n", | ||
"\n", | ||
"μ_a_g = abs(rs.normal(loc=10, scale=2))\n", | ||
"σ_a_g = abs(rs.normal(loc=1, scale=0.5))\n", | ||
"\n", | ||
"\n", | ||
"params_in = [lengths_draw,\n", | ||
" thetas_draw,\n", | ||
" μ_a_g, σ_a_g]\n", | ||
"\n", | ||
"a_gs, xs_out = save_thetas_and_xs_multiple(params_in)\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 50, | ||
"id": "9714134e-53a2-42df-b254-e942a2a41314", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/html": [ | ||
"<div>\n", | ||
"<style scoped>\n", | ||
" .dataframe tbody tr th:only-of-type {\n", | ||
" vertical-align: middle;\n", | ||
" }\n", | ||
"\n", | ||
" .dataframe tbody tr th {\n", | ||
" vertical-align: top;\n", | ||
" }\n", | ||
"\n", | ||
" .dataframe thead th {\n", | ||
" text-align: right;\n", | ||
" }\n", | ||
"</style>\n", | ||
"<table border=\"1\" class=\"dataframe\">\n", | ||
" <thead>\n", | ||
" <tr style=\"text-align: right;\">\n", | ||
" <th></th>\n", | ||
" <th>length</th>\n", | ||
" <th>theta</th>\n", | ||
" <th>a_g</th>\n", | ||
" <th>time</th>\n", | ||
" <th>pos</th>\n", | ||
" </tr>\n", | ||
" </thead>\n", | ||
" <tbody>\n", | ||
" <tr>\n", | ||
" <th>0</th>\n", | ||
" <td>5.939540</td>\n", | ||
" <td>0.040323</td>\n", | ||
" <td>10.348500</td>\n", | ||
" <td>0.75</td>\n", | ||
" <td>0.126648</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>1</th>\n", | ||
" <td>3.645948</td>\n", | ||
" <td>0.041587</td>\n", | ||
" <td>10.348500</td>\n", | ||
" <td>0.75</td>\n", | ||
" <td>0.046904</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>2</th>\n", | ||
" <td>4.494573</td>\n", | ||
" <td>0.032681</td>\n", | ||
" <td>10.348500</td>\n", | ||
" <td>0.75</td>\n", | ||
" <td>0.066548</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>3</th>\n", | ||
" <td>3.691717</td>\n", | ||
" <td>0.032449</td>\n", | ||
" <td>10.348500</td>\n", | ||
" <td>0.75</td>\n", | ||
" <td>0.036753</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>4</th>\n", | ||
" <td>8.463276</td>\n", | ||
" <td>0.026023</td>\n", | ||
" <td>10.348500</td>\n", | ||
" <td>0.75</td>\n", | ||
" <td>0.124928</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>...</th>\n", | ||
" <td>...</td>\n", | ||
" <td>...</td>\n", | ||
" <td>...</td>\n", | ||
" <td>...</td>\n", | ||
" <td>...</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>95</th>\n", | ||
" <td>3.287364</td>\n", | ||
" <td>0.027606</td>\n", | ||
" <td>9.821157</td>\n", | ||
" <td>0.75</td>\n", | ||
" <td>0.027284</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>96</th>\n", | ||
" <td>4.901377</td>\n", | ||
" <td>0.031559</td>\n", | ||
" <td>9.821157</td>\n", | ||
" <td>0.75</td>\n", | ||
" <td>0.080431</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>97</th>\n", | ||
" <td>6.374146</td>\n", | ||
" <td>0.028311</td>\n", | ||
" <td>9.821157</td>\n", | ||
" <td>0.75</td>\n", | ||
" <td>0.102336</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>98</th>\n", | ||
" <td>1.075661</td>\n", | ||
" <td>0.030635</td>\n", | ||
" <td>9.821157</td>\n", | ||
" <td>0.75</td>\n", | ||
" <td>-0.023095</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>99</th>\n", | ||
" <td>5.731387</td>\n", | ||
" <td>0.023705</td>\n", | ||
" <td>9.821157</td>\n", | ||
" <td>0.75</td>\n", | ||
" <td>0.071557</td>\n", | ||
" </tr>\n", | ||
" </tbody>\n", | ||
"</table>\n", | ||
"<p>100 rows × 5 columns</p>\n", | ||
"</div>" | ||
], | ||
"text/plain": [ | ||
" length theta a_g time pos\n", | ||
"0 5.939540 0.040323 10.348500 0.75 0.126648\n", | ||
"1 3.645948 0.041587 10.348500 0.75 0.046904\n", | ||
"2 4.494573 0.032681 10.348500 0.75 0.066548\n", | ||
"3 3.691717 0.032449 10.348500 0.75 0.036753\n", | ||
"4 8.463276 0.026023 10.348500 0.75 0.124928\n", | ||
".. ... ... ... ... ...\n", | ||
"95 3.287364 0.027606 9.821157 0.75 0.027284\n", | ||
"96 4.901377 0.031559 9.821157 0.75 0.080431\n", | ||
"97 6.374146 0.028311 9.821157 0.75 0.102336\n", | ||
"98 1.075661 0.030635 9.821157 0.75 -0.023095\n", | ||
"99 5.731387 0.023705 9.821157 0.75 0.071557\n", | ||
"\n", | ||
"[100 rows x 5 columns]" | ||
] | ||
}, | ||
"execution_count": 50, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"# now make it into a dataframe\n", | ||
"data_params = {\n", | ||
" 'length': lengths_draw,\n", | ||
" 'theta': thetas_draw,\n", | ||
" 'a_g': a_gs,\n", | ||
" 'time': np.repeat(time, len(lengths_draw)),\n", | ||
" 'pos': xs_out,\n", | ||
" \n", | ||
"}\n", | ||
"\n", | ||
"## create the DataFrame\n", | ||
"df = pd.DataFrame(data_params)\n", | ||
"df" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 51, | ||
"id": "44c6292d-fea9-4693-9173-913fd396bbd5", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# save the dataframe\n", | ||
"filepath = '../data/'\n", | ||
"df.to_csv(filepath+'static_hierarchical_df.csv')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "4efa4d17-45d1-4505-a111-13bcc479452a", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.12" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |