diff --git a/notebooks/save_dataframe.ipynb b/notebooks/save_dataframe.ipynb new file mode 100644 index 0000000..4e2020e --- /dev/null +++ b/notebooks/save_dataframe.ipynb @@ -0,0 +1,318 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e0b15caf-2e2b-4b65-a832-01892a353bed", + "metadata": {}, + "source": [ + "## Static dataset creation\n", + "This notebook walks through how to use the modules to create and save a static dataset for use in all of the statistical and ML methods. The method-focused notebooks show how to import and utilize this static dataset in inference." + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "47611152-0598-4d26-ac4d-d1f243dd0736", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import jax.numpy as jnp\n", + "from deepbench.physics_object import Pendulum\n", + "import pandas as pd" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "4ca50a6f-8f0e-469f-9993-e1e082133a7f", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def save_thetas_and_xs_multiple(params_in):\n", + " # except length will have 4 elements\n", + " lengths, thetas, μ_a_g, σ_a_g = params_in\n", + "\n", + " ag0 = rs.normal(loc=μ_a_g, scale=σ_a_g)\n", + " ag1 = rs.normal(loc=μ_a_g, scale=σ_a_g)\n", + " ags = np.array([np.repeat(ag0,int(len(lengths)/2)), np.repeat(ag1,int(len(lengths)/2))]).flatten()\n", + " #ags = np.array([rs.normal(loc=μ_a_g, scale=σ_a_g, size = int(len(lengths)/2)),\n", + " # rs.normal(loc=μ_a_g, scale=σ_a_g, size = int(len(lengths)/2))]).flatten()\n", + " \n", + " \n", + " xs = []\n", + " for i in range(len(lengths)):\n", + " #print(lengths[i], thetas[i], ags[i])\n", + " pendulum = Pendulum(\n", + " pendulum_arm_length=float(lengths[i]),\n", + " starting_angle_radians=float(thetas[i]),\n", + " acceleration_due_to_gravity=float(ags[i]),\n", + " noise_std_percent={\n", + " \"pendulum_arm_length\": 0.0,\n", + " \"starting_angle_radians\": 0.1,\n", + " \"acceleration_due_to_gravity\": 0.0,\n", + " },\n", + " )\n", + " x = pendulum.create_object(0.75, noiseless=False)\n", + " xs.append(x)\n", + " del pendulum\n", + " return ags, xs" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "9fdc1f49-e453-4526-b18b-814b68f92aca", + "metadata": {}, + "outputs": [], + "source": [ + "#length0, length1, length2, length3, theta0, theta1, theta2, theta3, μ_a_g, σ_a_g = thetas_in\n", + "length_percent_error_all = 0.0\n", + "theta_percent_error_all = 0.1\n", + "a_g_percent_error_all = 0.0\n", + "pos_err = 0.0\n", + "\n", + "time = 0.75\n", + "\n", + "total_length = 1000\n", + "length_df = int(total_length/4) # divide by four because we want the same total size as above\n", + "\n", + "pendulums_per_planet = 100\n", + "\n", + "# and we get four pendulums per iteration of the below\n", + "thetas = np.zeros((total_length, 5))\n", + "# this needs to have the extra 1 so that SBI is happy\n", + "xs = np.zeros((total_length,1))\n", + "#labels = np.zeros((2*length_df, 2))\n", + "#error = []\n", + "#y_noisy = []\n", + "\n", + " \n", + "rs = np.random.RandomState()#2147483648)# \n", + "\n", + "\n", + "lengths_draw = abs(rs.normal(loc=5, scale=2, size = pendulums_per_planet))\n", + "thetas_draw = abs(rs.normal(loc=jnp.pi/100, scale=jnp.pi/500, size = pendulums_per_planet))\n", + "\n", + "μ_a_g = abs(rs.normal(loc=10, scale=2))\n", + "σ_a_g = abs(rs.normal(loc=1, scale=0.5))\n", + "\n", + "\n", + "params_in = [lengths_draw,\n", + " thetas_draw,\n", + " μ_a_g, σ_a_g]\n", + "\n", + "a_gs, xs_out = save_thetas_and_xs_multiple(params_in)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "id": "9714134e-53a2-42df-b254-e942a2a41314", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " | length | \n", + "theta | \n", + "a_g | \n", + "time | \n", + "pos | \n", + "
---|---|---|---|---|---|
0 | \n", + "5.939540 | \n", + "0.040323 | \n", + "10.348500 | \n", + "0.75 | \n", + "0.126648 | \n", + "
1 | \n", + "3.645948 | \n", + "0.041587 | \n", + "10.348500 | \n", + "0.75 | \n", + "0.046904 | \n", + "
2 | \n", + "4.494573 | \n", + "0.032681 | \n", + "10.348500 | \n", + "0.75 | \n", + "0.066548 | \n", + "
3 | \n", + "3.691717 | \n", + "0.032449 | \n", + "10.348500 | \n", + "0.75 | \n", + "0.036753 | \n", + "
4 | \n", + "8.463276 | \n", + "0.026023 | \n", + "10.348500 | \n", + "0.75 | \n", + "0.124928 | \n", + "
... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "
95 | \n", + "3.287364 | \n", + "0.027606 | \n", + "9.821157 | \n", + "0.75 | \n", + "0.027284 | \n", + "
96 | \n", + "4.901377 | \n", + "0.031559 | \n", + "9.821157 | \n", + "0.75 | \n", + "0.080431 | \n", + "
97 | \n", + "6.374146 | \n", + "0.028311 | \n", + "9.821157 | \n", + "0.75 | \n", + "0.102336 | \n", + "
98 | \n", + "1.075661 | \n", + "0.030635 | \n", + "9.821157 | \n", + "0.75 | \n", + "-0.023095 | \n", + "
99 | \n", + "5.731387 | \n", + "0.023705 | \n", + "9.821157 | \n", + "0.75 | \n", + "0.071557 | \n", + "
100 rows × 5 columns
\n", + "