diff --git a/exercise.ipynb b/exercise.ipynb index 5b23448..a8fee33 100644 --- a/exercise.ipynb +++ b/exercise.ipynb @@ -960,7 +960,9 @@ "cell_type": "code", "execution_count": null, "id": "1d678c03", - "metadata": {}, + "metadata": { + "lines_to_next_cell": 0 + }, "outputs": [], "source": [ "# Check if GPU is available\n", @@ -986,7 +988,11 @@ " ),\n", ")\n", "# Launch training and check that loss and images are being logged on tensorboard.\n", - "trainer.fit(phase2fluor_model, datamodule=phase2fluor_2D_data)" + "trainer.fit(phase2fluor_model, datamodule=phase2fluor_2D_data)\n", + "\n", + "# Move the model to the GPU.\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "phase2fluor_model.to(device)" ] }, { @@ -1322,8 +1328,11 @@ "pretrained_phase2fluor = VSUNet.load_from_checkpoint(\n", " pretrained_model_ckpt,\n", " architecture=...,\n", - " module_config=phase2fluor_config,\n", - ")" + " model_config=phase2fluor_config,\n", + " accelerator='gpu'\n", + ")\n", + "# TODO: Setup the dataloader in evaluation/predict mode\n", + "#" ] }, { diff --git a/solution.ipynb b/solution.ipynb index 4238582..56a6908 100644 --- a/solution.ipynb +++ b/solution.ipynb @@ -944,7 +944,9 @@ "cell_type": "code", "execution_count": null, "id": "b5539a48", - "metadata": {}, + "metadata": { + "lines_to_next_cell": 0 + }, "outputs": [], "source": [ "# Check if GPU is available\n", @@ -970,7 +972,11 @@ " ),\n", ")\n", "# Launch training and check that loss and images are being logged on tensorboard.\n", - "trainer.fit(phase2fluor_model, datamodule=phase2fluor_2D_data)" + "trainer.fit(phase2fluor_model, datamodule=phase2fluor_2D_data)\n", + "\n", + "# Move the model to the GPU.\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "phase2fluor_model.to(device)" ] }, { @@ -1285,7 +1291,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5cb1fcdf", + "id": "a1c5193f", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1317,6 +1323,7 @@ " pretrained_model_ckpt,\n", " architecture=\"UNeXt2_2D\",\n", " model_config = phase2fluor_config,\n", + " accelerator='gpu'\n", ")\n", "pretrained_phase2fluor.eval()\n", "\n",