Skip to content

Commit

Permalink
code to save a static dataframe
Browse files Browse the repository at this point in the history
  • Loading branch information
beckynevin committed Oct 12, 2023
1 parent d294936 commit 823202f
Showing 1 changed file with 318 additions and 0 deletions.
318 changes: 318 additions & 0 deletions notebooks/save_dataframe.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "e0b15caf-2e2b-4b65-a832-01892a353bed",
"metadata": {},
"source": [
"## Static dataset creation\n",
"This notebook walks through how to use the modules to create and save a static dataset for use in all of the statistical and ML methods. The method-focused notebooks show how to import and utilize this static dataset in inference."
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "47611152-0598-4d26-ac4d-d1f243dd0736",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import jax.numpy as jnp\n",
"from deepbench.physics_object import Pendulum\n",
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "4ca50a6f-8f0e-469f-9993-e1e082133a7f",
"metadata": {},
"outputs": [],
"source": [
"\n",
"def save_thetas_and_xs_multiple(params_in):\n",
" # except length will have 4 elements\n",
" lengths, thetas, μ_a_g, σ_a_g = params_in\n",
"\n",
" ag0 = rs.normal(loc=μ_a_g, scale=σ_a_g)\n",
" ag1 = rs.normal(loc=μ_a_g, scale=σ_a_g)\n",
" ags = np.array([np.repeat(ag0,int(len(lengths)/2)), np.repeat(ag1,int(len(lengths)/2))]).flatten()\n",
" #ags = np.array([rs.normal(loc=μ_a_g, scale=σ_a_g, size = int(len(lengths)/2)),\n",
" # rs.normal(loc=μ_a_g, scale=σ_a_g, size = int(len(lengths)/2))]).flatten()\n",
" \n",
" \n",
" xs = []\n",
" for i in range(len(lengths)):\n",
" #print(lengths[i], thetas[i], ags[i])\n",
" pendulum = Pendulum(\n",
" pendulum_arm_length=float(lengths[i]),\n",
" starting_angle_radians=float(thetas[i]),\n",
" acceleration_due_to_gravity=float(ags[i]),\n",
" noise_std_percent={\n",
" \"pendulum_arm_length\": 0.0,\n",
" \"starting_angle_radians\": 0.1,\n",
" \"acceleration_due_to_gravity\": 0.0,\n",
" },\n",
" )\n",
" x = pendulum.create_object(0.75, noiseless=False)\n",
" xs.append(x)\n",
" del pendulum\n",
" return ags, xs"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "9fdc1f49-e453-4526-b18b-814b68f92aca",
"metadata": {},
"outputs": [],
"source": [
"#length0, length1, length2, length3, theta0, theta1, theta2, theta3, μ_a_g, σ_a_g = thetas_in\n",
"length_percent_error_all = 0.0\n",
"theta_percent_error_all = 0.1\n",
"a_g_percent_error_all = 0.0\n",
"pos_err = 0.0\n",
"\n",
"time = 0.75\n",
"\n",
"total_length = 1000\n",
"length_df = int(total_length/4) # divide by four because we want the same total size as above\n",
"\n",
"pendulums_per_planet = 100\n",
"\n",
"# and we get four pendulums per iteration of the below\n",
"thetas = np.zeros((total_length, 5))\n",
"# this needs to have the extra 1 so that SBI is happy\n",
"xs = np.zeros((total_length,1))\n",
"#labels = np.zeros((2*length_df, 2))\n",
"#error = []\n",
"#y_noisy = []\n",
"\n",
" \n",
"rs = np.random.RandomState()#2147483648)# \n",
"\n",
"\n",
"lengths_draw = abs(rs.normal(loc=5, scale=2, size = pendulums_per_planet))\n",
"thetas_draw = abs(rs.normal(loc=jnp.pi/100, scale=jnp.pi/500, size = pendulums_per_planet))\n",
"\n",
"μ_a_g = abs(rs.normal(loc=10, scale=2))\n",
"σ_a_g = abs(rs.normal(loc=1, scale=0.5))\n",
"\n",
"\n",
"params_in = [lengths_draw,\n",
" thetas_draw,\n",
" μ_a_g, σ_a_g]\n",
"\n",
"a_gs, xs_out = save_thetas_and_xs_multiple(params_in)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "9714134e-53a2-42df-b254-e942a2a41314",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>length</th>\n",
" <th>theta</th>\n",
" <th>a_g</th>\n",
" <th>time</th>\n",
" <th>pos</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>5.939540</td>\n",
" <td>0.040323</td>\n",
" <td>10.348500</td>\n",
" <td>0.75</td>\n",
" <td>0.126648</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>3.645948</td>\n",
" <td>0.041587</td>\n",
" <td>10.348500</td>\n",
" <td>0.75</td>\n",
" <td>0.046904</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>4.494573</td>\n",
" <td>0.032681</td>\n",
" <td>10.348500</td>\n",
" <td>0.75</td>\n",
" <td>0.066548</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>3.691717</td>\n",
" <td>0.032449</td>\n",
" <td>10.348500</td>\n",
" <td>0.75</td>\n",
" <td>0.036753</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>8.463276</td>\n",
" <td>0.026023</td>\n",
" <td>10.348500</td>\n",
" <td>0.75</td>\n",
" <td>0.124928</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>95</th>\n",
" <td>3.287364</td>\n",
" <td>0.027606</td>\n",
" <td>9.821157</td>\n",
" <td>0.75</td>\n",
" <td>0.027284</td>\n",
" </tr>\n",
" <tr>\n",
" <th>96</th>\n",
" <td>4.901377</td>\n",
" <td>0.031559</td>\n",
" <td>9.821157</td>\n",
" <td>0.75</td>\n",
" <td>0.080431</td>\n",
" </tr>\n",
" <tr>\n",
" <th>97</th>\n",
" <td>6.374146</td>\n",
" <td>0.028311</td>\n",
" <td>9.821157</td>\n",
" <td>0.75</td>\n",
" <td>0.102336</td>\n",
" </tr>\n",
" <tr>\n",
" <th>98</th>\n",
" <td>1.075661</td>\n",
" <td>0.030635</td>\n",
" <td>9.821157</td>\n",
" <td>0.75</td>\n",
" <td>-0.023095</td>\n",
" </tr>\n",
" <tr>\n",
" <th>99</th>\n",
" <td>5.731387</td>\n",
" <td>0.023705</td>\n",
" <td>9.821157</td>\n",
" <td>0.75</td>\n",
" <td>0.071557</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>100 rows × 5 columns</p>\n",
"</div>"
],
"text/plain": [
" length theta a_g time pos\n",
"0 5.939540 0.040323 10.348500 0.75 0.126648\n",
"1 3.645948 0.041587 10.348500 0.75 0.046904\n",
"2 4.494573 0.032681 10.348500 0.75 0.066548\n",
"3 3.691717 0.032449 10.348500 0.75 0.036753\n",
"4 8.463276 0.026023 10.348500 0.75 0.124928\n",
".. ... ... ... ... ...\n",
"95 3.287364 0.027606 9.821157 0.75 0.027284\n",
"96 4.901377 0.031559 9.821157 0.75 0.080431\n",
"97 6.374146 0.028311 9.821157 0.75 0.102336\n",
"98 1.075661 0.030635 9.821157 0.75 -0.023095\n",
"99 5.731387 0.023705 9.821157 0.75 0.071557\n",
"\n",
"[100 rows x 5 columns]"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# now make it into a dataframe\n",
"data_params = {\n",
" 'length': lengths_draw,\n",
" 'theta': thetas_draw,\n",
" 'a_g': a_gs,\n",
" 'time': np.repeat(time, len(lengths_draw)),\n",
" 'pos': xs_out,\n",
" \n",
"}\n",
"\n",
"## create the DataFrame\n",
"df = pd.DataFrame(data_params)\n",
"df"
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "44c6292d-fea9-4693-9173-913fd396bbd5",
"metadata": {},
"outputs": [],
"source": [
"# save the dataframe\n",
"filepath = '../data/'\n",
"df.to_csv(filepath+'static_hierarchical_df.csv')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4efa4d17-45d1-4505-a111-13bcc479452a",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit 823202f

Please sign in to comment.