diff --git a/04_pytorch_custom_datasets.ipynb b/04_pytorch_custom_datasets.ipynb new file mode 100644 index 0000000..2b6385f --- /dev/null +++ b/04_pytorch_custom_datasets.ipynb @@ -0,0 +1,243 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "gpuType": "T4", + "authorship_tag": "ABX9TyOj1u8Pwdu+m8kZy1VQC8bj", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Importing PyTorch and setting up device-agnostic code" + ], + "metadata": { + "id": "9OSpxyEwSfAI" + } + }, + { + "cell_type": "code", + "source": [ + "import torch\n", + "from torch import nn\n", + "\n", + "# Note: this notebook requires torch >= 1.10.0\n", + "torch.__version__" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 35 + }, + "id": "XHVVGzhSS3XX", + "outputId": "39b8c0c0-6403-430a-b2a4-5766337ff800" + }, + "execution_count": 1, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "'2.1.0+cu121'" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + } + }, + "metadata": {}, + "execution_count": 1 + } + ] + }, + { + "cell_type": "code", + "source": [ + "# Setup device-agnostic code\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "device" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 35 + }, + "id": "mrpx2f3JS3U6", + "outputId": "c15cb9ae-35de-41a8-8c0b-3c0ba381746a" + }, + "execution_count": 2, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "'cuda'" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + } + }, + "metadata": {}, + "execution_count": 2 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Returning Dataset\n", + "\n", + "* Original Food101 dataset and paper website.\n", + "* `torchvision.datasets.Food101` - the version of the data I downloaded for this notebook.\n", + "* `extras/04_custom_data_creation.ipynb` - a notebook I used to format the Food101 dataset to use for this notebook.\n", + "* `data/pizza_steak_sushi.zip` - the zip archive of pizza, steak and sushi images from Food101, created with the notebook linked above.\n", + "\n" + ], + "metadata": { + "id": "vVurMo0LUHC2" + } + }, + { + "cell_type": "code", + "source": [ + "import requests\n", + "import zipfile\n", + "from pathlib import Path\n", + "\n", + "# Setup path to data folder\n", + "data_path = Path(\"data/\")\n", + "image_path = data_path / \"pizza_steak_sushi\"\n", + "\n", + "# If the image folder doesn't exist, download it and prepare it...\n", + "if image_path.is_dir():\n", + " print(f\"{image_path} directory exists.\")\n", + "else:\n", + " print(f\"Did not find {image_path} directory, creating one...\")\n", + " image_path.mkdir(parents=True, exist_ok=True)\n", + "\n", + " # Download pizza, steak, sushi data\n", + " with open(data_path / \"pizza_steak_sushi.zip\", \"wb\") as f:\n", + " request = requests.get(\"https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip\")\n", + " print(\"Downloading pizza, steak, sushi data...\")\n", + " f.write(request.content)\n", + "\n", + " # Unzip pizza, steak, sushi data\n", + " with zipfile.ZipFile(data_path / \"pizza_steak_sushi.zip\", \"r\") as zip_ref:\n", + " print(\"Unzipping pizza, steak, sushi data...\")\n", + " zip_ref.extractall(image_path)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "xW1aHd0_S3Sr", + "outputId": "7ac0cd74-e9ee-4176-a51b-5d7c2cb6bf38" + }, + "execution_count": 3, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Did not find data/pizza_steak_sushi directory, creating one...\n", + "Downloading pizza, steak, sushi data...\n", + "Unzipping pizza, steak, sushi data...\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "RxuGixriS3Qk" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "sd3BoVdBS3OT" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "5Y1cD1FoS3L7" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "TFlZtM0kS3Jq" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "qUwEz4N9S3Hr" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "EDbW3ZrSS3Fr" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "r4JRci5uS3Dp" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "F7uTEDrgS3Bj" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file