diff --git a/.readthedocs.yml b/.readthedocs.yml index 6eeec2eb41..bb9ffdbdd8 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -20,4 +20,4 @@ python: version: 3.7 install: - requirements: requirements/docs.txt - #- requirements: requirements.txt \ No newline at end of file + #- requirements: requirements.txt diff --git a/Makefile b/Makefile index 9bcddf4426..6fcee001e6 100644 --- a/Makefile +++ b/Makefile @@ -25,4 +25,4 @@ clean: rm -rf .pytest_cache rm -rf ./docs/build rm -rf ./docs/source/**/generated - rm -rf ./docs/source/api \ No newline at end of file + rm -rf ./docs/source/api diff --git a/docs/source/_static/images/logo.svg b/docs/source/_static/images/logo.svg index ad94333a72..2c3e330bbf 100644 --- a/docs/source/_static/images/logo.svg +++ b/docs/source/_static/images/logo.svg @@ -1 +1 @@ - \ No newline at end of file + diff --git a/flash/vision/classification/transforms.py b/flash/vision/classification/transforms.py index 3eff2f4c2c..2406ec9b24 100644 --- a/flash/vision/classification/transforms.py +++ b/flash/vision/classification/transforms.py @@ -37,8 +37,7 @@ def default_train_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable] ), "post_tensor_transform": ApplyToKeys( DefaultDataKeys.INPUT, - # TODO (Edgar): replace with resize once kornia is fixed - K.augmentation.RandomResizedCrop(image_size, scale=(1.0, 1.0), ratio=(1.0, 1.0)), + K.geometry.Resize(image_size), K.augmentation.RandomHorizontalFlip(), ), "per_batch_transform_on_device": ApplyToKeys( @@ -70,8 +69,7 @@ def default_val_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: ), "post_tensor_transform": ApplyToKeys( DefaultDataKeys.INPUT, - # TODO (Edgar): replace with resize once kornia is fixed - K.augmentation.RandomResizedCrop(image_size, scale=(1.0, 1.0), ratio=(1.0, 1.0)), + K.geometry.Resize(image_size), ), "per_batch_transform_on_device": ApplyToKeys( DefaultDataKeys.INPUT, diff --git a/flash_notebooks/custom_task_tutorial.ipynb b/flash_notebooks/custom_task_tutorial.ipynb index 93b214ff62..10c698573e 100644 --- a/flash_notebooks/custom_task_tutorial.ipynb +++ b/flash_notebooks/custom_task_tutorial.ipynb @@ -40,18 +40,21 @@ "metadata": {}, "outputs": [], "source": [ - "from typing import Any, List, Tuple\n", + "from typing import Any, List, Tuple, Dict\n", "\n", "import numpy as np\n", "import torch\n", "from pytorch_lightning import seed_everything\n", "from sklearn import datasets\n", "from sklearn.model_selection import train_test_split\n", - "from torch import nn\n", + "from torch import nn, Tensor\n", "\n", "import flash\n", "from flash.data.auto_dataset import AutoDataset\n", - "from flash.data.process import Postprocess, Preprocess" + "from flash.data.data_source import DataSource\n", + "from flash.data.process import Preprocess\n", + "\n", + "ND = np.ndarray" ] }, { @@ -152,24 +155,43 @@ "metadata": {}, "outputs": [], "source": [ - "class NumpyRegressionPreprocess(Preprocess):\n", + "class NumpyDataSource(DataSource):\n", "\n", - " def load_data(self, data: Tuple[np.ndarray, np.ndarray], dataset: AutoDataset) -> List[Tuple[np.ndarray, float]]:\n", + " def load_data(self, data: Tuple[ND, ND], dataset: AutoDataset) -> List[Tuple[ND, float]]:\n", " if self.training:\n", " dataset.num_inputs = data[0].shape[1]\n", " return [(x, y) for x, y in zip(*data)]\n", "\n", - " def to_tensor_transform(self, sample: Any) -> Tuple[torch.Tensor, torch.Tensor]:\n", + " def predict_load_data(self, data: ND) -> ND:\n", + " return data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class NumpyPreprocess(Preprocess):\n", + "\n", + " def __init__(self):\n", + " super().__init__(data_sources={\"numpy\": NumpyDataSource()}, default_data_source=\"numpy\")\n", + "\n", + " def to_tensor_transform(self, sample: Any) -> Tuple[Tensor, Tensor]:\n", " x, y = sample\n", " x = torch.from_numpy(x).float()\n", " y = torch.tensor(y, dtype=torch.float)\n", " return x, y\n", "\n", - " def predict_load_data(self, data: np.ndarray) -> np.ndarray:\n", - " return data\n", + " def predict_to_tensor_transform(self, sample: ND) -> ND:\n", + " return torch.from_numpy(sample).float()\n", + "\n", + " def get_state_dict(self) -> Dict[str, Any]:\n", + " return {}\n", "\n", - " def predict_to_tensor_transform(self, sample: np.ndarray) -> np.ndarray:\n", - " return torch.from_numpy(sample).float()\n" + " @classmethod\n", + " def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool):\n", + " return cls()" ] }, { @@ -181,15 +203,16 @@ "class SklearnDataModule(flash.DataModule):\n", "\n", " @classmethod\n", - " def from_dataset(cls, x: np.ndarray, y: np.ndarray, batch_size: int = 64, num_workers: int = 0):\n", + " def from_dataset(cls, x: np.ndarray, y: np.ndarray, preprocess: Preprocess, batch_size: int = 64, num_workers: int = 0):\n", "\n", - " preprocess = NumpyRegressionPreprocess()\n", + " preprocess = preprocess\n", "\n", " x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.20, random_state=0)\n", "\n", - " dm = cls.from_load_data_inputs(\n", - " train_load_data_input=(x_train, y_train),\n", - " test_load_data_input=(x_test, y_test),\n", + " dm = cls.from_data_source(\n", + " \"numpy\",\n", + " train_data=(x_train, y_train),\n", + " test_data=(x_test, y_test),\n", " preprocess=preprocess,\n", " batch_size=batch_size,\n", " num_workers=num_workers\n", @@ -204,7 +227,8 @@ "metadata": {}, "outputs": [], "source": [ - "datamodule = SklearnDataModule.from_dataset(*datasets.load_diabetes(return_X_y=True))" + "x, y = datasets.load_diabetes(return_X_y=True)\n", + "datamodule = SklearnDataModule.from_dataset(x, y, NumpyPreprocess())" ] }, { @@ -350,25 +374,7 @@ ] } ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.5" - } - }, + "metadata": {}, "nbformat": 4, "nbformat_minor": 4 } diff --git a/flash_notebooks/generic_task.ipynb b/flash_notebooks/generic_task.ipynb index 4e7b25e465..a3ddd6a86b 100644 --- a/flash_notebooks/generic_task.ipynb +++ b/flash_notebooks/generic_task.ipynb @@ -34,7 +34,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "straight-vision", "metadata": {}, "outputs": [], @@ -45,7 +45,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "foreign-design", "metadata": {}, "outputs": [], @@ -55,12 +55,12 @@ "from torch.utils.data import DataLoader, random_split\n", "from torchvision import datasets, transforms\n", "\n", - "from flash import ClassificationTask" + "from flash.core.classification import ClassificationTask" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "mathematical-barbados", "metadata": {}, "outputs": [], @@ -83,7 +83,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "id": "selective-request", "metadata": {}, "outputs": [], @@ -106,41 +106,10 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "id": "stunning-anime", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz\n" - ] - }, - { - "ename": "HTTPError", - "evalue": "HTTP Error 503: Service Unavailable", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mHTTPError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mMNIST\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'./data'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdownload\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtransform\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtransforms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mToTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m~/Documents/GitHub/lightning-flash/.venv/lib/python3.8/site-packages/torchvision/datasets/mnist.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, root, train, transform, target_transform, download)\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 78\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdownload\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 79\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdownload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 80\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_check_exists\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Documents/GitHub/lightning-flash/.venv/lib/python3.8/site-packages/torchvision/datasets/mnist.py\u001b[0m in \u001b[0;36mdownload\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmd5\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresources\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 145\u001b[0m \u001b[0mfilename\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0murl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrpartition\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'/'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 146\u001b[0;31m \u001b[0mdownload_and_extract_archive\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdownload_root\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mraw_folder\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfilename\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmd5\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmd5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 147\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[0;31m# process and save as torch files\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Documents/GitHub/lightning-flash/.venv/lib/python3.8/site-packages/torchvision/datasets/utils.py\u001b[0m in \u001b[0;36mdownload_and_extract_archive\u001b[0;34m(url, download_root, extract_root, filename, md5, remove_finished)\u001b[0m\n\u001b[1;32m 312\u001b[0m \u001b[0mfilename\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbasename\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 313\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 314\u001b[0;31m \u001b[0mdownload_url\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdownload_root\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmd5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 315\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 316\u001b[0m \u001b[0marchive\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdownload_root\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Documents/GitHub/lightning-flash/.venv/lib/python3.8/site-packages/torchvision/datasets/utils.py\u001b[0m in \u001b[0;36mdownload_url\u001b[0;34m(url, root, filename, md5, max_redirect_hops)\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[0m_urlretrieve\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 139\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 140\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 141\u001b[0m \u001b[0;31m# check integrity of downloaded file\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mcheck_integrity\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfpath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmd5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Documents/GitHub/lightning-flash/.venv/lib/python3.8/site-packages/torchvision/datasets/utils.py\u001b[0m in \u001b[0;36mdownload_url\u001b[0;34m(url, root, filename, md5, max_redirect_hops)\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 131\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Downloading '\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0murl\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m' to '\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mfpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 132\u001b[0;31m \u001b[0m_urlretrieve\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 133\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0murllib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0merror\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mURLError\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mIOError\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# type: ignore[attr-defined]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0murl\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'https'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Documents/GitHub/lightning-flash/.venv/lib/python3.8/site-packages/torchvision/datasets/utils.py\u001b[0m in \u001b[0;36m_urlretrieve\u001b[0;34m(url, filename, chunk_size)\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_urlretrieve\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mchunk_size\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1024\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilename\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"wb\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mfh\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 29\u001b[0;31m \u001b[0;32mwith\u001b[0m \u001b[0murllib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0murlopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murllib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mRequest\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mheaders\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m\"User-Agent\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUSER_AGENT\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mresponse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 30\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtotal\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mresponse\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlength\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mpbar\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mchunk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mresponse\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mchunk_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/3.8.5/lib/python3.8/urllib/request.py\u001b[0m in \u001b[0;36murlopen\u001b[0;34m(url, data, timeout, cafile, capath, cadefault, context)\u001b[0m\n\u001b[1;32m 220\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 221\u001b[0m \u001b[0mopener\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_opener\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 222\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mopener\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 223\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 224\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0minstall_opener\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopener\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/3.8.5/lib/python3.8/urllib/request.py\u001b[0m in \u001b[0;36mopen\u001b[0;34m(self, fullurl, data, timeout)\u001b[0m\n\u001b[1;32m 529\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mprocessor\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_response\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprotocol\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 530\u001b[0m \u001b[0mmeth\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprocessor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmeth_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 531\u001b[0;31m \u001b[0mresponse\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmeth\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreq\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresponse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 532\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 533\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresponse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/3.8.5/lib/python3.8/urllib/request.py\u001b[0m in \u001b[0;36mhttp_response\u001b[0;34m(self, request, response)\u001b[0m\n\u001b[1;32m 638\u001b[0m \u001b[0;31m# request was successfully received, understood, and accepted.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 639\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m200\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0mcode\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0;36m300\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 640\u001b[0;31m response = self.parent.error(\n\u001b[0m\u001b[1;32m 641\u001b[0m 'http', request, response, code, msg, hdrs)\n\u001b[1;32m 642\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/3.8.5/lib/python3.8/urllib/request.py\u001b[0m in \u001b[0;36merror\u001b[0;34m(self, proto, *args)\u001b[0m\n\u001b[1;32m 567\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhttp_err\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 568\u001b[0m \u001b[0margs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'default'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'http_error_default'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0morig_args\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 569\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_chain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 570\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 571\u001b[0m \u001b[0;31m# XXX probably also want an abstract factory that knows when it makes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/3.8.5/lib/python3.8/urllib/request.py\u001b[0m in \u001b[0;36m_call_chain\u001b[0;34m(self, chain, kind, meth_name, *args)\u001b[0m\n\u001b[1;32m 500\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhandler\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mhandlers\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 501\u001b[0m \u001b[0mfunc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhandler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmeth_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 502\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 503\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 504\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/3.8.5/lib/python3.8/urllib/request.py\u001b[0m in \u001b[0;36mhttp_error_default\u001b[0;34m(self, req, fp, code, msg, hdrs)\u001b[0m\n\u001b[1;32m 647\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0mHTTPDefaultErrorHandler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mBaseHandler\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 648\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mhttp_error_default\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreq\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhdrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 649\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mHTTPError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreq\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_url\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhdrs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfp\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 650\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 651\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0mHTTPRedirectHandler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mBaseHandler\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mHTTPError\u001b[0m: HTTP Error 503: Service Unavailable" - ] - } - ], + "outputs": [], "source": [ "dataset = datasets.MNIST('./data', download=True, transform=transforms.ToTensor())" ] @@ -239,24 +208,6 @@ "results = trainer.test(classifier, test_dataloaders=DataLoader(test))" ] }, - { - "cell_type": "markdown", - "id": "classified-cholesterol", - "metadata": {}, - "source": [ - "# Predicting" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "academic-preference", - "metadata": {}, - "outputs": [], - "source": [ - "classifier.predict(predict)" - ] - }, { "cell_type": "markdown", "id": "traditional-faculty", @@ -299,25 +250,7 @@ ] } ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.5" - } - }, + "metadata": {}, "nbformat": 4, "nbformat_minor": 5 } diff --git a/flash_notebooks/image_classification.ipynb b/flash_notebooks/image_classification.ipynb index 403d5e2908..238596f327 100644 --- a/flash_notebooks/image_classification.ipynb +++ b/flash_notebooks/image_classification.ipynb @@ -198,7 +198,11 @@ "cell_type": "code", "execution_count": null, "id": "apart-arrangement", - "metadata": {}, + "metadata": { + "tags": [ + "outputPrepend" + ] + }, "outputs": [], "source": [ "trainer.finetune(model, datamodule=datamodule, strategy=\"freeze_unfreeze\")" @@ -351,25 +355,7 @@ ] } ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.9" - } - }, + "metadata": {}, "nbformat": 4, "nbformat_minor": 5 } diff --git a/flash_notebooks/tabular_classification.ipynb b/flash_notebooks/tabular_classification.ipynb index 8cbb8470a4..f493f5c57f 100644 --- a/flash_notebooks/tabular_classification.ipynb +++ b/flash_notebooks/tabular_classification.ipynb @@ -95,12 +95,12 @@ "outputs": [], "source": [ "datamodule = TabularData.from_csv(\n", - " train_csv=\"./data/titanic/titanic.csv\",\n", - " test_csv=\"./data/titanic/test.csv\",\n", - " categorical_cols=[\"Sex\", \"Age\", \"SibSp\", \"Parch\", \"Ticket\", \"Cabin\", \"Embarked\"],\n", - " numerical_cols=[\"Fare\"],\n", - " target_col=\"Survived\",\n", - " val_size=0.25,\n", + " [\"Sex\", \"Age\", \"SibSp\", \"Parch\", \"Ticket\", \"Cabin\", \"Embarked\"],\n", + " [\"Fare\"],\n", + " target_field=\"Survived\",\n", + " train_file=\"./data/titanic/titanic.csv\",\n", + " test_file=\"./data/titanic/test.csv\",\n", + " val_split=0.25,\n", ")\n" ] }, diff --git a/flash_notebooks/text_classification.ipynb b/flash_notebooks/text_classification.ipynb index f238c09b60..81e0e850e1 100644 --- a/flash_notebooks/text_classification.ipynb +++ b/flash_notebooks/text_classification.ipynb @@ -109,12 +109,12 @@ "metadata": {}, "outputs": [], "source": [ - "datamodule = TextClassificationData.from_files(\n", + "datamodule = TextClassificationData.from_csv(\n", " train_file=\"data/imdb/train.csv\",\n", " val_file=\"data/imdb/valid.csv\",\n", " test_file=\"data/imdb/test.csv\",\n", - " input=\"review\",\n", - " target=\"sentiment\",\n", + " input_fields=\"review\",\n", + " target_fields=\"sentiment\",\n", " batch_size=512\n", ")" ] @@ -188,7 +188,7 @@ "metadata": {}, "outputs": [], "source": [ - "trainer.finetune(model, datamodule=datamodule, strategy=\"no_freeze\")" + "trainer.finetune(model, datamodule=datamodule, strategy=\"freeze\")" ] }, { @@ -301,9 +301,9 @@ "metadata": {}, "outputs": [], "source": [ - "datamodule = TextClassificationData.from_file(\n", + "datamodule = TextClassificationData.from_csv(\n", " predict_file=\"data/imdb/predict.csv\",\n", - " input=\"review\",\n", + " input_fields=\"review\",\n", ")\n", "predictions = flash.Trainer().predict(model, datamodule=datamodule)\n", "print(predictions)" @@ -351,25 +351,7 @@ ] } ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.5" - } - }, + "metadata": {}, "nbformat": 4, "nbformat_minor": 5 } diff --git a/requirements.txt b/requirements.txt index bce2efd675..41601b3f3a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,6 +16,6 @@ rouge-score>=0.0.4 sentencepiece>=0.1.95 filelock # comes with 3rd-party dependency pycocotools>=2.0.2 ; python_version >= "3.7" -kornia==0.5.0 +kornia>=0.5.1 pytorchvideo matplotlib # used by the visualisation callback diff --git a/requirements/extras.txt b/requirements/extras.txt index 78882bf8f3..ccb0a4be9b 100644 --- a/requirements/extras.txt +++ b/requirements/extras.txt @@ -1 +1 @@ -timm>=0.4.5 \ No newline at end of file +timm>=0.4.5