Skip to content

Commit

Permalink
modifying the checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
beckynevin committed Mar 22, 2024
1 parent 25c03fb commit 520121c
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions notebooks/epistemic_by_epoch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": 45,
"id": "6cb14aad-12f0-48c7-80fd-995a68618902",
"metadata": {},
"outputs": [],
"source": [
"def load_model_checkpoint(model, model_name, epoch, device, path='../models'):\n",
"def load_model_checkpoint(model, model_name, nmodel, epoch, device, beta='0.5', path='../models'):\n",
" \"\"\"\n",
" Load PyTorch model checkpoint from a .pt file.\n",
"\n",
Expand All @@ -52,7 +52,11 @@
" :param model: PyTorch model to load the checkpoint into\n",
" :return: Loaded model\n",
" \"\"\"\n",
" file_name = f\"{path}/{model_name}_epoch_{epoch}.pt\"\n",
" if model_name == 'DE_noise_low' or model_name == 'DE_noise_medium' or model_name == 'DE_noise_high':\n",
" file_name = f\"{path}/{model_name}_beta_{beta}_nmodel_{nmodel}_epoch_{epoch}.pt\"\n",
" else:\n",
" file_name = f\"{path}/{model_name}_nmodel_{nmodel}_epoch_{epoch}.pt\"\n",
" checkpoint = torch.load(file_name, map_location=device)\n",
" \n",
" checkpoint = torch.load(file_name, map_location=device)\n",
" return checkpoint\n",
Expand Down Expand Up @@ -529,36 +533,36 @@
},
{
"cell_type": "code",
"execution_count": 32,
"execution_count": 36,
"id": "889d5e6b-bfee-4e62-b8bd-2feb19d6d326",
"metadata": {},
"outputs": [],
"source": [
"nmodels = 20\n",
"nmodels = 100\n",
"nepochs = 100\n",
"DE_type = 'bnll_loss'\n",
"model, mse_loss_de = models.model_setup_DE(DE_type, DEVICE)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"execution_count": 47,
"id": "2e1d2056-5a5c-4390-9644-1de59cf4018b",
"metadata": {},
"outputs": [
{
"ename": "FileNotFoundError",
"evalue": "[Errno 2] No such file or directory: '../models/DE_noise_medium_nmodel_5_epoch_0.pt'",
"evalue": "[Errno 2] No such file or directory: '../models/DE_noise_low_beta_0.5_nmodel_0_epoch_0.pt'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[33], line 29\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m n \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(nmodels):\n\u001b[1;32m 28\u001b[0m model_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mDE_noise_medium_nmodel_\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mstr\u001b[39m(n)\n\u001b[0;32m---> 29\u001b[0m chk \u001b[38;5;241m=\u001b[39m \u001b[43mload_model_checkpoint\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43me\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mDEVICE\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 30\u001b[0m _, _, mu_vals, sig_vals \u001b[38;5;241m=\u001b[39m ep_al_checkpoint_DE(chk)\n\u001b[1;32m 31\u001b[0m list_mus\u001b[38;5;241m.\u001b[39mappend(mu_vals\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mnumpy())\n",
"Cell \u001b[0;32mIn[31], line 14\u001b[0m, in \u001b[0;36mload_model_checkpoint\u001b[0;34m(model, model_name, epoch, device, path)\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;124;03mLoad PyTorch model checkpoint from a .pt file.\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;124;03m:return: Loaded model\u001b[39;00m\n\u001b[1;32m 11\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 12\u001b[0m file_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpath\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m_epoch_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m---> 14\u001b[0m checkpoint \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfile_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmap_location\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m checkpoint\n",
"Cell \u001b[0;32mIn[47], line 11\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m n \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(nmodels):\n\u001b[1;32m 10\u001b[0m model_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mDE_noise_low\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m---> 11\u001b[0m chk \u001b[38;5;241m=\u001b[39m \u001b[43mload_model_checkpoint\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43me\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mDEVICE\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 12\u001b[0m _, _, mu_vals, sig_vals \u001b[38;5;241m=\u001b[39m ep_al_checkpoint_DE(chk)\n\u001b[1;32m 13\u001b[0m list_mus\u001b[38;5;241m.\u001b[39mappend(mu_vals\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mnumpy())\n",
"Cell \u001b[0;32mIn[45], line 16\u001b[0m, in \u001b[0;36mload_model_checkpoint\u001b[0;34m(model, model_name, nmodel, epoch, device, beta, path)\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 15\u001b[0m file_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpath\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m_nmodel_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnmodel\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m_epoch_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m---> 16\u001b[0m checkpoint \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfile_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmap_location\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 18\u001b[0m checkpoint \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mload(file_name, map_location\u001b[38;5;241m=\u001b[39mdevice)\n\u001b[1;32m 19\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m checkpoint\n",
"File \u001b[0;32m~/Library/Caches/pypoetry/virtualenvs/deepuq-DRzT0TL8-py3.9/lib/python3.9/site-packages/torch/serialization.py:986\u001b[0m, in \u001b[0;36mload\u001b[0;34m(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)\u001b[0m\n\u001b[1;32m 983\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mencoding\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m pickle_load_args\u001b[38;5;241m.\u001b[39mkeys():\n\u001b[1;32m 984\u001b[0m pickle_load_args[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mencoding\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mutf-8\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m--> 986\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[43m_open_file_like\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mrb\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m opened_file:\n\u001b[1;32m 987\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _is_zipfile(opened_file):\n\u001b[1;32m 988\u001b[0m \u001b[38;5;66;03m# The zipfile reader is going to advance the current file position.\u001b[39;00m\n\u001b[1;32m 989\u001b[0m \u001b[38;5;66;03m# If we want to actually tail call to torch.jit.load, we need to\u001b[39;00m\n\u001b[1;32m 990\u001b[0m \u001b[38;5;66;03m# reset back to the original position.\u001b[39;00m\n\u001b[1;32m 991\u001b[0m orig_position \u001b[38;5;241m=\u001b[39m opened_file\u001b[38;5;241m.\u001b[39mtell()\n",
"File \u001b[0;32m~/Library/Caches/pypoetry/virtualenvs/deepuq-DRzT0TL8-py3.9/lib/python3.9/site-packages/torch/serialization.py:435\u001b[0m, in \u001b[0;36m_open_file_like\u001b[0;34m(name_or_buffer, mode)\u001b[0m\n\u001b[1;32m 433\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_open_file_like\u001b[39m(name_or_buffer, mode):\n\u001b[1;32m 434\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _is_path(name_or_buffer):\n\u001b[0;32m--> 435\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_open_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname_or_buffer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 436\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 437\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01min\u001b[39;00m mode:\n",
"File \u001b[0;32m~/Library/Caches/pypoetry/virtualenvs/deepuq-DRzT0TL8-py3.9/lib/python3.9/site-packages/torch/serialization.py:416\u001b[0m, in \u001b[0;36m_open_file.__init__\u001b[0;34m(self, name, mode)\u001b[0m\n\u001b[1;32m 415\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, name, mode):\n\u001b[0;32m--> 416\u001b[0m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m)\u001b[49m)\n",
"\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '../models/DE_noise_medium_nmodel_5_epoch_0.pt'"
"\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '../models/DE_noise_low_beta_0.5_nmodel_0_epoch_0.pt'"
]
}
],
Expand All @@ -572,15 +576,16 @@
" list_sigs = []\n",
" for n in range(nmodels):\n",
"\n",
" model_name = 'DE_noise_low_nmodel_' + str(n)\n",
" chk = load_model_checkpoint(model, model_name, e, DEVICE)\n",
" model_name = 'DE_noise_low'\n",
" chk = load_model_checkpoint(model, model_name, 0, e, DEVICE)\n",
" _, _, mu_vals, sig_vals = ep_al_checkpoint_DE(chk)\n",
" list_mus.append(mu_vals.detach().numpy())\n",
" list_sigs.append(sig_vals.detach().numpy()**2)\n",
" low_ep.append(np.median(np.std(list_mus, axis = 0)))\n",
" low_al_var.append(np.median(np.mean(list_sigs, axis = 0)))\n",
" low_ep_std.append(np.std(np.std(list_mus, axis = 0)))\n",
" low_al_var_std.append(np.std(np.mean(list_sigs, axis = 0)))\n",
"'''\n",
"med_ep = []\n",
"med_al_var = []\n",
"med_ep_std = []\n",
Expand Down Expand Up @@ -617,7 +622,8 @@
" high_ep.append(np.median(np.std(list_mus, axis = 0)))\n",
" high_al_var.append(np.median(np.mean(list_sigs, axis = 0)))\n",
" high_ep_std.append(np.std(np.std(list_mus, axis = 0)))\n",
" high_al_var_std.append(np.std(np.mean(list_sigs, axis = 0)))\n"
" high_al_var_std.append(np.std(np.mean(list_sigs, axis = 0)))\n",
"'''"
]
},
{
Expand Down

0 comments on commit 520121c

Please sign in to comment.