From 5ea8c6ba0ae3106147a19b45b4185ca55effe747 Mon Sep 17 00:00:00 2001 From: Giwan Date: Tue, 10 Nov 2020 09:02:24 +0900 Subject: [PATCH] Modify rainnet notebook --- notebooks/01-baseline-pytorch.ipynb | 14582 +++++++++++++++++++++++++- notebooks/02-rainnet.ipynb | 4373 ++++++-- notebooks/03-unet.ipynb | 5025 +++++++-- 3 files changed, 22062 insertions(+), 1918 deletions(-) diff --git a/notebooks/01-baseline-pytorch.ipynb b/notebooks/01-baseline-pytorch.ipynb index 2d33622..94b7dd7 100644 --- a/notebooks/01-baseline-pytorch.ipynb +++ b/notebooks/01-baseline-pytorch.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -12,45 +12,202 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 2;\n", + " var nbb_unformatted_code = \"%reload_ext autoreload\\n%autoreload 2\\n%matplotlib inline\\n%reload_ext nb_black\";\n", + " var nbb_formatted_code = \"%reload_ext autoreload\\n%autoreload 2\\n%matplotlib inline\\n%reload_ext nb_black\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", + "%matplotlib inline\n", "%reload_ext nb_black" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 3;\n", + " var nbb_unformatted_code = \"import gc\\nfrom pathlib import Path\\nfrom tqdm.notebook import tqdm\\nimport warnings\\n\\nimport cv2\\nimport numpy as np\\nimport pandas as pd\\nimport matplotlib.pyplot as plt\\nfrom sklearn import metrics\\n\\nimport torch\\nimport torch.nn as nn\\nimport torch.nn.functional as F\\nfrom torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler\\n\\nimport torchvision.transforms as T\\nimport albumentations as A\\nfrom albumentations.pytorch import ToTensorV2\\n\\nimport pytorch_lightning as pl\\nfrom transformers import (\\n AdamW,\\n get_linear_schedule_with_warmup,\\n get_cosine_schedule_with_warmup,\\n)\\n\\nimport optim\\nimport loss\\nfrom utils import visualize, radar2precipitation, seed_everything\";\n", + " var nbb_formatted_code = \"import gc\\nfrom pathlib import Path\\nfrom tqdm.notebook import tqdm\\nimport warnings\\n\\nimport cv2\\nimport numpy as np\\nimport pandas as pd\\nimport matplotlib.pyplot as plt\\nfrom sklearn import metrics\\n\\nimport torch\\nimport torch.nn as nn\\nimport torch.nn.functional as F\\nfrom torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler\\n\\nimport torchvision.transforms as T\\nimport albumentations as A\\nfrom albumentations.pytorch import ToTensorV2\\n\\nimport pytorch_lightning as pl\\nfrom transformers import (\\n AdamW,\\n get_linear_schedule_with_warmup,\\n get_cosine_schedule_with_warmup,\\n)\\n\\nimport optim\\nimport loss\\nfrom utils import visualize, radar2precipitation, seed_everything\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "import gc\n", "from pathlib import Path\n", + "from tqdm.notebook import tqdm\n", + "import warnings\n", "\n", "import cv2\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", - "from tqdm.notebook import tqdm\n", - "from sklearn import model_selection\n", + "from sklearn import metrics\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", + "from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler\n", "\n", - "import pytorch_lightning as pl" + "import torchvision.transforms as T\n", + "import albumentations as A\n", + "from albumentations.pytorch import ToTensorV2\n", + "\n", + "import pytorch_lightning as pl\n", + "from transformers import (\n", + " AdamW,\n", + " get_linear_schedule_with_warmup,\n", + " get_cosine_schedule_with_warmup,\n", + ")\n", + "\n", + "import optim\n", + "import loss\n", + "from utils import visualize, radar2precipitation, seed_everything" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 4;\n", + " var nbb_unformatted_code = \"warnings.simplefilter(\\\"ignore\\\")\";\n", + " var nbb_formatted_code = \"warnings.simplefilter(\\\"ignore\\\")\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "warnings.simplefilter(\"ignore\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 5;\n", + " var nbb_unformatted_code = \"args = dict(\\n seed=42,\\n dams=(6071, 6304, 7026, 7629, 7767, 8944, 11107),\\n train_folds_csv=Path(\\\"../input/train_folds.csv\\\"),\\n train_data_path=Path(\\\"../input/train\\\"),\\n test_data_path=Path(\\\"../input/test\\\"),\\n model_dir=Path(\\\"../models\\\"),\\n output_dir=Path(\\\"../output\\\"),\\n rng=255.0,\\n num_workers=4,\\n gpus=1,\\n lr=5e-4,\\n max_epochs=50,\\n batch_size=128,\\n precision=16,\\n optimizer=\\\"adamw\\\",\\n scheduler=\\\"cosine\\\",\\n warmup_epochs=1,\\n accumulate_grad_batches=1,\\n gradient_clip_val=5.0,\\n)\";\n", + " var nbb_formatted_code = \"args = dict(\\n seed=42,\\n dams=(6071, 6304, 7026, 7629, 7767, 8944, 11107),\\n train_folds_csv=Path(\\\"../input/train_folds.csv\\\"),\\n train_data_path=Path(\\\"../input/train\\\"),\\n test_data_path=Path(\\\"../input/test\\\"),\\n model_dir=Path(\\\"../models\\\"),\\n output_dir=Path(\\\"../output\\\"),\\n rng=255.0,\\n num_workers=4,\\n gpus=1,\\n lr=5e-4,\\n max_epochs=50,\\n batch_size=128,\\n precision=16,\\n optimizer=\\\"adamw\\\",\\n scheduler=\\\"cosine\\\",\\n warmup_epochs=1,\\n accumulate_grad_batches=1,\\n gradient_clip_val=5.0,\\n)\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "PATH = Path(\"../input\")" + "args = dict(\n", + " seed=42,\n", + " dams=(6071, 6304, 7026, 7629, 7767, 8944, 11107),\n", + " train_folds_csv=Path(\"../input/train_folds.csv\"),\n", + " train_data_path=Path(\"../input/train\"),\n", + " test_data_path=Path(\"../input/test\"),\n", + " model_dir=Path(\"../models\"),\n", + " output_dir=Path(\"../output\"),\n", + " rng=255.0,\n", + " num_workers=4,\n", + " gpus=1,\n", + " lr=5e-4,\n", + " max_epochs=50,\n", + " batch_size=128,\n", + " precision=16,\n", + " optimizer=\"adamw\",\n", + " scheduler=\"cosine\",\n", + " warmup_epochs=1,\n", + " accumulate_grad_batches=1,\n", + " gradient_clip_val=5.0,\n", + ")" ] }, { @@ -60,6 +217,468 @@ "# 🔥 Baseline ⚡️" ] }, + { + "cell_type": "markdown", + "metadata": { + "heading_collapsed": true + }, + "source": [ + "## Sketch" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "hidden": true + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
filenamefold
0train_60668.npy0
1train_60911.npy0
2train_14956.npy0
3train_24086.npy0
4train_09805.npy0
\n", + "
" + ], + "text/plain": [ + " filename fold\n", + "0 train_60668.npy 0\n", + "1 train_60911.npy 0\n", + "2 train_14956.npy 0\n", + "3 train_24086.npy 0\n", + "4 train_09805.npy 0" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 6;\n", + " var nbb_unformatted_code = \"df = pd.read_csv(args[\\\"train_folds_csv\\\"])\\ndf.head()\";\n", + " var nbb_formatted_code = \"df = pd.read_csv(args[\\\"train_folds_csv\\\"])\\ndf.head()\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "df = pd.read_csv(args[\"train_folds_csv\"])\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "hidden": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'train_60668.npy'" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 7;\n", + " var nbb_unformatted_code = \"fn = df.loc[0, \\\"filename\\\"]\\nfn\";\n", + " var nbb_formatted_code = \"fn = df.loc[0, \\\"filename\\\"]\\nfn\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fn = df.loc[0, \"filename\"]\n", + "fn" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "hidden": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "PosixPath('../input/train/train_60668.npy')" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 8;\n", + " var nbb_unformatted_code = \"path = args[\\\"train_data_path\\\"] / fn\\npath\";\n", + " var nbb_formatted_code = \"path = args[\\\"train_data_path\\\"] / fn\\npath\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "path = args[\"train_data_path\"] / fn\n", + "path" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "hidden": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(120, 120, 5)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 9;\n", + " var nbb_unformatted_code = \"data = np.load(path)\\ndata.shape\";\n", + " var nbb_formatted_code = \"data = np.load(path)\\ndata.shape\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "data = np.load(path)\n", + "data.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "hidden": true + }, + "outputs": [ + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 10;\n", + " var nbb_unformatted_code = \"tfms = A.Compose(\\n [\\n A.PadIfNeeded(min_height=128, min_width=128, always_apply=True, p=1),\\n A.VerticalFlip(p=0.5),\\n A.HorizontalFlip(p=0.5),\\n A.RandomRotate90(p=0.5),\\n A.Transpose(p=0.5),\\n ToTensorV2(always_apply=True, p=1),\\n ]\\n)\";\n", + " var nbb_formatted_code = \"tfms = A.Compose(\\n [\\n A.PadIfNeeded(min_height=128, min_width=128, always_apply=True, p=1),\\n A.VerticalFlip(p=0.5),\\n A.HorizontalFlip(p=0.5),\\n A.RandomRotate90(p=0.5),\\n A.Transpose(p=0.5),\\n ToTensorV2(always_apply=True, p=1),\\n ]\\n)\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "tfms = A.Compose(\n", + " [\n", + " A.PadIfNeeded(min_height=128, min_width=128, always_apply=True, p=1),\n", + " A.VerticalFlip(p=0.5),\n", + " A.HorizontalFlip(p=0.5),\n", + " A.RandomRotate90(p=0.5),\n", + " A.Transpose(p=0.5),\n", + " ToTensorV2(always_apply=True, p=1),\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "hidden": true + }, + "outputs": [ + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 11;\n", + " var nbb_unformatted_code = \"augmented = tfms(image=data)\";\n", + " var nbb_formatted_code = \"augmented = tfms(image=data)\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "augmented = tfms(image=data)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "hidden": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([5, 128, 128])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 12;\n", + " var nbb_unformatted_code = \"image = augmented[\\\"image\\\"]\\nimage.shape\";\n", + " var nbb_formatted_code = \"image = augmented[\\\"image\\\"]\\nimage.shape\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "image = augmented[\"image\"]\n", + "image.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "hidden": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(tensor(0, dtype=torch.uint8), tensor(220, dtype=torch.uint8))" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 13;\n", + " var nbb_unformatted_code = \"image[0].min(), image[0].max()\";\n", + " var nbb_formatted_code = \"image[0].min(), image[0].max()\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "image[0].min(), image[0].max()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "hidden": true + }, + "outputs": [], + "source": [] + }, { "cell_type": "markdown", "metadata": {}, @@ -69,13 +688,52 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 14;\n", + " var nbb_unformatted_code = \"class NowcastingDataset(Dataset):\\n def __init__(self, paths, tfms=None, test=False):\\n self.paths = paths\\n if tfms is not None:\\n self.tfms = tfms\\n else:\\n self.tfms = A.Compose(\\n [\\n A.PadIfNeeded(\\n min_height=128, min_width=128, always_apply=True, p=1\\n ),\\n ToTensorV2(always_apply=True, p=1),\\n ]\\n )\\n self.test = test\\n\\n def __len__(self):\\n return len(self.paths)\\n\\n def __getitem__(self, idx):\\n path = self.paths[idx]\\n data = np.load(path)\\n\\n augmented = self.tfms(image=data)\\n data = augmented[\\\"image\\\"]\\n\\n x = data[:4, :, :]\\n x = x / args[\\\"rng\\\"]\\n if self.test:\\n return x\\n else:\\n y = data[4, :, :]\\n y = y / args[\\\"rng\\\"]\\n y = y.unsqueeze(0)\\n\\n return x, y\";\n", + " var nbb_formatted_code = \"class NowcastingDataset(Dataset):\\n def __init__(self, paths, tfms=None, test=False):\\n self.paths = paths\\n if tfms is not None:\\n self.tfms = tfms\\n else:\\n self.tfms = A.Compose(\\n [\\n A.PadIfNeeded(\\n min_height=128, min_width=128, always_apply=True, p=1\\n ),\\n ToTensorV2(always_apply=True, p=1),\\n ]\\n )\\n self.test = test\\n\\n def __len__(self):\\n return len(self.paths)\\n\\n def __getitem__(self, idx):\\n path = self.paths[idx]\\n data = np.load(path)\\n\\n augmented = self.tfms(image=data)\\n data = augmented[\\\"image\\\"]\\n\\n x = data[:4, :, :]\\n x = x / args[\\\"rng\\\"]\\n if self.test:\\n return x\\n else:\\n y = data[4, :, :]\\n y = y / args[\\\"rng\\\"]\\n y = y.unsqueeze(0)\\n\\n return x, y\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "class NowcastingDataset(torch.utils.data.Dataset):\n", - " def __init__(self, paths, test=False):\n", + "class NowcastingDataset(Dataset):\n", + " def __init__(self, paths, tfms=None, test=False):\n", " self.paths = paths\n", + " if tfms is not None:\n", + " self.tfms = tfms\n", + " else:\n", + " self.tfms = A.Compose(\n", + " [\n", + " A.PadIfNeeded(\n", + " min_height=128, min_width=128, always_apply=True, p=1\n", + " ),\n", + " ToTensorV2(always_apply=True, p=1),\n", + " ]\n", + " )\n", " self.test = test\n", "\n", " def __len__(self):\n", @@ -84,70 +742,110 @@ " def __getitem__(self, idx):\n", " path = self.paths[idx]\n", " data = np.load(path)\n", - " x = data[:, :, :4]\n", - " # x = x / 255.0\n", - " x = x.astype(np.float32)\n", - " x = torch.tensor(x, dtype=torch.float)\n", - " x = x.permute(2, 0, 1)\n", + "\n", + " augmented = self.tfms(image=data)\n", + " data = augmented[\"image\"]\n", + "\n", + " x = data[:4, :, :]\n", + " x = x / args[\"rng\"]\n", " if self.test:\n", " return x\n", " else:\n", - " y = data[:, :, 4]\n", - " # y = y / 255.0\n", - " y = y.astype(np.float32)\n", - " y = torch.tensor(y, dtype=torch.float)\n", - " y = y.unsqueeze(-1)\n", - " y = y.permute(2, 0, 1)\n", + " y = data[4, :, :]\n", + " y = y / args[\"rng\"]\n", + " y = y.unsqueeze(0)\n", "\n", " return x, y" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 15;\n", + " var nbb_unformatted_code = \"class NowcastingDataModule(pl.LightningDataModule):\\n def __init__(\\n self,\\n train_df=None,\\n val_df=None,\\n batch_size=args[\\\"batch_size\\\"],\\n num_workers=args[\\\"num_workers\\\"],\\n test=False,\\n ):\\n super().__init__()\\n self.train_df = train_df\\n self.val_df = val_df\\n self.batch_size = batch_size\\n self.num_workers = num_workers\\n self.test = test\\n\\n def setup(self, stage=\\\"train\\\"):\\n if stage == \\\"train\\\":\\n train_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.train_df.filename.values\\n ]\\n val_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.val_df.filename.values\\n ]\\n self.train_dataset = NowcastingDataset(train_paths)\\n self.val_dataset = NowcastingDataset(val_paths)\\n else:\\n test_paths = list(sorted(args[\\\"test_data_path\\\"].glob(\\\"*.npy\\\")))\\n self.test_dataset = NowcastingDataset(test_paths, test=True)\\n\\n def train_dataloader(self):\\n return DataLoader(\\n self.train_dataset,\\n batch_size=self.batch_size,\\n sampler=RandomSampler(self.train_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n drop_last=True,\\n )\\n\\n def val_dataloader(self):\\n return DataLoader(\\n self.val_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.val_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\\n\\n def test_dataloader(self):\\n return DataLoader(\\n self.test_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.test_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\";\n", + " var nbb_formatted_code = \"class NowcastingDataModule(pl.LightningDataModule):\\n def __init__(\\n self,\\n train_df=None,\\n val_df=None,\\n batch_size=args[\\\"batch_size\\\"],\\n num_workers=args[\\\"num_workers\\\"],\\n test=False,\\n ):\\n super().__init__()\\n self.train_df = train_df\\n self.val_df = val_df\\n self.batch_size = batch_size\\n self.num_workers = num_workers\\n self.test = test\\n\\n def setup(self, stage=\\\"train\\\"):\\n if stage == \\\"train\\\":\\n train_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.train_df.filename.values\\n ]\\n val_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.val_df.filename.values\\n ]\\n self.train_dataset = NowcastingDataset(train_paths)\\n self.val_dataset = NowcastingDataset(val_paths)\\n else:\\n test_paths = list(sorted(args[\\\"test_data_path\\\"].glob(\\\"*.npy\\\")))\\n self.test_dataset = NowcastingDataset(test_paths, test=True)\\n\\n def train_dataloader(self):\\n return DataLoader(\\n self.train_dataset,\\n batch_size=self.batch_size,\\n sampler=RandomSampler(self.train_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n drop_last=True,\\n )\\n\\n def val_dataloader(self):\\n return DataLoader(\\n self.val_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.val_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\\n\\n def test_dataloader(self):\\n return DataLoader(\\n self.test_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.test_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "class NowcastingDataModule(pl.LightningDataModule):\n", - " def __init__(self, batch_size, test=False, num_workers=4):\n", + " def __init__(\n", + " self,\n", + " train_df=None,\n", + " val_df=None,\n", + " batch_size=args[\"batch_size\"],\n", + " num_workers=args[\"num_workers\"],\n", + " test=False,\n", + " ):\n", " super().__init__()\n", - " self.test = test\n", + " self.train_df = train_df\n", + " self.val_df = val_df\n", " self.batch_size = batch_size\n", - " self.num_workers = 4\n", + " self.num_workers = num_workers\n", + " self.test = test\n", "\n", " def setup(self, stage=\"train\"):\n", " if stage == \"train\":\n", - " paths = list((PATH / \"train\").glob(\"*.npy\"))\n", - " train_paths, val_paths = model_selection.train_test_split(\n", - " paths, test_size=0.1, shuffle=True\n", - " )\n", + " train_paths = [\n", + " args[\"train_data_path\"] / fn for fn in self.train_df.filename.values\n", + " ]\n", + " val_paths = [\n", + " args[\"train_data_path\"] / fn for fn in self.val_df.filename.values\n", + " ]\n", " self.train_dataset = NowcastingDataset(train_paths)\n", " self.val_dataset = NowcastingDataset(val_paths)\n", " else:\n", - " test_paths = list((PATH / \"test\").glob(\"*.npy\"))\n", + " test_paths = list(sorted(args[\"test_data_path\"].glob(\"*.npy\")))\n", " self.test_dataset = NowcastingDataset(test_paths, test=True)\n", "\n", " def train_dataloader(self):\n", - " return torch.utils.data.DataLoader(\n", + " return DataLoader(\n", " self.train_dataset,\n", " batch_size=self.batch_size,\n", - " shuffle=True,\n", + " sampler=RandomSampler(self.train_dataset),\n", " pin_memory=True,\n", " num_workers=self.num_workers,\n", + " drop_last=True,\n", " )\n", "\n", " def val_dataloader(self):\n", - " return torch.utils.data.DataLoader(\n", + " return DataLoader(\n", " self.val_dataset,\n", " batch_size=2 * self.batch_size,\n", + " sampler=SequentialSampler(self.val_dataset),\n", " pin_memory=True,\n", " num_workers=self.num_workers,\n", " )\n", "\n", " def test_dataloader(self):\n", - " return torch.utils.data.DataLoader(\n", + " return DataLoader(\n", " self.test_dataset,\n", " batch_size=2 * self.batch_size,\n", + " sampler=SequentialSampler(self.test_dataset),\n", " pin_memory=True,\n", " num_workers=self.num_workers,\n", " )" @@ -155,19 +853,59 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 16, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 16;\n", + " var nbb_unformatted_code = \"# df = pd.read_csv(args[\\\"train_folds_csv\\\"])\\n\\n# fold = 0\\n# train_df = df[df[\\\"fold\\\"] != fold]\\n# val_df = df[df[\\\"fold\\\"] == fold]\\n\\n# datamodule = NowcastingDataModule(train_df, val_df)\\n# datamodule.setup()\\n\\n# for batch in datamodule.train_dataloader():\\n# xs, ys = batch\\n# idx = np.random.randint(len(xs))\\n# x, y = xs[idx], ys[idx]\\n# print(x.shape, y.shape)\\n# x = x.detach().cpu().numpy().transpose(1, 2, 0)\\n# y = y.unsqueeze(-1)\\n# y = y.detach().cpu().numpy()\\n# visualize(x, y)\\n# break\";\n", + " var nbb_formatted_code = \"# df = pd.read_csv(args[\\\"train_folds_csv\\\"])\\n\\n# fold = 0\\n# train_df = df[df[\\\"fold\\\"] != fold]\\n# val_df = df[df[\\\"fold\\\"] == fold]\\n\\n# datamodule = NowcastingDataModule(train_df, val_df)\\n# datamodule.setup()\\n\\n# for batch in datamodule.train_dataloader():\\n# xs, ys = batch\\n# idx = np.random.randint(len(xs))\\n# x, y = xs[idx], ys[idx]\\n# print(x.shape, y.shape)\\n# x = x.detach().cpu().numpy().transpose(1, 2, 0)\\n# y = y.unsqueeze(-1)\\n# y = y.detach().cpu().numpy()\\n# visualize(x, y)\\n# break\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "datamodule = NowcastingDataModule(batch_size=32)\n", - "datamodule.setup()\n", - "for batch in datamodule.train_dataloader():\n", - " xs, ys = batch\n", - " x, y = xs[0], ys[0]\n", - " x = x.permute(1, 2, 0).numpy()\n", - " y = y.permute(1, 2, 0).numpy()\n", - " visualize(x, y)\n", - " break" + "# df = pd.read_csv(args[\"train_folds_csv\"])\n", + "\n", + "# fold = 0\n", + "# train_df = df[df[\"fold\"] != fold]\n", + "# val_df = df[df[\"fold\"] == fold]\n", + "\n", + "# datamodule = NowcastingDataModule(train_df, val_df)\n", + "# datamodule.setup()\n", + "\n", + "# for batch in datamodule.train_dataloader():\n", + "# xs, ys = batch\n", + "# idx = np.random.randint(len(xs))\n", + "# x, y = xs[idx], ys[idx]\n", + "# print(x.shape, y.shape)\n", + "# x = x.detach().cpu().numpy().transpose(1, 2, 0)\n", + "# y = y.unsqueeze(-1)\n", + "# y = y.detach().cpu().numpy()\n", + "# visualize(x, y)\n", + "# break" ] }, { @@ -177,11 +915,46 @@ "## Model" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Layers" + ] + }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 17;\n", + " var nbb_unformatted_code = \"class Block(nn.Module):\\n def __init__(self, in_ch, out_ch):\\n super().__init__()\\n self.net = nn.Sequential(\\n nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n nn.BatchNorm2d(out_ch),\\n nn.ReLU(inplace=True),\\n )\\n\\n def forward(self, x):\\n return self.net(x)\\n \\nclass Encoder(nn.Module):\\n def __init__(self, chs=[4, 64, 128]):\\n super().__init__()\\n self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\\n self.blocks = nn.ModuleList(\\n [Block(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n self.conv = nn.Conv2d(128, 512, kernel_size=3, padding=1)\\n\\n def forward(self, x):\\n ftrs = []\\n for block in self.blocks:\\n x = block(x)\\n ftrs.append(x)\\n x = self.pool(x)\\n x = self.conv(x)\\n ftrs.append(x)\\n return ftrs\\n \\nclass Decoder(nn.Module):\\n def __init__(self, chs=[512, 128, 64]):\\n super().__init__()\\n self.tr_convs = nn.ModuleList(\\n [\\n nn.ConvTranspose2d(chs[i], chs[i + 1], kernel_size=2, stride=2)\\n for i in range(len(chs) - 1)\\n ]\\n )\\n self.blocks = nn.ModuleList(\\n [Block(2 * chs[i + 1], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n\\n def forward(self, x, ftrs):\\n for i, ftr in enumerate(ftrs):\\n x = self.tr_convs[i](x)\\n x = torch.cat([ftr, x], dim=1)\\n x = self.blocks[i](x)\\n return x\";\n", + " var nbb_formatted_code = \"class Block(nn.Module):\\n def __init__(self, in_ch, out_ch):\\n super().__init__()\\n self.net = nn.Sequential(\\n nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n nn.BatchNorm2d(out_ch),\\n nn.ReLU(inplace=True),\\n )\\n\\n def forward(self, x):\\n return self.net(x)\\n\\n\\nclass Encoder(nn.Module):\\n def __init__(self, chs=[4, 64, 128]):\\n super().__init__()\\n self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\\n self.blocks = nn.ModuleList(\\n [Block(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n self.conv = nn.Conv2d(128, 512, kernel_size=3, padding=1)\\n\\n def forward(self, x):\\n ftrs = []\\n for block in self.blocks:\\n x = block(x)\\n ftrs.append(x)\\n x = self.pool(x)\\n x = self.conv(x)\\n ftrs.append(x)\\n return ftrs\\n\\n\\nclass Decoder(nn.Module):\\n def __init__(self, chs=[512, 128, 64]):\\n super().__init__()\\n self.tr_convs = nn.ModuleList(\\n [\\n nn.ConvTranspose2d(chs[i], chs[i + 1], kernel_size=2, stride=2)\\n for i in range(len(chs) - 1)\\n ]\\n )\\n self.blocks = nn.ModuleList(\\n [Block(2 * chs[i + 1], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n\\n def forward(self, x, ftrs):\\n for i, ftr in enumerate(ftrs):\\n x = self.tr_convs[i](x)\\n x = torch.cat([ftr, x], dim=1)\\n x = self.blocks[i](x)\\n return x\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "class Block(nn.Module):\n", " def __init__(self, in_ch, out_ch):\n", @@ -193,15 +966,9 @@ " )\n", "\n", " def forward(self, x):\n", - " return self.net(x)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ + " return self.net(x)\n", + "\n", + "\n", "class Encoder(nn.Module):\n", " def __init__(self, chs=[4, 64, 128]):\n", " super().__init__()\n", @@ -219,15 +986,9 @@ " x = self.pool(x)\n", " x = self.conv(x)\n", " ftrs.append(x)\n", - " return ftrs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ + " return ftrs\n", + "\n", + "\n", "class Decoder(nn.Module):\n", " def __init__(self, chs=[512, 128, 64]):\n", " super().__init__()\n", @@ -249,16 +1010,58 @@ " return x" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Model" + ] + }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 18;\n", + " var nbb_unformatted_code = \"class Baseline(pl.LightningModule):\\n def __init__(\\n self,\\n lr=args[\\\"lr\\\"],\\n enc_chs=[4, 64, 128],\\n dec_chs=[512, 128, 64],\\n num_train_steps=None,\\n ):\\n super().__init__()\\n self.lr = lr\\n self.num_train_steps = num_train_steps\\n self.criterion = nn.L1Loss()\\n self.encoder = Encoder(enc_chs)\\n self.decoder = Decoder(dec_chs)\\n self.out = nn.Sequential(\\n nn.Conv2d(64, 1, kernel_size=3, padding=1),\\n nn.ReLU(inplace=True),\\n )\\n\\n def forward(self, x):\\n ftrs = self.encoder(x)\\n ftrs = ftrs[::-1]\\n x = self.decoder(ftrs[0], ftrs[1:])\\n out = self.out(x)\\n return out\\n\\n def shared_step(self, batch, batch_idx):\\n x, y = batch\\n y_hat = self(x)\\n loss = self.criterion(y_hat, y)\\n return loss, y, y_hat\\n\\n def training_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n self.log(\\\"train_loss\\\", loss)\\n\\n for i, param_group in enumerate(self.optimizer.param_groups):\\n self.log(f\\\"lr/lr{i}\\\", param_group[\\\"lr\\\"])\\n\\n return {\\\"loss\\\": loss}\\n\\n def validation_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n\\n return {\\\"loss\\\": loss, \\\"y\\\": y.detach(), \\\"y_hat\\\": y_hat.detach()}\\n\\n def validation_epoch_end(self, outputs):\\n avg_loss = torch.stack([x[\\\"loss\\\"] for x in outputs]).mean()\\n self.log(\\\"val_loss\\\", avg_loss)\\n\\n y = torch.cat([x[\\\"y\\\"] for x in outputs])\\n y_hat = torch.cat([x[\\\"y_hat\\\"] for x in outputs])\\n\\n crop = T.CenterCrop(120)\\n y = crop(y)\\n y_hat = crop(y_hat)\\n\\n batch_size = len(y)\\n y = y.reshape(batch_size, -1)\\n y = y.detach().cpu().numpy()\\n y *= args[\\\"rng\\\"]\\n y_hat = y_hat.reshape(batch_size, -1)\\n y_hat = y_hat.detach().cpu().numpy()\\n y_hat *= args[\\\"rng\\\"]\\n\\n y = y[:, args[\\\"dams\\\"]]\\n y_hat = y_hat[:, args[\\\"dams\\\"]]\\n\\n y_true = radar2precipitation(y)\\n y_true = np.where(y_true >= 0.1, 1.0, 0.0)\\n y_pred = radar2precipitation(y_hat)\\n y_pred = np.where(y_pred >= 0.1, 1.0, 0.0)\\n\\n y_true = y_true.ravel()\\n y_pred = y_pred.ravel()\\n\\n tn, fp, fn, tp = metrics.confusion_matrix(y_true, y_pred).ravel()\\n csi = tp / (tp + tn + tp)\\n self.log(\\\"csi\\\", csi)\\n\\n mae = metrics.mean_absolute_error(y_true, y_pred, sample_weight=y_true)\\n self.log(\\\"mae\\\", mae)\\n\\n comp_metric = mae / (csi + 1e-12)\\n self.log(\\\"comp_metric\\\", comp_metric)\\n\\n print(\\n f\\\"Epoch {self.current_epoch} | MAE/CSI: {comp_metric} | MAE: {mae} | CSI: {csi} | Loss: {avg_loss}\\\"\\n )\\n\\n def configure_optimizers(self):\\n if args[\\\"optimizer\\\"] == \\\"adam\\\":\\n self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"adamw\\\":\\n self.optimizer = AdamW(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"radam\\\":\\n self.optimizer = optim.RAdam(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"ranger\\\":\\n self.optimizer = optim.RAdam(self.parameters(), lr=self.lr)\\n self.optimizer = optim.Lookahead(self.optimizer)\\n\\n if args[\\\"scheduler\\\"] == \\\"cosine\\\":\\n self.scheduler = get_cosine_schedule_with_warmup(\\n self.optimizer,\\n num_warmup_steps=self.num_train_steps * args[\\\"warmup_epochs\\\"],\\n num_training_steps=self.num_train_steps * args[\\\"max_epochs\\\"],\\n )\\n return [self.optimizer], [{\\\"scheduler\\\": self.scheduler, \\\"interval\\\": \\\"step\\\"}]\\n elif args[\\\"scheduler\\\"] == \\\"step\\\":\\n self.scheduler = torch.optim.lr_scheduler.StepLR(\\n self.optimizer, step_size=10, gamma=0.5\\n )\\n return [self.optimizer], [\\n {\\\"scheduler\\\": self.scheduler, \\\"interval\\\": \\\"epoch\\\"}\\n ]\\n elif args[\\\"scheduler\\\"] == \\\"plateau\\\":\\n self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\\n self.optimizer, mode=\\\"min\\\", factor=0.1, patience=3, verbose=True\\n )\\n return [self.optimizer], [\\n {\\n \\\"scheduler\\\": self.scheduler,\\n \\\"interval\\\": \\\"epoch\\\",\\n \\\"reduce_on_plateau\\\": True,\\n \\\"monitor\\\": \\\"comp_metric\\\",\\n }\\n ]\\n else:\\n self.scheduler = None\\n return [self.optimizer]\";\n", + " var nbb_formatted_code = \"class Baseline(pl.LightningModule):\\n def __init__(\\n self,\\n lr=args[\\\"lr\\\"],\\n enc_chs=[4, 64, 128],\\n dec_chs=[512, 128, 64],\\n num_train_steps=None,\\n ):\\n super().__init__()\\n self.lr = lr\\n self.num_train_steps = num_train_steps\\n self.criterion = nn.L1Loss()\\n self.encoder = Encoder(enc_chs)\\n self.decoder = Decoder(dec_chs)\\n self.out = nn.Sequential(\\n nn.Conv2d(64, 1, kernel_size=3, padding=1),\\n nn.ReLU(inplace=True),\\n )\\n\\n def forward(self, x):\\n ftrs = self.encoder(x)\\n ftrs = ftrs[::-1]\\n x = self.decoder(ftrs[0], ftrs[1:])\\n out = self.out(x)\\n return out\\n\\n def shared_step(self, batch, batch_idx):\\n x, y = batch\\n y_hat = self(x)\\n loss = self.criterion(y_hat, y)\\n return loss, y, y_hat\\n\\n def training_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n self.log(\\\"train_loss\\\", loss)\\n\\n for i, param_group in enumerate(self.optimizer.param_groups):\\n self.log(f\\\"lr/lr{i}\\\", param_group[\\\"lr\\\"])\\n\\n return {\\\"loss\\\": loss}\\n\\n def validation_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n\\n return {\\\"loss\\\": loss, \\\"y\\\": y.detach(), \\\"y_hat\\\": y_hat.detach()}\\n\\n def validation_epoch_end(self, outputs):\\n avg_loss = torch.stack([x[\\\"loss\\\"] for x in outputs]).mean()\\n self.log(\\\"val_loss\\\", avg_loss)\\n\\n y = torch.cat([x[\\\"y\\\"] for x in outputs])\\n y_hat = torch.cat([x[\\\"y_hat\\\"] for x in outputs])\\n\\n crop = T.CenterCrop(120)\\n y = crop(y)\\n y_hat = crop(y_hat)\\n\\n batch_size = len(y)\\n y = y.reshape(batch_size, -1)\\n y = y.detach().cpu().numpy()\\n y *= args[\\\"rng\\\"]\\n y_hat = y_hat.reshape(batch_size, -1)\\n y_hat = y_hat.detach().cpu().numpy()\\n y_hat *= args[\\\"rng\\\"]\\n\\n y = y[:, args[\\\"dams\\\"]]\\n y_hat = y_hat[:, args[\\\"dams\\\"]]\\n\\n y_true = radar2precipitation(y)\\n y_true = np.where(y_true >= 0.1, 1.0, 0.0)\\n y_pred = radar2precipitation(y_hat)\\n y_pred = np.where(y_pred >= 0.1, 1.0, 0.0)\\n\\n y_true = y_true.ravel()\\n y_pred = y_pred.ravel()\\n\\n tn, fp, fn, tp = metrics.confusion_matrix(y_true, y_pred).ravel()\\n csi = tp / (tp + tn + tp)\\n self.log(\\\"csi\\\", csi)\\n\\n mae = metrics.mean_absolute_error(y_true, y_pred, sample_weight=y_true)\\n self.log(\\\"mae\\\", mae)\\n\\n comp_metric = mae / (csi + 1e-12)\\n self.log(\\\"comp_metric\\\", comp_metric)\\n\\n print(\\n f\\\"Epoch {self.current_epoch} | MAE/CSI: {comp_metric} | MAE: {mae} | CSI: {csi} | Loss: {avg_loss}\\\"\\n )\\n\\n def configure_optimizers(self):\\n if args[\\\"optimizer\\\"] == \\\"adam\\\":\\n self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"adamw\\\":\\n self.optimizer = AdamW(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"radam\\\":\\n self.optimizer = optim.RAdam(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"ranger\\\":\\n self.optimizer = optim.RAdam(self.parameters(), lr=self.lr)\\n self.optimizer = optim.Lookahead(self.optimizer)\\n\\n if args[\\\"scheduler\\\"] == \\\"cosine\\\":\\n self.scheduler = get_cosine_schedule_with_warmup(\\n self.optimizer,\\n num_warmup_steps=self.num_train_steps * args[\\\"warmup_epochs\\\"],\\n num_training_steps=self.num_train_steps * args[\\\"max_epochs\\\"],\\n )\\n return [self.optimizer], [{\\\"scheduler\\\": self.scheduler, \\\"interval\\\": \\\"step\\\"}]\\n elif args[\\\"scheduler\\\"] == \\\"step\\\":\\n self.scheduler = torch.optim.lr_scheduler.StepLR(\\n self.optimizer, step_size=10, gamma=0.5\\n )\\n return [self.optimizer], [\\n {\\\"scheduler\\\": self.scheduler, \\\"interval\\\": \\\"epoch\\\"}\\n ]\\n elif args[\\\"scheduler\\\"] == \\\"plateau\\\":\\n self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\\n self.optimizer, mode=\\\"min\\\", factor=0.1, patience=3, verbose=True\\n )\\n return [self.optimizer], [\\n {\\n \\\"scheduler\\\": self.scheduler,\\n \\\"interval\\\": \\\"epoch\\\",\\n \\\"reduce_on_plateau\\\": True,\\n \\\"monitor\\\": \\\"comp_metric\\\",\\n }\\n ]\\n else:\\n self.scheduler = None\\n return [self.optimizer]\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "class Baseline(pl.LightningModule):\n", - " def __init__(self, lr=1e-3, enc_chs=[4, 64, 128], dec_chs=[512, 128, 64]):\n", + " def __init__(\n", + " self,\n", + " lr=args[\"lr\"],\n", + " enc_chs=[4, 64, 128],\n", + " dec_chs=[512, 128, 64],\n", + " num_train_steps=None,\n", + " ):\n", " super().__init__()\n", " self.lr = lr\n", + " self.num_train_steps = num_train_steps\n", " self.criterion = nn.L1Loss()\n", " self.encoder = Encoder(enc_chs)\n", " self.decoder = Decoder(dec_chs)\n", @@ -278,25 +1081,106 @@ " x, y = batch\n", " y_hat = self(x)\n", " loss = self.criterion(y_hat, y)\n", - " return loss\n", + " return loss, y, y_hat\n", "\n", " def training_step(self, batch, batch_idx):\n", - " loss = self.shared_step(batch, batch_idx)\n", + " loss, y, y_hat = self.shared_step(batch, batch_idx)\n", " self.log(\"train_loss\", loss)\n", + "\n", + " for i, param_group in enumerate(self.optimizer.param_groups):\n", + " self.log(f\"lr/lr{i}\", param_group[\"lr\"])\n", + "\n", " return {\"loss\": loss}\n", "\n", " def validation_step(self, batch, batch_idx):\n", - " loss = self.shared_step(batch, batch_idx)\n", - " self.log(\"val_loss\", loss)\n", - " return {\"loss\": loss}\n", + " loss, y, y_hat = self.shared_step(batch, batch_idx)\n", + "\n", + " return {\"loss\": loss, \"y\": y.detach(), \"y_hat\": y_hat.detach()}\n", "\n", " def validation_epoch_end(self, outputs):\n", " avg_loss = torch.stack([x[\"loss\"] for x in outputs]).mean()\n", - " print(f\"Epoch {self.current_epoch} | MAE: {avg_loss}\")\n", + " self.log(\"val_loss\", avg_loss)\n", + "\n", + " y = torch.cat([x[\"y\"] for x in outputs])\n", + " y_hat = torch.cat([x[\"y_hat\"] for x in outputs])\n", + "\n", + " crop = T.CenterCrop(120)\n", + " y = crop(y)\n", + " y_hat = crop(y_hat)\n", + "\n", + " batch_size = len(y)\n", + " y = y.reshape(batch_size, -1)\n", + " y = y.detach().cpu().numpy()\n", + " y *= args[\"rng\"]\n", + " y_hat = y_hat.reshape(batch_size, -1)\n", + " y_hat = y_hat.detach().cpu().numpy()\n", + " y_hat *= args[\"rng\"]\n", + "\n", + " y = y[:, args[\"dams\"]]\n", + " y_hat = y_hat[:, args[\"dams\"]]\n", + "\n", + " y_true = radar2precipitation(y)\n", + " y_true = np.where(y_true >= 0.1, 1.0, 0.0)\n", + " y_pred = radar2precipitation(y_hat)\n", + " y_pred = np.where(y_pred >= 0.1, 1.0, 0.0)\n", + "\n", + " y_true = y_true.ravel()\n", + " y_pred = y_pred.ravel()\n", + "\n", + " tn, fp, fn, tp = metrics.confusion_matrix(y_true, y_pred).ravel()\n", + " csi = tp / (tp + tn + tp)\n", + " self.log(\"csi\", csi)\n", + "\n", + " mae = metrics.mean_absolute_error(y, y_hat, sample_weight=y_true)\n", + " self.log(\"mae\", mae)\n", + "\n", + " comp_metric = mae / (csi + 1e-12)\n", + " self.log(\"comp_metric\", comp_metric)\n", + "\n", + " print(\n", + " f\"Epoch {self.current_epoch} | MAE/CSI: {comp_metric} | MAE: {mae} | CSI: {csi} | Loss: {avg_loss}\"\n", + " )\n", "\n", " def configure_optimizers(self):\n", - " optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\n", - " return optimizer" + " if args[\"optimizer\"] == \"adam\":\n", + " self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\n", + " elif args[\"optimizer\"] == \"adamw\":\n", + " self.optimizer = AdamW(self.parameters(), lr=self.lr)\n", + " elif args[\"optimizer\"] == \"radam\":\n", + " self.optimizer = optim.RAdam(self.parameters(), lr=self.lr)\n", + " elif args[\"optimizer\"] == \"ranger\":\n", + " self.optimizer = optim.RAdam(self.parameters(), lr=self.lr)\n", + " self.optimizer = optim.Lookahead(self.optimizer)\n", + "\n", + " if args[\"scheduler\"] == \"cosine\":\n", + " self.scheduler = get_cosine_schedule_with_warmup(\n", + " self.optimizer,\n", + " num_warmup_steps=self.num_train_steps * args[\"warmup_epochs\"],\n", + " num_training_steps=self.num_train_steps * args[\"max_epochs\"],\n", + " )\n", + " return [self.optimizer], [{\"scheduler\": self.scheduler, \"interval\": \"step\"}]\n", + " elif args[\"scheduler\"] == \"step\":\n", + " self.scheduler = torch.optim.lr_scheduler.StepLR(\n", + " self.optimizer, step_size=10, gamma=0.5\n", + " )\n", + " return [self.optimizer], [\n", + " {\"scheduler\": self.scheduler, \"interval\": \"epoch\"}\n", + " ]\n", + " elif args[\"scheduler\"] == \"plateau\":\n", + " self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n", + " self.optimizer, mode=\"min\", factor=0.1, patience=3, verbose=True\n", + " )\n", + " return [self.optimizer], [\n", + " {\n", + " \"scheduler\": self.scheduler,\n", + " \"interval\": \"epoch\",\n", + " \"reduce_on_plateau\": True,\n", + " \"monitor\": \"comp_metric\",\n", + " }\n", + " ]\n", + " else:\n", + " self.scheduler = None\n", + " return [self.optimizer]" ] }, { @@ -308,142 +1192,13540 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 19, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "42" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 19;\n", + " var nbb_unformatted_code = \"seed_everything(args[\\\"seed\\\"])\\npl.seed_everything(args[\\\"seed\\\"])\";\n", + " var nbb_formatted_code = \"seed_everything(args[\\\"seed\\\"])\\npl.seed_everything(args[\\\"seed\\\"])\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "datamodule = NowcastingDataModule(batch_size=256)\n", - "datamodule.setup()" + "seed_everything(args[\"seed\"])\n", + "pl.seed_everything(args[\"seed\"])" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 20, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 20;\n", + " var nbb_unformatted_code = \"df = pd.read_csv(args[\\\"train_folds_csv\\\"])\";\n", + " var nbb_formatted_code = \"df = pd.read_csv(args[\\\"train_folds_csv\\\"])\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "model = Baseline()" + "df = pd.read_csv(args[\"train_folds_csv\"])" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 14;\n", + " var nbb_unformatted_code = \"def train_fold(df, fold):\\n train_df = df[df.fold != fold]\\n val_df = df[df.fold == fold]\\n\\n datamodule = NowcastingDataModule(train_df, val_df)\\n datamodule.setup()\\n\\n num_train_steps = np.ceil(\\n len(train_df) // args[\\\"batch_size\\\"] / args[\\\"accumulate_grad_batches\\\"]\\n )\\n model = Baseline(num_train_steps=num_train_steps)\\n\\n trainer = pl.Trainer(\\n gpus=args[\\\"gpus\\\"],\\n max_epochs=args[\\\"max_epochs\\\"],\\n precision=args[\\\"precision\\\"],\\n progress_bar_refresh_rate=50,\\n benchmark=True,\\n )\\n\\n print(f\\\"Training fold{fold}...\\\")\\n trainer.fit(model, datamodule)\\n \\n checkpoint = (\\n args[\\\"model_dir\\\"]\\n / f\\\"baseline_fold{fold}_bs{args['batch_size']}_epochs{args['max_epochs']}_lr{model.lr}_{args['optimizer']}.ckpt\\\"\\n )\\n trainer.save_checkpoint(checkpoint)\\n print(\\\"Model saved at\\\", checkpoint)\\n \\n del model, trainer, datamodule\\n gc.collect()\\n torch.cuda.empty_cache() \";\n", + " var nbb_formatted_code = \"def train_fold(df, fold):\\n train_df = df[df.fold != fold]\\n val_df = df[df.fold == fold]\\n\\n datamodule = NowcastingDataModule(train_df, val_df)\\n datamodule.setup()\\n\\n num_train_steps = np.ceil(\\n len(train_df) // args[\\\"batch_size\\\"] / args[\\\"accumulate_grad_batches\\\"]\\n )\\n model = Baseline(num_train_steps=num_train_steps)\\n\\n trainer = pl.Trainer(\\n gpus=args[\\\"gpus\\\"],\\n max_epochs=args[\\\"max_epochs\\\"],\\n precision=args[\\\"precision\\\"],\\n progress_bar_refresh_rate=50,\\n benchmark=True,\\n )\\n\\n print(f\\\"Training fold{fold}...\\\")\\n trainer.fit(model, datamodule)\\n\\n checkpoint = (\\n args[\\\"model_dir\\\"]\\n / f\\\"baseline_fold{fold}_bs{args['batch_size']}_epochs{args['max_epochs']}_lr{model.lr}_{args['optimizer']}.ckpt\\\"\\n )\\n trainer.save_checkpoint(checkpoint)\\n print(\\\"Model saved at\\\", checkpoint)\\n\\n del model, trainer, datamodule\\n gc.collect()\\n torch.cuda.empty_cache()\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "trainer = pl.Trainer(\n", - " gpus=1, max_epochs=10, precision=16, progress_bar_refresh_rate=50, benchmark=True\n", - ")" + "def train_fold(df, fold):\n", + " train_df = df[df.fold != fold]\n", + " val_df = df[df.fold == fold]\n", + "\n", + " datamodule = NowcastingDataModule(train_df, val_df)\n", + " datamodule.setup()\n", + "\n", + " num_train_steps = np.ceil(\n", + " len(train_df) // args[\"batch_size\"] / args[\"accumulate_grad_batches\"]\n", + " )\n", + " model = Baseline(num_train_steps=num_train_steps)\n", + "\n", + " trainer = pl.Trainer(\n", + " gpus=args[\"gpus\"],\n", + " max_epochs=args[\"max_epochs\"],\n", + " precision=args[\"precision\"],\n", + " progress_bar_refresh_rate=50,\n", + " benchmark=True,\n", + " )\n", + "\n", + " print(f\"Training fold {fold}...\")\n", + " trainer.fit(model, datamodule)\n", + " \n", + " checkpoint = (\n", + " args[\"model_dir\"]\n", + " / f\"baseline_fold{fold}_bs{args['batch_size']}_epochs{args['max_epochs']}_lr{model.lr}_{args['optimizer']}.ckpt\"\n", + " )\n", + " trainer.save_checkpoint(checkpoint)\n", + " print(\"Model saved at\", checkpoint)\n", + " \n", + " del model, trainer, datamodule\n", + " gc.collect()\n", + " torch.cuda.empty_cache() " ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 14, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True, used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "Using native 16bit precision.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training fold0...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + " | Name | Type | Params\n", + "-----------------------------------------\n", + "0 | criterion | L1Loss | 0 \n", + "1 | encoder | Encoder | 666 K \n", + "2 | decoder | Decoder | 664 K \n", + "3 | out | Sequential | 577 \n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "aa45faf8cb774ecb9a2293ee66bca401", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 1000000000000.0 | MAE: 1.0 | CSI: 0.0 | Loss: 0.07034780830144882\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "23c130d602014c6ebcf40ce6616a2a0c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 0.8953259817369302 | MAE: 0.10786948176583493 | CSI: 0.12048067850736674 | Loss: 0.014457848854362965\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n", + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 | MAE/CSI: 1.0508558724080708 | MAE: 0.12422264875239923 | CSI: 0.11821092883716593 | Loss: 0.012260455638170242\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n", + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 2 | MAE/CSI: 0.897048855974516 | MAE: 0.10771593090211132 | CSI: 0.12007810966348786 | Loss: 0.012970300391316414\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 3 | MAE/CSI: 0.9234593276009212 | MAE: 0.11047984644913628 | CSI: 0.1196369381369846 | Loss: 0.011573512107133865\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n", + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 4 | MAE/CSI: 1.0946424615395505 | MAE: 0.12859884836852206 | CSI: 0.1174802301991471 | Loss: 0.011782950721681118\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 5 | MAE/CSI: 0.9862818881438672 | MAE: 0.11715930902111324 | CSI: 0.11878886799859507 | Loss: 0.011379786767065525\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n", + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 6 | MAE/CSI: 1.2269688941892183 | MAE: 0.1421880998080614 | CSI: 0.11588565975895616 | Loss: 0.011886500753462315\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 7 | MAE/CSI: 1.0524613342678637 | MAE: 0.12406909788867562 | CSI: 0.1178847087754828 | Loss: 0.011312616057693958\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 8 | MAE/CSI: 0.8743689630123128 | MAE: 0.10502879078694817 | CSI: 0.12011953217579474 | Loss: 0.011175304651260376\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n", + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 9 | MAE/CSI: 1.0495967204363328 | MAE: 0.1237619961612284 | CSI: 0.11791385562707277 | Loss: 0.011026876978576183\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n", + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 10 | MAE/CSI: 0.9960218536200095 | MAE: 0.11808061420345489 | CSI: 0.11855223233636758 | Loss: 0.011095545254647732\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 11 | MAE/CSI: 0.8782308244192888 | MAE: 0.1054126679462572 | CSI: 0.1200284310392781 | Loss: 0.010937790386378765\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n", + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 12 | MAE/CSI: 0.8742338140149913 | MAE: 0.10502879078694817 | CSI: 0.12013810161805627 | Loss: 0.011085813865065575\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 13 | MAE/CSI: 0.7914715377970224 | MAE: 0.09589251439539348 | CSI: 0.12115724926951726 | Loss: 0.011124382726848125\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 14 | MAE/CSI: 0.8122845658503062 | MAE: 0.0981957773512476 | CSI: 0.12088839487876168 | Loss: 0.010922062210738659\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n", + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 15 | MAE/CSI: 0.832724012562721 | MAE: 0.10042226487523992 | CSI: 0.12059489501852615 | Loss: 0.010781347751617432\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n", + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 16 | MAE/CSI: 0.9140697496979593 | MAE: 0.1092514395393474 | CSI: 0.11952199443700422 | Loss: 0.010753633454442024\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n", + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 17 | MAE/CSI: 1.0276737057476453 | MAE: 0.12138195777351247 | CSI: 0.11811332438848178 | Loss: 0.010888705030083656\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 18 | MAE/CSI: 0.90410702954464 | MAE: 0.10817658349328214 | CSI: 0.11965019622384969 | Loss: 0.010728368535637856\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n", + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 19 | MAE/CSI: 0.9914165414434577 | MAE: 0.11754318618042227 | CSI: 0.1185608483073051 | Loss: 0.010847923345863819\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 20 | MAE/CSI: 0.9764170742744235 | MAE: 0.11593090211132438 | CSI: 0.11873092468860844 | Loss: 0.01069142110645771\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n", + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 21 | MAE/CSI: 0.8930430185237244 | MAE: 0.10694817658349329 | CSI: 0.11975702666529393 | Loss: 0.01066147442907095\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n", + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 22 | MAE/CSI: 0.8887563310135709 | MAE: 0.10648752399232246 | CSI: 0.1198163324136227 | Loss: 0.010664239525794983\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 23 | MAE/CSI: 0.840922715282812 | MAE: 0.10126679462571977 | CSI: 0.12042342629646013 | Loss: 0.010658983141183853\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n", + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 24 | MAE/CSI: 0.8959247022302149 | MAE: 0.1072552783109405 | CSI: 0.11971461222472743 | Loss: 0.010633627884089947\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n", + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 25 | MAE/CSI: 0.8951603121976484 | MAE: 0.1071785028790787 | CSI: 0.11973107097996417 | Loss: 0.010618860833346844\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n", + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 26 | MAE/CSI: 0.9008450087222826 | MAE: 0.10779270633397313 | CSI: 0.11965732760839795 | Loss: 0.010613663122057915\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n", + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 27 | MAE/CSI: 0.8874501850055542 | MAE: 0.10633397312859885 | CSI: 0.11981965289358286 | Loss: 0.010609721764922142\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n", + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 28 | MAE/CSI: 0.8923343991380076 | MAE: 0.10687140115163148 | CSI: 0.11976608909616909 | Loss: 0.010615041479468346\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 29 | MAE/CSI: 0.8853361925398303 | MAE: 0.10610364683301343 | CSI: 0.11984559958826557 | Loss: 0.01060927752405405\n", + "\n", + "Model saved at ../models/baseline_bs128_epochs30_lr0.0005_adamw.ckpt\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n", + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 14;\n", + " var nbb_unformatted_code = \"fold = 0\\ntrain_fold(df, fold)\";\n", + " var nbb_formatted_code = \"fold = 0\\ntrain_fold(df, fold)\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "# lr_finder = trainer.tuner.lr_find(model, datamodule)\n", - "# fig = lr_finder.plot(suggest=True)" + "# AdamW\n", + "fold = 0\n", + "train_fold(df, fold)" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 13, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True, used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "Using native 16bit precision.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training fold0...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + " | Name | Type | Params\n", + "-----------------------------------------\n", + "0 | criterion | L1Loss | 0 \n", + "1 | encoder | Encoder | 666 K \n", + "2 | decoder | Decoder | 664 K \n", + "3 | out | Sequential | 577 \n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9b53f411c5f84c46b3d2c4d0aa52a36c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 1000000000000.0 | MAE: 1.0 | CSI: 0.0 | Loss: 0.07034780830144882\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "37dbe801d97d477e9d8782bb68e1bf81", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/optim/radam.py:49: UserWarning: This overload of addcmul_ is deprecated:\n", + "\taddcmul_(Number value, Tensor tensor1, Tensor tensor2)\n", + "Consider using one of the following signatures instead:\n", + "\taddcmul_(Tensor tensor1, Tensor tensor2, *, Number value) (Triggered internally at /opt/conda/conda-bld/pytorch_1603729047590/work/torch/csrc/utils/python_arg_parser.cpp:882.)\n", + " exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 1.1287884420798677 | MAE: 0.1327447216890595 | CSI: 0.11759929207225027 | Loss: 0.016403084620833397\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n", + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 | MAE/CSI: 1.2616119527130563 | MAE: 0.14610364683301344 | CSI: 0.11580712002415686 | Loss: 0.013337209820747375\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 2 | MAE/CSI: 0.8670467608914513 | MAE: 0.10456813819577736 | CSI: 0.12060265133497404 | Loss: 0.012833792716264725\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n", + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 3 | MAE/CSI: 1.0554725725706757 | MAE: 0.12460652591170826 | CSI: 0.11805756885483537 | Loss: 0.012007031589746475\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n", + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 4 | MAE/CSI: 1.3001223097119974 | MAE: 0.14963531669865643 | CSI: 0.11509326128747337 | Loss: 0.012559030205011368\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 5 | MAE/CSI: 0.9668908031778127 | MAE: 0.11516314779270634 | CSI: 0.11910667410760423 | Loss: 0.011626843363046646\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n", + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 6 | MAE/CSI: 1.142794785587942 | MAE: 0.13358925143953934 | CSI: 0.11689697321262094 | Loss: 0.011799349449574947\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 7 | MAE/CSI: 1.1083582171365651 | MAE: 0.12998080614203456 | CSI: 0.11727328234794937 | Loss: 0.011587237007915974\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 8 | MAE/CSI: 0.8388044949660333 | MAE: 0.10119001919385796 | CSI: 0.12063600016487366 | Loss: 0.011390800587832928\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n", + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 9 | MAE/CSI: 0.9891935240347431 | MAE: 0.11738963531669866 | CSI: 0.11867206210256834 | Loss: 0.011127580888569355\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n", + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 10 | MAE/CSI: 0.9411532947274672 | MAE: 0.11224568138195777 | CSI: 0.11926397326539663 | Loss: 0.01120688859373331\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 11 | MAE/CSI: 0.8536757691390952 | MAE: 0.10272552783109405 | CSI: 0.12033318918473672 | Loss: 0.011052107438445091\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n", + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 12 | MAE/CSI: 0.9470463135836465 | MAE: 0.1128598848368522 | CSI: 0.11917039665023411 | Loss: 0.011067020706832409\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 13 | MAE/CSI: 0.8655408572526657 | MAE: 0.10403071017274472 | CSI: 0.12019156496215047 | Loss: 0.011069140397012234\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 14 | MAE/CSI: 0.8477291753956104 | MAE: 0.10211132437619962 | CSI: 0.12045276644831707 | Loss: 0.011056198738515377\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 15 | MAE/CSI: 0.8269470306746735 | MAE: 0.09980806142034548 | CSI: 0.12069462458567518 | Loss: 0.010897000320255756\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 16 | MAE/CSI: 0.9089498997078406 | MAE: 0.10871401151631478 | CSI: 0.1196039644763141 | Loss: 0.010795501060783863\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 17 | MAE/CSI: 1.0144721688973737 | MAE: 0.12 | CSI: 0.11828811442842548 | Loss: 0.010921232402324677\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 18 | MAE/CSI: 0.9739822739190495 | MAE: 0.11570057581573896 | CSI: 0.11879125412541254 | Loss: 0.010769343003630638\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 19 | MAE/CSI: 1.012934577465377 | MAE: 0.11984644913627639 | CSI: 0.11831608062501935 | Loss: 0.010915917344391346\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 20 | MAE/CSI: 0.9746930990857315 | MAE: 0.11577735124760077 | CSI: 0.1187833907465088 | Loss: 0.010718860663473606\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 21 | MAE/CSI: 0.8865415622549162 | MAE: 0.10625719769673704 | CSI: 0.11985585585585586 | Loss: 0.010684728622436523\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 22 | MAE/CSI: 0.9431093784188942 | MAE: 0.11239923224568138 | CSI: 0.11917942374104427 | Loss: 0.010712772607803345\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 23 | MAE/CSI: 0.8567591311220998 | MAE: 0.10303262955854127 | CSI: 0.12025857188442496 | Loss: 0.01069499459117651\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 24 | MAE/CSI: 0.9041815313239039 | MAE: 0.10817658349328214 | CSI: 0.11964033741541441 | Loss: 0.01064450852572918\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n", + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 25 | MAE/CSI: 0.9142298319658559 | MAE: 0.1092514395393474 | CSI: 0.11950106605415761 | Loss: 0.010635210201144218\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n", + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 26 | MAE/CSI: 0.9199470836152046 | MAE: 0.10986564299424184 | CSI: 0.11942604629124133 | Loss: 0.01062960084527731\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n", + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 27 | MAE/CSI: 0.8950865803184656 | MAE: 0.1071785028790787 | CSI: 0.11974093370950802 | Loss: 0.0106242336332798\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n", + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 28 | MAE/CSI: 0.8978677068424368 | MAE: 0.10748560460652591 | CSI: 0.11971207315566174 | Loss: 0.010626324452459812\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 29 | MAE/CSI: 0.8928407437743435 | MAE: 0.10694817658349329 | CSI: 0.11978415784487374 | Loss: 0.010623933747410774\n", + "\n", + "Model saved at ../models/baseline_bs128_epochs30_lr0.0005_radam.ckpt\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:41: RuntimeWarning: overflow encountered in multiply\n", + " z *= np.power(10.0, dbz_max / 10.0)\n", + "/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.\n", + " warnings.warn(SAVE_STATE_WARNING, UserWarning)\n" + ] + }, + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 13;\n", + " var nbb_unformatted_code = \"# RAdam\\nfold = 0\\ntrain_fold(df, fold)\";\n", + " var nbb_formatted_code = \"# RAdam\\nfold = 0\\ntrain_fold(df, fold)\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "# model.lr = lr_finder.suggestion()\n", - "# model.lr" + "# RAdam\n", + "fold = 0\n", + "train_fold(df, fold)" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 14, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True, used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "Using native 16bit precision.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training fold0...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + " | Name | Type | Params\n", + "-----------------------------------------\n", + "0 | criterion | L1Loss | 0 \n", + "1 | encoder | Encoder | 666 K \n", + "2 | decoder | Decoder | 664 K \n", + "3 | out | Sequential | 577 \n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a9f80a7e3c2947fa902f13fa98a58573", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 1000000000000.0 | MAE: 1.0 | CSI: 0.0 | Loss: 0.07034780830144882\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d103dec33edb44fa8ad2ff9273414040", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 1.1385640966037704 | MAE: 0.13389635316698656 | CSI: 0.11760106748952318 | Loss: 0.01573796570301056\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 | MAE/CSI: 1.085314053724478 | MAE: 0.12806142034548945 | CSI: 0.1179948051948052 | Loss: 0.013925119303166866\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 2 | MAE/CSI: 0.9167696697228916 | MAE: 0.11001919385796545 | CSI: 0.12000745387912293 | Loss: 0.013134698383510113\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 3 | MAE/CSI: 1.0299586717682772 | MAE: 0.12199616122840691 | CSI: 0.11844762762949383 | Loss: 0.012236752547323704\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 4 | MAE/CSI: 1.013925980609321 | MAE: 0.12023032629558542 | CSI: 0.11857899747506105 | Loss: 0.012843919917941093\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 5 | MAE/CSI: 1.011833885875258 | MAE: 0.12 | CSI: 0.11859654205510776 | Loss: 0.011795842088758945\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 6 | MAE/CSI: 1.0590260203151582 | MAE: 0.12491362763915546 | CSI: 0.11795142446162284 | Loss: 0.011630662716925144\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 7 | MAE/CSI: 0.9600188360970295 | MAE: 0.1143953934740883 | CSI: 0.1191595301798498 | Loss: 0.011449274607002735\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 8 | MAE/CSI: 0.9020290172989984 | MAE: 0.10809980806142035 | CSI: 0.11984072129321105 | Loss: 0.011327717453241348\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 9 | MAE/CSI: 0.9770176426018803 | MAE: 0.11616122840690979 | CSI: 0.11889368558031933 | Loss: 0.011341550387442112\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 10 | MAE/CSI: 0.9844928312476287 | MAE: 0.11692898272552783 | CSI: 0.11877078127258835 | Loss: 0.011181168258190155\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 11 | MAE/CSI: 0.8644004373845067 | MAE: 0.10395393474088292 | CSI: 0.12026131668160787 | Loss: 0.011257469654083252\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 12 | MAE/CSI: 0.8851591085350253 | MAE: 0.10618042226487524 | CSI: 0.119956312080122 | Loss: 0.01108333095908165\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 13 | MAE/CSI: 1.013253430079364 | MAE: 0.1199232245681382 | CSI: 0.11835461988787131 | Loss: 0.011116106063127518\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 14 | MAE/CSI: 1.0395499227504847 | MAE: 0.12268714011516314 | CSI: 0.11801947884283692 | Loss: 0.01116656232625246\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 15 | MAE/CSI: 0.8960707942438373 | MAE: 0.1073320537428023 | CSI: 0.11978077450061296 | Loss: 0.010934371501207352\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 16 | MAE/CSI: 0.9623145456213387 | MAE: 0.11447216890595009 | CSI: 0.11895504378048906 | Loss: 0.011090419255197048\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 17 | MAE/CSI: 0.896153875733789 | MAE: 0.1073320537428023 | CSI: 0.11976966975009785 | Loss: 0.010892595164477825\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 18 | MAE/CSI: 0.9479289803918486 | MAE: 0.11293666026871402 | CSI: 0.11914042360122915 | Loss: 0.010916203260421753\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 19 | MAE/CSI: 0.8959230938172563 | MAE: 0.1073320537428023 | CSI: 0.1198005213646152 | Loss: 0.010992695577442646\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 20 | MAE/CSI: 0.9846853003130904 | MAE: 0.11685220729366602 | CSI: 0.1186695965254351 | Loss: 0.010856914333999157\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 21 | MAE/CSI: 0.9188840579993098 | MAE: 0.10978886756238004 | CSI: 0.1194806533051677 | Loss: 0.01079531665891409\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 22 | MAE/CSI: 0.9567146873342918 | MAE: 0.11385796545105566 | CSI: 0.11900932112513404 | Loss: 0.010844394564628601\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 23 | MAE/CSI: 0.9026124925539706 | MAE: 0.10802303262955854 | CSI: 0.11967819359889573 | Loss: 0.010828354395925999\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 24 | MAE/CSI: 0.9244720777907447 | MAE: 0.11040307101727448 | CSI: 0.11942282916774027 | Loss: 0.010767893865704536\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 25 | MAE/CSI: 0.9259653880900643 | MAE: 0.11055662188099807 | CSI: 0.11939606307327631 | Loss: 0.010765370912849903\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 26 | MAE/CSI: 0.9082203937558084 | MAE: 0.10863723608445297 | CSI: 0.11961549953122264 | Loss: 0.010760427452623844\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 27 | MAE/CSI: 0.9217818510578139 | MAE: 0.11009596928982726 | CSI: 0.11943820456278466 | Loss: 0.010757229290902615\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 28 | MAE/CSI: 0.9232172476669039 | MAE: 0.11024952015355087 | CSI: 0.11941882631768767 | Loss: 0.010758249089121819\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 29 | MAE/CSI: 0.9124960947270578 | MAE: 0.1090978886756238 | CSI: 0.11955984174085064 | Loss: 0.010757893323898315\n", + "\n", + "Model saved at ../models/baseline_bs128_epochs30_lr0.0005_ranger.ckpt\n" + ] + }, + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 14;\n", + " var nbb_unformatted_code = \"# Ranger\\nfold = 0\\ntrain_fold(df, fold)\";\n", + " var nbb_formatted_code = \"# Ranger\\nfold = 0\\ntrain_fold(df, fold)\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "trainer.fit(model, datamodule)" + "# Ranger\n", + "fold = 0\n", + "train_fold(df, fold)" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 16, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True, used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "Using native 16bit precision.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training fold0...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + " | Name | Type | Params\n", + "-----------------------------------------\n", + "0 | criterion | L1Loss | 0 \n", + "1 | encoder | Encoder | 666 K \n", + "2 | decoder | Decoder | 664 K \n", + "3 | out | Sequential | 577 \n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2ce17f85e18b43779829a708e2e3a885", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 1000000000000.0 | MAE: 1.0 | CSI: 0.0 | Loss: 0.07034780830144882\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "aa6a7ddfe60a43c19823313c21badb46", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 1.431734138641048 | MAE: 0.16314779270633398 | CSI: 0.11395117871517432 | Loss: 0.014841477386653423\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 | MAE/CSI: 1.0666386134475698 | MAE: 0.12591170825335893 | CSI: 0.11804533106608879 | Loss: 0.01233157329261303\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 2 | MAE/CSI: 0.8780003852209621 | MAE: 0.1056429942418426 | CSI: 0.12032226411196612 | Loss: 0.012936452403664589\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 3 | MAE/CSI: 0.9439850315398074 | MAE: 0.1127063339731286 | CSI: 0.11939419610111883 | Loss: 0.011566683650016785\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 4 | MAE/CSI: 1.106750432758379 | MAE: 0.12982725527831093 | CSI: 0.11730490581660112 | Loss: 0.01176687702536583\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 5 | MAE/CSI: 1.0036502619730607 | MAE: 0.11900191938579655 | CSI: 0.11856911106748365 | Loss: 0.011306156404316425\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 6 | MAE/CSI: 1.2770520777797807 | MAE: 0.1472552783109405 | CSI: 0.11530874963664299 | Loss: 0.011934855952858925\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 7 | MAE/CSI: 1.073235025239742 | MAE: 0.12629558541266794 | CSI: 0.11767747272633267 | Loss: 0.011313681490719318\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 8 | MAE/CSI: 0.8434373491310988 | MAE: 0.10165067178502879 | CSI: 0.12051952867501647 | Loss: 0.011227131821215153\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 9 | MAE/CSI: 1.0523852115853143 | MAE: 0.12406909788867562 | CSI: 0.11789323578647157 | Loss: 0.011084862984716892\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 10 | MAE/CSI: 0.9726492106175404 | MAE: 0.11562380038387716 | CSI: 0.118875128998968 | Loss: 0.011067209765315056\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 11 | MAE/CSI: 0.8857018008780639 | MAE: 0.10625719769673704 | CSI: 0.11996949491410139 | Loss: 0.010920040309429169\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 12 | MAE/CSI: 0.9229223356706401 | MAE: 0.11024952015355087 | CSI: 0.1194569855897087 | Loss: 0.010936064645648003\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 13 | MAE/CSI: 0.942111649744198 | MAE: 0.11232245681381958 | CSI: 0.11922414593150954 | Loss: 0.010872152633965015\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 14 | MAE/CSI: 0.9676800877047635 | MAE: 0.11508637236084453 | CSI: 0.1189301855253111 | Loss: 0.010937022045254707\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 15 | MAE/CSI: 0.8352548220955233 | MAE: 0.10072936660268714 | CSI: 0.1205971685971686 | Loss: 0.010871059261262417\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 16 | MAE/CSI: 0.9581314429908468 | MAE: 0.11401151631477927 | CSI: 0.11899360692926376 | Loss: 0.0107908695936203\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 17 | MAE/CSI: 0.8868929365209549 | MAE: 0.10633397312859885 | CSI: 0.11989493742596694 | Loss: 0.010807439684867859\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 18 | MAE/CSI: 1.0813794255930738 | MAE: 0.127063339731286 | CSI: 0.11750116261044799 | Loss: 0.01099312026053667\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 19 | MAE/CSI: 0.8605000659886831 | MAE: 0.1034932821497121 | CSI: 0.1202710914727724 | Loss: 0.010955577716231346\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 20 | MAE/CSI: 0.9247674515027908 | MAE: 0.11040307101727448 | CSI: 0.11938468513023409 | Loss: 0.010693227872252464\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 21 | MAE/CSI: 0.7480731419030813 | MAE: 0.09105566218809981 | CSI: 0.12172026649119921 | Loss: 0.011536278761923313\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 22 | MAE/CSI: 0.9189219325547425 | MAE: 0.10978886756238004 | CSI: 0.11947572875557708 | Loss: 0.010655155405402184\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 23 | MAE/CSI: 0.8743660610106783 | MAE: 0.10495201535508637 | CSI: 0.12003212388287138 | Loss: 0.010694632306694984\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 24 | MAE/CSI: 0.8596956033848641 | MAE: 0.10333973128598849 | CSI: 0.12020502475323953 | Loss: 0.010617670603096485\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 25 | MAE/CSI: 0.9320987589153724 | MAE: 0.11117082533589251 | CSI: 0.11926936311375765 | Loss: 0.010588045231997967\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 26 | MAE/CSI: 0.8901725335752227 | MAE: 0.10664107485604607 | CSI: 0.11979820858643056 | Loss: 0.01059663761407137\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 27 | MAE/CSI: 0.9442341384933582 | MAE: 0.11247600767754319 | CSI: 0.11911876841910023 | Loss: 0.010609036311507225\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 28 | MAE/CSI: 0.8273726514056701 | MAE: 0.09980806142034548 | CSI: 0.12063253631836701 | Loss: 0.010630769655108452\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 29 | MAE/CSI: 0.9478179216236259 | MAE: 0.1128598848368522 | CSI: 0.11907338135427294 | Loss: 0.01066933386027813\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 30 | MAE/CSI: 0.8845922772176007 | MAE: 0.10602687140115163 | CSI: 0.11985959422318754 | Loss: 0.010545131750404835\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 31 | MAE/CSI: 0.8893270175124783 | MAE: 0.10656429942418426 | CSI: 0.11982577536142346 | Loss: 0.010544568300247192\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 32 | MAE/CSI: 0.8942856105404078 | MAE: 0.1071017274472169 | CSI: 0.11976232892934743 | Loss: 0.010560829192399979\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 33 | MAE/CSI: 0.8121483608195074 | MAE: 0.0981190019193858 | CSI: 0.12081413526411058 | Loss: 0.010571216233074665\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 34 | MAE/CSI: 0.9277556442673833 | MAE: 0.11071017274472168 | CSI: 0.11933117672511487 | Loss: 0.010578982532024384\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 35 | MAE/CSI: 0.9421531481549642 | MAE: 0.11224568138195777 | CSI: 0.11913740520936367 | Loss: 0.010580139234662056\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 36 | MAE/CSI: 0.8190910415457837 | MAE: 0.09888675623800385 | CSI: 0.12072742982338844 | Loss: 0.01055662240833044\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 37 | MAE/CSI: 0.9201081769333753 | MAE: 0.10986564299424184 | CSI: 0.11940513707800367 | Loss: 0.010542426258325577\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 38 | MAE/CSI: 0.9013507058044742 | MAE: 0.10786948176583493 | CSI: 0.11967537282689297 | Loss: 0.010510878637433052\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 39 | MAE/CSI: 0.8654870949149819 | MAE: 0.10395393474088292 | CSI: 0.12011032325124268 | Loss: 0.01051416713744402\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 40 | MAE/CSI: 0.9463158106190775 | MAE: 0.1127063339731286 | CSI: 0.11910012778762522 | Loss: 0.010519065894186497\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 41 | MAE/CSI: 0.8654692808571053 | MAE: 0.10395393474088292 | CSI: 0.12011279549641339 | Loss: 0.010517258197069168\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 42 | MAE/CSI: 0.9022928392247763 | MAE: 0.10794625719769674 | CSI: 0.11963550247116969 | Loss: 0.010511299595236778\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 43 | MAE/CSI: 0.8773881200914918 | MAE: 0.10525911708253359 | CSI: 0.11996870560622594 | Loss: 0.010500743053853512\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 44 | MAE/CSI: 0.8986881268806691 | MAE: 0.10756238003838772 | CSI: 0.11968821754754476 | Loss: 0.010506429709494114\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 45 | MAE/CSI: 0.9135933108757834 | MAE: 0.10917466410748561 | CSI: 0.11950028837439236 | Loss: 0.010512260720133781\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 46 | MAE/CSI: 0.8986140990898661 | MAE: 0.10756238003838772 | CSI: 0.11969807745775453 | Loss: 0.010502759367227554\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 47 | MAE/CSI: 0.8943869105405711 | MAE: 0.1071017274472169 | CSI: 0.11974876441515651 | Loss: 0.010501205921173096\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 48 | MAE/CSI: 0.8986696199329683 | MAE: 0.10756238003838772 | CSI: 0.11969068237280805 | Loss: 0.010503779165446758\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 49 | MAE/CSI: 0.8781667057234775 | MAE: 0.10533589251439539 | CSI: 0.11994976788232509 | Loss: 0.010497772134840488\n", + "\n", + "Model saved at ../models/baseline_bs128_epochs50_lr0.0005_adamw.ckpt\n" + ] + }, + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 16;\n", + " var nbb_unformatted_code = \"# AdamW bs50\\nfold = 0\\ntrain_fold(df, fold)\";\n", + " var nbb_formatted_code = \"# AdamW bs50\\nfold = 0\\ntrain_fold(df, fold)\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "trainer.save_checkpoint(\"baseline_bs256_epoch10.ckpt\")" + "# AdamW bs50\n", + "fold = 0\n", + "train_fold(df, fold)" ] }, { - "cell_type": "markdown", - "metadata": {}, + "cell_type": "code", + "execution_count": 14, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True, used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "Using native 16bit precision.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training fold0...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + " | Name | Type | Params\n", + "-----------------------------------------\n", + "0 | criterion | L1Loss | 0 \n", + "1 | encoder | Encoder | 666 K \n", + "2 | decoder | Decoder | 664 K \n", + "3 | out | Sequential | 577 \n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "bdfb64c4e871464d997eee6447b0018d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 1000000000000.0 | MAE: 1.0 | CSI: 0.0 | Loss: 0.07034780830144882\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "19c5c4bdb2124ca4bcc59d75ed888ccf", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 1.1213412779907173 | MAE: 0.13205374280230325 | CSI: 0.1177640967946915 | Loss: 0.015904391184449196\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 | MAE/CSI: 1.0232372097056714 | MAE: 0.12153550863723608 | CSI: 0.11877549749307091 | Loss: 0.013309536501765251\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 2 | MAE/CSI: 1.0139998351280197 | MAE: 0.12046065259117082 | CSI: 0.11879750707745274 | Loss: 0.012951242737472057\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 3 | MAE/CSI: 1.1047516526727235 | MAE: 0.12990403071017276 | CSI: 0.11758663623158332 | Loss: 0.012385123409330845\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 4 | MAE/CSI: 1.117690074479474 | MAE: 0.13120921305182343 | CSI: 0.11739319874680997 | Loss: 0.012682433240115643\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 5 | MAE/CSI: 0.9326559243791501 | MAE: 0.11155470249520154 | CSI: 0.1196097076942159 | Loss: 0.011910993605852127\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 6 | MAE/CSI: 1.027332020181361 | MAE: 0.1216122840690979 | CSI: 0.11837680679572474 | Loss: 0.011943171732127666\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 7 | MAE/CSI: 1.0422545127997724 | MAE: 0.12314779270633397 | CSI: 0.11815520220150837 | Loss: 0.011781273409724236\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 8 | MAE/CSI: 0.976940484692145 | MAE: 0.11623800383877159 | CSI: 0.11898166332458189 | Loss: 0.011739981360733509\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 9 | MAE/CSI: 0.9288001062140463 | MAE: 0.1110172744721689 | CSI: 0.11952762895750106 | Loss: 0.011434326879680157\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 10 | MAE/CSI: 0.994726636645222 | MAE: 0.11808061420345489 | CSI: 0.11870659722222222 | Loss: 0.011498453095555305\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 11 | MAE/CSI: 0.86822094408707 | MAE: 0.10441458733205375 | CSI: 0.12026269124499979 | Loss: 0.011420509777963161\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 12 | MAE/CSI: 0.9243556771232797 | MAE: 0.11047984644913628 | CSI: 0.11952092596222288 | Loss: 0.011264418251812458\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 13 | MAE/CSI: 0.8254324608893246 | MAE: 0.09973128598848369 | CSI: 0.12082307240523024 | Loss: 0.01140605192631483\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 14 | MAE/CSI: 0.9816671665505846 | MAE: 0.11662188099808062 | CSI: 0.11879981827943667 | Loss: 0.011349151842296124\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 15 | MAE/CSI: 0.9067617427372024 | MAE: 0.10856046065259117 | CSI: 0.1197232476129591 | Loss: 0.011154781095683575\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 16 | MAE/CSI: 0.9926448922389418 | MAE: 0.11777351247600767 | CSI: 0.11864616782480304 | Loss: 0.011162293143570423\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 17 | MAE/CSI: 0.9942763377520618 | MAE: 0.11792706333973128 | CSI: 0.11860592358594774 | Loss: 0.011070771142840385\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 18 | MAE/CSI: 1.0252305224848342 | MAE: 0.12122840690978887 | CSI: 0.11824502319238835 | Loss: 0.011332811787724495\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 19 | MAE/CSI: 0.8948916276541378 | MAE: 0.1072552783109405 | CSI: 0.11985281233572806 | Loss: 0.011093420907855034\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 20 | MAE/CSI: 1.013906822535492 | MAE: 0.12 | CSI: 0.11835407093809695 | Loss: 0.01113690622150898\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 21 | MAE/CSI: 0.8771994920876025 | MAE: 0.10533589251439539 | CSI: 0.1200820263391109 | Loss: 0.011256896890699863\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 22 | MAE/CSI: 0.9377515842154777 | MAE: 0.11186180422264876 | CSI: 0.11928724632898861 | Loss: 0.010979496873915195\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 23 | MAE/CSI: 0.8964470593831949 | MAE: 0.10740882917466411 | CSI: 0.11981614312804023 | Loss: 0.010982400737702847\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 24 | MAE/CSI: 0.9242243449999964 | MAE: 0.11040307101727448 | CSI: 0.11945483974061588 | Loss: 0.0109475776553154\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 25 | MAE/CSI: 0.9463299270633719 | MAE: 0.11278310940499041 | CSI: 0.11917948083289502 | Loss: 0.010897496715188026\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 26 | MAE/CSI: 0.9341393703215161 | MAE: 0.11147792706333973 | CSI: 0.11933757488889118 | Loss: 0.010901767760515213\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 27 | MAE/CSI: 0.9795402145427297 | MAE: 0.1163147792706334 | CSI: 0.11874426138180767 | Loss: 0.010949915274977684\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 28 | MAE/CSI: 0.8924269999687918 | MAE: 0.10694817658349329 | CSI: 0.11983969174659757 | Loss: 0.010875929147005081\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 29 | MAE/CSI: 1.0142418426017923 | MAE: 0.12 | CSI: 0.1183149767230612 | Loss: 0.010970874689519405\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 30 | MAE/CSI: 0.9095483358412575 | MAE: 0.10879078694817658 | CSI: 0.11960968170717885 | Loss: 0.010860883630812168\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 31 | MAE/CSI: 0.9272108389625103 | MAE: 0.11071017274472168 | CSI: 0.1194012926635673 | Loss: 0.010840944945812225\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 32 | MAE/CSI: 0.9187514970552954 | MAE: 0.10978886756238004 | CSI: 0.11949789242613186 | Loss: 0.010872705839574337\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 33 | MAE/CSI: 0.8528154604418748 | MAE: 0.10264875239923224 | CSI: 0.12036455383347923 | Loss: 0.01089971512556076\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 34 | MAE/CSI: 0.9545016953524549 | MAE: 0.11362763915547025 | CSI: 0.11904393644115857 | Loss: 0.010870699770748615\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 35 | MAE/CSI: 0.9730808299782855 | MAE: 0.11562380038387716 | CSI: 0.1188224007922181 | Loss: 0.010838640853762627\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 36 | MAE/CSI: 0.8882988317697552 | MAE: 0.10648752399232246 | CSI: 0.11987804124348489 | Loss: 0.010831070132553577\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 37 | MAE/CSI: 0.9609859280455686 | MAE: 0.11431861804222648 | CSI: 0.11895972116237342 | Loss: 0.010835222899913788\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 38 | MAE/CSI: 0.9067524560855617 | MAE: 0.10848368522072936 | CSI: 0.11963980300438913 | Loss: 0.01081210095435381\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 39 | MAE/CSI: 0.9167002517731152 | MAE: 0.10955854126679462 | CSI: 0.1195140298630504 | Loss: 0.010801968164741993\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 40 | MAE/CSI: 0.9488315770034639 | MAE: 0.11301343570057582 | CSI: 0.11910800445379191 | Loss: 0.010811405256390572\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 41 | MAE/CSI: 0.9066310051806491 | MAE: 0.10848368522072936 | CSI: 0.11965582976969447 | Loss: 0.010802964679896832\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 42 | MAE/CSI: 0.9316570332554117 | MAE: 0.11117082533589251 | CSI: 0.11932591218305504 | Loss: 0.010797062888741493\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 43 | MAE/CSI: 0.8995883642404501 | MAE: 0.10771593090211132 | CSI: 0.11973913312246938 | Loss: 0.010799948126077652\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 44 | MAE/CSI: 0.9295530052067085 | MAE: 0.1109404990403071 | CSI: 0.11934822265967206 | Loss: 0.010795538313686848\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 45 | MAE/CSI: 0.9423350900232739 | MAE: 0.11232245681381958 | CSI: 0.1191958762886598 | Loss: 0.010799865238368511\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 46 | MAE/CSI: 0.933066581537057 | MAE: 0.11132437619961612 | CSI: 0.11931021687144389 | Loss: 0.010794575326144695\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 47 | MAE/CSI: 0.9401394272301328 | MAE: 0.11209213051823416 | CSI: 0.11922926245901977 | Loss: 0.010793674737215042\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 48 | MAE/CSI: 0.9387956837220837 | MAE: 0.11193857965451055 | CSI: 0.11923635951303488 | Loss: 0.01079593040049076\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 49 | MAE/CSI: 0.9216488734398317 | MAE: 0.11009596928982726 | CSI: 0.1194554373814824 | Loss: 0.010792912915349007\n", + "\n", + "Model saved at ../models/baseline_bs128_epochs50_lr0.0001_adamw.ckpt\n" + ] + }, + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 14;\n", + " var nbb_unformatted_code = \"# AdamW bs50 lr_1e-4\\nfold = 0\\ntrain_fold(df, fold)\";\n", + " var nbb_formatted_code = \"# AdamW bs50 lr_1e-4\\nfold = 0\\ntrain_fold(df, fold)\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "## Inference" + "# AdamW bs50 lr 1e-4\n", + "fold = 0\n", + "train_fold(df, fold)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True, used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "Using native 16bit precision.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training fold0...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + " | Name | Type | Params\n", + "-----------------------------------------\n", + "0 | criterion | L1Loss | 0 \n", + "1 | encoder | Encoder | 666 K \n", + "2 | decoder | Decoder | 664 K \n", + "3 | out | Sequential | 577 \n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e27c63ed9b4f4564bf0c0084b24a2947", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 1000000000000.0 | MAE: 1.0 | CSI: 0.0 | Loss: 0.07034780830144882\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "bd9efb2044db4f09a98589959c2da2c2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 1.138320426835042 | MAE: 0.13381957773512476 | CSI: 0.11755879502756099 | Loss: 0.01589236781001091\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 | MAE/CSI: 1.002306027157473 | MAE: 0.11930902111324376 | CSI: 0.11903452426660578 | Loss: 0.013353673741221428\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 2 | MAE/CSI: 1.0154380687146456 | MAE: 0.12061420345489443 | CSI: 0.11878046251166649 | Loss: 0.01299264281988144\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 3 | MAE/CSI: 1.114900060171623 | MAE: 0.130978886756238 | CSI: 0.11748038360941587 | Loss: 0.012409305199980736\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 4 | MAE/CSI: 1.1488281191112761 | MAE: 0.1344337811900192 | CSI: 0.11701818483766503 | Loss: 0.01272483542561531\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 5 | MAE/CSI: 0.9354791495688765 | MAE: 0.11186180422264876 | CSI: 0.11957701491611623 | Loss: 0.011894077993929386\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 6 | MAE/CSI: 1.0315937112542246 | MAE: 0.12207293666026871 | CSI: 0.11833431643434437 | Loss: 0.011928667314350605\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 7 | MAE/CSI: 1.007229160999509 | MAE: 0.11946257197696737 | CSI: 0.11860515620637235 | Loss: 0.011697919107973576\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 8 | MAE/CSI: 0.9667908785376472 | MAE: 0.11516314779270634 | CSI: 0.11911898462047296 | Loss: 0.011760309338569641\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 9 | MAE/CSI: 0.9315854624066506 | MAE: 0.11132437619961612 | CSI: 0.11949990708430551 | Loss: 0.011442412622272968\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 10 | MAE/CSI: 0.9982712355163732 | MAE: 0.11846449136276392 | CSI: 0.11866964322625986 | Loss: 0.011561795137822628\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 11 | MAE/CSI: 0.8730246773867592 | MAE: 0.10495201535508637 | CSI: 0.12021655065738593 | Loss: 0.011409180238842964\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 12 | MAE/CSI: 0.9235669002860872 | MAE: 0.11040307101727448 | CSI: 0.11953987413597442 | Loss: 0.011275683529675007\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 13 | MAE/CSI: 0.81997550304274 | MAE: 0.09911708253358925 | CSI: 0.12087810204691314 | Loss: 0.011408326216042042\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 14 | MAE/CSI: 0.9859672869815994 | MAE: 0.11708253358925144 | CSI: 0.11874890286339745 | Loss: 0.011363031342625618\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 15 | MAE/CSI: 0.9031850700262466 | MAE: 0.10817658349328214 | CSI: 0.11977233358079684 | Loss: 0.011150655336678028\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 16 | MAE/CSI: 0.9805180073263742 | MAE: 0.116468330134357 | CSI: 0.11878244893324938 | Loss: 0.011165974661707878\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 17 | MAE/CSI: 0.9950718900787031 | MAE: 0.1180038387715931 | CSI: 0.11858825472525884 | Loss: 0.011075479909777641\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 18 | MAE/CSI: 1.0223605930731856 | MAE: 0.12092130518234165 | CSI: 0.11827657090912848 | Loss: 0.011321539990603924\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 19 | MAE/CSI: 0.9021475974230198 | MAE: 0.10802303262955854 | CSI: 0.11973986622280396 | Loss: 0.011085733771324158\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 20 | MAE/CSI: 1.0066560658663206 | MAE: 0.11923224568138195 | CSI: 0.11844387544395804 | Loss: 0.011128334328532219\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 21 | MAE/CSI: 0.8885852706487493 | MAE: 0.10656429942418426 | CSI: 0.11992579996908333 | Loss: 0.011248618364334106\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 22 | MAE/CSI: 0.9406901157859805 | MAE: 0.11216890595009597 | CSI: 0.11924108063518252 | Loss: 0.010979549959301949\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 23 | MAE/CSI: 0.8930616191910925 | MAE: 0.10702495201535508 | CSI: 0.11984050116430028 | Loss: 0.010991481132805347\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 24 | MAE/CSI: 0.9242719859212941 | MAE: 0.11040307101727448 | CSI: 0.11944868252855552 | Loss: 0.010939003899693489\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 25 | MAE/CSI: 0.9343801849915703 | MAE: 0.11147792706333973 | CSI: 0.11930681841611514 | Loss: 0.010901150293648243\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 26 | MAE/CSI: 0.9321022845520779 | MAE: 0.11124760076775432 | CSI: 0.11935128001567155 | Loss: 0.010904178023338318\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 27 | MAE/CSI: 0.9866086956438601 | MAE: 0.11708253358925144 | CSI: 0.11867170247456298 | Loss: 0.01094411313533783\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 28 | MAE/CSI: 0.9017261190518742 | MAE: 0.10794625719769674 | CSI: 0.11971069143510649 | Loss: 0.01087891310453415\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 29 | MAE/CSI: 0.9998807513369863 | MAE: 0.11846449136276392 | CSI: 0.11847861977876836 | Loss: 0.010963994078338146\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 30 | MAE/CSI: 0.9089405350759588 | MAE: 0.10871401151631478 | CSI: 0.11960519673195207 | Loss: 0.010861601680517197\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 31 | MAE/CSI: 0.9265165003474861 | MAE: 0.11063339731285989 | CSI: 0.11940790830000413 | Loss: 0.0108433673158288\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 32 | MAE/CSI: 0.924424436869447 | MAE: 0.11040307101727448 | CSI: 0.119428983714698 | Loss: 0.010874624364078045\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 33 | MAE/CSI: 0.8675104462257137 | MAE: 0.10426103646833014 | CSI: 0.12018418558655074 | Loss: 0.010894048027694225\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 34 | MAE/CSI: 0.9538233214104781 | MAE: 0.11355086372360844 | CSI: 0.11904811003650011 | Loss: 0.010875188745558262\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 35 | MAE/CSI: 0.9766951757220987 | MAE: 0.11600767754318618 | CSI: 0.11877572494042646 | Loss: 0.010839647613465786\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 36 | MAE/CSI: 0.8868015843103648 | MAE: 0.10633397312859885 | CSI: 0.11990728817924286 | Loss: 0.010835183784365654\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 37 | MAE/CSI: 0.9737914151813813 | MAE: 0.11570057581573896 | CSI: 0.11881453667694783 | Loss: 0.01084214635193348\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 38 | MAE/CSI: 0.9109421263596105 | MAE: 0.10894433781190019 | CSI: 0.1195952351510655 | Loss: 0.010812942869961262\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 39 | MAE/CSI: 0.9110078346453826 | MAE: 0.10894433781190019 | CSI: 0.11958660910243069 | Loss: 0.010806724429130554\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 40 | MAE/CSI: 0.9516903840669387 | MAE: 0.11332053742802303 | CSI: 0.11907290367147468 | Loss: 0.010813283734023571\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 41 | MAE/CSI: 0.9131134429218587 | MAE: 0.10917466410748561 | CSI: 0.11956308928847442 | Loss: 0.010803856886923313\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 42 | MAE/CSI: 0.9324002005527466 | MAE: 0.11124760076775432 | CSI: 0.11931314547216096 | Loss: 0.010799448005855083\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 43 | MAE/CSI: 0.9046681478553775 | MAE: 0.10825335892514396 | CSI: 0.11966084931902005 | Loss: 0.010800926014780998\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 44 | MAE/CSI: 0.9323809801656068 | MAE: 0.11124760076775432 | CSI: 0.11931560502989075 | Loss: 0.010798103176057339\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 45 | MAE/CSI: 0.9466129584040338 | MAE: 0.11278310940499041 | CSI: 0.1191438469152095 | Loss: 0.010802337899804115\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 46 | MAE/CSI: 0.9281252628678694 | MAE: 0.1107869481765835 | CSI: 0.11936637500128827 | Loss: 0.010795307345688343\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 47 | MAE/CSI: 0.9366597046688451 | MAE: 0.11170825335892515 | CSI: 0.11926236689928153 | Loss: 0.010794851928949356\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 48 | MAE/CSI: 0.9374134629679352 | MAE: 0.11178502879078694 | CSI: 0.11924837140265523 | Loss: 0.010798281989991665\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 49 | MAE/CSI: 0.921677368643685 | MAE: 0.11009596928982726 | CSI: 0.11945174421600453 | Loss: 0.010794124566018581\n", + "\n", + "Model saved at ../models/baseline_bs128_epochs50_lr0.0001_adamw.ckpt\n" + ] + }, + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 16;\n", + " var nbb_unformatted_code = \"# AdamW bs50 lr 1e-3\\nfold = 0\\ntrain_fold(df, fold)\";\n", + " var nbb_formatted_code = \"# AdamW bs50 lr 1e-3\\nfold = 0\\ntrain_fold(df, fold)\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# AdamW bs50 lr 1e-3\n", + "fold = 0\n", + "train_fold(df, fold)" ] }, { "cell_type": "code", "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True, used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "Using native 16bit precision.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training fold0...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + " | Name | Type | Params\n", + "-----------------------------------------\n", + "0 | criterion | L1Loss | 0 \n", + "1 | encoder | Encoder | 666 K \n", + "2 | decoder | Decoder | 664 K \n", + "3 | out | Sequential | 577 \n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "98f00f38408b4783ae83e6f95bc1d8d4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 1000000000000.0 | MAE: 1.0 | CSI: 0.0 | Loss: 0.07034783065319061\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0201195004ce4af48dcccf4728b1f7c1", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 0.9381335799455521 | MAE: 0.11247600767754319 | CSI: 0.11989338200976986 | Loss: 0.013658225536346436\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 | MAE/CSI: 1.091616736304388 | MAE: 0.12852207293666026 | CSI: 0.11773552810363963 | Loss: 0.012366656213998795\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 2 | MAE/CSI: 0.8717283066361261 | MAE: 0.10495201535508637 | CSI: 0.12039532794249776 | Loss: 0.012938014231622219\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 3 | MAE/CSI: 0.922680779149949 | MAE: 0.11040307101727448 | CSI: 0.11965467744766979 | Loss: 0.011683200486004353\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 4 | MAE/CSI: 1.1422260861540856 | MAE: 0.13351247600767754 | CSI: 0.11688795906953622 | Loss: 0.011974487453699112\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 5 | MAE/CSI: 1.0207991020989804 | MAE: 0.12084452975047985 | CSI: 0.11838228452687405 | Loss: 0.011398336850106716\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 6 | MAE/CSI: 1.3015827350484537 | MAE: 0.14971209213051823 | CSI: 0.11502310848003323 | Loss: 0.012028665281832218\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 7 | MAE/CSI: 1.0668947256657269 | MAE: 0.1256046065259117 | CSI: 0.11772914749997415 | Loss: 0.011318517848849297\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 8 | MAE/CSI: 0.8600657788579773 | MAE: 0.1034932821497121 | CSI: 0.1203318219291014 | Loss: 0.011268786154687405\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 9 | MAE/CSI: 1.1294213724123174 | MAE: 0.13213051823416508 | CSI: 0.11698956780923994 | Loss: 0.011137530207633972\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 10 | MAE/CSI: 0.9699192152207351 | MAE: 0.11531669865642995 | CSI: 0.11889309629690772 | Loss: 0.011074976995587349\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 11 | MAE/CSI: 0.8592266415081566 | MAE: 0.10333973128598849 | CSI: 0.1202706320927646 | Loss: 0.01094895415008068\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 12 | MAE/CSI: 0.9156201995627933 | MAE: 0.10948176583493283 | CSI: 0.11957115612597288 | Loss: 0.010989376343786716\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 13 | MAE/CSI: 0.9186473420278555 | MAE: 0.10978886756238004 | CSI: 0.11951144094001237 | Loss: 0.010939965955913067\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 14 | MAE/CSI: 0.9428760437961236 | MAE: 0.11239923224568138 | CSI: 0.11920891721058764 | Loss: 0.011009103618562222\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 15 | MAE/CSI: 0.8349164229815289 | MAE: 0.10065259117082534 | CSI: 0.1205540918821011 | Loss: 0.010793941095471382\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 16 | MAE/CSI: 0.949548565503695 | MAE: 0.11309021113243763 | CSI: 0.11909892262487758 | Loss: 0.010786321945488453\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 17 | MAE/CSI: 0.8586650737446018 | MAE: 0.10326295585412668 | CSI: 0.12025987665125665 | Loss: 0.010778253898024559\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 18 | MAE/CSI: 1.0735752097766498 | MAE: 0.12621880998080615 | CSI: 0.11756867039244652 | Loss: 0.011026301421225071\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 19 | MAE/CSI: 0.8782398711638773 | MAE: 0.1054126679462572 | CSI: 0.12002719462700097 | Loss: 0.010951431468129158\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 20 | MAE/CSI: 0.961744436310654 | MAE: 0.1143953934740883 | CSI: 0.11894572939975458 | Loss: 0.010739585384726524\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 21 | MAE/CSI: 0.7626325761511776 | MAE: 0.0926679462571977 | CSI: 0.12151060570230005 | Loss: 0.011532123200595379\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 22 | MAE/CSI: 0.8935769142914931 | MAE: 0.10702495201535508 | CSI: 0.11977139326536917 | Loss: 0.010652078315615654\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 23 | MAE/CSI: 0.8583507216152066 | MAE: 0.10318618042226488 | CSI: 0.12021447390087271 | Loss: 0.010725504718720913\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 24 | MAE/CSI: 0.8592220571537372 | MAE: 0.10326295585412668 | CSI: 0.12018191920647006 | Loss: 0.010612650774419308\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 25 | MAE/CSI: 0.910653194743647 | MAE: 0.10886756238003839 | CSI: 0.11954887218045113 | Loss: 0.010592028498649597\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 26 | MAE/CSI: 0.8931625445119947 | MAE: 0.10694817658349329 | CSI: 0.1197410003808818 | Loss: 0.01058614905923605\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 27 | MAE/CSI: 0.9508534944090756 | MAE: 0.11316698656429942 | CSI: 0.11901621777567127 | Loss: 0.01062245573848486\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 28 | MAE/CSI: 0.8332982458772665 | MAE: 0.10042226487523992 | CSI: 0.1205117919919364 | Loss: 0.010614056140184402\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 29 | MAE/CSI: 0.9256429876075178 | MAE: 0.11047984644913628 | CSI: 0.1193547057853964 | Loss: 0.010645455680787563\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 30 | MAE/CSI: 0.8837848408830448 | MAE: 0.10595009596928982 | CSI: 0.11988222819317047 | Loss: 0.010543121956288815\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 31 | MAE/CSI: 0.8787736290945952 | MAE: 0.1054126679462572 | CSI: 0.11995429136168505 | Loss: 0.010546290315687656\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 32 | MAE/CSI: 0.8894002764619824 | MAE: 0.10656429942418426 | CSI: 0.11981590544046786 | Loss: 0.010572988539934158\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 33 | MAE/CSI: 0.8058613562069132 | MAE: 0.09742802303262955 | CSI: 0.12089923692383636 | Loss: 0.010588102042675018\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 34 | MAE/CSI: 0.9003947044562759 | MAE: 0.10771593090211132 | CSI: 0.1196319018404908 | Loss: 0.010574890300631523\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 35 | MAE/CSI: 0.937954560868547 | MAE: 0.11178502879078694 | CSI: 0.11917957804516235 | Loss: 0.010592584498226643\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 36 | MAE/CSI: 0.8038668596813368 | MAE: 0.09719769673704415 | CSI: 0.1209126804590137 | Loss: 0.010560299269855022\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 37 | MAE/CSI: 0.9095679654120318 | MAE: 0.10871401151631478 | CSI: 0.1195226917057903 | Loss: 0.010540400631725788\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 38 | MAE/CSI: 0.8960538365522246 | MAE: 0.1072552783109405 | CSI: 0.11969735961706728 | Loss: 0.010512854903936386\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 39 | MAE/CSI: 0.8620193727041435 | MAE: 0.1035700575815739 | CSI: 0.12014817863757975 | Loss: 0.010515077039599419\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 40 | MAE/CSI: 0.9139790870740366 | MAE: 0.10917466410748561 | CSI: 0.11944984918208304 | Loss: 0.010519146919250488\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 41 | MAE/CSI: 0.8697157778705387 | MAE: 0.10441458733205375 | CSI: 0.12005598839064253 | Loss: 0.010522506199777126\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 42 | MAE/CSI: 0.9010305223722821 | MAE: 0.10779270633397313 | CSI: 0.11963269129803683 | Loss: 0.010514034889638424\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 43 | MAE/CSI: 0.8783836508380664 | MAE: 0.10533589251439539 | CSI: 0.11992014242639416 | Loss: 0.010503781028091908\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 44 | MAE/CSI: 0.8798422581432144 | MAE: 0.105489443378119 | CSI: 0.11989585905985017 | Loss: 0.010508840903639793\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 45 | MAE/CSI: 0.9060126520071495 | MAE: 0.10833013435700575 | CSI: 0.11956801498975632 | Loss: 0.010515524074435234\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 46 | MAE/CSI: 0.8869249315791313 | MAE: 0.10625719769673704 | CSI: 0.11980404869966141 | Loss: 0.01050545647740364\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 47 | MAE/CSI: 0.891837017510536 | MAE: 0.10679462571976968 | CSI: 0.11974679635633781 | Loss: 0.010504990816116333\n" + ] + } + ], + "source": [ + "# AdamW bs50 lr 5e-4\n", + "fold = 0\n", + "train_fold(df, fold)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True, used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "Using native 16bit precision.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training fold0...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + " | Name | Type | Params\n", + "-----------------------------------------\n", + "0 | criterion | L1Loss | 0 \n", + "1 | encoder | Encoder | 666 K \n", + "2 | decoder | Decoder | 664 K \n", + "3 | out | Sequential | 577 \n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "481e8b8bf2704965b293683cf385071d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 1000000000000.0 | MAE: 1.0 | CSI: 0.0 | Loss: 0.07034780830144882\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "4e5af26e63b74c51816da86445c60b6f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 0.9437922424302221 | MAE: 0.11309021113243763 | CSI: 0.1198253238872696 | Loss: 0.01358635164797306\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 | MAE/CSI: 1.1144372062329246 | MAE: 0.1309021113243762 | CSI: 0.1174602845195231 | Loss: 0.01243552751839161\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 2 | MAE/CSI: 0.8626368456547189 | MAE: 0.10395393474088292 | CSI: 0.1205071812822022 | Loss: 0.012849031016230583\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 3 | MAE/CSI: 0.9002768018883869 | MAE: 0.10794625719769674 | CSI: 0.11990340856320238 | Loss: 0.011554539203643799\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 4 | MAE/CSI: 1.0671439714543172 | MAE: 0.1256813819577735 | CSI: 0.11777359505243346 | Loss: 0.011750211007893085\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 5 | MAE/CSI: 0.9762669221542187 | MAE: 0.11608445297504799 | CSI: 0.11890646947037925 | Loss: 0.011374273337423801\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 6 | MAE/CSI: 1.268498090098286 | MAE: 0.14641074856046066 | CSI: 0.1154205510454083 | Loss: 0.012063690461218357\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 7 | MAE/CSI: 1.1031854423828713 | MAE: 0.12944337811900192 | CSI: 0.11733601001686725 | Loss: 0.01138242892920971\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 8 | MAE/CSI: 0.8217496448119358 | MAE: 0.09927063339731286 | CSI: 0.12080398698463693 | Loss: 0.011296030133962631\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 9 | MAE/CSI: 1.1088743781210597 | MAE: 0.12998080614203456 | CSI: 0.11721869375426691 | Loss: 0.011095265857875347\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 10 | MAE/CSI: 0.9620862761671343 | MAE: 0.11447216890595009 | CSI: 0.11898326765561493 | Loss: 0.011082764714956284\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 11 | MAE/CSI: 0.8788553600805569 | MAE: 0.105489443378119 | CSI: 0.12003049440077472 | Loss: 0.010911011137068272\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 12 | MAE/CSI: 0.9072385822422571 | MAE: 0.10856046065259117 | CSI: 0.11966032174621005 | Loss: 0.010981745086610317\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 13 | MAE/CSI: 0.9236828223175657 | MAE: 0.11032629558541267 | CSI: 0.11944175307674865 | Loss: 0.010909012518823147\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 14 | MAE/CSI: 0.957795533843998 | MAE: 0.11401151631477927 | CSI: 0.11903533926103191 | Loss: 0.010993537493050098\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 15 | MAE/CSI: 0.8396774933378116 | MAE: 0.10119001919385796 | CSI: 0.12051057697256679 | Loss: 0.01081531960517168\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 16 | MAE/CSI: 0.9688093562689909 | MAE: 0.11516314779270634 | CSI: 0.11887080471151268 | Loss: 0.01078347209841013\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 17 | MAE/CSI: 0.8454214584932136 | MAE: 0.1018042226487524 | CSI: 0.12041830926476794 | Loss: 0.010767708532512188\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 18 | MAE/CSI: 1.0531790624176152 | MAE: 0.12406909788867562 | CSI: 0.11780437184424918 | Loss: 0.011013168841600418\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 19 | MAE/CSI: 0.9021290016177816 | MAE: 0.10802303262955854 | CSI: 0.11974233444988405 | Loss: 0.011004014872014523\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 20 | MAE/CSI: 0.9261467066757963 | MAE: 0.11055662188099807 | CSI: 0.11937268802357572 | Loss: 0.010737544856965542\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 21 | MAE/CSI: 0.738513689685624 | MAE: 0.08998080614203455 | CSI: 0.12184040377044293 | Loss: 0.011556853540241718\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 22 | MAE/CSI: 0.9217058638475384 | MAE: 0.11009596928982726 | CSI: 0.11944805127888043 | Loss: 0.010670513845980167\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 23 | MAE/CSI: 0.8700357218950803 | MAE: 0.10449136276391555 | CSI: 0.12010008340283569 | Loss: 0.010724274441599846\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 24 | MAE/CSI: 0.8609963941380611 | MAE: 0.1034932821497121 | CSI: 0.12020176025528849 | Loss: 0.010626938194036484\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 25 | MAE/CSI: 0.9355479405170692 | MAE: 0.11155470249520154 | CSI: 0.1192399637292886 | Loss: 0.010589144192636013\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 26 | MAE/CSI: 0.8959708216309327 | MAE: 0.1072552783109405 | CSI: 0.11970845000823588 | Loss: 0.010607089847326279\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 27 | MAE/CSI: 0.9535970473260313 | MAE: 0.11347408829174664 | CSI: 0.11899584694497975 | Loss: 0.010617554187774658\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 28 | MAE/CSI: 0.8343800580415958 | MAE: 0.10057581573896353 | CSI: 0.1205395727867638 | Loss: 0.010615648701786995\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 29 | MAE/CSI: 0.9420075384325132 | MAE: 0.11224568138195777 | CSI: 0.11915582073556538 | Loss: 0.010660631582140923\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 30 | MAE/CSI: 0.8873040214686102 | MAE: 0.10633397312859885 | CSI: 0.11983939050756717 | Loss: 0.01054992526769638\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 31 | MAE/CSI: 0.8851448177581197 | MAE: 0.10610364683301343 | CSI: 0.11987151108319863 | Loss: 0.01055346429347992\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 32 | MAE/CSI: 0.9027612589958749 | MAE: 0.10802303262955854 | CSI: 0.11965847177448426 | Loss: 0.010560257360339165\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 33 | MAE/CSI: 0.8119980124723106 | MAE: 0.0981190019193858 | CSI: 0.1208365050301397 | Loss: 0.010586261749267578\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 34 | MAE/CSI: 0.9156443179485217 | MAE: 0.10940499040307101 | CSI: 0.11948415804870009 | Loss: 0.010580405592918396\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 35 | MAE/CSI: 0.9486229827182243 | MAE: 0.11293666026871402 | CSI: 0.1190532617543715 | Loss: 0.010587424039840698\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 36 | MAE/CSI: 0.8343972284753538 | MAE: 0.10057581573896353 | CSI: 0.12053709229344582 | Loss: 0.01055158395320177\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 37 | MAE/CSI: 0.93222359442797 | MAE: 0.11117082533589251 | CSI: 0.1192533915676923 | Loss: 0.010545892640948296\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 38 | MAE/CSI: 0.8951787451674442 | MAE: 0.1071785028790787 | CSI: 0.11972860555143727 | Loss: 0.010513982735574245\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 39 | MAE/CSI: 0.8696352179660404 | MAE: 0.10441458733205375 | CSI: 0.12006710994915289 | Loss: 0.010518943890929222\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 40 | MAE/CSI: 0.9230175593085749 | MAE: 0.11017274472168906 | CSI: 0.11936148300720906 | Loss: 0.010523991659283638\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 41 | MAE/CSI: 0.865406931654537 | MAE: 0.10395393474088292 | CSI: 0.12012144915603129 | Loss: 0.010527591221034527\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 42 | MAE/CSI: 0.9123202405399485 | MAE: 0.109021113243762 | CSI: 0.11949873344728876 | Loss: 0.010516504757106304\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 43 | MAE/CSI: 0.878891236774245 | MAE: 0.1054126679462572 | CSI: 0.11993823983530623 | Loss: 0.010505661368370056\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 44 | MAE/CSI: 0.8987529011976216 | MAE: 0.10756238003838772 | CSI: 0.1196795914585178 | Loss: 0.010510049760341644\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 45 | MAE/CSI: 0.9165924059496204 | MAE: 0.10948176583493283 | CSI: 0.11944433002430284 | Loss: 0.010518706403672695\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 46 | MAE/CSI: 0.8994440072829465 | MAE: 0.10763915547024952 | CSI: 0.11967299197924282 | Loss: 0.010508017614483833\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 47 | MAE/CSI: 0.898789915093023 | MAE: 0.10756238003838772 | CSI: 0.1196746628230207 | Loss: 0.010506379418075085\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 48 | MAE/CSI: 0.9051807552372788 | MAE: 0.10825335892514396 | CSI: 0.11959308491469405 | Loss: 0.010508042760193348\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 49 | MAE/CSI: 0.8824160365458403 | MAE: 0.10579654510556621 | CSI: 0.11989417771555629 | Loss: 0.010502020828425884\n", + "\n", + "Model saved at ../models/baseline_fold0_bs128_epochs50_lr0.0005_adamw.ckpt\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True, used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "Using native 16bit precision.\n", + "\n", + " | Name | Type | Params\n", + "-----------------------------------------\n", + "0 | criterion | L1Loss | 0 \n", + "1 | encoder | Encoder | 666 K \n", + "2 | decoder | Decoder | 664 K \n", + "3 | out | Sequential | 577 \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training fold1...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ce395b010f24474e8046507208c7fa0b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 1000000000000.0 | MAE: 1.0 | CSI: 0.0 | Loss: 0.05332362651824951\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6d0245aeda304488a763823e82778836", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 1.4649042747142322 | MAE: 0.1602899718082964 | CSI: 0.10942009971136185 | Loss: 0.013296589255332947\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 | MAE/CSI: 1.290671138246638 | MAE: 0.14337494965767217 | CSI: 0.1110855781986066 | Loss: 0.0125888055190444\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 2 | MAE/CSI: 1.2783221086827037 | MAE: 0.14208618606524365 | CSI: 0.11115053482911558 | Loss: 0.011991249397397041\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 3 | MAE/CSI: 0.9015394723596947 | MAE: 0.10398711236407572 | CSI: 0.11534393728873313 | Loss: 0.01250691618770361\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 4 | MAE/CSI: 1.2422832403087438 | MAE: 0.13846153846153847 | CSI: 0.11145730214140571 | Loss: 0.012093799188733101\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 5 | MAE/CSI: 0.7875750740947998 | MAE: 0.0919049536850584 | CSI: 0.11669357843746118 | Loss: 0.01224138680845499\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 6 | MAE/CSI: 1.02749542826656 | MAE: 0.11695529601288764 | CSI: 0.11382561206055195 | Loss: 0.01118035800755024\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 7 | MAE/CSI: 1.195252290179063 | MAE: 0.13370922271445831 | CSI: 0.11186694542390864 | Loss: 0.011427664197981358\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 8 | MAE/CSI: 1.0641914459521074 | MAE: 0.12057994361659283 | CSI: 0.11330662736877063 | Loss: 0.011057481169700623\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 9 | MAE/CSI: 0.9369481790171504 | MAE: 0.10753121224325413 | CSI: 0.11476751292170329 | Loss: 0.010967950336635113\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 10 | MAE/CSI: 0.7484682012246785 | MAE: 0.08763592428513894 | CSI: 0.11708703741988837 | Loss: 0.011205273680388927\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 11 | MAE/CSI: 0.9380351910149244 | MAE: 0.1076923076923077 | CSI: 0.11480625537603764 | Loss: 0.011112474836409092\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 12 | MAE/CSI: 0.9998407717848802 | MAE: 0.11397503020539669 | CSI: 0.11399318113516482 | Loss: 0.010790626518428326\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 13 | MAE/CSI: 0.8450261420411643 | MAE: 0.09794603302456706 | CSI: 0.11590887920595327 | Loss: 0.01095589343458414\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 14 | MAE/CSI: 1.0267842632014805 | MAE: 0.11671365283930729 | CSI: 0.1136691094917749 | Loss: 0.010724475607275963\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 15 | MAE/CSI: 0.9396971828684089 | MAE: 0.10777285541683447 | CSI: 0.1146889204104242 | Loss: 0.010688774287700653\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 16 | MAE/CSI: 0.7939414270633738 | MAE: 0.09246878775674587 | CSI: 0.11646802220407489 | Loss: 0.010742655955255032\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 17 | MAE/CSI: 1.1852240181120506 | MAE: 0.13266210229561015 | CSI: 0.11192998139351164 | Loss: 0.011140896938741207\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 18 | MAE/CSI: 1.0105798988834238 | MAE: 0.11510269834877165 | CSI: 0.11389767248976207 | Loss: 0.010770338587462902\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 19 | MAE/CSI: 0.8335278197512083 | MAE: 0.09665726943213854 | CSI: 0.11596165975618583 | Loss: 0.010712211951613426\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 20 | MAE/CSI: 0.8510294880021778 | MAE: 0.09850986709625453 | CSI: 0.11575376468641403 | Loss: 0.010666263289749622\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 21 | MAE/CSI: 0.8465844434347423 | MAE: 0.09802658074909384 | CSI: 0.11579067098201822 | Loss: 0.01059988234192133\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 22 | MAE/CSI: 0.9672245166705439 | MAE: 0.11059202577527184 | CSI: 0.11433956012094602 | Loss: 0.010529949329793453\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 23 | MAE/CSI: 0.997738360733415 | MAE: 0.11373338703181635 | CSI: 0.11399119399119399 | Loss: 0.010567311197519302\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 24 | MAE/CSI: 0.898391589075928 | MAE: 0.10342327829238825 | CSI: 0.11512048815803082 | Loss: 0.010495680384337902\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 25 | MAE/CSI: 0.9244896330381611 | MAE: 0.10616190092629883 | CSI: 0.11483298149757855 | Loss: 0.010658398270606995\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 26 | MAE/CSI: 0.9298419514112347 | MAE: 0.10672573499798631 | CSI: 0.11477836081183179 | Loss: 0.010502760298550129\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 27 | MAE/CSI: 0.9704039361029636 | MAE: 0.11091421667337897 | CSI: 0.11429695670632578 | Loss: 0.010476275347173214\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 28 | MAE/CSI: 1.0245480664178566 | MAE: 0.11647200966572695 | CSI: 0.11368135227849807 | Loss: 0.010573608800768852\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 29 | MAE/CSI: 0.9706752428478566 | MAE: 0.11091421667337897 | CSI: 0.11426501035196687 | Loss: 0.010471746325492859\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 30 | MAE/CSI: 0.9757942295074935 | MAE: 0.11147805074506645 | CSI: 0.11424340027134232 | Loss: 0.010471350513398647\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 31 | MAE/CSI: 0.895234102893301 | MAE: 0.10310108739428112 | CSI: 0.11516662184804419 | Loss: 0.010457354597747326\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 32 | MAE/CSI: 0.8317717117439819 | MAE: 0.09641562625855819 | CSI: 0.115915971770152 | Loss: 0.01048274990171194\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 33 | MAE/CSI: 1.022891463602481 | MAE: 0.11631091421667338 | CSI: 0.11370797230628912 | Loss: 0.010477441363036633\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 34 | MAE/CSI: 0.8913645102599291 | MAE: 0.1026983487716472 | CSI: 0.11521476072769395 | Loss: 0.01042830292135477\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 35 | MAE/CSI: 0.9030110911423105 | MAE: 0.10390656463954893 | CSI: 0.11506676458115697 | Loss: 0.010428202338516712\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 36 | MAE/CSI: 0.9846335480455448 | MAE: 0.11236407571486105 | CSI: 0.11411765924176996 | Loss: 0.010422425344586372\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 37 | MAE/CSI: 0.9260881520169612 | MAE: 0.10632299637535239 | CSI: 0.1148087211167334 | Loss: 0.01040708739310503\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 38 | MAE/CSI: 0.909877958160892 | MAE: 0.10463149416028997 | CSI: 0.11499508612217452 | Loss: 0.010411178693175316\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 39 | MAE/CSI: 1.0135557133956454 | MAE: 0.115344341522352 | CSI: 0.1138016785825303 | Loss: 0.010445826686918736\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 40 | MAE/CSI: 0.9502076884845897 | MAE: 0.10881997583568265 | CSI: 0.11452230617948453 | Loss: 0.010401807725429535\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 41 | MAE/CSI: 0.9053927143266041 | MAE: 0.10414820781312928 | CSI: 0.11503097624292821 | Loss: 0.010389027185738087\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 42 | MAE/CSI: 0.946331936735179 | MAE: 0.10841723721304873 | CSI: 0.11456575964892307 | Loss: 0.010387437418103218\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 43 | MAE/CSI: 0.936402672617956 | MAE: 0.10737011679420057 | CSI: 0.11466233484050534 | Loss: 0.010386471636593342\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 44 | MAE/CSI: 0.9192851882635985 | MAE: 0.10559806685461136 | CSI: 0.11486975772246705 | Loss: 0.010383939370512962\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 45 | MAE/CSI: 0.9495266410697014 | MAE: 0.10873942811115586 | CSI: 0.11451961788845075 | Loss: 0.010390044189989567\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 46 | MAE/CSI: 0.9316930161089786 | MAE: 0.10688683044703987 | CSI: 0.11472322814278324 | Loss: 0.010382372885942459\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 47 | MAE/CSI: 0.9208344140544454 | MAE: 0.10575916230366492 | CSI: 0.11485144417777042 | Loss: 0.010380524210631847\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 48 | MAE/CSI: 0.9222982587823745 | MAE: 0.10592025775271849 | CSI: 0.11484382274735394 | Loss: 0.010379559360444546\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 49 | MAE/CSI: 0.9301114122389976 | MAE: 0.10672573499798631 | CSI: 0.11474510858881105 | Loss: 0.010381213389337063\n", + "\n", + "Model saved at ../models/baseline_fold1_bs128_epochs50_lr0.0005_adamw.ckpt\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True, used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "Using native 16bit precision.\n", + "\n", + " | Name | Type | Params\n", + "-----------------------------------------\n", + "0 | criterion | L1Loss | 0 \n", + "1 | encoder | Encoder | 666 K \n", + "2 | decoder | Decoder | 664 K \n", + "3 | out | Sequential | 577 \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training fold2...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3ae15ac210c2427e871a92657b1777b0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 1000000000000.0 | MAE: 1.0 | CSI: 0.0 | Loss: 0.06792119145393372\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "91c734768bbc49fc9502fae5a82efa7d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 1.0583238103064436 | MAE: 0.12316623519259434 | CSI: 0.11637859225322782 | Loss: 0.018375571817159653\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 | MAE/CSI: 1.3090439302248436 | MAE: 0.14827018121911037 | CSI: 0.11326600872159744 | Loss: 0.0128381522372365\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 2 | MAE/CSI: 1.0749716990031783 | MAE: 0.124499882325253 | CSI: 0.11581689307693106 | Loss: 0.012292868457734585\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 3 | MAE/CSI: 1.5262270258583122 | MAE: 0.16905938652231897 | CSI: 0.11076948819309364 | Loss: 0.014321698807179928\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 4 | MAE/CSI: 0.8281121601940147 | MAE: 0.09837608849140975 | CSI: 0.11879560912079053 | Loss: 0.013426098972558975\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 5 | MAE/CSI: 0.9191661424525188 | MAE: 0.10802541774535185 | CSI: 0.1175254535118094 | Loss: 0.011718537658452988\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 6 | MAE/CSI: 0.9063162075373175 | MAE: 0.10661332078136032 | CSI: 0.11763369108244068 | Loss: 0.01155879907310009\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 7 | MAE/CSI: 1.0341454411945987 | MAE: 0.12002824193927983 | CSI: 0.11606514631016929 | Loss: 0.01131217647343874\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 8 | MAE/CSI: 1.0352961355132178 | MAE: 0.12018514160194556 | CSI: 0.11608769460086121 | Loss: 0.011417840607464314\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 9 | MAE/CSI: 1.0188106526559253 | MAE: 0.11845924531262257 | CSI: 0.1162720913869459 | Loss: 0.011291706003248692\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 10 | MAE/CSI: 0.8717325295321008 | MAE: 0.10284772887738292 | CSI: 0.11798083172566053 | Loss: 0.011077452450990677\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 11 | MAE/CSI: 0.8649155800572065 | MAE: 0.10214168039538715 | CSI: 0.1180943929669604 | Loss: 0.011104998178780079\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 12 | MAE/CSI: 0.9185423571477423 | MAE: 0.10786851808268612 | CSI: 0.11743445170750849 | Loss: 0.011087087914347649\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 13 | MAE/CSI: 0.937323037184238 | MAE: 0.10982976386600769 | CSI: 0.11717386590113488 | Loss: 0.01098643708974123\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 14 | MAE/CSI: 0.9917942060852485 | MAE: 0.11555660155330666 | CSI: 0.11651268059775531 | Loss: 0.011159125715494156\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 15 | MAE/CSI: 0.87468919949303 | MAE: 0.10316152820271436 | CSI: 0.11794078200763437 | Loss: 0.010921032167971134\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 16 | MAE/CSI: 0.8852680079904909 | MAE: 0.10425982584137444 | CSI: 0.1177720474471377 | Loss: 0.01089795958250761\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 17 | MAE/CSI: 1.0215452563791814 | MAE: 0.1186161449752883 | CSI: 0.11611442981458897 | Loss: 0.010930635966360569\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 18 | MAE/CSI: 0.9635993753446791 | MAE: 0.11257550796265788 | CSI: 0.1168281246772561 | Loss: 0.010946526192128658\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 19 | MAE/CSI: 0.990630094343072 | MAE: 0.11539970189064093 | CSI: 0.11649121357066851 | Loss: 0.011101718060672283\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 20 | MAE/CSI: 0.9296621215830078 | MAE: 0.1089668157213462 | CSI: 0.11721120307114405 | Loss: 0.010813501663506031\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 21 | MAE/CSI: 1.053979582174606 | MAE: 0.1219894877226014 | CSI: 0.1157417940391735 | Loss: 0.010878403671085835\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 22 | MAE/CSI: 0.8713598432778761 | MAE: 0.10276927904605006 | CSI: 0.11794126139504187 | Loss: 0.01081886701285839\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 23 | MAE/CSI: 0.8094857284627778 | MAE: 0.09610104338275673 | CSI: 0.11871863826981134 | Loss: 0.010966574773192406\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 24 | MAE/CSI: 0.8459103347898103 | MAE: 0.10002353494939986 | CSI: 0.11824366110080396 | Loss: 0.010708808898925781\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 25 | MAE/CSI: 0.8446276818383552 | MAE: 0.09986663528673413 | CSI: 0.11823746419076277 | Loss: 0.010703754611313343\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 26 | MAE/CSI: 0.9895471686168905 | MAE: 0.11524280222797521 | CSI: 0.1164601404378356 | Loss: 0.010782686993479729\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 27 | MAE/CSI: 0.769290903636 | MAE: 0.09170785282811642 | CSI: 0.11921088939684109 | Loss: 0.01082665380090475\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 28 | MAE/CSI: 0.8198090405058381 | MAE: 0.09719934102141681 | CSI: 0.11856339247079187 | Loss: 0.010702860541641712\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 29 | MAE/CSI: 0.9320700526860811 | MAE: 0.10920216521534479 | CSI: 0.11716089890422832 | Loss: 0.010679910890758038\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 30 | MAE/CSI: 0.885834141218907 | MAE: 0.10425982584137444 | CSI: 0.11769677977982106 | Loss: 0.010661674663424492\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 31 | MAE/CSI: 0.8572212453409812 | MAE: 0.1012002824193928 | CSI: 0.11805619957340258 | Loss: 0.0106227220967412\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 32 | MAE/CSI: 0.9336796264837647 | MAE: 0.10935906487801052 | CSI: 0.11712696922489657 | Loss: 0.010662314482033253\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 33 | MAE/CSI: 0.9107966428687078 | MAE: 0.10692712010669177 | CSI: 0.11739955449220361 | Loss: 0.010680872946977615\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 34 | MAE/CSI: 0.9366501332996858 | MAE: 0.10967286420334196 | CSI: 0.11709053391797782 | Loss: 0.010653090663254261\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 35 | MAE/CSI: 0.9027136102945182 | MAE: 0.10606417196203029 | CSI: 0.1174948186795624 | Loss: 0.010605936869978905\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 36 | MAE/CSI: 0.9012293495300664 | MAE: 0.10590727229936456 | CSI: 0.1175142291512002 | Loss: 0.010662487708032131\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 37 | MAE/CSI: 0.8489388914597604 | MAE: 0.10033733427473131 | CSI: 0.1181914685300271 | Loss: 0.010626079514622688\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 38 | MAE/CSI: 0.9501680730955672 | MAE: 0.11108496116733349 | CSI: 0.11691085431283532 | Loss: 0.010639404878020287\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 39 | MAE/CSI: 0.850572468240016 | MAE: 0.10049423393739704 | CSI: 0.11814893814337384 | Loss: 0.010601965710520744\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 40 | MAE/CSI: 0.8908257676194905 | MAE: 0.10480897466070448 | CSI: 0.11765373035839485 | Loss: 0.010582203976809978\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 41 | MAE/CSI: 0.8185795509506226 | MAE: 0.09704244135875108 | CSI: 0.1185497991554228 | Loss: 0.010606820695102215\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 42 | MAE/CSI: 0.9011550090739481 | MAE: 0.10590727229936456 | CSI: 0.11752392344497607 | Loss: 0.010577197186648846\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 43 | MAE/CSI: 0.8879572498856813 | MAE: 0.10449517533537303 | CSI: 0.11768041237113402 | Loss: 0.010582765564322472\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 44 | MAE/CSI: 0.8653973955432505 | MAE: 0.1020632305640543 | CSI: 0.1179379913653646 | Loss: 0.010572961531579494\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 45 | MAE/CSI: 0.8799125258902079 | MAE: 0.10363222719071154 | CSI: 0.11777560171107561 | Loss: 0.010572567582130432\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 46 | MAE/CSI: 0.8857702229511826 | MAE: 0.10425982584137444 | CSI: 0.11770527292407608 | Loss: 0.010572829283773899\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 47 | MAE/CSI: 0.8616950848384399 | MAE: 0.10167098140738998 | CSI: 0.11798951066964791 | Loss: 0.010569415986537933\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 48 | MAE/CSI: 0.8690044359771177 | MAE: 0.1024554797207186 | CSI: 0.11789983511953833 | Loss: 0.010570191778242588\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 49 | MAE/CSI: 0.8623908562366291 | MAE: 0.10174943123872283 | CSI: 0.11798528532860705 | Loss: 0.01056947186589241\n", + "\n", + "Model saved at ../models/baseline_fold2_bs128_epochs50_lr0.0005_adamw.ckpt\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True, used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "Using native 16bit precision.\n", + "\n", + " | Name | Type | Params\n", + "-----------------------------------------\n", + "0 | criterion | L1Loss | 0 \n", + "1 | encoder | Encoder | 666 K \n", + "2 | decoder | Decoder | 664 K \n", + "3 | out | Sequential | 577 \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training fold3...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "160e3a284f5c4f829aaa6b0a3978cd3d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 1000000000000.0 | MAE: 1.0 | CSI: 0.0 | Loss: 0.06584510207176208\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "72d68bdf64be49209a7d07be0f904497", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 0.9761461365590198 | MAE: 0.11572746298736532 | CSI: 0.1185554689529748 | Loss: 0.016777969896793365\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 | MAE/CSI: 0.8674826352979895 | MAE: 0.10386791721571971 | CSI: 0.1197348661384703 | Loss: 0.013262536376714706\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 2 | MAE/CSI: 0.984027520372654 | MAE: 0.11627005658476088 | CSI: 0.11815732200228003 | Loss: 0.01209369394928217\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 3 | MAE/CSI: 0.9891841708089898 | MAE: 0.1166576234400434 | CSI: 0.11793316844490899 | Loss: 0.011719717644155025\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 4 | MAE/CSI: 1.0537721338059343 | MAE: 0.12340128672195953 | CSI: 0.11710433666072982 | Loss: 0.011907537467777729\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 5 | MAE/CSI: 1.2459166432193667 | MAE: 0.1432447097124254 | CSI: 0.11497134297928997 | Loss: 0.012449313886463642\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 6 | MAE/CSI: 0.9462066203023297 | MAE: 0.11192930780559647 | CSI: 0.11829267033545683 | Loss: 0.011184750124812126\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 7 | MAE/CSI: 0.9627440123754404 | MAE: 0.11371211533989613 | CSI: 0.11811251368716814 | Loss: 0.011118467897176743\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 8 | MAE/CSI: 0.8771097191461886 | MAE: 0.10448802418417177 | CSI: 0.11912765518663641 | Loss: 0.01101650483906269\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 9 | MAE/CSI: 0.9138557533172998 | MAE: 0.10851871947911014 | CSI: 0.11874819311939867 | Loss: 0.011204872280359268\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 10 | MAE/CSI: 0.8554829748464027 | MAE: 0.10216262305247656 | CSI: 0.11942098914354644 | Loss: 0.010996270924806595\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 11 | MAE/CSI: 1.097735005830804 | MAE: 0.12781954887218044 | CSI: 0.11643934846948278 | Loss: 0.011197962798178196\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 12 | MAE/CSI: 0.9800621222273554 | MAE: 0.1154949228741958 | CSI: 0.11784449194989208 | Loss: 0.010927666909992695\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 13 | MAE/CSI: 0.8232752663309416 | MAE: 0.09859700798387722 | CSI: 0.11976189740579397 | Loss: 0.010764451697468758\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 14 | MAE/CSI: 1.0497969881777247 | MAE: 0.12278117975350748 | CSI: 0.11695706992414377 | Loss: 0.010931246913969517\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 15 | MAE/CSI: 0.8431765659576419 | MAE: 0.10076738237345942 | CSI: 0.11950923026207351 | Loss: 0.010661198757588863\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 16 | MAE/CSI: 1.0157086375232767 | MAE: 0.11921556468490814 | CSI: 0.11737181341156055 | Loss: 0.010978718288242817\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 17 | MAE/CSI: 0.9975301019021314 | MAE: 0.11727773040849547 | CSI: 0.11756811166286404 | Loss: 0.01100278552621603\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 18 | MAE/CSI: 0.8262334319267731 | MAE: 0.09890706146810324 | CSI: 0.11970837495237409 | Loss: 0.010737092234194279\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 19 | MAE/CSI: 0.790807704258165 | MAE: 0.09503139291527789 | CSI: 0.12017003931901929 | Loss: 0.010724930092692375\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 20 | MAE/CSI: 0.8398138968407228 | MAE: 0.10037981551817689 | CSI: 0.11952626158599382 | Loss: 0.010706203989684582\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 21 | MAE/CSI: 0.8178549615586439 | MAE: 0.09797690101542517 | CSI: 0.11979740372044184 | Loss: 0.010585488751530647\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 22 | MAE/CSI: 0.8513987351271364 | MAE: 0.101620029455081 | CSI: 0.11935656615587412 | Loss: 0.010546199977397919\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 23 | MAE/CSI: 0.9827907270442443 | MAE: 0.11572746298736532 | CSI: 0.11775392237819983 | Loss: 0.010765822604298592\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 24 | MAE/CSI: 1.0323321113161896 | MAE: 0.1209208588481513 | CSI: 0.11713367967692959 | Loss: 0.01074812188744545\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 25 | MAE/CSI: 0.8066592825906277 | MAE: 0.09673668707852104 | CSI: 0.11992261065544246 | Loss: 0.010556642897427082\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 26 | MAE/CSI: 0.9649874075775624 | MAE: 0.11378962871095263 | CSI: 0.1179182524211764 | Loss: 0.010568722151219845\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 27 | MAE/CSI: 0.8237076721618036 | MAE: 0.09859700798387722 | CSI: 0.11969902832674571 | Loss: 0.010539213195443153\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 28 | MAE/CSI: 0.8324059296905256 | MAE: 0.0995271684365553 | CSI: 0.11956566488266776 | Loss: 0.01049722172319889\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 29 | MAE/CSI: 0.956361963256247 | MAE: 0.11285946825827456 | CSI: 0.11800915614946796 | Loss: 0.010548410937190056\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 30 | MAE/CSI: 0.9077731340303097 | MAE: 0.10766607239748857 | CSI: 0.1186046032432878 | Loss: 0.010499398224055767\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 31 | MAE/CSI: 0.8091561787903138 | MAE: 0.09696922719169057 | CSI: 0.11983993910279489 | Loss: 0.010456260293722153\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 32 | MAE/CSI: 0.9129124508798632 | MAE: 0.10820866599488412 | CSI: 0.11853126320018957 | Loss: 0.010469174012541771\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 33 | MAE/CSI: 0.9177270800913426 | MAE: 0.10875125959227967 | CSI: 0.11850065444351689 | Loss: 0.010467367246747017\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 34 | MAE/CSI: 0.8670352892262351 | MAE: 0.10324781024726766 | CSI: 0.11908143939393939 | Loss: 0.010418031364679337\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 35 | MAE/CSI: 0.8610727589358931 | MAE: 0.1026277032788156 | CSI: 0.1191858669466922 | Loss: 0.01043530460447073\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 36 | MAE/CSI: 0.7667947876794768 | MAE: 0.09231842492830013 | CSI: 0.12039521709180263 | Loss: 0.010462336242198944\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 37 | MAE/CSI: 0.8685409212526897 | MAE: 0.10340283698938067 | CSI: 0.11905350048374812 | Loss: 0.010403391905128956\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 38 | MAE/CSI: 0.7888173193755786 | MAE: 0.09472133943105186 | CSI: 0.12008019740900679 | Loss: 0.010409584268927574\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 39 | MAE/CSI: 0.7951825273769476 | MAE: 0.09541895977056042 | CSI: 0.11999629831470494 | Loss: 0.010402645915746689\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 40 | MAE/CSI: 0.8022224399679789 | MAE: 0.0961940934811255 | CSI: 0.11990950226244344 | Loss: 0.010389924049377441\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 41 | MAE/CSI: 0.7697754150328099 | MAE: 0.09262847841252617 | CSI: 0.12033182224689302 | Loss: 0.010405430570244789\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 42 | MAE/CSI: 0.8248249742946362 | MAE: 0.09867452135493372 | CSI: 0.11963086039979835 | Loss: 0.010377668775618076\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 43 | MAE/CSI: 0.8526745313404479 | MAE: 0.10169754282613751 | CSI: 0.1192688875852913 | Loss: 0.010379205457866192\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 44 | MAE/CSI: 0.8434967891742426 | MAE: 0.10068986900240291 | CSI: 0.11937196477076302 | Loss: 0.010375346057116985\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 45 | MAE/CSI: 0.8384023821780489 | MAE: 0.10014727540500737 | CSI: 0.11945013221932975 | Loss: 0.010371914133429527\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 46 | MAE/CSI: 0.8541217422244934 | MAE: 0.10185256956825052 | CSI: 0.11924830446550784 | Loss: 0.010385624133050442\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 47 | MAE/CSI: 0.8484406626048611 | MAE: 0.10123246259979847 | CSI: 0.11931590158367548 | Loss: 0.01037642452865839\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 48 | MAE/CSI: 0.8419922996367154 | MAE: 0.1005348422602899 | CSI: 0.1194011421515666 | Loss: 0.010372198186814785\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 49 | MAE/CSI: 0.8477266269374579 | MAE: 0.10115494922874196 | CSI: 0.11932496398435892 | Loss: 0.010375108569860458\n", + "\n", + "Model saved at ../models/baseline_fold3_bs128_epochs50_lr0.0005_adamw.ckpt\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True, used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "Using native 16bit precision.\n", + "\n", + " | Name | Type | Params\n", + "-----------------------------------------\n", + "0 | criterion | L1Loss | 0 \n", + "1 | encoder | Encoder | 666 K \n", + "2 | decoder | Decoder | 664 K \n", + "3 | out | Sequential | 577 \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training fold4...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "280f68c17ce24ae5a2c7cc2dd69b0c1a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 1000000000000.0 | MAE: 1.0 | CSI: 0.0 | Loss: 0.051339007914066315\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f05780bf342f417cbd09bf0187bfff83", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 1.2923509983887724 | MAE: 0.14612412270325684 | CSI: 0.1130684488069754 | Loss: 0.013486770913004875\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 | MAE/CSI: 1.2824421687627843 | MAE: 0.14502010882422522 | CSI: 0.11308120736769645 | Loss: 0.013099905103445053\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 2 | MAE/CSI: 1.0337804664341868 | MAE: 0.119785505874931 | CSI: 0.11587131868245944 | Loss: 0.013135443441569805\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 3 | MAE/CSI: 0.8676452444378151 | MAE: 0.10212128381042504 | CSI: 0.11769935289131468 | Loss: 0.012131770141422749\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 4 | MAE/CSI: 1.0104902871287074 | MAE: 0.11710432931156849 | CSI: 0.11588862436600766 | Loss: 0.011633886024355888\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 5 | MAE/CSI: 1.06292039288501 | MAE: 0.12254554057251005 | CSI: 0.11529136273209549 | Loss: 0.011597689241170883\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 6 | MAE/CSI: 0.9304383558461906 | MAE: 0.10866650895039823 | CSI: 0.11679065922711304 | Loss: 0.011226480826735497\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 7 | MAE/CSI: 0.8195911449757846 | MAE: 0.09683778881791656 | CSI: 0.118153776319726 | Loss: 0.011528192088007927\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 8 | MAE/CSI: 0.8785722230291139 | MAE: 0.10314643955524012 | CSI: 0.11740234536295317 | Loss: 0.011167199350893497\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 9 | MAE/CSI: 0.8861688488530748 | MAE: 0.10393502089740557 | CSI: 0.11728579833407306 | Loss: 0.011075042188167572\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 10 | MAE/CSI: 0.9453812543293786 | MAE: 0.11024367163472912 | CSI: 0.11661292322956716 | Loss: 0.011175395920872688\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 11 | MAE/CSI: 0.9979241244851905 | MAE: 0.11576374102988723 | CSI: 0.11600455203807158 | Loss: 0.011165978386998177\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 12 | MAE/CSI: 1.007878343679855 | MAE: 0.11671003864048576 | CSI: 0.11579774421321423 | Loss: 0.011008608154952526\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 13 | MAE/CSI: 1.1125746265559728 | MAE: 0.12751360302815234 | CSI: 0.11461128088258145 | Loss: 0.011036342941224575\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 14 | MAE/CSI: 0.7850852822097438 | MAE: 0.09305259837552243 | CSI: 0.11852546530082239 | Loss: 0.010939818806946278\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 15 | MAE/CSI: 1.0153115494115053 | MAE: 0.11749861998265121 | CSI: 0.11572666542574093 | Loss: 0.011188478209078312\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 16 | MAE/CSI: 0.986025980481453 | MAE: 0.11442315274820598 | CSI: 0.11604476455209614 | Loss: 0.011116521432995796\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 17 | MAE/CSI: 1.0617361140641164 | MAE: 0.12223010803564388 | CSI: 0.11512286943571073 | Loss: 0.010877163149416447\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 18 | MAE/CSI: 0.9572121558345978 | MAE: 0.11142654364797729 | CSI: 0.11640736378850802 | Loss: 0.011018428951501846\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 19 | MAE/CSI: 0.7521329035471026 | MAE: 0.08942512420156139 | CSI: 0.11889537577610972 | Loss: 0.010738796554505825\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 20 | MAE/CSI: 0.9892293528688089 | MAE: 0.11473858528507215 | CSI: 0.11598784948236315 | Loss: 0.01070215180516243\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 21 | MAE/CSI: 0.9864335429133184 | MAE: 0.11442315274820598 | CSI: 0.1159968186090711 | Loss: 0.010646497830748558\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 22 | MAE/CSI: 0.84469233657039 | MAE: 0.09944010724706254 | CSI: 0.117723463255229 | Loss: 0.010614980012178421\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 23 | MAE/CSI: 0.8534076363799895 | MAE: 0.10038640485766107 | CSI: 0.11763007568414757 | Loss: 0.010652757249772549\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 24 | MAE/CSI: 0.8727730631181053 | MAE: 0.10243671634729122 | CSI: 0.11736924599901007 | Loss: 0.010685691609978676\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 25 | MAE/CSI: 0.8287695273573423 | MAE: 0.09770522829429856 | CSI: 0.11789191695430426 | Loss: 0.010555294342339039\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 26 | MAE/CSI: 1.0085243594280549 | MAE: 0.11671003864048576 | CSI: 0.1157235693401246 | Loss: 0.010769852437078953\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 27 | MAE/CSI: 0.861976988570521 | MAE: 0.10125384433404305 | CSI: 0.11746699236263566 | Loss: 0.010523894801735878\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 28 | MAE/CSI: 0.87451386176962 | MAE: 0.10259443261572432 | CSI: 0.11731595930022783 | Loss: 0.010531861335039139\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 29 | MAE/CSI: 0.8629818743300961 | MAE: 0.1013327024682596 | CSI: 0.11742158842682274 | Loss: 0.010494248941540718\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 30 | MAE/CSI: 0.8834115504948823 | MAE: 0.10354073022632285 | CSI: 0.11720554272517321 | Loss: 0.010468833148479462\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 31 | MAE/CSI: 0.8848855029908551 | MAE: 0.10369844649475593 | CSI: 0.11718854715483198 | Loss: 0.01055654976516962\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 32 | MAE/CSI: 0.8179083911242114 | MAE: 0.0965223562810504 | CSI: 0.1180112067899963 | Loss: 0.010491106659173965\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 33 | MAE/CSI: 0.8508487946200639 | MAE: 0.10007097232079488 | CSI: 0.11761310934762445 | Loss: 0.01048244908452034\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 34 | MAE/CSI: 0.8400339456350763 | MAE: 0.09888810030754672 | CSI: 0.11771917173173999 | Loss: 0.010445468127727509\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 35 | MAE/CSI: 0.8725101451576892 | MAE: 0.10235785821307468 | CSI: 0.11731423271153252 | Loss: 0.01043706201016903\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 36 | MAE/CSI: 0.9569020424454356 | MAE: 0.1112688273795442 | CSI: 0.11628026949783844 | Loss: 0.010467946529388428\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 37 | MAE/CSI: 0.8362668945479855 | MAE: 0.09849380963646401 | CSI: 0.11777796093299267 | Loss: 0.010502946563065052\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 38 | MAE/CSI: 0.8416683758276093 | MAE: 0.09904581657597981 | CSI: 0.11767795894403989 | Loss: 0.010454561561346054\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 39 | MAE/CSI: 0.8865609501271944 | MAE: 0.10385616276318903 | CSI: 0.11714497773379515 | Loss: 0.010411867871880531\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 40 | MAE/CSI: 0.8445140682777522 | MAE: 0.099361249112846 | CSI: 0.11765493654194824 | Loss: 0.010427672415971756\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 41 | MAE/CSI: 0.8718152715712825 | MAE: 0.10227900007885814 | CSI: 0.11731728430685519 | Loss: 0.010411477647721767\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 42 | MAE/CSI: 0.8327265628359038 | MAE: 0.09809951896538129 | CSI: 0.11780519962094681 | Loss: 0.010417300276458263\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 43 | MAE/CSI: 0.8829431978104287 | MAE: 0.1034618720921063 | CSI: 0.11717840099770158 | Loss: 0.010404979810118675\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 44 | MAE/CSI: 0.8866157844264348 | MAE: 0.10385616276318903 | CSI: 0.11713773269837344 | Loss: 0.010403498075902462\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 45 | MAE/CSI: 0.870380796129511 | MAE: 0.10212128381042504 | CSI: 0.11732943128303948 | Loss: 0.010401197709143162\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 46 | MAE/CSI: 0.8607582781995061 | MAE: 0.10109612806560997 | CSI: 0.11745007933727615 | Loss: 0.01040069479495287\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 47 | MAE/CSI: 0.8755787668255379 | MAE: 0.10267329074994086 | CSI: 0.11726334013479256 | Loss: 0.010401761159300804\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 48 | MAE/CSI: 0.8866706187256753 | MAE: 0.10385616276318903 | CSI: 0.11713048855905998 | Loss: 0.010412335395812988\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 49 | MAE/CSI: 0.8792732161385125 | MAE: 0.10306758142102358 | CSI: 0.11721906175282382 | Loss: 0.010402982123196125\n", + "\n", + "Model saved at ../models/baseline_fold4_bs128_epochs50_lr0.0005_adamw.ckpt\n" + ] + }, + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 15;\n", + " var nbb_unformatted_code = \"# AdamW bs50 lr 5e-4\\nfor fold in range(5):\\n train_fold(df, fold)\";\n", + " var nbb_formatted_code = \"# AdamW bs50 lr 5e-4\\nfor fold in range(5):\\n train_fold(df, fold)\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# AdamW bs50 lr 5e-4\n", + "for fold in range(5):\n", + " train_fold(df, fold)" + ] + }, + { + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "model = Baseline.load_from_checkpoint(\"baseline_bs256_epoch10.ckpt\")\n", - "datamodule = NowcastingDataModule(batch_size=256)\n", - "datamodule.setup(\"test\")" + "## Inference" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 21;\n", + " var nbb_unformatted_code = \"def inference(checkpoints):\\n datamodule = NowcastingDataModule()\\n datamodule.setup(\\\"test\\\")\\n\\n test_paths = datamodule.test_dataset.paths\\n test_filenames = [path.name for path in test_paths]\\n final_preds = np.zeros((len(datamodule.test_dataset), 14400))\\n\\n for checkpoint in checkpoints:\\n print(\\\"Inference from\\\", checkpoint)\\n model = Baseline.load_from_checkpoint(str(checkpoint))\\n model.cuda()\\n model.eval()\\n preds = []\\n with torch.no_grad():\\n for batch in tqdm(datamodule.test_dataloader()):\\n batch = batch.cuda()\\n imgs = model(batch)\\n imgs = imgs.detach().cpu().numpy()\\n imgs = imgs[:, 0, 4:124, 4:124]\\n imgs = args[\\\"rng\\\"] * imgs\\n imgs = imgs.clip(0, 255)\\n imgs = imgs.round()\\n preds.append(imgs)\\n\\n preds = np.concatenate(preds)\\n preds = preds.astype(np.uint8)\\n preds = preds.reshape(-1, 14400)\\n final_preds += preds / len(checkpoints)\\n\\n del model\\n gc.collect()\\n torch.cuda.empty_cache()\\n \\n final_preds = final_preds.round()\\n final_preds = final_preds.astype(np.uint8)\\n\\n subm = pd.DataFrame()\\n subm[\\\"file_name\\\"] = test_filenames\\n for i in tqdm(range(14400)):\\n subm[str(i)] = final_preds[:, i]\\n\\n return subm\";\n", + " var nbb_formatted_code = \"def inference(checkpoints):\\n datamodule = NowcastingDataModule()\\n datamodule.setup(\\\"test\\\")\\n\\n test_paths = datamodule.test_dataset.paths\\n test_filenames = [path.name for path in test_paths]\\n final_preds = np.zeros((len(datamodule.test_dataset), 14400))\\n\\n for checkpoint in checkpoints:\\n print(\\\"Inference from\\\", checkpoint)\\n model = Baseline.load_from_checkpoint(str(checkpoint))\\n model.cuda()\\n model.eval()\\n preds = []\\n with torch.no_grad():\\n for batch in tqdm(datamodule.test_dataloader()):\\n batch = batch.cuda()\\n imgs = model(batch)\\n imgs = imgs.detach().cpu().numpy()\\n imgs = imgs[:, 0, 4:124, 4:124]\\n imgs = args[\\\"rng\\\"] * imgs\\n imgs = imgs.clip(0, 255)\\n imgs = imgs.round()\\n preds.append(imgs)\\n\\n preds = np.concatenate(preds)\\n preds = preds.astype(np.uint8)\\n preds = preds.reshape(-1, 14400)\\n final_preds += preds / len(checkpoints)\\n\\n del model\\n gc.collect()\\n torch.cuda.empty_cache()\\n\\n final_preds = final_preds.round()\\n final_preds = final_preds.astype(np.uint8)\\n\\n subm = pd.DataFrame()\\n subm[\\\"file_name\\\"] = test_filenames\\n for i in tqdm(range(14400)):\\n subm[str(i)] = final_preds[:, i]\\n\\n return subm\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "preds = []\n", - "model.to(\"cuda\")\n", - "model.eval()\n", - "with torch.no_grad():\n", - " for batch in tqdm(datamodule.test_dataloader(), total=len(datamodule.test_dataloader())):\n", - " batch = batch.to(\"cuda\")\n", - " imgs = model(batch)\n", - " imgs = imgs.detach().cpu().numpy()\n", - " imgs = np.round(imgs)\n", - " imgs = np.clip(imgs, 0, 255)\n", - " preds.append(imgs)\n", + "def inference(checkpoints):\n", + " datamodule = NowcastingDataModule()\n", + " datamodule.setup(\"test\")\n", + "\n", + " test_paths = datamodule.test_dataset.paths\n", + " test_filenames = [path.name for path in test_paths]\n", + " final_preds = np.zeros((len(datamodule.test_dataset), 14400))\n", + "\n", + " for checkpoint in checkpoints:\n", + " print(\"Inference from\", checkpoint)\n", + " model = Baseline.load_from_checkpoint(str(checkpoint))\n", + " model.cuda()\n", + " model.eval()\n", + " preds = []\n", + " with torch.no_grad():\n", + " for batch in tqdm(datamodule.test_dataloader()):\n", + " batch = batch.cuda()\n", + " imgs = model(batch)\n", + " imgs = imgs.detach().cpu().numpy()\n", + " imgs = imgs[:, 0, 4:124, 4:124]\n", + " imgs = args[\"rng\"] * imgs\n", + " imgs = imgs.clip(0, 255)\n", + " imgs = imgs.round()\n", + " preds.append(imgs)\n", + "\n", + " preds = np.concatenate(preds)\n", + " preds = preds.astype(np.uint8)\n", + " preds = preds.reshape(-1, 14400)\n", + " final_preds += preds / len(checkpoints)\n", + "\n", + " del model\n", + " gc.collect()\n", + " torch.cuda.empty_cache()\n", + "\n", + " final_preds = final_preds.round()\n", + " final_preds = final_preds.astype(np.uint8)\n", "\n", - "preds = np.concatenate(preds)\n", - "preds = preds.astype(np.uint8)\n", - "preds = preds.reshape(len(preds), -1)" + " subm = pd.DataFrame()\n", + " subm[\"file_name\"] = test_filenames\n", + " for i in tqdm(range(14400)):\n", + " subm[str(i)] = final_preds[:, i]\n", + "\n", + " return subm" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 22;\n", + " var nbb_unformatted_code = \"checkpoints = [\\n args[\\\"model_dir\\\"]\\n / f\\\"baseline_fold{fold}_bs{args['batch_size']}_epochs{args['max_epochs']}_lr{args['lr']}_{args['optimizer']}.ckpt\\\"\\n for fold in range(5)\\n]\";\n", + " var nbb_formatted_code = \"checkpoints = [\\n args[\\\"model_dir\\\"]\\n / f\\\"baseline_fold{fold}_bs{args['batch_size']}_epochs{args['max_epochs']}_lr{args['lr']}_{args['optimizer']}.ckpt\\\"\\n for fold in range(5)\\n]\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "test_paths = datamodule.test_dataset.paths\n", - "test_filenames = [path.name for path in test_paths]" + "checkpoints = [\n", + " args[\"model_dir\"]\n", + " / f\"baseline_fold{fold}_bs{args['batch_size']}_epochs{args['max_epochs']}_lr{args['lr']}_{args['optimizer']}.ckpt\"\n", + " for fold in range(5)\n", + "]" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 23, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Inference from ../models/baseline_fold0_bs128_epochs50_lr0.0005_adamw.ckpt\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "651cc21110814329a59e8982b6edfc9b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Inference from ../models/baseline_fold1_bs128_epochs50_lr0.0005_adamw.ckpt\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "857c3602b3e440f7975a323a9e442c27", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Inference from ../models/baseline_fold2_bs128_epochs50_lr0.0005_adamw.ckpt\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "17c43e63e7884cd5b7545d0ac0fcc843", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Inference from ../models/baseline_fold3_bs128_epochs50_lr0.0005_adamw.ckpt\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "22121b1f2aff4d55ba819a2d52ab567c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Inference from ../models/baseline_fold4_bs128_epochs50_lr0.0005_adamw.ckpt\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ea52c66bd6e343c7872133298ea9c5ca", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "66e0ae774de641cc916c66629576986f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14400.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 23;\n", + " var nbb_unformatted_code = \"subm = inference(checkpoints)\";\n", + " var nbb_formatted_code = \"subm = inference(checkpoints)\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "subm = pd.DataFrame()\n", - "subm[\"file_name\"] = test_filenames\n", - "for i in tqdm(range(14400)):\n", - " subm[str(i)] = preds[:, i]" + "subm = inference(checkpoints)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 24, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
file_name012345678...14390143911439214393143941439514396143971439814399
0test_00000.npy000000000...0000000000
1test_00001.npy000000000...0000000000
2test_00002.npy000000000...0000000000
3test_00003.npy000000000...0000000000
4test_00004.npy000000000...0000000000
\n", + "

5 rows Ă— 14401 columns

\n", + "
" + ], + "text/plain": [ + " file_name 0 1 2 3 4 5 6 7 8 ... 14390 14391 14392 14393 \\\n", + "0 test_00000.npy 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 \n", + "1 test_00001.npy 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 \n", + "2 test_00002.npy 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 \n", + "3 test_00003.npy 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 \n", + "4 test_00004.npy 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 \n", + "\n", + " 14394 14395 14396 14397 14398 14399 \n", + "0 0 0 0 0 0 0 \n", + "1 0 0 0 0 0 0 \n", + "2 0 0 0 0 0 0 \n", + "3 0 0 0 0 0 0 \n", + "4 0 0 0 0 0 0 \n", + "\n", + "[5 rows x 14401 columns]" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 24;\n", + " var nbb_unformatted_code = \"output_csv = (\\n args[\\\"output_dir\\\"]\\n / f\\\"baseline_bs{args['batch_size']}_epochs{args['max_epochs']}_lr{args['lr']}_{args['optimizer']}.csv\\\"\\n)\\nsubm.to_csv(output_csv, index=False)\\nsubm.head()\";\n", + " var nbb_formatted_code = \"output_csv = (\\n args[\\\"output_dir\\\"]\\n / f\\\"baseline_bs{args['batch_size']}_epochs{args['max_epochs']}_lr{args['lr']}_{args['optimizer']}.csv\\\"\\n)\\nsubm.to_csv(output_csv, index=False)\\nsubm.head()\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "subm.to_csv(\"baseline_epoch10.csv\", index=False)\n", + "output_csv = (\n", + " args[\"output_dir\"]\n", + " / f\"baseline_bs{args['batch_size']}_epochs{args['max_epochs']}_lr{args['lr']}_{args['optimizer']}.csv\"\n", + ")\n", + "subm.to_csv(output_csv, index=False)\n", "subm.head()" ] }, @@ -478,7 +14760,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python [conda env:torch]", + "display_name": "Python [conda env:torch] *", "language": "python", "name": "conda-env-torch-py" }, diff --git a/notebooks/02-rainnet.ipynb b/notebooks/02-rainnet.ipynb index 21c7e55..9e226c2 100644 --- a/notebooks/02-rainnet.ipynb +++ b/notebooks/02-rainnet.ipynb @@ -60,8 +60,8 @@ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 3;\n", - " var nbb_unformatted_code = \"import gc\\nimport functools\\nfrom pathlib import Path\\nfrom concurrent.futures import ThreadPoolExecutor\\nfrom tqdm.notebook import tqdm\\n\\nimport cv2\\nimport numpy as np\\nimport pandas as pd\\nimport matplotlib.pyplot as plt\\nfrom sklearn import metrics\\n\\nimport torch\\nimport torch.nn as nn\\nimport torch.nn.functional as F\\nimport torchvision.transforms as T\\nimport pytorch_lightning as pl\\nfrom torch.utils.data import SequentialSampler, RandomSampler\\n\\nimport transformers\\n\\nimport optim\\nfrom data import NowcastingDataset\\nfrom loss import LogCoshLoss\\nfrom utils import visualize, radar2precipitation, seed_everything\";\n", - " var nbb_formatted_code = \"import gc\\nimport functools\\nfrom pathlib import Path\\nfrom concurrent.futures import ThreadPoolExecutor\\nfrom tqdm.notebook import tqdm\\n\\nimport cv2\\nimport numpy as np\\nimport pandas as pd\\nimport matplotlib.pyplot as plt\\nfrom sklearn import metrics\\n\\nimport torch\\nimport torch.nn as nn\\nimport torch.nn.functional as F\\nimport torchvision.transforms as T\\nimport pytorch_lightning as pl\\nfrom torch.utils.data import SequentialSampler, RandomSampler\\n\\nimport transformers\\n\\nimport optim\\nfrom data import NowcastingDataset\\nfrom loss import LogCoshLoss\\nfrom utils import visualize, radar2precipitation, seed_everything\";\n", + " var nbb_unformatted_code = \"import gc\\nimport warnings\\nfrom pathlib import Path\\nfrom tqdm.notebook import tqdm\\n\\nimport cv2\\nimport numpy as np\\nimport pandas as pd\\nimport matplotlib.pyplot as plt\\nfrom sklearn import metrics\\n\\nimport torch\\nimport torch.nn as nn\\nimport torch.nn.functional as F\\nfrom torch.utils.data import Dataset, DataLoader, SequentialSampler, RandomSampler\\n\\nimport pytorch_lightning as pl\\n\\nimport torchvision.transforms as T\\nimport albumentations as A\\nfrom albumentations.pytorch import ToTensorV2\\n\\nfrom transformers import AdamW, get_cosine_schedule_with_warmup\\n\\nimport optim\\nimport loss\\nfrom utils import visualize, radar2precipitation, seed_everything\";\n", + " var nbb_formatted_code = \"import gc\\nimport warnings\\nfrom pathlib import Path\\nfrom tqdm.notebook import tqdm\\n\\nimport cv2\\nimport numpy as np\\nimport pandas as pd\\nimport matplotlib.pyplot as plt\\nfrom sklearn import metrics\\n\\nimport torch\\nimport torch.nn as nn\\nimport torch.nn.functional as F\\nfrom torch.utils.data import Dataset, DataLoader, SequentialSampler, RandomSampler\\n\\nimport pytorch_lightning as pl\\n\\nimport torchvision.transforms as T\\nimport albumentations as A\\nfrom albumentations.pytorch import ToTensorV2\\n\\nfrom transformers import AdamW, get_cosine_schedule_with_warmup\\n\\nimport optim\\nimport loss\\nfrom utils import visualize, radar2precipitation, seed_everything\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", @@ -84,9 +84,8 @@ ], "source": [ "import gc\n", - "import functools\n", + "import warnings\n", "from pathlib import Path\n", - "from concurrent.futures import ThreadPoolExecutor\n", "from tqdm.notebook import tqdm\n", "\n", "import cv2\n", @@ -98,21 +97,24 @@ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", - "import torchvision.transforms as T\n", + "from torch.utils.data import Dataset, DataLoader, SequentialSampler, RandomSampler\n", + "\n", "import pytorch_lightning as pl\n", - "from torch.utils.data import SequentialSampler, RandomSampler\n", "\n", - "import transformers\n", + "import torchvision.transforms as T\n", + "import albumentations as A\n", + "from albumentations.pytorch import ToTensorV2\n", + "\n", + "from transformers import AdamW, get_cosine_schedule_with_warmup\n", "\n", "import optim\n", - "from data import NowcastingDataset\n", - "from loss import LogCoshLoss\n", + "import loss\n", "from utils import visualize, radar2precipitation, seed_everything" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -120,9 +122,9 @@ "application/javascript": [ "\n", " setTimeout(function() {\n", - " var nbb_cell_id = 8;\n", - " var nbb_unformatted_code = \"args = dict(\\n seed=42,\\n dams=(6071, 6304, 7026, 7629, 7767, 8944, 11107),\\n train_folds_csv=Path(\\\"../input/train_folds.csv\\\"),\\n train_data_path=Path(\\\"../input/train-128\\\"),\\n test_data_path=Path(\\\"../input/test-128\\\"),\\n num_workers=4,\\n gpus=1,\\n lr=1e-4,\\n max_epochs=50,\\n batch_size=64,\\n precision=16,\\n optimizer=\\\"adamw\\\",\\n scheduler=\\\"cosine\\\",\\n accumulate_grad_batches=2,\\n gradient_clip_val=5.0,\\n rng=255.0,\\n)\";\n", - " var nbb_formatted_code = \"args = dict(\\n seed=42,\\n dams=(6071, 6304, 7026, 7629, 7767, 8944, 11107),\\n train_folds_csv=Path(\\\"../input/train_folds.csv\\\"),\\n train_data_path=Path(\\\"../input/train-128\\\"),\\n test_data_path=Path(\\\"../input/test-128\\\"),\\n num_workers=4,\\n gpus=1,\\n lr=1e-4,\\n max_epochs=50,\\n batch_size=64,\\n precision=16,\\n optimizer=\\\"adamw\\\",\\n scheduler=\\\"cosine\\\",\\n accumulate_grad_batches=2,\\n gradient_clip_val=5.0,\\n rng=255.0,\\n)\";\n", + " var nbb_cell_id = 4;\n", + " var nbb_unformatted_code = \"warnings.simplefilter(\\\"ignore\\\")\";\n", + " var nbb_formatted_code = \"warnings.simplefilter(\\\"ignore\\\")\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", @@ -144,48 +146,13 @@ } ], "source": [ - "args = dict(\n", - " seed=42,\n", - " dams=(6071, 6304, 7026, 7629, 7767, 8944, 11107),\n", - " train_folds_csv=Path(\"../input/train_folds.csv\"),\n", - " train_data_path=Path(\"../input/train-128\"),\n", - " test_data_path=Path(\"../input/test-128\"),\n", - " num_workers=4,\n", - " gpus=1,\n", - " lr=1e-4,\n", - " max_epochs=50,\n", - " batch_size=64,\n", - " precision=16,\n", - " optimizer=\"adamw\",\n", - " scheduler=\"cosine\",\n", - " accumulate_grad_batches=2,\n", - " gradient_clip_val=5.0,\n", - " rng=255.0,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 🔥 RainNet ⚡️" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "heading_collapsed": true - }, - "source": [ - "## Resize data" + "warnings.simplefilter(\"ignore\")" ] }, { "cell_type": "code", "execution_count": 5, - "metadata": { - "hidden": true - }, + "metadata": {}, "outputs": [ { "data": { @@ -193,146 +160,8 @@ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 5;\n", - " var nbb_unformatted_code = \"def resize_data(path, folder=\\\"train-128\\\"):\\n data = np.load(path)\\n img1 = data[:, :, :3]\\n img2 = data[:, :, 2:]\\n img1 = cv2.copyMakeBorder(img1, 4, 4, 4, 4, cv2.BORDER_REFLECT)\\n img2 = cv2.copyMakeBorder(img2, 4, 4, 4, 4, cv2.BORDER_REFLECT)\\n img2 = img2[:, :, 1:]\\n data = np.concatenate([img1, img2], axis=-1)\\n np.save(PATH / folder / path.name, data)\";\n", - " var nbb_formatted_code = \"def resize_data(path, folder=\\\"train-128\\\"):\\n data = np.load(path)\\n img1 = data[:, :, :3]\\n img2 = data[:, :, 2:]\\n img1 = cv2.copyMakeBorder(img1, 4, 4, 4, 4, cv2.BORDER_REFLECT)\\n img2 = cv2.copyMakeBorder(img2, 4, 4, 4, 4, cv2.BORDER_REFLECT)\\n img2 = img2[:, :, 1:]\\n data = np.concatenate([img1, img2], axis=-1)\\n np.save(PATH / folder / path.name, data)\";\n", - " var nbb_cells = Jupyter.notebook.get_cells();\n", - " for (var i = 0; i < nbb_cells.length; ++i) {\n", - " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", - " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", - " nbb_cells[i].set_text(nbb_formatted_code);\n", - " }\n", - " break;\n", - " }\n", - " }\n", - " }, 500);\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "def resize_data(path, folder=\"train-128\"):\n", - " data = np.load(path)\n", - " img1 = data[:, :, :3]\n", - " img2 = data[:, :, 2:]\n", - " img1 = cv2.copyMakeBorder(img1, 4, 4, 4, 4, cv2.BORDER_REFLECT)\n", - " img2 = cv2.copyMakeBorder(img2, 4, 4, 4, 4, cv2.BORDER_REFLECT)\n", - " img2 = img2[:, :, 1:]\n", - " data = np.concatenate([img1, img2], axis=-1)\n", - " np.save(PATH / folder / path.name, data)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "hidden": true - }, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'PATH' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;34m(\u001b[0m\u001b[0mPATH\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0;34m\"train-128\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmkdir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexist_ok\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mNameError\u001b[0m: name 'PATH' is not defined" - ] - }, - { - "data": { - "application/javascript": [ - "\n", - " setTimeout(function() {\n", - " var nbb_cell_id = 6;\n", - " var nbb_unformatted_code = \"(PATH / \\\"train-128\\\").mkdir(exist_ok=True)\";\n", - " var nbb_formatted_code = \"(PATH / \\\"train-128\\\").mkdir(exist_ok=True)\";\n", - " var nbb_cells = Jupyter.notebook.get_cells();\n", - " for (var i = 0; i < nbb_cells.length; ++i) {\n", - " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", - " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", - " nbb_cells[i].set_text(nbb_formatted_code);\n", - " }\n", - " break;\n", - " }\n", - " }\n", - " }, 500);\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "(PATH / \"train-128\").mkdir(exist_ok=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": { - "hidden": true - }, - "outputs": [ - { - "data": { - "application/javascript": [ - "\n", - " setTimeout(function() {\n", - " var nbb_cell_id = 21;\n", - " var nbb_unformatted_code = \"def resize_data(path):\\n data = np.load(path)\\n img1 = data[:, :, :3]\\n img2 = data[:, :, 2:]\\n img1 = cv2.copyMakeBorder(img1, 4, 4, 4, 4, cv2.BORDER_REFLECT)\\n img2 = cv2.copyMakeBorder(img2, 4, 4, 4, 4, cv2.BORDER_REFLECT)\\n img2 = img2[:, :, 1:]\\n data = np.concatenate([img1, img2], axis=-1)\\n np.save(PATH / \\\"train-128\\\" / path.name, data)\\n \\nfiles = list((PATH / \\\"train\\\").glob(\\\"*.npy\\\"))\\nwith ThreadPoolExecutor(8) as e: e.map(resize_data, files)\";\n", - " var nbb_formatted_code = \"def resize_data(path):\\n data = np.load(path)\\n img1 = data[:, :, :3]\\n img2 = data[:, :, 2:]\\n img1 = cv2.copyMakeBorder(img1, 4, 4, 4, 4, cv2.BORDER_REFLECT)\\n img2 = cv2.copyMakeBorder(img2, 4, 4, 4, 4, cv2.BORDER_REFLECT)\\n img2 = img2[:, :, 1:]\\n data = np.concatenate([img1, img2], axis=-1)\\n np.save(PATH / \\\"train-128\\\" / path.name, data)\\n\\n\\nfiles = list((PATH / \\\"train\\\").glob(\\\"*.npy\\\"))\\nwith ThreadPoolExecutor(8) as e:\\n e.map(resize_data, files)\";\n", - " var nbb_cells = Jupyter.notebook.get_cells();\n", - " for (var i = 0; i < nbb_cells.length; ++i) {\n", - " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", - " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", - " nbb_cells[i].set_text(nbb_formatted_code);\n", - " }\n", - " break;\n", - " }\n", - " }\n", - " }, 500);\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "files = list((PATH / \"train\").glob(\"*.npy\"))\n", - "with ThreadPoolExecutor(8) as e:\n", - " e.map(resize_data, files)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "hidden": true - }, - "outputs": [ - { - "data": { - "application/javascript": [ - "\n", - " setTimeout(function() {\n", - " var nbb_cell_id = 8;\n", - " var nbb_unformatted_code = \"(PATH / \\\"test-128\\\").mkdir(exist_ok=True)\";\n", - " var nbb_formatted_code = \"(PATH / \\\"test-128\\\").mkdir(exist_ok=True)\";\n", + " var nbb_unformatted_code = \"args = dict(\\n seed=42,\\n dams=(6071, 6304, 7026, 7629, 7767, 8944, 11107),\\n train_folds_csv=Path(\\\"../input/train_folds.csv\\\"),\\n train_data_path=Path(\\\"../input/train\\\"),\\n test_data_path=Path(\\\"../input/test\\\"),\\n model_dir=Path(\\\"../models\\\"),\\n output_dir=Path(\\\"../output\\\"),\\n rng=255.0,\\n num_workers=4,\\n gpus=1,\\n lr=1e-3,\\n max_epochs=50,\\n batch_size=128,\\n precision=16,\\n optimizer=\\\"adamw\\\",\\n scheduler=\\\"cosine\\\",\\n accumulate_grad_batches=1,\\n gradient_clip_val=5.0,\\n warmup_epochs=1,\\n)\\n\\nargs[\\\"trn_tfms\\\"] = A.Compose(\\n [\\n A.PadIfNeeded(min_height=128, min_width=128, always_apply=True, p=1),\\n ToTensorV2(always_apply=True, p=1),\\n ]\\n)\\n\\nargs[\\\"val_tfms\\\"] = A.Compose(\\n [\\n A.PadIfNeeded(min_height=128, min_width=128, always_apply=True, p=1),\\n ToTensorV2(always_apply=True, p=1),\\n ]\\n)\";\n", + " var nbb_formatted_code = \"args = dict(\\n seed=42,\\n dams=(6071, 6304, 7026, 7629, 7767, 8944, 11107),\\n train_folds_csv=Path(\\\"../input/train_folds.csv\\\"),\\n train_data_path=Path(\\\"../input/train\\\"),\\n test_data_path=Path(\\\"../input/test\\\"),\\n model_dir=Path(\\\"../models\\\"),\\n output_dir=Path(\\\"../output\\\"),\\n rng=255.0,\\n num_workers=4,\\n gpus=1,\\n lr=1e-3,\\n max_epochs=50,\\n batch_size=128,\\n precision=16,\\n optimizer=\\\"adamw\\\",\\n scheduler=\\\"cosine\\\",\\n accumulate_grad_batches=1,\\n gradient_clip_val=5.0,\\n warmup_epochs=1,\\n)\\n\\nargs[\\\"trn_tfms\\\"] = A.Compose(\\n [\\n A.PadIfNeeded(min_height=128, min_width=128, always_apply=True, p=1),\\n ToTensorV2(always_apply=True, p=1),\\n ]\\n)\\n\\nargs[\\\"val_tfms\\\"] = A.Compose(\\n [\\n A.PadIfNeeded(min_height=128, min_width=128, always_apply=True, p=1),\\n ToTensorV2(always_apply=True, p=1),\\n ]\\n)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", @@ -354,48 +183,48 @@ } ], "source": [ - "(PATH / \"test-128\").mkdir(exist_ok=True)" + "args = dict(\n", + " seed=42,\n", + " dams=(6071, 6304, 7026, 7629, 7767, 8944, 11107),\n", + " train_folds_csv=Path(\"../input/train_folds.csv\"),\n", + " train_data_path=Path(\"../input/train\"),\n", + " test_data_path=Path(\"../input/test\"),\n", + " model_dir=Path(\"../models\"),\n", + " output_dir=Path(\"../output\"),\n", + " rng=255.0,\n", + " num_workers=4,\n", + " gpus=1,\n", + " lr=1e-3,\n", + " max_epochs=50,\n", + " batch_size=128,\n", + " precision=16,\n", + " optimizer=\"adamw\",\n", + " scheduler=\"cosine\",\n", + " accumulate_grad_batches=1,\n", + " gradient_clip_val=5.0,\n", + " warmup_epochs=1,\n", + ")\n", + "\n", + "args[\"trn_tfms\"] = A.Compose(\n", + " [\n", + " A.PadIfNeeded(min_height=128, min_width=128, always_apply=True, p=1),\n", + " ToTensorV2(always_apply=True, p=1),\n", + " ]\n", + ")\n", + "\n", + "args[\"val_tfms\"] = A.Compose(\n", + " [\n", + " A.PadIfNeeded(min_height=128, min_width=128, always_apply=True, p=1),\n", + " ToTensorV2(always_apply=True, p=1),\n", + " ]\n", + ")" ] }, { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "hidden": true - }, - "outputs": [ - { - "data": { - "application/javascript": [ - "\n", - " setTimeout(function() {\n", - " var nbb_cell_id = 9;\n", - " var nbb_unformatted_code = \"test_files = list((PATH / \\\"test\\\").glob(\\\"*.npy\\\"))\\nwith ThreadPoolExecutor(8) as e:\\n e.map(functools.partial(resize_data, folder=\\\"test-128\\\"), test_files)\";\n", - " var nbb_formatted_code = \"test_files = list((PATH / \\\"test\\\").glob(\\\"*.npy\\\"))\\nwith ThreadPoolExecutor(8) as e:\\n e.map(functools.partial(resize_data, folder=\\\"test-128\\\"), test_files)\";\n", - " var nbb_cells = Jupyter.notebook.get_cells();\n", - " for (var i = 0; i < nbb_cells.length; ++i) {\n", - " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", - " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", - " nbb_cells[i].set_text(nbb_formatted_code);\n", - " }\n", - " break;\n", - " }\n", - " }\n", - " }, 500);\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "cell_type": "markdown", + "metadata": {}, "source": [ - "test_files = list((PATH / \"test\").glob(\"*.npy\"))\n", - "with ThreadPoolExecutor(8) as e:\n", - " e.map(functools.partial(resize_data, folder=\"test-128\"), test_files)" + "# 🔥 RainNet ⚡️" ] }, { @@ -407,7 +236,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -415,9 +244,9 @@ "application/javascript": [ "\n", " setTimeout(function() {\n", - " var nbb_cell_id = 9;\n", - " var nbb_unformatted_code = \"class NowcastingDataset(torch.utils.data.Dataset):\\n def __init__(self, paths, test=False):\\n self.paths = paths\\n self.test = test\\n\\n def __len__(self):\\n return len(self.paths)\\n\\n def __getitem__(self, idx):\\n path = self.paths[idx]\\n data = np.load(path)\\n x = data[:, :, :4]\\n x = x / args[\\\"rng\\\"]\\n x = x.astype(np.float32)\\n x = torch.tensor(x, dtype=torch.float)\\n x = x.permute(2, 0, 1)\\n if self.test:\\n return x\\n else:\\n y = data[:, :, 4]\\n y = y / args[\\\"rng\\\"]\\n y = y.astype(np.float32)\\n y = torch.tensor(y, dtype=torch.float)\\n y = y.unsqueeze(-1)\\n y = y.permute(2, 0, 1)\\n\\n return x, y\";\n", - " var nbb_formatted_code = \"class NowcastingDataset(torch.utils.data.Dataset):\\n def __init__(self, paths, test=False):\\n self.paths = paths\\n self.test = test\\n\\n def __len__(self):\\n return len(self.paths)\\n\\n def __getitem__(self, idx):\\n path = self.paths[idx]\\n data = np.load(path)\\n x = data[:, :, :4]\\n x = x / args[\\\"rng\\\"]\\n x = x.astype(np.float32)\\n x = torch.tensor(x, dtype=torch.float)\\n x = x.permute(2, 0, 1)\\n if self.test:\\n return x\\n else:\\n y = data[:, :, 4]\\n y = y / args[\\\"rng\\\"]\\n y = y.astype(np.float32)\\n y = torch.tensor(y, dtype=torch.float)\\n y = y.unsqueeze(-1)\\n y = y.permute(2, 0, 1)\\n\\n return x, y\";\n", + " var nbb_cell_id = 6;\n", + " var nbb_unformatted_code = \"class NowcastingDataset(Dataset):\\n def __init__(self, paths, tfms=None, test=False):\\n self.paths = paths\\n if tfms is not None:\\n self.tfms = tfms\\n else:\\n self.tfms = A.Compose(\\n [\\n A.PadIfNeeded(\\n min_height=128, min_width=128, always_apply=True, p=1\\n ),\\n ToTensorV2(always_apply=True, p=1),\\n ]\\n )\\n self.test = test\\n\\n def __len__(self):\\n return len(self.paths)\\n\\n def __getitem__(self, idx):\\n path = self.paths[idx]\\n data = np.load(path)\\n\\n augmented = self.tfms(image=data)\\n data = augmented[\\\"image\\\"]\\n\\n x = data[:4, :, :]\\n x = x / args[\\\"rng\\\"]\\n if self.test:\\n return x\\n else:\\n y = data[4, :, :]\\n y = y / args[\\\"rng\\\"]\\n y = y.unsqueeze(0)\\n return x, y\";\n", + " var nbb_formatted_code = \"class NowcastingDataset(Dataset):\\n def __init__(self, paths, tfms=None, test=False):\\n self.paths = paths\\n if tfms is not None:\\n self.tfms = tfms\\n else:\\n self.tfms = A.Compose(\\n [\\n A.PadIfNeeded(\\n min_height=128, min_width=128, always_apply=True, p=1\\n ),\\n ToTensorV2(always_apply=True, p=1),\\n ]\\n )\\n self.test = test\\n\\n def __len__(self):\\n return len(self.paths)\\n\\n def __getitem__(self, idx):\\n path = self.paths[idx]\\n data = np.load(path)\\n\\n augmented = self.tfms(image=data)\\n data = augmented[\\\"image\\\"]\\n\\n x = data[:4, :, :]\\n x = x / args[\\\"rng\\\"]\\n if self.test:\\n return x\\n else:\\n y = data[4, :, :]\\n y = y / args[\\\"rng\\\"]\\n y = y.unsqueeze(0)\\n return x, y\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", @@ -439,9 +268,20 @@ } ], "source": [ - "class NowcastingDataset(torch.utils.data.Dataset):\n", - " def __init__(self, paths, test=False):\n", + "class NowcastingDataset(Dataset):\n", + " def __init__(self, paths, tfms=None, test=False):\n", " self.paths = paths\n", + " if tfms is not None:\n", + " self.tfms = tfms\n", + " else:\n", + " self.tfms = A.Compose(\n", + " [\n", + " A.PadIfNeeded(\n", + " min_height=128, min_width=128, always_apply=True, p=1\n", + " ),\n", + " ToTensorV2(always_apply=True, p=1),\n", + " ]\n", + " )\n", " self.test = test\n", "\n", " def __len__(self):\n", @@ -450,27 +290,24 @@ " def __getitem__(self, idx):\n", " path = self.paths[idx]\n", " data = np.load(path)\n", - " x = data[:, :, :4]\n", + "\n", + " augmented = self.tfms(image=data)\n", + " data = augmented[\"image\"]\n", + "\n", + " x = data[:4, :, :]\n", " x = x / args[\"rng\"]\n", - " x = x.astype(np.float32)\n", - " x = torch.tensor(x, dtype=torch.float)\n", - " x = x.permute(2, 0, 1)\n", " if self.test:\n", " return x\n", " else:\n", - " y = data[:, :, 4]\n", + " y = data[4, :, :]\n", " y = y / args[\"rng\"]\n", - " y = y.astype(np.float32)\n", - " y = torch.tensor(y, dtype=torch.float)\n", - " y = y.unsqueeze(-1)\n", - " y = y.permute(2, 0, 1)\n", - "\n", + " y = y.unsqueeze(0)\n", " return x, y" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -478,9 +315,9 @@ "application/javascript": [ "\n", " setTimeout(function() {\n", - " var nbb_cell_id = 10;\n", - " var nbb_unformatted_code = \"class NowcastingDataModule(pl.LightningDataModule):\\n def __init__(\\n self,\\n train_df=None,\\n val_df=None,\\n batch_size=args[\\\"batch_size\\\"],\\n num_workers=args[\\\"num_workers\\\"],\\n ):\\n super().__init__()\\n self.train_df = train_df\\n self.val_df = val_df\\n self.batch_size = batch_size\\n self.num_workers = num_workers\\n\\n def setup(self, stage=\\\"train\\\"):\\n if stage == \\\"train\\\":\\n train_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.train_df.filename.values\\n ]\\n val_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.val_df.filename.values\\n ]\\n self.train_dataset = NowcastingDataset(train_paths)\\n self.val_dataset = NowcastingDataset(val_paths)\\n else:\\n test_paths = list(args[\\\"test_data_path\\\"].glob(\\\"*.npy\\\"))\\n self.test_dataset = NowcastingDataset(test_paths, test=True)\\n\\n def train_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.train_dataset,\\n batch_size=self.batch_size,\\n sampler=RandomSampler(self.train_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n drop_last=True,\\n )\\n\\n def val_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.val_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.val_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\\n\\n def test_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.test_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.test_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\";\n", - " var nbb_formatted_code = \"class NowcastingDataModule(pl.LightningDataModule):\\n def __init__(\\n self,\\n train_df=None,\\n val_df=None,\\n batch_size=args[\\\"batch_size\\\"],\\n num_workers=args[\\\"num_workers\\\"],\\n ):\\n super().__init__()\\n self.train_df = train_df\\n self.val_df = val_df\\n self.batch_size = batch_size\\n self.num_workers = num_workers\\n\\n def setup(self, stage=\\\"train\\\"):\\n if stage == \\\"train\\\":\\n train_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.train_df.filename.values\\n ]\\n val_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.val_df.filename.values\\n ]\\n self.train_dataset = NowcastingDataset(train_paths)\\n self.val_dataset = NowcastingDataset(val_paths)\\n else:\\n test_paths = list(args[\\\"test_data_path\\\"].glob(\\\"*.npy\\\"))\\n self.test_dataset = NowcastingDataset(test_paths, test=True)\\n\\n def train_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.train_dataset,\\n batch_size=self.batch_size,\\n sampler=RandomSampler(self.train_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n drop_last=True,\\n )\\n\\n def val_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.val_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.val_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\\n\\n def test_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.test_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.test_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\";\n", + " var nbb_cell_id = 7;\n", + " var nbb_unformatted_code = \"class NowcastingDataModule(pl.LightningDataModule):\\n def __init__(\\n self,\\n train_df=None,\\n val_df=None,\\n batch_size=args[\\\"batch_size\\\"],\\n num_workers=args[\\\"num_workers\\\"],\\n test=False,\\n ):\\n super().__init__()\\n self.train_df = train_df\\n self.val_df = val_df\\n self.batch_size = batch_size\\n self.num_workers = num_workers\\n self.test = test\\n\\n def setup(self, stage=\\\"train\\\"):\\n if stage == \\\"train\\\":\\n train_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.train_df.filename.values\\n ]\\n val_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.val_df.filename.values\\n ]\\n self.train_dataset = NowcastingDataset(train_paths, tfms=args[\\\"trn_tfms\\\"])\\n self.val_dataset = NowcastingDataset(val_paths, tfms=args[\\\"val_tfms\\\"])\\n else:\\n test_paths = list(sorted(args[\\\"test_data_path\\\"].glob(\\\"*.npy\\\")))\\n self.test_dataset = NowcastingDataset(test_paths, test=True)\\n\\n def train_dataloader(self):\\n return DataLoader(\\n self.train_dataset,\\n batch_size=self.batch_size,\\n sampler=RandomSampler(self.train_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n drop_last=True,\\n )\\n\\n def val_dataloader(self):\\n return DataLoader(\\n self.val_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.val_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\\n\\n def test_dataloader(self):\\n return DataLoader(\\n self.test_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.test_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\";\n", + " var nbb_formatted_code = \"class NowcastingDataModule(pl.LightningDataModule):\\n def __init__(\\n self,\\n train_df=None,\\n val_df=None,\\n batch_size=args[\\\"batch_size\\\"],\\n num_workers=args[\\\"num_workers\\\"],\\n test=False,\\n ):\\n super().__init__()\\n self.train_df = train_df\\n self.val_df = val_df\\n self.batch_size = batch_size\\n self.num_workers = num_workers\\n self.test = test\\n\\n def setup(self, stage=\\\"train\\\"):\\n if stage == \\\"train\\\":\\n train_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.train_df.filename.values\\n ]\\n val_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.val_df.filename.values\\n ]\\n self.train_dataset = NowcastingDataset(train_paths, tfms=args[\\\"trn_tfms\\\"])\\n self.val_dataset = NowcastingDataset(val_paths, tfms=args[\\\"val_tfms\\\"])\\n else:\\n test_paths = list(sorted(args[\\\"test_data_path\\\"].glob(\\\"*.npy\\\")))\\n self.test_dataset = NowcastingDataset(test_paths, test=True)\\n\\n def train_dataloader(self):\\n return DataLoader(\\n self.train_dataset,\\n batch_size=self.batch_size,\\n sampler=RandomSampler(self.train_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n drop_last=True,\\n )\\n\\n def val_dataloader(self):\\n return DataLoader(\\n self.val_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.val_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\\n\\n def test_dataloader(self):\\n return DataLoader(\\n self.test_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.test_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", @@ -509,12 +346,14 @@ " val_df=None,\n", " batch_size=args[\"batch_size\"],\n", " num_workers=args[\"num_workers\"],\n", + " test=False,\n", " ):\n", " super().__init__()\n", " self.train_df = train_df\n", " self.val_df = val_df\n", " self.batch_size = batch_size\n", " self.num_workers = num_workers\n", + " self.test = test\n", "\n", " def setup(self, stage=\"train\"):\n", " if stage == \"train\":\n", @@ -524,14 +363,14 @@ " val_paths = [\n", " args[\"train_data_path\"] / fn for fn in self.val_df.filename.values\n", " ]\n", - " self.train_dataset = NowcastingDataset(train_paths)\n", - " self.val_dataset = NowcastingDataset(val_paths)\n", + " self.train_dataset = NowcastingDataset(train_paths, tfms=args[\"trn_tfms\"])\n", + " self.val_dataset = NowcastingDataset(val_paths, tfms=args[\"val_tfms\"])\n", " else:\n", - " test_paths = list(args[\"test_data_path\"].glob(\"*.npy\"))\n", + " test_paths = list(sorted(args[\"test_data_path\"].glob(\"*.npy\")))\n", " self.test_dataset = NowcastingDataset(test_paths, test=True)\n", "\n", " def train_dataloader(self):\n", - " return torch.utils.data.DataLoader(\n", + " return DataLoader(\n", " self.train_dataset,\n", " batch_size=self.batch_size,\n", " sampler=RandomSampler(self.train_dataset),\n", @@ -541,7 +380,7 @@ " )\n", "\n", " def val_dataloader(self):\n", - " return torch.utils.data.DataLoader(\n", + " return DataLoader(\n", " self.val_dataset,\n", " batch_size=2 * self.batch_size,\n", " sampler=SequentialSampler(self.val_dataset),\n", @@ -550,7 +389,7 @@ " )\n", "\n", " def test_dataloader(self):\n", - " return torch.utils.data.DataLoader(\n", + " return DataLoader(\n", " self.test_dataset,\n", " batch_size=2 * self.batch_size,\n", " sampler=SequentialSampler(self.test_dataset),\n", @@ -559,55 +398,6 @@ " )" ] }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "data": { - "application/javascript": [ - "\n", - " setTimeout(function() {\n", - " var nbb_cell_id = 12;\n", - " var nbb_unformatted_code = \"# df = pd.read_csv(args[\\\"train_folds_csv\\\"])\\n# datamodule = NowcastingDataModule(df, fold=0, batch_size=2)\\n# datamodule.setup()\\n# for batch in datamodule.train_dataloader():\\n# xs, ys = batch\\n# idx = np.random.randint(len(xs))\\n# x, y = xs[idx], ys[idx]\\n# x = x.permute(1, 2, 0).numpy()\\n# y = y.permute(1, 2, 0).numpy()\\n# visualize(x, y)\\n# break\";\n", - " var nbb_formatted_code = \"# df = pd.read_csv(args[\\\"train_folds_csv\\\"])\\n# datamodule = NowcastingDataModule(df, fold=0, batch_size=2)\\n# datamodule.setup()\\n# for batch in datamodule.train_dataloader():\\n# xs, ys = batch\\n# idx = np.random.randint(len(xs))\\n# x, y = xs[idx], ys[idx]\\n# x = x.permute(1, 2, 0).numpy()\\n# y = y.permute(1, 2, 0).numpy()\\n# visualize(x, y)\\n# break\";\n", - " var nbb_cells = Jupyter.notebook.get_cells();\n", - " for (var i = 0; i < nbb_cells.length; ++i) {\n", - " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", - " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", - " nbb_cells[i].set_text(nbb_formatted_code);\n", - " }\n", - " break;\n", - " }\n", - " }\n", - " }, 500);\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# df = pd.read_csv(args[\"train_folds_csv\"])\n", - "# datamodule = NowcastingDataModule(df, fold=0, batch_size=2)\n", - "# datamodule.setup()\n", - "# for batch in datamodule.train_dataloader():\n", - "# xs, ys = batch\n", - "# idx = np.random.randint(len(xs))\n", - "# x, y = xs[idx], ys[idx]\n", - "# x = x.permute(1, 2, 0).numpy()\n", - "# y = y.permute(1, 2, 0).numpy()\n", - "# visualize(x, y)\n", - "# break" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -624,7 +414,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -632,9 +422,9 @@ "application/javascript": [ "\n", " setTimeout(function() {\n", - " var nbb_cell_id = 11;\n", - " var nbb_unformatted_code = \"class Block(nn.Module):\\n def __init__(self, in_ch, out_ch):\\n super().__init__()\\n self.net = nn.Sequential(\\n nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n nn.ReLU(inplace=True),\\n nn.BatchNorm2d(out_ch),\\n nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n nn.ReLU(inplace=True),\\n nn.BatchNorm2d(out_ch),\\n )\\n\\n def forward(self, x):\\n return self.net(x)\\n\\n\\nclass Encoder(nn.Module):\\n def __init__(self, chs=[4, 64, 128, 256, 512, 1024], drop_rate=0.5):\\n super().__init__()\\n self.blocks = nn.ModuleList(\\n [Block(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\\n self.dropout = nn.Dropout(p=drop_rate)\\n\\n def forward(self, x):\\n ftrs = []\\n for i, block in enumerate(self.blocks):\\n x = block(x)\\n ftrs.append(x)\\n if i >= 3:\\n x = self.dropout(x)\\n if i < 4:\\n x = self.pool(x)\\n return ftrs\\n\\n\\nclass Decoder(nn.Module):\\n def __init__(self, chs=[1024, 512, 256, 128, 64]):\\n super().__init__()\\n self.chs = chs\\n self.ups = nn.ModuleList(\\n [nn.Upsample(scale_factor=2, mode=\\\"nearest\\\") for i in range(len(chs) - 1)]\\n )\\n self.convs = nn.ModuleList(\\n [Block(chs[i] + chs[i + 1], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n\\n def forward(self, x, ftrs):\\n for i in range(len(self.chs) - 1):\\n x = self.ups[i](x)\\n x = torch.cat([ftrs[i], x], dim=1)\\n x = self.convs[i](x)\\n return x\";\n", - " var nbb_formatted_code = \"class Block(nn.Module):\\n def __init__(self, in_ch, out_ch):\\n super().__init__()\\n self.net = nn.Sequential(\\n nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n nn.ReLU(inplace=True),\\n nn.BatchNorm2d(out_ch),\\n nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n nn.ReLU(inplace=True),\\n nn.BatchNorm2d(out_ch),\\n )\\n\\n def forward(self, x):\\n return self.net(x)\\n\\n\\nclass Encoder(nn.Module):\\n def __init__(self, chs=[4, 64, 128, 256, 512, 1024], drop_rate=0.5):\\n super().__init__()\\n self.blocks = nn.ModuleList(\\n [Block(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\\n self.dropout = nn.Dropout(p=drop_rate)\\n\\n def forward(self, x):\\n ftrs = []\\n for i, block in enumerate(self.blocks):\\n x = block(x)\\n ftrs.append(x)\\n if i >= 3:\\n x = self.dropout(x)\\n if i < 4:\\n x = self.pool(x)\\n return ftrs\\n\\n\\nclass Decoder(nn.Module):\\n def __init__(self, chs=[1024, 512, 256, 128, 64]):\\n super().__init__()\\n self.chs = chs\\n self.ups = nn.ModuleList(\\n [nn.Upsample(scale_factor=2, mode=\\\"nearest\\\") for i in range(len(chs) - 1)]\\n )\\n self.convs = nn.ModuleList(\\n [Block(chs[i] + chs[i + 1], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n\\n def forward(self, x, ftrs):\\n for i in range(len(self.chs) - 1):\\n x = self.ups[i](x)\\n x = torch.cat([ftrs[i], x], dim=1)\\n x = self.convs[i](x)\\n return x\";\n", + " var nbb_cell_id = 8;\n", + " var nbb_unformatted_code = \"class Block(nn.Module):\\n def __init__(self, in_ch, out_ch):\\n super().__init__()\\n self.net = nn.Sequential(\\n nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n nn.ReLU(inplace=True),\\n nn.BatchNorm2d(out_ch),\\n nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n nn.ReLU(inplace=True),\\n nn.BatchNorm2d(out_ch),\\n )\\n\\n def forward(self, x):\\n return self.net(x)\\n\\n\\nclass Encoder(nn.Module):\\n def __init__(self, chs=[4, 64, 128, 256, 512, 1024], drop_rate=0.5):\\n super().__init__()\\n self.blocks = nn.ModuleList(\\n [Block(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\\n\\n # self.dropout = nn.Dropout(p=drop_rate)\\n\\n def forward(self, x):\\n feats = []\\n for block in self.blocks:\\n x = block(x)\\n feats.append(x)\\n x = self.pool(x)\\n return feats\\n\\n\\nclass Decoder(nn.Module):\\n def __init__(self, chs=[1024, 512, 256, 128, 64], bilinear=True):\\n super().__init__()\\n if bilinear:\\n self.upsamples = nn.ModuleList(\\n [\\n nn.Upsample(scale_factor=2, mode=\\\"nearest\\\")\\n for i in range(len(chs) - 1)\\n ]\\n )\\n else:\\n self.upsamples = nn.ModuleList(\\n [\\n nn.ConvTranspose2d(chs[i], chs[i], kernel_size=2, stride=2)\\n for i in range(len(chs) - 1)\\n ]\\n )\\n self.blocks = nn.ModuleList(\\n [Block(chs[i] + chs[i + 1], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n\\n def forward(self, x, feats):\\n for upsample, block, feat in zip(self.upsamples, self.blocks, feats):\\n # print(\\\"Before upsample:\\\", x.shape)\\n x = upsample(x)\\n # print(\\\"After upsample:\\\", x.shape)\\n # print(\\\"Feat:\\\", feat.shape)\\n x = torch.cat([feat, x], dim=1)\\n # print(\\\"Concat:\\\", x.shape)\\n x = block(x)\\n # print(\\\"After block:\\\", x.shape)\\n return x\";\n", + " var nbb_formatted_code = \"class Block(nn.Module):\\n def __init__(self, in_ch, out_ch):\\n super().__init__()\\n self.net = nn.Sequential(\\n nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n nn.ReLU(inplace=True),\\n nn.BatchNorm2d(out_ch),\\n nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n nn.ReLU(inplace=True),\\n nn.BatchNorm2d(out_ch),\\n )\\n\\n def forward(self, x):\\n return self.net(x)\\n\\n\\nclass Encoder(nn.Module):\\n def __init__(self, chs=[4, 64, 128, 256, 512, 1024], drop_rate=0.5):\\n super().__init__()\\n self.blocks = nn.ModuleList(\\n [Block(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\\n\\n # self.dropout = nn.Dropout(p=drop_rate)\\n\\n def forward(self, x):\\n feats = []\\n for block in self.blocks:\\n x = block(x)\\n feats.append(x)\\n x = self.pool(x)\\n return feats\\n\\n\\nclass Decoder(nn.Module):\\n def __init__(self, chs=[1024, 512, 256, 128, 64], bilinear=True):\\n super().__init__()\\n if bilinear:\\n self.upsamples = nn.ModuleList(\\n [\\n nn.Upsample(scale_factor=2, mode=\\\"nearest\\\")\\n for i in range(len(chs) - 1)\\n ]\\n )\\n else:\\n self.upsamples = nn.ModuleList(\\n [\\n nn.ConvTranspose2d(chs[i], chs[i], kernel_size=2, stride=2)\\n for i in range(len(chs) - 1)\\n ]\\n )\\n self.blocks = nn.ModuleList(\\n [Block(chs[i] + chs[i + 1], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n\\n def forward(self, x, feats):\\n for upsample, block, feat in zip(self.upsamples, self.blocks, feats):\\n # print(\\\"Before upsample:\\\", x.shape)\\n x = upsample(x)\\n # print(\\\"After upsample:\\\", x.shape)\\n # print(\\\"Feat:\\\", feat.shape)\\n x = torch.cat([feat, x], dim=1)\\n # print(\\\"Concat:\\\", x.shape)\\n x = block(x)\\n # print(\\\"After block:\\\", x.shape)\\n return x\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", @@ -679,36 +469,49 @@ " [Block(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]\n", " )\n", " self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\n", - " self.dropout = nn.Dropout(p=drop_rate)\n", + "\n", + " # self.dropout = nn.Dropout(p=drop_rate)\n", "\n", " def forward(self, x):\n", - " ftrs = []\n", - " for i, block in enumerate(self.blocks):\n", + " feats = []\n", + " for block in self.blocks:\n", " x = block(x)\n", - " ftrs.append(x)\n", - " if i >= 3:\n", - " x = self.dropout(x)\n", - " if i < 4:\n", - " x = self.pool(x)\n", - " return ftrs\n", + " feats.append(x)\n", + " x = self.pool(x)\n", + " return feats\n", "\n", "\n", "class Decoder(nn.Module):\n", - " def __init__(self, chs=[1024, 512, 256, 128, 64]):\n", + " def __init__(self, chs=[1024, 512, 256, 128, 64], bilinear=True):\n", " super().__init__()\n", - " self.chs = chs\n", - " self.ups = nn.ModuleList(\n", - " [nn.Upsample(scale_factor=2, mode=\"nearest\") for i in range(len(chs) - 1)]\n", - " )\n", - " self.convs = nn.ModuleList(\n", + " if bilinear:\n", + " self.upsamples = nn.ModuleList(\n", + " [\n", + " nn.Upsample(scale_factor=2, mode=\"nearest\")\n", + " for i in range(len(chs) - 1)\n", + " ]\n", + " )\n", + " else:\n", + " self.upsamples = nn.ModuleList(\n", + " [\n", + " nn.ConvTranspose2d(chs[i], chs[i], kernel_size=2, stride=2)\n", + " for i in range(len(chs) - 1)\n", + " ]\n", + " )\n", + " self.blocks = nn.ModuleList(\n", " [Block(chs[i] + chs[i + 1], chs[i + 1]) for i in range(len(chs) - 1)]\n", " )\n", "\n", - " def forward(self, x, ftrs):\n", - " for i in range(len(self.chs) - 1):\n", - " x = self.ups[i](x)\n", - " x = torch.cat([ftrs[i], x], dim=1)\n", - " x = self.convs[i](x)\n", + " def forward(self, x, feats):\n", + " for upsample, block, feat in zip(self.upsamples, self.blocks, feats):\n", + " # print(\"Before upsample:\", x.shape)\n", + " x = upsample(x)\n", + " # print(\"After upsample:\", x.shape)\n", + " # print(\"Feat:\", feat.shape)\n", + " x = torch.cat([feat, x], dim=1)\n", + " # print(\"Concat:\", x.shape)\n", + " x = block(x)\n", + " # print(\"After block:\", x.shape)\n", " return x" ] }, @@ -721,7 +524,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -729,9 +532,9 @@ "application/javascript": [ "\n", " setTimeout(function() {\n", - " var nbb_cell_id = 12;\n", - " var nbb_unformatted_code = \"class RainNet(pl.LightningModule):\\n def __init__(\\n self,\\n lr=4e-4,\\n enc_chs=[4, 64, 128, 256, 512, 1024],\\n dec_chs=[1024, 512, 256, 128, 64],\\n num_train_steps=None,\\n ):\\n super().__init__()\\n\\n # Parameters\\n self.lr = lr\\n self.num_train_steps = num_train_steps\\n\\n # self.criterion = LogCoshLoss()\\n self.criterion = nn.L1Loss()\\n\\n # Layers\\n self.encoder = Encoder(enc_chs)\\n self.decoder = Decoder(dec_chs)\\n self.out = nn.Sequential(\\n nn.Conv2d(64, 2, kernel_size=3, padding=1),\\n nn.ReLU(inplace=True),\\n nn.BatchNorm2d(2),\\n nn.Conv2d(2, 1, kernel_size=1),\\n nn.ReLU(inplace=True),\\n )\\n\\n def forward(self, x):\\n ftrs = self.encoder(x)\\n ftrs = ftrs[::-1]\\n x = self.decoder(ftrs[0], ftrs[1:])\\n out = self.out(x)\\n return out\\n\\n def shared_step(self, batch, batch_idx):\\n x, y = batch\\n y_hat = self(x)\\n loss = self.criterion(y_hat, y)\\n return loss, y, y_hat\\n\\n def training_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n self.log(\\\"train_loss\\\", loss)\\n return {\\\"loss\\\": loss}\\n\\n def validation_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n return {\\\"loss\\\": loss, \\\"y\\\": y.detach(), \\\"y_hat\\\": y_hat.detach()}\\n\\n def validation_epoch_end(self, outputs):\\n avg_loss = torch.stack([x[\\\"loss\\\"] for x in outputs]).mean()\\n self.log(\\\"val_loss\\\", avg_loss)\\n\\n tfms = nn.Sequential(\\n T.CenterCrop(120),\\n )\\n\\n y = torch.cat([x[\\\"y\\\"] for x in outputs])\\n y = tfms(y)\\n y = y.detach().cpu().numpy()\\n y = y.reshape(-1, 120 * 120)\\n\\n y_hat = torch.cat([x[\\\"y_hat\\\"] for x in outputs])\\n y_hat = tfms(y_hat)\\n y_hat = y_hat.detach().cpu().numpy()\\n y_hat = y_hat.reshape(-1, 120 * 120)\\n\\n rng = args[\\\"rng\\\"]\\n y = rng * y[:, args[\\\"dams\\\"]]\\n y = y.clip(0, 255)\\n y_hat = rng * y_hat[:, args[\\\"dams\\\"]]\\n y_hat = y_hat.clip(0, 255)\\n # mae = metrics.mean_absolute_error(y, y_hat)\\n\\n y_true = radar2precipitation(y)\\n y_true = np.where(y_true >= 0.1, 1, 0)\\n y_pred = radar2precipitation(y_hat)\\n y_pred = np.where(y_pred >= 0.1, 1, 0)\\n\\n y *= y_true\\n y_hat *= y_true\\n mae = metrics.mean_absolute_error(y, y_hat)\\n\\n tn, fp, fn, tp = metrics.confusion_matrix(\\n y_true.reshape(-1), y_pred.reshape(-1)\\n ).ravel()\\n csi = tp / (tp + fn + fp)\\n\\n comp_metric = mae / (csi + 1e-12)\\n\\n print(\\n f\\\"Epoch {self.current_epoch} | MAE/CSI: {comp_metric} | MAE: {mae} | CSI: {csi} | Loss: {avg_loss}\\\"\\n )\\n\\n def configure_optimizers(self):\\n # optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\\n optimizer = transformers.AdamW(self.parameters(), lr=self.lr)\\n scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\\n optimizer, T_max=self.num_train_steps\\n )\\n return [optimizer], [{\\\"scheduler\\\": scheduler, \\\"interval\\\": \\\"step\\\"}]\";\n", - " var nbb_formatted_code = \"class RainNet(pl.LightningModule):\\n def __init__(\\n self,\\n lr=4e-4,\\n enc_chs=[4, 64, 128, 256, 512, 1024],\\n dec_chs=[1024, 512, 256, 128, 64],\\n num_train_steps=None,\\n ):\\n super().__init__()\\n\\n # Parameters\\n self.lr = lr\\n self.num_train_steps = num_train_steps\\n\\n # self.criterion = LogCoshLoss()\\n self.criterion = nn.L1Loss()\\n\\n # Layers\\n self.encoder = Encoder(enc_chs)\\n self.decoder = Decoder(dec_chs)\\n self.out = nn.Sequential(\\n nn.Conv2d(64, 2, kernel_size=3, padding=1),\\n nn.ReLU(inplace=True),\\n nn.BatchNorm2d(2),\\n nn.Conv2d(2, 1, kernel_size=1),\\n nn.ReLU(inplace=True),\\n )\\n\\n def forward(self, x):\\n ftrs = self.encoder(x)\\n ftrs = ftrs[::-1]\\n x = self.decoder(ftrs[0], ftrs[1:])\\n out = self.out(x)\\n return out\\n\\n def shared_step(self, batch, batch_idx):\\n x, y = batch\\n y_hat = self(x)\\n loss = self.criterion(y_hat, y)\\n return loss, y, y_hat\\n\\n def training_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n self.log(\\\"train_loss\\\", loss)\\n return {\\\"loss\\\": loss}\\n\\n def validation_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n return {\\\"loss\\\": loss, \\\"y\\\": y.detach(), \\\"y_hat\\\": y_hat.detach()}\\n\\n def validation_epoch_end(self, outputs):\\n avg_loss = torch.stack([x[\\\"loss\\\"] for x in outputs]).mean()\\n self.log(\\\"val_loss\\\", avg_loss)\\n\\n tfms = nn.Sequential(\\n T.CenterCrop(120),\\n )\\n\\n y = torch.cat([x[\\\"y\\\"] for x in outputs])\\n y = tfms(y)\\n y = y.detach().cpu().numpy()\\n y = y.reshape(-1, 120 * 120)\\n\\n y_hat = torch.cat([x[\\\"y_hat\\\"] for x in outputs])\\n y_hat = tfms(y_hat)\\n y_hat = y_hat.detach().cpu().numpy()\\n y_hat = y_hat.reshape(-1, 120 * 120)\\n\\n rng = args[\\\"rng\\\"]\\n y = rng * y[:, args[\\\"dams\\\"]]\\n y = y.clip(0, 255)\\n y_hat = rng * y_hat[:, args[\\\"dams\\\"]]\\n y_hat = y_hat.clip(0, 255)\\n # mae = metrics.mean_absolute_error(y, y_hat)\\n\\n y_true = radar2precipitation(y)\\n y_true = np.where(y_true >= 0.1, 1, 0)\\n y_pred = radar2precipitation(y_hat)\\n y_pred = np.where(y_pred >= 0.1, 1, 0)\\n\\n y *= y_true\\n y_hat *= y_true\\n mae = metrics.mean_absolute_error(y, y_hat)\\n\\n tn, fp, fn, tp = metrics.confusion_matrix(\\n y_true.reshape(-1), y_pred.reshape(-1)\\n ).ravel()\\n csi = tp / (tp + fn + fp)\\n\\n comp_metric = mae / (csi + 1e-12)\\n\\n print(\\n f\\\"Epoch {self.current_epoch} | MAE/CSI: {comp_metric} | MAE: {mae} | CSI: {csi} | Loss: {avg_loss}\\\"\\n )\\n\\n def configure_optimizers(self):\\n # optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\\n optimizer = transformers.AdamW(self.parameters(), lr=self.lr)\\n scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\\n optimizer, T_max=self.num_train_steps\\n )\\n return [optimizer], [{\\\"scheduler\\\": scheduler, \\\"interval\\\": \\\"step\\\"}]\";\n", + " var nbb_cell_id = 9;\n", + " var nbb_unformatted_code = \"class RainNet(pl.LightningModule):\\n def __init__(\\n self,\\n lr=args[\\\"lr\\\"],\\n enc_chs=[4, 64, 128, 256, 512, 1024],\\n dec_chs=[1024, 512, 256, 128, 64],\\n num_train_steps=None,\\n bilinear=True,\\n ):\\n super().__init__()\\n self.lr = lr\\n self.num_train_steps = num_train_steps\\n # self.criterion = loss.LogCoshLoss()\\n self.criterion = nn.L1Loss()\\n self.encoder = Encoder(enc_chs)\\n self.decoder = Decoder(dec_chs, bilinear=bilinear)\\n self.out = nn.Sequential(\\n nn.Conv2d(64, 2, kernel_size=3, padding=1),\\n nn.ReLU(inplace=True),\\n nn.BatchNorm2d(2),\\n nn.Conv2d(2, 1, kernel_size=1),\\n nn.Sigmoid(),\\n )\\n\\n def forward(self, x):\\n ftrs = self.encoder(x)\\n ftrs = list(reversed(ftrs))\\n x = self.decoder(ftrs[0], ftrs[1:])\\n out = self.out(x)\\n return out\\n\\n def shared_step(self, batch, batch_idx):\\n x, y = batch\\n y_hat = self(x)\\n loss = self.criterion(y_hat, y)\\n return loss, y, y_hat\\n\\n def training_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n self.log(\\\"train_loss\\\", loss)\\n for i, param_group in enumerate(self.optimizer.param_groups):\\n self.log(f\\\"lr/lr{i}\\\", param_group[\\\"lr\\\"])\\n return {\\\"loss\\\": loss}\\n\\n def validation_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n return {\\\"loss\\\": loss, \\\"y\\\": y.detach(), \\\"y_hat\\\": y_hat.detach()}\\n\\n def validation_epoch_end(self, outputs):\\n avg_loss = torch.stack([x[\\\"loss\\\"] for x in outputs]).mean()\\n self.log(\\\"val_loss\\\", avg_loss)\\n\\n y = torch.cat([x[\\\"y\\\"] for x in outputs])\\n y_hat = torch.cat([x[\\\"y_hat\\\"] for x in outputs])\\n\\n crop = T.CenterCrop(120)\\n y = crop(y)\\n y_hat = crop(y_hat)\\n\\n batch_size = len(y)\\n\\n y = y.detach().cpu().numpy()\\n y = y.reshape(batch_size, -1)\\n y = y[:, args[\\\"dams\\\"]]\\n y *= args[\\\"rng\\\"]\\n\\n y_hat = y_hat.detach().cpu().numpy()\\n y_hat = y_hat.reshape(batch_size, -1)\\n y_hat = y_hat[:, args[\\\"dams\\\"]]\\n y_hat *= args[\\\"rng\\\"]\\n\\n y_true = radar2precipitation(y)\\n y_true = np.where(y_true >= 0.1, 1, 0)\\n y_true = y_true.ravel()\\n y_pred = radar2precipitation(y_hat)\\n y_pred = np.where(y_pred >= 0.1, 1, 0)\\n y_pred = y_pred.ravel()\\n\\n y = y.ravel()\\n y_hat = y_hat.ravel()\\n # mae = metrics.mean_absolute_error(y, y_hat, sample_weight=y_true)\\n\\n err = (y - y_hat) * y_true\\n err = np.abs(err)\\n mae = err.sum() / y_true.sum()\\n\\n self.log(\\\"mae\\\", mae)\\n\\n tn, fp, fn, tp = metrics.confusion_matrix(y_true, y_pred).ravel()\\n csi = tp / (tp + fn + fp)\\n self.log(\\\"csi\\\", csi)\\n\\n comp_metric = mae / (csi + 1e-12)\\n self.log(\\\"comp_metric\\\", comp_metric)\\n\\n print(\\n f\\\"Epoch {self.current_epoch} | MAE/CSI: {comp_metric} | MAE: {mae} | CSI: {csi} | Loss: {avg_loss}\\\"\\n )\\n\\n def configure_optimizers(self):\\n # optimizer\\n if args[\\\"optimizer\\\"] == \\\"adam\\\":\\n self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"adamw\\\":\\n self.optimizer = AdamW(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"radam\\\":\\n self.optimizer = optim.RAdam(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"ranger\\\":\\n self.optimizer = optim.RAdam(self.parameters(), lr=self.lr)\\n self.optimizer = optim.Lookahead(self.optimizer)\\n\\n # scheduler\\n if args[\\\"scheduler\\\"] == \\\"cosine\\\":\\n self.scheduler = get_cosine_schedule_with_warmup(\\n self.optimizer,\\n num_warmup_steps=self.num_train_steps * args[\\\"warmup_epochs\\\"],\\n num_training_steps=self.num_train_steps * args[\\\"max_epochs\\\"],\\n )\\n return [self.optimizer], [{\\\"scheduler\\\": self.scheduler, \\\"interval\\\": \\\"step\\\"}]\\n elif args[\\\"scheduler\\\"] == \\\"step\\\":\\n self.scheduler = torch.optim.lr_scheduler.StepLR(\\n self.optimizer, step_size=10, gamma=0.5\\n )\\n return [self.optimizer], [\\n {\\\"scheduler\\\": self.scheduler, \\\"interval\\\": \\\"epoch\\\"}\\n ]\\n elif args[\\\"scheduler\\\"] == \\\"plateau\\\":\\n self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\\n self.optimizer, mode=\\\"min\\\", factor=0.1, patience=3, verbose=True\\n )\\n return [self.optimizer], [\\n {\\n \\\"scheduler\\\": self.scheduler,\\n \\\"interval\\\": \\\"epoch\\\",\\n \\\"reduce_on_plateau\\\": True,\\n \\\"monitor\\\": \\\"comp_metric\\\",\\n }\\n ]\\n else:\\n self.scheduler = None\\n return [self.optimizer]\";\n", + " var nbb_formatted_code = \"class RainNet(pl.LightningModule):\\n def __init__(\\n self,\\n lr=args[\\\"lr\\\"],\\n enc_chs=[4, 64, 128, 256, 512, 1024],\\n dec_chs=[1024, 512, 256, 128, 64],\\n num_train_steps=None,\\n bilinear=True,\\n ):\\n super().__init__()\\n self.lr = lr\\n self.num_train_steps = num_train_steps\\n # self.criterion = loss.LogCoshLoss()\\n self.criterion = nn.L1Loss()\\n self.encoder = Encoder(enc_chs)\\n self.decoder = Decoder(dec_chs, bilinear=bilinear)\\n self.out = nn.Sequential(\\n nn.Conv2d(64, 2, kernel_size=3, padding=1),\\n nn.ReLU(inplace=True),\\n nn.BatchNorm2d(2),\\n nn.Conv2d(2, 1, kernel_size=1),\\n nn.Sigmoid(),\\n )\\n\\n def forward(self, x):\\n ftrs = self.encoder(x)\\n ftrs = list(reversed(ftrs))\\n x = self.decoder(ftrs[0], ftrs[1:])\\n out = self.out(x)\\n return out\\n\\n def shared_step(self, batch, batch_idx):\\n x, y = batch\\n y_hat = self(x)\\n loss = self.criterion(y_hat, y)\\n return loss, y, y_hat\\n\\n def training_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n self.log(\\\"train_loss\\\", loss)\\n for i, param_group in enumerate(self.optimizer.param_groups):\\n self.log(f\\\"lr/lr{i}\\\", param_group[\\\"lr\\\"])\\n return {\\\"loss\\\": loss}\\n\\n def validation_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n return {\\\"loss\\\": loss, \\\"y\\\": y.detach(), \\\"y_hat\\\": y_hat.detach()}\\n\\n def validation_epoch_end(self, outputs):\\n avg_loss = torch.stack([x[\\\"loss\\\"] for x in outputs]).mean()\\n self.log(\\\"val_loss\\\", avg_loss)\\n\\n y = torch.cat([x[\\\"y\\\"] for x in outputs])\\n y_hat = torch.cat([x[\\\"y_hat\\\"] for x in outputs])\\n\\n crop = T.CenterCrop(120)\\n y = crop(y)\\n y_hat = crop(y_hat)\\n\\n batch_size = len(y)\\n\\n y = y.detach().cpu().numpy()\\n y = y.reshape(batch_size, -1)\\n y = y[:, args[\\\"dams\\\"]]\\n y *= args[\\\"rng\\\"]\\n\\n y_hat = y_hat.detach().cpu().numpy()\\n y_hat = y_hat.reshape(batch_size, -1)\\n y_hat = y_hat[:, args[\\\"dams\\\"]]\\n y_hat *= args[\\\"rng\\\"]\\n\\n y_true = radar2precipitation(y)\\n y_true = np.where(y_true >= 0.1, 1, 0)\\n y_true = y_true.ravel()\\n y_pred = radar2precipitation(y_hat)\\n y_pred = np.where(y_pred >= 0.1, 1, 0)\\n y_pred = y_pred.ravel()\\n\\n y = y.ravel()\\n y_hat = y_hat.ravel()\\n # mae = metrics.mean_absolute_error(y, y_hat, sample_weight=y_true)\\n\\n err = (y - y_hat) * y_true\\n err = np.abs(err)\\n mae = err.sum() / y_true.sum()\\n\\n self.log(\\\"mae\\\", mae)\\n\\n tn, fp, fn, tp = metrics.confusion_matrix(y_true, y_pred).ravel()\\n csi = tp / (tp + fn + fp)\\n self.log(\\\"csi\\\", csi)\\n\\n comp_metric = mae / (csi + 1e-12)\\n self.log(\\\"comp_metric\\\", comp_metric)\\n\\n print(\\n f\\\"Epoch {self.current_epoch} | MAE/CSI: {comp_metric} | MAE: {mae} | CSI: {csi} | Loss: {avg_loss}\\\"\\n )\\n\\n def configure_optimizers(self):\\n # optimizer\\n if args[\\\"optimizer\\\"] == \\\"adam\\\":\\n self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"adamw\\\":\\n self.optimizer = AdamW(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"radam\\\":\\n self.optimizer = optim.RAdam(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"ranger\\\":\\n self.optimizer = optim.RAdam(self.parameters(), lr=self.lr)\\n self.optimizer = optim.Lookahead(self.optimizer)\\n\\n # scheduler\\n if args[\\\"scheduler\\\"] == \\\"cosine\\\":\\n self.scheduler = get_cosine_schedule_with_warmup(\\n self.optimizer,\\n num_warmup_steps=self.num_train_steps * args[\\\"warmup_epochs\\\"],\\n num_training_steps=self.num_train_steps * args[\\\"max_epochs\\\"],\\n )\\n return [self.optimizer], [{\\\"scheduler\\\": self.scheduler, \\\"interval\\\": \\\"step\\\"}]\\n elif args[\\\"scheduler\\\"] == \\\"step\\\":\\n self.scheduler = torch.optim.lr_scheduler.StepLR(\\n self.optimizer, step_size=10, gamma=0.5\\n )\\n return [self.optimizer], [\\n {\\\"scheduler\\\": self.scheduler, \\\"interval\\\": \\\"epoch\\\"}\\n ]\\n elif args[\\\"scheduler\\\"] == \\\"plateau\\\":\\n self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\\n self.optimizer, mode=\\\"min\\\", factor=0.1, patience=3, verbose=True\\n )\\n return [self.optimizer], [\\n {\\n \\\"scheduler\\\": self.scheduler,\\n \\\"interval\\\": \\\"epoch\\\",\\n \\\"reduce_on_plateau\\\": True,\\n \\\"monitor\\\": \\\"comp_metric\\\",\\n }\\n ]\\n else:\\n self.scheduler = None\\n return [self.optimizer]\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", @@ -745,125 +548,3384 @@ " " ], "text/plain": [ - "" + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "class RainNet(pl.LightningModule):\n", + " def __init__(\n", + " self,\n", + " lr=args[\"lr\"],\n", + " enc_chs=[4, 64, 128, 256, 512, 1024],\n", + " dec_chs=[1024, 512, 256, 128, 64],\n", + " num_train_steps=None,\n", + " bilinear=True,\n", + " ):\n", + " super().__init__()\n", + " self.lr = lr\n", + " self.num_train_steps = num_train_steps\n", + " # self.criterion = loss.LogCoshLoss()\n", + " self.criterion = nn.L1Loss()\n", + " self.encoder = Encoder(enc_chs)\n", + " self.decoder = Decoder(dec_chs, bilinear=bilinear)\n", + " self.out = nn.Sequential(\n", + " nn.Conv2d(64, 2, kernel_size=3, padding=1),\n", + " nn.ReLU(inplace=True),\n", + " nn.BatchNorm2d(2),\n", + " nn.Conv2d(2, 1, kernel_size=1),\n", + " nn.Sigmoid(),\n", + " )\n", + "\n", + " def forward(self, x):\n", + " ftrs = self.encoder(x)\n", + " ftrs = list(reversed(ftrs))\n", + " x = self.decoder(ftrs[0], ftrs[1:])\n", + " out = self.out(x)\n", + " return out\n", + "\n", + " def shared_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " y_hat = self(x)\n", + " loss = self.criterion(y_hat, y)\n", + " return loss, y, y_hat\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " loss, y, y_hat = self.shared_step(batch, batch_idx)\n", + " self.log(\"train_loss\", loss)\n", + " for i, param_group in enumerate(self.optimizer.param_groups):\n", + " self.log(f\"lr/lr{i}\", param_group[\"lr\"])\n", + " return {\"loss\": loss}\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + " loss, y, y_hat = self.shared_step(batch, batch_idx)\n", + " return {\"loss\": loss, \"y\": y.detach(), \"y_hat\": y_hat.detach()}\n", + "\n", + " def validation_epoch_end(self, outputs):\n", + " avg_loss = torch.stack([x[\"loss\"] for x in outputs]).mean()\n", + " self.log(\"val_loss\", avg_loss)\n", + "\n", + " y = torch.cat([x[\"y\"] for x in outputs])\n", + " y_hat = torch.cat([x[\"y_hat\"] for x in outputs])\n", + "\n", + " crop = T.CenterCrop(120)\n", + " y = crop(y)\n", + " y_hat = crop(y_hat)\n", + "\n", + " batch_size = len(y)\n", + "\n", + " y = y.detach().cpu().numpy()\n", + " y = y.reshape(batch_size, -1)\n", + " y = y[:, args[\"dams\"]]\n", + " y *= args[\"rng\"]\n", + "\n", + " y_hat = y_hat.detach().cpu().numpy()\n", + " y_hat = y_hat.reshape(batch_size, -1)\n", + " y_hat = y_hat[:, args[\"dams\"]]\n", + " y_hat *= args[\"rng\"]\n", + "\n", + " y_true = radar2precipitation(y)\n", + " y_true = np.where(y_true >= 0.1, 1, 0)\n", + " y_true = y_true.ravel()\n", + " y_pred = radar2precipitation(y_hat)\n", + " y_pred = np.where(y_pred >= 0.1, 1, 0)\n", + " y_pred = y_pred.ravel()\n", + "\n", + " y = y.ravel()\n", + " y_hat = y_hat.ravel()\n", + " # mae = metrics.mean_absolute_error(y, y_hat, sample_weight=y_true)\n", + "\n", + " err = (y - y_hat) * y_true\n", + " err = np.abs(err)\n", + " mae = err.sum() / y_true.sum()\n", + "\n", + " self.log(\"mae\", mae)\n", + "\n", + " tn, fp, fn, tp = metrics.confusion_matrix(y_true, y_pred).ravel()\n", + " csi = tp / (tp + fn + fp)\n", + " self.log(\"csi\", csi)\n", + "\n", + " comp_metric = mae / (csi + 1e-12)\n", + " self.log(\"comp_metric\", comp_metric)\n", + "\n", + " print(\n", + " f\"Epoch {self.current_epoch} | MAE/CSI: {comp_metric} | MAE: {mae} | CSI: {csi} | Loss: {avg_loss}\"\n", + " )\n", + "\n", + " def configure_optimizers(self):\n", + " # optimizer\n", + " if args[\"optimizer\"] == \"adam\":\n", + " self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\n", + " elif args[\"optimizer\"] == \"adamw\":\n", + " self.optimizer = AdamW(self.parameters(), lr=self.lr)\n", + " elif args[\"optimizer\"] == \"radam\":\n", + " self.optimizer = optim.RAdam(self.parameters(), lr=self.lr)\n", + " elif args[\"optimizer\"] == \"ranger\":\n", + " self.optimizer = optim.RAdam(self.parameters(), lr=self.lr)\n", + " self.optimizer = optim.Lookahead(self.optimizer)\n", + "\n", + " # scheduler\n", + " if args[\"scheduler\"] == \"cosine\":\n", + " self.scheduler = get_cosine_schedule_with_warmup(\n", + " self.optimizer,\n", + " num_warmup_steps=self.num_train_steps * args[\"warmup_epochs\"],\n", + " num_training_steps=self.num_train_steps * args[\"max_epochs\"],\n", + " )\n", + " return [self.optimizer], [{\"scheduler\": self.scheduler, \"interval\": \"step\"}]\n", + " elif args[\"scheduler\"] == \"step\":\n", + " self.scheduler = torch.optim.lr_scheduler.StepLR(\n", + " self.optimizer, step_size=10, gamma=0.5\n", + " )\n", + " return [self.optimizer], [\n", + " {\"scheduler\": self.scheduler, \"interval\": \"epoch\"}\n", + " ]\n", + " elif args[\"scheduler\"] == \"plateau\":\n", + " self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n", + " self.optimizer, mode=\"min\", factor=0.1, patience=3, verbose=True\n", + " )\n", + " return [self.optimizer], [\n", + " {\n", + " \"scheduler\": self.scheduler,\n", + " \"interval\": \"epoch\",\n", + " \"reduce_on_plateau\": True,\n", + " \"monitor\": \"comp_metric\",\n", + " }\n", + " ]\n", + " else:\n", + " self.scheduler = None\n", + " return [self.optimizer]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "42" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 10;\n", + " var nbb_unformatted_code = \"seed_everything(args[\\\"seed\\\"])\\npl.seed_everything(args[\\\"seed\\\"])\";\n", + " var nbb_formatted_code = \"seed_everything(args[\\\"seed\\\"])\\npl.seed_everything(args[\\\"seed\\\"])\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "seed_everything(args[\"seed\"])\n", + "pl.seed_everything(args[\"seed\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 11;\n", + " var nbb_unformatted_code = \"df = pd.read_csv(args[\\\"train_folds_csv\\\"])\";\n", + " var nbb_formatted_code = \"df = pd.read_csv(args[\\\"train_folds_csv\\\"])\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "df = pd.read_csv(args[\"train_folds_csv\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 12;\n", + " var nbb_unformatted_code = \"def train_fold(df, fold, lr_find=False, bilinear=False):\\n train_df = df[df.fold != fold]\\n val_df = df[df.fold == fold]\\n\\n datamodule = NowcastingDataModule(train_df, val_df)\\n datamodule.setup()\\n\\n num_train_steps = np.ceil(\\n len(train_df) // args[\\\"batch_size\\\"] / args[\\\"accumulate_grad_batches\\\"]\\n )\\n model = RainNet(num_train_steps=num_train_steps, bilinear=bilinear)\\n\\n trainer = pl.Trainer(\\n gpus=args[\\\"gpus\\\"],\\n max_epochs=args[\\\"max_epochs\\\"],\\n precision=args[\\\"precision\\\"],\\n progress_bar_refresh_rate=50,\\n benchmark=True,\\n )\\n\\n if lr_find:\\n lr_finder = trainer.tuner.lr_find(model, datamodule=datamodule)\\n fig = lr_finder.plot(suggest=True)\\n fig.show()\\n return\\n\\n print(f\\\"Training fold {fold}...\\\")\\n trainer.fit(model, datamodule)\\n\\n checkpoint = (\\n args[\\\"model_dir\\\"]\\n / f\\\"rainnet_fold{fold}_bs{args['batch_size']}_epochs{args['max_epochs']}_lr{model.lr}_{args['optimizer']}_{args['scheduler']}.ckpt\\\"\\n )\\n trainer.save_checkpoint(checkpoint)\\n print(\\\"Model saved at\\\", checkpoint)\\n\\n del model, trainer, datamodule\\n gc.collect()\\n torch.cuda.empty_cache()\";\n", + " var nbb_formatted_code = \"def train_fold(df, fold, lr_find=False, bilinear=False):\\n train_df = df[df.fold != fold]\\n val_df = df[df.fold == fold]\\n\\n datamodule = NowcastingDataModule(train_df, val_df)\\n datamodule.setup()\\n\\n num_train_steps = np.ceil(\\n len(train_df) // args[\\\"batch_size\\\"] / args[\\\"accumulate_grad_batches\\\"]\\n )\\n model = RainNet(num_train_steps=num_train_steps, bilinear=bilinear)\\n\\n trainer = pl.Trainer(\\n gpus=args[\\\"gpus\\\"],\\n max_epochs=args[\\\"max_epochs\\\"],\\n precision=args[\\\"precision\\\"],\\n progress_bar_refresh_rate=50,\\n benchmark=True,\\n )\\n\\n if lr_find:\\n lr_finder = trainer.tuner.lr_find(model, datamodule=datamodule)\\n fig = lr_finder.plot(suggest=True)\\n fig.show()\\n return\\n\\n print(f\\\"Training fold {fold}...\\\")\\n trainer.fit(model, datamodule)\\n\\n checkpoint = (\\n args[\\\"model_dir\\\"]\\n / f\\\"rainnet_fold{fold}_bs{args['batch_size']}_epochs{args['max_epochs']}_lr{model.lr}_{args['optimizer']}_{args['scheduler']}.ckpt\\\"\\n )\\n trainer.save_checkpoint(checkpoint)\\n print(\\\"Model saved at\\\", checkpoint)\\n\\n del model, trainer, datamodule\\n gc.collect()\\n torch.cuda.empty_cache()\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def train_fold(df, fold, lr_find=False, bilinear=False):\n", + " train_df = df[df.fold != fold]\n", + " val_df = df[df.fold == fold]\n", + "\n", + " datamodule = NowcastingDataModule(train_df, val_df)\n", + " datamodule.setup()\n", + "\n", + " num_train_steps = np.ceil(\n", + " len(train_df) // args[\"batch_size\"] / args[\"accumulate_grad_batches\"]\n", + " )\n", + " model = RainNet(num_train_steps=num_train_steps, bilinear=bilinear)\n", + "\n", + " trainer = pl.Trainer(\n", + " gpus=args[\"gpus\"],\n", + " max_epochs=args[\"max_epochs\"],\n", + " precision=args[\"precision\"],\n", + " progress_bar_refresh_rate=50,\n", + " benchmark=True,\n", + " )\n", + "\n", + " if lr_find:\n", + " lr_finder = trainer.tuner.lr_find(model, datamodule=datamodule)\n", + " fig = lr_finder.plot(suggest=True)\n", + " fig.show()\n", + " return\n", + "\n", + " print(f\"Training fold {fold}...\")\n", + " trainer.fit(model, datamodule)\n", + "\n", + " checkpoint = (\n", + " args[\"model_dir\"]\n", + " / f\"rainnet_fold{fold}_bs{args['batch_size']}_epochs{args['max_epochs']}_lr{model.lr}_{args['optimizer']}_{args['scheduler']}.ckpt\"\n", + " )\n", + " trainer.save_checkpoint(checkpoint)\n", + " print(\"Model saved at\", checkpoint)\n", + "\n", + " del model, trainer, datamodule\n", + " gc.collect()\n", + " torch.cuda.empty_cache()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True, used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "Using native 16bit precision.\n", + "\n", + " | Name | Type | Params\n", + "-----------------------------------------\n", + "0 | criterion | L1Loss | 0 \n", + "1 | encoder | Encoder | 18 M \n", + "2 | decoder | Decoder | 12 M \n", + "3 | out | Sequential | 1 K \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 336.4057808053285 | MAE: 53.408172231985944 | CSI: 0.15876116071428573 | Loss: 0.598244309425354\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2e115e1b43eb4041bfab77b2d60dc4ef", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Finding best initial lr'), FloatProgress(value=0.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAY4AAAEKCAYAAAAFJbKyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAApr0lEQVR4nO3deXwc9X3/8ddnV6v7sK3LtnzbAmMIGGrMYUM5QgKkDdC6DTSh5BcS6hLSJk3Tkra//tL20TbpowlpAgmhKQnNRQjhcBMnhBDCYS7L1AYb21g+wLKNDh86rFv7+f2xY2ctZEtrNNpd6f18PPTQzux8dz4ztvft+c7Md8zdERERGalIugsQEZHsouAQEZGUKDhERCQlCg4REUmJgkNERFKi4BARkZTkpLuAsVBRUeFz5sxJdxkiIlll3bp1Le5eOXj+hAiOOXPmUFdXl+4yRESyipm9MdR8dVWJiEhKFBwiIpISBYeIiKREwSEiIilRcIiISEoUHCIikpIJcTnuaOrqHWB7cwf5sSgLqorTXY6IyJhTcJzAtsZ21u8+RH1TB9uaOtjW1E7DwS7cIWLwXzedy6ULq9JdpojImFJwnMB9z+/iuy+8SW40wrzKIs6aMYkV58xkQVUxX/t1PX/2g//loVsvpLa6JN2lnpR43Nnb2sX25sP09sepLMmjsiSPwliU/rgzEHfauvvYc6iLvYe6ONDRS3f/AD19ceIOhblRCnKjlOTnMKUolylFuZQVxIiYETEjGjEKc6MU5kbJj0WJRSNELLHunv44h3v66eobIBpJLGsYXb0DHO7tp7N3AHDMDAP6405ff5y+uBOLGnk5UfJjESJmmIFh9A3E6RuI09sfp7N3gPaeftq7+9h3qJs3D3TScLCTnGgksZ3FeUwuzKW0IIeyghhlBTEmFeYyqTDG1NJ8ivL0T0PkeGwiPAFwyZIlfjJ3ju8+0EnfQJxZUwrJiR57OmjvoS7ef+caivKiPHLrMiYX5QLQ3t3Hc9v38/Trzbyxv5MVvzWD3zlz2tvaj4bO3n4A8nOiRI58Iyfp7hugub2H9bsP8fyO/bywYz+HOvuIRY2cSIQDh3vp6htIaZ3RiJGfE8HM6OztJ34Sf33MYCz/2kUjxvRJ+cyYVMiAOy3tPTS399De03/cNhXFecytKKS6NP9osJQVxCgtiFGSn4M7NLf30NzRQ1fvACX5ORTn5VCSH2NyYSKECnMTAdw/kAiyPYe62HOoi9auPsqLcqksyaOiOI8pRbmUB6HrJEIyHndyoolAjUUi5MUi5OVEiQ7x5ywSFjNb5+5L3jZfwXHy1r1xkBvueYFZ5YVMLoyxr7Wbfa3dDMSdotwo5cV5vHmgk1lTCvnAuTN5q7WbV/a08sb+w8yvLObMGWWcNWMSZ9SUMreimGjE6O4b4OU3D7JxTyszJhdy1sxJTC/Lp6Wjl5d2HmDtrgNsa2qnvqmDxraeo7UU5kbJzYmQE4kQixod3f3HfDEW5+WwdO4Upk/Kp3/A6R2IU1YQY0FVMQsqi8mPRWnpSHyhdvUNkBONkBMxivJyqJlUQM2kAsqLc4klBaC709Mfp6OnnwOHe2np6KGtqx93x4G+gThdvQN09g7Q1TfAQNzpjzvuTmFuDkV5UfJzosT9N/MLcnMoCo5kzCzxWQ6xaGK7cqJG34DT3TdAd1/86LrcISdq5OZEiEUiFOUljoRK8mOUF+UOGdz9A3Hau/tp7erjUFcfhzp7OdTZx97WLt5o6WTn/sM0t/fQ2tVHa1cfA0OkZE7EKIhF6ejtH1EY5kSM0oIYhzp7Typ0Y1GjojiPqtJ8qkrymFwYBFp+ItRKC3IoyYtRXvybYMqPRVNfkQgKjlCCA+Anr+zlzl/VM6kwxrSyAmZOLuDCBRWcM2syORHjl5sbuevJejY0tFKcl8MZNaXMKS+ivqmDjXtb6e6LA5AfizBrSiE7Ww7TN3Dsn0lpfg5t3YkQKIhFOWVqCfMri5hXUURONJL4Yu7tp7c/Tu+A0zcQpzgvJ/jiyOWU6hLeVVMWylHPROHudPYO0NbdR1tX4s+iqiQv0TUXsWPeP3i4j0NdvXT2DJATNWLRCAW5UaaXFVBZkkc0YgzE/WjYHjjcy/7DvbR29RGxRLiYGfG40xccsfT2x+kJuuBaOnpobOumqe03oXaiI8fKkjxmTylkVnnh0RqOhEp5cS4VxXmU5udgpqMZOZaCI42DHLo7ze09VBTnHdOl1D8Qp765g0172ti0t42dLR2cUl3CefOmcNaMSTQc7GL97kNseauN2eVFnDd3CmfUlB3zv34RgN7+OO3dfbR199PW1cf+w4mjx6a2HnYf7OSN/Z28eaCTpvaeIY+c8mMRqkvzqS7NZ3JhjOK8GMV5USqK85g5pZAZkwuYX1l8tEtWJgYFh0bHFSEedw529tLU3kNLRw/7O3qPHsG81dZDY2s3rV19dAQXFhw50j2iZlIBp08vZeaUwqMXPZTm5zA5uDiiqiSPaWUFurhgnDhecOhPV2QCiUSM8uI8yovzRrR8d98ADQe72H2wk22N7Wzc08bGPa2sqW+hs2/guOd1ygpiVJXkUZKfQ2lBjMriPE6fXsoZNWUsml5KYa6+erJZqEccZnYl8B9AFPimu39+0PuXAI8CO4NZD7n7P5rZqcAPkxadB/y9u3/ZzD4HfAxoDt77G3dffaI6dMQhMvrcExdZtHb1ceBwLwc6Ekcye1u7eKu1O3HlWnc/bd197D3URUtHLwC50QgXLijnvadP5fKFVVSV5qd5S+R4xvyIw8yiwF3AFUADsNbMVrn7a4MWfcbdfyd5hrtvBRYnfc4e4OGkRe5w938Pq3YRGZ5Z4n6aqpIoVSUn/vJ3dxrbeti4p5UXduznF6818tmHXgUS3V+LZ03inFmTuWxhFXMrisaifHkHwjxeXArUu/sOADO7H7gGGBwcw7kc2O7uQz6JSkQyn5kxtSyfqWX5vHtRNX/7vtPY2tjOM6+3sH73Ida/eYifvrKPf/rJa8yrLOKK06p57xlTWTxj0pD3KEl6hRkcNcDupOkG4LwhlrvAzDYAe4G/dPdNg96/HvjBoHm3mdkfA3XAp9394OAPNbNbgFsAZs2adXJbICKhMDMWTi1l4dTSo/N2H+jkic2NPLGliXvX7OQbT+9gamk+V71rKr9/zgzOqClLY8WSLLRzHGb2B8B73f2jwfSNwFJ3/0TSMqVA3N07zOxq4D/cvTbp/VwSgXK6uzcG86qBFsCBfwKmuftHTlSLznGIZJfWrj5+taWR1a++xVOvN9PbH+e0aaX84ZIZ/N7ZMygrjKW7xAkhHVdVNQAzk6ZnkAiBo9y9Len1ajP7mplVuHtLMPsq4OUjoREsd/S1mf0n8JMwiheR9CkriHHd2TO47uwZtHb2sWrDHh6oa+Af/uc1vvDzLVxzVg03XjBbRyFpEmZwrAVqzWwuiZPb1wN/lLyAmU0FGt3dzWwpieeD7E9a5AYGdVOZ2TR33xdMXgdsDKl+EckAZYUxbrxgDjdeMIeNe1r53otv8Mj/7uWHdbu58vSpfPbqhcwu1wn1sRT25bhXA18mcTnuve7+z2a2EsDd7zaz24A/BfqBLuAv3P25oG0hiXMk89y9Nekzv0PiiisHdgF/khQkQ1JXlcj40trVx33P7eLup7bTNxDnwxfO4ROX11Kary6s0aQ7xxUcIuNOY1s3X/zFVn60roGK4jz+5uqFXLu4RuNujZLjBYcGPRKRrFVdms+/rTiLRz++jOll+Xzqhxv4wDdeYEdzR7pLG9cUHCKS9c6cMYmHb13G53/vXWxtbOd9X3mWB9buZiL0qKSDgkNExoVIxLh+6Sx+/smLWDxzEn/141f4+PdfprWrL92ljTsKDhEZV6aVFfDdj57H7Vct5BebGrnurjXUN6nrajQpOERk3IlGjJW/PZ/vf+x8Wrv6uO6uNfxqS+PwDWVEFBwiMm4tnTuFVZ9YzqzyQm6+r44frn0z3SWNCwoOERnXaiYV8ODKC7motpLPPvQqj216K90lZT0Fh4iMewW5Ub7+wXN414xJfOIH/8uLO/YP30iOS8EhIhNCUV4O3/rwucyYXMBH/7uOrW+1p7ukrKXgEJEJY0pRLt+5+TzyY1H+9HvrONzTP3wjeRsFh4hMKDWTCviP6xezs+Uw//dRjZF6MhQcIjLhXDi/gj+7rJaHXt7Dg+sa0l1O1lFwiMiE9GeX13L+vCn830c26gbBFCk4RGRCikaM/7j+bPJjET770Csa1yoFCg4RmbCqS/P5qysXsnbXQR5dv3f4BgIoOERkgvvAkpmcNaOMf169mfZuDYg4EgoOEZnQIhHjH645g5aOHr7yxLZ0l5MVFBwiMuEtnjmJDyyZybfW7GJbo24MHI6CQ0QE+Mx7T6UwN8oXfr4l3aVkvFCDw8yuNLOtZlZvZrcP8f4lZtZqZuuDn79Pem+Xmb0azK9Lmj/FzB43s23B78lhboOITAzlxXl87KJ5/HJzE680HEp3ORkttOAwsyhwF3AVsAi4wcwWDbHoM+6+OPj5x0HvXRrMT35Y+u3AE+5eCzwRTIuIvGMfXjaHSYUxvvxLnes4kTCPOJYC9e6+w917gfuBa0bhc68B7gte3wdcOwqfKSJCSX6Mj100j19taWL97kPpLidjhRkcNcDupOmGYN5gF5jZBjP7mZmdnjTfgV+Y2TozuyVpfrW77wMIflcNtXIzu8XM6sysrrm5+Z1tiYhMGDddOIfJhTHuePz1dJeSscIMDhti3uBbM18GZrv7WcBXgUeS3lvm7ueQ6Or6uJldnMrK3f0ed1/i7ksqKytTaSoiE1hxXg63XDyfp15vZt0bB9NdTkYKMzgagJlJ0zOAY27NdPc2d+8IXq8GYmZWEUzvDX43AQ+T6PoCaDSzaQDB76YQt0FEJqA/vmA2U4py+fqv69NdSkYKMzjWArVmNtfMcoHrgVXJC5jZVDOz4PXSoJ79ZlZkZiXB/CLgPcCR8Y9XATcFr28CHg1xG0RkAirKy+FD58/miS1N7Go5nO5yMk5oweHu/cBtwGPAZuABd99kZivNbGWw2Apgo5ltAL4CXO+JkcaqgWeD+S8BP3X3nwdtPg9cYWbbgCuCaRGRUfWh82eREzG+/dyudJeScWwijAi5ZMkSr6urG35BEZEkf/HAen6+8S2e/+zllBXE0l3OmDOzdYNuhwB057iIyHF9ZNlcOnsH+FHd7uEXnkAUHCIix3FGTRlL507hW2t20T8QT3c5GUPBISJyAjcvn8ueQ108/lpjukvJGAoOEZETePdp1cwpL+QLP99CZ29/usvJCAoOEZETiEaMz//+mbxxoJN/Xa2Rc0HBISIyrPPnlXPzsrl854U3eOp1DWGk4BARGYG/fO+p1FYV85kfbeBQZ2+6y0krBYeIyAjkx6Lc8YHFHDjcy7+s3pzuctJKwSEiMkJn1JTxofNn89DLe9jX2pXuctJGwSEikoKbl88l7s631+xKdylpo+AQEUnBzCmFXPWuaXz/xTdp7+5LdzlpoeAQEUnRn1w8j/aefu5/aWIORaLgEBFJ0ZkzJnHe3Cncu2YnfRNwKBIFh4jISfiT357HvtZufvLK3uEXHmcUHCIiJ+GSU6qYX1nE9198M92ljDkFh4jISYhEjPedOZ11bxzkwOGJdUOggkNE5CS9+7Qq4g5PbmlKdyljSsEhInKSzpheRlVJHk9smVhDris4REROUiRiXH5aFU+/3kJP/0C6yxkzoQaHmV1pZlvNrN7Mbh/i/UvMrNXM1gc/fx/Mn2lmT5rZZjPbZGZ/ntTmc2a2J6nN1WFug4jIibz7tGo6evp5cceBdJcyZnLC+mAziwJ3AVcADcBaM1vl7q8NWvQZd/+dQfP6gU+7+8tmVgKsM7PHk9re4e7/HlbtIiIjtWxBBfmxCE9sbuTiUyrTXc6YCPOIYylQ7+473L0XuB+4ZiQN3X2fu78cvG4HNgM1oVUqInKS8mNRli+o4Jebm3D3dJczJsIMjhog+X78Bob+8r/AzDaY2c/M7PTBb5rZHOBs4MWk2beZ2Stmdq+ZTR5q5WZ2i5nVmVldc7MevCIi4bn8tGr2HOpiy1vt6S5lTIQZHDbEvMFx/DIw293PAr4KPHLMB5gVAz8GPunubcHsrwPzgcXAPuCLQ63c3e9x9yXuvqSycmIcPopIely+sAqAJzZPjKurwgyOBmBm0vQM4Jh78929zd07gtergZiZVQCYWYxEaHzP3R9KatPo7gPuHgf+k0SXmIhI2lSV5nPmjDJ+NUHu5wgzONYCtWY218xygeuBVckLmNlUM7Pg9dKgnv3BvP8CNrv7lwa1mZY0eR2wMcRtEBEZkYtrK9nQ0ErbBBhqPbTgcPd+4DbgMRIntx9w901mttLMVgaLrQA2mtkG4CvA9Z44u7QMuBG4bIjLbv/NzF41s1eAS4FPhbUNIiIjtWxBBQNxnxCX5YZ2OS4c7X5aPWje3Umv7wTuHKLdswx9jgR3v3GUyxQRecfOmT2J/FiENfUtXLGoOt3lhEp3jouIjIK8nChL55azpr4l3aWETsEhIjJKli8oZ1tTB41t3ekuJVQKDhGRUXLh/AqAcX/UoeAQERkli6aVMrkwxrMKDhERGYlIxLhwQQVr6lvG9fAjCg4RkVG0fEEFjW09bG8+nO5SQqPgEBEZRcsXjP/zHAoOEZFRNHNKIbOmFI7r8xwKDhGRUbZsQTkv7NhP/0A83aWEQsEhIjLKLpxfQXt3Pxv3tg2/cBZScIiIjLIL5pcD4/c8h4JDRGSUVRTnsXBqCc9v35/uUkKh4BARCcGF8ytYu+sA3X0D6S5l1Ck4RERCsGxBOT39cV5+82C6Sxl1Cg4RkRAsnTuFaMTGZXeVgkNEJAQle97kzme/yZ9ecw5EIlBaCrfeCtu3p7u0d0zBISIy2n72MzjzTN7z3P9Q2N0J7tDeDt/8Jpx5ZuL9LKbgEBEZTdu3w4oV0NlJdKD/2Pf6+qCzM/F+Fh95jCg4zKzIzCLB61PM7P1mFgu3NBGRLPTFLyYC4kT6+uCOO8amnhCM9IjjaSDfzGqAJ4D/A3x7uEZmdqWZbTWzejO7fYj3LzGzVjNbH/z8/XBtzWyKmT1uZtuC35NHuA0iIuH77ndHFhzf+c7Y1BOCkQaHuXsn8HvAV939OmDRCRuYRYG7gKuCZW8ws6HaPOPui4OffxxB29uBJ9y9lkSIvS2QRETSpqNjdJfLQCMODjO7APgg8NNgXs4wbZYC9e6+w917gfuBa0a4vhO1vQa4L3h9H3DtCD9TRCR8xcWju1wGGmlwfBL4LPCwu28ys3nAk8O0qQF2J003BPMGu8DMNpjZz8zs9BG0rXb3fQDB76qhVm5mt5hZnZnVNTc3D1OqiMgo+dCHIDbMKeBYDG68cWzqCcGIgsPdn3L397v7F4KT5C3u/mfDNLOhPmrQ9MvAbHc/C/gq8EgKbYer+R53X+LuSyorK1NpKiJy8j796ZEFx6c+NTb1hGCkV1V938xKzawIeA3YamafGaZZAzAzaXoGsDd5AXdvc/eO4PVqIGZmFcO0bTSzaUFd04CmkWyDiMiYmD8fHnwQCgvfFiAeiyXmP/hgYrksNdKuqkXu3kbifMJqYBYw3HHWWqDWzOaaWS5wPbAqeQEzm2pmFrxeGtSzf5i2q4Cbgtc3AY+OcBtERMbGVVfBK6/ALbdAaSluRntuITuu+6PE/KuuSneF78hIgyMW3LdxLfCou/cxTNeRu/cDtwGPAZuBB4LzIyvNbGWw2Apgo5ltAL4CXO8JQ7YN2nweuMLMtgFXBNMiIpll/ny4805obcX7B7jobx/m7hWfyuojjSOGuzLqiG8Au4ANwNNmNhsY9tFWQffT6kHz7k56fSdw50jbBvP3A5ePsG4RkbSLRIwL5pXz3Pb9uDtBR0vWGunJ8a+4e427Xx0cEbwBXBpybSIi48aF88vZc6iL3Qe60l3KOzbSk+NlZvalI5e3mtkXgaKQaxMRGTcumF8BwHPbs/9xsiM9x3Ev0A78YfDTBnwrrKJERMab+ZVFVJbk8dw4eD7HSM9xzHf330+a/gczWx9CPSIi45KZceH8ctbUt2T9eY6RHnF0mdnyIxNmtgzI/o46EZExtHxBBS0dvWx5qz3dpbwjIz3iWAn8t5mVBdMH+c29FCIiMgIX1SZGsXhmWzOnTStNczUnb6RXVW0IhgU5EzjT3c8GLgu1MhGRcWZqWT6nVBfzzLbsPkGe0hMAgyFCjty/8Rch1CMiMq5dVFvJizsP0N03kO5STto7eXRs9p7ZERFJk4tqK+jtj/PSzgPpLuWkvZPgSGm0WhERgfPmlpMbjfDMtux93MMJT46bWTtDB4QBBaFUJCIyjhXkRjl37uSsPs9xwiMOdy9x99IhfkrcfaRXZImISJKLaivZ8lY7TW3d6S7lpLyTrioRETkJF9Umhh/J1qMOBYeIyBg7bWop5UW5WXueQ8EhIjLGIhFjeW0Fz9a3EI9n33VGCg4RkTQ4MvzI1sbsG35EwSEikgbLg/Mca+qz7zyHgkNEJA2mlRUwv7IoK0+QKzhERNJk+YIKXtp5gJ7+7Bp+JNTgMLMrzWyrmdWb2e0nWO5cMxswsxXB9Klmtj7pp83MPhm89zkz25P03tVhboOISFiWLaigq2+Al984lO5SUhLaTXxmFgXuAq4AGoC1ZrbK3V8bYrkvAI8dmefuW4HFSe/vAR5OanaHu/97WLWLiIyF8+eXE40Ya+pbuGB+ebrLGbEwjziWAvXuvsPde4H7gWuGWO4TwI+BpuN8zuXAdnd/I5wyRUTSozQ/xlkzyng2y06QhxkcNcDupOmGYN5RZlYDXAfcfYLPuR74waB5t5nZK2Z2r5lNHqqRmd1iZnVmVtfcnJ032YjI+Le8tpJXGg7R2tmX7lJGLMzgGGrY9cF3unwZ+Gt3H/LMkJnlAu8HfpQ0++vAfBJdWfuALw7V1t3vcfcl7r6ksrIytcpFRMbI8gUVxB2e37E/3aWMWJjB0QDMTJqeAewdtMwS4H4z2wWsAL5mZtcmvX8V8LK7Nx6Z4e6N7j7g7nHgP0l0iYmIZKXFMydRmBvNqvs5whzhdi1Qa2ZzSZzcvh74o+QF3H3ukddm9m3gJ+7+SNIiNzCom8rMprn7vmDyOmDjqFcuIjJGcnMinD+vPKvOc4R2xOHu/cBtJK6W2gw84O6bzGylma0crr2ZFZK4IuuhQW/9m5m9amavAJcCnxrl0kVExtSyBRXsbDlMw8HOdJcyIqE+U8PdVwOrB80b8kS4u3940HQn8Lbr09z9xlEsUUQk7S5KGn7kA+fOSnM1w9Od4yIiaVZbVUxVSV7WDD+i4BARSTMzY/mCCp7bvj8rhllXcIiIZIDltRUcONzLa/va0l3KsBQcIiIZYPmC7HmcrIJDRCQDVJXmc2p1Cc/WZ/5IFwoOEZEMsby2grW7DtLdl9nDrCs4REQyxPLaCnr746zddSDdpZyQgkNEJEOcN3cKudEIz2b4eQ4Fh4hIhijMzeGc2ZMy/gS5gkNEJINcVFvJa/vaaOnoSXcpx6XgEBHJIEeGH8nk7ioFh4hIBjljehlTinJ5+vXMvSxXwSEikkEikcTwI09va8E9M4cfUXCIiGSYi2oraOnoYfO+9nSXMiQFh4hIhrn4lMTjrp/elpndVQoOEZEMU12az8KpJRl7nkPBISKSgS6qraBu10E6e/vTXcrbKDhERDLQxadU0jsQ58UdmTf8iIJDRCQDnTtnCnk5EZ7KwO6qUIPDzK40s61mVm9mt59guXPNbMDMViTN22Vmr5rZejOrS5o/xcweN7Ntwe/JYW6DiEg65MeinDevPCNPkIcWHGYWBe4CrgIWATeY2aLjLPcF4LEhPuZSd1/s7kuS5t0OPOHutcATwbSIyLhzcW0FO5oPs/tAZ7pLOUaYRxxLgXp33+HuvcD9wDVDLPcJ4MdA0wg/9xrgvuD1fcC177BOEZGMdMmpVQD8OsO6q8IMjhpgd9J0QzDvKDOrAa4D7h6ivQO/MLN1ZnZL0vxqd98HEPyuGmrlZnaLmdWZWV1zc2btdBGRkZhfWcTMKQU8tXWk/68eG2EGhw0xb/D9818G/trdh3rc1TJ3P4dEV9fHzeziVFbu7ve4+xJ3X1JZWZlKUxGRjGBmXHJKFWvq92fUUwHDDI4GYGbS9Axg76BllgD3m9kuYAXwNTO7FsDd9wa/m4CHSXR9ATSa2TSA4HdmRbGIyCi6dGElXX0DGfVUwDCDYy1Qa2ZzzSwXuB5YlbyAu8919znuPgd4ELjV3R8xsyIzKwEwsyLgPcDGoNkq4Kbg9U3AoyFug4hIWl0wr4LcnAhPbsmcLvfQgsPd+4HbSFwttRl4wN03mdlKM1s5TPNq4Fkz2wC8BPzU3X8evPd54Aoz2wZcEUyLiIxLBblRzp9Xzq9fz5zOlZwwP9zdVwOrB80b6kQ47v7hpNc7gLOOs9x+4PLRq1JEJLNdckol//iT13hzfyezygvTXY7uHBcRyXSXLjxyWW5mHHUoOEREMtzciiJmlxfy5BYFh4iIjNClp1bx/I7MuCxXwSEikgUuObWS7r44z+/Yn+5SFBwiItng/HnlFMSi/Gpz+rurFBwiIlkgPxZleW0Fv9rShPvgQTjGloJDRCRLXL6wij2Hutja2J7WOhQcIiJZ4shluU+kubtKwSEikiWqS/M5o6Y07ZflKjhERLLIZQurefnNgxw43Ju2GhQcIiJZ5PKFVcQdnkrjXeQKDhGRLPKumjIqivPSep5DwSEikkUiEeOyhZU89XozfQPx9NSQlrWKiMhJu/y0atq7+3lxR3oe7qTgEBHJMr99SiWFuVF++uq+tKxfwSEikmXyY1EuW1jFLza9RX8auqsUHCIiWejqd01j/+FeXto59t1VCg4RkSx06alVFMTS012l4BARyUIFuYnuqsc2vcVAfGwHPQw1OMzsSjPbamb1Znb7CZY718wGzGxFMD3TzJ40s81mtsnM/jxp2c+Z2R4zWx/8XB3mNoiIZKqr3zWNlo6x764KLTjMLArcBVwFLAJuMLNFx1nuC8BjSbP7gU+7+2nA+cDHB7W9w90XBz+rw9oGEZFMdunCSvJjEVaPcXdVmEccS4F6d9/h7r3A/cA1Qyz3CeDHwNHbIN19n7u/HLxuBzYDNSHWKiKSdQpzc7j01Cp+tnFsu6vCDI4aYHfSdAODvvzNrAa4Drj7eB9iZnOAs4EXk2bfZmavmNm9Zjb5OO1uMbM6M6trbm4+yU0QEclsie6qHta9cXDM1hlmcNgQ8wZH4peBv3b3IZ++bmbFJI5GPunubcHsrwPzgcXAPuCLQ7V193vcfYm7L6msrEy9ehGRLHDJqZXEosbjr701ZusMMzgagJlJ0zOAvYOWWQLcb2a7gBXA18zsWgAzi5EIje+5+0NHGrh7o7sPuHsc+E8SXWIiIhNSSX6MC+ZX8PhrjWP2SNkwg2MtUGtmc80sF7geWJW8gLvPdfc57j4HeBC41d0fMTMD/gvY7O5fSm5jZtOSJq8DNoa4DSIiGe+KRdXs2t/J9uaOMVlfaMHh7v3AbSSultoMPODum8xspZmtHKb5MuBG4LIhLrv9NzN71cxeAS4FPhXWNoiIZIN3n5Z4pOwvXmsck/XZWB3apNOSJUu8rq4u3WWIiITmd7/6LDlR4+Fbl43aZ5rZOndfMni+7hwXERkHrlhUzfrdh2hq7w59XQoOEZFx4IpF1bgzJk8GVHCIiIwDC6eWMGNyAY+PwXkOBYeIyDhgZlyxqJpn61s43NMf6roUHCIi48R7Fk2ltz/Ok1vD7a5ScIiIjBNL506hqiSP/9kw+F7r0aXgEBEZJ6IR431nTuPJrc20dfeFth4Fh4jIOPL+s6bT2x/nsY3hjV2l4BARGUcWz5zEzCkFrAqxu0rBISIyjpgZv3vmdJ7bvp+Wjp5Q1qHgEBEZZ96/eDoDcednIT0ZUMEhIjLOLJxayinVxaF1Vyk4RETGofefNZ21uw6y51DXqH+2gkNEZBz63bOmU1Gcx87mw6P+2Tmj/okiIpJ2s8uLePFvLicaGeop3u+MjjhERMapMEIDFBwiIpIiBYeIiKREwSEiIikJNTjM7Eoz22pm9WZ2+wmWO9fMBsxsxXBtzWyKmT1uZtuC35PD3AYRETlWaMFhZlHgLuAqYBFwg5ktOs5yXwAeG2Hb24En3L0WeCKYFhGRMRLmEcdSoN7dd7h7L3A/cM0Qy30C+DHQNMK21wD3Ba/vA64NoXYRETmOMIOjBtidNN0QzDvKzGqA64C7U2hb7e77AILfVUOt3MxuMbM6M6trbm4+6Y0QEZFjhXkD4FAXEPug6S8Df+3uA2bHLD6Stifk7vcA9wCYWbOZvZFK+yxQAbSku4gsov2VGu2v1IzX/TV7qJlhBkcDMDNpegYweMStJcD9QWhUAFebWf8wbRvNbJq77zOzaRzbxTUkd688uU3IXGZW5+5L0l1HttD+So32V2om2v4Ks6tqLVBrZnPNLBe4HliVvIC7z3X3Oe4+B3gQuNXdHxmm7SrgpuD1TcCjIW6DiIgMEtoRh7v3m9ltJK6WigL3uvsmM1sZvD/4vMawbYO3Pw88YGY3A28CfxDWNoiIyNuZe0qnDiRDmNktwXkcGQHtr9Rof6Vmou0vBYeIiKREQ46IiEhKFBwiIpISBYeIiKREwTEOmdlFZna3mX3TzJ5Ldz2ZzswuMbNngn12SbrryXRmdlqwrx40sz9Ndz2Zzszmmdl/mdmD6a5ltCg4MoyZ3WtmTWa2cdD8EY00DODuz7j7SuAn/GZcr3FpNPYXiVEJOoB8Ejefjluj9Pdrc/D36w9J3MQ7bo3S/trh7jeHW+nY0lVVGcbMLibxJfbf7n5GMC8KvA5cQeKLbS1wA4l7XP510Ed8xN2bgnYPAB9197YxKn/Mjcb+AlrcPW5m1cCX3P2DY1X/WButv19m9n4SI1Pf6e7fH6v6x9oo/3t80N1XMA6EOeSInAR3f9rM5gyafXS0YAAzux+4xt3/FfidoT7HzGYBreM5NGD09lfgIJAXSqEZYrT2l7uvAlaZ2U+BcRsco/z3a9xQV1V2GHak4SHcDHwrtIoyW0r7y8x+z8y+AXwHuDPk2jJRqvvrEjP7SrDPVoddXAZKdX+Vm9ndwNlm9tmwixsLOuLIDimPFuzu/y+kWrJBSvvL3R8CHgqvnIyX6v76NfDrsIrJAqnur/3AyvDKGXs64sgOIxlpWH5D+ys12l+pmfD7S8GRHYYdaViOof2VGu2v1Ez4/aXgyDBm9gPgeeBUM2sws5vdvR84MlrwZuCBpNGCJzTtr9Rof6VG+2touhxXRERSoiMOERFJiYJDRERSouAQEZGUKDhERCQlCg4REUmJgkNERFKi4JAJzcw6xnh9Y/p8FDObZGa3juU6ZfxTcIiMIjM74fhv7n7hGK9zEqDgkFGlQQ5FBjGz+cBdQCXQCXzM3beY2e8CfwfkAvuBD7p7o5l9DpgOzAFazOx1YBYwL/j9ZXf/SvDZHe5eHDxp8HNAC3AGsA74kLu7mV0NfCl472VgnrsfM1y3mX0YeB+Jh08VBc/HeBSYDMSAv3P3R4HPA/PNbD3wuLt/xsw+Q+IhTHnAwxN8QEw5CQoOkbe7B1jp7tvM7Dzga8BlwLPA+cGX+0eBvwI+HbT5LWC5u3cFQbIQuBQoAbaa2dfdvW/Qes4GTicxQN4aYJmZ1QHfAC52953BkBfHcwFwprsfCI46rnP3NjOrAF4ws1UkHrZ0hrsvBjCz9wC1JJ4pYSSeqXGxuz99sjtLJh4Fh0gSMysGLgR+ZHZ09OwjD3eaAfzQzKaROOrYmdR0lbt3JU3/1N17gB4zawKqeftjaV9y94ZgvetJHLF0ADvc/chn/wC45TjlPu7uB46UDvxL8MS6OInnQ1QP0eY9wc//BtPFJIJEwSEjpuAQOVYEOHTkf+iDfJXEo2VXJXU1HXF40LI9Sa8HGPrf2lDLDPWsh+NJXucHSXSt/Za795nZLhLdWIMZ8K/u/o0U1iNyDJ0cF0kSPGp3p5n9AYAlnBW8XQbsCV7fFFIJW4B5SY8r/cAI25UBTUFoXArMDua3k+guO+Ix4CPBkRVmVmNmVe+8bJlIdMQhE12hmSV3IX2JxP/ev25mf0fiRPP9wAYSRxg/MrM9wAvA3NEuJjhHcivwczNrAV4aYdPvAf8TnCNZTyKAcPf9ZrbGzDYCPwtOjp8GPB90xXUAHwKaRnlTZBzTsOoiGcbMit29wxLf7HcB29z9jnTXJXKEuqpEMs/HgpPlm0h0Qel8hGQUHXGIiEhKdMQhIiIpUXCIiEhKFBwiIpISBYeIiKREwSEiIilRcIiISEr+PwEqG1Y98U/WAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 13;\n", + " var nbb_unformatted_code = \"fold = 0\\ntrain_fold(df, fold, lr_find=True, bilinear=True)\";\n", + " var nbb_formatted_code = \"fold = 0\\ntrain_fold(df, fold, lr_find=True, bilinear=True)\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fold = 0\n", + "train_fold(df, fold, lr_find=True, bilinear=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True, used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "Using native 16bit precision.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training fold 0...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + " | Name | Type | Params\n", + "-----------------------------------------\n", + "0 | criterion | L1Loss | 0 \n", + "1 | encoder | Encoder | 18 M \n", + "2 | decoder | Decoder | 12 M \n", + "3 | out | Sequential | 1 K \n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3e59ec39a19d45fba2a7e4ccd4493bfe", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 336.4057808053285 | MAE: 53.408172231985944 | CSI: 0.15876116071428573 | Loss: 0.598244309425354\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "25d0dfa2ddfe4a868855d729ec3a98c5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 851.2272794534249 | MAE: 140.2393953934741 | CSI: 0.16474964886307558 | Loss: 0.40799224376678467\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 | MAE/CSI: 109.83946068802479 | MAE: 81.70424544145874 | CSI: 0.7438514804202483 | Loss: 0.1272396445274353\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 2 | MAE/CSI: 75.52758898412003 | MAE: 58.209938219769676 | CSI: 0.770710928319624 | Loss: 0.0523519441485405\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 3 | MAE/CSI: 23.392370284004784 | MAE: 18.79694412787908 | CSI: 0.8035502131525614 | Loss: 0.02864903211593628\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 4 | MAE/CSI: 32.37117212343894 | MAE: 25.93985927003359 | CSI: 0.8013259195893926 | Loss: 0.023721566423773766\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 5 | MAE/CSI: 26.157452161642333 | MAE: 20.989396349418186 | CSI: 0.8024251069900142 | Loss: 0.019010227173566818\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 6 | MAE/CSI: 25.917166931877333 | MAE: 20.769555992082534 | CSI: 0.8013821899071342 | Loss: 0.017507778480648994\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 7 | MAE/CSI: 38.251320729303224 | MAE: 30.751167638855566 | CSI: 0.8039243365330322 | Loss: 0.021157938987016678\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 8 | MAE/CSI: 58.86530233760733 | MAE: 38.65860817819548 | CSI: 0.6567299689793448 | Loss: 0.021269183605909348\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 9 | MAE/CSI: 41.746526641365925 | MAE: 31.07450567849538 | CSI: 0.7443614637789395 | Loss: 0.01815567910671234\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 10 | MAE/CSI: 56.12451792683211 | MAE: 37.91385808728707 | CSI: 0.6755311134548742 | Loss: 0.020581135526299477\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 11 | MAE/CSI: 25.1557488867044 | MAE: 20.33929562383787 | CSI: 0.8085346898403278 | Loss: 0.014991799369454384\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 12 | MAE/CSI: 26.920995640834708 | MAE: 21.757567042608866 | CSI: 0.8082006822057988 | Loss: 0.015221461653709412\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 13 | MAE/CSI: 23.155391641423765 | MAE: 18.816632869669206 | CSI: 0.812624254473161 | Loss: 0.013913542032241821\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 14 | MAE/CSI: 21.637483453544107 | MAE: 17.401303753111506 | CSI: 0.8042203147353362 | Loss: 0.015407077968120575\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 15 | MAE/CSI: 22.040701014232845 | MAE: 17.96442628007475 | CSI: 0.8150569380008436 | Loss: 0.01368167344480753\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 16 | MAE/CSI: 22.008607869385187 | MAE: 17.85556759122237 | CSI: 0.8112992742279778 | Loss: 0.012915906496345997\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 17 | MAE/CSI: 21.194299734786483 | MAE: 17.151350926492587 | CSI: 0.809243576862351 | Loss: 0.012440902180969715\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 18 | MAE/CSI: 20.337474020254813 | MAE: 16.553419196482125 | CSI: 0.8139368330587983 | Loss: 0.012346466071903706\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 19 | MAE/CSI: 20.196481001123846 | MAE: 16.48095607120565 | CSI: 0.8160310734463276 | Loss: 0.012524261139333248\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 20 | MAE/CSI: 19.579376551711263 | MAE: 16.00924872048757 | CSI: 0.817658758346356 | Loss: 0.012212013825774193\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 21 | MAE/CSI: 19.434638965829198 | MAE: 15.903576101936443 | CSI: 0.8183108587650816 | Loss: 0.012107722461223602\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 22 | MAE/CSI: 19.315971947920712 | MAE: 15.801728971924296 | CSI: 0.8180654338549075 | Loss: 0.012064236216247082\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 23 | MAE/CSI: 19.64415372044433 | MAE: 16.077096355533417 | CSI: 0.8184163382300251 | Loss: 0.011871688067913055\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 24 | MAE/CSI: 18.46500519762341 | MAE: 15.164289529694248 | CSI: 0.8212448015789102 | Loss: 0.011801979504525661\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 25 | MAE/CSI: 18.73391489637479 | MAE: 15.357639994264336 | CSI: 0.819777397260274 | Loss: 0.011744298972189426\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 26 | MAE/CSI: 20.360283202755173 | MAE: 16.612902010809666 | CSI: 0.8159465094543101 | Loss: 0.011868265457451344\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 27 | MAE/CSI: 18.691859697627716 | MAE: 15.343228950134394 | CSI: 0.8208508515641701 | Loss: 0.011681349016726017\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 28 | MAE/CSI: 18.99759392298845 | MAE: 15.576239135507887 | CSI: 0.8199058890631684 | Loss: 0.011633138172328472\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 29 | MAE/CSI: 18.369188522504103 | MAE: 15.073867430000524 | CSI: 0.8206060606060606 | Loss: 0.011501042172312737\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 30 | MAE/CSI: 18.922362349209088 | MAE: 15.535369447206582 | CSI: 0.8210058110337901 | Loss: 0.011564943939447403\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 31 | MAE/CSI: 19.756675427061698 | MAE: 16.16920061371377 | CSI: 0.8184170800086262 | Loss: 0.01166921854019165\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 32 | MAE/CSI: 19.05487915631408 | MAE: 15.649650092243919 | CSI: 0.8212935891036155 | Loss: 0.011614903807640076\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 33 | MAE/CSI: 22.18995794837799 | MAE: 18.11920549447614 | CSI: 0.8165497896213184 | Loss: 0.01295209489762783\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 34 | MAE/CSI: 19.10433483456079 | MAE: 15.671032509831228 | CSI: 0.8202867383512544 | Loss: nan\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 35 | MAE/CSI: 18.393898565887127 | MAE: 15.130253461079954 | CSI: 0.8225691474194469 | Loss: 0.011438152752816677\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 36 | MAE/CSI: 18.77037865542701 | MAE: 15.43640507088513 | CSI: 0.8223811226313907 | Loss: nan\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 37 | MAE/CSI: 18.248775363370108 | MAE: 15.022372152141783 | CSI: 0.823198918995804 | Loss: 0.011384990066289902\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 38 | MAE/CSI: 18.63588369634031 | MAE: 15.310743075204217 | CSI: 0.8215732253250906 | Loss: nan\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 39 | MAE/CSI: 18.42697808474478 | MAE: 15.136916292743354 | CSI: 0.8214540779888587 | Loss: 0.01135366503149271\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 40 | MAE/CSI: 18.3887366130173 | MAE: 15.124588150346577 | CSI: 0.8224919671545876 | Loss: 0.01135934330523014\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 41 | MAE/CSI: 18.447671622670043 | MAE: 15.157825857444367 | CSI: 0.8216660707901323 | Loss: 0.011356715112924576\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 42 | MAE/CSI: 18.691355424587854 | MAE: 15.361878315052465 | CSI: 0.8218707507335575 | Loss: 0.011380273848772049\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 43 | MAE/CSI: 18.04757378042738 | MAE: 14.87064560375836 | CSI: 0.8239692373424482 | Loss: 0.011337735690176487\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 44 | MAE/CSI: 18.266029168429 | MAE: 15.038464900802056 | CSI: 0.8233023588277341 | Loss: 0.011326964013278484\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 45 | MAE/CSI: 18.479300920181704 | MAE: 15.201695430997237 | CSI: 0.8226336860165856 | Loss: 0.011349148117005825\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 46 | MAE/CSI: 18.199572864776194 | MAE: 14.979478572325606 | CSI: 0.82306758975091 | Loss: 0.011300484649837017\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 47 | MAE/CSI: 18.323702636929 | MAE: 15.067870550759695 | CSI: 0.8223158195316962 | Loss: 0.011328000575304031\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 48 | MAE/CSI: 18.407541280349236 | MAE: 15.14070985233944 | CSI: 0.8225275511664519 | Loss: 0.011321539990603924\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 49 | MAE/CSI: 18.377727199363537 | MAE: 15.119684204570163 | CSI: 0.8227178497390807 | Loss: 0.011315426789224148\n", + "\n", + "Model saved at ../models/rainnet_fold0_bs128_epochs50_lr0.001_adamw_cosine.ckpt\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True, used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "Using native 16bit precision.\n", + "\n", + " | Name | Type | Params\n", + "-----------------------------------------\n", + "0 | criterion | L1Loss | 0 \n", + "1 | encoder | Encoder | 18 M \n", + "2 | decoder | Decoder | 12 M \n", + "3 | out | Sequential | 1 K \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training fold 1...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "386082da55bd4ec4849b68d96b9d6c71", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 379.5017715866807 | MAE: 48.814262472885034 | CSI: 0.12862723214285715 | Loss: 0.5731568932533264\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d736a01a792340a48352cccc39d03b8c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 830.5491628719715 | MAE: 117.40163109142166 | CSI: 0.14135422240945475 | Loss: 0.34613731503486633\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 | MAE/CSI: 91.81839230116013 | MAE: 73.22176802255336 | CSI: 0.797462972149387 | Loss: 0.11212663352489471\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 2 | MAE/CSI: 27.0362165882638 | MAE: 21.741607933950867 | CSI: 0.8041660660227764 | Loss: 0.04682685807347298\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 3 | MAE/CSI: 32.60717592544919 | MAE: 25.086674071184053 | CSI: 0.769360527526453 | Loss: 0.049210671335458755\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 4 | MAE/CSI: 32.87824652390067 | MAE: 25.566604315596052 | CSI: 0.7776145938007266 | Loss: 0.028726404532790184\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 5 | MAE/CSI: 190.37916632869667 | MAE: 106.61150056713778 | CSI: 0.5599956267424698 | Loss: 0.10083737224340439\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 6 | MAE/CSI: 33.49095542174912 | MAE: 25.808580073012106 | CSI: 0.7706134312375708 | Loss: 0.02540106698870659\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 7 | MAE/CSI: 27.259353606420433 | MAE: 21.691783022256423 | CSI: 0.7957555903717419 | Loss: 0.02527276985347271\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 8 | MAE/CSI: 59.00217704973967 | MAE: 38.52305011603317 | CSI: 0.652908961028653 | Loss: 0.028482506051659584\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 9 | MAE/CSI: 49.340782296089905 | MAE: 34.691828055656394 | CSI: 0.7031065670467951 | Loss: nan\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 10 | MAE/CSI: 21.547635635406653 | MAE: 17.372861825157162 | CSI: 0.8062537402752843 | Loss: nan\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 11 | MAE/CSI: 29.0514092954408 | MAE: 22.3223084614539 | CSI: 0.7683726539534182 | Loss: 0.023590411990880966\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 12 | MAE/CSI: 37.202140917942096 | MAE: 28.288918404919073 | CSI: 0.7604110329908058 | Loss: 0.02394930273294449\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 13 | MAE/CSI: 22.98714207860777 | MAE: 18.434432383002267 | CSI: 0.8019453797231575 | Loss: 0.020558180287480354\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 14 | MAE/CSI: 23.26957075976557 | MAE: 18.73982572889693 | CSI: 0.8053361156655239 | Loss: 0.020575497299432755\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 15 | MAE/CSI: 23.58595108606626 | MAE: 19.151916746031215 | CSI: 0.8120052770448549 | Loss: 0.021250629797577858\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 16 | MAE/CSI: 20.393563470743203 | MAE: 16.54136147040107 | CSI: 0.8111069698099143 | Loss: 0.019684718921780586\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 17 | MAE/CSI: 20.042915389054823 | MAE: 16.268009843645775 | CSI: 0.8116588593947506 | Loss: 0.019605526700615883\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 18 | MAE/CSI: 22.412473170121277 | MAE: 18.109453748379607 | CSI: 0.8080078271995184 | Loss: 0.01997843012213707\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 19 | MAE/CSI: 19.61756871954229 | MAE: 16.031871469031405 | CSI: 0.8172200999118425 | Loss: 0.019728684797883034\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 20 | MAE/CSI: 19.732200014472237 | MAE: 16.09550881111185 | CSI: 0.8156976312467513 | Loss: 0.019471045583486557\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 21 | MAE/CSI: 19.71513026073653 | MAE: 16.072754084463806 | CSI: 0.8152497027348394 | Loss: 0.01939563825726509\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 22 | MAE/CSI: 19.210671393148324 | MAE: 15.676308675605446 | CSI: 0.8160208643815201 | Loss: 0.01918698474764824\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 23 | MAE/CSI: 21.348405895212498 | MAE: 17.274073813627 | CSI: 0.8091505238561845 | Loss: 0.019472893327474594\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 24 | MAE/CSI: 19.706948234670737 | MAE: 16.079432886548453 | CSI: 0.8159270880023906 | Loss: 0.01924729160964489\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 25 | MAE/CSI: 34.29962118580521 | MAE: 27.4630366398738 | CSI: 0.800680464984406 | Loss: 0.02530735358595848\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 26 | MAE/CSI: 19.79136074088518 | MAE: 16.115654821315218 | CSI: 0.8142772511848341 | Loss: 0.019455544650554657\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 27 | MAE/CSI: 21.48409797192093 | MAE: 17.433473830543793 | CSI: 0.8114594270286486 | Loss: nan\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 28 | MAE/CSI: 23.09111606533413 | MAE: 18.612775743463757 | CSI: 0.8060578661844484 | Loss: 0.019915243610739708\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 29 | MAE/CSI: 19.490008098332357 | MAE: 15.914866592093816 | CSI: 0.8165654170988295 | Loss: 0.019278205931186676\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 30 | MAE/CSI: 19.71195478626922 | MAE: 16.108600758104327 | CSI: 0.817199558985667 | Loss: 0.01950940676033497\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 31 | MAE/CSI: 19.366664139296766 | MAE: 15.780887725744044 | CSI: 0.8148480095600866 | Loss: 0.018996911123394966\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 32 | MAE/CSI: 20.004943414738534 | MAE: 16.282534976241763 | CSI: 0.8139255702280912 | Loss: 0.019085006788372993\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 33 | MAE/CSI: 19.02802758919082 | MAE: 15.560641321557883 | CSI: 0.817774792925901 | Loss: 0.01893746294081211\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 34 | MAE/CSI: 18.453699099554825 | MAE: 15.120410611002473 | CSI: 0.8193701723113488 | Loss: 0.018903637304902077\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 35 | MAE/CSI: 18.984316961791034 | MAE: 15.549896907921621 | CSI: 0.8190917239318793 | Loss: 0.01895761489868164\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 36 | MAE/CSI: 19.094590752430758 | MAE: 15.609456116165118 | CSI: 0.8174805272618334 | Loss: 0.01885439082980156\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 37 | MAE/CSI: 18.80774036261592 | MAE: 15.40562411771875 | CSI: 0.8191108458899 | Loss: 0.01885846070945263\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 38 | MAE/CSI: 18.277189251160205 | MAE: 14.992392518085383 | CSI: 0.8202788903723484 | Loss: 0.01883639022707939\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 39 | MAE/CSI: 19.76735549674251 | MAE: 16.17491163422012 | CSI: 0.8182638105975197 | Loss: 0.018992852419614792\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 40 | MAE/CSI: 18.058117511205 | MAE: 14.836280011266554 | CSI: 0.8215850850479091 | Loss: 0.018756549805402756\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 41 | MAE/CSI: 18.178082741466547 | MAE: 14.930691238261979 | CSI: 0.821356765209621 | Loss: 0.018752722069621086\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 42 | MAE/CSI: 18.276723836123214 | MAE: 15.012342504469848 | CSI: 0.8213913302547297 | Loss: 0.018736498430371284\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 43 | MAE/CSI: 18.148915772713714 | MAE: 14.90182541084059 | CSI: 0.8210862619808307 | Loss: 0.0187591053545475\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 44 | MAE/CSI: 18.212364304487497 | MAE: 14.959138487540471 | CSI: 0.8213726805276101 | Loss: 0.01873149909079075\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 45 | MAE/CSI: 18.16989653471521 | MAE: 14.925030720065882 | CSI: 0.8214152838752884 | Loss: 0.018718460574746132\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 46 | MAE/CSI: 18.20102030473038 | MAE: 14.954469055040585 | CSI: 0.8216280628584196 | Loss: 0.018718944862484932\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 47 | MAE/CSI: 18.24522623158656 | MAE: 14.997007626100649 | CSI: 0.8219688501378642 | Loss: 0.018726889044046402\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 48 | MAE/CSI: 18.17261528656933 | MAE: 14.937805814892833 | CSI: 0.8219953803740406 | Loss: 0.018716935068368912\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 49 | MAE/CSI: 18.298108030165828 | MAE: 15.021140255814736 | CSI: 0.8209122074824862 | Loss: 0.018727868795394897\n", + "\n", + "Model saved at ../models/rainnet_fold1_bs128_epochs50_lr0.001_adamw_cosine.ckpt\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True, used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "Using native 16bit precision.\n", + "\n", + " | Name | Type | Params\n", + "-----------------------------------------\n", + "0 | criterion | L1Loss | 0 \n", + "1 | encoder | Encoder | 18 M \n", + "2 | decoder | Decoder | 12 M \n", + "3 | out | Sequential | 1 K \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training fold 2...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7e2f9c7c1381433c99139a0f32438818", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 270.7111357323698 | MAE: 43.05394736842105 | CSI: 0.15904017857142858 | Loss: 0.547532320022583\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "bc28e50ca8dc49bfab87fc1d05a987bf", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 824.3692924596616 | MAE: 120.68287881462305 | CSI: 0.14639419483192842 | Loss: 0.3203072249889374\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 | MAE/CSI: 135.31623162772803 | MAE: 89.66539859868989 | CSI: 0.6626359418967221 | Loss: 0.08254968374967575\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 2 | MAE/CSI: 45.54290021901838 | MAE: 35.04702791588217 | CSI: 0.7695387809580305 | Loss: 0.042056649923324585\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 3 | MAE/CSI: 211.1032118978899 | MAE: 79.25169823965933 | CSI: 0.3754168282279953 | Loss: 0.041645802557468414\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 4 | MAE/CSI: 36.437417082485744 | MAE: 28.170058166259707 | CSI: 0.7731079868381693 | Loss: 0.02194306254386902\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 5 | MAE/CSI: 70.58472629909586 | MAE: 37.75884085164647 | CSI: 0.5349435045136606 | Loss: 0.03173857554793358\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 6 | MAE/CSI: 27.881159411684454 | MAE: 22.233580963596825 | CSI: 0.7974410473852562 | Loss: 0.01828513666987419\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 7 | MAE/CSI: 19.897511344503304 | MAE: 16.197337333048953 | CSI: 0.8140383514596451 | Loss: 0.01550195924937725\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 8 | MAE/CSI: 27.829000981871314 | MAE: 21.86874469850749 | CSI: 0.7858257187430354 | Loss: 0.015892019495368004\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 9 | MAE/CSI: 21.11183108889901 | MAE: 17.0924955879631 | CSI: 0.8096169165037295 | Loss: 0.013916689902544022\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 10 | MAE/CSI: 24.223921772304422 | MAE: 19.547626823652134 | CSI: 0.806955496610669 | Loss: 0.014188835397362709\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 11 | MAE/CSI: 20.32213817647545 | MAE: 16.51127657430578 | CSI: 0.8124773304316286 | Loss: 0.013155419379472733\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 12 | MAE/CSI: 23.737533199970937 | MAE: 18.58459746541159 | CSI: 0.782920335859103 | Loss: 0.019955884665250778\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 13 | MAE/CSI: 20.86632972141281 | MAE: 16.925578594319987 | CSI: 0.8111430625449317 | Loss: 0.012959692627191544\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 14 | MAE/CSI: 19.70590377919541 | MAE: 16.090719803441804 | CSI: 0.8165431021950865 | Loss: 0.012357279658317566\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 15 | MAE/CSI: 18.776920136624494 | MAE: 15.375098501007782 | CSI: 0.8188296264305766 | Loss: 0.01216345839202404\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 16 | MAE/CSI: 18.964919271119133 | MAE: 15.55531227959926 | CSI: 0.8202150537634408 | Loss: 0.012243330478668213\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 17 | MAE/CSI: 19.23431540066987 | MAE: 15.746667512549674 | CSI: 0.8186757461605332 | Loss: 0.011927510611712933\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 18 | MAE/CSI: 18.291899073685013 | MAE: 15.020821213806396 | CSI: 0.8211734141589089 | Loss: 0.011770730838179588\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 19 | MAE/CSI: 19.015175475346958 | MAE: 15.562383641408847 | CSI: 0.8184191443074692 | Loss: 0.011834865435957909\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 20 | MAE/CSI: 20.120202830802185 | MAE: 16.424505128274667 | CSI: 0.8163190633004025 | Loss: 0.011988245882093906\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 21 | MAE/CSI: 18.569269316572885 | MAE: 15.228230593765055 | CSI: 0.8200769957143895 | Loss: 0.01158822514116764\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 22 | MAE/CSI: 18.248224809794063 | MAE: 15.025846310594067 | CSI: 0.8234141384816374 | Loss: 0.011551343835890293\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 23 | MAE/CSI: 18.163714680331232 | MAE: 14.963916520001776 | CSI: 0.8238356956899046 | Loss: 0.011585965752601624\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 24 | MAE/CSI: 18.350711779249217 | MAE: 15.083799712768469 | CSI: 0.8219735503560529 | Loss: 0.01150451134890318\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 25 | MAE/CSI: 19.28657518492255 | MAE: 15.8863000370731 | CSI: 0.823697306791569 | Loss: 0.011653022840619087\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 26 | MAE/CSI: 17.706335978747877 | MAE: 14.587489906120107 | CSI: 0.8238570601851852 | Loss: 0.011293087154626846\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 27 | MAE/CSI: 18.59037858319951 | MAE: 15.264862740770212 | CSI: 0.8211162926260566 | Loss: 0.011432108469307423\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 28 | MAE/CSI: 17.77012812064616 | MAE: 14.689610925741622 | CSI: 0.8266463148713475 | Loss: 0.011268140748143196\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 29 | MAE/CSI: 18.248919756227128 | MAE: 15.08150958753038 | CSI: 0.8264330047462578 | Loss: 0.011347295716404915\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 30 | MAE/CSI: 18.29934510727998 | MAE: 15.116966286163438 | CSI: 0.8260932944606414 | Loss: 0.011362237855792046\n" + ] } ], "source": [ - "class RainNet(pl.LightningModule):\n", - " def __init__(\n", - " self,\n", - " lr=3e-4,\n", - " enc_chs=[4, 64, 128, 256, 512, 1024],\n", - " dec_chs=[1024, 512, 256, 128, 64],\n", - " num_train_steps=None,\n", - " ):\n", - " super().__init__()\n", - "\n", - " # Parameters\n", - " self.lr = lr\n", - " self.num_train_steps = num_train_steps\n", - "\n", - " # self.criterion = LogCoshLoss()\n", - "# self.criterion = nn.L1Loss()\n", - " self.criterion = nn.SmoothL1Loss()\n", - "\n", - " # Layers\n", - " self.encoder = Encoder(enc_chs)\n", - " self.decoder = Decoder(dec_chs)\n", - " self.out = nn.Sequential(\n", - " nn.Conv2d(64, 2, kernel_size=3, padding=1),\n", - " nn.ReLU(inplace=True),\n", - " nn.BatchNorm2d(2),\n", - " nn.Conv2d(2, 1, kernel_size=1),\n", - " nn.ReLU(inplace=True),\n", - " )\n", - "\n", - " def forward(self, x):\n", - " ftrs = self.encoder(x)\n", - " ftrs = ftrs[::-1]\n", - " x = self.decoder(ftrs[0], ftrs[1:])\n", - " out = self.out(x)\n", - " return out\n", - "\n", - " def shared_step(self, batch, batch_idx):\n", - " x, y = batch\n", - " y_hat = self(x)\n", - " loss = self.criterion(y_hat, y)\n", - " return loss, y, y_hat\n", - "\n", - " def training_step(self, batch, batch_idx):\n", - " loss, y, y_hat = self.shared_step(batch, batch_idx)\n", - " self.log(\"train_loss\", loss)\n", - " return {\"loss\": loss}\n", - "\n", - " def validation_step(self, batch, batch_idx):\n", - " loss, y, y_hat = self.shared_step(batch, batch_idx)\n", - " return {\"loss\": loss, \"y\": y.detach(), \"y_hat\": y_hat.detach()}\n", - "\n", - " def validation_epoch_end(self, outputs):\n", - " avg_loss = torch.stack([x[\"loss\"] for x in outputs]).mean()\n", - " self.log(\"val_loss\", avg_loss)\n", - "\n", - " tfms = nn.Sequential(\n", - " T.CenterCrop(120),\n", - " )\n", - "\n", - " y = torch.cat([x[\"y\"] for x in outputs])\n", - " y = tfms(y)\n", - " y = y.detach().cpu().numpy()\n", - " y = y.reshape(-1, 120 * 120)\n", - "\n", - " y_hat = torch.cat([x[\"y_hat\"] for x in outputs])\n", - " y_hat = tfms(y_hat)\n", - " y_hat = y_hat.detach().cpu().numpy()\n", - " y_hat = y_hat.reshape(-1, 120 * 120)\n", - "\n", - " rng = args[\"rng\"]\n", - " y = rng * y[:, args[\"dams\"]]\n", - " y = y.clip(0, 255)\n", - " y_hat = rng * y_hat[:, args[\"dams\"]]\n", - " y_hat = y_hat.clip(0, 255)\n", - " # mae = metrics.mean_absolute_error(y, y_hat)\n", - "\n", - " y_true = radar2precipitation(y)\n", - " y_true = np.where(y_true >= 0.1, 1, 0)\n", - " y_pred = radar2precipitation(y_hat)\n", - " y_pred = np.where(y_pred >= 0.1, 1, 0)\n", - "\n", - " y *= y_true\n", - " y_hat *= y_true\n", - " mae = metrics.mean_absolute_error(y, y_hat)\n", - "\n", - " tn, fp, fn, tp = metrics.confusion_matrix(\n", - " y_true.reshape(-1), y_pred.reshape(-1)\n", - " ).ravel()\n", - " csi = tp / (tp + fn + fp)\n", - "\n", - " comp_metric = mae / (csi + 1e-12)\n", - "\n", - " print(\n", - " f\"Epoch {self.current_epoch} | MAE/CSI: {comp_metric} | MAE: {mae} | CSI: {csi} | Loss: {avg_loss}\"\n", - " )\n", - "\n", - " def configure_optimizers(self):\n", - " # optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\n", - " optimizer = transformers.AdamW(self.parameters(), lr=self.lr)\n", - " scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n", - " optimizer, T_max=self.num_train_steps\n", - " )\n", - " return [optimizer], [{\"scheduler\": scheduler, \"interval\": \"step\"}]" + "# AdamW bs128 lr 1e-3\n", + "for fold in range(5):\n", + " train_fold(df, fold, bilinear=True)" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "## Train" - ] + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] }, { "cell_type": "code", @@ -1198,8 +4260,6 @@ " max_epochs=args[\"max_epochs\"],\n", " precision=args[\"precision\"],\n", " progress_bar_refresh_rate=50,\n", - "# accumulate_grad_batches=args[\"accumulate_grad_batches\"],\n", - " gradient_clip_val=args[\"gradient_clip_val\"],\n", " # auto_lr_find=True,\n", "# benchmark=True,\n", " )\n", @@ -1227,397 +4287,64 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "application/javascript": [ - "\n", - " setTimeout(function() {\n", - " var nbb_cell_id = 14;\n", - " var nbb_unformatted_code = \"preds = []\\nmodel.eval()\\nwith torch.no_grad():\\n for batch in datamodule.test_dataloader():\\n batch = batch.to(\\\"cuda\\\")\\n imgs = model(batch)\\n imgs = imgs.detach().cpu().numpy()\\n imgs = imgs[:, 0, 4:124, 4:124]\\n imgs = 255.0 * imgs\\n imgs = np.round(imgs)\\n imgs = np.clip(imgs, 0, 255)\\n preds.append(imgs)\\n\\npreds = np.concatenate(preds)\\npreds = preds.astype(np.uint8)\\npreds = preds.reshape(len(preds), -1)\";\n", - " var nbb_formatted_code = \"preds = []\\nmodel.eval()\\nwith torch.no_grad():\\n for batch in datamodule.test_dataloader():\\n batch = batch.to(\\\"cuda\\\")\\n imgs = model(batch)\\n imgs = imgs.detach().cpu().numpy()\\n imgs = imgs[:, 0, 4:124, 4:124]\\n imgs = 255.0 * imgs\\n imgs = np.round(imgs)\\n imgs = np.clip(imgs, 0, 255)\\n preds.append(imgs)\\n\\npreds = np.concatenate(preds)\\npreds = preds.astype(np.uint8)\\npreds = preds.reshape(len(preds), -1)\";\n", - " var nbb_cells = Jupyter.notebook.get_cells();\n", - " for (var i = 0; i < nbb_cells.length; ++i) {\n", - " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", - " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", - " nbb_cells[i].set_text(nbb_formatted_code);\n", - " }\n", - " break;\n", - " }\n", - " }\n", - " }, 500);\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ - "datamodule = NowcastingDataModule()\n", - "datamodule.setup(\"test\")\n", - "\n", - "final_preds = np.zeros((len(datamodule.test_dataset), 120, 120))\n", - "\n", - "for fold in range(5):\n", - " model = RainNet.load_from_checkpoint(f\"rainnet_fold{fold}_bs{args['batch_size']}_epoch{args['max_epochs']}.ckpt\")\n", - " model.to(\"cuda\")\n", - "\n", - " preds = []\n", - " model.eval()\n", - " with torch.no_grad():\n", - " for batch in tqdm(datamodule.test_dataloader()):\n", - " batch = batch.to(\"cuda\")\n", - " imgs = model(batch)\n", - " imgs = imgs.detach().cpu().numpy()\n", - " imgs = imgs[:, 0, 4:124, 4:124]\n", - " imgs = args[\"rng\"] * imgs\n", - " imgs = imgs.clip(0, 255)\n", - " imgs = imgs.round()\n", - " preds.append(imgs)\n", - "\n", - " preds = np.concatenate(preds)\n", - " preds = preds.astype(np.uint8)\n", - " final_preds += preds\n", + "def inference(checkpoints):\n", + " datamodule = NowcastingDataModule()\n", + " datamodule.setup(\"test\")\n", " \n", - " del model\n", - " gc.collect()\n", - " torch.cuda.empty_cache()\n", - " break\n", + " test_paths = datamodule.test_dataset.paths\n", + " test_filenames = [path.name for path in test_paths]\n", + " final_preds = np.zeros((len(datamodule.test_dataset), 14400))\n", " \n", - "final_preds = final_preds.reshape(-1, 14400)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "data": { - "application/javascript": [ - "\n", - " setTimeout(function() {\n", - " var nbb_cell_id = 15;\n", - " var nbb_unformatted_code = \"test_paths = datamodule.test_dataset.paths\\ntest_filenames = [path.name for path in test_paths]\";\n", - " var nbb_formatted_code = \"test_paths = datamodule.test_dataset.paths\\ntest_filenames = [path.name for path in test_paths]\";\n", - " var nbb_cells = Jupyter.notebook.get_cells();\n", - " for (var i = 0; i < nbb_cells.length; ++i) {\n", - " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", - " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", - " nbb_cells[i].set_text(nbb_formatted_code);\n", - " }\n", - " break;\n", - " }\n", - " }\n", - " }, 500);\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "test_paths = datamodule.test_dataset.paths\n", - "test_filenames = [path.name for path in test_paths]" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "67aa02094dbe4d8f8ddc835f0c523658", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14400.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/javascript": [ - "\n", - " setTimeout(function() {\n", - " var nbb_cell_id = 16;\n", - " var nbb_unformatted_code = \"subm = pd.DataFrame()\\nsubm[\\\"file_name\\\"] = test_filenames\\nfor i in tqdm(range(14400)):\\n subm[str(i)] = preds[:, i]\";\n", - " var nbb_formatted_code = \"subm = pd.DataFrame()\\nsubm[\\\"file_name\\\"] = test_filenames\\nfor i in tqdm(range(14400)):\\n subm[str(i)] = preds[:, i]\";\n", - " var nbb_cells = Jupyter.notebook.get_cells();\n", - " for (var i = 0; i < nbb_cells.length; ++i) {\n", - " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", - " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", - " nbb_cells[i].set_text(nbb_formatted_code);\n", - " }\n", - " break;\n", - " }\n", - " }\n", - " }, 500);\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "subm = pd.DataFrame({\"file_name\": test_filenames})\n", - "for i in tqdm(range(14400)):\n", - " subm[str(i)] = final_preds[:, i]" + " for checkpoint in checkpoints:\n", + " print(f\"Inference from {checkpoint}\")\n", + " model = RainNet.load_from_checkpoint(str(checkpoint))\n", + " model.cuda()\n", + " model.eval()\n", + " preds = []\n", + " with torch.no_grad():\n", + " for batch in tqdm(datamodule.test_dataloader()):\n", + " batch = batch.cuda()\n", + " imgs = model(batch)\n", + " imgs = imgs.detach().cpu().numpy()\n", + " imgs = imgs[:, 0, 4:124, 4:124]\n", + " imgs = args[\"rng\"] * imgs\n", + " imgs = imgs.clip(0, 255)\n", + " imgs = imgs.round()\n", + " preds.append(imgs)\n", + " \n", + " preds = np.concatenate(preds)\n", + " preds = preds.astype(np.uint8)\n", + " preds = preds.reshape(-1, 14400)\n", + " final_preds += preds / len(checkpoint)\n", + " \n", + " del model\n", + " gc.collect()\n", + " torch.cuda.empty_cache()\n", + " \n", + " final_preds = final_preds.round()\n", + " final_preds = final_preds.astype(np.uint8)\n", + " \n", + " subm = pd.DataFrame()\n", + " subm[\"file_name\"] = test_filename\n", + " for i in tqdm(range(14400)):\n", + " subm[str(i)] = final_preds[:, i]\n", + " \n", + " return subm" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
file_name012345678...14390143911439214393143941439514396143971439814399
0test_00402.npy000000008...0000000000
1test_00365.npy000000000...0000000000
2test_00122.npy000000000...0000000000
3test_01822.npy000000000...0000000000
4test_01769.npy000000000...0000000000
\n", - "

5 rows Ă— 14401 columns

\n", - "
" - ], - "text/plain": [ - " file_name 0 1 2 3 4 5 6 7 8 ... 14390 14391 14392 14393 \\\n", - "0 test_00402.npy 0 0 0 0 0 0 0 0 8 ... 0 0 0 0 \n", - "1 test_00365.npy 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 \n", - "2 test_00122.npy 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 \n", - "3 test_01822.npy 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 \n", - "4 test_01769.npy 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 \n", - "\n", - " 14394 14395 14396 14397 14398 14399 \n", - "0 0 0 0 0 0 0 \n", - "1 0 0 0 0 0 0 \n", - "2 0 0 0 0 0 0 \n", - "3 0 0 0 0 0 0 \n", - "4 0 0 0 0 0 0 \n", - "\n", - "[5 rows x 14401 columns]" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "application/javascript": [ - "\n", - " setTimeout(function() {\n", - " var nbb_cell_id = 17;\n", - " var nbb_unformatted_code = \"subm.to_csv(\\\"rainnet_fold0_epoch50.csv\\\", index=False)\\nsubm.head()\";\n", - " var nbb_formatted_code = \"subm.to_csv(\\\"rainnet_fold0_epoch50.csv\\\", index=False)\\nsubm.head()\";\n", - " var nbb_cells = Jupyter.notebook.get_cells();\n", - " for (var i = 0; i < nbb_cells.length; ++i) {\n", - " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", - " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", - " nbb_cells[i].set_text(nbb_formatted_code);\n", - " }\n", - " break;\n", - " }\n", - " }\n", - " }, 500);\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ - "subm.to_csv(f\"rainnet_epoch{args['max_epochs']}_lr{args['lr']}.csv\", index=False)\n", + "checkpoints = [args[\"model_dir\"] / f\"rainnet_fold{fold}_bs{args['batch_size']}_epochs{args['max_epochs']}_lr{model.lr}_{args['optimizer']}_{args['scheduler']}.ckpt\" for fold in range(5)]\n", + "output_path = args[\"output_dir\"] / f\"rainnet_bs{args['batch_size']}_epochs{args['max_epochs']}_lr{model.lr}_{args['optimizer']}_{args['scheduler']}.csv\"\n", + "subm.to_csv(output_path, index=False)\n", "subm.head()" ] }, @@ -1680,7 +4407,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python [conda env:torch]", + "display_name": "Python [conda env:torch] *", "language": "python", "name": "conda-env-torch-py" }, diff --git a/notebooks/03-unet.ipynb b/notebooks/03-unet.ipynb index 941b625..b6ee2c4 100644 --- a/notebooks/03-unet.ipynb +++ b/notebooks/03-unet.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -20,7 +20,7 @@ "application/javascript": [ "\n", " setTimeout(function() {\n", - " var nbb_cell_id = 4;\n", + " var nbb_cell_id = 2;\n", " var nbb_unformatted_code = \"%reload_ext autoreload\\n%autoreload 2\\n%matplotlib inline\\n%reload_ext nb_black\";\n", " var nbb_formatted_code = \"%reload_ext autoreload\\n%autoreload 2\\n%matplotlib inline\\n%reload_ext nb_black\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", @@ -52,7 +52,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -60,9 +60,9 @@ "application/javascript": [ "\n", " setTimeout(function() {\n", - " var nbb_cell_id = 5;\n", - " var nbb_unformatted_code = \"import gc\\nfrom pathlib import Path\\nfrom tqdm.notebook import tqdm\\n\\nimport cv2\\nimport numpy as np\\nimport pandas as pd\\nimport matplotlib.pyplot as plt\\nfrom sklearn import metrics\\n\\nimport torch\\nimport torch.nn as nn\\nimport torch.nn.functional as F\\nimport torchvision.transforms as T\\nfrom torch.utils.data import RandomSampler, SequentialSampler\\nimport pytorch_lightning as pl\\n\\nimport transformers\\n\\nimport optim\\nimport loss\\nfrom utils import visualize, radar2precipitation, seed_everything\";\n", - " var nbb_formatted_code = \"import gc\\nfrom pathlib import Path\\nfrom tqdm.notebook import tqdm\\n\\nimport cv2\\nimport numpy as np\\nimport pandas as pd\\nimport matplotlib.pyplot as plt\\nfrom sklearn import metrics\\n\\nimport torch\\nimport torch.nn as nn\\nimport torch.nn.functional as F\\nimport torchvision.transforms as T\\nfrom torch.utils.data import RandomSampler, SequentialSampler\\nimport pytorch_lightning as pl\\n\\nimport transformers\\n\\nimport optim\\nimport loss\\nfrom utils import visualize, radar2precipitation, seed_everything\";\n", + " var nbb_cell_id = 3;\n", + " var nbb_unformatted_code = \"import gc\\nimport warnings\\nfrom pathlib import Path\\nfrom tqdm.notebook import tqdm\\n\\nimport cv2\\nimport numpy as np\\nimport pandas as pd\\nimport matplotlib.pyplot as plt\\nfrom sklearn import metrics\\n\\nimport torch\\nimport torch.nn as nn\\nimport torch.nn.functional as F\\nfrom torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler\\n\\nimport pytorch_lightning as pl\\n\\nimport torchvision.transforms as T\\nimport albumentations as A\\nfrom albumentations.pytorch import ToTensorV2\\n\\nfrom transformers import AdamW, get_cosine_schedule_with_warmup\\n\\nimport optim\\nimport loss\\nfrom utils import visualize, radar2precipitation, seed_everything\";\n", + " var nbb_formatted_code = \"import gc\\nimport warnings\\nfrom pathlib import Path\\nfrom tqdm.notebook import tqdm\\n\\nimport cv2\\nimport numpy as np\\nimport pandas as pd\\nimport matplotlib.pyplot as plt\\nfrom sklearn import metrics\\n\\nimport torch\\nimport torch.nn as nn\\nimport torch.nn.functional as F\\nfrom torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler\\n\\nimport pytorch_lightning as pl\\n\\nimport torchvision.transforms as T\\nimport albumentations as A\\nfrom albumentations.pytorch import ToTensorV2\\n\\nfrom transformers import AdamW, get_cosine_schedule_with_warmup\\n\\nimport optim\\nimport loss\\nfrom utils import visualize, radar2precipitation, seed_everything\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", @@ -85,6 +85,7 @@ ], "source": [ "import gc\n", + "import warnings\n", "from pathlib import Path\n", "from tqdm.notebook import tqdm\n", "\n", @@ -97,17 +98,58 @@ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", - "import torchvision.transforms as T\n", - "from torch.utils.data import RandomSampler, SequentialSampler\n", + "from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler\n", + "\n", "import pytorch_lightning as pl\n", "\n", - "import transformers\n", + "import torchvision.transforms as T\n", + "import albumentations as A\n", + "from albumentations.pytorch import ToTensorV2\n", + "\n", + "from transformers import AdamW, get_cosine_schedule_with_warmup\n", "\n", "import optim\n", "import loss\n", "from utils import visualize, radar2precipitation, seed_everything" ] }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 4;\n", + " var nbb_unformatted_code = \"warnings.simplefilter(\\\"ignore\\\")\";\n", + " var nbb_formatted_code = \"warnings.simplefilter(\\\"ignore\\\")\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "warnings.simplefilter(\"ignore\")" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -124,7 +166,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -132,9 +174,9 @@ "application/javascript": [ "\n", " setTimeout(function() {\n", - " var nbb_cell_id = 24;\n", - " var nbb_unformatted_code = \"args = dict(\\n seed=42,\\n dams=(6071, 6304, 7026, 7629, 7767, 8944, 11107),\\n train_folds_csv=Path(\\\"../input/train_folds.csv\\\"),\\n train_data_path=Path(\\\"../input/train-128\\\"),\\n test_data_path=Path(\\\"../input/test-128\\\"),\\n model_dir=Path(\\\"../models\\\"),\\n rng=255.0,\\n num_workers=4,\\n gpus=1,\\n lr=1e-4,\\n max_epochs=50,\\n batch_size=256,\\n precision=16,\\n optimizer=\\\"adamw\\\",\\n scheduler=\\\"cosine\\\",\\n accumulate_grad_batches=1,\\n gradient_clip_val=5.0,\\n)\";\n", - " var nbb_formatted_code = \"args = dict(\\n seed=42,\\n dams=(6071, 6304, 7026, 7629, 7767, 8944, 11107),\\n train_folds_csv=Path(\\\"../input/train_folds.csv\\\"),\\n train_data_path=Path(\\\"../input/train-128\\\"),\\n test_data_path=Path(\\\"../input/test-128\\\"),\\n model_dir=Path(\\\"../models\\\"),\\n rng=255.0,\\n num_workers=4,\\n gpus=1,\\n lr=1e-4,\\n max_epochs=50,\\n batch_size=256,\\n precision=16,\\n optimizer=\\\"adamw\\\",\\n scheduler=\\\"cosine\\\",\\n accumulate_grad_batches=1,\\n gradient_clip_val=5.0,\\n)\";\n", + " var nbb_cell_id = 5;\n", + " var nbb_unformatted_code = \"args = dict(\\n seed=42,\\n dams=(6071, 6304, 7026, 7629, 7767, 8944, 11107),\\n train_folds_csv=Path(\\\"../input/train_folds.csv\\\"),\\n train_data_path=Path(\\\"../input/train\\\"),\\n test_data_path=Path(\\\"../input/test\\\"),\\n model_dir=Path(\\\"../models\\\"),\\n output_dir=Path(\\\"../output\\\"),\\n rng=255.0,\\n num_workers=4,\\n gpus=1,\\n lr=1e-3,\\n max_epochs=30,\\n batch_size=256,\\n precision=16,\\n optimizer=\\\"adamw\\\",\\n scheduler=\\\"cosine\\\",\\n accumulate_grad_batches=1,\\n gradient_clip_val=5.0,\\n warmup_epochs=1,\\n)\\n\\nargs[\\\"trn_tfms\\\"] = A.Compose(\\n [\\n A.PadIfNeeded(min_height=128, min_width=128, always_apply=True, p=1),\\n ToTensorV2(always_apply=True, p=1),\\n ]\\n)\\n\\nargs[\\\"val_tfms\\\"] = A.Compose(\\n [\\n A.PadIfNeeded(min_height=128, min_width=128, always_apply=True, p=1),\\n ToTensorV2(always_apply=True, p=1),\\n ]\\n)\";\n", + " var nbb_formatted_code = \"args = dict(\\n seed=42,\\n dams=(6071, 6304, 7026, 7629, 7767, 8944, 11107),\\n train_folds_csv=Path(\\\"../input/train_folds.csv\\\"),\\n train_data_path=Path(\\\"../input/train\\\"),\\n test_data_path=Path(\\\"../input/test\\\"),\\n model_dir=Path(\\\"../models\\\"),\\n output_dir=Path(\\\"../output\\\"),\\n rng=255.0,\\n num_workers=4,\\n gpus=1,\\n lr=1e-3,\\n max_epochs=30,\\n batch_size=256,\\n precision=16,\\n optimizer=\\\"adamw\\\",\\n scheduler=\\\"cosine\\\",\\n accumulate_grad_batches=1,\\n gradient_clip_val=5.0,\\n warmup_epochs=1,\\n)\\n\\nargs[\\\"trn_tfms\\\"] = A.Compose(\\n [\\n A.PadIfNeeded(min_height=128, min_width=128, always_apply=True, p=1),\\n ToTensorV2(always_apply=True, p=1),\\n ]\\n)\\n\\nargs[\\\"val_tfms\\\"] = A.Compose(\\n [\\n A.PadIfNeeded(min_height=128, min_width=128, always_apply=True, p=1),\\n ToTensorV2(always_apply=True, p=1),\\n ]\\n)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", @@ -160,20 +202,36 @@ " seed=42,\n", " dams=(6071, 6304, 7026, 7629, 7767, 8944, 11107),\n", " train_folds_csv=Path(\"../input/train_folds.csv\"),\n", - " train_data_path=Path(\"../input/train-128\"),\n", - " test_data_path=Path(\"../input/test-128\"),\n", + " train_data_path=Path(\"../input/train\"),\n", + " test_data_path=Path(\"../input/test\"),\n", " model_dir=Path(\"../models\"),\n", + " output_dir=Path(\"../output\"),\n", " rng=255.0,\n", " num_workers=4,\n", " gpus=1,\n", - " lr=1e-4,\n", - " max_epochs=50,\n", + " lr=1e-3,\n", + " max_epochs=30,\n", " batch_size=256,\n", " precision=16,\n", " optimizer=\"adamw\",\n", " scheduler=\"cosine\",\n", " accumulate_grad_batches=1,\n", " gradient_clip_val=5.0,\n", + " warmup_epochs=1,\n", + ")\n", + "\n", + "args[\"trn_tfms\"] = A.Compose(\n", + " [\n", + " A.PadIfNeeded(min_height=128, min_width=128, always_apply=True, p=1),\n", + " ToTensorV2(always_apply=True, p=1),\n", + " ]\n", + ")\n", + "\n", + "args[\"val_tfms\"] = A.Compose(\n", + " [\n", + " A.PadIfNeeded(min_height=128, min_width=128, always_apply=True, p=1),\n", + " ToTensorV2(always_apply=True, p=1),\n", + " ]\n", ")" ] }, @@ -181,21 +239,79 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Model" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Layers" + "## Dataset" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 6, "metadata": {}, + "outputs": [ + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 6;\n", + " var nbb_unformatted_code = \"class NowcastingDataset(Dataset):\\n def __init__(self, paths, tfms=None, test=False):\\n self.paths = paths\\n if tfms is not None:\\n self.tfms = tfms\\n else:\\n self.tfms = A.Compose(\\n [\\n A.PadIfNeeded(\\n min_height=128, min_width=128, always_apply=True, p=1\\n ),\\n ToTensorV2(always_apply=True, p=1),\\n ]\\n )\\n self.test = test\\n\\n def __len__(self):\\n return len(self.paths)\\n\\n def __getitem__(self, idx):\\n path = self.paths[idx]\\n data = np.load(path)\\n\\n augmented = self.tfms(image=data)\\n data = augmented[\\\"image\\\"]\\n\\n x = data[:4, :, :]\\n x = x / args[\\\"rng\\\"]\\n if self.test:\\n return x\\n else:\\n y = data[4, :, :]\\n y = y / args[\\\"rng\\\"]\\n y = y.unsqueeze(0)\\n\\n return x, y\";\n", + " var nbb_formatted_code = \"class NowcastingDataset(Dataset):\\n def __init__(self, paths, tfms=None, test=False):\\n self.paths = paths\\n if tfms is not None:\\n self.tfms = tfms\\n else:\\n self.tfms = A.Compose(\\n [\\n A.PadIfNeeded(\\n min_height=128, min_width=128, always_apply=True, p=1\\n ),\\n ToTensorV2(always_apply=True, p=1),\\n ]\\n )\\n self.test = test\\n\\n def __len__(self):\\n return len(self.paths)\\n\\n def __getitem__(self, idx):\\n path = self.paths[idx]\\n data = np.load(path)\\n\\n augmented = self.tfms(image=data)\\n data = augmented[\\\"image\\\"]\\n\\n x = data[:4, :, :]\\n x = x / args[\\\"rng\\\"]\\n if self.test:\\n return x\\n else:\\n y = data[4, :, :]\\n y = y / args[\\\"rng\\\"]\\n y = y.unsqueeze(0)\\n\\n return x, y\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "#### Basic" + "class NowcastingDataset(Dataset):\n", + " def __init__(self, paths, tfms=None, test=False):\n", + " self.paths = paths\n", + " if tfms is not None:\n", + " self.tfms = tfms\n", + " else:\n", + " self.tfms = A.Compose(\n", + " [\n", + " A.PadIfNeeded(\n", + " min_height=128, min_width=128, always_apply=True, p=1\n", + " ),\n", + " ToTensorV2(always_apply=True, p=1),\n", + " ]\n", + " )\n", + " self.test = test\n", + "\n", + " def __len__(self):\n", + " return len(self.paths)\n", + "\n", + " def __getitem__(self, idx):\n", + " path = self.paths[idx]\n", + " data = np.load(path)\n", + "\n", + " augmented = self.tfms(image=data)\n", + " data = augmented[\"image\"]\n", + "\n", + " x = data[:4, :, :]\n", + " x = x / args[\"rng\"]\n", + " if self.test:\n", + " return x\n", + " else:\n", + " y = data[4, :, :]\n", + " y = y / args[\"rng\"]\n", + " y = y.unsqueeze(0)\n", + "\n", + " return x, y" ] }, { @@ -209,8 +325,8 @@ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 7;\n", - " var nbb_unformatted_code = \"class BasicBlock(nn.Module):\\n def __init__(self, in_ch, out_ch):\\n assert in_ch == out_ch\\n super().__init__()\\n self.net = nn.Sequential(\\n nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n nn.BatchNorm2d(out_ch),\\n nn.LeakyReLU(inplace=True),\\n nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),\\n )\\n\\n def forward(self, x):\\n return x + self.net(x)\";\n", - " var nbb_formatted_code = \"class BasicBlock(nn.Module):\\n def __init__(self, in_ch, out_ch):\\n assert in_ch == out_ch\\n super().__init__()\\n self.net = nn.Sequential(\\n nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n nn.BatchNorm2d(out_ch),\\n nn.LeakyReLU(inplace=True),\\n nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),\\n )\\n\\n def forward(self, x):\\n return x + self.net(x)\";\n", + " var nbb_unformatted_code = \"class NowcastingDataModule(pl.LightningDataModule):\\n def __init__(\\n self,\\n train_df=None,\\n val_df=None,\\n batch_size=args[\\\"batch_size\\\"],\\n num_workers=args[\\\"num_workers\\\"],\\n test=False,\\n ):\\n super().__init__()\\n self.train_df = train_df\\n self.val_df = val_df\\n self.batch_size = batch_size\\n self.num_workers = num_workers\\n self.test = test\\n\\n def setup(self, stage=\\\"train\\\"):\\n if stage == \\\"train\\\":\\n train_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.train_df.filename.values\\n ]\\n val_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.val_df.filename.values\\n ]\\n self.train_dataset = NowcastingDataset(train_paths, tfms=args[\\\"trn_tfms\\\"])\\n self.val_dataset = NowcastingDataset(val_paths, tfms=args[\\\"val_tfms\\\"])\\n else:\\n test_paths = list(sorted(args[\\\"test_data_path\\\"].glob(\\\"*.npy\\\")))\\n self.test_dataset = NowcastingDataset(test_paths, test=True)\\n\\n def train_dataloader(self):\\n return DataLoader(\\n self.train_dataset,\\n batch_size=self.batch_size,\\n sampler=RandomSampler(self.train_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n drop_last=True,\\n )\\n\\n def val_dataloader(self):\\n return DataLoader(\\n self.val_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.val_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\\n\\n def test_dataloader(self):\\n return DataLoader(\\n self.test_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.test_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\";\n", + " var nbb_formatted_code = \"class NowcastingDataModule(pl.LightningDataModule):\\n def __init__(\\n self,\\n train_df=None,\\n val_df=None,\\n batch_size=args[\\\"batch_size\\\"],\\n num_workers=args[\\\"num_workers\\\"],\\n test=False,\\n ):\\n super().__init__()\\n self.train_df = train_df\\n self.val_df = val_df\\n self.batch_size = batch_size\\n self.num_workers = num_workers\\n self.test = test\\n\\n def setup(self, stage=\\\"train\\\"):\\n if stage == \\\"train\\\":\\n train_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.train_df.filename.values\\n ]\\n val_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.val_df.filename.values\\n ]\\n self.train_dataset = NowcastingDataset(train_paths, tfms=args[\\\"trn_tfms\\\"])\\n self.val_dataset = NowcastingDataset(val_paths, tfms=args[\\\"val_tfms\\\"])\\n else:\\n test_paths = list(sorted(args[\\\"test_data_path\\\"].glob(\\\"*.npy\\\")))\\n self.test_dataset = NowcastingDataset(test_paths, test=True)\\n\\n def train_dataloader(self):\\n return DataLoader(\\n self.train_dataset,\\n batch_size=self.batch_size,\\n sampler=RandomSampler(self.train_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n drop_last=True,\\n )\\n\\n def val_dataloader(self):\\n return DataLoader(\\n self.val_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.val_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\\n\\n def test_dataloader(self):\\n return DataLoader(\\n self.test_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.test_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", @@ -232,26 +348,77 @@ } ], "source": [ - "class BasicBlock(nn.Module):\n", - " def __init__(self, in_ch, out_ch):\n", - " assert in_ch == out_ch\n", + "class NowcastingDataModule(pl.LightningDataModule):\n", + " def __init__(\n", + " self,\n", + " train_df=None,\n", + " val_df=None,\n", + " batch_size=args[\"batch_size\"],\n", + " num_workers=args[\"num_workers\"],\n", + " test=False,\n", + " ):\n", " super().__init__()\n", - " self.net = nn.Sequential(\n", - " nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),\n", - " nn.BatchNorm2d(out_ch),\n", - " nn.LeakyReLU(inplace=True),\n", - " nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),\n", + " self.train_df = train_df\n", + " self.val_df = val_df\n", + " self.batch_size = batch_size\n", + " self.num_workers = num_workers\n", + " self.test = test\n", + "\n", + " def setup(self, stage=\"train\"):\n", + " if stage == \"train\":\n", + " train_paths = [\n", + " args[\"train_data_path\"] / fn for fn in self.train_df.filename.values\n", + " ]\n", + " val_paths = [\n", + " args[\"train_data_path\"] / fn for fn in self.val_df.filename.values\n", + " ]\n", + " self.train_dataset = NowcastingDataset(train_paths, tfms=args[\"trn_tfms\"])\n", + " self.val_dataset = NowcastingDataset(val_paths, tfms=args[\"val_tfms\"])\n", + " else:\n", + " test_paths = list(sorted(args[\"test_data_path\"].glob(\"*.npy\")))\n", + " self.test_dataset = NowcastingDataset(test_paths, test=True)\n", + "\n", + " def train_dataloader(self):\n", + " return DataLoader(\n", + " self.train_dataset,\n", + " batch_size=self.batch_size,\n", + " sampler=RandomSampler(self.train_dataset),\n", + " pin_memory=True,\n", + " num_workers=self.num_workers,\n", + " drop_last=True,\n", " )\n", "\n", - " def forward(self, x):\n", - " return x + self.net(x)" + " def val_dataloader(self):\n", + " return DataLoader(\n", + " self.val_dataset,\n", + " batch_size=2 * self.batch_size,\n", + " sampler=SequentialSampler(self.val_dataset),\n", + " pin_memory=True,\n", + " num_workers=self.num_workers,\n", + " )\n", + "\n", + " def test_dataloader(self):\n", + " return DataLoader(\n", + " self.test_dataset,\n", + " batch_size=2 * self.batch_size,\n", + " sampler=SequentialSampler(self.test_dataset),\n", + " pin_memory=True,\n", + " num_workers=self.num_workers,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "#### Encoder" + "### Basic" ] }, { @@ -265,8 +432,8 @@ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 8;\n", - " var nbb_unformatted_code = \"class DownBlock(nn.Module):\\n def __init__(self, in_ch, out_ch):\\n super().__init__()\\n self.id_conv = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=2)\\n self.net = nn.Sequential(\\n nn.BatchNorm2d(in_ch),\\n nn.LeakyReLU(inplace=True),\\n nn.MaxPool2d(2),\\n nn.BatchNorm2d(in_ch),\\n nn.LeakyReLU(inplace=True),\\n nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),\\n )\\n\\n def forward(self, x):\\n residual = x\\n residual = self.id_conv(residual)\\n x = self.net(x)\\n return residual + x, x\";\n", - " var nbb_formatted_code = \"class DownBlock(nn.Module):\\n def __init__(self, in_ch, out_ch):\\n super().__init__()\\n self.id_conv = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=2)\\n self.net = nn.Sequential(\\n nn.BatchNorm2d(in_ch),\\n nn.LeakyReLU(inplace=True),\\n nn.MaxPool2d(2),\\n nn.BatchNorm2d(in_ch),\\n nn.LeakyReLU(inplace=True),\\n nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),\\n )\\n\\n def forward(self, x):\\n residual = x\\n residual = self.id_conv(residual)\\n x = self.net(x)\\n return residual + x, x\";\n", + " var nbb_unformatted_code = \"class BasicBlock(nn.Module):\\n def __init__(self, in_ch, out_ch):\\n assert in_ch == out_ch\\n super().__init__()\\n self.net = nn.Sequential(\\n nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n nn.BatchNorm2d(out_ch),\\n nn.LeakyReLU(inplace=True),\\n nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),\\n )\\n\\n def forward(self, x):\\n return x + self.net(x)\";\n", + " var nbb_formatted_code = \"class BasicBlock(nn.Module):\\n def __init__(self, in_ch, out_ch):\\n assert in_ch == out_ch\\n super().__init__()\\n self.net = nn.Sequential(\\n nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n nn.BatchNorm2d(out_ch),\\n nn.LeakyReLU(inplace=True),\\n nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),\\n )\\n\\n def forward(self, x):\\n return x + self.net(x)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", @@ -288,24 +455,26 @@ } ], "source": [ - "class DownBlock(nn.Module):\n", + "class BasicBlock(nn.Module):\n", " def __init__(self, in_ch, out_ch):\n", + " assert in_ch == out_ch\n", " super().__init__()\n", - " self.id_conv = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=2)\n", " self.net = nn.Sequential(\n", - " nn.BatchNorm2d(in_ch),\n", - " nn.LeakyReLU(inplace=True),\n", - " nn.MaxPool2d(2),\n", - " nn.BatchNorm2d(in_ch),\n", + " nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),\n", + " nn.BatchNorm2d(out_ch),\n", " nn.LeakyReLU(inplace=True),\n", - " nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),\n", + " nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),\n", " )\n", "\n", " def forward(self, x):\n", - " residual = x\n", - " residual = self.id_conv(residual)\n", - " x = self.net(x)\n", - " return residual + x, x" + " return x + self.net(x)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Encoder" ] }, { @@ -319,8 +488,8 @@ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 9;\n", - " var nbb_unformatted_code = \"class Encoder(nn.Module):\\n def __init__(self, chs=[4, 64, 128, 256, 512, 1024]):\\n super().__init__()\\n self.blocks = nn.ModuleList(\\n [DownBlock(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n self.basic = BasicBlock(chs[-1], chs[-1])\\n\\n def forward(self, x):\\n feats = []\\n for block in self.blocks:\\n x, feat = block(x)\\n feats.append(feat)\\n x = self.basic(x)\\n feats.append(x)\\n return feats\";\n", - " var nbb_formatted_code = \"class Encoder(nn.Module):\\n def __init__(self, chs=[4, 64, 128, 256, 512, 1024]):\\n super().__init__()\\n self.blocks = nn.ModuleList(\\n [DownBlock(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n self.basic = BasicBlock(chs[-1], chs[-1])\\n\\n def forward(self, x):\\n feats = []\\n for block in self.blocks:\\n x, feat = block(x)\\n feats.append(feat)\\n x = self.basic(x)\\n feats.append(x)\\n return feats\";\n", + " var nbb_unformatted_code = \"class DownBlock(nn.Module):\\n def __init__(self, in_ch, out_ch):\\n super().__init__()\\n self.id_conv = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=2)\\n self.net = nn.Sequential(\\n nn.BatchNorm2d(in_ch),\\n nn.LeakyReLU(inplace=True),\\n nn.MaxPool2d(2),\\n nn.BatchNorm2d(in_ch),\\n nn.LeakyReLU(inplace=True),\\n nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),\\n )\\n\\n def forward(self, x):\\n residual = x\\n residual = self.id_conv(residual)\\n x = self.net(x)\\n return residual + x, x\\n\\n\\nclass Encoder(nn.Module):\\n def __init__(self, chs=[4, 64, 128, 256, 512, 1024]):\\n super().__init__()\\n self.blocks = nn.ModuleList(\\n [DownBlock(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n self.basic = BasicBlock(chs[-1], chs[-1])\\n\\n def forward(self, x):\\n feats = []\\n for block in self.blocks:\\n x, feat = block(x)\\n feats.append(feat)\\n x = self.basic(x)\\n feats.append(x)\\n return feats\";\n", + " var nbb_formatted_code = \"class DownBlock(nn.Module):\\n def __init__(self, in_ch, out_ch):\\n super().__init__()\\n self.id_conv = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=2)\\n self.net = nn.Sequential(\\n nn.BatchNorm2d(in_ch),\\n nn.LeakyReLU(inplace=True),\\n nn.MaxPool2d(2),\\n nn.BatchNorm2d(in_ch),\\n nn.LeakyReLU(inplace=True),\\n nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),\\n )\\n\\n def forward(self, x):\\n residual = x\\n residual = self.id_conv(residual)\\n x = self.net(x)\\n return residual + x, x\\n\\n\\nclass Encoder(nn.Module):\\n def __init__(self, chs=[4, 64, 128, 256, 512, 1024]):\\n super().__init__()\\n self.blocks = nn.ModuleList(\\n [DownBlock(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n self.basic = BasicBlock(chs[-1], chs[-1])\\n\\n def forward(self, x):\\n feats = []\\n for block in self.blocks:\\n x, feat = block(x)\\n feats.append(feat)\\n x = self.basic(x)\\n feats.append(x)\\n return feats\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", @@ -342,6 +511,26 @@ } ], "source": [ + "class DownBlock(nn.Module):\n", + " def __init__(self, in_ch, out_ch):\n", + " super().__init__()\n", + " self.id_conv = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=2)\n", + " self.net = nn.Sequential(\n", + " nn.BatchNorm2d(in_ch),\n", + " nn.LeakyReLU(inplace=True),\n", + " nn.MaxPool2d(2),\n", + " nn.BatchNorm2d(in_ch),\n", + " nn.LeakyReLU(inplace=True),\n", + " nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),\n", + " )\n", + "\n", + " def forward(self, x):\n", + " residual = x\n", + " residual = self.id_conv(residual)\n", + " x = self.net(x)\n", + " return residual + x, x\n", + "\n", + "\n", "class Encoder(nn.Module):\n", " def __init__(self, chs=[4, 64, 128, 256, 512, 1024]):\n", " super().__init__()\n", @@ -364,7 +553,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "#### Decoder" + "### Decoder" ] }, { @@ -378,8 +567,8 @@ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 10;\n", - " var nbb_unformatted_code = \"class UpBlock(nn.Module):\\n def __init__(self, in_ch, out_ch, bilinear=False):\\n super().__init__()\\n self.id_conv = nn.ConvTranspose2d(\\n in_ch + in_ch, out_ch, kernel_size=2, stride=2\\n )\\n layers = []\\n if bilinear:\\n layers.append(nn.Upsample(scale_factor=2, mode=\\\"nearest\\\"))\\n else:\\n layers.append(\\n nn.ConvTranspose2d(in_ch + in_ch, out_ch, kernel_size=2, stride=2)\\n )\\n layers.extend(\\n [\\n nn.BatchNorm2d(out_ch),\\n nn.LeakyReLU(inplace=True),\\n nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n nn.BatchNorm2d(out_ch),\\n nn.LeakyReLU(inplace=True),\\n nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n ]\\n )\\n self.block = nn.Sequential(*layers)\\n\\n def forward(self, x, feat):\\n x = torch.cat([x, feat], dim=1)\\n residual = x\\n residual = self.id_conv(residual)\\n x = self.block(x)\\n return x + residual\";\n", - " var nbb_formatted_code = \"class UpBlock(nn.Module):\\n def __init__(self, in_ch, out_ch, bilinear=False):\\n super().__init__()\\n self.id_conv = nn.ConvTranspose2d(\\n in_ch + in_ch, out_ch, kernel_size=2, stride=2\\n )\\n layers = []\\n if bilinear:\\n layers.append(nn.Upsample(scale_factor=2, mode=\\\"nearest\\\"))\\n else:\\n layers.append(\\n nn.ConvTranspose2d(in_ch + in_ch, out_ch, kernel_size=2, stride=2)\\n )\\n layers.extend(\\n [\\n nn.BatchNorm2d(out_ch),\\n nn.LeakyReLU(inplace=True),\\n nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n nn.BatchNorm2d(out_ch),\\n nn.LeakyReLU(inplace=True),\\n nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n ]\\n )\\n self.block = nn.Sequential(*layers)\\n\\n def forward(self, x, feat):\\n x = torch.cat([x, feat], dim=1)\\n residual = x\\n residual = self.id_conv(residual)\\n x = self.block(x)\\n return x + residual\";\n", + " var nbb_unformatted_code = \"class UpBlock(nn.Module):\\n def __init__(self, in_ch, out_ch, bilinear=False):\\n super().__init__()\\n self.id_conv = nn.ConvTranspose2d(\\n in_ch + in_ch, out_ch, kernel_size=2, stride=2\\n )\\n layers = []\\n if bilinear:\\n layers.append(nn.Upsample(scale_factor=2, mode=\\\"nearest\\\"))\\n else:\\n layers.append(\\n nn.ConvTranspose2d(in_ch + in_ch, out_ch, kernel_size=2, stride=2)\\n )\\n layers.extend(\\n [\\n nn.BatchNorm2d(out_ch),\\n nn.LeakyReLU(inplace=True),\\n nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n nn.BatchNorm2d(out_ch),\\n nn.LeakyReLU(inplace=True),\\n nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n ]\\n )\\n self.block = nn.Sequential(*layers)\\n\\n def forward(self, x, feat):\\n x = torch.cat([x, feat], dim=1)\\n residual = x\\n residual = self.id_conv(residual)\\n x = self.block(x)\\n return x + residual\\n\\n\\nclass Decoder(nn.Module):\\n def __init__(self, chs=[1024, 512, 256, 128, 64]):\\n super().__init__()\\n self.blocks = nn.ModuleList(\\n [UpBlock(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n\\n def forward(self, x, feats):\\n for block, feat in zip(self.blocks, feats):\\n x = block(x, feat)\\n return x\";\n", + " var nbb_formatted_code = \"class UpBlock(nn.Module):\\n def __init__(self, in_ch, out_ch, bilinear=False):\\n super().__init__()\\n self.id_conv = nn.ConvTranspose2d(\\n in_ch + in_ch, out_ch, kernel_size=2, stride=2\\n )\\n layers = []\\n if bilinear:\\n layers.append(nn.Upsample(scale_factor=2, mode=\\\"nearest\\\"))\\n else:\\n layers.append(\\n nn.ConvTranspose2d(in_ch + in_ch, out_ch, kernel_size=2, stride=2)\\n )\\n layers.extend(\\n [\\n nn.BatchNorm2d(out_ch),\\n nn.LeakyReLU(inplace=True),\\n nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n nn.BatchNorm2d(out_ch),\\n nn.LeakyReLU(inplace=True),\\n nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n ]\\n )\\n self.block = nn.Sequential(*layers)\\n\\n def forward(self, x, feat):\\n x = torch.cat([x, feat], dim=1)\\n residual = x\\n residual = self.id_conv(residual)\\n x = self.block(x)\\n return x + residual\\n\\n\\nclass Decoder(nn.Module):\\n def __init__(self, chs=[1024, 512, 256, 128, 64]):\\n super().__init__()\\n self.blocks = nn.ModuleList(\\n [UpBlock(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n\\n def forward(self, x, feats):\\n for block, feat in zip(self.blocks, feats):\\n x = block(x, feat)\\n return x\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", @@ -431,12 +620,32 @@ " residual = x\n", " residual = self.id_conv(residual)\n", " x = self.block(x)\n", - " return x + residual" + " return x + residual\n", + "\n", + "\n", + "class Decoder(nn.Module):\n", + " def __init__(self, chs=[1024, 512, 256, 128, 64]):\n", + " super().__init__()\n", + " self.blocks = nn.ModuleList(\n", + " [UpBlock(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]\n", + " )\n", + "\n", + " def forward(self, x, feats):\n", + " for block, feat in zip(self.blocks, feats):\n", + " x = block(x, feat)\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### U-Net" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -444,9 +653,9 @@ "application/javascript": [ "\n", " setTimeout(function() {\n", - " var nbb_cell_id = 11;\n", - " var nbb_unformatted_code = \"class Decoder(nn.Module):\\n def __init__(self, chs=[1024, 512, 256, 128, 64]):\\n super().__init__()\\n self.blocks = nn.ModuleList(\\n [UpBlock(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n\\n def forward(self, x, feats):\\n for block, feat in zip(self.blocks, feats):\\n x = block(x, feat)\\n return x\";\n", - " var nbb_formatted_code = \"class Decoder(nn.Module):\\n def __init__(self, chs=[1024, 512, 256, 128, 64]):\\n super().__init__()\\n self.blocks = nn.ModuleList(\\n [UpBlock(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n\\n def forward(self, x, feats):\\n for block, feat in zip(self.blocks, feats):\\n x = block(x, feat)\\n return x\";\n", + " var nbb_cell_id = 16;\n", + " var nbb_unformatted_code = \"class UNet(pl.LightningModule):\\n def __init__(\\n self,\\n lr=args[\\\"lr\\\"],\\n enc_chs=[4, 64, 128, 256, 512, 1024],\\n dec_chs=[1024, 512, 256, 128, 64],\\n num_train_steps=None,\\n ):\\n super().__init__()\\n self.lr = lr\\n self.num_train_steps = num_train_steps\\n self.criterion = nn.L1Loss()\\n\\n self.tail = BasicBlock(4, enc_chs[0])\\n self.encoder = Encoder(enc_chs)\\n self.decoder = Decoder(dec_chs)\\n self.head = nn.Sequential(\\n nn.ConvTranspose2d(dec_chs[-1], 32, kernel_size=2, stride=2, bias=False),\\n nn.BatchNorm2d(32),\\n nn.LeakyReLU(inplace=True),\\n nn.Conv2d(32, 1, kernel_size=3, padding=1),\\n # nn.ReLU(inplace=True),\\n nn.Sigmoid(),\\n )\\n\\n def forward(self, x):\\n x = self.tail(x)\\n feats = self.encoder(x)\\n feats = feats[::-1]\\n x = self.decoder(feats[0], feats[1:])\\n x = self.head(x)\\n return x\\n\\n def shared_step(self, batch, batch_idx):\\n x, y = batch\\n y_hat = self(x)\\n loss = self.criterion(y_hat, y)\\n return loss, y, y_hat\\n\\n def training_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n self.log(\\\"train_loss\\\", loss)\\n for i, param_group in enumerate(self.optimizer.param_groups):\\n self.log(f\\\"lr/lr{i}\\\", param_group[\\\"lr\\\"])\\n return {\\\"loss\\\": loss}\\n\\n def validation_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n return {\\\"loss\\\": loss, \\\"y\\\": y.detach(), \\\"y_hat\\\": y_hat.detach()}\\n\\n def validation_epoch_end(self, outputs):\\n avg_loss = torch.stack([x[\\\"loss\\\"] for x in outputs]).mean()\\n self.log(\\\"val_loss\\\", avg_loss)\\n\\n y = torch.cat([x[\\\"y\\\"] for x in outputs])\\n y_hat = torch.cat([x[\\\"y_hat\\\"] for x in outputs])\\n\\n crop = T.CenterCrop(120)\\n y = crop(y)\\n y_hat = crop(y_hat)\\n\\n batch_size = len(y)\\n y = y.detach().cpu().numpy()\\n y *= args[\\\"rng\\\"]\\n y = y.reshape(batch_size, -1)\\n y = y[:, args[\\\"dams\\\"]]\\n y_hat = y_hat.detach().cpu().numpy()\\n y_hat *= args[\\\"rng\\\"]\\n y_hat = y_hat.reshape(batch_size, -1)\\n y_hat = y_hat[:, args[\\\"dams\\\"]]\\n\\n y_true = radar2precipitation(y)\\n y_true = np.where(y_true >= 0.1, 1, 0)\\n y_true = y_true.ravel()\\n y_pred = radar2precipitation(y_hat)\\n y_pred = np.where(y_pred >= 0.1, 1, 0)\\n y_pred = y_pred.ravel()\\n\\n y = y.ravel()\\n y_hat = y_hat.ravel()\\n mae = metrics.mean_absolute_error(y, y_hat, sample_weight=y_true)\\n self.log(\\\"mae\\\", mae)\\n\\n tn, fp, fn, tp = metrics.confusion_matrix(y_true, y_pred).ravel()\\n csi = tp / (tp + fn + fp)\\n self.log(\\\"csi\\\", csi)\\n\\n comp_metric = mae / (csi + 1e-12)\\n self.log(\\\"comp_metric\\\", comp_metric)\\n\\n print(\\n f\\\"Epoch {self.current_epoch} | MAE/CSI: {comp_metric} | MAE: {mae} | CSI: {csi} | Loss: {avg_loss}\\\"\\n )\\n\\n def configure_optimizers(self):\\n # optimizer\\n if args[\\\"optimizer\\\"] == \\\"adam\\\":\\n self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"adamw\\\":\\n self.optimizer = AdamW(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"radam\\\":\\n self.optimizer = optim.RAdam(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"ranger\\\":\\n self.optimizer = optim.RAdam(self.parameters(), lr=self.lr)\\n self.optimizer = optim.Lookahead(self.optimizer)\\n\\n # scheduler\\n if args[\\\"scheduler\\\"] == \\\"cosine\\\":\\n self.scheduler = get_cosine_schedule_with_warmup(\\n self.optimizer,\\n num_warmup_steps=self.num_train_steps * args[\\\"warmup_epochs\\\"],\\n num_training_steps=self.num_train_steps * args[\\\"max_epochs\\\"],\\n )\\n return [self.optimizer], [{\\\"scheduler\\\": self.scheduler, \\\"interval\\\": \\\"step\\\"}]\\n elif args[\\\"scheduler\\\"] == \\\"step\\\":\\n self.scheduler = torch.optim.lr_scheduler.StepLR(\\n self.optimizer, step_size=10, gamma=0.5\\n )\\n return [self.optimizer], [\\n {\\\"scheduler\\\": self.scheduler, \\\"interval\\\": \\\"epoch\\\"}\\n ]\\n elif args[\\\"scheduler\\\"] == \\\"plateau\\\":\\n self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\\n self.optimizer, mode=\\\"min\\\", factor=0.1, patience=3, verbose=True\\n )\\n return [self.optimizer], [\\n {\\n \\\"scheduler\\\": self.scheduler,\\n \\\"interval\\\": \\\"epoch\\\",\\n \\\"reduce_on_plateau\\\": True,\\n \\\"monitor\\\": \\\"comp_metric\\\",\\n }\\n ]\\n else:\\n self.scheduler = None\\n return [self.optimizer]\";\n", + " var nbb_formatted_code = \"class UNet(pl.LightningModule):\\n def __init__(\\n self,\\n lr=args[\\\"lr\\\"],\\n enc_chs=[4, 64, 128, 256, 512, 1024],\\n dec_chs=[1024, 512, 256, 128, 64],\\n num_train_steps=None,\\n ):\\n super().__init__()\\n self.lr = lr\\n self.num_train_steps = num_train_steps\\n self.criterion = nn.L1Loss()\\n\\n self.tail = BasicBlock(4, enc_chs[0])\\n self.encoder = Encoder(enc_chs)\\n self.decoder = Decoder(dec_chs)\\n self.head = nn.Sequential(\\n nn.ConvTranspose2d(dec_chs[-1], 32, kernel_size=2, stride=2, bias=False),\\n nn.BatchNorm2d(32),\\n nn.LeakyReLU(inplace=True),\\n nn.Conv2d(32, 1, kernel_size=3, padding=1),\\n # nn.ReLU(inplace=True),\\n nn.Sigmoid(),\\n )\\n\\n def forward(self, x):\\n x = self.tail(x)\\n feats = self.encoder(x)\\n feats = feats[::-1]\\n x = self.decoder(feats[0], feats[1:])\\n x = self.head(x)\\n return x\\n\\n def shared_step(self, batch, batch_idx):\\n x, y = batch\\n y_hat = self(x)\\n loss = self.criterion(y_hat, y)\\n return loss, y, y_hat\\n\\n def training_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n self.log(\\\"train_loss\\\", loss)\\n for i, param_group in enumerate(self.optimizer.param_groups):\\n self.log(f\\\"lr/lr{i}\\\", param_group[\\\"lr\\\"])\\n return {\\\"loss\\\": loss}\\n\\n def validation_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n return {\\\"loss\\\": loss, \\\"y\\\": y.detach(), \\\"y_hat\\\": y_hat.detach()}\\n\\n def validation_epoch_end(self, outputs):\\n avg_loss = torch.stack([x[\\\"loss\\\"] for x in outputs]).mean()\\n self.log(\\\"val_loss\\\", avg_loss)\\n\\n y = torch.cat([x[\\\"y\\\"] for x in outputs])\\n y_hat = torch.cat([x[\\\"y_hat\\\"] for x in outputs])\\n\\n crop = T.CenterCrop(120)\\n y = crop(y)\\n y_hat = crop(y_hat)\\n\\n batch_size = len(y)\\n y = y.detach().cpu().numpy()\\n y *= args[\\\"rng\\\"]\\n y = y.reshape(batch_size, -1)\\n y = y[:, args[\\\"dams\\\"]]\\n y_hat = y_hat.detach().cpu().numpy()\\n y_hat *= args[\\\"rng\\\"]\\n y_hat = y_hat.reshape(batch_size, -1)\\n y_hat = y_hat[:, args[\\\"dams\\\"]]\\n\\n y_true = radar2precipitation(y)\\n y_true = np.where(y_true >= 0.1, 1, 0)\\n y_true = y_true.ravel()\\n y_pred = radar2precipitation(y_hat)\\n y_pred = np.where(y_pred >= 0.1, 1, 0)\\n y_pred = y_pred.ravel()\\n\\n y = y.ravel()\\n y_hat = y_hat.ravel()\\n mae = metrics.mean_absolute_error(y, y_hat, sample_weight=y_true)\\n self.log(\\\"mae\\\", mae)\\n\\n tn, fp, fn, tp = metrics.confusion_matrix(y_true, y_pred).ravel()\\n csi = tp / (tp + fn + fp)\\n self.log(\\\"csi\\\", csi)\\n\\n comp_metric = mae / (csi + 1e-12)\\n self.log(\\\"comp_metric\\\", comp_metric)\\n\\n print(\\n f\\\"Epoch {self.current_epoch} | MAE/CSI: {comp_metric} | MAE: {mae} | CSI: {csi} | Loss: {avg_loss}\\\"\\n )\\n\\n def configure_optimizers(self):\\n # optimizer\\n if args[\\\"optimizer\\\"] == \\\"adam\\\":\\n self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"adamw\\\":\\n self.optimizer = AdamW(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"radam\\\":\\n self.optimizer = optim.RAdam(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"ranger\\\":\\n self.optimizer = optim.RAdam(self.parameters(), lr=self.lr)\\n self.optimizer = optim.Lookahead(self.optimizer)\\n\\n # scheduler\\n if args[\\\"scheduler\\\"] == \\\"cosine\\\":\\n self.scheduler = get_cosine_schedule_with_warmup(\\n self.optimizer,\\n num_warmup_steps=self.num_train_steps * args[\\\"warmup_epochs\\\"],\\n num_training_steps=self.num_train_steps * args[\\\"max_epochs\\\"],\\n )\\n return [self.optimizer], [{\\\"scheduler\\\": self.scheduler, \\\"interval\\\": \\\"step\\\"}]\\n elif args[\\\"scheduler\\\"] == \\\"step\\\":\\n self.scheduler = torch.optim.lr_scheduler.StepLR(\\n self.optimizer, step_size=10, gamma=0.5\\n )\\n return [self.optimizer], [\\n {\\\"scheduler\\\": self.scheduler, \\\"interval\\\": \\\"epoch\\\"}\\n ]\\n elif args[\\\"scheduler\\\"] == \\\"plateau\\\":\\n self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\\n self.optimizer, mode=\\\"min\\\", factor=0.1, patience=3, verbose=True\\n )\\n return [self.optimizer], [\\n {\\n \\\"scheduler\\\": self.scheduler,\\n \\\"interval\\\": \\\"epoch\\\",\\n \\\"reduce_on_plateau\\\": True,\\n \\\"monitor\\\": \\\"comp_metric\\\",\\n }\\n ]\\n else:\\n self.scheduler = None\\n return [self.optimizer]\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", @@ -468,147 +677,18 @@ } ], "source": [ - "class Decoder(nn.Module):\n", - " def __init__(self, chs=[1024, 512, 256, 128, 64]):\n", + "class UNet(pl.LightningModule):\n", + " def __init__(\n", + " self,\n", + " lr=args[\"lr\"],\n", + " enc_chs=[4, 64, 128, 256, 512, 1024],\n", + " dec_chs=[1024, 512, 256, 128, 64],\n", + " num_train_steps=None,\n", + " ):\n", " super().__init__()\n", - " self.blocks = nn.ModuleList(\n", - " [UpBlock(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]\n", - " )\n", - "\n", - " def forward(self, x, feats):\n", - " for block, feat in zip(self.blocks, feats):\n", - " x = block(x, feat)\n", - " return x" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "data": { - "application/javascript": [ - "\n", - " setTimeout(function() {\n", - " var nbb_cell_id = 12;\n", - " var nbb_unformatted_code = \"# x = torch.randn(3, 4, 128, 128)\\n# encoder = Encoder()\\n# feats = encoder(x)\\n# for feat in feats:\\n# print(feat.shape)\";\n", - " var nbb_formatted_code = \"# x = torch.randn(3, 4, 128, 128)\\n# encoder = Encoder()\\n# feats = encoder(x)\\n# for feat in feats:\\n# print(feat.shape)\";\n", - " var nbb_cells = Jupyter.notebook.get_cells();\n", - " for (var i = 0; i < nbb_cells.length; ++i) {\n", - " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", - " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", - " nbb_cells[i].set_text(nbb_formatted_code);\n", - " }\n", - " break;\n", - " }\n", - " }\n", - " }, 500);\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# x = torch.randn(3, 4, 128, 128)\n", - "# encoder = Encoder()\n", - "# feats = encoder(x)\n", - "# for feat in feats:\n", - "# print(feat.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "data": { - "application/javascript": [ - "\n", - " setTimeout(function() {\n", - " var nbb_cell_id = 13;\n", - " var nbb_unformatted_code = \"# decoder = Decoder()\\n# x = torch.randn(3, 1024, 4, 4)\\n# feats = list(reversed(feats))[1:]\\n# decoder(x, feats).shape\";\n", - " var nbb_formatted_code = \"# decoder = Decoder()\\n# x = torch.randn(3, 1024, 4, 4)\\n# feats = list(reversed(feats))[1:]\\n# decoder(x, feats).shape\";\n", - " var nbb_cells = Jupyter.notebook.get_cells();\n", - " for (var i = 0; i < nbb_cells.length; ++i) {\n", - " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", - " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", - " nbb_cells[i].set_text(nbb_formatted_code);\n", - " }\n", - " break;\n", - " }\n", - " }\n", - " }, 500);\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# decoder = Decoder()\n", - "# x = torch.randn(3, 1024, 4, 4)\n", - "# feats = list(reversed(feats))[1:]\n", - "# decoder(x, feats).shape" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "data": { - "application/javascript": [ - "\n", - " setTimeout(function() {\n", - " var nbb_cell_id = 14;\n", - " var nbb_unformatted_code = \"class UNet(pl.LightningModule):\\n def __init__(\\n self,\\n lr=args[\\\"lr\\\"],\\n enc_chs=[4, 64, 128, 256, 512, 1024],\\n dec_chs=[1024, 512, 256, 128, 64],\\n num_train_steps=None,\\n ):\\n super().__init__()\\n self.lr = lr\\n self.num_train_steps = num_train_steps\\n # self.criterion = nn.SmoothL1Loss()\\n self.criterion = nn.L1Loss()\\n\\n self.tail = BasicBlock(4, enc_chs[0])\\n self.encoder = Encoder(enc_chs)\\n self.decoder = Decoder(dec_chs)\\n self.head = nn.Sequential(\\n nn.ConvTranspose2d(dec_chs[-1], 32, kernel_size=2, stride=2, bias=False),\\n nn.BatchNorm2d(32),\\n nn.LeakyReLU(inplace=True),\\n nn.Conv2d(32, 1, kernel_size=3, padding=1),\\n nn.ReLU(inplace=True),\\n )\\n\\n def forward(self, x):\\n x = self.tail(x)\\n feats = self.encoder(x)\\n feats = feats[::-1]\\n x = self.decoder(feats[0], feats[1:])\\n x = self.head(x)\\n\\n return x\\n\\n def shared_step(self, batch, batch_idx):\\n x, y = batch\\n y_hat = self(x)\\n loss = self.criterion(y_hat, y)\\n\\n return loss, y, y_hat\\n\\n def training_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n self.log(\\\"train_loss\\\", loss)\\n for i, param_group in enumerate(self.optimizer.param_groups):\\n self.log(f\\\"lr/lr{i}\\\", param_group[\\\"lr\\\"])\\n\\n return {\\\"loss\\\": loss}\\n\\n def validation_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n\\n return {\\\"loss\\\": loss, \\\"y\\\": y.detach(), \\\"y_hat\\\": y_hat.detach()}\\n\\n def validation_epoch_end(self, outputs):\\n avg_loss = torch.stack([x[\\\"loss\\\"] for x in outputs]).mean()\\n self.log(\\\"val_loss\\\", avg_loss)\\n\\n crop = T.CenterCrop(120)\\n\\n y = torch.cat([x[\\\"y\\\"] for x in outputs])\\n y = crop(y)\\n y = y.detach().cpu().numpy()\\n y = y.reshape(-1, 120 * 120)\\n\\n y_hat = torch.cat([x[\\\"y_hat\\\"] for x in outputs])\\n y_hat = crop(y_hat)\\n y_hat = y_hat.detach().cpu().numpy()\\n y_hat = y_hat.reshape(-1, 120 * 120)\\n\\n y = args[\\\"rng\\\"] * y[:, args[\\\"dams\\\"]]\\n y = y.clip(0, 255)\\n y_hat = args[\\\"rng\\\"] * y_hat[:, args[\\\"dams\\\"]]\\n y_hat = y_hat.clip(0, 255)\\n\\n y_true = radar2precipitation(y)\\n y_true = np.where(y_true >= 0.1, 1, 0)\\n y_pred = radar2precipitation(y_hat)\\n y_pred = np.where(y_pred >= 0.1, 1, 0)\\n\\n y *= y_true\\n y_hat *= y_true\\n mae = metrics.mean_absolute_error(y, y_hat)\\n self.log(\\\"mae\\\", mae)\\n\\n tn, fp, fn, tp = metrics.confusion_matrix(\\n y_true.ravel(), y_pred.ravel()\\n ).ravel()\\n csi = tp / (tp + fn + fp)\\n self.log(\\\"csi\\\", csi)\\n\\n comp_metric = mae / (csi + 1e-12)\\n self.log(\\\"comp_metric\\\", comp_metric)\\n\\n print(\\n f\\\"Epoch {self.current_epoch} | MAE/CSI: {comp_metric} | MAE: {mae} | CSI: {csi} | Loss: {avg_loss}\\\"\\n )\\n\\n def configure_optimizers(self):\\n \\n # Optimizer\\n if args[\\\"optimizer\\\"] == \\\"adam\\\":\\n self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"adamw\\\":\\n self.optimizer = transformers.AdamW(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"adagrad\\\":\\n self.optimizer = torch.optim.Adagrad(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"radam\\\":\\n self.optimizer = optim.RAdam(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"ranger\\\":\\n optimizer = optim.RAdam(self.parameters(), lr=self.lr)\\n self.optimizer = optim.Lookahead(optimizer)\\n \\n # Scheduler\\n if args[\\\"scheduler\\\"] == \\\"cosine\\\":\\n self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\\n self.optimizer, T_max=self.num_train_steps\\n )\\n return [self.optimizer], [{\\\"scheduler\\\": self.scheduler, \\\"interval\\\": \\\"step\\\"}]\\n elif args[\\\"scheduler\\\"] == \\\"step\\\":\\n self.scheduler = torch.optim.lr_scheduler.StepLR()\\n elif args[\\\"scheduler\\\"] == \\\"plateau\\\":\\n self.scheduler = torch.optim.lr_scheduler.ReduceLR()\\n else:\\n self.scheduler = None\\n return [self.optimizer]\";\n", - " var nbb_formatted_code = \"class UNet(pl.LightningModule):\\n def __init__(\\n self,\\n lr=args[\\\"lr\\\"],\\n enc_chs=[4, 64, 128, 256, 512, 1024],\\n dec_chs=[1024, 512, 256, 128, 64],\\n num_train_steps=None,\\n ):\\n super().__init__()\\n self.lr = lr\\n self.num_train_steps = num_train_steps\\n # self.criterion = nn.SmoothL1Loss()\\n self.criterion = nn.L1Loss()\\n\\n self.tail = BasicBlock(4, enc_chs[0])\\n self.encoder = Encoder(enc_chs)\\n self.decoder = Decoder(dec_chs)\\n self.head = nn.Sequential(\\n nn.ConvTranspose2d(dec_chs[-1], 32, kernel_size=2, stride=2, bias=False),\\n nn.BatchNorm2d(32),\\n nn.LeakyReLU(inplace=True),\\n nn.Conv2d(32, 1, kernel_size=3, padding=1),\\n nn.ReLU(inplace=True),\\n )\\n\\n def forward(self, x):\\n x = self.tail(x)\\n feats = self.encoder(x)\\n feats = feats[::-1]\\n x = self.decoder(feats[0], feats[1:])\\n x = self.head(x)\\n\\n return x\\n\\n def shared_step(self, batch, batch_idx):\\n x, y = batch\\n y_hat = self(x)\\n loss = self.criterion(y_hat, y)\\n\\n return loss, y, y_hat\\n\\n def training_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n self.log(\\\"train_loss\\\", loss)\\n for i, param_group in enumerate(self.optimizer.param_groups):\\n self.log(f\\\"lr/lr{i}\\\", param_group[\\\"lr\\\"])\\n\\n return {\\\"loss\\\": loss}\\n\\n def validation_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n\\n return {\\\"loss\\\": loss, \\\"y\\\": y.detach(), \\\"y_hat\\\": y_hat.detach()}\\n\\n def validation_epoch_end(self, outputs):\\n avg_loss = torch.stack([x[\\\"loss\\\"] for x in outputs]).mean()\\n self.log(\\\"val_loss\\\", avg_loss)\\n\\n crop = T.CenterCrop(120)\\n\\n y = torch.cat([x[\\\"y\\\"] for x in outputs])\\n y = crop(y)\\n y = y.detach().cpu().numpy()\\n y = y.reshape(-1, 120 * 120)\\n\\n y_hat = torch.cat([x[\\\"y_hat\\\"] for x in outputs])\\n y_hat = crop(y_hat)\\n y_hat = y_hat.detach().cpu().numpy()\\n y_hat = y_hat.reshape(-1, 120 * 120)\\n\\n y = args[\\\"rng\\\"] * y[:, args[\\\"dams\\\"]]\\n y = y.clip(0, 255)\\n y_hat = args[\\\"rng\\\"] * y_hat[:, args[\\\"dams\\\"]]\\n y_hat = y_hat.clip(0, 255)\\n\\n y_true = radar2precipitation(y)\\n y_true = np.where(y_true >= 0.1, 1, 0)\\n y_pred = radar2precipitation(y_hat)\\n y_pred = np.where(y_pred >= 0.1, 1, 0)\\n\\n y *= y_true\\n y_hat *= y_true\\n mae = metrics.mean_absolute_error(y, y_hat)\\n self.log(\\\"mae\\\", mae)\\n\\n tn, fp, fn, tp = metrics.confusion_matrix(\\n y_true.ravel(), y_pred.ravel()\\n ).ravel()\\n csi = tp / (tp + fn + fp)\\n self.log(\\\"csi\\\", csi)\\n\\n comp_metric = mae / (csi + 1e-12)\\n self.log(\\\"comp_metric\\\", comp_metric)\\n\\n print(\\n f\\\"Epoch {self.current_epoch} | MAE/CSI: {comp_metric} | MAE: {mae} | CSI: {csi} | Loss: {avg_loss}\\\"\\n )\\n\\n def configure_optimizers(self):\\n\\n # Optimizer\\n if args[\\\"optimizer\\\"] == \\\"adam\\\":\\n self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"adamw\\\":\\n self.optimizer = transformers.AdamW(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"adagrad\\\":\\n self.optimizer = torch.optim.Adagrad(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"radam\\\":\\n self.optimizer = optim.RAdam(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"ranger\\\":\\n optimizer = optim.RAdam(self.parameters(), lr=self.lr)\\n self.optimizer = optim.Lookahead(optimizer)\\n\\n # Scheduler\\n if args[\\\"scheduler\\\"] == \\\"cosine\\\":\\n self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\\n self.optimizer, T_max=self.num_train_steps\\n )\\n return [self.optimizer], [{\\\"scheduler\\\": self.scheduler, \\\"interval\\\": \\\"step\\\"}]\\n elif args[\\\"scheduler\\\"] == \\\"step\\\":\\n self.scheduler = torch.optim.lr_scheduler.StepLR()\\n elif args[\\\"scheduler\\\"] == \\\"plateau\\\":\\n self.scheduler = torch.optim.lr_scheduler.ReduceLR()\\n else:\\n self.scheduler = None\\n return [self.optimizer]\";\n", - " var nbb_cells = Jupyter.notebook.get_cells();\n", - " for (var i = 0; i < nbb_cells.length; ++i) {\n", - " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", - " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", - " nbb_cells[i].set_text(nbb_formatted_code);\n", - " }\n", - " break;\n", - " }\n", - " }\n", - " }, 500);\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "class UNet(pl.LightningModule):\n", - " def __init__(\n", - " self,\n", - " lr=args[\"lr\"],\n", - " enc_chs=[4, 64, 128, 256, 512, 1024],\n", - " dec_chs=[1024, 512, 256, 128, 64],\n", - " num_train_steps=None,\n", - " ):\n", - " super().__init__()\n", - " self.lr = lr\n", - " self.num_train_steps = num_train_steps\n", - " # self.criterion = nn.SmoothL1Loss()\n", - " self.criterion = nn.L1Loss()\n", + " self.lr = lr\n", + " self.num_train_steps = num_train_steps\n", + " self.criterion = nn.L1Loss()\n", "\n", " self.tail = BasicBlock(4, enc_chs[0])\n", " self.encoder = Encoder(enc_chs)\n", @@ -618,7 +698,8 @@ " nn.BatchNorm2d(32),\n", " nn.LeakyReLU(inplace=True),\n", " nn.Conv2d(32, 1, kernel_size=3, padding=1),\n", - " nn.ReLU(inplace=True),\n", + " # nn.ReLU(inplace=True),\n", + " nn.Sigmoid(),\n", " )\n", "\n", " def forward(self, x):\n", @@ -627,14 +708,12 @@ " feats = feats[::-1]\n", " x = self.decoder(feats[0], feats[1:])\n", " x = self.head(x)\n", - "\n", " return x\n", "\n", " def shared_step(self, batch, batch_idx):\n", " x, y = batch\n", " y_hat = self(x)\n", " loss = self.criterion(y_hat, y)\n", - "\n", " return loss, y, y_hat\n", "\n", " def training_step(self, batch, batch_idx):\n", @@ -642,48 +721,46 @@ " self.log(\"train_loss\", loss)\n", " for i, param_group in enumerate(self.optimizer.param_groups):\n", " self.log(f\"lr/lr{i}\", param_group[\"lr\"])\n", - "\n", " return {\"loss\": loss}\n", "\n", " def validation_step(self, batch, batch_idx):\n", " loss, y, y_hat = self.shared_step(batch, batch_idx)\n", - "\n", " return {\"loss\": loss, \"y\": y.detach(), \"y_hat\": y_hat.detach()}\n", "\n", " def validation_epoch_end(self, outputs):\n", " avg_loss = torch.stack([x[\"loss\"] for x in outputs]).mean()\n", " self.log(\"val_loss\", avg_loss)\n", "\n", - " crop = T.CenterCrop(120)\n", - "\n", " y = torch.cat([x[\"y\"] for x in outputs])\n", - " y = crop(y)\n", - " y = y.detach().cpu().numpy()\n", - " y = y.reshape(-1, 120 * 120)\n", - "\n", " y_hat = torch.cat([x[\"y_hat\"] for x in outputs])\n", + "\n", + " crop = T.CenterCrop(120)\n", + " y = crop(y)\n", " y_hat = crop(y_hat)\n", - " y_hat = y_hat.detach().cpu().numpy()\n", - " y_hat = y_hat.reshape(-1, 120 * 120)\n", "\n", - " y = args[\"rng\"] * y[:, args[\"dams\"]]\n", - " y = y.clip(0, 255)\n", - " y_hat = args[\"rng\"] * y_hat[:, args[\"dams\"]]\n", - " y_hat = y_hat.clip(0, 255)\n", + " batch_size = len(y)\n", + " y = y.detach().cpu().numpy()\n", + " y *= args[\"rng\"]\n", + " y = y.reshape(batch_size, -1)\n", + " y = y[:, args[\"dams\"]]\n", + " y_hat = y_hat.detach().cpu().numpy()\n", + " y_hat *= args[\"rng\"]\n", + " y_hat = y_hat.reshape(batch_size, -1)\n", + " y_hat = y_hat[:, args[\"dams\"]]\n", "\n", " y_true = radar2precipitation(y)\n", " y_true = np.where(y_true >= 0.1, 1, 0)\n", + " y_true = y_true.ravel()\n", " y_pred = radar2precipitation(y_hat)\n", " y_pred = np.where(y_pred >= 0.1, 1, 0)\n", + " y_pred = y_pred.ravel()\n", "\n", - " y *= y_true\n", - " y_hat *= y_true\n", - " mae = metrics.mean_absolute_error(y, y_hat)\n", + " y = y.ravel()\n", + " y_hat = y_hat.ravel()\n", + " mae = metrics.mean_absolute_error(y, y_hat, sample_weight=y_true)\n", " self.log(\"mae\", mae)\n", "\n", - " tn, fp, fn, tp = metrics.confusion_matrix(\n", - " y_true.ravel(), y_pred.ravel()\n", - " ).ravel()\n", + " tn, fp, fn, tp = metrics.confusion_matrix(y_true, y_pred).ravel()\n", " csi = tp / (tp + fn + fp)\n", " self.log(\"csi\", csi)\n", "\n", @@ -695,94 +772,79 @@ " )\n", "\n", " def configure_optimizers(self):\n", - "\n", - " # Optimizer\n", + " # optimizer\n", " if args[\"optimizer\"] == \"adam\":\n", " self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\n", " elif args[\"optimizer\"] == \"adamw\":\n", - " self.optimizer = transformers.AdamW(self.parameters(), lr=self.lr)\n", - " elif args[\"optimizer\"] == \"adagrad\":\n", - " self.optimizer = torch.optim.Adagrad(self.parameters(), lr=self.lr)\n", + " self.optimizer = AdamW(self.parameters(), lr=self.lr)\n", " elif args[\"optimizer\"] == \"radam\":\n", " self.optimizer = optim.RAdam(self.parameters(), lr=self.lr)\n", " elif args[\"optimizer\"] == \"ranger\":\n", - " optimizer = optim.RAdam(self.parameters(), lr=self.lr)\n", - " self.optimizer = optim.Lookahead(optimizer)\n", + " self.optimizer = optim.RAdam(self.parameters(), lr=self.lr)\n", + " self.optimizer = optim.Lookahead(self.optimizer)\n", "\n", - " # Scheduler\n", + " # scheduler\n", " if args[\"scheduler\"] == \"cosine\":\n", - " self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n", - " self.optimizer, T_max=self.num_train_steps\n", + " self.scheduler = get_cosine_schedule_with_warmup(\n", + " self.optimizer,\n", + " num_warmup_steps=self.num_train_steps * args[\"warmup_epochs\"],\n", + " num_training_steps=self.num_train_steps * args[\"max_epochs\"],\n", " )\n", " return [self.optimizer], [{\"scheduler\": self.scheduler, \"interval\": \"step\"}]\n", " elif args[\"scheduler\"] == \"step\":\n", - " self.scheduler = torch.optim.lr_scheduler.StepLR()\n", + " self.scheduler = torch.optim.lr_scheduler.StepLR(\n", + " self.optimizer, step_size=10, gamma=0.5\n", + " )\n", + " return [self.optimizer], [\n", + " {\"scheduler\": self.scheduler, \"interval\": \"epoch\"}\n", + " ]\n", " elif args[\"scheduler\"] == \"plateau\":\n", - " self.scheduler = torch.optim.lr_scheduler.ReduceLR()\n", + " self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n", + " self.optimizer, mode=\"min\", factor=0.1, patience=3, verbose=True\n", + " )\n", + " return [self.optimizer], [\n", + " {\n", + " \"scheduler\": self.scheduler,\n", + " \"interval\": \"epoch\",\n", + " \"reduce_on_plateau\": True,\n", + " \"monitor\": \"comp_metric\",\n", + " }\n", + " ]\n", " else:\n", " self.scheduler = None\n", " return [self.optimizer]" ] }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [ - { - "data": { - "application/javascript": [ - "\n", - " setTimeout(function() {\n", - " var nbb_cell_id = 18;\n", - " var nbb_unformatted_code = \"# m = UNet()\\n# x = torch.randn(3, 4, 128, 128)\\n# m(x).shape\";\n", - " var nbb_formatted_code = \"# m = UNet()\\n# x = torch.randn(3, 4, 128, 128)\\n# m(x).shape\";\n", - " var nbb_cells = Jupyter.notebook.get_cells();\n", - " for (var i = 0; i < nbb_cells.length; ++i) {\n", - " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", - " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", - " nbb_cells[i].set_text(nbb_formatted_code);\n", - " }\n", - " break;\n", - " }\n", - " }\n", - " }, 500);\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# m = UNet()\n", - "# x = torch.randn(3, 4, 128, 128)\n", - "# m(x).shape" - ] - }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Dataset" + "## Train" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 12, "metadata": {}, "outputs": [ + { + "data": { + "text/plain": [ + "42" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", - " var nbb_cell_id = 19;\n", - " var nbb_unformatted_code = \"class NowcastingDataset(torch.utils.data.Dataset):\\n def __init__(self, paths, test=False):\\n self.paths = paths\\n self.test = test\\n\\n def __len__(self):\\n return len(self.paths)\\n\\n def __getitem__(self, idx):\\n path = self.paths[idx]\\n data = np.load(path)\\n\\n x = data[:, :, :4]\\n x = x / args[\\\"rng\\\"]\\n x = x.astype(np.float32)\\n x = torch.tensor(x, dtype=torch.float)\\n x = x.permute(2, 0, 1)\\n if self.test:\\n return x\\n else:\\n y = data[:, :, 4]\\n\\n precipitation = radar2precipitation(y)\\n\\n label = np.zeros(y.shape)\\n label[precipitation >= 0.1] += 1\\n label[precipitation >= 1.0] += 1\\n label[precipitation >= 2.5] += 1\\n label = torch.tensor(label, dtype=torch.long)\\n label = label.unsqueeze(0)\\n\\n y = y / args[\\\"rng\\\"]\\n y = y.astype(np.float32)\\n y = torch.tensor(y, dtype=torch.float)\\n y = y.unsqueeze(-1)\\n y = y.permute(2, 0, 1)\\n\\n return x, y, label\";\n", - " var nbb_formatted_code = \"class NowcastingDataset(torch.utils.data.Dataset):\\n def __init__(self, paths, test=False):\\n self.paths = paths\\n self.test = test\\n\\n def __len__(self):\\n return len(self.paths)\\n\\n def __getitem__(self, idx):\\n path = self.paths[idx]\\n data = np.load(path)\\n\\n x = data[:, :, :4]\\n x = x / args[\\\"rng\\\"]\\n x = x.astype(np.float32)\\n x = torch.tensor(x, dtype=torch.float)\\n x = x.permute(2, 0, 1)\\n if self.test:\\n return x\\n else:\\n y = data[:, :, 4]\\n\\n precipitation = radar2precipitation(y)\\n\\n label = np.zeros(y.shape)\\n label[precipitation >= 0.1] += 1\\n label[precipitation >= 1.0] += 1\\n label[precipitation >= 2.5] += 1\\n label = torch.tensor(label, dtype=torch.long)\\n label = label.unsqueeze(0)\\n\\n y = y / args[\\\"rng\\\"]\\n y = y.astype(np.float32)\\n y = torch.tensor(y, dtype=torch.float)\\n y = y.unsqueeze(-1)\\n y = y.permute(2, 0, 1)\\n\\n return x, y, label\";\n", + " var nbb_cell_id = 12;\n", + " var nbb_unformatted_code = \"seed_everything(args[\\\"seed\\\"])\\npl.seed_everything(args[\\\"seed\\\"])\";\n", + " var nbb_formatted_code = \"seed_everything(args[\\\"seed\\\"])\\npl.seed_everything(args[\\\"seed\\\"])\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", @@ -804,70 +866,23 @@ } ], "source": [ - "class NowcastingDataset(torch.utils.data.Dataset):\n", - " def __init__(self, paths, test=False):\n", - " self.paths = paths\n", - " self.test = test\n", - "\n", - " def __len__(self):\n", - " return len(self.paths)\n", - "\n", - " def __getitem__(self, idx):\n", - " path = self.paths[idx]\n", - " data = np.load(path)\n", - "\n", - " x = data[:, :, :4]\n", - " x = x / args[\"rng\"]\n", - " x = x.astype(np.float32)\n", - " x = torch.tensor(x, dtype=torch.float)\n", - " x = x.permute(2, 0, 1)\n", - " if self.test:\n", - " return x\n", - " else:\n", - " y = data[:, :, 4]\n", - "\n", - " precipitation = radar2precipitation(y)\n", - "\n", - " label = np.zeros(y.shape)\n", - " label[precipitation >= 0.1] += 1\n", - " label[precipitation >= 1.0] += 1\n", - " label[precipitation >= 2.5] += 1\n", - " label = torch.tensor(label, dtype=torch.long)\n", - " label = label.unsqueeze(0)\n", - "\n", - " y = y / args[\"rng\"]\n", - " y = y.astype(np.float32)\n", - " y = torch.tensor(y, dtype=torch.float)\n", - " y = y.unsqueeze(-1)\n", - " y = y.permute(2, 0, 1)\n", - "\n", - " return x, y, label" + "seed_everything(args[\"seed\"])\n", + "pl.seed_everything(args[\"seed\"])" ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 13, "metadata": {}, "outputs": [ - { - "ename": "ValueError", - "evalue": "too many values to unpack (expected 2)", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mdataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mNowcastingDataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_paths\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0midx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mValueError\u001b[0m: too many values to unpack (expected 2)" - ] - }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", - " var nbb_cell_id = 20;\n", - " var nbb_unformatted_code = \"fold = 3\\ndf = pd.read_csv(args[\\\"train_folds_csv\\\"])\\ntrain_df = df[df.fold != fold]\\ntrain_paths = [args[\\\"train_data_path\\\"] / fn for fn in train_df.filename.values]\\ndataset = NowcastingDataset(train_paths)\\nidx = np.random.randint(len(dataset))\\nx, y = dataset[idx]\";\n", - " var nbb_formatted_code = \"fold = 3\\ndf = pd.read_csv(args[\\\"train_folds_csv\\\"])\\ntrain_df = df[df.fold != fold]\\ntrain_paths = [args[\\\"train_data_path\\\"] / fn for fn in train_df.filename.values]\\ndataset = NowcastingDataset(train_paths)\\nidx = np.random.randint(len(dataset))\\nx, y = dataset[idx]\";\n", + " var nbb_cell_id = 13;\n", + " var nbb_unformatted_code = \"df = pd.read_csv(args[\\\"train_folds_csv\\\"])\";\n", + " var nbb_formatted_code = \"df = pd.read_csv(args[\\\"train_folds_csv\\\"])\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", @@ -889,18 +904,12 @@ } ], "source": [ - "fold = 3\n", - "df = pd.read_csv(args[\"train_folds_csv\"])\n", - "train_df = df[df.fold != fold]\n", - "train_paths = [args[\"train_data_path\"] / fn for fn in train_df.filename.values]\n", - "dataset = NowcastingDataset(train_paths)\n", - "idx = np.random.randint(len(dataset))\n", - "x, y = dataset[idx]" + "df = pd.read_csv(args[\"train_folds_csv\"])" ] }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -908,9 +917,9 @@ "application/javascript": [ "\n", " setTimeout(function() {\n", - " var nbb_cell_id = 21;\n", - " var nbb_unformatted_code = \"class NowcastingDataModule(pl.LightningDataModule):\\n def __init__(\\n self,\\n train_df=None,\\n val_df=None,\\n batch_size=args[\\\"batch_size\\\"],\\n num_workers=args[\\\"num_workers\\\"],\\n ):\\n super().__init__()\\n self.train_df = train_df\\n self.val_df = val_df\\n self.batch_size = batch_size\\n self.num_workers = num_workers\\n\\n def setup(self, stage=\\\"train\\\"):\\n if stage == \\\"train\\\":\\n train_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.train_df.filename.values\\n ]\\n val_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.val_df.filename.values\\n ]\\n self.train_dataset = NowcastingDataset(train_paths)\\n self.val_dataset = NowcastingDataset(val_paths)\\n else:\\n test_paths = list(sorted(args[\\\"test_data_path\\\"].glob(\\\"*.npy\\\")))\\n self.test_dataset = NowcastingDataset(test_paths, test=True)\\n\\n def train_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.train_dataset,\\n batch_size=self.batch_size,\\n sampler=RandomSampler(self.train_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n drop_last=True,\\n )\\n\\n def val_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.val_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.val_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\\n\\n def test_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.test_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.test_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\";\n", - " var nbb_formatted_code = \"class NowcastingDataModule(pl.LightningDataModule):\\n def __init__(\\n self,\\n train_df=None,\\n val_df=None,\\n batch_size=args[\\\"batch_size\\\"],\\n num_workers=args[\\\"num_workers\\\"],\\n ):\\n super().__init__()\\n self.train_df = train_df\\n self.val_df = val_df\\n self.batch_size = batch_size\\n self.num_workers = num_workers\\n\\n def setup(self, stage=\\\"train\\\"):\\n if stage == \\\"train\\\":\\n train_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.train_df.filename.values\\n ]\\n val_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.val_df.filename.values\\n ]\\n self.train_dataset = NowcastingDataset(train_paths)\\n self.val_dataset = NowcastingDataset(val_paths)\\n else:\\n test_paths = list(sorted(args[\\\"test_data_path\\\"].glob(\\\"*.npy\\\")))\\n self.test_dataset = NowcastingDataset(test_paths, test=True)\\n\\n def train_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.train_dataset,\\n batch_size=self.batch_size,\\n sampler=RandomSampler(self.train_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n drop_last=True,\\n )\\n\\n def val_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.val_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.val_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\\n\\n def test_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.test_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.test_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\";\n", + " var nbb_cell_id = 14;\n", + " var nbb_unformatted_code = \"def train_fold(df, fold, lr_find=False):\\n train_df = df[df.fold != fold]\\n val_df = df[df.fold == fold]\\n\\n datamodule = NowcastingDataModule(train_df, val_df)\\n datamodule.setup()\\n\\n num_train_steps = np.ceil(\\n len(train_df) // args[\\\"batch_size\\\"] / args[\\\"accumulate_grad_batches\\\"]\\n )\\n model = UNet(num_train_steps=num_train_steps)\\n\\n trainer = pl.Trainer(\\n gpus=args[\\\"gpus\\\"],\\n max_epochs=args[\\\"max_epochs\\\"],\\n precision=args[\\\"precision\\\"],\\n progress_bar_refresh_rate=50,\\n benchmark=True,\\n )\\n\\n if lr_find:\\n lr_finder = trainer.tuner.lr_find(model, datamodule=datamodule)\\n fig = lr_finder.plot(suggest=True)\\n fig.show()\\n return\\n\\n print(f\\\"Training fold {fold}...\\\")\\n trainer.fit(model, datamodule)\\n\\n checkpoint = (\\n args[\\\"model_dir\\\"]\\n / f\\\"unet_sigmoid_fold{fold}_bs{args['batch_size']}_epochs{args['max_epochs']}_lr{model.lr}_{args['optimizer']}_{args['scheduler']}.ckpt\\\"\\n )\\n trainer.save_checkpoint(checkpoint)\\n print(\\\"Model saved at\\\", checkpoint)\\n\\n del model, trainer, datamodule\\n gc.collect()\\n torch.cuda.empty_cache()\";\n", + " var nbb_formatted_code = \"def train_fold(df, fold, lr_find=False):\\n train_df = df[df.fold != fold]\\n val_df = df[df.fold == fold]\\n\\n datamodule = NowcastingDataModule(train_df, val_df)\\n datamodule.setup()\\n\\n num_train_steps = np.ceil(\\n len(train_df) // args[\\\"batch_size\\\"] / args[\\\"accumulate_grad_batches\\\"]\\n )\\n model = UNet(num_train_steps=num_train_steps)\\n\\n trainer = pl.Trainer(\\n gpus=args[\\\"gpus\\\"],\\n max_epochs=args[\\\"max_epochs\\\"],\\n precision=args[\\\"precision\\\"],\\n progress_bar_refresh_rate=50,\\n benchmark=True,\\n )\\n\\n if lr_find:\\n lr_finder = trainer.tuner.lr_find(model, datamodule=datamodule)\\n fig = lr_finder.plot(suggest=True)\\n fig.show()\\n return\\n\\n print(f\\\"Training fold {fold}...\\\")\\n trainer.fit(model, datamodule)\\n\\n checkpoint = (\\n args[\\\"model_dir\\\"]\\n / f\\\"unet_sigmoid_fold{fold}_bs{args['batch_size']}_epochs{args['max_epochs']}_lr{model.lr}_{args['optimizer']}_{args['scheduler']}.ckpt\\\"\\n )\\n trainer.save_checkpoint(checkpoint)\\n print(\\\"Model saved at\\\", checkpoint)\\n\\n del model, trainer, datamodule\\n gc.collect()\\n torch.cuda.empty_cache()\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", @@ -932,73 +941,50 @@ } ], "source": [ - "class NowcastingDataModule(pl.LightningDataModule):\n", - " def __init__(\n", - " self,\n", - " train_df=None,\n", - " val_df=None,\n", - " batch_size=args[\"batch_size\"],\n", - " num_workers=args[\"num_workers\"],\n", - " ):\n", - " super().__init__()\n", - " self.train_df = train_df\n", - " self.val_df = val_df\n", - " self.batch_size = batch_size\n", - " self.num_workers = num_workers\n", + "def train_fold(df, fold, lr_find=False):\n", + " train_df = df[df.fold != fold]\n", + " val_df = df[df.fold == fold]\n", "\n", - " def setup(self, stage=\"train\"):\n", - " if stage == \"train\":\n", - " train_paths = [\n", - " args[\"train_data_path\"] / fn for fn in self.train_df.filename.values\n", - " ]\n", - " val_paths = [\n", - " args[\"train_data_path\"] / fn for fn in self.val_df.filename.values\n", - " ]\n", - " self.train_dataset = NowcastingDataset(train_paths)\n", - " self.val_dataset = NowcastingDataset(val_paths)\n", - " else:\n", - " test_paths = list(sorted(args[\"test_data_path\"].glob(\"*.npy\")))\n", - " self.test_dataset = NowcastingDataset(test_paths, test=True)\n", + " datamodule = NowcastingDataModule(train_df, val_df)\n", + " datamodule.setup()\n", "\n", - " def train_dataloader(self):\n", - " return torch.utils.data.DataLoader(\n", - " self.train_dataset,\n", - " batch_size=self.batch_size,\n", - " sampler=RandomSampler(self.train_dataset),\n", - " pin_memory=True,\n", - " num_workers=self.num_workers,\n", - " drop_last=True,\n", - " )\n", + " num_train_steps = np.ceil(\n", + " len(train_df) // args[\"batch_size\"] / args[\"accumulate_grad_batches\"]\n", + " )\n", + " model = UNet(num_train_steps=num_train_steps)\n", "\n", - " def val_dataloader(self):\n", - " return torch.utils.data.DataLoader(\n", - " self.val_dataset,\n", - " batch_size=2 * self.batch_size,\n", - " sampler=SequentialSampler(self.val_dataset),\n", - " pin_memory=True,\n", - " num_workers=self.num_workers,\n", - " )\n", + " trainer = pl.Trainer(\n", + " gpus=args[\"gpus\"],\n", + " max_epochs=args[\"max_epochs\"],\n", + " precision=args[\"precision\"],\n", + " progress_bar_refresh_rate=50,\n", + " benchmark=True,\n", + " )\n", "\n", - " def test_dataloader(self):\n", - " return torch.utils.data.DataLoader(\n", - " self.test_dataset,\n", - " batch_size=2 * self.batch_size,\n", - " sampler=SequentialSampler(self.test_dataset),\n", - " pin_memory=True,\n", - " num_workers=self.num_workers,\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Train" + " if lr_find:\n", + " lr_finder = trainer.tuner.lr_find(model, datamodule=datamodule)\n", + " fig = lr_finder.plot(suggest=True)\n", + " fig.show()\n", + " return\n", + "\n", + " print(f\"Training fold {fold}...\")\n", + " trainer.fit(model, datamodule)\n", + "\n", + " checkpoint = (\n", + " args[\"model_dir\"]\n", + " / f\"unet_sigmoid_fold{fold}_bs{args['batch_size']}_epochs{args['max_epochs']}_lr{model.lr}_{args['optimizer']}_{args['scheduler']}.ckpt\"\n", + " )\n", + " trainer.save_checkpoint(checkpoint)\n", + " print(\"Model saved at\", checkpoint)\n", + "\n", + " del model, trainer, datamodule\n", + " gc.collect()\n", + " torch.cuda.empty_cache()" ] }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 16, "metadata": { "scrolled": true }, @@ -1010,7 +996,20 @@ "GPU available: True, used: True\n", "TPU available: False, using: 0 TPU cores\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", - "Using native 16bit precision.\n", + "Using native 16bit precision.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training fold 0...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ "\n", " | Name | Type | Params\n", "-----------------------------------------\n", @@ -1024,7 +1023,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "ed004db6ca9f4443b88355a974d0aff0", + "model_id": "e340beec3a3749aaaf7fcef74b3c6726", "version_major": 2, "version_minor": 0 }, @@ -1036,219 +1035,56 @@ "output_type": "display_data" }, { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ - "ERROR:root:Internal Python error in the inspect module.\n", - "Below is the traceback from this internal error.\n", - "\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Traceback (most recent call last):\n", - " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/IPython/core/interactiveshell.py\", line 3418, in run_code\n", - " exec(code_obj, self.user_global_ns, self.user_ns)\n", - " File \"\", line 41, in \n", - " trainer.fit(model, datamodule)\n", - " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\", line 440, in fit\n", - " results = self.accelerator_backend.train()\n", - " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\", line 54, in train\n", - " results = self.train_or_test()\n", - " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py\", line 68, in train_or_test\n", - " results = self.trainer.train()\n", - " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\", line 462, in train\n", - " self.run_sanity_check(self.get_model())\n", - " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\", line 650, in run_sanity_check\n", - " _, eval_results = self.run_evaluation(test_mode=False, max_batches=self.num_sanity_val_batches)\n", - " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\", line 570, in run_evaluation\n", - " output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx)\n", - " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/trainer/evaluation_loop.py\", line 171, in evaluation_step\n", - " output = self.trainer.accelerator_backend.validation_step(args)\n", - " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\", line 76, in validation_step\n", - " output = self.__validation_step(args)\n", - " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\", line 86, in __validation_step\n", - " output = self.trainer.model.validation_step(*args)\n", - " File \"\", line 51, in validation_step\n", - " loss, y, y_hat = self.shared_step(batch, batch_idx)\n", - " File \"\", line 36, in shared_step\n", - " x, y = batch\n", - "ValueError: too many values to unpack (expected 2)\n", - "\n", - "During handling of the above exception, another exception occurred:\n", - "\n", - "Traceback (most recent call last):\n", - " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/IPython/core/interactiveshell.py\", line 2045, in showtraceback\n", - " stb = value._render_traceback_()\n", - "AttributeError: 'ValueError' object has no attribute '_render_traceback_'\n", - "\n", - "During handling of the above exception, another exception occurred:\n", - "\n", - "Traceback (most recent call last):\n", - " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/IPython/core/ultratb.py\", line 1170, in get_records\n", - " return _fixed_getinnerframes(etb, number_of_lines_of_context, tb_offset)\n", - " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/IPython/core/ultratb.py\", line 316, in wrapped\n", - " return f(*args, **kwargs)\n", - " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/IPython/core/ultratb.py\", line 350, in _fixed_getinnerframes\n", - " records = fix_frame_records_filenames(inspect.getinnerframes(etb, context))\n", - " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/inspect.py\", line 1503, in getinnerframes\n", - " frameinfo = (tb.tb_frame,) + getframeinfo(tb, context)\n", - " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/inspect.py\", line 1461, in getframeinfo\n", - " filename = getsourcefile(frame) or getfile(frame)\n", - " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/inspect.py\", line 708, in getsourcefile\n", - " if getattr(getmodule(object, filename), '__loader__', None) is not None:\n", - " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/inspect.py\", line 745, in getmodule\n", - " if ismodule(module) and hasattr(module, '__file__'):\n", - "KeyboardInterrupt\n" - ] - }, - { - "ename": "TypeError", - "evalue": "object of type 'NoneType' has no len()", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 41\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdatamodule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 42\u001b[0m trainer.save_checkpoint(\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, model, train_dataloader, val_dataloaders, datamodule)\u001b[0m\n\u001b[1;32m 439\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 440\u001b[0;31m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccelerator_backend\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 441\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccelerator_backend\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mteardown\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0;31m# train or test\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 54\u001b[0;31m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_or_test\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 55\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py\u001b[0m in \u001b[0;36mtrain_or_test\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 68\u001b[0;31m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 69\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 461\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 462\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_sanity_check\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 463\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mrun_sanity_check\u001b[0;34m(self, ref_model)\u001b[0m\n\u001b[1;32m 649\u001b[0m \u001b[0;31m# run eval step\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 650\u001b[0;31m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_evaluation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_mode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_batches\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_sanity_val_batches\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 651\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mrun_evaluation\u001b[0;34m(self, test_mode, max_batches)\u001b[0m\n\u001b[1;32m 569\u001b[0m \u001b[0;31m# lightning module methods\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 570\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluation_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluation_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_mode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdataloader_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 571\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluation_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluation_step_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/trainer/evaluation_loop.py\u001b[0m in \u001b[0;36mevaluation_step\u001b[0;34m(self, test_mode, batch, batch_idx, dataloader_idx)\u001b[0m\n\u001b[1;32m 170\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 171\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccelerator_backend\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalidation_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 172\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\u001b[0m in \u001b[0;36mvalidation_step\u001b[0;34m(self, args)\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mamp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautocast\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 76\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__validation_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 77\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\u001b[0m in \u001b[0;36m__validation_step\u001b[0;34m(self, args)\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 86\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalidation_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 87\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m\u001b[0m in \u001b[0;36mvalidation_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mvalidation_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 51\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_hat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshared_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 52\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m\u001b[0m in \u001b[0;36mshared_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mshared_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 36\u001b[0;31m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 37\u001b[0m \u001b[0my_hat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mValueError\u001b[0m: too many values to unpack (expected 2)", - "\nDuring handling of the above exception, another exception occurred:\n", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mshowtraceback\u001b[0;34m(self, exc_tuple, filename, tb_offset, exception_only, running_compiled_code)\u001b[0m\n\u001b[1;32m 2044\u001b[0m \u001b[0;31m# in the engines. This should return a list of strings.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2045\u001b[0;31m \u001b[0mstb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_render_traceback_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2046\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mAttributeError\u001b[0m: 'ValueError' object has no attribute '_render_traceback_'", - "\nDuring handling of the above exception, another exception occurred:\n", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mshowtraceback\u001b[0;34m(self, exc_tuple, filename, tb_offset, exception_only, running_compiled_code)\u001b[0m\n\u001b[1;32m 2045\u001b[0m \u001b[0mstb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_render_traceback_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2046\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2047\u001b[0;31m stb = self.InteractiveTB.structured_traceback(etype,\n\u001b[0m\u001b[1;32m 2048\u001b[0m value, tb, tb_offset=tb_offset)\n\u001b[1;32m 2049\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mstructured_traceback\u001b[0;34m(self, etype, value, tb, tb_offset, number_of_lines_of_context)\u001b[0m\n\u001b[1;32m 1434\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1435\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1436\u001b[0;31m return FormattedTB.structured_traceback(\n\u001b[0m\u001b[1;32m 1437\u001b[0m self, etype, value, tb, tb_offset, number_of_lines_of_context)\n\u001b[1;32m 1438\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mstructured_traceback\u001b[0;34m(self, etype, value, tb, tb_offset, number_of_lines_of_context)\u001b[0m\n\u001b[1;32m 1334\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmode\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mverbose_modes\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1335\u001b[0m \u001b[0;31m# Verbose modes need a full traceback\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1336\u001b[0;31m return VerboseTB.structured_traceback(\n\u001b[0m\u001b[1;32m 1337\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0metype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtb_offset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnumber_of_lines_of_context\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1338\u001b[0m )\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mstructured_traceback\u001b[0;34m(self, etype, evalue, etb, tb_offset, number_of_lines_of_context)\u001b[0m\n\u001b[1;32m 1191\u001b[0m \u001b[0;34m\"\"\"Return a nice text document describing the traceback.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1192\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1193\u001b[0;31m formatted_exception = self.format_exception_as_a_whole(etype, evalue, etb, number_of_lines_of_context,\n\u001b[0m\u001b[1;32m 1194\u001b[0m tb_offset)\n\u001b[1;32m 1195\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mformat_exception_as_a_whole\u001b[0;34m(self, etype, evalue, etb, number_of_lines_of_context, tb_offset)\u001b[0m\n\u001b[1;32m 1149\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1150\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1151\u001b[0;31m \u001b[0mlast_unique\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecursion_repeat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfind_recursion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0morig_etype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mevalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecords\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1152\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1153\u001b[0m \u001b[0mframes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat_records\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrecords\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlast_unique\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecursion_repeat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mfind_recursion\u001b[0;34m(etype, value, records)\u001b[0m\n\u001b[1;32m 449\u001b[0m \u001b[0;31m# first frame (from in to out) that looks different.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 450\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mis_recursion_error\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0metype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecords\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 451\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrecords\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 452\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 453\u001b[0m \u001b[0;31m# Select filename, lineno, func_name to track frames with\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mTypeError\u001b[0m: object of type 'NoneType' has no len()" + "Epoch 0 | MAE/CSI: 110793664383561.66 | MAE: 110.79366438356165 | CSI: 0.0 | Loss: 0.05267500877380371\n" ] }, { "data": { - "application/javascript": [ - "\n", - " setTimeout(function() {\n", - " var nbb_cell_id = 22;\n", - " var nbb_unformatted_code = \"seed_everything(args[\\\"seed\\\"])\\npl.seed_everything(args[\\\"seed\\\"])\\n\\ndf = pd.read_csv(args[\\\"train_folds_csv\\\"])\\n\\nfor fold in range(3, 5):\\n train_df = df[df.fold != fold]\\n val_df = df[df.fold == fold]\\n\\n datamodule = NowcastingDataModule(\\n train_df, val_df, batch_size=args[\\\"batch_size\\\"], num_workers=args[\\\"num_workers\\\"]\\n )\\n datamodule.setup()\\n\\n num_train_steps = (\\n int(\\n np.ceil(\\n len(train_df) // args[\\\"batch_size\\\"] / args[\\\"accumulate_grad_batches\\\"]\\n )\\n )\\n * args[\\\"max_epochs\\\"]\\n )\\n\\n model = UNet(num_train_steps=num_train_steps)\\n\\n trainer = pl.Trainer(\\n gpus=args[\\\"gpus\\\"],\\n max_epochs=args[\\\"max_epochs\\\"],\\n precision=args[\\\"precision\\\"],\\n progress_bar_refresh_rate=50,\\n # accumulate_grad_batches=args[\\\"accumulate_grad_batches\\\"],\\n # gradient_clip_val=args[\\\"gradient_clip_val\\\"],\\n auto_lr_find=True,\\n )\\n\\n # learning rate finder\\n # lr_finder = trainer.tuner.lr_find(model, datamodule=datamodule)\\n # fig = lr_finder.plot(suggest=True)\\n # fig.show()\\n\\n trainer.fit(model, datamodule)\\n trainer.save_checkpoint(\\n args[\\\"model_dir\\\"]\\n / f\\\"unet_fold{fold}_bs{args['batch_size']}_epoch{args['max_epochs']}_{args['optimizer']}_{args['scheduler']}.ckpt\\\"\\n )\\n\\n del datamodule, model, trainer\\n gc.collect()\\n torch.cuda.empty_cache()\";\n", - " var nbb_formatted_code = \"seed_everything(args[\\\"seed\\\"])\\npl.seed_everything(args[\\\"seed\\\"])\\n\\ndf = pd.read_csv(args[\\\"train_folds_csv\\\"])\\n\\nfor fold in range(3, 5):\\n train_df = df[df.fold != fold]\\n val_df = df[df.fold == fold]\\n\\n datamodule = NowcastingDataModule(\\n train_df, val_df, batch_size=args[\\\"batch_size\\\"], num_workers=args[\\\"num_workers\\\"]\\n )\\n datamodule.setup()\\n\\n num_train_steps = (\\n int(\\n np.ceil(\\n len(train_df) // args[\\\"batch_size\\\"] / args[\\\"accumulate_grad_batches\\\"]\\n )\\n )\\n * args[\\\"max_epochs\\\"]\\n )\\n\\n model = UNet(num_train_steps=num_train_steps)\\n\\n trainer = pl.Trainer(\\n gpus=args[\\\"gpus\\\"],\\n max_epochs=args[\\\"max_epochs\\\"],\\n precision=args[\\\"precision\\\"],\\n progress_bar_refresh_rate=50,\\n # accumulate_grad_batches=args[\\\"accumulate_grad_batches\\\"],\\n # gradient_clip_val=args[\\\"gradient_clip_val\\\"],\\n auto_lr_find=True,\\n )\\n\\n # learning rate finder\\n # lr_finder = trainer.tuner.lr_find(model, datamodule=datamodule)\\n # fig = lr_finder.plot(suggest=True)\\n # fig.show()\\n\\n trainer.fit(model, datamodule)\\n trainer.save_checkpoint(\\n args[\\\"model_dir\\\"]\\n / f\\\"unet_fold{fold}_bs{args['batch_size']}_epoch{args['max_epochs']}_{args['optimizer']}_{args['scheduler']}.ckpt\\\"\\n )\\n\\n del datamodule, model, trainer\\n gc.collect()\\n torch.cuda.empty_cache()\";\n", - " var nbb_cells = Jupyter.notebook.get_cells();\n", - " for (var i = 0; i < nbb_cells.length; ++i) {\n", - " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", - " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", - " nbb_cells[i].set_text(nbb_formatted_code);\n", - " }\n", - " break;\n", - " }\n", - " }\n", - " }, 500);\n", - " " - ], + "application/vnd.jupyter.widget-view+json": { + "model_id": "1f8163856b804c03a1356bc70986b5b1", + "version_major": 2, + "version_minor": 0 + }, "text/plain": [ - "" + "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…" ] }, "metadata": {}, "output_type": "display_data" - } - ], - "source": [ - "seed_everything(args[\"seed\"])\n", - "pl.seed_everything(args[\"seed\"])\n", - "\n", - "df = pd.read_csv(args[\"train_folds_csv\"])\n", - "\n", - "for fold in range(3, 5):\n", - " train_df = df[df.fold != fold]\n", - " val_df = df[df.fold == fold]\n", - "\n", - " datamodule = NowcastingDataModule(\n", - " train_df, val_df, batch_size=args[\"batch_size\"], num_workers=args[\"num_workers\"]\n", - " )\n", - " datamodule.setup()\n", - "\n", - " num_train_steps = (\n", - " int(\n", - " np.ceil(\n", - " len(train_df) // args[\"batch_size\"] / args[\"accumulate_grad_batches\"]\n", - " )\n", - " )\n", - " * args[\"max_epochs\"]\n", - " )\n", - "\n", - " model = UNet(num_train_steps=num_train_steps)\n", - "\n", - " trainer = pl.Trainer(\n", - " gpus=args[\"gpus\"],\n", - " max_epochs=args[\"max_epochs\"],\n", - " precision=args[\"precision\"],\n", - " progress_bar_refresh_rate=50,\n", - " # accumulate_grad_batches=args[\"accumulate_grad_batches\"],\n", - " # gradient_clip_val=args[\"gradient_clip_val\"],\n", - " auto_lr_find=True,\n", - " )\n", - "\n", - " # learning rate finder\n", - " # lr_finder = trainer.tuner.lr_find(model, datamodule=datamodule)\n", - " # fig = lr_finder.plot(suggest=True)\n", - " # fig.show()\n", - "\n", - " trainer.fit(model, datamodule)\n", - " trainer.save_checkpoint(\n", - " args[\"model_dir\"]\n", - " / f\"unet_fold{fold}_bs{args['batch_size']}_epoch{args['max_epochs']}_{args['optimizer']}_{args['scheduler']}.ckpt\"\n", - " )\n", - "\n", - " del datamodule, model, trainer\n", - " gc.collect()\n", - " torch.cuda.empty_cache()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Inference" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [ + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ab4a02f18c184218ab3ae7a8e2ee813c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "name": "stdout", "output_type": "stream", "text": [ - "../models/unet_fold0_bs256_epoch50_adamw_cosine.ckpt\n" + "Epoch 0 | MAE/CSI: 33.186943419342214 | MAE: 25.55001192118522 | CSI: 0.7698814439856134 | Loss: 0.017848094925284386\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "209aa235815347a9b2522aba0e43f37c", + "model_id": "67061e885fb545a3bed1513ea4741a74", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6.0), HTML(value='')))" + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" ] }, "metadata": {}, @@ -1258,19 +1094,18 @@ "name": "stdout", "output_type": "stream", "text": [ - "\n", - "../models/unet_fold1_bs256_epoch50_adamw_cosine.ckpt\n" + "Epoch 1 | MAE/CSI: 25.185173461181805 | MAE: 19.69605205396773 | CSI: 0.7820494897245911 | Loss: 0.01364449504762888\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "10755cf46d21468e8109a77e1f278596", + "model_id": "7c1220abc1e14797a75ed1160d130621", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6.0), HTML(value='')))" + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" ] }, "metadata": {}, @@ -1280,19 +1115,18 @@ "name": "stdout", "output_type": "stream", "text": [ - "\n", - "../models/unet_fold2_bs256_epoch50_adamw_cosine.ckpt\n" + "Epoch 2 | MAE/CSI: 25.54220371042292 | MAE: 20.054110962616964 | CSI: 0.7851362862010222 | Loss: 0.013360547833144665\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "bf62d681150647e5b1397731d1484a54", + "model_id": "eda949cdabdf43ff8c57f0c7a987893e", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6.0), HTML(value='')))" + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" ] }, "metadata": {}, @@ -1302,19 +1136,18 @@ "name": "stdout", "output_type": "stream", "text": [ - "\n", - "../models/unet_fold3_bs256_epoch50_adamw_cosine.ckpt\n" + "Epoch 3 | MAE/CSI: 27.026754305478295 | MAE: 21.24946493606421 | CSI: 0.7862381363244176 | Loss: 0.013475954532623291\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "687733fd26c745d8a01ff94af632ffeb", + "model_id": "0383e854b13e4225b7ec1e8fdc0c59b7", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6.0), HTML(value='')))" + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" ] }, "metadata": {}, @@ -1324,19 +1157,18 @@ "name": "stdout", "output_type": "stream", "text": [ - "\n", - "../models/unet_fold4_bs256_epoch50_adamw_cosine.ckpt\n" + "Epoch 4 | MAE/CSI: 21.443058880888415 | MAE: 17.222089883581003 | CSI: 0.8031545302946081 | Loss: 0.012746231630444527\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "350aa5098fd3423aa7487492fcf1ec48", + "model_id": "cc3d623fe7bd44bf910368fae52c2b4f", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6.0), HTML(value='')))" + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" ] }, "metadata": {}, @@ -1346,128 +1178,60 @@ "name": "stdout", "output_type": "stream", "text": [ - "\n" + "Epoch 5 | MAE/CSI: 21.810521936074984 | MAE: 17.524096834325515 | CSI: 0.8034698521046644 | Loss: 0.012218066491186619\n" ] }, { "data": { - "application/javascript": [ - "\n", - " setTimeout(function() {\n", - " var nbb_cell_id = 25;\n", - " var nbb_unformatted_code = \"datamodule = NowcastingDataModule()\\ndatamodule.setup(\\\"test\\\")\\n\\nfinal_preds = np.zeros((len(datamodule.test_dataset), 120, 120))\\n\\nfor fold in range(5):\\n checkpoint = (\\n args[\\\"model_dir\\\"]\\n / f\\\"unet_fold{fold}_bs{args['batch_size']}_epoch{args['max_epochs']}_{args['optimizer']}_{args['scheduler']}.ckpt\\\"\\n )\\n print(checkpoint)\\n model = UNet.load_from_checkpoint(str(checkpoint))\\n model.cuda()\\n model.eval()\\n preds = []\\n with torch.no_grad():\\n for batch in tqdm(datamodule.test_dataloader()):\\n batch = batch.cuda()\\n imgs = model(batch)\\n imgs = imgs.detach().cpu().numpy()\\n imgs = imgs[:, 0, 4:124, 4:124]\\n imgs = args[\\\"rng\\\"] * imgs\\n imgs = imgs.clip(0, 255)\\n imgs = imgs.round()\\n preds.append(imgs)\\n\\n preds = np.concatenate(preds)\\n preds = preds.astype(np.uint8)\\n final_preds += preds\\n\\n del model\\n gc.collect()\\n torch.cuda.empty_cache()\\n\\nfinal_preds = final_preds.astype(np.uint8)\\nfinal_preds = final_preds.reshape(-1, 14400)\";\n", - " var nbb_formatted_code = \"datamodule = NowcastingDataModule()\\ndatamodule.setup(\\\"test\\\")\\n\\nfinal_preds = np.zeros((len(datamodule.test_dataset), 120, 120))\\n\\nfor fold in range(5):\\n checkpoint = (\\n args[\\\"model_dir\\\"]\\n / f\\\"unet_fold{fold}_bs{args['batch_size']}_epoch{args['max_epochs']}_{args['optimizer']}_{args['scheduler']}.ckpt\\\"\\n )\\n print(checkpoint)\\n model = UNet.load_from_checkpoint(str(checkpoint))\\n model.cuda()\\n model.eval()\\n preds = []\\n with torch.no_grad():\\n for batch in tqdm(datamodule.test_dataloader()):\\n batch = batch.cuda()\\n imgs = model(batch)\\n imgs = imgs.detach().cpu().numpy()\\n imgs = imgs[:, 0, 4:124, 4:124]\\n imgs = args[\\\"rng\\\"] * imgs\\n imgs = imgs.clip(0, 255)\\n imgs = imgs.round()\\n preds.append(imgs)\\n\\n preds = np.concatenate(preds)\\n preds = preds.astype(np.uint8)\\n final_preds += preds\\n\\n del model\\n gc.collect()\\n torch.cuda.empty_cache()\\n\\nfinal_preds = final_preds.astype(np.uint8)\\nfinal_preds = final_preds.reshape(-1, 14400)\";\n", - " var nbb_cells = Jupyter.notebook.get_cells();\n", - " for (var i = 0; i < nbb_cells.length; ++i) {\n", - " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", - " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", - " nbb_cells[i].set_text(nbb_formatted_code);\n", - " }\n", - " break;\n", - " }\n", - " }\n", - " }, 500);\n", - " " - ], + "application/vnd.jupyter.widget-view+json": { + "model_id": "c94d6cf2a99448b6b811fddcf8d2565a", + "version_major": 2, + "version_minor": 0 + }, "text/plain": [ - "" + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" - } - ], - "source": [ - "datamodule = NowcastingDataModule()\n", - "datamodule.setup(\"test\")\n", - "\n", - "final_preds = np.zeros((len(datamodule.test_dataset), 120, 120))\n", - "\n", - "for fold in range(5):\n", - " checkpoint = (\n", - " args[\"model_dir\"]\n", - " / f\"unet_fold{fold}_bs{args['batch_size']}_epoch{args['max_epochs']}_{args['optimizer']}_{args['scheduler']}.ckpt\"\n", - " )\n", - " print(checkpoint)\n", - " model = UNet.load_from_checkpoint(str(checkpoint))\n", - " model.cuda()\n", - " model.eval()\n", - " preds = []\n", - " with torch.no_grad():\n", - " for batch in tqdm(datamodule.test_dataloader()):\n", - " batch = batch.cuda()\n", - " imgs = model(batch)\n", - " imgs = imgs.detach().cpu().numpy()\n", - " imgs = imgs[:, 0, 4:124, 4:124]\n", - " imgs = args[\"rng\"] * imgs\n", - " imgs = imgs.clip(0, 255)\n", - " imgs = imgs.round()\n", - " preds.append(imgs)\n", - "\n", - " preds = np.concatenate(preds)\n", - " preds = preds.astype(np.uint8)\n", - " final_preds += preds\n", - "\n", - " del model\n", - " gc.collect()\n", - " torch.cuda.empty_cache()\n", - "\n", - "final_preds = final_preds.astype(np.uint8)\n", - "final_preds = final_preds.reshape(-1, 14400)" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [ + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 6 | MAE/CSI: 19.459475168673396 | MAE: 15.815077867269974 | CSI: 0.8127186231985448 | Loss: 0.011925801634788513\n" + ] + }, { "data": { - "application/javascript": [ - "\n", - " setTimeout(function() {\n", - " var nbb_cell_id = 26;\n", - " var nbb_unformatted_code = \"test_paths = datamodule.test_dataset.paths\\ntest_filenames = [path.name for path in test_paths]\";\n", - " var nbb_formatted_code = \"test_paths = datamodule.test_dataset.paths\\ntest_filenames = [path.name for path in test_paths]\";\n", - " var nbb_cells = Jupyter.notebook.get_cells();\n", - " for (var i = 0; i < nbb_cells.length; ++i) {\n", - " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", - " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", - " nbb_cells[i].set_text(nbb_formatted_code);\n", - " }\n", - " break;\n", - " }\n", - " }\n", - " }, 500);\n", - " " - ], + "application/vnd.jupyter.widget-view+json": { + "model_id": "a91b1778b4374d0c94ddfd7c8be2056b", + "version_major": 2, + "version_minor": 0 + }, "text/plain": [ - "" + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" - } - ], - "source": [ - "test_paths = datamodule.test_dataset.paths\n", - "test_filenames = [path.name for path in test_paths]" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 7 | MAE/CSI: 21.077778418376518 | MAE: 17.037970858700216 | CSI: 0.8083380762663631 | Loss: 0.011935080401599407\n" + ] + }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "e2360942148d490c8fd59244a4e174f3", + "model_id": "b2aab07fe413456eb4bd186360a0b2f3", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14400.0), HTML(value='')))" + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" ] }, "metadata": {}, @@ -1477,248 +1241,3228 @@ "name": "stdout", "output_type": "stream", "text": [ - "\n" + "Epoch 8 | MAE/CSI: 20.73141515057428 | MAE: 16.773482073236565 | CSI: 0.8090852434041964 | Loss: 0.012599549256265163\n" ] }, { "data": { - "application/javascript": [ - "\n", - " setTimeout(function() {\n", - " var nbb_cell_id = 27;\n", - " var nbb_unformatted_code = \"subm = pd.DataFrame({\\\"file_name\\\": test_filenames})\\nfor i in tqdm(range(14400)):\\n subm[str(i)] = final_preds[:, i]\";\n", - " var nbb_formatted_code = \"subm = pd.DataFrame({\\\"file_name\\\": test_filenames})\\nfor i in tqdm(range(14400)):\\n subm[str(i)] = final_preds[:, i]\";\n", - " var nbb_cells = Jupyter.notebook.get_cells();\n", - " for (var i = 0; i < nbb_cells.length; ++i) {\n", - " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", - " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", - " nbb_cells[i].set_text(nbb_formatted_code);\n", - " }\n", - " break;\n", - " }\n", - " }\n", - " }, 500);\n", - " " - ], + "application/vnd.jupyter.widget-view+json": { + "model_id": "02e2db2eda824017a74f5cb273a98dc0", + "version_major": 2, + "version_minor": 0 + }, "text/plain": [ - "" + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" - } - ], - "source": [ - "subm = pd.DataFrame({\"file_name\": test_filenames})\n", - "for i in tqdm(range(14400)):\n", - " subm[str(i)] = final_preds[:, i]" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": {}, - "outputs": [ + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 9 | MAE/CSI: 20.89835070821418 | MAE: 17.01616066608685 | CSI: 0.8142346208869814 | Loss: 0.011802570894360542\n" + ] + }, { "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
file_name012345678...14390143911439214393143941439514396143971439814399
0test_00000.npy000000000...0000000000
1test_00001.npy000000000...0000000000
2test_00002.npy000000000...0000000000
3test_00003.npy000000000...0000000000
4test_00004.npy000000000...0000000000
\n", - "

5 rows Ă— 14401 columns

\n", - "
" - ], + "application/vnd.jupyter.widget-view+json": { + "model_id": "4a39513ea57849868f4459c2df69aa53", + "version_major": 2, + "version_minor": 0 + }, "text/plain": [ - " file_name 0 1 2 3 4 5 6 7 8 ... 14390 14391 14392 14393 \\\n", - "0 test_00000.npy 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 \n", - "1 test_00001.npy 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 \n", - "2 test_00002.npy 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 \n", - "3 test_00003.npy 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 \n", - "4 test_00004.npy 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 \n", - "\n", - " 14394 14395 14396 14397 14398 14399 \n", - "0 0 0 0 0 0 0 \n", - "1 0 0 0 0 0 0 \n", - "2 0 0 0 0 0 0 \n", - "3 0 0 0 0 0 0 \n", - "4 0 0 0 0 0 0 \n", - "\n", - "[5 rows x 14401 columns]" + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" ] }, - "execution_count": 28, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" }, { - "data": { - "application/javascript": [ - "\n", - " setTimeout(function() {\n", - " var nbb_cell_id = 28;\n", - " var nbb_unformatted_code = \"subm.to_csv(\\n f\\\"unet_bs{args['batch_size']}_epoch{args['max_epochs']}_{args['optimizer']}_{args['scheduler']}.csv\\\",\\n index=False,\\n)\\nsubm.head()\";\n", - " var nbb_formatted_code = \"subm.to_csv(\\n f\\\"unet_bs{args['batch_size']}_epoch{args['max_epochs']}_{args['optimizer']}_{args['scheduler']}.csv\\\",\\n index=False,\\n)\\nsubm.head()\";\n", + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 10 | MAE/CSI: 20.213408750832656 | MAE: 16.462153962294565 | CSI: 0.8144175069727526 | Loss: 0.011618967168033123\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "34c0998b4b6c4b8dbfda754b1911cd02", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 11 | MAE/CSI: 24.208863723571998 | MAE: 19.503725665599507 | CSI: 0.805643994211288 | Loss: 0.012568147853016853\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "54b41b9e8d894bd4bdd4c9aef70f31e6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 12 | MAE/CSI: 20.601691043830602 | MAE: 16.781879929672506 | CSI: 0.81458749643163 | Loss: 0.01178012229502201\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8341b57fc5294c0aa6531493b313718e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 13 | MAE/CSI: 19.976155204575853 | MAE: 16.334156712039647 | CSI: 0.8176827094474153 | Loss: 0.011617383919656277\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ee62e7c1f038428da229a1ca55d9ee54", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 14 | MAE/CSI: 18.39588910524068 | MAE: 15.092483317838292 | CSI: 0.8204269568857262 | Loss: 0.01175840012729168\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "13c26c69dd2547c2954e49585ddc6682", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 15 | MAE/CSI: 19.09062582033396 | MAE: 15.576906377459213 | CSI: 0.8159452981813056 | Loss: 0.011697824113070965\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0791fdbd0ae64936839efb7189a1e5ee", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 16 | MAE/CSI: 18.493273251072264 | MAE: 15.189089688174784 | CSI: 0.8213305174234424 | Loss: 0.011621751822531223\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "40dedef37d074264bec6de3792f9b399", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 17 | MAE/CSI: 18.927566425968354 | MAE: 15.492399944517754 | CSI: 0.818509870515814 | Loss: 0.011545676738023758\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7a1dfe2df2e540f8a456b693d6224aa8", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 18 | MAE/CSI: 19.334057504131742 | MAE: 15.774070578664828 | CSI: 0.8158696422245838 | Loss: 0.011600039899349213\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "09324ff636fd4c909de601b3dab5fc64", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 19 | MAE/CSI: 19.14047650865582 | MAE: 15.601810012202344 | CSI: 0.8151212957069099 | Loss: 0.011548931710422039\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2ebe0a8060ed4d01ba46d6a03fcefea2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 20 | MAE/CSI: 20.027654187947494 | MAE: 16.23505613828575 | CSI: 0.8106319385140905 | Loss: 0.011748154647648335\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0b84e7d7820f43e2ba540e52081e90c6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 21 | MAE/CSI: 19.016997137613103 | MAE: 15.494138983728332 | CSI: 0.814752133135886 | Loss: 0.011649723164737225\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ca046f3f353f49cf944135f32bc6cbac", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 22 | MAE/CSI: 19.549548452653777 | MAE: 15.862644750554823 | CSI: 0.8114072194021432 | Loss: 0.01184056606143713\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e0407bfcb9514a509f0f55da0e4f3a56", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 23 | MAE/CSI: 19.62327869559657 | MAE: 15.935058043144943 | CSI: 0.8120487045164944 | Loss: 0.011686836369335651\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "50c6661120e34e21a5f563af56eaddb5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 24 | MAE/CSI: 19.20136738616643 | MAE: 15.653670722243882 | CSI: 0.8152372905224616 | Loss: 0.011736424639821053\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "bde6ebc6396143b39f6e898fb8c7abc7", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 25 | MAE/CSI: 19.269590556557628 | MAE: 15.65051845909309 | CSI: 0.8121873899260303 | Loss: 0.011749816127121449\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a72177c3bb1840b8a725a8bc53f59335", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 26 | MAE/CSI: 19.560944262355147 | MAE: 15.85134794678203 | CSI: 0.8103569916748977 | Loss: 0.011810777708888054\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "58d578baa0094079bb079ea6d745f2fb", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 27 | MAE/CSI: 19.463234309122434 | MAE: 15.779038946990088 | CSI: 0.8107100133699247 | Loss: 0.011808572337031364\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0bb48e14af5743798aca97d9b6e1fb6a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 28 | MAE/CSI: 19.463016053011348 | MAE: 15.769864921350168 | CSI: 0.8102477477477478 | Loss: 0.01181867253035307\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "652bf4b1300d4c3d8c70706ad9a5fffd", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 29 | MAE/CSI: 19.419930651638133 | MAE: 15.728437331303983 | CSI: 0.8099121265377855 | Loss: 0.011822505854070187\n", + "\n", + "Model saved at ../models/unet_fold0_bs256_epochs30_lr0.001_adamw_cosine.ckpt\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True, used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "Using native 16bit precision.\n", + "\n", + " | Name | Type | Params\n", + "-----------------------------------------\n", + "0 | criterion | L1Loss | 0 \n", + "1 | tail | BasicBlock | 300 \n", + "2 | encoder | Encoder | 25 M \n", + "3 | decoder | Decoder | 17 M \n", + "4 | head | Sequential | 8 K \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training fold 1...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d1ff6d803a8541fda99aeb6d32bd7bc2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 109838714384134.84 | MAE: 109.83871438413485 | CSI: 0.0 | Loss: 0.050759207457304\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3fd59c9f003548f58f39e4e6373646e4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "46cc5e9d0f9f49558c36f81a6d0688e1", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 30.438610328033292 | MAE: 22.877340869081632 | CSI: 0.7515895312731037 | Loss: 0.01447451300919056\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a902d952874a4a78a3cd2bd7283a1d1e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 | MAE/CSI: 27.81257920736288 | MAE: 21.42626403857756 | CSI: 0.7703803332586117 | Loss: 0.013589969836175442\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8a4c7fec056a43a3bb14b68e607b9861", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 2 | MAE/CSI: 23.791732404542913 | MAE: 18.607595477371124 | CSI: 0.7821034282393957 | Loss: 0.013189456425607204\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7cef39366f3d4379baddfb28bd5cf6f1", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 3 | MAE/CSI: 24.67189650396396 | MAE: 19.398163979907117 | CSI: 0.7862453531598513 | Loss: 0.01269851066172123\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "69ff40148b304ff3901476ec9e9e839e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 4 | MAE/CSI: 22.37970634801444 | MAE: 17.825067209957396 | CSI: 0.7964835164835165 | Loss: 0.01245537493377924\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "561ef88774334a06b954a0239811fcec", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 5 | MAE/CSI: 21.38870789855043 | MAE: 17.124827474286082 | CSI: 0.8006480595036454 | Loss: 0.012010117061436176\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "20583f22ca784c0aafc31ab4d9ea1d2f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 6 | MAE/CSI: 23.121047769023082 | MAE: 18.387481908225936 | CSI: 0.7952702702702703 | Loss: 0.01195940189063549\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "950d2184f7884e85a64d8b1a99f5a151", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 7 | MAE/CSI: 21.380010649024133 | MAE: 17.249816201340995 | CSI: 0.8068198133524767 | Loss: 0.012651579454541206\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c34d43b63339466da512b9ca8fcc28e7", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 8 | MAE/CSI: 19.884120596258796 | MAE: 16.060755238895975 | CSI: 0.8077176539503551 | Loss: 0.011617396026849747\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e8c6f8c1bddd425f8c1fa188a7bbb532", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 9 | MAE/CSI: 19.95838481240207 | MAE: 16.128404994784848 | CSI: 0.8081017149615612 | Loss: 0.011446814052760601\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "89385dfbb8de4e0fb18e2bf5bad2c8c7", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 10 | MAE/CSI: 21.23726514737613 | MAE: 17.11664908782062 | CSI: 0.8059723777528929 | Loss: 0.011620122008025646\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "00e352a8062240b5a5743916fc96a69c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 11 | MAE/CSI: 19.230807556137442 | MAE: 15.613997821907722 | CSI: 0.8119262686347948 | Loss: 0.01138119213283062\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1d6c1804cef14876ae4e43dd9b324c3e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 12 | MAE/CSI: 19.26334314043933 | MAE: 15.615153630625755 | CSI: 0.810614934114202 | Loss: 0.011543345637619495\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7ce763edb2fe4a07a766106f37696d98", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 13 | MAE/CSI: 19.264658003703257 | MAE: 15.678122278859526 | CSI: 0.8138282172373081 | Loss: 0.011346523649990559\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "79c599c142f541bdb0a2ec25d4745d0a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 14 | MAE/CSI: 19.5614310438213 | MAE: 15.858158574006369 | CSI: 0.8106849922411882 | Loss: 0.011353997513651848\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d99419b31ccd42a695a7f7d64b342efb", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 15 | MAE/CSI: 20.325657045944222 | MAE: 16.421114476880287 | CSI: 0.8079007945347887 | Loss: 0.011452319100499153\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "71964a988c884548818bbe9245ae6576", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 16 | MAE/CSI: 19.160887430973464 | MAE: 15.560784978478654 | CSI: 0.8121119146760173 | Loss: 0.011521076783537865\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "87bdbf52bbb34e9790ec466ae817032c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 17 | MAE/CSI: 18.913051188853814 | MAE: 15.386936254420686 | CSI: 0.8135618151062735 | Loss: 0.01132373046129942\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "93c2f47a6bbe4ffea2204ac56ac38d7f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 18 | MAE/CSI: 20.261271536845946 | MAE: 16.398104336047624 | CSI: 0.8093324402768475 | Loss: 0.011397392489016056\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b7b1c439f59c474b8d5af5c5b42a82e3", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 19 | MAE/CSI: 18.88648488743959 | MAE: 15.356261966135346 | CSI: 0.813082056170712 | Loss: 0.011438331566751003\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "62d0d3a89b6e4d7abd0d1f9fa6cba609", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 20 | MAE/CSI: 18.90888501053057 | MAE: 15.309875574315502 | CSI: 0.8096656976744186 | Loss: 0.011600780300796032\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9a9f131a35ef4f0294a57718a026597d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 21 | MAE/CSI: 19.057435755200128 | MAE: 15.442373928848985 | CSI: 0.8103070175438597 | Loss: 0.011565647087991238\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f16275d93b70448eb23b32e982976b0d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 22 | MAE/CSI: 19.747915922078203 | MAE: 15.972321902462133 | CSI: 0.8088105076741441 | Loss: 0.01151563972234726\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c7e469afd97e422f80a5fa2eafbe7e2d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 23 | MAE/CSI: 18.986097731680978 | MAE: 15.399541466936103 | CSI: 0.8110956598111688 | Loss: 0.01157230231910944\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7a00e6b480f54b0bb22892842fcc17ee", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 24 | MAE/CSI: 18.99107928711418 | MAE: 15.39025627194611 | CSI: 0.8103939770484614 | Loss: 0.011596854776144028\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a130b024b6a6472e86c29bf913da82e0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 25 | MAE/CSI: 19.7236355264424 | MAE: 15.961558913109142 | CSI: 0.8092604880926049 | Loss: 0.011602360755205154\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "fa868a2ea33948d6931841708c47faff", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 26 | MAE/CSI: 19.66751364177057 | MAE: 15.879075442501195 | CSI: 0.8073758448427858 | Loss: 0.01169213280081749\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e2f5842c9ac34e9f833d477de063b938", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 27 | MAE/CSI: 19.442847596632387 | MAE: 15.730006844393438 | CSI: 0.8090382217005355 | Loss: 0.011666052974760532\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "616cdce2484f41d4bef4fb6a2e282aec", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 28 | MAE/CSI: 19.545673056666146 | MAE: 15.805726503701734 | CSI: 0.8086560364464692 | Loss: 0.011664430610835552\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "307618286fe141c797ea990e8ff558ca", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 29 | MAE/CSI: 19.661343826279275 | MAE: 15.894221094064639 | CSI: 0.8083995292733157 | Loss: 0.01167358923703432\n", + "\n", + "Model saved at ../models/unet_fold1_bs256_epochs30_lr0.001_adamw_cosine.ckpt\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True, used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "Using native 16bit precision.\n", + "\n", + " | Name | Type | Params\n", + "-----------------------------------------\n", + "0 | criterion | L1Loss | 0 \n", + "1 | tail | BasicBlock | 300 \n", + "2 | encoder | Encoder | 25 M \n", + "3 | decoder | Decoder | 17 M \n", + "4 | head | Sequential | 8 K \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training fold 2...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "405f5c85ef084bce8b871f410f30d9f6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 98624364269755.4 | MAE: 98.6243642697554 | CSI: 0.0 | Loss: 0.08764511346817017\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b6c1c69682bc46f6af0cd0ffdb37f6f2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "25d0948609e64e97b71a4345eeaa62c1", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 28.118997965134586 | MAE: 21.552152280986306 | CSI: 0.7664623151821133 | Loss: 0.01427386049181223\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "790146027622489ea4f5c7f7aa93cad4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 | MAE/CSI: 27.128638812721636 | MAE: 21.151175594950804 | CSI: 0.7796622506915126 | Loss: 0.013792337849736214\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "54c89488e7d349bf8fb01b31cbe88e08", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 2 | MAE/CSI: 24.618840091334018 | MAE: 19.40673530777922 | CSI: 0.7882879630298216 | Loss: 0.01312659028917551\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9689e4bd0e5543c8826899daba921033", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 3 | MAE/CSI: 23.996545896105665 | MAE: 19.048209403667897 | CSI: 0.793789634813033 | Loss: 0.012730807065963745\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1f68408fc1b241f1ab3cc6f33434c4f9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 4 | MAE/CSI: 22.49951004667157 | MAE: 18.03575311320954 | CSI: 0.8016064827978391 | Loss: 0.012839207425713539\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "936cafc4880d4ef896ded07a9d1fb4e5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 5 | MAE/CSI: 21.361131217621754 | MAE: 17.257766442326197 | CSI: 0.8079050807977871 | Loss: 0.012023642659187317\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b96e1401f7fb48279141083e6341b988", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 6 | MAE/CSI: 20.236411459874972 | MAE: 16.415489828220917 | CSI: 0.811185810327564 | Loss: 0.012115568853914738\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "61756ecb35544951bb8b588df639509d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 7 | MAE/CSI: 24.470472404637587 | MAE: 19.564411922558676 | CSI: 0.7995110024449877 | Loss: 0.012515711598098278\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "55e6ce89340e4a35a699ea62ac0d03e7", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 8 | MAE/CSI: 19.02272187040231 | MAE: 15.550396087375951 | CSI: 0.817464303652149 | Loss: 0.011686594225466251\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "bd2dfacd4e544b9499aab1e79445a0a2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 9 | MAE/CSI: 20.080601901201696 | MAE: 16.387970643444902 | CSI: 0.816109533173112 | Loss: 0.011634819209575653\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b2a5005184fd465d9aa296ce26441527", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 10 | MAE/CSI: 19.03831631311525 | MAE: 15.581505716571792 | CSI: 0.8184287654585746 | Loss: 0.011825410649180412\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "53c512a1f821453cbde93599832b96d8", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 11 | MAE/CSI: 19.32006970412333 | MAE: 15.80681914816354 | CSI: 0.8181553891997965 | Loss: 0.01143832691013813\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8da809ca9ee9444f87344ee092618358", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 12 | MAE/CSI: 18.683346398615083 | MAE: 15.279302381427371 | CSI: 0.8178033022254128 | Loss: 0.011568550951778889\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f850472f38c14a5db4959df141dbe859", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 13 | MAE/CSI: 19.403732510254446 | MAE: 15.866799598526859 | CSI: 0.8177189409368636 | Loss: 0.011433429084718227\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "57a156dfab894249b5f6e6d972afc0f0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 14 | MAE/CSI: 18.861351107256343 | MAE: 15.467400527873899 | CSI: 0.8200579290369298 | Loss: 0.011417574249207973\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "275cbf5881ac4340913d46a69d6a3cd7", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 15 | MAE/CSI: 20.403804237083232 | MAE: 16.587705867603038 | CSI: 0.8129712319742333 | Loss: 0.011530996300280094\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9ee8f92c20014609943c833a22917bf9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 16 | MAE/CSI: 18.508092655544406 | MAE: 15.191096051399349 | CSI: 0.8207812838472092 | Loss: 0.011424791999161243\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0896606bc2d34a11afb711583c9d93e1", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 17 | MAE/CSI: 18.44665647834475 | MAE: 15.127479769290078 | CSI: 0.8200662156326471 | Loss: 0.011457541026175022\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d2cacea574c14483914d2348191e01a7", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 18 | MAE/CSI: 18.96570753645215 | MAE: 15.53061988275734 | CSI: 0.8188790137614679 | Loss: 0.011679432354867458\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e4f80b1db2e1480b846cc0d1a16d2145", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 19 | MAE/CSI: 18.822043787124592 | MAE: 15.420485982148985 | CSI: 0.8192779783393502 | Loss: 0.011392543092370033\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "cb9947c1fcb64350b6d0b60ef911121b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 20 | MAE/CSI: 18.606130005732833 | MAE: 15.197414212351436 | CSI: 0.8167960885821111 | Loss: 0.01142452098429203\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8a96d1ab45ff4d289137cec22137db89", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 21 | MAE/CSI: 18.9835783172211 | MAE: 15.499320162112285 | CSI: 0.8164593578247035 | Loss: 0.011478066444396973\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "50d37fa7501b41ddaa7db99ad4bf9333", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 22 | MAE/CSI: 19.555789183906093 | MAE: 15.923755880672903 | CSI: 0.8142732431252728 | Loss: 0.01156239677220583\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "077ca7638bbd458690edb7ed9a90cd97", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 23 | MAE/CSI: 18.89104865410489 | MAE: 15.40757357276467 | CSI: 0.8156018151696319 | Loss: 0.011537961661815643\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c290c6b11d734153ac8a1a9f48004f93", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 24 | MAE/CSI: 18.59823612059139 | MAE: 15.182996494158246 | CSI: 0.816367552045944 | Loss: 0.011565683409571648\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "5b52ca41066c41aaa98c9f2d1bc789dc", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 25 | MAE/CSI: 19.234773276351596 | MAE: 15.645243910983037 | CSI: 0.8133833285261033 | Loss: 0.011604744009673595\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8dc8dee56c1a4d83a52653d57b7d466d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 26 | MAE/CSI: 18.789052965351058 | MAE: 15.296363151240856 | CSI: 0.8141103854159191 | Loss: 0.01161937415599823\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "475a0e05c6a142e0bd4792de41e7ba2e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 27 | MAE/CSI: 18.882758069128442 | MAE: 15.359764472001991 | CSI: 0.8134280180761781 | Loss: 0.011645403690636158\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "5d0dcfc251c642f9b1c59dd535cb9984", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 28 | MAE/CSI: 18.913308360419204 | MAE: 15.388432096423116 | CSI: 0.8136298421807747 | Loss: 0.011661795899271965\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e6ddf2f7232741f9a3e1d076aff25238", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 29 | MAE/CSI: 18.950830060861367 | MAE: 15.392199249279548 | CSI: 0.8122176812217681 | Loss: 0.011679957620799541\n", + "\n", + "Model saved at ../models/unet_fold2_bs256_epochs30_lr0.001_adamw_cosine.ckpt\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True, used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "Using native 16bit precision.\n", + "\n", + " | Name | Type | Params\n", + "-----------------------------------------\n", + "0 | criterion | L1Loss | 0 \n", + "1 | tail | BasicBlock | 300 \n", + "2 | encoder | Encoder | 25 M \n", + "3 | decoder | Decoder | 17 M \n", + "4 | head | Sequential | 8 K \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training fold 3...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ad83a026b5e544bf8e4d9fa29eb2b718", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 110557260156424.86 | MAE: 110.55726015642486 | CSI: 0.0 | Loss: 0.05172676593065262\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8af8c9128df944fb9cf91293421f8331", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "95ca18fc0bac4fee85fe76284d969c53", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 53.56887701828267 | MAE: 35.32899085851282 | CSI: 0.6595059076262084 | Loss: 0.01728362962603569\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9e4cded4871444eebff9948e284b2efe", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 | MAE/CSI: 30.639784592662963 | MAE: 23.430682562051953 | CSI: 0.7647143370463528 | Loss: 0.014653380028903484\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "553f0ab5c3924ae49cd47cbc05afb482", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 2 | MAE/CSI: 24.957249398984533 | MAE: 19.493277948112066 | CSI: 0.7810667608618863 | Loss: 0.013412871398031712\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b66782cf56f540179789b8338ab5add8", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 3 | MAE/CSI: 22.953791365850908 | MAE: 18.058455785756223 | CSI: 0.7867308497279196 | Loss: 0.013156878761947155\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "4338523873ab430b8af44a2c03af89b2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 4 | MAE/CSI: 23.78838192093069 | MAE: 18.79893090212847 | CSI: 0.7902568137921168 | Loss: 0.012770530767738819\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "112cfcb535c247c8a802d637cd4ea78e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 5 | MAE/CSI: 24.134156968826158 | MAE: 19.308000326101173 | CSI: 0.8000279583420703 | Loss: 0.013565506786108017\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "459e7710db3f4715a22d2109736243d5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 6 | MAE/CSI: 21.513528786478364 | MAE: 17.24819258159191 | CSI: 0.8017370256994376 | Loss: 0.012290913611650467\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "162357f4824543e993f6c37e3a001450", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 7 | MAE/CSI: 21.362476504765034 | MAE: 17.063856162282722 | CSI: 0.7987770593196514 | Loss: 0.01256471686065197\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f2335250002d46c79c67a2917e2752f6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 8 | MAE/CSI: 24.453781933161846 | MAE: 19.512796349196865 | CSI: 0.7979459538203802 | Loss: 0.012635443359613419\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "166c60d14bf04ec5a3a2886cf4736daf", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 9 | MAE/CSI: 20.900274204023304 | MAE: 16.887015755273936 | CSI: 0.8079805839103433 | Loss: 0.0120097566395998\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "81464fc773bf4bc6891968689cce104a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 10 | MAE/CSI: 20.906133797377368 | MAE: 16.826452392597048 | CSI: 0.8048572039028559 | Loss: 0.012089049443602562\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a10d960e744b41c7bb5eb130d315edef", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 11 | MAE/CSI: 19.764185822147727 | MAE: 16.106189117183263 | CSI: 0.8149179157744468 | Loss: 0.011762270703911781\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7d0369708b3b438893095d9962042ef8", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 12 | MAE/CSI: 19.216470309349777 | MAE: 15.633324322908592 | CSI: 0.8135377658446973 | Loss: 0.011789782904088497\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7f3c1002588b4ca3836d07279147d918", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 13 | MAE/CSI: 21.358521733793662 | MAE: 17.035517334854887 | CSI: 0.7975981459372147 | Loss: 0.012352569960057735\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "75d2d7bf01e04d9599bb623b20280d43", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 14 | MAE/CSI: 18.955754078822256 | MAE: 15.441341389797334 | CSI: 0.8145991621103458 | Loss: 0.011553915217518806\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "5a51217b899446be8c22bfb4d1f8b382", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 15 | MAE/CSI: 18.41964258573334 | MAE: 15.062246454066061 | CSI: 0.8177274007321881 | Loss: 0.011606993153691292\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e3d6b58fa0594567bd4a44fbd7425187", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 16 | MAE/CSI: 18.945060050135833 | MAE: 15.475827797763376 | CSI: 0.8168793214056597 | Loss: 0.011503880843520164\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "800287eb1fd34ea3a96cd66956da367c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 17 | MAE/CSI: 19.05648451538784 | MAE: 15.521529688907957 | CSI: 0.814501209620037 | Loss: 0.011547097936272621\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "5fcd702c53704783b869023d0f0c6070", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 18 | MAE/CSI: 19.16718569318381 | MAE: 15.615780001194492 | CSI: 0.8147142857142857 | Loss: 0.011518046259880066\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "281f546726974b9abdd5277c2c427677", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 19 | MAE/CSI: 18.73081169605486 | MAE: 15.293072094853354 | CSI: 0.8164660636706788 | Loss: 0.01148569118231535\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "bb25d9a424b6461bb8f7c6c29e3a854c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 20 | MAE/CSI: 19.141530793836786 | MAE: 15.653774430905004 | CSI: 0.817791147400086 | Loss: 0.011520149186253548\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "168fd5a8a9684f0fac0a04a354d50728", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 21 | MAE/CSI: 18.813476282761126 | MAE: 15.291689353583903 | CSI: 0.8128050937389458 | Loss: 0.011663438752293587\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f50168362a874e738ae34ce5281ce9d6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 22 | MAE/CSI: 18.42512713407604 | MAE: 14.991494951336136 | CSI: 0.813644043931287 | Loss: 0.011668611317873001\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "36336bf5edf842cea7e3506996a143a8", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 23 | MAE/CSI: 18.5111411504437 | MAE: 15.066295467950644 | CSI: 0.8139041966935142 | Loss: 0.011603855527937412\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8baab25eb41943aa9ca865724504293a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 24 | MAE/CSI: 18.97355917639315 | MAE: 15.4504381284336 | CSI: 0.814314171883893 | Loss: 0.011588823981583118\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "86f8c11ff5724d3aa431311aa061b2e2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 25 | MAE/CSI: 18.991307442213472 | MAE: 15.43785156363545 | CSI: 0.8128904031800114 | Loss: 0.011663117446005344\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9dfb8cf47bc74b599cccede10f6d90d1", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 26 | MAE/CSI: 18.63918385797683 | MAE: 15.163947710759158 | CSI: 0.8135521290140048 | Loss: 0.01164956297725439\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "213217b29be249778d043162380e16dd", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 27 | MAE/CSI: 18.906251519195056 | MAE: 15.356309645176488 | CSI: 0.8122344944774851 | Loss: 0.01169898733496666\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "fc96ce3b143d4abca23182942be61809", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 28 | MAE/CSI: 19.010790761517644 | MAE: 15.43363629861466 | CSI: 0.8118355776045358 | Loss: 0.011704476550221443\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "133a44e7b2344fbdb2a4365a21382447", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 29 | MAE/CSI: 18.91723245279753 | MAE: 15.368827583863563 | CSI: 0.8124247361337394 | Loss: 0.011702695861458778\n", + "\n", + "Model saved at ../models/unet_fold3_bs256_epochs30_lr0.001_adamw_cosine.ckpt\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True, used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "Using native 16bit precision.\n", + "\n", + " | Name | Type | Params\n", + "-----------------------------------------\n", + "0 | criterion | L1Loss | 0 \n", + "1 | tail | BasicBlock | 300 \n", + "2 | encoder | Encoder | 25 M \n", + "3 | decoder | Decoder | 17 M \n", + "4 | head | Sequential | 8 K \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training fold 4...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c6f1f1cf42cf41348e7a58e78981023d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 110804066395808.14 | MAE: 110.80406639580814 | CSI: 0.0 | Loss: 0.046886492520570755\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "722c1610ece8437b94239fb0e476cb00", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "837aadf0ed7544b2b2497fa1263209f4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 29.46492178703207 | MAE: 22.387287602854418 | CSI: 0.7597945707997066 | Loss: 0.014288208447396755\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "5bb9bc593d75424d8b392255d6f2209d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 | MAE/CSI: 24.204093119843044 | MAE: 19.084029852953837 | CSI: 0.7884629165173773 | Loss: 0.013465666212141514\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6132ce6528d74d7cb47eebfce33b4439", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 2 | MAE/CSI: 23.3908121436228 | MAE: 18.574004691365896 | CSI: 0.7940726716667881 | Loss: 0.0127839595079422\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "baa6e45053cb4ce8bf20c56c8760d998", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 3 | MAE/CSI: 21.65168614373402 | MAE: 17.37842964056247 | CSI: 0.8026363177987467 | Loss: 0.012596615590155125\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0c5949bc6ac049439eaa8a630a88e230", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 4 | MAE/CSI: 21.677014943572697 | MAE: 17.44700715090798 | CSI: 0.8048620714753622 | Loss: 0.012286808341741562\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ff6ad758b4a64e0caa31f86c05b23690", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 5 | MAE/CSI: 21.918280019408275 | MAE: 17.60216383659147 | CSI: 0.803081438004402 | Loss: 0.012095422483980656\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2927897dd5ee475da777d5d252c03fd3", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 6 | MAE/CSI: 20.329710330708902 | MAE: 16.422185408844385 | CSI: 0.8077923955472025 | Loss: 0.012070290744304657\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "af8baae7f3924bee9614d53f2b210df0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 7 | MAE/CSI: 21.331304906024766 | MAE: 17.233019202263723 | CSI: 0.8078745898651112 | Loss: 0.011999445036053658\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "bbbaaf899dca4ddcb7e6a6600889a11c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 8 | MAE/CSI: 20.6731214146294 | MAE: 16.683493317061913 | CSI: 0.8070137538705264 | Loss: 0.012145860120654106\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "404c5111ad7c45cca38624ae968dcb3b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 9 | MAE/CSI: 19.466089391431066 | MAE: 15.881236795690588 | CSI: 0.8158411520837879 | Loss: 0.011607903987169266\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b18785cc75f44ad5ac35439416075c2c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 10 | MAE/CSI: 19.81252870088382 | MAE: 16.179648879536806 | CSI: 0.8166372462488968 | Loss: 0.011471263132989407\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "09c94f8fff7c448e87cfd9896cc46913", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 11 | MAE/CSI: 19.550455453315504 | MAE: 15.932699609748498 | CSI: 0.8149528612146459 | Loss: 0.011466726660728455\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "11b1a04b097245c791ac9639a8c30b67", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 12 | MAE/CSI: 18.289922608388945 | MAE: 15.044766262064371 | CSI: 0.8225713462092822 | Loss: 0.01155020110309124\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6989146b42de453c96c75c646a6f88ce", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 13 | MAE/CSI: 19.916873297085072 | MAE: 16.239037439054368 | CSI: 0.8153407011637268 | Loss: 0.011535759083926678\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a00c71eebe2742078f40673906189ff6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 14 | MAE/CSI: 19.381075628919593 | MAE: 15.88246965193991 | CSI: 0.8194833948339484 | Loss: 0.011257984675467014\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "019efaa949604264b6090e3970c44484", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 15 | MAE/CSI: 20.13362559236226 | MAE: 16.44119687934336 | CSI: 0.8166038850727528 | Loss: 0.01147684920579195\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2a79d9eeae8b43d59d3275b07882e4a0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 16 | MAE/CSI: 17.862983125481282 | MAE: 14.748776933633795 | CSI: 0.8256614715476622 | Loss: 0.01122660469263792\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f9cf658173894379b090c7cd764b42c3", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 17 | MAE/CSI: 17.659409365863613 | MAE: 14.61047819083576 | CSI: 0.8273480662983426 | Loss: 0.011134890839457512\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a8a4100fad5b46c599684dbdb0455b15", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 18 | MAE/CSI: 17.875598933953043 | MAE: 14.754400889379596 | CSI: 0.8253933724893047 | Loss: 0.011305585503578186\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d7d2720c11be4081b3bcde968f9dc88c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 19 | MAE/CSI: 18.201254790410943 | MAE: 15.010341899683212 | CSI: 0.8246872027511524 | Loss: 0.011105424724519253\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "92dda2ccd1904847bc25407145721f94", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 20 | MAE/CSI: 18.577237816975604 | MAE: 15.28320825097125 | CSI: 0.8226846424384525 | Loss: 0.011204416863620281\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d0ed5dd921164c1790dc539914aaecc7", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 21 | MAE/CSI: 18.38351714846262 | MAE: 15.161192424176441 | CSI: 0.8247166361974406 | Loss: 0.011202105320990086\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8bbf85b4cda94a999f5f272b69d77c8b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 22 | MAE/CSI: 18.325339605405098 | MAE: 15.064050353614043 | CSI: 0.8220338983050848 | Loss: 0.011181995272636414\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "fb1d2d4e188c4e8d85bec6f98d921b8e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 23 | MAE/CSI: 17.875720986784252 | MAE: 14.731736891221365 | CSI: 0.8241198719813791 | Loss: 0.011183995753526688\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ae0e0a326ff846ebad65bef6cbc2ee95", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 24 | MAE/CSI: 18.44947563426877 | MAE: 15.130209495701461 | CSI: 0.8200888629907495 | Loss: 0.011271136812865734\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "715961f00ba642a6beb071797fe3a3ed", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 25 | MAE/CSI: 18.53381123761257 | MAE: 15.19639837323524 | CSI: 0.8199284096719994 | Loss: 0.011240239255130291\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9f3d5de73dcf41f7813996af558f10dd", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 26 | MAE/CSI: 18.0624360317926 | MAE: 14.856623087439964 | CSI: 0.8225149177703391 | Loss: 0.01123881060630083\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1b33f4a82abb4b8389120ff2f915bc46", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 27 | MAE/CSI: 18.33757410027865 | MAE: 15.043335066826108 | CSI: 0.8203557888597258 | Loss: 0.0112459110096097\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "08e80f7d937847a392118aa90a2bdfdd", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 28 | MAE/CSI: 18.386441478880272 | MAE: 15.091226520452597 | CSI: 0.8207801676995989 | Loss: 0.011267893016338348\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6d3c00256761455285174527ddd21387", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 29 | MAE/CSI: 18.40432508550994 | MAE: 15.094416588039563 | CSI: 0.8201559425781535 | Loss: 0.011272534728050232\n", + "\n", + "Model saved at ../models/unet_fold4_bs256_epochs30_lr0.001_adamw_cosine.ckpt\n" + ] + }, + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 16;\n", + " var nbb_unformatted_code = \"# AdamW bs256 lr 1e-3\\nfor fold in range(5):\\n train_fold(df, fold)\";\n", + " var nbb_formatted_code = \"# AdamW bs256 lr 1e-3\\nfor fold in range(5):\\n train_fold(df, fold)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", @@ -1732,18 +4476,395 @@ " " ], "text/plain": [ - "" + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# AdamW bs256 lr 1e-3\n", + "for fold in range(5):\n", + " train_fold(df, fold)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True, used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "Using native 16bit precision.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training fold 0...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + " | Name | Type | Params\n", + "-----------------------------------------\n", + "0 | criterion | L1Loss | 0 \n", + "1 | tail | BasicBlock | 300 \n", + "2 | encoder | Encoder | 25 M \n", + "3 | decoder | Decoder | 17 M \n", + "4 | head | Sequential | 8 K \n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "885a4ed50e7d41ab823130597b98959e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 167.45782510687246 | MAE: 27.286654537671232 | CSI: 0.16294642857142858 | Loss: 0.4579598009586334\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "33c71bafc60247cbbb413df289864066", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b810e2381d1c4e6095e01e694b20886e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 28.362450125306825 | MAE: 22.10657577075336 | CSI: 0.7794311024984428 | Loss: 0.020207127556204796\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a48d157eba4049c09d92b2b1a7c7c85d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 | MAE/CSI: 23.82231910706382 | MAE: 18.83105605581214 | CSI: 0.7904795486600846 | Loss: 0.014244887046515942\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "26f724d76a2d4523aad75bdf1784305c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 2 | MAE/CSI: 23.14310842963703 | MAE: 18.453300962246768 | CSI: 0.7973561986423723 | Loss: 0.013162474147975445\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d72b4eb2a7d14be4b0d829b07c1fa356", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 3 | MAE/CSI: 22.18791287315471 | MAE: 17.807568912908806 | CSI: 0.8025797205302759 | Loss: 0.012457935139536858\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "06588808b6a44f629791f023be14b748", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 4 | MAE/CSI: 23.763760209997628 | MAE: 19.19425320030403 | CSI: 0.8077111126632674 | Loss: 0.013415777124464512\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "701177f951c548a2957209fe134f763a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 5 | MAE/CSI: 21.892119289753165 | MAE: 17.683018462607592 | CSI: 0.8077344284736482 | Loss: 0.013465486466884613\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "4edfbe5a786b42d89ea66ab63c6a132e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 6 | MAE/CSI: 22.04263144863797 | MAE: 17.724101042939697 | CSI: 0.8040828103585083 | Loss: 0.01200629211962223\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6abc8d9668884244a071939dcc15c42f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 7 | MAE/CSI: 21.346540554610538 | MAE: 17.230123555188904 | CSI: 0.8071623367303 | Loss: 0.011888713575899601\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "29dd3c71466e4c44933785c50306f362", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 8 | MAE/CSI: 19.32465185410075 | MAE: 15.690600147796456 | CSI: 0.8119473647566664 | Loss: 0.01206459105014801\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b66ddc67851047fdbcaf7d6df8d7ce3d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 9 | MAE/CSI: 20.710088985666317 | MAE: 16.81065647000772 | CSI: 0.8117133867276888 | Loss: 0.011814854107797146\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "46727dee487c46c79c0b4448d9487023", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 10 | MAE/CSI: 24.743819129015346 | MAE: 19.77880616991625 | CSI: 0.7993433053630062 | Loss: 0.012558660469949245\n" + ] } ], "source": [ - "subm.to_csv(\n", - " f\"unet_bs{args['batch_size']}_epoch{args['max_epochs']}_{args['optimizer']}_{args['scheduler']}.csv\",\n", - " index=False,\n", - ")\n", + "# AdamW bs256 lr 1e-3 sigmoid\n", + "for fold in range(5):\n", + " train_fold(df, fold)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Inference" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def inference(checkpoints):\n", + " datamodule = NowcastingDataModule()\n", + " datamodule.setup(\"test\")\n", + " \n", + " test_paths = datamodule.test_dataset.paths\n", + " test_filenames = [path.name for path in test_paths]\n", + " final_preds = np.zeros((len(datamodule.test_dataset), 14400))\n", + " \n", + " for checkpoint in checkpoints:\n", + " print(f\"Inference from {checkpoint}\")\n", + " model = UNet.load_from_checkpoint(str(checkpoint))\n", + " model.cuda()\n", + " model.eval()\n", + " preds = []\n", + " with torch.no_grad():\n", + " for batch in tqdm(datamodule.test_dataloader()):\n", + " batch = batch.cuda()\n", + " imgs = model(batch)\n", + " imgs = imgs.detach().cpu().numpy()\n", + " imgs = imgs[:, 0, 4:124, 4:124]\n", + " imgs = args[\"rng\"] * imgs\n", + " imgs = imgs.clip(0, 255)\n", + " imgs = imgs.round()\n", + " preds.append(imgs)\n", + " \n", + " preds = np.concatenate(preds)\n", + " preds = preds.astype(np.uint8)\n", + " preds = preds.reshape(-1, 14400)\n", + " final_preds += preds / len(checkpoint)\n", + " \n", + " del model\n", + " gc.collect()\n", + " torch.cuda.empty_cache()\n", + " \n", + " final_preds = final_preds.round()\n", + " final_preds = final_preds.astype(np.uint8)\n", + " \n", + " subm = pd.DataFrame()\n", + " subm[\"file_name\"] = test_filename\n", + " for i in tqdm(range(14400)):\n", + " subm[str(i)] = final_preds[:, i]\n", + " \n", + " return subm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "checkpoints = [args[\"model_dir\"] / f\"unet_fold{fold}_bs{args['batch_size']}_epochs{args['max_epochs']}_lr{model.lr}_{args['optimizer']}_{args['scheduler']}.ckpt\" for fold in range(5)]\n", + "output_path = args[\"output_dir\"] / f\"unet_bs{args['batch_size']}_epochs{args['max_epochs']}_lr{model.lr}_{args['optimizer']}_{args['scheduler']}.csv\"\n", + "subm.to_csv(output_path, index=False)\n", "subm.head()" ] }, @@ -1803,6 +4924,20 @@ "outputs": [], "source": [] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, @@ -1813,9 +4948,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python [conda env:torch2] *", + "display_name": "Python [conda env:torch] *", "language": "python", - "name": "conda-env-torch2-py" + "name": "conda-env-torch-py" }, "language_info": { "codemirror_mode": { @@ -1827,7 +4962,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.6" + "version": "3.7.8" } }, "nbformat": 4,