From edee80bc642d60304f43c12a26c782681f6838ca Mon Sep 17 00:00:00 2001 From: Giwan Date: Sat, 7 Nov 2020 09:31:30 +0900 Subject: [PATCH] Modify unet notebooks --- notebooks/01-baseline-pytorch.ipynb | 42 +- notebooks/02-rainnet.ipynb | 781 ++++--- notebooks/03-unet.ipynb | 2923 +++++++-------------------- src/utils.py | 5 + 4 files changed, 1150 insertions(+), 2601 deletions(-) diff --git a/notebooks/01-baseline-pytorch.ipynb b/notebooks/01-baseline-pytorch.ipynb index a4b8dab..2d33622 100644 --- a/notebooks/01-baseline-pytorch.ipynb +++ b/notebooks/01-baseline-pytorch.ipynb @@ -6,6 +6,18 @@ "metadata": {}, "outputs": [], "source": [ + "import sys\n", + "sys.path.insert(0, \"../src\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%reload_ext autoreload\n", + "%autoreload 2\n", "%reload_ext nb_black" ] }, @@ -141,30 +153,6 @@ " )" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def visualize(x, y=None, test=False):\n", - " cmap = plt.cm.get_cmap(\"RdBu\")\n", - " cmap = cmap.reversed()\n", - " if test:\n", - " fig, axes = plt.subplots(1, 4, figsize=(10, 10))\n", - " for i, ax in enumerate(axes):\n", - " img = x[:, :, i]\n", - " ax.imshow(img, cmap=cmap)\n", - " else:\n", - " fig, axes = plt.subplots(1, 5, figsize=(10, 10))\n", - " for i, ax in enumerate(axes[:-1]):\n", - " img = x[:, :, i]\n", - " ax.imshow(img, cmap=cmap)\n", - " axes[-1].imshow(y[:, :, 0], cmap=cmap)\n", - " # plt.tight_layout()\n", - " plt.show()" - ] - }, { "cell_type": "code", "execution_count": null, @@ -400,7 +388,7 @@ "outputs": [], "source": [ "model = Baseline.load_from_checkpoint(\"baseline_bs256_epoch10.ckpt\")\n", - "datamodule = NowcastingDataModule(batch_size=128)\n", + "datamodule = NowcastingDataModule(batch_size=256)\n", "datamodule.setup(\"test\")" ] }, @@ -490,7 +478,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python [conda env:torch] *", + "display_name": "Python [conda env:torch]", "language": "python", "name": "conda-env-torch-py" }, @@ -504,7 +492,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.5" + "version": "3.7.8" } }, "nbformat": 4, diff --git a/notebooks/02-rainnet.ipynb b/notebooks/02-rainnet.ipynb index 7db15ee..21c7e55 100644 --- a/notebooks/02-rainnet.ipynb +++ b/notebooks/02-rainnet.ipynb @@ -60,8 +60,8 @@ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 3;\n", - " var nbb_unformatted_code = \"import gc\\nimport functools\\nfrom pathlib import Path\\nfrom concurrent.futures import ThreadPoolExecutor\\nfrom tqdm.notebook import tqdm\\n\\nimport cv2\\nimport numpy as np\\nimport pandas as pd\\nimport matplotlib.pyplot as plt\\nfrom sklearn import metrics\\n\\nimport torch\\nimport torch.nn as nn\\nimport torch.nn.functional as F\\nimport torchvision.transforms as T\\nimport pytorch_lightning as pl\\nfrom torch.utils.data import SequentialSampler, RandomSampler\\n\\nimport optim\\nfrom data import NowcastingDataset\\nfrom loss import LogCoshLoss\\nfrom utils import visualize, radar2precipitation\";\n", - " var nbb_formatted_code = \"import gc\\nimport functools\\nfrom pathlib import Path\\nfrom concurrent.futures import ThreadPoolExecutor\\nfrom tqdm.notebook import tqdm\\n\\nimport cv2\\nimport numpy as np\\nimport pandas as pd\\nimport matplotlib.pyplot as plt\\nfrom sklearn import metrics\\n\\nimport torch\\nimport torch.nn as nn\\nimport torch.nn.functional as F\\nimport torchvision.transforms as T\\nimport pytorch_lightning as pl\\nfrom torch.utils.data import SequentialSampler, RandomSampler\\n\\nimport optim\\nfrom data import NowcastingDataset\\nfrom loss import LogCoshLoss\\nfrom utils import visualize, radar2precipitation\";\n", + " var nbb_unformatted_code = \"import gc\\nimport functools\\nfrom pathlib import Path\\nfrom concurrent.futures import ThreadPoolExecutor\\nfrom tqdm.notebook import tqdm\\n\\nimport cv2\\nimport numpy as np\\nimport pandas as pd\\nimport matplotlib.pyplot as plt\\nfrom sklearn import metrics\\n\\nimport torch\\nimport torch.nn as nn\\nimport torch.nn.functional as F\\nimport torchvision.transforms as T\\nimport pytorch_lightning as pl\\nfrom torch.utils.data import SequentialSampler, RandomSampler\\n\\nimport transformers\\n\\nimport optim\\nfrom data import NowcastingDataset\\nfrom loss import LogCoshLoss\\nfrom utils import visualize, radar2precipitation, seed_everything\";\n", + " var nbb_formatted_code = \"import gc\\nimport functools\\nfrom pathlib import Path\\nfrom concurrent.futures import ThreadPoolExecutor\\nfrom tqdm.notebook import tqdm\\n\\nimport cv2\\nimport numpy as np\\nimport pandas as pd\\nimport matplotlib.pyplot as plt\\nfrom sklearn import metrics\\n\\nimport torch\\nimport torch.nn as nn\\nimport torch.nn.functional as F\\nimport torchvision.transforms as T\\nimport pytorch_lightning as pl\\nfrom torch.utils.data import SequentialSampler, RandomSampler\\n\\nimport transformers\\n\\nimport optim\\nfrom data import NowcastingDataset\\nfrom loss import LogCoshLoss\\nfrom utils import visualize, radar2precipitation, seed_everything\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", @@ -102,15 +102,17 @@ "import pytorch_lightning as pl\n", "from torch.utils.data import SequentialSampler, RandomSampler\n", "\n", + "import transformers\n", + "\n", "import optim\n", "from data import NowcastingDataset\n", "from loss import LogCoshLoss\n", - "from utils import visualize, radar2precipitation" + "from utils import visualize, radar2precipitation, seed_everything" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -118,9 +120,9 @@ "application/javascript": [ "\n", " setTimeout(function() {\n", - " var nbb_cell_id = 4;\n", - " var nbb_unformatted_code = \"args = dict(\\n dams=(6071, 6304, 7026, 7629, 7767, 8944, 11107),\\n train_folds_csv=Path(\\\"../input/train_folds.csv\\\"),\\n train_data_path=Path(\\\"../input/train-128\\\"),\\n test_data_path=Path(\\\"../input/test-128\\\"),\\n num_workers=4,\\n gpus=1,\\n lr=1e-4,\\n max_epochs=50,\\n batch_size=64,\\n precision=16,\\n optimizer=\\\"adamw\\\",\\n scheduler=\\\"cosine\\\",\\n gradient_accumulation_steps=1,\\n)\";\n", - " var nbb_formatted_code = \"args = dict(\\n dams=(6071, 6304, 7026, 7629, 7767, 8944, 11107),\\n train_folds_csv=Path(\\\"../input/train_folds.csv\\\"),\\n train_data_path=Path(\\\"../input/train-128\\\"),\\n test_data_path=Path(\\\"../input/test-128\\\"),\\n num_workers=4,\\n gpus=1,\\n lr=1e-4,\\n max_epochs=50,\\n batch_size=64,\\n precision=16,\\n optimizer=\\\"adamw\\\",\\n scheduler=\\\"cosine\\\",\\n gradient_accumulation_steps=1,\\n)\";\n", + " var nbb_cell_id = 8;\n", + " var nbb_unformatted_code = \"args = dict(\\n seed=42,\\n dams=(6071, 6304, 7026, 7629, 7767, 8944, 11107),\\n train_folds_csv=Path(\\\"../input/train_folds.csv\\\"),\\n train_data_path=Path(\\\"../input/train-128\\\"),\\n test_data_path=Path(\\\"../input/test-128\\\"),\\n num_workers=4,\\n gpus=1,\\n lr=1e-4,\\n max_epochs=50,\\n batch_size=64,\\n precision=16,\\n optimizer=\\\"adamw\\\",\\n scheduler=\\\"cosine\\\",\\n accumulate_grad_batches=2,\\n gradient_clip_val=5.0,\\n rng=255.0,\\n)\";\n", + " var nbb_formatted_code = \"args = dict(\\n seed=42,\\n dams=(6071, 6304, 7026, 7629, 7767, 8944, 11107),\\n train_folds_csv=Path(\\\"../input/train_folds.csv\\\"),\\n train_data_path=Path(\\\"../input/train-128\\\"),\\n test_data_path=Path(\\\"../input/test-128\\\"),\\n num_workers=4,\\n gpus=1,\\n lr=1e-4,\\n max_epochs=50,\\n batch_size=64,\\n precision=16,\\n optimizer=\\\"adamw\\\",\\n scheduler=\\\"cosine\\\",\\n accumulate_grad_batches=2,\\n gradient_clip_val=5.0,\\n rng=255.0,\\n)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", @@ -143,6 +145,7 @@ ], "source": [ "args = dict(\n", + " seed=42,\n", " dams=(6071, 6304, 7026, 7629, 7767, 8944, 11107),\n", " train_folds_csv=Path(\"../input/train_folds.csv\"),\n", " train_data_path=Path(\"../input/train-128\"),\n", @@ -155,7 +158,9 @@ " precision=16,\n", " optimizer=\"adamw\",\n", " scheduler=\"cosine\",\n", - " gradient_accumulation_steps=1,\n", + " accumulate_grad_batches=2,\n", + " gradient_clip_val=5.0,\n", + " rng=255.0,\n", ")" ] }, @@ -177,7 +182,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "metadata": { "hidden": true }, @@ -187,7 +192,7 @@ "application/javascript": [ "\n", " setTimeout(function() {\n", - " var nbb_cell_id = 7;\n", + " var nbb_cell_id = 5;\n", " var nbb_unformatted_code = \"def resize_data(path, folder=\\\"train-128\\\"):\\n data = np.load(path)\\n img1 = data[:, :, :3]\\n img2 = data[:, :, 2:]\\n img1 = cv2.copyMakeBorder(img1, 4, 4, 4, 4, cv2.BORDER_REFLECT)\\n img2 = cv2.copyMakeBorder(img2, 4, 4, 4, 4, cv2.BORDER_REFLECT)\\n img2 = img2[:, :, 1:]\\n data = np.concatenate([img1, img2], axis=-1)\\n np.save(PATH / folder / path.name, data)\";\n", " var nbb_formatted_code = \"def resize_data(path, folder=\\\"train-128\\\"):\\n data = np.load(path)\\n img1 = data[:, :, :3]\\n img2 = data[:, :, 2:]\\n img1 = cv2.copyMakeBorder(img1, 4, 4, 4, 4, cv2.BORDER_REFLECT)\\n img2 = cv2.copyMakeBorder(img2, 4, 4, 4, 4, cv2.BORDER_REFLECT)\\n img2 = img2[:, :, 1:]\\n data = np.concatenate([img1, img2], axis=-1)\\n np.save(PATH / folder / path.name, data)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", @@ -224,17 +229,28 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": { "hidden": true }, "outputs": [ + { + "ename": "NameError", + "evalue": "name 'PATH' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;34m(\u001b[0m\u001b[0mPATH\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0;34m\"train-128\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmkdir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexist_ok\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mNameError\u001b[0m: name 'PATH' is not defined" + ] + }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", - " var nbb_cell_id = 5;\n", + " var nbb_cell_id = 6;\n", " var nbb_unformatted_code = \"(PATH / \\\"train-128\\\").mkdir(exist_ok=True)\";\n", " var nbb_formatted_code = \"(PATH / \\\"train-128\\\").mkdir(exist_ok=True)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", @@ -391,7 +407,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -399,9 +415,72 @@ "application/javascript": [ "\n", " setTimeout(function() {\n", - " var nbb_cell_id = 5;\n", - " var nbb_unformatted_code = \"class NowcastingDataModule(pl.LightningDataModule):\\n def __init__(\\n self, train_df, val_df, batch_size=args[\\\"batch_size\\\"], num_workers=args[\\\"num_workers\\\"]\\n ):\\n super().__init__()\\n self.train_df = train_df\\n self.val_df = val_df\\n self.batch_size = batch_size\\n self.num_workers = num_workers\\n\\n def setup(self, stage=\\\"train\\\"):\\n if stage == \\\"train\\\":\\n train_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.train_df.filename.values\\n ]\\n val_paths = [args[\\\"train_data_path\\\"] / fn for fn in self.val_df.filename.values]\\n self.train_dataset = NowcastingDataset(train_paths)\\n self.val_dataset = NowcastingDataset(val_paths)\\n else:\\n test_paths = list(args[\\\"test_data_path\\\"].glob(\\\"*.npy\\\"))\\n self.test_dataset = NowcastingDataset(test_paths, test=True)\\n\\n def train_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.train_dataset,\\n batch_size=self.batch_size,\\n sampler=RandomSampler(self.train_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n drop_last=True,\\n )\\n\\n def val_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.val_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.val_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\\n\\n def test_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.test_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.test_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\";\n", - " var nbb_formatted_code = \"class NowcastingDataModule(pl.LightningDataModule):\\n def __init__(\\n self,\\n train_df,\\n val_df,\\n batch_size=args[\\\"batch_size\\\"],\\n num_workers=args[\\\"num_workers\\\"],\\n ):\\n super().__init__()\\n self.train_df = train_df\\n self.val_df = val_df\\n self.batch_size = batch_size\\n self.num_workers = num_workers\\n\\n def setup(self, stage=\\\"train\\\"):\\n if stage == \\\"train\\\":\\n train_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.train_df.filename.values\\n ]\\n val_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.val_df.filename.values\\n ]\\n self.train_dataset = NowcastingDataset(train_paths)\\n self.val_dataset = NowcastingDataset(val_paths)\\n else:\\n test_paths = list(args[\\\"test_data_path\\\"].glob(\\\"*.npy\\\"))\\n self.test_dataset = NowcastingDataset(test_paths, test=True)\\n\\n def train_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.train_dataset,\\n batch_size=self.batch_size,\\n sampler=RandomSampler(self.train_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n drop_last=True,\\n )\\n\\n def val_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.val_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.val_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\\n\\n def test_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.test_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.test_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\";\n", + " var nbb_cell_id = 9;\n", + " var nbb_unformatted_code = \"class NowcastingDataset(torch.utils.data.Dataset):\\n def __init__(self, paths, test=False):\\n self.paths = paths\\n self.test = test\\n\\n def __len__(self):\\n return len(self.paths)\\n\\n def __getitem__(self, idx):\\n path = self.paths[idx]\\n data = np.load(path)\\n x = data[:, :, :4]\\n x = x / args[\\\"rng\\\"]\\n x = x.astype(np.float32)\\n x = torch.tensor(x, dtype=torch.float)\\n x = x.permute(2, 0, 1)\\n if self.test:\\n return x\\n else:\\n y = data[:, :, 4]\\n y = y / args[\\\"rng\\\"]\\n y = y.astype(np.float32)\\n y = torch.tensor(y, dtype=torch.float)\\n y = y.unsqueeze(-1)\\n y = y.permute(2, 0, 1)\\n\\n return x, y\";\n", + " var nbb_formatted_code = \"class NowcastingDataset(torch.utils.data.Dataset):\\n def __init__(self, paths, test=False):\\n self.paths = paths\\n self.test = test\\n\\n def __len__(self):\\n return len(self.paths)\\n\\n def __getitem__(self, idx):\\n path = self.paths[idx]\\n data = np.load(path)\\n x = data[:, :, :4]\\n x = x / args[\\\"rng\\\"]\\n x = x.astype(np.float32)\\n x = torch.tensor(x, dtype=torch.float)\\n x = x.permute(2, 0, 1)\\n if self.test:\\n return x\\n else:\\n y = data[:, :, 4]\\n y = y / args[\\\"rng\\\"]\\n y = y.astype(np.float32)\\n y = torch.tensor(y, dtype=torch.float)\\n y = y.unsqueeze(-1)\\n y = y.permute(2, 0, 1)\\n\\n return x, y\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "class NowcastingDataset(torch.utils.data.Dataset):\n", + " def __init__(self, paths, test=False):\n", + " self.paths = paths\n", + " self.test = test\n", + "\n", + " def __len__(self):\n", + " return len(self.paths)\n", + "\n", + " def __getitem__(self, idx):\n", + " path = self.paths[idx]\n", + " data = np.load(path)\n", + " x = data[:, :, :4]\n", + " x = x / args[\"rng\"]\n", + " x = x.astype(np.float32)\n", + " x = torch.tensor(x, dtype=torch.float)\n", + " x = x.permute(2, 0, 1)\n", + " if self.test:\n", + " return x\n", + " else:\n", + " y = data[:, :, 4]\n", + " y = y / args[\"rng\"]\n", + " y = y.astype(np.float32)\n", + " y = torch.tensor(y, dtype=torch.float)\n", + " y = y.unsqueeze(-1)\n", + " y = y.permute(2, 0, 1)\n", + "\n", + " return x, y" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 10;\n", + " var nbb_unformatted_code = \"class NowcastingDataModule(pl.LightningDataModule):\\n def __init__(\\n self,\\n train_df=None,\\n val_df=None,\\n batch_size=args[\\\"batch_size\\\"],\\n num_workers=args[\\\"num_workers\\\"],\\n ):\\n super().__init__()\\n self.train_df = train_df\\n self.val_df = val_df\\n self.batch_size = batch_size\\n self.num_workers = num_workers\\n\\n def setup(self, stage=\\\"train\\\"):\\n if stage == \\\"train\\\":\\n train_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.train_df.filename.values\\n ]\\n val_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.val_df.filename.values\\n ]\\n self.train_dataset = NowcastingDataset(train_paths)\\n self.val_dataset = NowcastingDataset(val_paths)\\n else:\\n test_paths = list(args[\\\"test_data_path\\\"].glob(\\\"*.npy\\\"))\\n self.test_dataset = NowcastingDataset(test_paths, test=True)\\n\\n def train_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.train_dataset,\\n batch_size=self.batch_size,\\n sampler=RandomSampler(self.train_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n drop_last=True,\\n )\\n\\n def val_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.val_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.val_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\\n\\n def test_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.test_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.test_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\";\n", + " var nbb_formatted_code = \"class NowcastingDataModule(pl.LightningDataModule):\\n def __init__(\\n self,\\n train_df=None,\\n val_df=None,\\n batch_size=args[\\\"batch_size\\\"],\\n num_workers=args[\\\"num_workers\\\"],\\n ):\\n super().__init__()\\n self.train_df = train_df\\n self.val_df = val_df\\n self.batch_size = batch_size\\n self.num_workers = num_workers\\n\\n def setup(self, stage=\\\"train\\\"):\\n if stage == \\\"train\\\":\\n train_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.train_df.filename.values\\n ]\\n val_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.val_df.filename.values\\n ]\\n self.train_dataset = NowcastingDataset(train_paths)\\n self.val_dataset = NowcastingDataset(val_paths)\\n else:\\n test_paths = list(args[\\\"test_data_path\\\"].glob(\\\"*.npy\\\"))\\n self.test_dataset = NowcastingDataset(test_paths, test=True)\\n\\n def train_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.train_dataset,\\n batch_size=self.batch_size,\\n sampler=RandomSampler(self.train_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n drop_last=True,\\n )\\n\\n def val_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.val_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.val_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\\n\\n def test_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.test_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.test_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", @@ -425,7 +504,11 @@ "source": [ "class NowcastingDataModule(pl.LightningDataModule):\n", " def __init__(\n", - " self, train_df, val_df, batch_size=args[\"batch_size\"], num_workers=args[\"num_workers\"]\n", + " self,\n", + " train_df=None,\n", + " val_df=None,\n", + " batch_size=args[\"batch_size\"],\n", + " num_workers=args[\"num_workers\"],\n", " ):\n", " super().__init__()\n", " self.train_df = train_df\n", @@ -438,7 +521,9 @@ " train_paths = [\n", " args[\"train_data_path\"] / fn for fn in self.train_df.filename.values\n", " ]\n", - " val_paths = [args[\"train_data_path\"] / fn for fn in self.val_df.filename.values]\n", + " val_paths = [\n", + " args[\"train_data_path\"] / fn for fn in self.val_df.filename.values\n", + " ]\n", " self.train_dataset = NowcastingDataset(train_paths)\n", " self.val_dataset = NowcastingDataset(val_paths)\n", " else:\n", @@ -476,7 +561,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 12, "metadata": { "scrolled": false }, @@ -486,7 +571,7 @@ "application/javascript": [ "\n", " setTimeout(function() {\n", - " var nbb_cell_id = 6;\n", + " var nbb_cell_id = 12;\n", " var nbb_unformatted_code = \"# df = pd.read_csv(args[\\\"train_folds_csv\\\"])\\n# datamodule = NowcastingDataModule(df, fold=0, batch_size=2)\\n# datamodule.setup()\\n# for batch in datamodule.train_dataloader():\\n# xs, ys = batch\\n# idx = np.random.randint(len(xs))\\n# x, y = xs[idx], ys[idx]\\n# x = x.permute(1, 2, 0).numpy()\\n# y = y.permute(1, 2, 0).numpy()\\n# visualize(x, y)\\n# break\";\n", " var nbb_formatted_code = \"# df = pd.read_csv(args[\\\"train_folds_csv\\\"])\\n# datamodule = NowcastingDataModule(df, fold=0, batch_size=2)\\n# datamodule.setup()\\n# for batch in datamodule.train_dataloader():\\n# xs, ys = batch\\n# idx = np.random.randint(len(xs))\\n# x, y = xs[idx], ys[idx]\\n# x = x.permute(1, 2, 0).numpy()\\n# y = y.permute(1, 2, 0).numpy()\\n# visualize(x, y)\\n# break\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", @@ -539,7 +624,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -547,9 +632,9 @@ "application/javascript": [ "\n", " setTimeout(function() {\n", - " var nbb_cell_id = 10;\n", - " var nbb_unformatted_code = \"class Block(nn.Module):\\n def __init__(self, in_ch, out_ch):\\n super().__init__()\\n self.net = nn.Sequential(\\n nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n nn.ReLU(inplace=True),\\n nn.BatchNorm2d(out_ch),\\n nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n nn.ReLU(inplace=True),\\n nn.BatchNorm2d(out_ch),\\n )\\n\\n def forward(self, x):\\n return self.net(x)\\n\\n\\nclass Encoder(nn.Module):\\n def __init__(self, chs=[4, 64, 128, 256, 512, 1024], drop_rate=0.5):\\n super().__init__()\\n self.blocks = nn.ModuleList(\\n [Block(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\\n self.dropout = nn.Dropout(p=drop_rate)\\n\\n def forward(self, x):\\n ftrs = []\\n for i, block in enumerate(self.blocks):\\n x = block(x)\\n ftrs.append(x)\\n if i >= 3:\\n x = self.dropout(x)\\n if i < 4:\\n x = self.pool(x)\\n return ftrs\\n\\n\\nclass Decoder(nn.Module):\\n def __init__(self, chs=[1024, 512, 256, 128, 64]):\\n super().__init__()\\n self.chs = chs\\n# self.ups = nn.ModuleList(\\n# [nn.Upsample(scale_factor=2, mode=\\\"nearest\\\") for i in range(len(chs) - 1)]\\n# )\\n self.ups = nn.ModuleList(\\n [nn.ConvTranspose2d(chs[i], chs[i+1], kernel_size=2, stride=2) for i in range(len(chs) - 1)]\\n )\\n self.convs = nn.ModuleList(\\n [Block(chs[i] + chs[i + 1], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n\\n def forward(self, x, ftrs):\\n for i in range(len(self.chs) - 1):\\n x = self.ups[i](x)\\n x = torch.cat([ftrs[i], x], dim=1)\\n x = self.convs[i](x)\\n return x\";\n", - " var nbb_formatted_code = \"class Block(nn.Module):\\n def __init__(self, in_ch, out_ch):\\n super().__init__()\\n self.net = nn.Sequential(\\n nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n nn.ReLU(inplace=True),\\n nn.BatchNorm2d(out_ch),\\n nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n nn.ReLU(inplace=True),\\n nn.BatchNorm2d(out_ch),\\n )\\n\\n def forward(self, x):\\n return self.net(x)\\n\\n\\nclass Encoder(nn.Module):\\n def __init__(self, chs=[4, 64, 128, 256, 512, 1024], drop_rate=0.5):\\n super().__init__()\\n self.blocks = nn.ModuleList(\\n [Block(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\\n self.dropout = nn.Dropout(p=drop_rate)\\n\\n def forward(self, x):\\n ftrs = []\\n for i, block in enumerate(self.blocks):\\n x = block(x)\\n ftrs.append(x)\\n if i >= 3:\\n x = self.dropout(x)\\n if i < 4:\\n x = self.pool(x)\\n return ftrs\\n\\n\\nclass Decoder(nn.Module):\\n def __init__(self, chs=[1024, 512, 256, 128, 64]):\\n super().__init__()\\n self.chs = chs\\n # self.ups = nn.ModuleList(\\n # [nn.Upsample(scale_factor=2, mode=\\\"nearest\\\") for i in range(len(chs) - 1)]\\n # )\\n self.ups = nn.ModuleList(\\n [\\n nn.ConvTranspose2d(chs[i], chs[i + 1], kernel_size=2, stride=2)\\n for i in range(len(chs) - 1)\\n ]\\n )\\n self.convs = nn.ModuleList(\\n [Block(chs[i] + chs[i + 1], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n\\n def forward(self, x, ftrs):\\n for i in range(len(self.chs) - 1):\\n x = self.ups[i](x)\\n x = torch.cat([ftrs[i], x], dim=1)\\n x = self.convs[i](x)\\n return x\";\n", + " var nbb_cell_id = 11;\n", + " var nbb_unformatted_code = \"class Block(nn.Module):\\n def __init__(self, in_ch, out_ch):\\n super().__init__()\\n self.net = nn.Sequential(\\n nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n nn.ReLU(inplace=True),\\n nn.BatchNorm2d(out_ch),\\n nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n nn.ReLU(inplace=True),\\n nn.BatchNorm2d(out_ch),\\n )\\n\\n def forward(self, x):\\n return self.net(x)\\n\\n\\nclass Encoder(nn.Module):\\n def __init__(self, chs=[4, 64, 128, 256, 512, 1024], drop_rate=0.5):\\n super().__init__()\\n self.blocks = nn.ModuleList(\\n [Block(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\\n self.dropout = nn.Dropout(p=drop_rate)\\n\\n def forward(self, x):\\n ftrs = []\\n for i, block in enumerate(self.blocks):\\n x = block(x)\\n ftrs.append(x)\\n if i >= 3:\\n x = self.dropout(x)\\n if i < 4:\\n x = self.pool(x)\\n return ftrs\\n\\n\\nclass Decoder(nn.Module):\\n def __init__(self, chs=[1024, 512, 256, 128, 64]):\\n super().__init__()\\n self.chs = chs\\n self.ups = nn.ModuleList(\\n [nn.Upsample(scale_factor=2, mode=\\\"nearest\\\") for i in range(len(chs) - 1)]\\n )\\n self.convs = nn.ModuleList(\\n [Block(chs[i] + chs[i + 1], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n\\n def forward(self, x, ftrs):\\n for i in range(len(self.chs) - 1):\\n x = self.ups[i](x)\\n x = torch.cat([ftrs[i], x], dim=1)\\n x = self.convs[i](x)\\n return x\";\n", + " var nbb_formatted_code = \"class Block(nn.Module):\\n def __init__(self, in_ch, out_ch):\\n super().__init__()\\n self.net = nn.Sequential(\\n nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n nn.ReLU(inplace=True),\\n nn.BatchNorm2d(out_ch),\\n nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n nn.ReLU(inplace=True),\\n nn.BatchNorm2d(out_ch),\\n )\\n\\n def forward(self, x):\\n return self.net(x)\\n\\n\\nclass Encoder(nn.Module):\\n def __init__(self, chs=[4, 64, 128, 256, 512, 1024], drop_rate=0.5):\\n super().__init__()\\n self.blocks = nn.ModuleList(\\n [Block(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\\n self.dropout = nn.Dropout(p=drop_rate)\\n\\n def forward(self, x):\\n ftrs = []\\n for i, block in enumerate(self.blocks):\\n x = block(x)\\n ftrs.append(x)\\n if i >= 3:\\n x = self.dropout(x)\\n if i < 4:\\n x = self.pool(x)\\n return ftrs\\n\\n\\nclass Decoder(nn.Module):\\n def __init__(self, chs=[1024, 512, 256, 128, 64]):\\n super().__init__()\\n self.chs = chs\\n self.ups = nn.ModuleList(\\n [nn.Upsample(scale_factor=2, mode=\\\"nearest\\\") for i in range(len(chs) - 1)]\\n )\\n self.convs = nn.ModuleList(\\n [Block(chs[i] + chs[i + 1], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n\\n def forward(self, x, ftrs):\\n for i in range(len(self.chs) - 1):\\n x = self.ups[i](x)\\n x = torch.cat([ftrs[i], x], dim=1)\\n x = self.convs[i](x)\\n return x\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", @@ -612,14 +697,8 @@ " def __init__(self, chs=[1024, 512, 256, 128, 64]):\n", " super().__init__()\n", " self.chs = chs\n", - " # self.ups = nn.ModuleList(\n", - " # [nn.Upsample(scale_factor=2, mode=\"nearest\") for i in range(len(chs) - 1)]\n", - " # )\n", " self.ups = nn.ModuleList(\n", - " [\n", - " nn.ConvTranspose2d(chs[i], chs[i + 1], kernel_size=2, stride=2)\n", - " for i in range(len(chs) - 1)\n", - " ]\n", + " [nn.Upsample(scale_factor=2, mode=\"nearest\") for i in range(len(chs) - 1)]\n", " )\n", " self.convs = nn.ModuleList(\n", " [Block(chs[i] + chs[i + 1], chs[i + 1]) for i in range(len(chs) - 1)]\n", @@ -642,7 +721,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -650,9 +729,9 @@ "application/javascript": [ "\n", " setTimeout(function() {\n", - " var nbb_cell_id = 11;\n", - " var nbb_unformatted_code = \"class RainNet(pl.LightningModule):\\n def __init__(\\n self,\\n lr=1e-4,\\n enc_chs=[4, 64, 128, 256, 512, 1024],\\n dec_chs=[1024, 512, 256, 128, 64],\\n num_train_steps=None,\\n ):\\n super().__init__()\\n\\n # Parameters\\n self.lr = lr\\n self.num_train_steps = num_train_steps\\n\\n # self.criterion = LogCoshLoss()\\n self.criterion = nn.L1Loss()\\n\\n # Layers\\n self.encoder = Encoder(enc_chs)\\n self.decoder = Decoder(dec_chs)\\n self.out = nn.Sequential(\\n nn.Conv2d(64, 2, kernel_size=3, padding=1),\\n nn.ReLU(inplace=True),\\n nn.BatchNorm2d(2),\\n nn.Conv2d(2, 1, kernel_size=1),\\n nn.ReLU(inplace=True),\\n )\\n\\n def forward(self, x):\\n ftrs = self.encoder(x)\\n ftrs = ftrs[::-1]\\n x = self.decoder(ftrs[0], ftrs[1:])\\n out = self.out(x)\\n return out\\n\\n def shared_step(self, batch, batch_idx):\\n x, y = batch\\n y_hat = self(x)\\n loss = self.criterion(y_hat, y)\\n return loss, y, y_hat\\n\\n def training_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n self.log(\\\"train_loss\\\", loss)\\n return {\\\"loss\\\": loss}\\n\\n def validation_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n return {\\\"loss\\\": loss, \\\"y\\\": y.detach(), \\\"y_hat\\\": y_hat.detach()}\\n\\n def validation_epoch_end(self, outputs):\\n avg_loss = torch.stack([x[\\\"loss\\\"] for x in outputs]).mean()\\n self.log(\\\"val_loss\\\", avg_loss)\\n\\n tfms = nn.Sequential(\\n T.CenterCrop(120),\\n )\\n\\n y = torch.cat([x[\\\"y\\\"] for x in outputs])\\n y = tfms(y)\\n y = y.detach().cpu().numpy()\\n y = y.reshape(-1, 120 * 120)\\n\\n y_hat = torch.cat([x[\\\"y_hat\\\"] for x in outputs])\\n y_hat = tfms(y_hat)\\n y_hat = y_hat.detach().cpu().numpy()\\n y_hat = y_hat.reshape(-1, 120 * 120)\\n\\n y = 255.0 * y[:, args[\\\"dams\\\"]]\\n y = np.round(y).clip(0, 255)\\n y_hat = 255.0 * y_hat[:, args[\\\"dams\\\"]]\\n y_hat = np.round(y_hat).clip(0, 255)\\n # mae = metrics.mean_absolute_error(y, y_hat)\\n\\n y_true = radar2precipitation(y)\\n y_true = np.where(y_true >= 0.1, 1, 0)\\n y_pred = radar2precipitation(y_hat)\\n y_pred = np.where(y_pred >= 0.1, 1, 0)\\n\\n y = y * y_true\\n y_hat = y_hat * y_true\\n# mae = np.abs(y - y_hat).sum() / y_true.sum()\\n mae = np.abs(y - y_hat).mean()\\n\\n tn, fp, fn, tp = metrics.confusion_matrix(\\n y_true.reshape(-1), y_pred.reshape(-1)\\n ).ravel()\\n csi = tp / (tp + fn + fp)\\n\\n comp_metric = mae / (csi + 1e-12)\\n\\n print(\\n f\\\"Epoch {self.current_epoch} | MAE/CSI: {comp_metric} | MAE: {mae} | CSI: {csi} | Loss: {avg_loss}\\\"\\n )\\n\\n def configure_optimizers(self):\\n optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\\n scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\\n optimizer, T_max=self.num_train_steps\\n )\\n return [optimizer], [{\\\"scheduler\\\": scheduler, \\\"interval\\\": \\\"step\\\"}]\";\n", - " var nbb_formatted_code = \"class RainNet(pl.LightningModule):\\n def __init__(\\n self,\\n lr=1e-4,\\n enc_chs=[4, 64, 128, 256, 512, 1024],\\n dec_chs=[1024, 512, 256, 128, 64],\\n num_train_steps=None,\\n ):\\n super().__init__()\\n\\n # Parameters\\n self.lr = lr\\n self.num_train_steps = num_train_steps\\n\\n # self.criterion = LogCoshLoss()\\n self.criterion = nn.L1Loss()\\n\\n # Layers\\n self.encoder = Encoder(enc_chs)\\n self.decoder = Decoder(dec_chs)\\n self.out = nn.Sequential(\\n nn.Conv2d(64, 2, kernel_size=3, padding=1),\\n nn.ReLU(inplace=True),\\n nn.BatchNorm2d(2),\\n nn.Conv2d(2, 1, kernel_size=1),\\n nn.ReLU(inplace=True),\\n )\\n\\n def forward(self, x):\\n ftrs = self.encoder(x)\\n ftrs = ftrs[::-1]\\n x = self.decoder(ftrs[0], ftrs[1:])\\n out = self.out(x)\\n return out\\n\\n def shared_step(self, batch, batch_idx):\\n x, y = batch\\n y_hat = self(x)\\n loss = self.criterion(y_hat, y)\\n return loss, y, y_hat\\n\\n def training_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n self.log(\\\"train_loss\\\", loss)\\n return {\\\"loss\\\": loss}\\n\\n def validation_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n return {\\\"loss\\\": loss, \\\"y\\\": y.detach(), \\\"y_hat\\\": y_hat.detach()}\\n\\n def validation_epoch_end(self, outputs):\\n avg_loss = torch.stack([x[\\\"loss\\\"] for x in outputs]).mean()\\n self.log(\\\"val_loss\\\", avg_loss)\\n\\n tfms = nn.Sequential(\\n T.CenterCrop(120),\\n )\\n\\n y = torch.cat([x[\\\"y\\\"] for x in outputs])\\n y = tfms(y)\\n y = y.detach().cpu().numpy()\\n y = y.reshape(-1, 120 * 120)\\n\\n y_hat = torch.cat([x[\\\"y_hat\\\"] for x in outputs])\\n y_hat = tfms(y_hat)\\n y_hat = y_hat.detach().cpu().numpy()\\n y_hat = y_hat.reshape(-1, 120 * 120)\\n\\n y = 255.0 * y[:, args[\\\"dams\\\"]]\\n y = np.round(y).clip(0, 255)\\n y_hat = 255.0 * y_hat[:, args[\\\"dams\\\"]]\\n y_hat = np.round(y_hat).clip(0, 255)\\n # mae = metrics.mean_absolute_error(y, y_hat)\\n\\n y_true = radar2precipitation(y)\\n y_true = np.where(y_true >= 0.1, 1, 0)\\n y_pred = radar2precipitation(y_hat)\\n y_pred = np.where(y_pred >= 0.1, 1, 0)\\n\\n y = y * y_true\\n y_hat = y_hat * y_true\\n # mae = np.abs(y - y_hat).sum() / y_true.sum()\\n mae = np.abs(y - y_hat).mean()\\n\\n tn, fp, fn, tp = metrics.confusion_matrix(\\n y_true.reshape(-1), y_pred.reshape(-1)\\n ).ravel()\\n csi = tp / (tp + fn + fp)\\n\\n comp_metric = mae / (csi + 1e-12)\\n\\n print(\\n f\\\"Epoch {self.current_epoch} | MAE/CSI: {comp_metric} | MAE: {mae} | CSI: {csi} | Loss: {avg_loss}\\\"\\n )\\n\\n def configure_optimizers(self):\\n optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\\n scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\\n optimizer, T_max=self.num_train_steps\\n )\\n return [optimizer], [{\\\"scheduler\\\": scheduler, \\\"interval\\\": \\\"step\\\"}]\";\n", + " var nbb_cell_id = 12;\n", + " var nbb_unformatted_code = \"class RainNet(pl.LightningModule):\\n def __init__(\\n self,\\n lr=4e-4,\\n enc_chs=[4, 64, 128, 256, 512, 1024],\\n dec_chs=[1024, 512, 256, 128, 64],\\n num_train_steps=None,\\n ):\\n super().__init__()\\n\\n # Parameters\\n self.lr = lr\\n self.num_train_steps = num_train_steps\\n\\n # self.criterion = LogCoshLoss()\\n self.criterion = nn.L1Loss()\\n\\n # Layers\\n self.encoder = Encoder(enc_chs)\\n self.decoder = Decoder(dec_chs)\\n self.out = nn.Sequential(\\n nn.Conv2d(64, 2, kernel_size=3, padding=1),\\n nn.ReLU(inplace=True),\\n nn.BatchNorm2d(2),\\n nn.Conv2d(2, 1, kernel_size=1),\\n nn.ReLU(inplace=True),\\n )\\n\\n def forward(self, x):\\n ftrs = self.encoder(x)\\n ftrs = ftrs[::-1]\\n x = self.decoder(ftrs[0], ftrs[1:])\\n out = self.out(x)\\n return out\\n\\n def shared_step(self, batch, batch_idx):\\n x, y = batch\\n y_hat = self(x)\\n loss = self.criterion(y_hat, y)\\n return loss, y, y_hat\\n\\n def training_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n self.log(\\\"train_loss\\\", loss)\\n return {\\\"loss\\\": loss}\\n\\n def validation_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n return {\\\"loss\\\": loss, \\\"y\\\": y.detach(), \\\"y_hat\\\": y_hat.detach()}\\n\\n def validation_epoch_end(self, outputs):\\n avg_loss = torch.stack([x[\\\"loss\\\"] for x in outputs]).mean()\\n self.log(\\\"val_loss\\\", avg_loss)\\n\\n tfms = nn.Sequential(\\n T.CenterCrop(120),\\n )\\n\\n y = torch.cat([x[\\\"y\\\"] for x in outputs])\\n y = tfms(y)\\n y = y.detach().cpu().numpy()\\n y = y.reshape(-1, 120 * 120)\\n\\n y_hat = torch.cat([x[\\\"y_hat\\\"] for x in outputs])\\n y_hat = tfms(y_hat)\\n y_hat = y_hat.detach().cpu().numpy()\\n y_hat = y_hat.reshape(-1, 120 * 120)\\n\\n rng = args[\\\"rng\\\"]\\n y = rng * y[:, args[\\\"dams\\\"]]\\n y = y.clip(0, 255)\\n y_hat = rng * y_hat[:, args[\\\"dams\\\"]]\\n y_hat = y_hat.clip(0, 255)\\n # mae = metrics.mean_absolute_error(y, y_hat)\\n\\n y_true = radar2precipitation(y)\\n y_true = np.where(y_true >= 0.1, 1, 0)\\n y_pred = radar2precipitation(y_hat)\\n y_pred = np.where(y_pred >= 0.1, 1, 0)\\n\\n y *= y_true\\n y_hat *= y_true\\n mae = metrics.mean_absolute_error(y, y_hat)\\n\\n tn, fp, fn, tp = metrics.confusion_matrix(\\n y_true.reshape(-1), y_pred.reshape(-1)\\n ).ravel()\\n csi = tp / (tp + fn + fp)\\n\\n comp_metric = mae / (csi + 1e-12)\\n\\n print(\\n f\\\"Epoch {self.current_epoch} | MAE/CSI: {comp_metric} | MAE: {mae} | CSI: {csi} | Loss: {avg_loss}\\\"\\n )\\n\\n def configure_optimizers(self):\\n # optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\\n optimizer = transformers.AdamW(self.parameters(), lr=self.lr)\\n scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\\n optimizer, T_max=self.num_train_steps\\n )\\n return [optimizer], [{\\\"scheduler\\\": scheduler, \\\"interval\\\": \\\"step\\\"}]\";\n", + " var nbb_formatted_code = \"class RainNet(pl.LightningModule):\\n def __init__(\\n self,\\n lr=4e-4,\\n enc_chs=[4, 64, 128, 256, 512, 1024],\\n dec_chs=[1024, 512, 256, 128, 64],\\n num_train_steps=None,\\n ):\\n super().__init__()\\n\\n # Parameters\\n self.lr = lr\\n self.num_train_steps = num_train_steps\\n\\n # self.criterion = LogCoshLoss()\\n self.criterion = nn.L1Loss()\\n\\n # Layers\\n self.encoder = Encoder(enc_chs)\\n self.decoder = Decoder(dec_chs)\\n self.out = nn.Sequential(\\n nn.Conv2d(64, 2, kernel_size=3, padding=1),\\n nn.ReLU(inplace=True),\\n nn.BatchNorm2d(2),\\n nn.Conv2d(2, 1, kernel_size=1),\\n nn.ReLU(inplace=True),\\n )\\n\\n def forward(self, x):\\n ftrs = self.encoder(x)\\n ftrs = ftrs[::-1]\\n x = self.decoder(ftrs[0], ftrs[1:])\\n out = self.out(x)\\n return out\\n\\n def shared_step(self, batch, batch_idx):\\n x, y = batch\\n y_hat = self(x)\\n loss = self.criterion(y_hat, y)\\n return loss, y, y_hat\\n\\n def training_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n self.log(\\\"train_loss\\\", loss)\\n return {\\\"loss\\\": loss}\\n\\n def validation_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n return {\\\"loss\\\": loss, \\\"y\\\": y.detach(), \\\"y_hat\\\": y_hat.detach()}\\n\\n def validation_epoch_end(self, outputs):\\n avg_loss = torch.stack([x[\\\"loss\\\"] for x in outputs]).mean()\\n self.log(\\\"val_loss\\\", avg_loss)\\n\\n tfms = nn.Sequential(\\n T.CenterCrop(120),\\n )\\n\\n y = torch.cat([x[\\\"y\\\"] for x in outputs])\\n y = tfms(y)\\n y = y.detach().cpu().numpy()\\n y = y.reshape(-1, 120 * 120)\\n\\n y_hat = torch.cat([x[\\\"y_hat\\\"] for x in outputs])\\n y_hat = tfms(y_hat)\\n y_hat = y_hat.detach().cpu().numpy()\\n y_hat = y_hat.reshape(-1, 120 * 120)\\n\\n rng = args[\\\"rng\\\"]\\n y = rng * y[:, args[\\\"dams\\\"]]\\n y = y.clip(0, 255)\\n y_hat = rng * y_hat[:, args[\\\"dams\\\"]]\\n y_hat = y_hat.clip(0, 255)\\n # mae = metrics.mean_absolute_error(y, y_hat)\\n\\n y_true = radar2precipitation(y)\\n y_true = np.where(y_true >= 0.1, 1, 0)\\n y_pred = radar2precipitation(y_hat)\\n y_pred = np.where(y_pred >= 0.1, 1, 0)\\n\\n y *= y_true\\n y_hat *= y_true\\n mae = metrics.mean_absolute_error(y, y_hat)\\n\\n tn, fp, fn, tp = metrics.confusion_matrix(\\n y_true.reshape(-1), y_pred.reshape(-1)\\n ).ravel()\\n csi = tp / (tp + fn + fp)\\n\\n comp_metric = mae / (csi + 1e-12)\\n\\n print(\\n f\\\"Epoch {self.current_epoch} | MAE/CSI: {comp_metric} | MAE: {mae} | CSI: {csi} | Loss: {avg_loss}\\\"\\n )\\n\\n def configure_optimizers(self):\\n # optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\\n optimizer = transformers.AdamW(self.parameters(), lr=self.lr)\\n scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\\n optimizer, T_max=self.num_train_steps\\n )\\n return [optimizer], [{\\\"scheduler\\\": scheduler, \\\"interval\\\": \\\"step\\\"}]\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", @@ -677,7 +756,7 @@ "class RainNet(pl.LightningModule):\n", " def __init__(\n", " self,\n", - " lr=1e-4,\n", + " lr=3e-4,\n", " enc_chs=[4, 64, 128, 256, 512, 1024],\n", " dec_chs=[1024, 512, 256, 128, 64],\n", " num_train_steps=None,\n", @@ -689,7 +768,8 @@ " self.num_train_steps = num_train_steps\n", "\n", " # self.criterion = LogCoshLoss()\n", - " self.criterion = nn.L1Loss()\n", + "# self.criterion = nn.L1Loss()\n", + " self.criterion = nn.SmoothL1Loss()\n", "\n", " # Layers\n", " self.encoder = Encoder(enc_chs)\n", @@ -742,10 +822,11 @@ " y_hat = y_hat.detach().cpu().numpy()\n", " y_hat = y_hat.reshape(-1, 120 * 120)\n", "\n", - " y = 255.0 * y[:, args[\"dams\"]]\n", - " y = np.round(y).clip(0, 255)\n", - " y_hat = 255.0 * y_hat[:, args[\"dams\"]]\n", - " y_hat = np.round(y_hat).clip(0, 255)\n", + " rng = args[\"rng\"]\n", + " y = rng * y[:, args[\"dams\"]]\n", + " y = y.clip(0, 255)\n", + " y_hat = rng * y_hat[:, args[\"dams\"]]\n", + " y_hat = y_hat.clip(0, 255)\n", " # mae = metrics.mean_absolute_error(y, y_hat)\n", "\n", " y_true = radar2precipitation(y)\n", @@ -753,10 +834,9 @@ " y_pred = radar2precipitation(y_hat)\n", " y_pred = np.where(y_pred >= 0.1, 1, 0)\n", "\n", - " y = y * y_true\n", - " y_hat = y_hat * y_true\n", - " # mae = np.abs(y - y_hat).sum() / y_true.sum()\n", - " mae = np.abs(y - y_hat).mean()\n", + " y *= y_true\n", + " y_hat *= y_true\n", + " mae = metrics.mean_absolute_error(y, y_hat)\n", "\n", " tn, fp, fn, tp = metrics.confusion_matrix(\n", " y_true.reshape(-1), y_pred.reshape(-1)\n", @@ -770,7 +850,8 @@ " )\n", "\n", " def configure_optimizers(self):\n", - " optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\n", + " # optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\n", + " optimizer = transformers.AdamW(self.parameters(), lr=self.lr)\n", " scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n", " optimizer, T_max=self.num_train_steps\n", " )\n", @@ -786,7 +867,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": { "scrolled": false }, @@ -804,14 +885,14 @@ "-----------------------------------------\n", "0 | criterion | L1Loss | 0 \n", "1 | encoder | Encoder | 18 M \n", - "2 | decoder | Decoder | 15 M \n", + "2 | decoder | Decoder | 12 M \n", "3 | out | Sequential | 1 K \n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "101c9fa38ccc4e998ca991d319623ae6", + "model_id": "9f51fcc5cca242d39eb647651880f9d1", "version_major": 2, "version_minor": 0 }, @@ -823,67 +904,273 @@ "output_type": "display_data" }, { - "ename": "RuntimeError", - "evalue": "Given groups=1, weight of size [512, 1536, 3, 3], expected input[128, 1024, 16, 16] to have 1536 channels, but got 1024 channels instead", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 32\u001b[0m )\n\u001b[1;32m 33\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 34\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdatamodule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 35\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave_checkpoint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"rainnet_fold{fold}_bs64_epoch50.ckpt\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, model, train_dataloader, val_dataloaders, datamodule)\u001b[0m\n\u001b[1;32m 438\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcall_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'on_fit_start'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 439\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 440\u001b[0;31m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccelerator_backend\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 441\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccelerator_backend\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mteardown\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 442\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0;31m# train or test\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 54\u001b[0;31m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_or_test\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 55\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py\u001b[0m in \u001b[0;36mtrain_or_test\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 66\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_test\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 68\u001b[0;31m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 69\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 70\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 460\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 461\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 462\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_sanity_check\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 463\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 464\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheckpoint_connector\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhas_trained\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mrun_sanity_check\u001b[0;34m(self, ref_model)\u001b[0m\n\u001b[1;32m 648\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 649\u001b[0m \u001b[0;31m# run eval step\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 650\u001b[0;31m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_evaluation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_mode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_batches\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_sanity_val_batches\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 651\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 652\u001b[0m \u001b[0;31m# allow no returns from eval\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mrun_evaluation\u001b[0;34m(self, test_mode, max_batches)\u001b[0m\n\u001b[1;32m 568\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 569\u001b[0m \u001b[0;31m# lightning module methods\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 570\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluation_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluation_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_mode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdataloader_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 571\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluation_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluation_step_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 572\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/pytorch_lightning/trainer/evaluation_loop.py\u001b[0m in \u001b[0;36mevaluation_step\u001b[0;34m(self, test_mode, batch, batch_idx, dataloader_idx)\u001b[0m\n\u001b[1;32m 169\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccelerator_backend\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtest_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 170\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 171\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccelerator_backend\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalidation_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 172\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 173\u001b[0m \u001b[0;31m# track batch size for weighted average\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\u001b[0m in \u001b[0;36mvalidation_step\u001b[0;34m(self, args)\u001b[0m\n\u001b[1;32m 74\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mamp_backend\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mAMPType\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mNATIVE\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mamp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautocast\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 76\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__validation_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 77\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 78\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__validation_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\u001b[0m in \u001b[0;36m__validation_step\u001b[0;34m(self, args)\u001b[0m\n\u001b[1;32m 84\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_device\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 86\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalidation_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 87\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m\u001b[0m in \u001b[0;36mvalidation_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mvalidation_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 48\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_hat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshared_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 49\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m\"loss\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"y\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"y_hat\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0my_hat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m\u001b[0m in \u001b[0;36mshared_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mshared_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 38\u001b[0;31m \u001b[0my_hat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 39\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_hat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_hat\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0mftrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0mftrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mftrs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mftrs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mftrs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 33\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, ftrs)\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mups\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mftrs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 56\u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconvs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 57\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 117\u001b[0;31m \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 118\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/nn/modules/conv.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 421\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 422\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 423\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_conv_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 424\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 425\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0mConv3d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_ConvNd\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch/lib/python3.7/site-packages/torch/nn/modules/conv.py\u001b[0m in \u001b[0;36m_conv_forward\u001b[0;34m(self, input, weight)\u001b[0m\n\u001b[1;32m 418\u001b[0m _pair(0), self.dilation, self.groups)\n\u001b[1;32m 419\u001b[0m return F.conv2d(input, weight, self.bias, self.stride,\n\u001b[0;32m--> 420\u001b[0;31m self.padding, self.dilation, self.groups)\n\u001b[0m\u001b[1;32m 421\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 422\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mRuntimeError\u001b[0m: Given groups=1, weight of size [512, 1536, 3, 3], expected input[128, 1024, 16, 16] to have 1536 channels, but got 1024 channels instead" + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 44.72930992361908 | MAE: 7.238560199737549 | CSI: 0.16183035714285715 | Loss: 0.5618522763252258\n" ] }, { "data": { - "application/javascript": [ - "\n", - " setTimeout(function() {\n", - " var nbb_cell_id = 12;\n", - " var nbb_unformatted_code = \"df = pd.read_csv(args[\\\"train_folds_csv\\\"])\\n\\nfor fold in range(5):\\n train_df = df[df.fold != fold]\\n val_df = df[df.fold == fold]\\n \\n datamodule = NowcastingDataModule(\\n train_df, val_df, batch_size=args[\\\"batch_size\\\"], num_workers=args[\\\"num_workers\\\"]\\n )\\n datamodule.setup()\\n\\n num_train_steps = (\\n int(\\n np.ceil(\\n len(train_df)\\n // args[\\\"batch_size\\\"]\\n / args[\\\"gradient_accumulation_steps\\\"]\\n )\\n )\\n * args[\\\"max_epochs\\\"]\\n )\\n \\n model = RainNet(num_train_steps=num_train_steps)\\n \\n trainer = pl.Trainer(\\n gpus=args[\\\"gpus\\\"],\\n max_epochs=args[\\\"max_epochs\\\"],\\n precision=args[\\\"precision\\\"],\\n progress_bar_refresh_rate=50,\\n benchmark=True,\\n auto_lr_find=True,\\n )\\n\\n trainer.fit(model, datamodule)\\n trainer.save_checkpoint(f\\\"rainnet_fold{fold}_bs64_epoch50.ckpt\\\")\\n\\n del datamodule, model, trainer\\n gc.collect()\\n torch.cuda.empty_cache()\\n break\";\n", - " var nbb_formatted_code = \"df = pd.read_csv(args[\\\"train_folds_csv\\\"])\\n\\nfor fold in range(5):\\n train_df = df[df.fold != fold]\\n val_df = df[df.fold == fold]\\n\\n datamodule = NowcastingDataModule(\\n train_df, val_df, batch_size=args[\\\"batch_size\\\"], num_workers=args[\\\"num_workers\\\"]\\n )\\n datamodule.setup()\\n\\n num_train_steps = (\\n int(\\n np.ceil(\\n len(train_df)\\n // args[\\\"batch_size\\\"]\\n / args[\\\"gradient_accumulation_steps\\\"]\\n )\\n )\\n * args[\\\"max_epochs\\\"]\\n )\\n\\n model = RainNet(num_train_steps=num_train_steps)\\n\\n trainer = pl.Trainer(\\n gpus=args[\\\"gpus\\\"],\\n max_epochs=args[\\\"max_epochs\\\"],\\n precision=args[\\\"precision\\\"],\\n progress_bar_refresh_rate=50,\\n benchmark=True,\\n auto_lr_find=True,\\n )\\n\\n trainer.fit(model, datamodule)\\n trainer.save_checkpoint(f\\\"rainnet_fold{fold}_bs64_epoch50.ckpt\\\")\\n\\n del datamodule, model, trainer\\n gc.collect()\\n torch.cuda.empty_cache()\\n break\";\n", - " var nbb_cells = Jupyter.notebook.get_cells();\n", - " for (var i = 0; i < nbb_cells.length; ++i) {\n", - " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", - " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", - " nbb_cells[i].set_text(nbb_formatted_code);\n", - " }\n", - " break;\n", - " }\n", - " }\n", - " }, 500);\n", - " " - ], + "application/vnd.jupyter.widget-view+json": { + "model_id": "950af769e0964c6d97efa2d2285e87db", + "version_major": 2, + "version_minor": 0 + }, "text/plain": [ - "" + "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8ea8ec474d2e4cc78c9308a5465bde52", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", + " z = np.power(10.0, dbz / 10.0)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 | MAE/CSI: 3.380227101944409 | MAE: 2.691118001937866 | CSI: 0.7961352657004831 | Loss: 0.013766583986580372\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a46b071e8fa444e685db515f6149ec09", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 | MAE/CSI: 3.3351390941183854 | MAE: 2.668205976486206 | CSI: 0.8000283949740896 | Loss: 0.012434015050530434\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "4e164dc28c284b098404eb60a35e8fe6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 2 | MAE/CSI: 3.254171461657984 | MAE: 2.6064536571502686 | CSI: 0.8009576901086335 | Loss: 0.012150284834206104\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9446cbd1103f44e294c890bd5f9aae8d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 3 | MAE/CSI: 2.9932598900167617 | MAE: 2.4145052433013916 | CSI: 0.8066473784489451 | Loss: 0.012062947265803814\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b10580e41a054d638d692a4e2c62bd02", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 4 | MAE/CSI: 3.2017156048262843 | MAE: 2.5828773975372314 | CSI: 0.8067166845301893 | Loss: 0.012147068046033382\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", + " z = np.power(10.0, dbz / 10.0)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2725bee46ad64f2a9019c4b30323edba", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 5 | MAE/CSI: 3.101806781350983 | MAE: 2.5002763271331787 | CSI: 0.8060709461861093 | Loss: 0.012087756767868996\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", + " z = np.power(10.0, dbz / 10.0)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d63deb8a7634444886789fcc1aecde22", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 6 | MAE/CSI: 3.7159665055459423 | MAE: 2.9285545349121094 | CSI: 0.788100358422939 | Loss: 0.014243747107684612\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "61685426cf3e4789aee9aed6fda48dbf", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 7 | MAE/CSI: 3.5832275825526168 | MAE: 2.8243930339813232 | CSI: 0.7882259691598213 | Loss: 0.012260997667908669\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "fe79cfb1ce4f492192f5206e04b7caec", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 8 | MAE/CSI: 4.340023416819759 | MAE: 3.3040218353271484 | CSI: 0.761291246152745 | Loss: 0.01332629844546318\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", + " z = np.power(10.0, dbz / 10.0)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "cb8daa0e90124bdc8cf0dd380e249734", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 9 | MAE/CSI: 8.132752788816978 | MAE: 5.300608158111572 | CSI: 0.6517606394469648 | Loss: 0.017829107120633125\n" + ] } ], "source": [ + "seed_everything(args[\"seed\"])\n", + "pl.seed_everything(args[\"seed\"])\n", + "\n", "df = pd.read_csv(args[\"train_folds_csv\"])\n", "\n", "for fold in range(5):\n", @@ -898,9 +1185,7 @@ " num_train_steps = (\n", " int(\n", " np.ceil(\n", - " len(train_df)\n", - " // args[\"batch_size\"]\n", - " / args[\"gradient_accumulation_steps\"]\n", + " len(train_df) // args[\"batch_size\"] / args[\"accumulate_grad_batches\"]\n", " )\n", " )\n", " * args[\"max_epochs\"]\n", @@ -913,10 +1198,17 @@ " max_epochs=args[\"max_epochs\"],\n", " precision=args[\"precision\"],\n", " progress_bar_refresh_rate=50,\n", - " benchmark=True,\n", - " auto_lr_find=True,\n", + "# accumulate_grad_batches=args[\"accumulate_grad_batches\"],\n", + " gradient_clip_val=args[\"gradient_clip_val\"],\n", + " # auto_lr_find=True,\n", + "# benchmark=True,\n", " )\n", "\n", + " # learning rate finder\n", + " # lr_finder = trainer.tuner.lr_find(model, datamodule=datamodule)\n", + " # fig = lr_finder.plot(suggest=True)\n", + " # fig.show()\n", + "\n", " trainer.fit(model, datamodule)\n", " trainer.save_checkpoint(f\"rainnet_fold{fold}_bs64_epoch50.ckpt\")\n", "\n", @@ -933,209 +1225,6 @@ "## Inference" ] }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "RainNet(\n", - " (criterion): L1Loss()\n", - " (encoder): Encoder(\n", - " (blocks): ModuleList(\n", - " (0): Block(\n", - " (net): Sequential(\n", - " (0): Conv2d(4, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): ReLU(inplace=True)\n", - " (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (4): ReLU(inplace=True)\n", - " (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " (1): Block(\n", - " (net): Sequential(\n", - " (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): ReLU(inplace=True)\n", - " (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (4): ReLU(inplace=True)\n", - " (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " (2): Block(\n", - " (net): Sequential(\n", - " (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): ReLU(inplace=True)\n", - " (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (4): ReLU(inplace=True)\n", - " (5): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " (3): Block(\n", - " (net): Sequential(\n", - " (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): ReLU(inplace=True)\n", - " (2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (4): ReLU(inplace=True)\n", - " (5): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " (4): Block(\n", - " (net): Sequential(\n", - " (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): ReLU(inplace=True)\n", - " (2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (3): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (4): ReLU(inplace=True)\n", - " (5): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " )\n", - " (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", - " (dropout): Dropout(p=0.5, inplace=False)\n", - " )\n", - " (decoder): Decoder(\n", - " (ups): ModuleList(\n", - " (0): Upsample(scale_factor=2.0, mode=nearest)\n", - " (1): Upsample(scale_factor=2.0, mode=nearest)\n", - " (2): Upsample(scale_factor=2.0, mode=nearest)\n", - " (3): Upsample(scale_factor=2.0, mode=nearest)\n", - " )\n", - " (convs): ModuleList(\n", - " (0): Block(\n", - " (net): Sequential(\n", - " (0): Conv2d(1536, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): ReLU(inplace=True)\n", - " (2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (4): ReLU(inplace=True)\n", - " (5): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " (1): Block(\n", - " (net): Sequential(\n", - " (0): Conv2d(768, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): ReLU(inplace=True)\n", - " (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (4): ReLU(inplace=True)\n", - " (5): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " (2): Block(\n", - " (net): Sequential(\n", - " (0): Conv2d(384, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): ReLU(inplace=True)\n", - " (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (4): ReLU(inplace=True)\n", - " (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " (3): Block(\n", - " (net): Sequential(\n", - " (0): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): ReLU(inplace=True)\n", - " (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (4): ReLU(inplace=True)\n", - " (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " )\n", - " )\n", - " (out): Sequential(\n", - " (0): Conv2d(64, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (1): ReLU(inplace=True)\n", - " (2): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (3): Conv2d(2, 1, kernel_size=(1, 1), stride=(1, 1))\n", - " (4): ReLU(inplace=True)\n", - " )\n", - ")" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "application/javascript": [ - "\n", - " setTimeout(function() {\n", - " var nbb_cell_id = 11;\n", - " var nbb_unformatted_code = \"model = RainNet.load_from_checkpoint(\\\"rainnet_fold0_bs64_epoch50.ckpt\\\")\\nmodel.to(\\\"cuda\\\")\";\n", - " var nbb_formatted_code = \"model = RainNet.load_from_checkpoint(\\\"rainnet_fold0_bs64_epoch50.ckpt\\\")\\nmodel.to(\\\"cuda\\\")\";\n", - " var nbb_cells = Jupyter.notebook.get_cells();\n", - " for (var i = 0; i < nbb_cells.length; ++i) {\n", - " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", - " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", - " nbb_cells[i].set_text(nbb_formatted_code);\n", - " }\n", - " break;\n", - " }\n", - " }\n", - " }, 500);\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "model = RainNet.load_from_checkpoint(\"rainnet_fold0_bs64_epoch50.ckpt\")\n", - "model.to(\"cuda\")" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "data": { - "application/javascript": [ - "\n", - " setTimeout(function() {\n", - " var nbb_cell_id = 13;\n", - " var nbb_unformatted_code = \"datamodule = NowcastingDataModule(train_df, val_df, batch_size=2 * args[\\\"batch_size\\\"])\\ndatamodule.setup(\\\"test\\\")\";\n", - " var nbb_formatted_code = \"datamodule = NowcastingDataModule(train_df, val_df, batch_size=2 * args[\\\"batch_size\\\"])\\ndatamodule.setup(\\\"test\\\")\";\n", - " var nbb_cells = Jupyter.notebook.get_cells();\n", - " for (var i = 0; i < nbb_cells.length; ++i) {\n", - " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", - " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", - " nbb_cells[i].set_text(nbb_formatted_code);\n", - " }\n", - " break;\n", - " }\n", - " }\n", - " }, 500);\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "datamodule = NowcastingDataModule(train_df, val_df, batch_size=2 * args[\"batch_size\"])\n", - "datamodule.setup(\"test\")" - ] - }, { "cell_type": "code", "execution_count": 14, @@ -1170,22 +1259,38 @@ } ], "source": [ - "preds = []\n", - "model.eval()\n", - "with torch.no_grad():\n", - " for batch in tqdm(datamodule.test_dataloader()):\n", - " batch = batch.to(\"cuda\")\n", - " imgs = model(batch)\n", - " imgs = imgs.detach().cpu().numpy()\n", - " imgs = imgs[:, 0, 4:124, 4:124]\n", - " imgs = 255.0 * imgs\n", - " imgs = np.round(imgs)\n", - " imgs = np.clip(imgs, 0, 255)\n", - " preds.append(imgs)\n", + "datamodule = NowcastingDataModule()\n", + "datamodule.setup(\"test\")\n", + "\n", + "final_preds = np.zeros((len(datamodule.test_dataset), 120, 120))\n", + "\n", + "for fold in range(5):\n", + " model = RainNet.load_from_checkpoint(f\"rainnet_fold{fold}_bs{args['batch_size']}_epoch{args['max_epochs']}.ckpt\")\n", + " model.to(\"cuda\")\n", + "\n", + " preds = []\n", + " model.eval()\n", + " with torch.no_grad():\n", + " for batch in tqdm(datamodule.test_dataloader()):\n", + " batch = batch.to(\"cuda\")\n", + " imgs = model(batch)\n", + " imgs = imgs.detach().cpu().numpy()\n", + " imgs = imgs[:, 0, 4:124, 4:124]\n", + " imgs = args[\"rng\"] * imgs\n", + " imgs = imgs.clip(0, 255)\n", + " imgs = imgs.round()\n", + " preds.append(imgs)\n", "\n", - "preds = np.concatenate(preds)\n", - "preds = preds.astype(np.uint8)\n", - "preds = preds.reshape(len(preds), -1)" + " preds = np.concatenate(preds)\n", + " preds = preds.astype(np.uint8)\n", + " final_preds += preds\n", + " \n", + " del model\n", + " gc.collect()\n", + " torch.cuda.empty_cache()\n", + " break\n", + " \n", + "final_preds = final_preds.reshape(-1, 14400)" ] }, { @@ -1281,10 +1386,9 @@ } ], "source": [ - "subm = pd.DataFrame()\n", - "subm[\"file_name\"] = test_filenames\n", + "subm = pd.DataFrame({\"file_name\": test_filenames})\n", "for i in tqdm(range(14400)):\n", - " subm[str(i)] = preds[:, i]" + " subm[str(i)] = final_preds[:, i]" ] }, { @@ -1513,7 +1617,7 @@ } ], "source": [ - "subm.to_csv(\"rainnet_fold0_epoch50.csv\", index=False)\n", + "subm.to_csv(f\"rainnet_epoch{args['max_epochs']}_lr{args['lr']}.csv\", index=False)\n", "subm.head()" ] }, @@ -1522,9 +1626,28 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "del model" - ] + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] }, { "cell_type": "code", @@ -1557,7 +1680,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python [conda env:torch] *", + "display_name": "Python [conda env:torch]", "language": "python", "name": "conda-env-torch-py" }, diff --git a/notebooks/03-unet.ipynb b/notebooks/03-unet.ipynb index 0e21e74..941b625 100644 --- a/notebooks/03-unet.ipynb +++ b/notebooks/03-unet.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -20,7 +20,7 @@ "application/javascript": [ "\n", " setTimeout(function() {\n", - " var nbb_cell_id = 2;\n", + " var nbb_cell_id = 4;\n", " var nbb_unformatted_code = \"%reload_ext autoreload\\n%autoreload 2\\n%matplotlib inline\\n%reload_ext nb_black\";\n", " var nbb_formatted_code = \"%reload_ext autoreload\\n%autoreload 2\\n%matplotlib inline\\n%reload_ext nb_black\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", @@ -52,7 +52,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -60,9 +60,9 @@ "application/javascript": [ "\n", " setTimeout(function() {\n", - " var nbb_cell_id = 3;\n", - " var nbb_unformatted_code = \"import gc\\nfrom pathlib import Path\\nfrom tqdm.notebook import tqdm\\n\\nimport cv2\\nimport numpy as np\\nimport pandas as pd\\nimport matplotlib.pyplot as plt\\nfrom sklearn import metrics\\n\\nimport torch\\nimport torch.nn as nn\\nimport torch.nn.functional as F\\nimport torchvision.transforms as T\\nfrom torch.utils.data import RandomSampler, SequentialSampler\\nimport pytorch_lightning as pl\\n\\nimport transformers\\n\\nfrom utils import visualize, radar2precipitation, seed_everything\";\n", - " var nbb_formatted_code = \"import gc\\nfrom pathlib import Path\\nfrom tqdm.notebook import tqdm\\n\\nimport cv2\\nimport numpy as np\\nimport pandas as pd\\nimport matplotlib.pyplot as plt\\nfrom sklearn import metrics\\n\\nimport torch\\nimport torch.nn as nn\\nimport torch.nn.functional as F\\nimport torchvision.transforms as T\\nfrom torch.utils.data import RandomSampler, SequentialSampler\\nimport pytorch_lightning as pl\\n\\nimport transformers\\n\\nfrom utils import visualize, radar2precipitation, seed_everything\";\n", + " var nbb_cell_id = 5;\n", + " var nbb_unformatted_code = \"import gc\\nfrom pathlib import Path\\nfrom tqdm.notebook import tqdm\\n\\nimport cv2\\nimport numpy as np\\nimport pandas as pd\\nimport matplotlib.pyplot as plt\\nfrom sklearn import metrics\\n\\nimport torch\\nimport torch.nn as nn\\nimport torch.nn.functional as F\\nimport torchvision.transforms as T\\nfrom torch.utils.data import RandomSampler, SequentialSampler\\nimport pytorch_lightning as pl\\n\\nimport transformers\\n\\nimport optim\\nimport loss\\nfrom utils import visualize, radar2precipitation, seed_everything\";\n", + " var nbb_formatted_code = \"import gc\\nfrom pathlib import Path\\nfrom tqdm.notebook import tqdm\\n\\nimport cv2\\nimport numpy as np\\nimport pandas as pd\\nimport matplotlib.pyplot as plt\\nfrom sklearn import metrics\\n\\nimport torch\\nimport torch.nn as nn\\nimport torch.nn.functional as F\\nimport torchvision.transforms as T\\nfrom torch.utils.data import RandomSampler, SequentialSampler\\nimport pytorch_lightning as pl\\n\\nimport transformers\\n\\nimport optim\\nimport loss\\nfrom utils import visualize, radar2precipitation, seed_everything\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", @@ -103,6 +103,8 @@ "\n", "import transformers\n", "\n", + "import optim\n", + "import loss\n", "from utils import visualize, radar2precipitation, seed_everything" ] }, @@ -122,7 +124,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 24, "metadata": {}, "outputs": [ { @@ -130,9 +132,9 @@ "application/javascript": [ "\n", " setTimeout(function() {\n", - " var nbb_cell_id = 4;\n", - " var nbb_unformatted_code = \"args = dict(\\n seed=42,\\n dams=(6071, 6304, 7026, 7629, 7767, 8944, 11107),\\n train_folds_csv=Path(\\\"../input/train_folds.csv\\\"),\\n train_data_path=Path(\\\"../input/train-128\\\"),\\n test_data_path=Path(\\\"../input/test-128\\\"),\\n rng=255.0,\\n num_workers=4,\\n gpus=1,\\n lr=1e-4,\\n max_epochs=50,\\n batch_size=256,\\n precision=16,\\n optimizer=\\\"adamw\\\",\\n scheduler=\\\"cosine\\\",\\n accumulate_grad_batches=1,\\n gradient_clip_val=5.0,\\n)\";\n", - " var nbb_formatted_code = \"args = dict(\\n seed=42,\\n dams=(6071, 6304, 7026, 7629, 7767, 8944, 11107),\\n train_folds_csv=Path(\\\"../input/train_folds.csv\\\"),\\n train_data_path=Path(\\\"../input/train-128\\\"),\\n test_data_path=Path(\\\"../input/test-128\\\"),\\n rng=255.0,\\n num_workers=4,\\n gpus=1,\\n lr=1e-4,\\n max_epochs=50,\\n batch_size=256,\\n precision=16,\\n optimizer=\\\"adamw\\\",\\n scheduler=\\\"cosine\\\",\\n accumulate_grad_batches=1,\\n gradient_clip_val=5.0,\\n)\";\n", + " var nbb_cell_id = 24;\n", + " var nbb_unformatted_code = \"args = dict(\\n seed=42,\\n dams=(6071, 6304, 7026, 7629, 7767, 8944, 11107),\\n train_folds_csv=Path(\\\"../input/train_folds.csv\\\"),\\n train_data_path=Path(\\\"../input/train-128\\\"),\\n test_data_path=Path(\\\"../input/test-128\\\"),\\n model_dir=Path(\\\"../models\\\"),\\n rng=255.0,\\n num_workers=4,\\n gpus=1,\\n lr=1e-4,\\n max_epochs=50,\\n batch_size=256,\\n precision=16,\\n optimizer=\\\"adamw\\\",\\n scheduler=\\\"cosine\\\",\\n accumulate_grad_batches=1,\\n gradient_clip_val=5.0,\\n)\";\n", + " var nbb_formatted_code = \"args = dict(\\n seed=42,\\n dams=(6071, 6304, 7026, 7629, 7767, 8944, 11107),\\n train_folds_csv=Path(\\\"../input/train_folds.csv\\\"),\\n train_data_path=Path(\\\"../input/train-128\\\"),\\n test_data_path=Path(\\\"../input/test-128\\\"),\\n model_dir=Path(\\\"../models\\\"),\\n rng=255.0,\\n num_workers=4,\\n gpus=1,\\n lr=1e-4,\\n max_epochs=50,\\n batch_size=256,\\n precision=16,\\n optimizer=\\\"adamw\\\",\\n scheduler=\\\"cosine\\\",\\n accumulate_grad_batches=1,\\n gradient_clip_val=5.0,\\n)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", @@ -160,6 +162,7 @@ " train_folds_csv=Path(\"../input/train_folds.csv\"),\n", " train_data_path=Path(\"../input/train-128\"),\n", " test_data_path=Path(\"../input/test-128\"),\n", + " model_dir=Path(\"../models\"),\n", " rng=255.0,\n", " num_workers=4,\n", " gpus=1,\n", @@ -197,7 +200,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -205,7 +208,7 @@ "application/javascript": [ "\n", " setTimeout(function() {\n", - " var nbb_cell_id = 5;\n", + " var nbb_cell_id = 7;\n", " var nbb_unformatted_code = \"class BasicBlock(nn.Module):\\n def __init__(self, in_ch, out_ch):\\n assert in_ch == out_ch\\n super().__init__()\\n self.net = nn.Sequential(\\n nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n nn.BatchNorm2d(out_ch),\\n nn.LeakyReLU(inplace=True),\\n nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),\\n )\\n\\n def forward(self, x):\\n return x + self.net(x)\";\n", " var nbb_formatted_code = \"class BasicBlock(nn.Module):\\n def __init__(self, in_ch, out_ch):\\n assert in_ch == out_ch\\n super().__init__()\\n self.net = nn.Sequential(\\n nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),\\n nn.BatchNorm2d(out_ch),\\n nn.LeakyReLU(inplace=True),\\n nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),\\n )\\n\\n def forward(self, x):\\n return x + self.net(x)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", @@ -244,45 +247,6 @@ " return x + self.net(x)" ] }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "application/javascript": [ - "\n", - " setTimeout(function() {\n", - " var nbb_cell_id = 6;\n", - " var nbb_unformatted_code = \"# x = torch.randn(3, 4, 128, 128)\\n# block = BasicBlock(4, 4)\\n# block(x).shape\";\n", - " var nbb_formatted_code = \"# x = torch.randn(3, 4, 128, 128)\\n# block = BasicBlock(4, 4)\\n# block(x).shape\";\n", - " var nbb_cells = Jupyter.notebook.get_cells();\n", - " for (var i = 0; i < nbb_cells.length; ++i) {\n", - " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", - " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", - " nbb_cells[i].set_text(nbb_formatted_code);\n", - " }\n", - " break;\n", - " }\n", - " }\n", - " }, 500);\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# x = torch.randn(3, 4, 128, 128)\n", - "# block = BasicBlock(4, 4)\n", - "# block(x).shape" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -292,7 +256,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -300,7 +264,7 @@ "application/javascript": [ "\n", " setTimeout(function() {\n", - " var nbb_cell_id = 6;\n", + " var nbb_cell_id = 8;\n", " var nbb_unformatted_code = \"class DownBlock(nn.Module):\\n def __init__(self, in_ch, out_ch):\\n super().__init__()\\n self.id_conv = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=2)\\n self.net = nn.Sequential(\\n nn.BatchNorm2d(in_ch),\\n nn.LeakyReLU(inplace=True),\\n nn.MaxPool2d(2),\\n nn.BatchNorm2d(in_ch),\\n nn.LeakyReLU(inplace=True),\\n nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),\\n )\\n\\n def forward(self, x):\\n residual = x\\n residual = self.id_conv(residual)\\n x = self.net(x)\\n return residual + x, x\";\n", " var nbb_formatted_code = \"class DownBlock(nn.Module):\\n def __init__(self, in_ch, out_ch):\\n super().__init__()\\n self.id_conv = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=2)\\n self.net = nn.Sequential(\\n nn.BatchNorm2d(in_ch),\\n nn.LeakyReLU(inplace=True),\\n nn.MaxPool2d(2),\\n nn.BatchNorm2d(in_ch),\\n nn.LeakyReLU(inplace=True),\\n nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),\\n )\\n\\n def forward(self, x):\\n residual = x\\n residual = self.id_conv(residual)\\n x = self.net(x)\\n return residual + x, x\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", @@ -346,46 +310,7 @@ }, { "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "application/javascript": [ - "\n", - " setTimeout(function() {\n", - " var nbb_cell_id = 7;\n", - " var nbb_unformatted_code = \"# block = DownBlock(4, 64)\\n# down, across = block(x)\\n# down.shape, across.shape\";\n", - " var nbb_formatted_code = \"# block = DownBlock(4, 64)\\n# down, across = block(x)\\n# down.shape, across.shape\";\n", - " var nbb_cells = Jupyter.notebook.get_cells();\n", - " for (var i = 0; i < nbb_cells.length; ++i) {\n", - " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", - " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", - " nbb_cells[i].set_text(nbb_formatted_code);\n", - " }\n", - " break;\n", - " }\n", - " }\n", - " }, 500);\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# block = DownBlock(4, 64)\n", - "# down, across = block(x)\n", - "# down.shape, across.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -393,7 +318,7 @@ "application/javascript": [ "\n", " setTimeout(function() {\n", - " var nbb_cell_id = 8;\n", + " var nbb_cell_id = 9;\n", " var nbb_unformatted_code = \"class Encoder(nn.Module):\\n def __init__(self, chs=[4, 64, 128, 256, 512, 1024]):\\n super().__init__()\\n self.blocks = nn.ModuleList(\\n [DownBlock(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n self.basic = BasicBlock(chs[-1], chs[-1])\\n\\n def forward(self, x):\\n feats = []\\n for block in self.blocks:\\n x, feat = block(x)\\n feats.append(feat)\\n x = self.basic(x)\\n feats.append(x)\\n return feats\";\n", " var nbb_formatted_code = \"class Encoder(nn.Module):\\n def __init__(self, chs=[4, 64, 128, 256, 512, 1024]):\\n super().__init__()\\n self.blocks = nn.ModuleList(\\n [DownBlock(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n self.basic = BasicBlock(chs[-1], chs[-1])\\n\\n def forward(self, x):\\n feats = []\\n for block in self.blocks:\\n x, feat = block(x)\\n feats.append(feat)\\n x = self.basic(x)\\n feats.append(x)\\n return feats\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", @@ -435,47 +360,6 @@ " return feats" ] }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "application/javascript": [ - "\n", - " setTimeout(function() {\n", - " var nbb_cell_id = 9;\n", - " var nbb_unformatted_code = \"# x = torch.randn(3, 4, 128, 128)\\n# encoder = Encoder()\\n# feats = encoder(x)\\n# for feat in feats:\\n# print(feat.shape)\";\n", - " var nbb_formatted_code = \"# x = torch.randn(3, 4, 128, 128)\\n# encoder = Encoder()\\n# feats = encoder(x)\\n# for feat in feats:\\n# print(feat.shape)\";\n", - " var nbb_cells = Jupyter.notebook.get_cells();\n", - " for (var i = 0; i < nbb_cells.length; ++i) {\n", - " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", - " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", - " nbb_cells[i].set_text(nbb_formatted_code);\n", - " }\n", - " break;\n", - " }\n", - " }\n", - " }, 500);\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# x = torch.randn(3, 4, 128, 128)\n", - "# encoder = Encoder()\n", - "# feats = encoder(x)\n", - "# for feat in feats:\n", - "# print(feat.shape)" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -561,46 +445,6 @@ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 11;\n", - " var nbb_unformatted_code = \"# x = torch.randn(3, 1024, 4, 4)\\n# feat = torch.randn(3, 1024, 4, 4)\\n# block = UpBlock(1024, 512)\\n# block(x, feat).shape\";\n", - " var nbb_formatted_code = \"# x = torch.randn(3, 1024, 4, 4)\\n# feat = torch.randn(3, 1024, 4, 4)\\n# block = UpBlock(1024, 512)\\n# block(x, feat).shape\";\n", - " var nbb_cells = Jupyter.notebook.get_cells();\n", - " for (var i = 0; i < nbb_cells.length; ++i) {\n", - " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", - " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", - " nbb_cells[i].set_text(nbb_formatted_code);\n", - " }\n", - " break;\n", - " }\n", - " }\n", - " }, 500);\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# x = torch.randn(3, 1024, 4, 4)\n", - "# feat = torch.randn(3, 1024, 4, 4)\n", - "# block = UpBlock(1024, 512)\n", - "# block(x, feat).shape" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "data": { - "application/javascript": [ - "\n", - " setTimeout(function() {\n", - " var nbb_cell_id = 12;\n", " var nbb_unformatted_code = \"class Decoder(nn.Module):\\n def __init__(self, chs=[1024, 512, 256, 128, 64]):\\n super().__init__()\\n self.blocks = nn.ModuleList(\\n [UpBlock(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n\\n def forward(self, x, feats):\\n for block, feat in zip(self.blocks, feats):\\n x = block(x, feat)\\n return x\";\n", " var nbb_formatted_code = \"class Decoder(nn.Module):\\n def __init__(self, chs=[1024, 512, 256, 128, 64]):\\n super().__init__()\\n self.blocks = nn.ModuleList(\\n [UpBlock(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]\\n )\\n\\n def forward(self, x, feats):\\n for block, feat in zip(self.blocks, feats):\\n x = block(x, feat)\\n return x\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", @@ -639,7 +483,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -647,7 +491,7 @@ "application/javascript": [ "\n", " setTimeout(function() {\n", - " var nbb_cell_id = 13;\n", + " var nbb_cell_id = 12;\n", " var nbb_unformatted_code = \"# x = torch.randn(3, 4, 128, 128)\\n# encoder = Encoder()\\n# feats = encoder(x)\\n# for feat in feats:\\n# print(feat.shape)\";\n", " var nbb_formatted_code = \"# x = torch.randn(3, 4, 128, 128)\\n# encoder = Encoder()\\n# feats = encoder(x)\\n# for feat in feats:\\n# print(feat.shape)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", @@ -680,7 +524,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -688,7 +532,7 @@ "application/javascript": [ "\n", " setTimeout(function() {\n", - " var nbb_cell_id = 14;\n", + " var nbb_cell_id = 13;\n", " var nbb_unformatted_code = \"# decoder = Decoder()\\n# x = torch.randn(3, 1024, 4, 4)\\n# feats = list(reversed(feats))[1:]\\n# decoder(x, feats).shape\";\n", " var nbb_formatted_code = \"# decoder = Decoder()\\n# x = torch.randn(3, 1024, 4, 4)\\n# feats = list(reversed(feats))[1:]\\n# decoder(x, feats).shape\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", @@ -720,7 +564,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -728,9 +572,9 @@ "application/javascript": [ "\n", " setTimeout(function() {\n", - " var nbb_cell_id = 15;\n", - " var nbb_unformatted_code = \"class UNet(pl.LightningModule):\\n def __init__(\\n self,\\n lr=args[\\\"lr\\\"],\\n enc_chs=[4, 64, 128, 256, 512, 1024],\\n dec_chs=[1024, 512, 256, 128, 64],\\n num_train_steps=None,\\n ):\\n super().__init__()\\n self.lr = lr\\n self.num_train_steps = num_train_steps\\n self.criterion = nn.SmoothL1Loss()\\n\\n self.tail = BasicBlock(4, enc_chs[0])\\n self.encoder = Encoder(enc_chs)\\n self.decoder = Decoder(dec_chs)\\n self.head = nn.Sequential(\\n nn.ConvTranspose2d(dec_chs[-1], 32, kernel_size=2, stride=2, bias=False),\\n nn.BatchNorm2d(32),\\n nn.LeakyReLU(inplace=True),\\n nn.Conv2d(32, 1, kernel_size=3, padding=1),\\n nn.ReLU(inplace=True),\\n )\\n\\n def forward(self, x):\\n x = self.tail(x)\\n # print(\\\"after tail:\\\", x.shape)\\n feats = self.encoder(x)\\n feats = feats[::-1]\\n x = self.decoder(feats[0], feats[1:])\\n # print(\\\"after decoder:\\\", x.shape)\\n x = self.head(x)\\n\\n return x\\n\\n def shared_step(self, batch, batch_idx):\\n x, y = batch\\n y_hat = self(x)\\n loss = self.criterion(y_hat, y)\\n\\n return loss, y, y_hat\\n\\n def training_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n self.log(\\\"train_loss\\\", loss)\\n for i, param_group in enumerate(self.optimizer.param_groups):\\n self.log(f\\\"lr/lr{i}\\\", param_group[\\\"lr\\\"])\\n\\n return {\\\"loss\\\": loss}\\n\\n def validation_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n\\n return {\\\"loss\\\": loss, \\\"y\\\": y.detach(), \\\"y_hat\\\": y_hat.detach()}\\n\\n def validation_epoch_end(self, outputs):\\n avg_loss = torch.stack([x[\\\"loss\\\"] for x in outputs]).mean()\\n self.log(\\\"val_loss\\\", avg_loss)\\n\\n crop = T.CenterCrop(120)\\n\\n y = torch.cat([x[\\\"y\\\"] for x in outputs])\\n y = crop(y)\\n y = y.detach().cpu().numpy()\\n y = y.reshape(-1, 120 * 120)\\n\\n y_hat = torch.cat([x[\\\"y_hat\\\"] for x in outputs])\\n y_hat = crop(y_hat)\\n y_hat = y_hat.detach().cpu().numpy()\\n y_hat = y_hat.reshape(-1, 120 * 120)\\n\\n y = args[\\\"rng\\\"] * y[:, args[\\\"dams\\\"]]\\n y = y.clip(0, 255)\\n y_hat = args[\\\"rng\\\"] * y_hat[:, args[\\\"dams\\\"]]\\n y_hat = y_hat.clip(0, 255)\\n\\n y_true = radar2precipitation(y)\\n y_true = np.where(y_true >= 0.1, 1, 0)\\n y_pred = radar2precipitation(y_hat)\\n y_pred = np.where(y_pred >= 0.1, 1, 0)\\n\\n y *= y_true\\n y_hat *= y_true\\n mae = metrics.mean_absolute_error(y, y_hat)\\n self.log(\\\"mae\\\", mae)\\n\\n tn, fp, fn, tp = metrics.confusion_matrix(\\n y_true.ravel(), y_pred.ravel()\\n ).ravel()\\n csi = tp / (tp + fn + fp)\\n self.log(\\\"csi\\\", csi)\\n\\n comp_metric = mae / (csi + 1e-12)\\n self.log(\\\"comp_metric\\\", comp_metric)\\n\\n print(\\n f\\\"Epoch {self.current_epoch} | MAE/CSI: {comp_metric} | MAE: {mae} | CSI: {csi} | Loss: {avg_loss}\\\"\\n )\\n\\n def configure_optimizers(self):\\n # self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\\n self.optimizer = transformers.AdamW(self.parameters(), lr=self.lr)\\n self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\\n self.optimizer, T_max=self.num_train_steps\\n )\\n\\n return [self.optimizer], [{\\\"scheduler\\\": self.scheduler, \\\"interval\\\": \\\"step\\\"}]\";\n", - " var nbb_formatted_code = \"class UNet(pl.LightningModule):\\n def __init__(\\n self,\\n lr=args[\\\"lr\\\"],\\n enc_chs=[4, 64, 128, 256, 512, 1024],\\n dec_chs=[1024, 512, 256, 128, 64],\\n num_train_steps=None,\\n ):\\n super().__init__()\\n self.lr = lr\\n self.num_train_steps = num_train_steps\\n self.criterion = nn.SmoothL1Loss()\\n\\n self.tail = BasicBlock(4, enc_chs[0])\\n self.encoder = Encoder(enc_chs)\\n self.decoder = Decoder(dec_chs)\\n self.head = nn.Sequential(\\n nn.ConvTranspose2d(dec_chs[-1], 32, kernel_size=2, stride=2, bias=False),\\n nn.BatchNorm2d(32),\\n nn.LeakyReLU(inplace=True),\\n nn.Conv2d(32, 1, kernel_size=3, padding=1),\\n nn.ReLU(inplace=True),\\n )\\n\\n def forward(self, x):\\n x = self.tail(x)\\n # print(\\\"after tail:\\\", x.shape)\\n feats = self.encoder(x)\\n feats = feats[::-1]\\n x = self.decoder(feats[0], feats[1:])\\n # print(\\\"after decoder:\\\", x.shape)\\n x = self.head(x)\\n\\n return x\\n\\n def shared_step(self, batch, batch_idx):\\n x, y = batch\\n y_hat = self(x)\\n loss = self.criterion(y_hat, y)\\n\\n return loss, y, y_hat\\n\\n def training_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n self.log(\\\"train_loss\\\", loss)\\n for i, param_group in enumerate(self.optimizer.param_groups):\\n self.log(f\\\"lr/lr{i}\\\", param_group[\\\"lr\\\"])\\n\\n return {\\\"loss\\\": loss}\\n\\n def validation_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n\\n return {\\\"loss\\\": loss, \\\"y\\\": y.detach(), \\\"y_hat\\\": y_hat.detach()}\\n\\n def validation_epoch_end(self, outputs):\\n avg_loss = torch.stack([x[\\\"loss\\\"] for x in outputs]).mean()\\n self.log(\\\"val_loss\\\", avg_loss)\\n\\n crop = T.CenterCrop(120)\\n\\n y = torch.cat([x[\\\"y\\\"] for x in outputs])\\n y = crop(y)\\n y = y.detach().cpu().numpy()\\n y = y.reshape(-1, 120 * 120)\\n\\n y_hat = torch.cat([x[\\\"y_hat\\\"] for x in outputs])\\n y_hat = crop(y_hat)\\n y_hat = y_hat.detach().cpu().numpy()\\n y_hat = y_hat.reshape(-1, 120 * 120)\\n\\n y = args[\\\"rng\\\"] * y[:, args[\\\"dams\\\"]]\\n y = y.clip(0, 255)\\n y_hat = args[\\\"rng\\\"] * y_hat[:, args[\\\"dams\\\"]]\\n y_hat = y_hat.clip(0, 255)\\n\\n y_true = radar2precipitation(y)\\n y_true = np.where(y_true >= 0.1, 1, 0)\\n y_pred = radar2precipitation(y_hat)\\n y_pred = np.where(y_pred >= 0.1, 1, 0)\\n\\n y *= y_true\\n y_hat *= y_true\\n mae = metrics.mean_absolute_error(y, y_hat)\\n self.log(\\\"mae\\\", mae)\\n\\n tn, fp, fn, tp = metrics.confusion_matrix(\\n y_true.ravel(), y_pred.ravel()\\n ).ravel()\\n csi = tp / (tp + fn + fp)\\n self.log(\\\"csi\\\", csi)\\n\\n comp_metric = mae / (csi + 1e-12)\\n self.log(\\\"comp_metric\\\", comp_metric)\\n\\n print(\\n f\\\"Epoch {self.current_epoch} | MAE/CSI: {comp_metric} | MAE: {mae} | CSI: {csi} | Loss: {avg_loss}\\\"\\n )\\n\\n def configure_optimizers(self):\\n # self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\\n self.optimizer = transformers.AdamW(self.parameters(), lr=self.lr)\\n self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\\n self.optimizer, T_max=self.num_train_steps\\n )\\n\\n return [self.optimizer], [{\\\"scheduler\\\": self.scheduler, \\\"interval\\\": \\\"step\\\"}]\";\n", + " var nbb_cell_id = 14;\n", + " var nbb_unformatted_code = \"class UNet(pl.LightningModule):\\n def __init__(\\n self,\\n lr=args[\\\"lr\\\"],\\n enc_chs=[4, 64, 128, 256, 512, 1024],\\n dec_chs=[1024, 512, 256, 128, 64],\\n num_train_steps=None,\\n ):\\n super().__init__()\\n self.lr = lr\\n self.num_train_steps = num_train_steps\\n # self.criterion = nn.SmoothL1Loss()\\n self.criterion = nn.L1Loss()\\n\\n self.tail = BasicBlock(4, enc_chs[0])\\n self.encoder = Encoder(enc_chs)\\n self.decoder = Decoder(dec_chs)\\n self.head = nn.Sequential(\\n nn.ConvTranspose2d(dec_chs[-1], 32, kernel_size=2, stride=2, bias=False),\\n nn.BatchNorm2d(32),\\n nn.LeakyReLU(inplace=True),\\n nn.Conv2d(32, 1, kernel_size=3, padding=1),\\n nn.ReLU(inplace=True),\\n )\\n\\n def forward(self, x):\\n x = self.tail(x)\\n feats = self.encoder(x)\\n feats = feats[::-1]\\n x = self.decoder(feats[0], feats[1:])\\n x = self.head(x)\\n\\n return x\\n\\n def shared_step(self, batch, batch_idx):\\n x, y = batch\\n y_hat = self(x)\\n loss = self.criterion(y_hat, y)\\n\\n return loss, y, y_hat\\n\\n def training_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n self.log(\\\"train_loss\\\", loss)\\n for i, param_group in enumerate(self.optimizer.param_groups):\\n self.log(f\\\"lr/lr{i}\\\", param_group[\\\"lr\\\"])\\n\\n return {\\\"loss\\\": loss}\\n\\n def validation_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n\\n return {\\\"loss\\\": loss, \\\"y\\\": y.detach(), \\\"y_hat\\\": y_hat.detach()}\\n\\n def validation_epoch_end(self, outputs):\\n avg_loss = torch.stack([x[\\\"loss\\\"] for x in outputs]).mean()\\n self.log(\\\"val_loss\\\", avg_loss)\\n\\n crop = T.CenterCrop(120)\\n\\n y = torch.cat([x[\\\"y\\\"] for x in outputs])\\n y = crop(y)\\n y = y.detach().cpu().numpy()\\n y = y.reshape(-1, 120 * 120)\\n\\n y_hat = torch.cat([x[\\\"y_hat\\\"] for x in outputs])\\n y_hat = crop(y_hat)\\n y_hat = y_hat.detach().cpu().numpy()\\n y_hat = y_hat.reshape(-1, 120 * 120)\\n\\n y = args[\\\"rng\\\"] * y[:, args[\\\"dams\\\"]]\\n y = y.clip(0, 255)\\n y_hat = args[\\\"rng\\\"] * y_hat[:, args[\\\"dams\\\"]]\\n y_hat = y_hat.clip(0, 255)\\n\\n y_true = radar2precipitation(y)\\n y_true = np.where(y_true >= 0.1, 1, 0)\\n y_pred = radar2precipitation(y_hat)\\n y_pred = np.where(y_pred >= 0.1, 1, 0)\\n\\n y *= y_true\\n y_hat *= y_true\\n mae = metrics.mean_absolute_error(y, y_hat)\\n self.log(\\\"mae\\\", mae)\\n\\n tn, fp, fn, tp = metrics.confusion_matrix(\\n y_true.ravel(), y_pred.ravel()\\n ).ravel()\\n csi = tp / (tp + fn + fp)\\n self.log(\\\"csi\\\", csi)\\n\\n comp_metric = mae / (csi + 1e-12)\\n self.log(\\\"comp_metric\\\", comp_metric)\\n\\n print(\\n f\\\"Epoch {self.current_epoch} | MAE/CSI: {comp_metric} | MAE: {mae} | CSI: {csi} | Loss: {avg_loss}\\\"\\n )\\n\\n def configure_optimizers(self):\\n \\n # Optimizer\\n if args[\\\"optimizer\\\"] == \\\"adam\\\":\\n self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"adamw\\\":\\n self.optimizer = transformers.AdamW(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"adagrad\\\":\\n self.optimizer = torch.optim.Adagrad(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"radam\\\":\\n self.optimizer = optim.RAdam(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"ranger\\\":\\n optimizer = optim.RAdam(self.parameters(), lr=self.lr)\\n self.optimizer = optim.Lookahead(optimizer)\\n \\n # Scheduler\\n if args[\\\"scheduler\\\"] == \\\"cosine\\\":\\n self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\\n self.optimizer, T_max=self.num_train_steps\\n )\\n return [self.optimizer], [{\\\"scheduler\\\": self.scheduler, \\\"interval\\\": \\\"step\\\"}]\\n elif args[\\\"scheduler\\\"] == \\\"step\\\":\\n self.scheduler = torch.optim.lr_scheduler.StepLR()\\n elif args[\\\"scheduler\\\"] == \\\"plateau\\\":\\n self.scheduler = torch.optim.lr_scheduler.ReduceLR()\\n else:\\n self.scheduler = None\\n return [self.optimizer]\";\n", + " var nbb_formatted_code = \"class UNet(pl.LightningModule):\\n def __init__(\\n self,\\n lr=args[\\\"lr\\\"],\\n enc_chs=[4, 64, 128, 256, 512, 1024],\\n dec_chs=[1024, 512, 256, 128, 64],\\n num_train_steps=None,\\n ):\\n super().__init__()\\n self.lr = lr\\n self.num_train_steps = num_train_steps\\n # self.criterion = nn.SmoothL1Loss()\\n self.criterion = nn.L1Loss()\\n\\n self.tail = BasicBlock(4, enc_chs[0])\\n self.encoder = Encoder(enc_chs)\\n self.decoder = Decoder(dec_chs)\\n self.head = nn.Sequential(\\n nn.ConvTranspose2d(dec_chs[-1], 32, kernel_size=2, stride=2, bias=False),\\n nn.BatchNorm2d(32),\\n nn.LeakyReLU(inplace=True),\\n nn.Conv2d(32, 1, kernel_size=3, padding=1),\\n nn.ReLU(inplace=True),\\n )\\n\\n def forward(self, x):\\n x = self.tail(x)\\n feats = self.encoder(x)\\n feats = feats[::-1]\\n x = self.decoder(feats[0], feats[1:])\\n x = self.head(x)\\n\\n return x\\n\\n def shared_step(self, batch, batch_idx):\\n x, y = batch\\n y_hat = self(x)\\n loss = self.criterion(y_hat, y)\\n\\n return loss, y, y_hat\\n\\n def training_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n self.log(\\\"train_loss\\\", loss)\\n for i, param_group in enumerate(self.optimizer.param_groups):\\n self.log(f\\\"lr/lr{i}\\\", param_group[\\\"lr\\\"])\\n\\n return {\\\"loss\\\": loss}\\n\\n def validation_step(self, batch, batch_idx):\\n loss, y, y_hat = self.shared_step(batch, batch_idx)\\n\\n return {\\\"loss\\\": loss, \\\"y\\\": y.detach(), \\\"y_hat\\\": y_hat.detach()}\\n\\n def validation_epoch_end(self, outputs):\\n avg_loss = torch.stack([x[\\\"loss\\\"] for x in outputs]).mean()\\n self.log(\\\"val_loss\\\", avg_loss)\\n\\n crop = T.CenterCrop(120)\\n\\n y = torch.cat([x[\\\"y\\\"] for x in outputs])\\n y = crop(y)\\n y = y.detach().cpu().numpy()\\n y = y.reshape(-1, 120 * 120)\\n\\n y_hat = torch.cat([x[\\\"y_hat\\\"] for x in outputs])\\n y_hat = crop(y_hat)\\n y_hat = y_hat.detach().cpu().numpy()\\n y_hat = y_hat.reshape(-1, 120 * 120)\\n\\n y = args[\\\"rng\\\"] * y[:, args[\\\"dams\\\"]]\\n y = y.clip(0, 255)\\n y_hat = args[\\\"rng\\\"] * y_hat[:, args[\\\"dams\\\"]]\\n y_hat = y_hat.clip(0, 255)\\n\\n y_true = radar2precipitation(y)\\n y_true = np.where(y_true >= 0.1, 1, 0)\\n y_pred = radar2precipitation(y_hat)\\n y_pred = np.where(y_pred >= 0.1, 1, 0)\\n\\n y *= y_true\\n y_hat *= y_true\\n mae = metrics.mean_absolute_error(y, y_hat)\\n self.log(\\\"mae\\\", mae)\\n\\n tn, fp, fn, tp = metrics.confusion_matrix(\\n y_true.ravel(), y_pred.ravel()\\n ).ravel()\\n csi = tp / (tp + fn + fp)\\n self.log(\\\"csi\\\", csi)\\n\\n comp_metric = mae / (csi + 1e-12)\\n self.log(\\\"comp_metric\\\", comp_metric)\\n\\n print(\\n f\\\"Epoch {self.current_epoch} | MAE/CSI: {comp_metric} | MAE: {mae} | CSI: {csi} | Loss: {avg_loss}\\\"\\n )\\n\\n def configure_optimizers(self):\\n\\n # Optimizer\\n if args[\\\"optimizer\\\"] == \\\"adam\\\":\\n self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"adamw\\\":\\n self.optimizer = transformers.AdamW(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"adagrad\\\":\\n self.optimizer = torch.optim.Adagrad(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"radam\\\":\\n self.optimizer = optim.RAdam(self.parameters(), lr=self.lr)\\n elif args[\\\"optimizer\\\"] == \\\"ranger\\\":\\n optimizer = optim.RAdam(self.parameters(), lr=self.lr)\\n self.optimizer = optim.Lookahead(optimizer)\\n\\n # Scheduler\\n if args[\\\"scheduler\\\"] == \\\"cosine\\\":\\n self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\\n self.optimizer, T_max=self.num_train_steps\\n )\\n return [self.optimizer], [{\\\"scheduler\\\": self.scheduler, \\\"interval\\\": \\\"step\\\"}]\\n elif args[\\\"scheduler\\\"] == \\\"step\\\":\\n self.scheduler = torch.optim.lr_scheduler.StepLR()\\n elif args[\\\"scheduler\\\"] == \\\"plateau\\\":\\n self.scheduler = torch.optim.lr_scheduler.ReduceLR()\\n else:\\n self.scheduler = None\\n return [self.optimizer]\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", @@ -763,7 +607,8 @@ " super().__init__()\n", " self.lr = lr\n", " self.num_train_steps = num_train_steps\n", - " self.criterion = nn.SmoothL1Loss()\n", + " # self.criterion = nn.SmoothL1Loss()\n", + " self.criterion = nn.L1Loss()\n", "\n", " self.tail = BasicBlock(4, enc_chs[0])\n", " self.encoder = Encoder(enc_chs)\n", @@ -778,11 +623,9 @@ "\n", " def forward(self, x):\n", " x = self.tail(x)\n", - " # print(\"after tail:\", x.shape)\n", " feats = self.encoder(x)\n", " feats = feats[::-1]\n", " x = self.decoder(feats[0], feats[1:])\n", - " # print(\"after decoder:\", x.shape)\n", " x = self.head(x)\n", "\n", " return x\n", @@ -852,18 +695,38 @@ " )\n", "\n", " def configure_optimizers(self):\n", - " # self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\n", - " self.optimizer = transformers.AdamW(self.parameters(), lr=self.lr)\n", - " self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n", - " self.optimizer, T_max=self.num_train_steps\n", - " )\n", "\n", - " return [self.optimizer], [{\"scheduler\": self.scheduler, \"interval\": \"step\"}]" + " # Optimizer\n", + " if args[\"optimizer\"] == \"adam\":\n", + " self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\n", + " elif args[\"optimizer\"] == \"adamw\":\n", + " self.optimizer = transformers.AdamW(self.parameters(), lr=self.lr)\n", + " elif args[\"optimizer\"] == \"adagrad\":\n", + " self.optimizer = torch.optim.Adagrad(self.parameters(), lr=self.lr)\n", + " elif args[\"optimizer\"] == \"radam\":\n", + " self.optimizer = optim.RAdam(self.parameters(), lr=self.lr)\n", + " elif args[\"optimizer\"] == \"ranger\":\n", + " optimizer = optim.RAdam(self.parameters(), lr=self.lr)\n", + " self.optimizer = optim.Lookahead(optimizer)\n", + "\n", + " # Scheduler\n", + " if args[\"scheduler\"] == \"cosine\":\n", + " self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n", + " self.optimizer, T_max=self.num_train_steps\n", + " )\n", + " return [self.optimizer], [{\"scheduler\": self.scheduler, \"interval\": \"step\"}]\n", + " elif args[\"scheduler\"] == \"step\":\n", + " self.scheduler = torch.optim.lr_scheduler.StepLR()\n", + " elif args[\"scheduler\"] == \"plateau\":\n", + " self.scheduler = torch.optim.lr_scheduler.ReduceLR()\n", + " else:\n", + " self.scheduler = None\n", + " return [self.optimizer]" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -871,7 +734,7 @@ "application/javascript": [ "\n", " setTimeout(function() {\n", - " var nbb_cell_id = 16;\n", + " var nbb_cell_id = 18;\n", " var nbb_unformatted_code = \"# m = UNet()\\n# x = torch.randn(3, 4, 128, 128)\\n# m(x).shape\";\n", " var nbb_formatted_code = \"# m = UNet()\\n# x = torch.randn(3, 4, 128, 128)\\n# m(x).shape\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", @@ -909,7 +772,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -917,9 +780,9 @@ "application/javascript": [ "\n", " setTimeout(function() {\n", - " var nbb_cell_id = 17;\n", - " var nbb_unformatted_code = \"class NowcastingDataset(torch.utils.data.Dataset):\\n def __init__(self, paths, test=False):\\n self.paths = paths\\n self.test = test\\n\\n def __len__(self):\\n return len(self.paths)\\n\\n def __getitem__(self, idx):\\n path = self.paths[idx]\\n data = np.load(path)\\n x = data[:, :, :4]\\n x = x / args[\\\"rng\\\"]\\n x = x.astype(np.float32)\\n x = torch.tensor(x, dtype=torch.float)\\n x = x.permute(2, 0, 1)\\n if self.test:\\n return x\\n else:\\n y = data[:, :, 4]\\n y = y / args[\\\"rng\\\"]\\n y = y.astype(np.float32)\\n y = torch.tensor(y, dtype=torch.float)\\n y = y.unsqueeze(-1)\\n y = y.permute(2, 0, 1)\\n\\n return x, y\";\n", - " var nbb_formatted_code = \"class NowcastingDataset(torch.utils.data.Dataset):\\n def __init__(self, paths, test=False):\\n self.paths = paths\\n self.test = test\\n\\n def __len__(self):\\n return len(self.paths)\\n\\n def __getitem__(self, idx):\\n path = self.paths[idx]\\n data = np.load(path)\\n x = data[:, :, :4]\\n x = x / args[\\\"rng\\\"]\\n x = x.astype(np.float32)\\n x = torch.tensor(x, dtype=torch.float)\\n x = x.permute(2, 0, 1)\\n if self.test:\\n return x\\n else:\\n y = data[:, :, 4]\\n y = y / args[\\\"rng\\\"]\\n y = y.astype(np.float32)\\n y = torch.tensor(y, dtype=torch.float)\\n y = y.unsqueeze(-1)\\n y = y.permute(2, 0, 1)\\n\\n return x, y\";\n", + " var nbb_cell_id = 19;\n", + " var nbb_unformatted_code = \"class NowcastingDataset(torch.utils.data.Dataset):\\n def __init__(self, paths, test=False):\\n self.paths = paths\\n self.test = test\\n\\n def __len__(self):\\n return len(self.paths)\\n\\n def __getitem__(self, idx):\\n path = self.paths[idx]\\n data = np.load(path)\\n\\n x = data[:, :, :4]\\n x = x / args[\\\"rng\\\"]\\n x = x.astype(np.float32)\\n x = torch.tensor(x, dtype=torch.float)\\n x = x.permute(2, 0, 1)\\n if self.test:\\n return x\\n else:\\n y = data[:, :, 4]\\n\\n precipitation = radar2precipitation(y)\\n\\n label = np.zeros(y.shape)\\n label[precipitation >= 0.1] += 1\\n label[precipitation >= 1.0] += 1\\n label[precipitation >= 2.5] += 1\\n label = torch.tensor(label, dtype=torch.long)\\n label = label.unsqueeze(0)\\n\\n y = y / args[\\\"rng\\\"]\\n y = y.astype(np.float32)\\n y = torch.tensor(y, dtype=torch.float)\\n y = y.unsqueeze(-1)\\n y = y.permute(2, 0, 1)\\n\\n return x, y, label\";\n", + " var nbb_formatted_code = \"class NowcastingDataset(torch.utils.data.Dataset):\\n def __init__(self, paths, test=False):\\n self.paths = paths\\n self.test = test\\n\\n def __len__(self):\\n return len(self.paths)\\n\\n def __getitem__(self, idx):\\n path = self.paths[idx]\\n data = np.load(path)\\n\\n x = data[:, :, :4]\\n x = x / args[\\\"rng\\\"]\\n x = x.astype(np.float32)\\n x = torch.tensor(x, dtype=torch.float)\\n x = x.permute(2, 0, 1)\\n if self.test:\\n return x\\n else:\\n y = data[:, :, 4]\\n\\n precipitation = radar2precipitation(y)\\n\\n label = np.zeros(y.shape)\\n label[precipitation >= 0.1] += 1\\n label[precipitation >= 1.0] += 1\\n label[precipitation >= 2.5] += 1\\n label = torch.tensor(label, dtype=torch.long)\\n label = label.unsqueeze(0)\\n\\n y = y / args[\\\"rng\\\"]\\n y = y.astype(np.float32)\\n y = torch.tensor(y, dtype=torch.float)\\n y = y.unsqueeze(-1)\\n y = y.permute(2, 0, 1)\\n\\n return x, y, label\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", @@ -952,6 +815,7 @@ " def __getitem__(self, idx):\n", " path = self.paths[idx]\n", " data = np.load(path)\n", + "\n", " x = data[:, :, :4]\n", " x = x / args[\"rng\"]\n", " x = x.astype(np.float32)\n", @@ -961,28 +825,92 @@ " return x\n", " else:\n", " y = data[:, :, 4]\n", + "\n", + " precipitation = radar2precipitation(y)\n", + "\n", + " label = np.zeros(y.shape)\n", + " label[precipitation >= 0.1] += 1\n", + " label[precipitation >= 1.0] += 1\n", + " label[precipitation >= 2.5] += 1\n", + " label = torch.tensor(label, dtype=torch.long)\n", + " label = label.unsqueeze(0)\n", + "\n", " y = y / args[\"rng\"]\n", " y = y.astype(np.float32)\n", " y = torch.tensor(y, dtype=torch.float)\n", " y = y.unsqueeze(-1)\n", " y = y.permute(2, 0, 1)\n", "\n", - " return x, y" + " return x, y, label" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 20, "metadata": {}, "outputs": [ + { + "ename": "ValueError", + "evalue": "too many values to unpack (expected 2)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mdataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mNowcastingDataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_paths\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0midx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mValueError\u001b[0m: too many values to unpack (expected 2)" + ] + }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", - " var nbb_cell_id = 18;\n", - " var nbb_unformatted_code = \"class NowcastingDataModule(pl.LightningDataModule):\\n def __init__(\\n self,\\n train_df=None,\\n val_df=None,\\n batch_size=args[\\\"batch_size\\\"],\\n num_workers=args[\\\"num_workers\\\"],\\n ):\\n super().__init__()\\n self.train_df = train_df\\n self.val_df = val_df\\n self.batch_size = batch_size\\n self.num_workers = num_workers\\n\\n def setup(self, stage=\\\"train\\\"):\\n if stage == \\\"train\\\":\\n train_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.train_df.filename.values\\n ]\\n val_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.val_df.filename.values\\n ]\\n self.train_dataset = NowcastingDataset(train_paths)\\n self.val_dataset = NowcastingDataset(val_paths)\\n else:\\n test_paths = list(args[\\\"test_data_path\\\"].glob(\\\"*.npy\\\"))\\n self.test_dataset = NowcastingDataset(test_paths, test=True)\\n\\n def train_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.train_dataset,\\n batch_size=self.batch_size,\\n sampler=RandomSampler(self.train_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n drop_last=True,\\n )\\n\\n def val_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.val_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.val_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\\n\\n def test_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.test_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.test_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\";\n", - " var nbb_formatted_code = \"class NowcastingDataModule(pl.LightningDataModule):\\n def __init__(\\n self,\\n train_df=None,\\n val_df=None,\\n batch_size=args[\\\"batch_size\\\"],\\n num_workers=args[\\\"num_workers\\\"],\\n ):\\n super().__init__()\\n self.train_df = train_df\\n self.val_df = val_df\\n self.batch_size = batch_size\\n self.num_workers = num_workers\\n\\n def setup(self, stage=\\\"train\\\"):\\n if stage == \\\"train\\\":\\n train_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.train_df.filename.values\\n ]\\n val_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.val_df.filename.values\\n ]\\n self.train_dataset = NowcastingDataset(train_paths)\\n self.val_dataset = NowcastingDataset(val_paths)\\n else:\\n test_paths = list(args[\\\"test_data_path\\\"].glob(\\\"*.npy\\\"))\\n self.test_dataset = NowcastingDataset(test_paths, test=True)\\n\\n def train_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.train_dataset,\\n batch_size=self.batch_size,\\n sampler=RandomSampler(self.train_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n drop_last=True,\\n )\\n\\n def val_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.val_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.val_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\\n\\n def test_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.test_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.test_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\";\n", + " var nbb_cell_id = 20;\n", + " var nbb_unformatted_code = \"fold = 3\\ndf = pd.read_csv(args[\\\"train_folds_csv\\\"])\\ntrain_df = df[df.fold != fold]\\ntrain_paths = [args[\\\"train_data_path\\\"] / fn for fn in train_df.filename.values]\\ndataset = NowcastingDataset(train_paths)\\nidx = np.random.randint(len(dataset))\\nx, y = dataset[idx]\";\n", + " var nbb_formatted_code = \"fold = 3\\ndf = pd.read_csv(args[\\\"train_folds_csv\\\"])\\ntrain_df = df[df.fold != fold]\\ntrain_paths = [args[\\\"train_data_path\\\"] / fn for fn in train_df.filename.values]\\ndataset = NowcastingDataset(train_paths)\\nidx = np.random.randint(len(dataset))\\nx, y = dataset[idx]\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fold = 3\n", + "df = pd.read_csv(args[\"train_folds_csv\"])\n", + "train_df = df[df.fold != fold]\n", + "train_paths = [args[\"train_data_path\"] / fn for fn in train_df.filename.values]\n", + "dataset = NowcastingDataset(train_paths)\n", + "idx = np.random.randint(len(dataset))\n", + "x, y = dataset[idx]" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 21;\n", + " var nbb_unformatted_code = \"class NowcastingDataModule(pl.LightningDataModule):\\n def __init__(\\n self,\\n train_df=None,\\n val_df=None,\\n batch_size=args[\\\"batch_size\\\"],\\n num_workers=args[\\\"num_workers\\\"],\\n ):\\n super().__init__()\\n self.train_df = train_df\\n self.val_df = val_df\\n self.batch_size = batch_size\\n self.num_workers = num_workers\\n\\n def setup(self, stage=\\\"train\\\"):\\n if stage == \\\"train\\\":\\n train_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.train_df.filename.values\\n ]\\n val_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.val_df.filename.values\\n ]\\n self.train_dataset = NowcastingDataset(train_paths)\\n self.val_dataset = NowcastingDataset(val_paths)\\n else:\\n test_paths = list(sorted(args[\\\"test_data_path\\\"].glob(\\\"*.npy\\\")))\\n self.test_dataset = NowcastingDataset(test_paths, test=True)\\n\\n def train_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.train_dataset,\\n batch_size=self.batch_size,\\n sampler=RandomSampler(self.train_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n drop_last=True,\\n )\\n\\n def val_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.val_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.val_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\\n\\n def test_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.test_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.test_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\";\n", + " var nbb_formatted_code = \"class NowcastingDataModule(pl.LightningDataModule):\\n def __init__(\\n self,\\n train_df=None,\\n val_df=None,\\n batch_size=args[\\\"batch_size\\\"],\\n num_workers=args[\\\"num_workers\\\"],\\n ):\\n super().__init__()\\n self.train_df = train_df\\n self.val_df = val_df\\n self.batch_size = batch_size\\n self.num_workers = num_workers\\n\\n def setup(self, stage=\\\"train\\\"):\\n if stage == \\\"train\\\":\\n train_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.train_df.filename.values\\n ]\\n val_paths = [\\n args[\\\"train_data_path\\\"] / fn for fn in self.val_df.filename.values\\n ]\\n self.train_dataset = NowcastingDataset(train_paths)\\n self.val_dataset = NowcastingDataset(val_paths)\\n else:\\n test_paths = list(sorted(args[\\\"test_data_path\\\"].glob(\\\"*.npy\\\")))\\n self.test_dataset = NowcastingDataset(test_paths, test=True)\\n\\n def train_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.train_dataset,\\n batch_size=self.batch_size,\\n sampler=RandomSampler(self.train_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n drop_last=True,\\n )\\n\\n def val_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.val_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.val_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\\n\\n def test_dataloader(self):\\n return torch.utils.data.DataLoader(\\n self.test_dataset,\\n batch_size=2 * self.batch_size,\\n sampler=SequentialSampler(self.test_dataset),\\n pin_memory=True,\\n num_workers=self.num_workers,\\n )\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", @@ -1029,7 +957,7 @@ " self.train_dataset = NowcastingDataset(train_paths)\n", " self.val_dataset = NowcastingDataset(val_paths)\n", " else:\n", - " test_paths = list(args[\"test_data_path\"].glob(\"*.npy\"))\n", + " test_paths = list(sorted(args[\"test_data_path\"].glob(\"*.npy\")))\n", " self.test_dataset = NowcastingDataset(test_paths, test=True)\n", "\n", " def train_dataloader(self):\n", @@ -1070,9 +998,9 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "metadata": { - "scrolled": false + "scrolled": true }, "outputs": [ { @@ -1084,19 +1012,19 @@ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", "Using native 16bit precision.\n", "\n", - " | Name | Type | Params\n", - "-------------------------------------------\n", - "0 | criterion | SmoothL1Loss | 0 \n", - "1 | tail | BasicBlock | 300 \n", - "2 | encoder | Encoder | 25 M \n", - "3 | decoder | Decoder | 17 M \n", - "4 | head | Sequential | 8 K \n" + " | Name | Type | Params\n", + "-----------------------------------------\n", + "0 | criterion | L1Loss | 0 \n", + "1 | tail | BasicBlock | 300 \n", + "2 | encoder | Encoder | 25 M \n", + "3 | decoder | Decoder | 17 M \n", + "4 | head | Sequential | 8 K \n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "f6cbf2f4afb64d01b7584ee121a1da6e", + "model_id": "ed004db6ca9f4443b88355a974d0aff0", "version_major": 2, "version_minor": 0 }, @@ -1107,1923 +1035,220 @@ "metadata": {}, "output_type": "display_data" }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 0 | MAE/CSI: 18053432464599.61 | MAE: 18.05343246459961 | CSI: 0.0 | Loss: 0.011555060744285583\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e02e2c5a7c694153a5ac69e51294884a", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "4403ddfd3cb8421d997ebf14219a7e88", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, { "name": "stderr", "output_type": "stream", "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" + "ERROR:root:Internal Python error in the inspect module.\n", + "Below is the traceback from this internal error.\n", + "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 0 | MAE/CSI: 4.579871016522517 | MAE: 3.356309652328491 | CSI: 0.7328393398450657 | Loss: 0.0015933009563013911\n" + "Traceback (most recent call last):\n", + " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/IPython/core/interactiveshell.py\", line 3418, in run_code\n", + " exec(code_obj, self.user_global_ns, self.user_ns)\n", + " File \"\", line 41, in \n", + " trainer.fit(model, datamodule)\n", + " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\", line 440, in fit\n", + " results = self.accelerator_backend.train()\n", + " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\", line 54, in train\n", + " results = self.train_or_test()\n", + " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py\", line 68, in train_or_test\n", + " results = self.trainer.train()\n", + " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\", line 462, in train\n", + " self.run_sanity_check(self.get_model())\n", + " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\", line 650, in run_sanity_check\n", + " _, eval_results = self.run_evaluation(test_mode=False, max_batches=self.num_sanity_val_batches)\n", + " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\", line 570, in run_evaluation\n", + " output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx)\n", + " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/trainer/evaluation_loop.py\", line 171, in evaluation_step\n", + " output = self.trainer.accelerator_backend.validation_step(args)\n", + " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\", line 76, in validation_step\n", + " output = self.__validation_step(args)\n", + " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\", line 86, in __validation_step\n", + " output = self.trainer.model.validation_step(*args)\n", + " File \"\", line 51, in validation_step\n", + " loss, y, y_hat = self.shared_step(batch, batch_idx)\n", + " File \"\", line 36, in shared_step\n", + " x, y = batch\n", + "ValueError: too many values to unpack (expected 2)\n", + "\n", + "During handling of the above exception, another exception occurred:\n", + "\n", + "Traceback (most recent call last):\n", + " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/IPython/core/interactiveshell.py\", line 2045, in showtraceback\n", + " stb = value._render_traceback_()\n", + "AttributeError: 'ValueError' object has no attribute '_render_traceback_'\n", + "\n", + "During handling of the above exception, another exception occurred:\n", + "\n", + "Traceback (most recent call last):\n", + " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/IPython/core/ultratb.py\", line 1170, in get_records\n", + " return _fixed_getinnerframes(etb, number_of_lines_of_context, tb_offset)\n", + " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/IPython/core/ultratb.py\", line 316, in wrapped\n", + " return f(*args, **kwargs)\n", + " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/IPython/core/ultratb.py\", line 350, in _fixed_getinnerframes\n", + " records = fix_frame_records_filenames(inspect.getinnerframes(etb, context))\n", + " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/inspect.py\", line 1503, in getinnerframes\n", + " frameinfo = (tb.tb_frame,) + getframeinfo(tb, context)\n", + " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/inspect.py\", line 1461, in getframeinfo\n", + " filename = getsourcefile(frame) or getfile(frame)\n", + " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/inspect.py\", line 708, in getsourcefile\n", + " if getattr(getmodule(object, filename), '__loader__', None) is not None:\n", + " File \"/home/isleof/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/inspect.py\", line 745, in getmodule\n", + " if ismodule(module) and hasattr(module, '__file__'):\n", + "KeyboardInterrupt\n" + ] + }, + { + "ename": "TypeError", + "evalue": "object of type 'NoneType' has no len()", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 41\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdatamodule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 42\u001b[0m trainer.save_checkpoint(\n", + "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, model, train_dataloader, val_dataloaders, datamodule)\u001b[0m\n\u001b[1;32m 439\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 440\u001b[0;31m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccelerator_backend\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 441\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccelerator_backend\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mteardown\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0;31m# train or test\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 54\u001b[0;31m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_or_test\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 55\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py\u001b[0m in \u001b[0;36mtrain_or_test\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 68\u001b[0;31m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 69\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 461\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 462\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_sanity_check\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 463\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mrun_sanity_check\u001b[0;34m(self, ref_model)\u001b[0m\n\u001b[1;32m 649\u001b[0m \u001b[0;31m# run eval step\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 650\u001b[0;31m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_evaluation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_mode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_batches\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_sanity_val_batches\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 651\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mrun_evaluation\u001b[0;34m(self, test_mode, max_batches)\u001b[0m\n\u001b[1;32m 569\u001b[0m \u001b[0;31m# lightning module methods\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 570\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluation_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluation_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_mode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdataloader_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 571\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluation_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluation_step_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/trainer/evaluation_loop.py\u001b[0m in \u001b[0;36mevaluation_step\u001b[0;34m(self, test_mode, batch, batch_idx, dataloader_idx)\u001b[0m\n\u001b[1;32m 170\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 171\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccelerator_backend\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalidation_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 172\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\u001b[0m in \u001b[0;36mvalidation_step\u001b[0;34m(self, args)\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mamp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautocast\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 76\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__validation_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 77\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\u001b[0m in \u001b[0;36m__validation_step\u001b[0;34m(self, args)\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 86\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalidation_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 87\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mvalidation_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mvalidation_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 51\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_hat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshared_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 52\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mshared_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mshared_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 36\u001b[0;31m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 37\u001b[0m \u001b[0my_hat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mValueError\u001b[0m: too many values to unpack (expected 2)", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mshowtraceback\u001b[0;34m(self, exc_tuple, filename, tb_offset, exception_only, running_compiled_code)\u001b[0m\n\u001b[1;32m 2044\u001b[0m \u001b[0;31m# in the engines. This should return a list of strings.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2045\u001b[0;31m \u001b[0mstb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_render_traceback_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2046\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mAttributeError\u001b[0m: 'ValueError' object has no attribute '_render_traceback_'", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mshowtraceback\u001b[0;34m(self, exc_tuple, filename, tb_offset, exception_only, running_compiled_code)\u001b[0m\n\u001b[1;32m 2045\u001b[0m \u001b[0mstb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_render_traceback_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2046\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2047\u001b[0;31m stb = self.InteractiveTB.structured_traceback(etype,\n\u001b[0m\u001b[1;32m 2048\u001b[0m value, tb, tb_offset=tb_offset)\n\u001b[1;32m 2049\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mstructured_traceback\u001b[0;34m(self, etype, value, tb, tb_offset, number_of_lines_of_context)\u001b[0m\n\u001b[1;32m 1434\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1435\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1436\u001b[0;31m return FormattedTB.structured_traceback(\n\u001b[0m\u001b[1;32m 1437\u001b[0m self, etype, value, tb, tb_offset, number_of_lines_of_context)\n\u001b[1;32m 1438\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mstructured_traceback\u001b[0;34m(self, etype, value, tb, tb_offset, number_of_lines_of_context)\u001b[0m\n\u001b[1;32m 1334\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmode\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mverbose_modes\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1335\u001b[0m \u001b[0;31m# Verbose modes need a full traceback\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1336\u001b[0;31m return VerboseTB.structured_traceback(\n\u001b[0m\u001b[1;32m 1337\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0metype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtb_offset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnumber_of_lines_of_context\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1338\u001b[0m )\n", + "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mstructured_traceback\u001b[0;34m(self, etype, evalue, etb, tb_offset, number_of_lines_of_context)\u001b[0m\n\u001b[1;32m 1191\u001b[0m \u001b[0;34m\"\"\"Return a nice text document describing the traceback.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1192\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1193\u001b[0;31m formatted_exception = self.format_exception_as_a_whole(etype, evalue, etb, number_of_lines_of_context,\n\u001b[0m\u001b[1;32m 1194\u001b[0m tb_offset)\n\u001b[1;32m 1195\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mformat_exception_as_a_whole\u001b[0;34m(self, etype, evalue, etb, number_of_lines_of_context, tb_offset)\u001b[0m\n\u001b[1;32m 1149\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1150\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1151\u001b[0;31m \u001b[0mlast_unique\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecursion_repeat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfind_recursion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0morig_etype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mevalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecords\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1152\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1153\u001b[0m \u001b[0mframes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat_records\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrecords\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlast_unique\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecursion_repeat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/miniconda3-latest/envs/torch2/lib/python3.8/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mfind_recursion\u001b[0;34m(etype, value, records)\u001b[0m\n\u001b[1;32m 449\u001b[0m \u001b[0;31m# first frame (from in to out) that looks different.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 450\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mis_recursion_error\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0metype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecords\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 451\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrecords\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 452\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 453\u001b[0m \u001b[0;31m# Select filename, lineno, func_name to track frames with\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mTypeError\u001b[0m: object of type 'NoneType' has no len()" ] }, { "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "28814e59112a4e41aecf3c9f55868797", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 1 | MAE/CSI: 4.3457863606748175 | MAE: 3.272869348526001 | CSI: 0.7531132634908084 | Loss: 0.0015471124788746238\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "7dcafedcbc8f4c388e4afb2fd60656d8", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 2 | MAE/CSI: 4.294591652152133 | MAE: 3.2539639472961426 | CSI: 0.7576887887958337 | Loss: 0.0012963797198608518\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "0ccf9265046e47aba7705ead37cbfd59", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 3 | MAE/CSI: 4.154932101723256 | MAE: 3.154283046722412 | CSI: 0.7591659669745312 | Loss: 0.0012544452911242843\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "59323aa0540e439497bc330e2dcb9853", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 4 | MAE/CSI: 3.7860146000298136 | MAE: 2.913952112197876 | CSI: 0.7696621434505677 | Loss: 0.00123176712077111\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "f4b933dd0b7d449f92ac36598480e89d", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 5 | MAE/CSI: 4.065141708011735 | MAE: 3.1184253692626953 | CSI: 0.7671135702631766 | Loss: 0.0012020657304674387\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "88b7d3e1305d433a82c59d40c1c44aaf", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 6 | MAE/CSI: 3.61714756904823 | MAE: 2.808803081512451 | CSI: 0.7765243269430411 | Loss: 0.0011883970582857728\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c6dabd0d340242c7a0a5db8148f42146", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 7 | MAE/CSI: 3.7737966527452103 | MAE: 2.9232699871063232 | CSI: 0.7746230801747217 | Loss: 0.0011574724921956658\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "d916876fa47c4325a4657159d63a7208", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 8 | MAE/CSI: 3.5785213933996576 | MAE: 2.792635440826416 | CSI: 0.7803880803880804 | Loss: 0.0011816049227491021\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "1791743d12bb4347a3cce079cd3f3087", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 9 | MAE/CSI: 3.7593860248772217 | MAE: 2.9188804626464844 | CSI: 0.776424778761062 | Loss: 0.0011313384165987372\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "715c2eea0fe445eab910394a3b404f1e", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 10 | MAE/CSI: 3.537799221788761 | MAE: 2.764934539794922 | CSI: 0.7815408298929396 | Loss: 0.0011271697003394365\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "760658f5ad6041ddad9f48459af02785", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 11 | MAE/CSI: 4.176114563603725 | MAE: 3.2004687786102295 | CSI: 0.7663747557356879 | Loss: 0.0011378307826817036\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "f29d82d17c8d4fa6a024244d1ed3e6bc", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 12 | MAE/CSI: 4.033352344612596 | MAE: 3.1121833324432373 | CSI: 0.7716120652330783 | Loss: 0.001107610878534615\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "a5cb37e8678f4ff8b5aace45880d6f6d", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 13 | MAE/CSI: 3.6702073925815 | MAE: 2.86871600151062 | CSI: 0.781622315759435 | Loss: 0.0010806667851284146\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "8b64e55a82ee4ea7a9ed872f7b380b40", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 14 | MAE/CSI: 3.359812260798782 | MAE: 2.65893816947937 | CSI: 0.7913948646773075 | Loss: 0.001116615254431963\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "ccbe2341ad044ae4bfc6145827a6299b", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 15 | MAE/CSI: 3.777346978087415 | MAE: 2.942171096801758 | CSI: 0.7788988181031997 | Loss: 0.0010497045004740357\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e21184f2cd1e41bab275803fb27a24b0", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 16 | MAE/CSI: 3.730972666019791 | MAE: 2.9134128093719482 | CSI: 0.7808721934369602 | Loss: 0.0010338622378185391\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c0625e1c518f44c494e68a8396678b85", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 17 | MAE/CSI: 3.216006483334444 | MAE: 2.5546412467956543 | CSI: 0.7943520201314134 | Loss: 0.0010403376072645187\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "eba978a0c17a43db92d23b8b9ae5b7e2", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 18 | MAE/CSI: 3.6045317048942183 | MAE: 2.8333561420440674 | CSI: 0.786053882725832 | Loss: 0.0010198085801675916\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "4a3e4bf564d8412f9df3a8deed8cf7bf", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 19 | MAE/CSI: 3.2532478057212075 | MAE: 2.5799641609191895 | CSI: 0.7930426192492238 | Loss: 0.0010090331779792905\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "2dd891eb85c74f7da253fdcbb93d7274", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 20 | MAE/CSI: 3.2199748341813246 | MAE: 2.5563158988952637 | CSI: 0.7938931297709924 | Loss: 0.0010016037849709392\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "1284e89674bd47c681dd0cd8238f35f0", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 21 | MAE/CSI: 3.856811757635469 | MAE: 3.0003974437713623 | CSI: 0.7779475982532751 | Loss: 0.0010271386709064245\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "aa73c2927a764e3ea499fa4100beb4e0", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 22 | MAE/CSI: 3.1347424155752437 | MAE: 2.4959778785705566 | CSI: 0.7962306140899923 | Loss: 0.0010296452092006803\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "6fa65efe55b349348c5cd3d08b467a7c", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 23 | MAE/CSI: 3.1336774350690413 | MAE: 2.504286289215088 | CSI: 0.7991525423728814 | Loss: 0.000995357520878315\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "2f4c9935ab584c90a7392f816af6f22a", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 24 | MAE/CSI: 3.155359118626466 | MAE: 2.518603563308716 | CSI: 0.7981987053194484 | Loss: 0.0009966425132006407\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c4f0ad90a664479795c0d0e62c41d0ba", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 25 | MAE/CSI: 3.2987531262931515 | MAE: 2.6152637004852295 | CSI: 0.7928037050231564 | Loss: 0.0009848386980593204\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "7015f2336b984321bf0a58cec1c4bb6a", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 26 | MAE/CSI: 3.3408334380625853 | MAE: 2.6483819484710693 | CSI: 0.7927309150747657 | Loss: 0.0009853563969954848\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "2164a72dd4b143d7a446d2f06a9cbd2a", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 27 | MAE/CSI: 3.151360169059713 | MAE: 2.5201938152313232 | CSI: 0.7997162114224903 | Loss: 0.0009801144478842616\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "12b9d6e90f2141439f7fc488135cc960", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 28 | MAE/CSI: 3.383971888869794 | MAE: 2.6815693378448486 | CSI: 0.7924325100516945 | Loss: 0.0009856465039774776\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "68ba2d0300124d629e54bf417cede96e", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 29 | MAE/CSI: 3.3458963516517897 | MAE: 2.650679349899292 | CSI: 0.7922180101566412 | Loss: 0.0009886184707283974\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "fa96e02187c74c6b90607139f484a227", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 30 | MAE/CSI: 3.540172589897291 | MAE: 2.7885289192199707 | CSI: 0.7876816308826718 | Loss: 0.00101062364410609\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "34ece3d4e7b84afab3d7df74e8739769", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 31 | MAE/CSI: 3.1403178747400826 | MAE: 2.5151991844177246 | CSI: 0.8009377664109122 | Loss: 0.0009796030353754759\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "2b09f7630d0145ed925db402091ef5bb", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 32 | MAE/CSI: 3.137336144450593 | MAE: 2.5105366706848145 | CSI: 0.8002128414331323 | Loss: 0.0009821136482059956\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "8e92d0242e3b4f948b997fd189404b39", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 33 | MAE/CSI: 3.22167231289145 | MAE: 2.5630111694335938 | CSI: 0.7955530297648646 | Loss: 0.0009937712457031012\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "59783284006f494f8795590cc98023d8", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 34 | MAE/CSI: 3.107360140259587 | MAE: 2.48958420753479 | CSI: 0.8011894647408666 | Loss: 0.000986725091934204\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "0a71c0dd88884b45a12eaa2171496784", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 35 | MAE/CSI: 3.3972837724475897 | MAE: 2.691316843032837 | CSI: 0.7921966557095899 | Loss: 0.0010110668372362852\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "7d2929d4f8714b9b8b499cd012651aae", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 36 | MAE/CSI: 3.2956298785612326 | MAE: 2.6264376640319824 | CSI: 0.7969455796945579 | Loss: 0.000985836493782699\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "0b89cff3e84348d286a28c8709925c5f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 37 | MAE/CSI: 3.276042547174778 | MAE: 2.6103878021240234 | CSI: 0.7968113248016014 | Loss: 0.0009991289116442204\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "60df573ae9ba4dd9a1220dfbb3f8867d", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 38 | MAE/CSI: 3.282018963884671 | MAE: 2.615131139755249 | CSI: 0.7968056152413694 | Loss: 0.0009931615786626935\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "1c59c50b13f84856924548b4f314dffa", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 39 | MAE/CSI: 3.1812205203792945 | MAE: 2.543889284133911 | CSI: 0.7996582656984195 | Loss: 0.000992935849353671\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "18c250afe9b04510ad40667227516315", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 40 | MAE/CSI: 3.307319243115345 | MAE: 2.6357016563415527 | CSI: 0.7969299189441217 | Loss: 0.0009866819018498063\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "308babf388cb45ef9d49af0814051083", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 41 | MAE/CSI: 3.143233505972508 | MAE: 2.513828754425049 | CSI: 0.7997588310398638 | Loss: 0.0009999492904171348\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "2bcfdb91f482422a890515f3da45b664", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 42 | MAE/CSI: 3.215217112553536 | MAE: 2.5696022510528564 | CSI: 0.7992002285061411 | Loss: 0.0009870363865047693\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "93459b4781074b5ba8b42980278ce58b", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 43 | MAE/CSI: 3.21669262308394 | MAE: 2.570596933364868 | CSI: 0.7991428571428572 | Loss: 0.000990581582300365\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "3adc190f0616417d86275da9b809f646", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 44 | MAE/CSI: 3.2109438338309406 | MAE: 2.5652711391448975 | CSI: 0.7989149832250696 | Loss: 0.0009919735603034496\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "6945e2154e664154b7068118a5bf35d8", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 45 | MAE/CSI: 3.2196984066052785 | MAE: 2.5702412128448486 | CSI: 0.7982863263120314 | Loss: 0.000992590212263167\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "257d69cb8de6424c8ecc42bf46d3c3d1", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 46 | MAE/CSI: 3.206927726907789 | MAE: 2.562476634979248 | CSI: 0.799044086174918 | Loss: 0.000992824207060039\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "5c2fcaabbb7f49c5979b4a849786956d", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 47 | MAE/CSI: 3.2243089758598296 | MAE: 2.5741519927978516 | CSI: 0.7983577293823635 | Loss: 0.0009943352779373527\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e5dd68f551ae468b9a99f98d15a6bb34", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 48 | MAE/CSI: 3.167185464203879 | MAE: 2.532801866531372 | CSI: 0.7997011526967411 | Loss: 0.0009948504157364368\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "197e1de0e4e14b36bd3db43ff9fd0127", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 49 | MAE/CSI: 3.2167388425628727 | MAE: 2.569028615951538 | CSI: 0.7986438258386866 | Loss: 0.0009935392299667\n", - "\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "GPU available: True, used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", - "Using native 16bit precision.\n", - "\n", - " | Name | Type | Params\n", - "-------------------------------------------\n", - "0 | criterion | SmoothL1Loss | 0 \n", - "1 | tail | BasicBlock | 300 \n", - "2 | encoder | Encoder | 25 M \n", - "3 | decoder | Decoder | 17 M \n", - "4 | head | Sequential | 8 K \n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "cba4f62f6f904b18ba1da5a3be427e92", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 0 | MAE/CSI: 14487452507019.043 | MAE: 14.487452507019043 | CSI: 0.0 | Loss: 0.009541328065097332\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "cd67491c697d48e2af8b00afab7187ed", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "13880b180a4842cab1576d07ae699be4", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 0 | MAE/CSI: 5.392028101629733 | MAE: 3.834322214126587 | CSI: 0.7111094641666049 | Loss: 0.0021661436185240746\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "457a326ec0d3427bb8254ccf00da3dc2", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 1 | MAE/CSI: 4.158585763428692 | MAE: 3.119844675064087 | CSI: 0.7502177068214804 | Loss: 0.0013823203044012189\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "dac1a9d6360c4a10a0872f8e6adb6547", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 2 | MAE/CSI: 4.3142486810508425 | MAE: 3.232961654663086 | CSI: 0.749368405409422 | Loss: 0.0013024784857407212\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "46f1a7778a614768b4ecea9512ea2b58", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 3 | MAE/CSI: 3.7116308049028257 | MAE: 2.8598763942718506 | CSI: 0.770517474553349 | Loss: 0.0013511937577277422\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "949428eee1d84a8385ba5183ceafc8f1", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 4 | MAE/CSI: 3.539873240049237 | MAE: 2.7341103553771973 | CSI: 0.7723752151462995 | Loss: 0.001252486719749868\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "1c078b167e534b7488e0075077bca609", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 5 | MAE/CSI: 4.01579733783149 | MAE: 3.0554652214050293 | CSI: 0.7608614091693554 | Loss: 0.0011789751006290317\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "d0e33790c0fa48318eb793df1cfda842", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 6 | MAE/CSI: 3.7320073356218004 | MAE: 2.864356517791748 | CSI: 0.7675109559533536 | Loss: 0.0011457751970738173\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e5c2cb15737a46159792932e62f99ed3", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 7 | MAE/CSI: 3.401637135420141 | MAE: 2.644503355026245 | CSI: 0.7774207682196582 | Loss: 0.0011405773693695664\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "242db299e4c44b75917ba09747ca309b", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 8 | MAE/CSI: 4.595668000327307 | MAE: 3.438323736190796 | CSI: 0.7481662591687042 | Loss: 0.0012028244091197848\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "26a303f530c14d31977eb3f5f5bcd121", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 9 | MAE/CSI: 3.4759753314680486 | MAE: 2.6975338459014893 | CSI: 0.7760509177027827 | Loss: 0.0010866763768717647\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "7999e78050774bf884f692794676a56f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 10 | MAE/CSI: 3.2317734726219225 | MAE: 2.534553289413452 | CSI: 0.7842608124863248 | Loss: 0.0011001526145264506\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e2f5b41d276640caae430f70230123cb", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 11 | MAE/CSI: 3.2738988791916976 | MAE: 2.5705955028533936 | CSI: 0.7851786501985002 | Loss: 0.0010762200690805912\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "91a8b1a8cdab4d17b7ef8b29bfc06168", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 12 | MAE/CSI: 3.2065835705106345 | MAE: 2.5200467109680176 | CSI: 0.785897718101108 | Loss: 0.001055945991538465\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "d9216bc069204c84b59672bfd1435417", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 13 | MAE/CSI: 3.6612566362170322 | MAE: 2.834601640701294 | CSI: 0.7742155009451795 | Loss: 0.00104250549338758\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "3476cd7419c54969972471d9e7cad7e5", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 14 | MAE/CSI: 3.0992773782456213 | MAE: 2.4434335231781006 | CSI: 0.7883881385789783 | Loss: 0.0010305154137313366\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "6d679fa8b6b04f328bed50f267c8bac8", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 15 | MAE/CSI: 3.639291737971471 | MAE: 2.8226284980773926 | CSI: 0.7755983035443805 | Loss: 0.0010273642838001251\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c19f4005ca2040ce8cfa369ac57a741a", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 16 | MAE/CSI: 3.2597624066956175 | MAE: 2.5623531341552734 | CSI: 0.786055182699478 | Loss: 0.0010038955369964242\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "669a4a35a01f486ba3e9db002921c06f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 17 | MAE/CSI: 3.2570115885248803 | MAE: 2.5611789226531982 | CSI: 0.7863585538576221 | Loss: 0.0010010426631197333\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "d66ba08e376046c691404001699f113b", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 18 | MAE/CSI: 3.406129191555971 | MAE: 2.6659343242645264 | CSI: 0.7826873774694616 | Loss: 0.000992716639302671\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "756646755ab844ae8afad07eb4c7bffa", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 19 | MAE/CSI: 3.144901228592783 | MAE: 2.488124370574951 | CSI: 0.791161371921732 | Loss: 0.0009813575306907296\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "0a0a28cb8de04035925733f5aaae64f5", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 20 | MAE/CSI: 3.0428135418560394 | MAE: 2.419461727142334 | CSI: 0.795139660665333 | Loss: 0.0009844209998846054\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "3bd7d95a09464e2c93e4bbc7b127268f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 21 | MAE/CSI: 3.4147496972918714 | MAE: 2.6745433807373047 | CSI: 0.7832326283987915 | Loss: 0.0009928239742293954\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "bae234623e734a51a4de2d9c169d7de3", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 22 | MAE/CSI: 3.465787560107332 | MAE: 2.704352617263794 | CSI: 0.7802995914661824 | Loss: 0.0009944615885615349\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "b68df5562ca34f53b3a5caab0e14109f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 23 | MAE/CSI: 3.4667651828205974 | MAE: 2.7036044597625732 | CSI: 0.7798637395912188 | Loss: 0.0010142156388610601\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "bff9e27a8ee0441a89a9d15d62fc6a79", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 24 | MAE/CSI: 3.3635030555243413 | MAE: 2.640925884246826 | CSI: 0.7851712457659014 | Loss: 0.0009805120062083006\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "f3ea25fda8674b48839b4671423dd49e", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 25 | MAE/CSI: 3.0347875782549485 | MAE: 2.413892984390259 | CSI: 0.7954075605434141 | Loss: 0.0009923680918291211\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "84d3e7cdc49a4b8090b60914bd297972", - "version_major": 2, - "version_minor": 0 - }, + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 22;\n", + " var nbb_unformatted_code = \"seed_everything(args[\\\"seed\\\"])\\npl.seed_everything(args[\\\"seed\\\"])\\n\\ndf = pd.read_csv(args[\\\"train_folds_csv\\\"])\\n\\nfor fold in range(3, 5):\\n train_df = df[df.fold != fold]\\n val_df = df[df.fold == fold]\\n\\n datamodule = NowcastingDataModule(\\n train_df, val_df, batch_size=args[\\\"batch_size\\\"], num_workers=args[\\\"num_workers\\\"]\\n )\\n datamodule.setup()\\n\\n num_train_steps = (\\n int(\\n np.ceil(\\n len(train_df) // args[\\\"batch_size\\\"] / args[\\\"accumulate_grad_batches\\\"]\\n )\\n )\\n * args[\\\"max_epochs\\\"]\\n )\\n\\n model = UNet(num_train_steps=num_train_steps)\\n\\n trainer = pl.Trainer(\\n gpus=args[\\\"gpus\\\"],\\n max_epochs=args[\\\"max_epochs\\\"],\\n precision=args[\\\"precision\\\"],\\n progress_bar_refresh_rate=50,\\n # accumulate_grad_batches=args[\\\"accumulate_grad_batches\\\"],\\n # gradient_clip_val=args[\\\"gradient_clip_val\\\"],\\n auto_lr_find=True,\\n )\\n\\n # learning rate finder\\n # lr_finder = trainer.tuner.lr_find(model, datamodule=datamodule)\\n # fig = lr_finder.plot(suggest=True)\\n # fig.show()\\n\\n trainer.fit(model, datamodule)\\n trainer.save_checkpoint(\\n args[\\\"model_dir\\\"]\\n / f\\\"unet_fold{fold}_bs{args['batch_size']}_epoch{args['max_epochs']}_{args['optimizer']}_{args['scheduler']}.ckpt\\\"\\n )\\n\\n del datamodule, model, trainer\\n gc.collect()\\n torch.cuda.empty_cache()\";\n", + " var nbb_formatted_code = \"seed_everything(args[\\\"seed\\\"])\\npl.seed_everything(args[\\\"seed\\\"])\\n\\ndf = pd.read_csv(args[\\\"train_folds_csv\\\"])\\n\\nfor fold in range(3, 5):\\n train_df = df[df.fold != fold]\\n val_df = df[df.fold == fold]\\n\\n datamodule = NowcastingDataModule(\\n train_df, val_df, batch_size=args[\\\"batch_size\\\"], num_workers=args[\\\"num_workers\\\"]\\n )\\n datamodule.setup()\\n\\n num_train_steps = (\\n int(\\n np.ceil(\\n len(train_df) // args[\\\"batch_size\\\"] / args[\\\"accumulate_grad_batches\\\"]\\n )\\n )\\n * args[\\\"max_epochs\\\"]\\n )\\n\\n model = UNet(num_train_steps=num_train_steps)\\n\\n trainer = pl.Trainer(\\n gpus=args[\\\"gpus\\\"],\\n max_epochs=args[\\\"max_epochs\\\"],\\n precision=args[\\\"precision\\\"],\\n progress_bar_refresh_rate=50,\\n # accumulate_grad_batches=args[\\\"accumulate_grad_batches\\\"],\\n # gradient_clip_val=args[\\\"gradient_clip_val\\\"],\\n auto_lr_find=True,\\n )\\n\\n # learning rate finder\\n # lr_finder = trainer.tuner.lr_find(model, datamodule=datamodule)\\n # fig = lr_finder.plot(suggest=True)\\n # fig.show()\\n\\n trainer.fit(model, datamodule)\\n trainer.save_checkpoint(\\n args[\\\"model_dir\\\"]\\n / f\\\"unet_fold{fold}_bs{args['batch_size']}_epoch{args['max_epochs']}_{args['optimizer']}_{args['scheduler']}.ckpt\\\"\\n )\\n\\n del datamodule, model, trainer\\n gc.collect()\\n torch.cuda.empty_cache()\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + "" ] }, "metadata": {}, "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" - ] - }, + } + ], + "source": [ + "seed_everything(args[\"seed\"])\n", + "pl.seed_everything(args[\"seed\"])\n", + "\n", + "df = pd.read_csv(args[\"train_folds_csv\"])\n", + "\n", + "for fold in range(3, 5):\n", + " train_df = df[df.fold != fold]\n", + " val_df = df[df.fold == fold]\n", + "\n", + " datamodule = NowcastingDataModule(\n", + " train_df, val_df, batch_size=args[\"batch_size\"], num_workers=args[\"num_workers\"]\n", + " )\n", + " datamodule.setup()\n", + "\n", + " num_train_steps = (\n", + " int(\n", + " np.ceil(\n", + " len(train_df) // args[\"batch_size\"] / args[\"accumulate_grad_batches\"]\n", + " )\n", + " )\n", + " * args[\"max_epochs\"]\n", + " )\n", + "\n", + " model = UNet(num_train_steps=num_train_steps)\n", + "\n", + " trainer = pl.Trainer(\n", + " gpus=args[\"gpus\"],\n", + " max_epochs=args[\"max_epochs\"],\n", + " precision=args[\"precision\"],\n", + " progress_bar_refresh_rate=50,\n", + " # accumulate_grad_batches=args[\"accumulate_grad_batches\"],\n", + " # gradient_clip_val=args[\"gradient_clip_val\"],\n", + " auto_lr_find=True,\n", + " )\n", + "\n", + " # learning rate finder\n", + " # lr_finder = trainer.tuner.lr_find(model, datamodule=datamodule)\n", + " # fig = lr_finder.plot(suggest=True)\n", + " # fig.show()\n", + "\n", + " trainer.fit(model, datamodule)\n", + " trainer.save_checkpoint(\n", + " args[\"model_dir\"]\n", + " / f\"unet_fold{fold}_bs{args['batch_size']}_epoch{args['max_epochs']}_{args['optimizer']}_{args['scheduler']}.ckpt\"\n", + " )\n", + "\n", + " del datamodule, model, trainer\n", + " gc.collect()\n", + " torch.cuda.empty_cache()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Inference" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 26 | MAE/CSI: 3.1692318757620845 | MAE: 2.50390887260437 | CSI: 0.7900680577368933 | Loss: 0.0009772846242412925\n" + "../models/unet_fold0_bs256_epoch50_adamw_cosine.ckpt\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "9f43d086e804488a87511eccdb538a0a", + "model_id": "209aa235815347a9b2522aba0e43f37c", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6.0), HTML(value='')))" ] }, "metadata": {}, @@ -3033,55 +1258,41 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch 27 | MAE/CSI: 3.2020519076581904 | MAE: 2.5177001953125 | CSI: 0.7862771335117454 | Loss: 0.000999476877041161\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" + "\n", + "../models/unet_fold1_bs256_epoch50_adamw_cosine.ckpt\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "c96c437af5804b3fb3fb7be380b7eaa5", + "model_id": "10755cf46d21468e8109a77e1f278596", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 28 | MAE/CSI: 3.0848241435457164 | MAE: 2.443657398223877 | CSI: 0.7921545230815125 | Loss: 0.0009700111113488674\n" + "\n", + "../models/unet_fold2_bs256_epoch50_adamw_cosine.ckpt\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "82fc8df75f8740fd810e194ca9aa5f43", + "model_id": "bf62d681150647e5b1397731d1484a54", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6.0), HTML(value='')))" ] }, "metadata": {}, @@ -3091,26 +1302,19 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch 29 | MAE/CSI: 3.1461110961318504 | MAE: 2.488356351852417 | CSI: 0.7909308590242442 | Loss: 0.0009753701160661876\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" + "\n", + "../models/unet_fold3_bs256_epoch50_adamw_cosine.ckpt\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "32e8d8235b2a4cceb680173996d837d6", + "model_id": "687733fd26c745d8a01ff94af632ffeb", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6.0), HTML(value='')))" ] }, "metadata": {}, @@ -3120,18 +1324,19 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch 30 | MAE/CSI: 3.1462543540290406 | MAE: 2.4851224422454834 | CSI: 0.7898669855029143 | Loss: 0.0009774373611435294\n" + "\n", + "../models/unet_fold4_bs256_epoch50_adamw_cosine.ckpt\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "d51d498f3f8d4c2c9ce24c210d226c1b", + "model_id": "350aa5098fd3423aa7487492fcf1ec48", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6.0), HTML(value='')))" ] }, "metadata": {}, @@ -3141,108 +1346,37 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch 31 | MAE/CSI: 3.0840243109285304 | MAE: 2.4406659603118896 | CSI: 0.7913899873162725 | Loss: 0.0009708466241136193\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" + "\n" ] }, { "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "760cb02f21294ee6aba5baaa19e8337b", - "version_major": 2, - "version_minor": 0 - }, + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 25;\n", + " var nbb_unformatted_code = \"datamodule = NowcastingDataModule()\\ndatamodule.setup(\\\"test\\\")\\n\\nfinal_preds = np.zeros((len(datamodule.test_dataset), 120, 120))\\n\\nfor fold in range(5):\\n checkpoint = (\\n args[\\\"model_dir\\\"]\\n / f\\\"unet_fold{fold}_bs{args['batch_size']}_epoch{args['max_epochs']}_{args['optimizer']}_{args['scheduler']}.ckpt\\\"\\n )\\n print(checkpoint)\\n model = UNet.load_from_checkpoint(str(checkpoint))\\n model.cuda()\\n model.eval()\\n preds = []\\n with torch.no_grad():\\n for batch in tqdm(datamodule.test_dataloader()):\\n batch = batch.cuda()\\n imgs = model(batch)\\n imgs = imgs.detach().cpu().numpy()\\n imgs = imgs[:, 0, 4:124, 4:124]\\n imgs = args[\\\"rng\\\"] * imgs\\n imgs = imgs.clip(0, 255)\\n imgs = imgs.round()\\n preds.append(imgs)\\n\\n preds = np.concatenate(preds)\\n preds = preds.astype(np.uint8)\\n final_preds += preds\\n\\n del model\\n gc.collect()\\n torch.cuda.empty_cache()\\n\\nfinal_preds = final_preds.astype(np.uint8)\\nfinal_preds = final_preds.reshape(-1, 14400)\";\n", + " var nbb_formatted_code = \"datamodule = NowcastingDataModule()\\ndatamodule.setup(\\\"test\\\")\\n\\nfinal_preds = np.zeros((len(datamodule.test_dataset), 120, 120))\\n\\nfor fold in range(5):\\n checkpoint = (\\n args[\\\"model_dir\\\"]\\n / f\\\"unet_fold{fold}_bs{args['batch_size']}_epoch{args['max_epochs']}_{args['optimizer']}_{args['scheduler']}.ckpt\\\"\\n )\\n print(checkpoint)\\n model = UNet.load_from_checkpoint(str(checkpoint))\\n model.cuda()\\n model.eval()\\n preds = []\\n with torch.no_grad():\\n for batch in tqdm(datamodule.test_dataloader()):\\n batch = batch.cuda()\\n imgs = model(batch)\\n imgs = imgs.detach().cpu().numpy()\\n imgs = imgs[:, 0, 4:124, 4:124]\\n imgs = args[\\\"rng\\\"] * imgs\\n imgs = imgs.clip(0, 255)\\n imgs = imgs.round()\\n preds.append(imgs)\\n\\n preds = np.concatenate(preds)\\n preds = preds.astype(np.uint8)\\n final_preds += preds\\n\\n del model\\n gc.collect()\\n torch.cuda.empty_cache()\\n\\nfinal_preds = final_preds.astype(np.uint8)\\nfinal_preds = final_preds.reshape(-1, 14400)\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], "text/plain": [ - "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" + "" ] }, "metadata": {}, "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 32 | MAE/CSI: 3.122035965640138 | MAE: 2.4696598052978516 | CSI: 0.7910414333706607 | Loss: 0.0009853055234998465\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../src/utils.py:36: RuntimeWarning: overflow encountered in power\n", - " z = np.power(10.0, dbz / 10.0)\n" - ] } ], - "source": [ - "seed_everything(args[\"seed\"])\n", - "pl.seed_everything(args[\"seed\"])\n", - "\n", - "df = pd.read_csv(args[\"train_folds_csv\"])\n", - "\n", - "for fold in range(5):\n", - " train_df = df[df.fold != fold]\n", - " val_df = df[df.fold == fold]\n", - "\n", - " datamodule = NowcastingDataModule(\n", - " train_df, val_df, batch_size=args[\"batch_size\"], num_workers=args[\"num_workers\"]\n", - " )\n", - " datamodule.setup()\n", - "\n", - " num_train_steps = (\n", - " int(\n", - " np.ceil(\n", - " len(train_df) // args[\"batch_size\"] / args[\"accumulate_grad_batches\"]\n", - " )\n", - " )\n", - " * args[\"max_epochs\"]\n", - " )\n", - "\n", - " model = UNet(num_train_steps=num_train_steps)\n", - "\n", - " trainer = pl.Trainer(\n", - " gpus=args[\"gpus\"],\n", - " max_epochs=args[\"max_epochs\"],\n", - " precision=args[\"precision\"],\n", - " progress_bar_refresh_rate=50,\n", - " # accumulate_grad_batches=args[\"accumulate_grad_batches\"],\n", - " # gradient_clip_val=args[\"gradient_clip_val\"],\n", - " auto_lr_find=True,\n", - " )\n", - "\n", - " # learning rate finder\n", - " # lr_finder = trainer.tuner.lr_find(model, datamodule=datamodule)\n", - " # fig = lr_finder.plot(suggest=True)\n", - " # fig.show()\n", - "\n", - " trainer.fit(model, datamodule)\n", - " trainer.save_checkpoint(f\"unet_fold{fold}_bs{args['batch_size']}_epoch{args['max_epochs']}.ckpt\")\n", - "\n", - " del datamodule, model, trainer\n", - " gc.collect()\n", - " torch.cuda.empty_cache()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Inference" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "datamodule = NowcastingDataModule()\n", "datamodule.setup(\"test\")\n", @@ -3250,7 +1384,12 @@ "final_preds = np.zeros((len(datamodule.test_dataset), 120, 120))\n", "\n", "for fold in range(5):\n", - " model = UNet.load_from_checkpoint(f\"unet_fold{fold}_bs{args['batch_size']}_epoch{args['max_epochs']}\")\n", + " checkpoint = (\n", + " args[\"model_dir\"]\n", + " / f\"unet_fold{fold}_bs{args['batch_size']}_epoch{args['max_epochs']}_{args['optimizer']}_{args['scheduler']}.ckpt\"\n", + " )\n", + " print(checkpoint)\n", + " model = UNet.load_from_checkpoint(str(checkpoint))\n", " model.cuda()\n", " model.eval()\n", " preds = []\n", @@ -3268,19 +1407,48 @@ " preds = np.concatenate(preds)\n", " preds = preds.astype(np.uint8)\n", " final_preds += preds\n", - " \n", + "\n", " del model\n", " gc.collect()\n", " torch.cuda.empty_cache()\n", - " \n", + "\n", + "final_preds = final_preds.astype(np.uint8)\n", "final_preds = final_preds.reshape(-1, 14400)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 26, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 26;\n", + " var nbb_unformatted_code = \"test_paths = datamodule.test_dataset.paths\\ntest_filenames = [path.name for path in test_paths]\";\n", + " var nbb_formatted_code = \"test_paths = datamodule.test_dataset.paths\\ntest_filenames = [path.name for path in test_paths]\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "test_paths = datamodule.test_dataset.paths\n", "test_filenames = [path.name for path in test_paths]" @@ -3288,22 +1456,294 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 27, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e2360942148d490c8fd59244a4e174f3", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14400.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 27;\n", + " var nbb_unformatted_code = \"subm = pd.DataFrame({\\\"file_name\\\": test_filenames})\\nfor i in tqdm(range(14400)):\\n subm[str(i)] = final_preds[:, i]\";\n", + " var nbb_formatted_code = \"subm = pd.DataFrame({\\\"file_name\\\": test_filenames})\\nfor i in tqdm(range(14400)):\\n subm[str(i)] = final_preds[:, i]\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "subm = pd.DataFrame({\"file_name\"}: test_filenames)\n", + "subm = pd.DataFrame({\"file_name\": test_filenames})\n", "for i in tqdm(range(14400)):\n", " subm[str(i)] = final_preds[:, i]" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 28, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
file_name012345678...14390143911439214393143941439514396143971439814399
0test_00000.npy000000000...0000000000
1test_00001.npy000000000...0000000000
2test_00002.npy000000000...0000000000
3test_00003.npy000000000...0000000000
4test_00004.npy000000000...0000000000
\n", + "

5 rows × 14401 columns

\n", + "
" + ], + "text/plain": [ + " file_name 0 1 2 3 4 5 6 7 8 ... 14390 14391 14392 14393 \\\n", + "0 test_00000.npy 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 \n", + "1 test_00001.npy 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 \n", + "2 test_00002.npy 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 \n", + "3 test_00003.npy 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 \n", + "4 test_00004.npy 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 \n", + "\n", + " 14394 14395 14396 14397 14398 14399 \n", + "0 0 0 0 0 0 0 \n", + "1 0 0 0 0 0 0 \n", + "2 0 0 0 0 0 0 \n", + "3 0 0 0 0 0 0 \n", + "4 0 0 0 0 0 0 \n", + "\n", + "[5 rows x 14401 columns]" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "application/javascript": [ + "\n", + " setTimeout(function() {\n", + " var nbb_cell_id = 28;\n", + " var nbb_unformatted_code = \"subm.to_csv(\\n f\\\"unet_bs{args['batch_size']}_epoch{args['max_epochs']}_{args['optimizer']}_{args['scheduler']}.csv\\\",\\n index=False,\\n)\\nsubm.head()\";\n", + " var nbb_formatted_code = \"subm.to_csv(\\n f\\\"unet_bs{args['batch_size']}_epoch{args['max_epochs']}_{args['optimizer']}_{args['scheduler']}.csv\\\",\\n index=False,\\n)\\nsubm.head()\";\n", + " var nbb_cells = Jupyter.notebook.get_cells();\n", + " for (var i = 0; i < nbb_cells.length; ++i) {\n", + " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", + " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", + " nbb_cells[i].set_text(nbb_formatted_code);\n", + " }\n", + " break;\n", + " }\n", + " }\n", + " }, 500);\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "subm.to_csv(f\"unet_bs{args['batch_size']}_epoch{args['max_epochs']}_lr{model.lr}.csv\", index=False)\n", + "subm.to_csv(\n", + " f\"unet_bs{args['batch_size']}_epoch{args['max_epochs']}_{args['optimizer']}_{args['scheduler']}.csv\",\n", + " index=False,\n", + ")\n", "subm.head()" ] }, @@ -3363,13 +1803,6 @@ "outputs": [], "source": [] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": null, diff --git a/src/utils.py b/src/utils.py index 2a776cd..4baad5b 100644 --- a/src/utils.py +++ b/src/utils.py @@ -33,6 +33,11 @@ def visualize(x, y=None, test=False): def radar2precipitation(radar): """Convert radar to precipitation.""" dbz = ((radar - 0.5) / 255.0) * 70 - 10 + + # Numerically stable + dbz_max = np.max(dbz) + dbz -= dbz_max z = np.power(10.0, dbz / 10.0) + z *= np.power(10.0, dbz_max / 10.0) r = np.power(z / 200.0, 1.0 / 1.6) return r