From e76f0f19b8eb4efa785ba41173aa82f4e8d1f7fc Mon Sep 17 00:00:00 2001 From: beckynevin Date: Wed, 11 Oct 2023 15:50:53 -0600 Subject: [PATCH] adding hierarchical SBI to train.py --- notebooks/SBI_hierarchical_csv.ipynb | 59 +++++++++++++--------------- src/scripts/train.py | 23 +++++++++++ 2 files changed, 50 insertions(+), 32 deletions(-) diff --git a/notebooks/SBI_hierarchical_csv.ipynb b/notebooks/SBI_hierarchical_csv.ipynb index ecdbcfe..e600554 100644 --- a/notebooks/SBI_hierarchical_csv.ipynb +++ b/notebooks/SBI_hierarchical_csv.ipynb @@ -19,7 +19,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "id": "5e8c8f57", "metadata": {}, "outputs": [], @@ -38,7 +38,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "id": "4a76432d", "metadata": {}, "outputs": [ @@ -58,7 +58,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "id": "7bfe9a7f", "metadata": {}, "outputs": [], @@ -71,6 +71,19 @@ "import torch" ] }, + { + "cell_type": "code", + "execution_count": 15, + "id": "d24587bc-5777-41bb-b3b3-d1d63b7325c7", + "metadata": {}, + "outputs": [], + "source": [ + "# this is necessary to import modules from this repo\n", + "import sys\n", + "sys.path.append('..')\n", + "from src.scripts import models, utils, train" + ] + }, { "cell_type": "markdown", "id": "d478548e", @@ -82,7 +95,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 5, "id": "6ff66af4-8bf4-4019-8c1a-8f0ef4428137", "metadata": {}, "outputs": [], @@ -117,7 +130,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 6, "id": "37caacc2-d7ea-407e-9583-f94643976037", "metadata": {}, "outputs": [ @@ -142,7 +155,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 7, "id": "6e5e4866-6cf1-40b9-93db-303988cbf8e2", "metadata": {}, "outputs": [], @@ -152,7 +165,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 8, "id": "f7baa131-951c-4a69-a7c9-6f433afe6302", "metadata": {}, "outputs": [], @@ -168,6 +181,7 @@ "\n", "length_df = 1000\n", "thetas = np.zeros((length_df, 5))\n", + "# this needs to have the extra 1 so that SBI is happy\n", "xs = np.zeros((length_df,1))\n", "#labels = np.zeros((2*length_df, 2))\n", "#error = []\n", @@ -202,13 +216,13 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 9, "id": "d1f15262-4b67-468a-8fd7-8f6578111f04", "metadata": {}, "outputs": [ { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAGdCAYAAAA44ojeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAiaUlEQVR4nO3de3DU5f238XcOJJyyGwIkm2gIIEJAORUkbFWkkpKEeI5TQWrBYaBisKMRhVgEsR0TkamMDsjUqaAdkUpHsUKh0iDgIQRJoRzNABMLFDYgmWQJSjjkfv7ow85vJYCb7LJ3wvWa2TH57r1f7k/XbS43u0uEMcYIAADAIpHh3gAAAMAPESgAAMA6BAoAALAOgQIAAKxDoAAAAOsQKAAAwDoECgAAsA6BAgAArBMd7g00RUNDg44cOaK4uDhFRESEezsAAOBHMMbo5MmTSklJUWTk5Z8jaZGBcuTIEaWmpoZ7GwAAoAkOHTqk66+//rJrWmSgxMXFSfrfgA6HI8y7AQAAP4bX61Vqaqrv5/jltMhAufBrHYfDQaAAANDC/JiXZ/AiWQAAYJ2AAqWoqEi33HKL4uLilJiYqPvuu08VFRV+a0aOHKmIiAi/y2OPPea35uDBg8rNzVX79u2VmJioZ555RufOnWv+NAAAoFUI6Fc8GzduVH5+vm655RadO3dOzz33nEaPHq09e/aoQ4cOvnWTJ0/Wiy++6Pu+ffv2vq/Pnz+v3NxcuVwuffnllzp69Kh+9atfqU2bNnrppZeCMBIAAGjpIowxpqk3Pn78uBITE7Vx40aNGDFC0v+eQRk0aJAWLFjQ6G3WrFmju+66S0eOHFFSUpIkafHixZoxY4aOHz+umJiYK/65Xq9XTqdTtbW1vAYFAIAWIpCf3816DUptba0kKSEhwe/4u+++qy5duujmm29WYWGhvvvuO991paWl6t+/vy9OJCkrK0ter1e7d+9uznYAAEAr0eR38TQ0NOjJJ5/Urbfeqptvvtl3/OGHH1ZaWppSUlK0Y8cOzZgxQxUVFfrggw8kSR6Pxy9OJPm+93g8jf5Z9fX1qq+v933v9Xqbum0AANACNDlQ8vPztWvXLn3++ed+x6dMmeL7un///kpOTtaoUaN04MAB3XDDDU36s4qKijR37tymbhUAALQwTfoVz7Rp07Rq1Sp9+umnV/wkuIyMDEnS/v37JUkul0tVVVV+ay5873K5Gj1HYWGhamtrfZdDhw41ZdsAAKCFCChQjDGaNm2aPvzwQ61fv149evS44m22b98uSUpOTpYkud1u7dy5U8eOHfOtWbdunRwOh/r169foOWJjY30fysaHswEA0PoF9Cue/Px8LVu2TB999JHi4uJ8rxlxOp1q166dDhw4oGXLlmnMmDHq3LmzduzYoaeeekojRozQgAEDJEmjR49Wv3799Mgjj2jevHnyeDyaNWuW8vPzFRsbG/wJAQBAixPQ24wv9dG0S5Ys0cSJE3Xo0CH98pe/1K5du3Tq1Cmlpqbq/vvv16xZs/ye9fjPf/6jqVOnasOGDerQoYMmTJig4uJiRUf/uF7ibcYAALQ8gfz8btbnoIQLgQIAQMtz1T4HBQAAIBQIFAAAYJ0mfw4K0FzdZ64O9xYC9k1xbri3AADXBJ5BAQAA1iFQAACAdQgUAABgHQIFAABYh0ABAADWIVAAAIB1CBQAAGAdAgUAAFiHQAEAANYhUAAAgHUIFAAAYB0CBQAAWIdAAQAA1iFQAACAdQgUAABgHQIFAABYh0ABAADWIVAAAIB1CBQAAGAdAgUAAFiHQAEAANYhUAAAgHUIFAAAYB0CBQAAWCc63BsAWpLuM1eHewsB+6Y4N9xbAICA8QwKAACwDoECAACsQ6AAAADrECgAAMA6BAoAALAOgQIAAKxDoAAAAOsQKAAAwDoECgAAsA6BAgAArEOgAAAA6xAoAADAOgQKAACwDoECAACsQ6AAAADrECgAAMA6BAoAALAOgQIAAKxDoAAAAOsQKAAAwDoECgAAsA6BAgAArEOgAAAA6xAoAADAOgQKAACwDoECAACsQ6AAAADrECgAAMA6BAoAALAOgQIAAKxDoAAAAOsQKAAAwDoECgAAsA6BAgAArBNQoBQVFemWW25RXFycEhMTdd9996miosJvzenTp5Wfn6/OnTurY8eOysvLU1VVld+agwcPKjc3V+3bt1diYqKeeeYZnTt3rvnTAACAViGgQNm4caPy8/O1efNmrVu3TmfPntXo0aN16tQp35qnnnpKH3/8sVasWKGNGzfqyJEjeuCBB3zXnz9/Xrm5uTpz5oy+/PJLvf3221q6dKlmz54dvKkAAECLFmGMMU298fHjx5WYmKiNGzdqxIgRqq2tVdeuXbVs2TI9+OCDkqSvv/5affv2VWlpqYYPH641a9borrvu0pEjR5SUlCRJWrx4sWbMmKHjx48rJibmin+u1+uV0+lUbW2tHA5HU7ePMOs+c3W4t3BN+KY4N9xbAABJgf38btZrUGprayVJCQkJkqTy8nKdPXtWmZmZvjXp6enq1q2bSktLJUmlpaXq37+/L04kKSsrS16vV7t37270z6mvr5fX6/W7AACA1qvJgdLQ0KAnn3xSt956q26++WZJksfjUUxMjOLj4/3WJiUlyePx+Nb83zi5cP2F6xpTVFQkp9Ppu6SmpjZ12wAAoAVocqDk5+dr165dWr58eTD306jCwkLV1tb6LocOHQr5nwkAAMInuik3mjZtmlatWqVNmzbp+uuv9x13uVw6c+aMampq/J5Fqaqqksvl8q3ZsmWL3/kuvMvnwpofio2NVWxsbFO2CgAAWqCAnkExxmjatGn68MMPtX79evXo0cPv+iFDhqhNmzYqKSnxHauoqNDBgwfldrslSW63Wzt37tSxY8d8a9atWyeHw6F+/fo1ZxYAANBKBPQMSn5+vpYtW6aPPvpIcXFxvteMOJ1OtWvXTk6nU5MmTVJBQYESEhLkcDj0xBNPyO12a/jw4ZKk0aNHq1+/fnrkkUc0b948eTwezZo1S/n5+TxLAgAAJAUYKG+88YYkaeTIkX7HlyxZookTJ0qSXn31VUVGRiovL0/19fXKysrSokWLfGujoqK0atUqTZ06VW63Wx06dNCECRP04osvNm8SAADQajTrc1DChc9BaR34HJSrg89BAWCLq/Y5KAAAAKFAoAAAAOsQKAAAwDoECgAAsA6BAgAArEOgAAAA6xAoAADAOgQKAACwDoECAACsQ6AAAADrECgAAMA6BAoAALAOgQIAAKxDoAAAAOsQKAAAwDoECgAAsA6BAgAArEOgAAAA6xAoAADAOgQKAACwDoECAACsQ6AAAADrECgAAMA6BAoAALAOgQIAAKxDoAAAAOsQKAAAwDoECgAAsA6BAgAArEOgAAAA6xAoAADAOgQKAACwTnS4N4Dg6D5zdbi3AABA0PAMCgAAsA6BAgAArEOgAAAA6xAoAADAOgQKAACwDoECAACsQ6AAAADrECgAAMA6BAoAALAOgQIAAKxDoAAAAOsQKAAAwDoECgAAsA6BAgAArEOgAAAA6xAoAADAOgQKAACwDoECAACsQ6AAAADrECgAAMA6BAoAALAOgQIAAKwTHe4NAAit7jNXh3sLAfumODfcWwAQZjyDAgAArEOgAAAA6xAoAADAOgQKAACwDoECAACsE3CgbNq0SXfffbdSUlIUERGhlStX+l0/ceJERURE+F2ys7P91lRXV2v8+PFyOByKj4/XpEmTVFdX16xBAABA6xFwoJw6dUoDBw7UwoULL7kmOztbR48e9V3ee+89v+vHjx+v3bt3a926dVq1apU2bdqkKVOmBL57AADQKgX8OSg5OTnKycm57JrY2Fi5XK5Gr9u7d6/Wrl2rr776SkOHDpUkvf766xozZozmz5+vlJSUQLcEAABamZC8BmXDhg1KTExUnz59NHXqVJ04ccJ3XWlpqeLj431xIkmZmZmKjIxUWVlZo+err6+X1+v1uwAAgNYr6IGSnZ2td955RyUlJXr55Ze1ceNG5eTk6Pz585Ikj8ejxMREv9tER0crISFBHo+n0XMWFRXJ6XT6LqmpqcHeNgAAsEjQP+p+7Nixvq/79++vAQMG6IYbbtCGDRs0atSoJp2zsLBQBQUFvu+9Xi+RAgBAKxbytxn37NlTXbp00f79+yVJLpdLx44d81tz7tw5VVdXX/J1K7GxsXI4HH4XAADQeoU8UA4fPqwTJ04oOTlZkuR2u1VTU6Py8nLfmvXr16uhoUEZGRmh3g4AAGgBAv4VT11dne/ZEEmqrKzU9u3blZCQoISEBM2dO1d5eXlyuVw6cOCAnn32WfXq1UtZWVmSpL59+yo7O1uTJ0/W4sWLdfbsWU2bNk1jx47lHTwAAEBSE55B2bp1qwYPHqzBgwdLkgoKCjR48GDNnj1bUVFR2rFjh+655x717t1bkyZN0pAhQ/TZZ58pNjbWd453331X6enpGjVqlMaMGaPbbrtNf/zjH4M3FQAAaNECfgZl5MiRMsZc8vp//OMfVzxHQkKCli1bFugfDQAArhH8XTwAAMA6BAoAALAOgQIAAKxDoAAAAOsQKAAAwDoECgAAsA6BAgAArEOgAAAA6xAoAADAOgQKAACwDoECAACsQ6AAAADrECgAAMA6BAoAALAOgQIAAKxDoAAAAOsQKAAAwDoECgAAsA6BAgAArEOgAAAA6xAoAADAOgQKAACwDoECAACsQ6AAAADrECgAAMA6BAoAALAOgQIAAKxDoAAAAOsQKAAAwDoECgAAsA6BAgAArEOgAAAA6xAoAADAOgQKAACwDoECAACsQ6AAAADrECgAAMA6BAoAALAOgQIAAKxDoAAAAOsQKAAAwDoECgAAsA6BAgAArEOgAAAA6xAoAADAOgQKAACwDoECAACsQ6AAAADrECgAAMA6BAoAALAOgQIAAKxDoAAAAOsQKAAAwDoECgAAsA6BAgAArEOgAAAA6xAoAADAOgQKAACwDoECAACsQ6AAAADrECgAAMA6AQfKpk2bdPfddyslJUURERFauXKl3/XGGM2ePVvJyclq166dMjMztW/fPr811dXVGj9+vBwOh+Lj4zVp0iTV1dU1axAAANB6BBwop06d0sCBA7Vw4cJGr583b55ee+01LV68WGVlZerQoYOysrJ0+vRp35rx48dr9+7dWrdunVatWqVNmzZpypQpTZ8CAAC0KtGB3iAnJ0c5OTmNXmeM0YIFCzRr1izde++9kqR33nlHSUlJWrlypcaOHau9e/dq7dq1+uqrrzR06FBJ0uuvv64xY8Zo/vz5SklJacY4AACgNQjqa1AqKyvl8XiUmZnpO+Z0OpWRkaHS0lJJUmlpqeLj431xIkmZmZmKjIxUWVlZo+etr6+X1+v1uwAAgNYrqIHi8XgkSUlJSX7Hk5KSfNd5PB4lJib6XR8dHa2EhATfmh8qKiqS0+n0XVJTU4O5bQAAYJkW8S6ewsJC1dbW+i6HDh0K95YAAEAIBTVQXC6XJKmqqsrveFVVle86l8ulY8eO+V1/7tw5VVdX+9b8UGxsrBwOh98FAAC0XkENlB49esjlcqmkpMR3zOv1qqysTG63W5LkdrtVU1Oj8vJy35r169eroaFBGRkZwdwOAABooQJ+F09dXZ3279/v+76yslLbt29XQkKCunXrpieffFK///3vdeONN6pHjx56/vnnlZKSovvuu0+S1LdvX2VnZ2vy5MlavHixzp49q2nTpmns2LG8gwcAAEhqQqBs3bpVP/vZz3zfFxQUSJImTJigpUuX6tlnn9WpU6c0ZcoU1dTU6LbbbtPatWvVtm1b323effddTZs2TaNGjVJkZKTy8vL02muvBWEcAADQGkQYY0y4NxEor9crp9Op2tpaXo/y/3WfuTrcWwCC5pvi3HBvAUAIBPLzu0W8iwcAAFxbCBQAAGAdAgUAAFiHQAEAANYhUAAAgHUIFAAAYB0CBQAAWIdAAQAA1iFQAACAdQgUAABgHQIFAABYh0ABAADWIVAAAIB1CBQAAGAdAgUAAFiHQAEAANYhUAAAgHUIFAAAYB0CBQAAWIdAAQAA1iFQAACAdQgUAABgnehwbwAAfqj7zNXh3kLAvinODfcWgFaFZ1AAAIB1CBQAAGAdAgUAAFiHQAEAANYhUAAAgHUIFAAAYB0CBQAAWIdAAQAA1iFQAACAdQgUAABgHQIFAABYh0ABAADWIVAAAIB1CBQAAGAdAgUAAFiHQAEAANYhUAAAgHUIFAAAYB0CBQAAWIdAAQAA1iFQAACAdQgUAABgHQIFAABYh0ABAADWIVAAAIB1osO9ARt1n7k63FsAAOCaxjMoAADAOgQKAACwDoECAACsQ6AAAADrECgAAMA6BAoAALAOgQIAAKxDoAAAAOsQKAAAwDoECgAAsA6BAgAArEOgAAAA6wQ9UF544QVFRET4XdLT033Xnz59Wvn5+ercubM6duyovLw8VVVVBXsbAACgBQvJMyg33XSTjh496rt8/vnnvuueeuopffzxx1qxYoU2btyoI0eO6IEHHgjFNgAAQAsVHZKTRkfL5XJddLy2tlZ/+tOftGzZMt15552SpCVLlqhv377avHmzhg8fHortAACAFiYkz6Ds27dPKSkp6tmzp8aPH6+DBw9KksrLy3X27FllZmb61qanp6tbt24qLS295Pnq6+vl9Xr9LgAAoPUKeqBkZGRo6dKlWrt2rd544w1VVlbq9ttv18mTJ+XxeBQTE6P4+Hi/2yQlJcnj8VzynEVFRXI6nb5LampqsLcNAAAsEvRf8eTk5Pi+HjBggDIyMpSWlqb3339f7dq1a9I5CwsLVVBQ4Pve6/USKQAAtGIhf5txfHy8evfurf3798vlcunMmTOqqanxW1NVVdXoa1YuiI2NlcPh8LsAAIDWK+SBUldXpwMHDig5OVlDhgxRmzZtVFJS4ru+oqJCBw8elNvtDvVWAABACxH0X/FMnz5dd999t9LS0nTkyBHNmTNHUVFRGjdunJxOpyZNmqSCggIlJCTI4XDoiSeekNvt5h08AADAJ+iBcvjwYY0bN04nTpxQ165dddttt2nz5s3q2rWrJOnVV19VZGSk8vLyVF9fr6ysLC1atCjY2wAAAC1YhDHGhHsTgfJ6vXI6naqtrQ3J61G6z1wd9HMCaN2+Kc4N9xYA6wXy85u/iwcAAFiHQAEAANYhUAAAgHUIFAAAYB0CBQAAWIdAAQAA1iFQAACAdQgUAABgHQIFAABYh0ABAADWIVAAAIB1CBQAAGAdAgUAAFiHQAEAANYhUAAAgHUIFAAAYB0CBQAAWIdAAQAA1iFQAACAdQgUAABgHQIFAABYJzrcGwCA1qD7zNXh3kLAvinODfcWgEviGRQAAGAdAgUAAFiHQAEAANYhUAAAgHUIFAAAYB0CBQAAWIdAAQAA1iFQAACAdQgUAABgHQIFAABYh0ABAADWIVAAAIB1CBQAAGAdAgUAAFiHQAEAANYhUAAAgHUIFAAAYB0CBQAAWIdAAQAA1iFQAACAdQgUAABgHQIFAABYJzrcGwAAhEf3mavDvYWAfVOcG+4t4CrhGRQAAGAdAgUAAFiHQAEAANYhUAAAgHUIFAAAYB0CBQAAWIdAAQAA1iFQAACAdQgUAABgHQIFAABYh4+6BwC0GHw8/7WDZ1AAAIB1CBQAAGAdAgUAAFiHQAEAANbhRbIAAIRQS3xhrxT+F/eG9RmUhQsXqnv37mrbtq0yMjK0ZcuWcG4HAABYImyB8pe//EUFBQWaM2eO/vWvf2ngwIHKysrSsWPHwrUlAABgibAFyh/+8AdNnjxZjz76qPr166fFixerffv2euutt8K1JQAAYImwvAblzJkzKi8vV2Fhoe9YZGSkMjMzVVpaetH6+vp61dfX+76vra2VJHm93pDsr6H+u5CcFwCAliIUP2MvnNMYc8W1YQmUb7/9VufPn1dSUpLf8aSkJH399dcXrS8qKtLcuXMvOp6amhqyPQIAcC1zLgjduU+ePCmn03nZNS3iXTyFhYUqKCjwfd/Q0KDq6mp17txZERERzT6/1+tVamqqDh06JIfD0ezz2epamJMZW4drYUbp2piTGVuPYMxpjNHJkyeVkpJyxbVhCZQuXbooKipKVVVVfserqqrkcrkuWh8bG6vY2Fi/Y/Hx8UHfl8PhaNX/cl1wLczJjK3DtTCjdG3MyYytR3PnvNIzJxeE5UWyMTExGjJkiEpKSnzHGhoaVFJSIrfbHY4tAQAAi4TtVzwFBQWaMGGChg4dqmHDhmnBggU6deqUHn300XBtCQAAWCJsgfLQQw/p+PHjmj17tjwejwYNGqS1a9de9MLZqyE2NlZz5sy56NdIrc21MCcztg7XwozStTEnM7YeV3vOCPNj3usDAABwFfGXBQIAAOsQKAAAwDoECgAAsA6BAgAArHPNBEp1dbXGjx8vh8Oh+Ph4TZo0SXV1dZe9zenTp5Wfn6/OnTurY8eOysvLu+jD5SIiIi66LF++PJSj+CxcuFDdu3dX27ZtlZGRoS1btlx2/YoVK5Senq62bduqf//++vvf/+53vTFGs2fPVnJystq1a6fMzEzt27cvlCNcUbBnnDhx4kX3V3Z2dihH+FECmXP37t3Ky8tT9+7dFRERoQULFjT7nFdDsGd84YUXLrov09PTQzjBlQUy45tvvqnbb79dnTp1UqdOnZSZmXnRehsfk1Lw57TxcRnIjB988IGGDh2q+Ph4dejQQYMGDdKf//xnvzU23pfBnjHo96O5RmRnZ5uBAweazZs3m88++8z06tXLjBs37rK3eeyxx0xqaqopKSkxW7duNcOHDzc//elP/dZIMkuWLDFHjx71Xb7//vtQjmKMMWb58uUmJibGvPXWW2b37t1m8uTJJj4+3lRVVTW6/osvvjBRUVFm3rx5Zs+ePWbWrFmmTZs2ZufOnb41xcXFxul0mpUrV5p///vf5p577jE9evS4KvM0JhQzTpgwwWRnZ/vdX9XV1VdrpEYFOueWLVvM9OnTzXvvvWdcLpd59dVXm33OUAvFjHPmzDE33XST3315/PjxEE9yaYHO+PDDD5uFCxeabdu2mb1795qJEycap9NpDh8+7Ftj22PSmNDMadvjMtAZP/30U/PBBx+YPXv2mP3795sFCxaYqKgos3btWt8a2+7LUMwY7PvxmgiUPXv2GEnmq6++8h1bs2aNiYiIMP/9738bvU1NTY1p06aNWbFihe/Y3r17jSRTWlrqOybJfPjhhyHb+6UMGzbM5Ofn+74/f/68SUlJMUVFRY2u/8UvfmFyc3P9jmVkZJhf//rXxhhjGhoajMvlMq+88orv+pqaGhMbG2vee++9EExwZcGe0Zj/PYDuvffekOy3qQKd8/9KS0tr9Id3c84ZCqGYcc6cOWbgwIFB3GXzNPd/83Pnzpm4uDjz9ttvG2PsfEwaE/w5jbHvcRmMx8/gwYPNrFmzjDF23pfBntGY4N+P18SveEpLSxUfH6+hQ4f6jmVmZioyMlJlZWWN3qa8vFxnz55VZmam71h6erq6deum0tJSv7X5+fnq0qWLhg0bprfeeutH/TXSzXHmzBmVl5f77S0yMlKZmZkX7e2C0tJSv/WSlJWV5VtfWVkpj8fjt8bpdCojI+OS5wylUMx4wYYNG5SYmKg+ffpo6tSpOnHiRPAH+JGaMmc4ztkcodzPvn37lJKSop49e2r8+PE6ePBgc7fbJMGY8bvvvtPZs2eVkJAgyb7HpBSaOS+w5XHZ3BmNMSopKVFFRYVGjBghyb77MhQzXhDM+7FF/G3GzeXxeJSYmOh3LDo6WgkJCfJ4PJe8TUxMzEV/KWFSUpLfbV588UXdeeedat++vT755BM9/vjjqqur029+85ugz3HBt99+q/Pnz1/0qbtJSUn6+uuvG72Nx+NpdP2FWS7883JrrqZQzChJ2dnZeuCBB9SjRw8dOHBAzz33nHJyclRaWqqoqKjgD3IFTZkzHOdsjlDtJyMjQ0uXLlWfPn109OhRzZ07V7fffrt27dqluLi45m47IMGYccaMGUpJSfH90LDtMSmFZk7JrsdlU2esra3Vddddp/r6ekVFRWnRokX6+c9/Lsm++zIUM0rBvx9bdKDMnDlTL7/88mXX7N27N6R7eP75531fDx48WKdOndIrr7wS0kBB040dO9b3df/+/TVgwADdcMMN2rBhg0aNGhXGnSFQOTk5vq8HDBigjIwMpaWl6f3339ekSZPCuLPAFRcXa/ny5dqwYYPatm0b7u2EzKXmbA2Py7i4OG3fvl11dXUqKSlRQUGBevbsqZEjR4Z7a0FzpRmDfT+26F/xPP3009q7d+9lLz179pTL5dKxY8f8bnvu3DlVV1fL5XI1em6Xy6UzZ86opqbG73hVVdUlbyP977/qDh8+rPr6+mbPdyldunRRVFTURe8outzeXC7XZddf+Gcg5wylUMzYmJ49e6pLly7av39/8zfdBE2ZMxznbI6rtZ/4+Hj17t07LPdlc2acP3++iouL9cknn2jAgAG+47Y9JqXQzNmYcD4umzpjZGSkevXqpUGDBunpp5/Wgw8+qKKiIkn23ZehmLExzb0fW3SgdO3aVenp6Ze9xMTEyO12q6amRuXl5b7brl+/Xg0NDcrIyGj03EOGDFGbNm1UUlLiO1ZRUaGDBw/K7XZfck/bt29Xp06dQvqXKcXExGjIkCF+e2toaFBJSckl9+Z2u/3WS9K6det863v06CGXy+W3xuv1qqys7LLzhkooZmzM4cOHdeLECSUnJwdn4wFqypzhOGdzXK391NXV6cCBA2G5L5s647x58/S73/1Oa9eu9XuNnGTfY1IKzZyNCefjMlj/vjY0NPj+Q9W2+zIUMzam2fdj0F5ua7ns7GwzePBgU1ZWZj7//HNz4403+r3N+PDhw6ZPnz6mrKzMd+yxxx4z3bp1M+vXrzdbt241brfbuN1u3/V/+9vfzJtvvml27txp9u3bZxYtWmTat29vZs+eHfJ5li9fbmJjY83SpUvNnj17zJQpU0x8fLzxeDzGGGMeeeQRM3PmTN/6L774wkRHR5v58+ebvXv3mjlz5jT6NuP4+Hjz0UcfmR07dph777037G+DC+aMJ0+eNNOnTzelpaWmsrLS/POf/zQ/+clPzI033mhOnz4dlhmNCXzO+vp6s23bNrNt2zaTnJxspk+fbrZt22b27dv3o895tYVixqefftps2LDBVFZWmi+++MJkZmaaLl26mGPHjl31+YwJfMbi4mITExNj/vrXv/q9LfPkyZN+a2x6TBoT/DltfFwGOuNLL71kPvnkE3PgwAGzZ88eM3/+fBMdHW3efPNN3xrb7stgzxiK+/GaCZQTJ06YcePGmY4dOxqHw2EeffRRv/8jqKysNJLMp59+6jv2/fffm8cff9x06tTJtG/f3tx///3m6NGjvuvXrFljBg0aZDp27Gg6dOhgBg4caBYvXmzOnz9/VWZ6/fXXTbdu3UxMTIwZNmyY2bx5s++6O+64w0yYMMFv/fvvv2969+5tYmJizE033WRWr17td31DQ4N5/vnnTVJSkomNjTWjRo0yFRUVV2OUSwrmjN99950ZPXq06dq1q2nTpo1JS0szkydPDtsP7f8rkDkv/Lv6w8sdd9zxo88ZDsGe8aGHHjLJyckmJibGXHfddeahhx4y+/fvv4oTXSyQGdPS0hqdcc6cOb41Nj4mjQnunLY+LgOZ8be//a3p1auXadu2renUqZNxu91m+fLlfuez8b4M5oyhuB8jjAnxe2IBAAAC1KJfgwIAAFonAgUAAFiHQAEAANYhUAAAgHUIFAAAYB0CBQAAWIdAAQAA1iFQAACAdQgUAABgHQIFAABYh0ABAADWIVAAAIB1/h8st7+O1N1AGwAAAABJRU5ErkJggg==", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAGdCAYAAAA44ojeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAfS0lEQVR4nO3de3BU9f3/8VcuJFw3S4BkiYabCEHLrVDCtt4qqSRSLyVOhTIOOAxUDXY0UgWLILbTUGQqIwMydarYjkilo1ihUmkQULtESaEqIANMHKCwAckkISjhks/vj/7Y+a4EZMMu+97wfMyckZzz2cPn03Wbp2dvSc45JwAAAEOS4z0BAACAbyJQAACAOQQKAAAwh0ABAADmECgAAMAcAgUAAJhDoAAAAHMIFAAAYE5qvCfQEk1NTTp48KA6deqkpKSkeE8HAABcBOecjh07ppycHCUnX/gaSUIGysGDB5WbmxvvaQAAgBbYv3+/rr766guOSchA6dSpk6T/LdDj8cR5NgAA4GLU19crNzc39Hv8QhIyUM4+rePxeAgUAAASzMW8PIMXyQIAAHMIFAAAYA6BAgAAzCFQAACAOQQKAAAwh0ABAADmECgAAMAcAgUAAJhDoAAAAHMIFAAAYA6BAgAAzCFQAACAOQQKAAAwh0ABAADmpMZ7Arhy9ZqxJt5TiNgX88bEewoAcEXgCgoAADCHQAEAAOYQKAAAwBwCBQAAmEOgAAAAcwgUAABgDoECAADMIVAAAIA5BAoAADCHQAEAAObwUfdABPh4fgC4PLiCAgAAzCFQAACAOQQKAAAwh0ABAADmECgAAMAcAgUAAJhDoAAAAHMIFAAAYA6BAgAAzCFQAACAOQQKAAAwh0ABAADmECgAAMAcAgUAAJhDoAAAAHMIFAAAYA6BAgAAzCFQAACAOQQKAAAwh0ABAADmECgAAMAcAgUAAJhDoAAAAHMIFAAAYA6BAgAAzIkoUMrKyvS9731PnTp1UlZWlu6++27t2rUrbMyJEydUUlKiLl26qGPHjiouLlZ1dXXYmH379mnMmDFq3769srKy9Mtf/lKnT5++9NUAAIBWIaJA2bhxo0pKSrR582atW7dOp06d0m233abjx4+Hxjz66KN6++23tXLlSm3cuFEHDx7U2LFjQ8fPnDmjMWPG6OTJk/rXv/6lV155RcuWLdPs2bOjtyoAAJDQkpxzrqU3PnLkiLKysrRx40bddNNNqqurU7du3bR8+XLdc889kqTPP/9cAwYMUCAQ0MiRI/XOO+/oxz/+sQ4ePKjs7GxJ0tKlS/XEE0/oyJEjSktL+9a/t76+XhkZGaqrq5PH42np9BFnvWasifcUrghfzBsT7ykAgKTIfn9f0mtQ6urqJEmZmZmSpMrKSp06dUoFBQWhMXl5eerRo4cCgYAkKRAIaODAgaE4kaTRo0ervr5e27dvb/bvaWxsVH19fdgGAABarxYHSlNTkx555BH94Ac/0He+8x1JUjAYVFpamrxeb9jY7OxsBYPB0Jj/Gydnj5891pyysjJlZGSEttzc3JZOGwAAJIAWB0pJSYk+++wzrVixIprzadbMmTNVV1cX2vbv3x/zvxMAAMRPaktuNG3aNK1evVqbNm3S1VdfHdrv8/l08uRJ1dbWhl1Fqa6uls/nC4356KOPws539l0+Z8d8U3p6utLT01syVQAAkIAiuoLinNO0adP05ptvav369erdu3fY8WHDhqlNmzYqLy8P7du1a5f27dsnv98vSfL7/fr00091+PDh0Jh169bJ4/Houuuuu5S1AACAViKiKyglJSVavny53nrrLXXq1Cn0mpGMjAy1a9dOGRkZmjx5skpLS5WZmSmPx6OHH35Yfr9fI0eOlCTddtttuu6663Tfffdp/vz5CgaDmjVrlkpKSrhKAgAAJEUYKC+88IIk6ZZbbgnb//LLL2vSpEmSpOeee07JyckqLi5WY2OjRo8erSVLloTGpqSkaPXq1XrwwQfl9/vVoUMHTZw4Uc8888ylrQQAALQal/Q5KPHC56C0DnwOyuXB56AAsOKyfQ4KAABALBAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDmp8Z4AoqPXjDXxngIAAFHDFRQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOnyQLtHKJ+CnDX8wbE+8pAIgzrqAAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYE3GgbNq0SXfccYdycnKUlJSkVatWhR2fNGmSkpKSwrbCwsKwMTU1NZowYYI8Ho+8Xq8mT56shoaGS1oIAABoPSIOlOPHj2vw4MFavHjxeccUFhbq0KFDoe21114LOz5hwgRt375d69at0+rVq7Vp0yZNnTo18tkDAIBWKTXSGxQVFamoqOiCY9LT0+Xz+Zo9tnPnTq1du1Yff/yxhg8fLklatGiRbr/9di1YsEA5OTmRTgkAALQyMXkNyoYNG5SVlaX+/fvrwQcf1NGjR0PHAoGAvF5vKE4kqaCgQMnJyaqoqIjFdAAAQIKJ+ArKtyksLNTYsWPVu3dv7d27V08++aSKiooUCASUkpKiYDCorKys8EmkpiozM1PBYLDZczY2NqqxsTH0c319fbSnDQAADIl6oIwbNy7054EDB2rQoEG65pprtGHDBo0aNapF5ywrK9PcuXOjNUUAAGBczN9m3KdPH3Xt2lV79uyRJPl8Ph0+fDhszOnTp1VTU3Pe163MnDlTdXV1oW3//v2xnjYAAIijmAfKgQMHdPToUXXv3l2S5Pf7VVtbq8rKytCY9evXq6mpSfn5+c2eIz09XR6PJ2wDAACtV8RP8TQ0NISuhkhSVVWVtm3bpszMTGVmZmru3LkqLi6Wz+fT3r179fjjj6tv374aPXq0JGnAgAEqLCzUlClTtHTpUp06dUrTpk3TuHHjeAcPAACQ1IIrKFu2bNHQoUM1dOhQSVJpaamGDh2q2bNnKyUlRZ988onuvPNO9evXT5MnT9awYcP0/vvvKz09PXSOV199VXl5eRo1apRuv/123XDDDfrDH/4QvVUBAICEFvEVlFtuuUXOufMe/8c//vGt58jMzNTy5csj/asBAMAVgu/iAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYE3GgbNq0SXfccYdycnKUlJSkVatWhR13zmn27Nnq3r272rVrp4KCAu3evTtsTE1NjSZMmCCPxyOv16vJkyeroaHhkhYCAABaj4gD5fjx4xo8eLAWL17c7PH58+fr+eef19KlS1VRUaEOHTpo9OjROnHiRGjMhAkTtH37dq1bt06rV6/Wpk2bNHXq1JavAgAAtCqpkd6gqKhIRUVFzR5zzmnhwoWaNWuW7rrrLknSn/70J2VnZ2vVqlUaN26cdu7cqbVr1+rjjz/W8OHDJUmLFi3S7bffrgULFignJ+cSlgMAAFqDqL4GpaqqSsFgUAUFBaF9GRkZys/PVyAQkCQFAgF5vd5QnEhSQUGBkpOTVVFREc3pAACABBXxFZQLCQaDkqTs7Oyw/dnZ2aFjwWBQWVlZ4ZNITVVmZmZozDc1NjaqsbEx9HN9fX00pw0AAIxJiHfxlJWVKSMjI7Tl5ubGe0oAACCGohooPp9PklRdXR22v7q6OnTM5/Pp8OHDYcdPnz6tmpqa0Jhvmjlzpurq6kLb/v37ozltAABgTFQDpXfv3vL5fCovLw/tq6+vV0VFhfx+vyTJ7/ertrZWlZWVoTHr169XU1OT8vPzmz1venq6PB5P2AYAAFqviF+D0tDQoD179oR+rqqq0rZt25SZmakePXrokUce0W9+8xtde+216t27t5566inl5OTo7rvvliQNGDBAhYWFmjJlipYuXapTp05p2rRpGjduHO/gAQAAkloQKFu2bNEPf/jD0M+lpaWSpIkTJ2rZsmV6/PHHdfz4cU2dOlW1tbW64YYbtHbtWrVt2zZ0m1dffVXTpk3TqFGjlJycrOLiYj3//PNRWA4AAGgNkpxzLt6TiFR9fb0yMjJUV1fH0z3/X68Za+I9BSBqvpg3Jt5TABADkfz+Toh38QAAgCsLgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAnKh+WSAAREMivm2et0YD0cUVFAAAYA6BAgAAzCFQAACAOQQKAAAwh0ABAADmECgAAMAcAgUAAJhDoAAAAHMIFAAAYA6BAgAAzCFQAACAOQQKAAAwh0ABAADmECgAAMAcAgUAAJhDoAAAAHMIFAAAYA6BAgAAzCFQAACAOQQKAAAwh0ABAADmECgAAMAcAgUAAJhDoAAAAHMIFAAAYA6BAgAAzCFQAACAOQQKAAAwh0ABAADmECgAAMAcAgUAAJhDoAAAAHMIFAAAYA6BAgAAzCFQAACAOQQKAAAwh0ABAADmECgAAMCc1HhPwKJeM9bEewoAAFzRuIICAADMIVAAAIA5BAoAADCHQAEAAOYQKAAAwBwCBQAAmEOgAAAAcwgUAABgDoECAADMIVAAAIA5BAoAADCHQAEAAOYQKAAAwBwCBQAAmEOgAAAAcwgUAABgDoECAADMIVAAAIA5BAoAADCHQAEAAOYQKAAAwJyoB8rTTz+tpKSksC0vLy90/MSJEyopKVGXLl3UsWNHFRcXq7q6OtrTAAAACSwmV1Cuv/56HTp0KLR98MEHoWOPPvqo3n77ba1cuVIbN27UwYMHNXbs2FhMAwAAJKjUmJw0NVU+n++c/XV1dfrjH/+o5cuX69Zbb5UkvfzyyxowYIA2b96skSNHxmI6AAAgwcTkCsru3buVk5OjPn36aMKECdq3b58kqbKyUqdOnVJBQUFobF5ennr06KFAIHDe8zU2Nqq+vj5sAwAArVfUAyU/P1/Lli3T2rVr9cILL6iqqko33nijjh07pmAwqLS0NHm93rDbZGdnKxgMnvecZWVlysjICG25ubnRnjYAADAk6k/xFBUVhf48aNAg5efnq2fPnnr99dfVrl27Fp1z5syZKi0tDf1cX19PpAAA0IrF/G3GXq9X/fr10549e+Tz+XTy5EnV1taGjamurm72NStnpaeny+PxhG0AAKD1inmgNDQ0aO/everevbuGDRumNm3aqLy8PHR8165d2rdvn/x+f6ynAgAAEkTUn+KZPn267rjjDvXs2VMHDx7UnDlzlJKSovHjxysjI0OTJ09WaWmpMjMz5fF49PDDD8vv9/MOHgAAEBL1QDlw4IDGjx+vo0ePqlu3brrhhhu0efNmdevWTZL03HPPKTk5WcXFxWpsbNTo0aO1ZMmSaE8DAAAksCTnnIv3JCJVX1+vjIwM1dXVxeT1KL1mrIn6OQG0bl/MGxPvKQDmRfL7m+/iAQAA5hAoAADAHAIFAACYQ6AAAABzCBQAAGAOgQIAAMwhUAAAgDkECgAAMIdAAQAA5hAoAADAHAIFAACYE/UvCwSAK1EifocX3x8Ey7iCAgAAzCFQAACAOQQKAAAwh0ABAADmECgAAMAcAgUAAJhDoAAAAHMIFAAAYA6BAgAAzCFQAACAOQQKAAAwh0ABAADmECgAAMAcAgUAAJhDoAAAAHMIFAAAYA6BAgAAzCFQAACAOQQKAAAwh0ABAADmECgAAMAcAgUAAJhDoAAAAHMIFAAAYA6BAgAAzCFQAACAOQQKAAAwh0ABAADmECgAAMAcAgUAAJhDoAAAAHMIFAAAYA6BAgAAzCFQAACAOanxngAAID56zVgT7ylE7It5Y+I9BVwmXEEBAADmECgAAMAcAgUAAJhDoAAAAHMIFAAAYA6BAgAAzCFQAACAOQQKAAAwh0ABAADmECgAAMAcAgUAAJhDoAAAAHMIFAAAYA6BAgAAzCFQAACAOQQKAAAwh0ABAADmECgAAMCc1HhPAACAi9Vrxpp4TyFiX8wbE+8pJCSuoAAAAHPiegVl8eLFevbZZxUMBjV48GAtWrRII0aMiOeUAACIqkS86iPF/8pP3K6g/OUvf1FpaanmzJmjf//73xo8eLBGjx6tw4cPx2tKAADAiLgFyu9//3tNmTJF999/v6677jotXbpU7du310svvRSvKQEAACPi8hTPyZMnVVlZqZkzZ4b2JScnq6CgQIFA4JzxjY2NamxsDP1cV1cnSaqvr4/J/Joav4rJeQEASBSx+B179pzOuW8dG5dA+fLLL3XmzBllZ2eH7c/Oztbnn39+zviysjLNnTv3nP25ubkxmyMAAFeyjIWxO/exY8eUkZFxwTEJ8TbjmTNnqrS0NPRzU1OTampq1KVLFyUlJbX4vPX19crNzdX+/fvl8XiiMVXzrsQ1S1fmulnzlbFm6cpcN2tOzDU753Ts2DHl5OR869i4BErXrl2VkpKi6urqsP3V1dXy+XznjE9PT1d6enrYPq/XG7X5eDyehL2zW+pKXLN0Za6bNV85rsR1s+bE821XTs6Ky4tk09LSNGzYMJWXl4f2NTU1qby8XH6/Px5TAgAAhsTtKZ7S0lJNnDhRw4cP14gRI7Rw4UIdP35c999/f7ymBAAAjIhboNx77706cuSIZs+erWAwqCFDhmjt2rXnvHA2ltLT0zVnzpxznj5qza7ENUtX5rpZ85XjSlw3a279ktzFvNcHAADgMuK7eAAAgDkECgAAMIdAAQAA5hAoAADAnFYdKDU1NZowYYI8Ho+8Xq8mT56shoaGC97mxIkTKikpUZcuXdSxY0cVFxef84FySUlJ52wrVqyI5VIuaPHixerVq5fatm2r/Px8ffTRRxccv3LlSuXl5alt27YaOHCg/v73v4cdd85p9uzZ6t69u9q1a6eCggLt3r07lkuIWLTXPGnSpHPu08LCwlguIWKRrHn79u0qLi5Wr169lJSUpIULF17yOeMl2ut++umnz7mv8/LyYriCyEWy5hdffFE33nijOnfurM6dO6ugoOCc8YnwmJaiv+7W9rh+4403NHz4cHm9XnXo0EFDhgzRn//857AxiXJfXxTXihUWFrrBgwe7zZs3u/fff9/17dvXjR8//oK3eeCBB1xubq4rLy93W7ZscSNHjnTf//73w8ZIci+//LI7dOhQaPv6669juZTzWrFihUtLS3MvvfSS2759u5syZYrzer2uurq62fEffvihS0lJcfPnz3c7duxws2bNcm3atHGffvppaMy8efNcRkaGW7VqlfvPf/7j7rzzTte7d++4rfGbYrHmiRMnusLCwrD7tKam5nIt6VtFuuaPPvrITZ8+3b322mvO5/O555577pLPGQ+xWPecOXPc9ddfH3ZfHzlyJMYruXiRrvlnP/uZW7x4sdu6davbuXOnmzRpksvIyHAHDhwIjbH+mHYuNutubY/r9957z73xxhtux44dbs+ePW7hwoUuJSXFrV27NjQmEe7ri9VqA2XHjh1Okvv4449D+9555x2XlJTk/vvf/zZ7m9raWtemTRu3cuXK0L6dO3c6SS4QCIT2SXJvvvlmzOYeiREjRriSkpLQz2fOnHE5OTmurKys2fE//elP3ZgxY8L25efnu5///OfOOeeampqcz+dzzz77bOh4bW2tS09Pd6+99loMVhC5aK/Zuf/9H9ldd90Vk/lGQ6Rr/r969uzZ7C/qSznn5RKLdc+ZM8cNHjw4irOMrku9X06fPu06derkXnnlFedcYjymnYv+up1r3Y/rs4YOHepmzZrlnEuc+/pitdqneAKBgLxer4YPHx7aV1BQoOTkZFVUVDR7m8rKSp06dUoFBQWhfXl5eerRo4cCgUDY2JKSEnXt2lUjRozQSy+9dFFfHR1tJ0+eVGVlZdh8k5OTVVBQcM58zwoEAmHjJWn06NGh8VVVVQoGg2FjMjIylJ+ff95zXk6xWPNZGzZsUFZWlvr3768HH3xQR48ejf4CWqAla47HOaMtlnPcvXu3cnJy1KdPH02YMEH79u271OlGRTTW/NVXX+nUqVPKzMyUZP8xLcVm3We11se1c07l5eXatWuXbrrpJkmJcV9HIiG+zbglgsGgsrKywvalpqYqMzNTwWDwvLdJS0s754sIs7Ozw27zzDPP6NZbb1X79u317rvv6qGHHlJDQ4N+8YtfRH0dF/Lll1/qzJkz53z6bnZ2tj7//PNmbxMMBpsdf3Z9Z/95oTHxFIs1S1JhYaHGjh2r3r17a+/evXryySdVVFSkQCCglJSU6C8kAi1ZczzOGW2xmmN+fr6WLVum/v3769ChQ5o7d65uvPFGffbZZ+rUqdOlTvuSRGPNTzzxhHJyckK/pKw/pqXYrFtqnY/ruro6XXXVVWpsbFRKSoqWLFmiH/3oR5IS476ORMIFyowZM/S73/3ugmN27twZ0zk89dRToT8PHTpUx48f17PPPnvZAwXRM27cuNCfBw4cqEGDBumaa67Rhg0bNGrUqDjODNFWVFQU+vOgQYOUn5+vnj176vXXX9fkyZPjOLNLN2/ePK1YsUIbNmxQ27Zt4z2dy+Z8626Nj+tOnTpp27ZtamhoUHl5uUpLS9WnTx/dcsst8Z5a1CXcUzyPPfaYdu7cecGtT58+8vl8Onz4cNhtT58+rZqaGvl8vmbP7fP5dPLkSdXW1obtr66uPu9tpP/9F9mBAwfU2Nh4yeuLRNeuXZWSknLOu4wuNF+fz3fB8Wf/Gck5L6dYrLk5ffr0UdeuXbVnz55Ln/Qlasma43HOaLtcc/R6verXr1/C39cLFizQvHnz9O6772rQoEGh/dYf01Js1t2c1vC4Tk5OVt++fTVkyBA99thjuueee1RWViYpMe7rSCRcoHTr1k15eXkX3NLS0uT3+1VbW6vKysrQbdevX6+mpibl5+c3e+5hw4apTZs2Ki8vD+3btWuX9u3bJ7/ff945bdu2TZ07d77sX+CUlpamYcOGhc23qalJ5eXl552v3+8PGy9J69atC43v3bu3fD5f2Jj6+npVVFRc8H+DyyUWa27OgQMHdPToUXXv3j06E78ELVlzPM4ZbZdrjg0NDdq7d29C39fz58/Xr3/9a61duzbsdXeS/ce0FJt1N6c1Pq6bmppC/3GcCPd1ROL9Kt1YKiwsdEOHDnUVFRXugw8+cNdee23Y24wPHDjg+vfv7yoqKkL7HnjgAdejRw+3fv16t2XLFuf3+53f7w8d/9vf/uZefPFF9+mnn7rdu3e7JUuWuPbt27vZs2df1rWdtWLFCpeenu6WLVvmduzY4aZOneq8Xq8LBoPOOefuu+8+N2PGjND4Dz/80KWmproFCxa4nTt3ujlz5jT7NmOv1+veeust98knn7i77rrL1NvUor3mY8eOuenTp7tAIOCqqqrcP//5T/fd737XXXvtte7EiRNxWeM3RbrmxsZGt3XrVrd161bXvXt3N336dLd161a3e/fuiz6nBbFY92OPPeY2bNjgqqqq3IcffugKCgpc165d3eHDhy/7+poT6ZrnzZvn0tLS3F//+tewt9MeO3YsbIzlx7Rz0V93a3xc//a3v3Xvvvuu27t3r9uxY4dbsGCBS01NdS+++GJoTCLc1xerVQfK0aNH3fjx413Hjh2dx+Nx999/f9iDtqqqykly7733Xmjf119/7R566CHXuXNn1759e/eTn/zEHTp0KHT8nXfecUOGDHEdO3Z0HTp0cIMHD3ZLly51Z86cuZxLC7No0SLXo0cPl5aW5kaMGOE2b94cOnbzzTe7iRMnho1//fXXXb9+/VxaWpq7/vrr3Zo1a8KONzU1uaeeesplZ2e79PR0N2rUKLdr167LsZSLFs01f/XVV+62225z3bp1c23atHE9e/Z0U6ZMMfWL2rnI1nz23+1vbjfffPNFn9OKaK/73nvvdd27d3dpaWnuqquucvfee6/bs2fPZVzRt4tkzT179mx2zXPmzAmNSYTHtHPRXXdrfFz/6le/cn379nVt27Z1nTt3dn6/361YsSLsfIlyX1+MJOfi8P5YAACAC0i416AAAIDWj0ABAADmECgAAMAcAgUAAJhDoAAAAHMIFAAAYA6BAgAAzCFQAACAOQQKAAAwh0ABAADmECgAAMAcAgUAAJjz/wBW5ZmuC6VvaAAAAABJRU5ErkJggg==", "text/plain": [ "
" ] @@ -269,7 +283,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": null, "id": "f41fc285-883e-4313-bae2-e8e1c10a3e5b", "metadata": {}, "outputs": [ @@ -277,26 +291,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "shape theta torch.Size([1000, 5])\n", - "shape x torch.Size([1000, 1])\n", - " Neural network successfully converged after 224 epochs." + " Training neural network. Epochs trained: 103" ] } ], "source": [ - "# Now let's put them in a tensor form that SBI can read.\n", - "theta = torch.tensor(thetas,dtype=torch.float32)#.reshape(100,1,5)\n", - "x = torch.tensor(xs,dtype=torch.float32)\n", "\n", - "print('shape theta', np.shape(theta))\n", - "print('shape x', np.shape(x))\n", - "#embedding_net = SummaryNet()\n", - "\n", - "# instantiate the neural density estimator\n", - "neural_posterior = sbi.utils.posterior_nn(model='maf')#,\n", - " #embedding_net=embedding_net,\n", - " #hidden_features=hidden_features,\n", - " #num_transforms=num_transforms)\n", "\n", "# make a fake prior\n", "# L, theta_0, a_g, mu, sigma\n", @@ -311,14 +311,9 @@ "prior_low = [0, jnp.pi / 1000, 0, 0, 0]\n", "prior_high = [10, jnp.pi / 10, 20, 20, 4]\n", "\n", - "prior = utils.BoxUniform(low=torch.tensor(prior_low), high=torch.tensor(prior_high), device='cpu')\n", - "\n", - "# setup the inference procedure with the SNPE-C procedure\n", - "inference = SNPE(prior=prior, density_estimator=neural_posterior, device=\"cpu\")\n", + "prior = sbi.utils.BoxUniform(low=torch.tensor(prior_low), high=torch.tensor(prior_high), device='cpu')\n", "\n", - "# Now that we have both the simulated images and parameters defined properly, we can train the SBI.\n", - "density_estimator = inference.append_simulations(theta,x).train()\n", - "posterior = inference.build_posterior(density_estimator)" + "posterior = train.train_SBI_hierarchical(thetas, xs, prior)" ] }, { diff --git a/src/scripts/train.py b/src/scripts/train.py index a24465b..9735959 100644 --- a/src/scripts/train.py +++ b/src/scripts/train.py @@ -3,6 +3,8 @@ Can leave a default data source, or specify that 'load data' loads the dataset used in the final version """ import argparse +import torch +import sbi def architecture(): @@ -29,6 +31,27 @@ def train_model(data_source, n_epochs): return 0 +def train_SBI_hierarchical(thetas, xs, prior): + # Now let's put them in a tensor form that SBI can read. + theta = torch.tensor(thetas, dtype=torch.float32) + x = torch.tensor(xs, dtype=torch.float32) + + # instantiate the neural density estimator + neural_posterior = sbi.utils.posterior_nn(model='maf')#, + #embedding_net=embedding_net, + #hidden_features=hidden_features, + #num_transforms=num_transforms) + # setup the inference procedure with the SNPE-C procedure + inference = sbi.inference.SNPE(prior=prior, + density_estimator=neural_posterior, + device="cpu") + + # now that we have both the simulated images and + # parameters defined properly, we can train the SBI. + density_estimator = inference.append_simulations(theta, x).train() + return inference.build_posterior(density_estimator) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--data_source", type=str, help="Data used to train the model")