Skip to content

Commit

Permalink
suggested changes
Browse files Browse the repository at this point in the history
  • Loading branch information
rijobro committed Dec 10, 2020
1 parent e932a88 commit 0b5096e
Showing 1 changed file with 60 additions and 24 deletions.
84 changes: 60 additions & 24 deletions modules/CAM/class_lung_lesion.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -46,28 +46,50 @@
"id": "nZYowi8fUVp5"
},
"source": [
"import os\n",
"import glob\n",
"import random\n",
"import numpy as np\n",
"import torch\n",
"from torch.utils.tensorboard import SummaryWriter\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import monai\n",
"from monai.transforms import (\n",
" AddChanneld, Compose, LoadImaged, RandRotate90d, \n",
" Resized, ScaleIntensityRanged, ToTensord, \n",
" RandFlipd, RandSpatialCropd,\n",
")\n",
"from monai.metrics import compute_occlusion_sensitivity\n",
"\n",
"monai.config.print_config()\n",
"monai.utils.set_determinism(42)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"random_seed = 42\n",
"monai.utils.set_determinism(random_seed)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"use_patch_dataset=True # switch this to use partial dataset or whole thing\n",
"\n",
"directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n",
"root_dir = tempfile.mkdtemp() if directory is None else os.path.expanduser(directory)\n",
"data_path=os.path.join(root_dir, \"patch\")\n",
"if use_patch_dataset:\n",
" !cd {root_dir} && [ -f \"lung_lesion_patches.tar.gz\" ] || gdown --id 1Jte6L7B_5q9XMgOCAq1Ldn29F1aaIxjW \\\n",
" && mkdir -p {data_path} && tar -xvf \"lung_lesion_patches.tar.gz\" -C {data_path} > /dev/null\n",
"else:\n",
" # TODO: took too long to download data to fix this.\n",
" monai.apps.DecathlonDataset(root_dir=root_dir, task=\"Task06_Lung\", section=\"training\", download=True)\n",
" %run -i './bbox_gen.py {root_dir}'"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YcfTvuxyy9jX"
},
Expand All @@ -85,8 +107,8 @@
"outputId": "773bfa78-4a34-4692-9540-f6bb0babd7d9"
},
"source": [
"lesion = glob.glob(\"./patch/lesion_*\")\n",
"non_lesion = glob.glob(\"./patch/norm_*\")\n",
"lesion = glob.glob(os.path.join(data_path,\"lesion_*\"))\n",
"non_lesion = glob.glob(os.path.join(data_path,\"norm_*\"))\n",
"labels = np.asarray([[0., 1.]] * len(lesion) + [[1., 0.]] * len(non_lesion))\n",
"\n",
"all_files = [{\"image\": img, \"label\": label} for img, label in zip(lesion + non_lesion, labels)]\n",
Expand Down Expand Up @@ -235,12 +257,16 @@
"id": "rA_cp54ebxRv"
},
"source": [
"from IPython.display import clear_output\n",
"\n",
"# start training\n",
"val_interval, total_epochs = 1, 30\n",
"best_metric = best_metric_epoch = -1\n",
"epoch_loss_values = metric_values = []\n",
"epoch_loss_values = []\n",
"metric_values = []\n",
"scaler = torch.cuda.amp.GradScaler()\n",
"for epoch in range(total_epochs):\n",
" clear_output()\n",
" print(\"-\" * 10)\n",
" print(f\"epoch {epoch + 1}/{total_epochs}\")\n",
" model.train()\n",
Expand Down Expand Up @@ -284,15 +310,28 @@
" print(\n",
" \"current epoch: {} current accuracy: {:.4f} best accuracy: {:.4f} at epoch {}\".format(\n",
" epoch + 1, metric, best_metric, best_metric_epoch\n",
" )\n",
" )\n",
"print(f\"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}\")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
" )\n",
" )\n",
"print(f\"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"plt.plot(epoch_loss_values,label='train')\n",
"val_epochs=np.linspace(1, total_epochs,np.floor(total_epochs/val_interval).astype(np.int32))\n",
"plt.plot(val_epochs, metric_values,label='validation')\n",
"plt.legend()\n",
"plt.xlabel('Epoch')\n",
"plt.ylabel('Loss')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iHE0RMEFIera"
},
Expand Down Expand Up @@ -341,10 +380,7 @@
"outputId": "ffcda223-4c1a-42be-a0b6-7f393d06e25a"
},
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"monai.utils.set_determinism(42)\n",
"train_transforms.set_random_state(42)\n",
"train_transforms.set_random_state(random_seed)\n",
"items = [41, 45, 80, 78, 10]\n",
"plt.subplots(2, len(items))\n",
"for idx, item in enumerate(items):\n",
Expand Down

0 comments on commit 0b5096e

Please sign in to comment.