diff --git a/.readthedocs.yml b/.readthedocs.yml
index 6eeec2eb41..bb9ffdbdd8 100644
--- a/.readthedocs.yml
+++ b/.readthedocs.yml
@@ -20,4 +20,4 @@ python:
version: 3.7
install:
- requirements: requirements/docs.txt
- #- requirements: requirements.txt
\ No newline at end of file
+ #- requirements: requirements.txt
diff --git a/Makefile b/Makefile
index 9bcddf4426..6fcee001e6 100644
--- a/Makefile
+++ b/Makefile
@@ -25,4 +25,4 @@ clean:
rm -rf .pytest_cache
rm -rf ./docs/build
rm -rf ./docs/source/**/generated
- rm -rf ./docs/source/api
\ No newline at end of file
+ rm -rf ./docs/source/api
diff --git a/docs/source/_static/images/logo.svg b/docs/source/_static/images/logo.svg
index ad94333a72..2c3e330bbf 100644
--- a/docs/source/_static/images/logo.svg
+++ b/docs/source/_static/images/logo.svg
@@ -1 +1 @@
-
\ No newline at end of file
+
diff --git a/flash/vision/classification/transforms.py b/flash/vision/classification/transforms.py
index 3eff2f4c2c..2406ec9b24 100644
--- a/flash/vision/classification/transforms.py
+++ b/flash/vision/classification/transforms.py
@@ -37,8 +37,7 @@ def default_train_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]
),
"post_tensor_transform": ApplyToKeys(
DefaultDataKeys.INPUT,
- # TODO (Edgar): replace with resize once kornia is fixed
- K.augmentation.RandomResizedCrop(image_size, scale=(1.0, 1.0), ratio=(1.0, 1.0)),
+ K.geometry.Resize(image_size),
K.augmentation.RandomHorizontalFlip(),
),
"per_batch_transform_on_device": ApplyToKeys(
@@ -70,8 +69,7 @@ def default_val_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]:
),
"post_tensor_transform": ApplyToKeys(
DefaultDataKeys.INPUT,
- # TODO (Edgar): replace with resize once kornia is fixed
- K.augmentation.RandomResizedCrop(image_size, scale=(1.0, 1.0), ratio=(1.0, 1.0)),
+ K.geometry.Resize(image_size),
),
"per_batch_transform_on_device": ApplyToKeys(
DefaultDataKeys.INPUT,
diff --git a/flash_notebooks/custom_task_tutorial.ipynb b/flash_notebooks/custom_task_tutorial.ipynb
index 93b214ff62..10c698573e 100644
--- a/flash_notebooks/custom_task_tutorial.ipynb
+++ b/flash_notebooks/custom_task_tutorial.ipynb
@@ -40,18 +40,21 @@
"metadata": {},
"outputs": [],
"source": [
- "from typing import Any, List, Tuple\n",
+ "from typing import Any, List, Tuple, Dict\n",
"\n",
"import numpy as np\n",
"import torch\n",
"from pytorch_lightning import seed_everything\n",
"from sklearn import datasets\n",
"from sklearn.model_selection import train_test_split\n",
- "from torch import nn\n",
+ "from torch import nn, Tensor\n",
"\n",
"import flash\n",
"from flash.data.auto_dataset import AutoDataset\n",
- "from flash.data.process import Postprocess, Preprocess"
+ "from flash.data.data_source import DataSource\n",
+ "from flash.data.process import Preprocess\n",
+ "\n",
+ "ND = np.ndarray"
]
},
{
@@ -152,24 +155,43 @@
"metadata": {},
"outputs": [],
"source": [
- "class NumpyRegressionPreprocess(Preprocess):\n",
+ "class NumpyDataSource(DataSource):\n",
"\n",
- " def load_data(self, data: Tuple[np.ndarray, np.ndarray], dataset: AutoDataset) -> List[Tuple[np.ndarray, float]]:\n",
+ " def load_data(self, data: Tuple[ND, ND], dataset: AutoDataset) -> List[Tuple[ND, float]]:\n",
" if self.training:\n",
" dataset.num_inputs = data[0].shape[1]\n",
" return [(x, y) for x, y in zip(*data)]\n",
"\n",
- " def to_tensor_transform(self, sample: Any) -> Tuple[torch.Tensor, torch.Tensor]:\n",
+ " def predict_load_data(self, data: ND) -> ND:\n",
+ " return data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class NumpyPreprocess(Preprocess):\n",
+ "\n",
+ " def __init__(self):\n",
+ " super().__init__(data_sources={\"numpy\": NumpyDataSource()}, default_data_source=\"numpy\")\n",
+ "\n",
+ " def to_tensor_transform(self, sample: Any) -> Tuple[Tensor, Tensor]:\n",
" x, y = sample\n",
" x = torch.from_numpy(x).float()\n",
" y = torch.tensor(y, dtype=torch.float)\n",
" return x, y\n",
"\n",
- " def predict_load_data(self, data: np.ndarray) -> np.ndarray:\n",
- " return data\n",
+ " def predict_to_tensor_transform(self, sample: ND) -> ND:\n",
+ " return torch.from_numpy(sample).float()\n",
+ "\n",
+ " def get_state_dict(self) -> Dict[str, Any]:\n",
+ " return {}\n",
"\n",
- " def predict_to_tensor_transform(self, sample: np.ndarray) -> np.ndarray:\n",
- " return torch.from_numpy(sample).float()\n"
+ " @classmethod\n",
+ " def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool):\n",
+ " return cls()"
]
},
{
@@ -181,15 +203,16 @@
"class SklearnDataModule(flash.DataModule):\n",
"\n",
" @classmethod\n",
- " def from_dataset(cls, x: np.ndarray, y: np.ndarray, batch_size: int = 64, num_workers: int = 0):\n",
+ " def from_dataset(cls, x: np.ndarray, y: np.ndarray, preprocess: Preprocess, batch_size: int = 64, num_workers: int = 0):\n",
"\n",
- " preprocess = NumpyRegressionPreprocess()\n",
+ " preprocess = preprocess\n",
"\n",
" x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.20, random_state=0)\n",
"\n",
- " dm = cls.from_load_data_inputs(\n",
- " train_load_data_input=(x_train, y_train),\n",
- " test_load_data_input=(x_test, y_test),\n",
+ " dm = cls.from_data_source(\n",
+ " \"numpy\",\n",
+ " train_data=(x_train, y_train),\n",
+ " test_data=(x_test, y_test),\n",
" preprocess=preprocess,\n",
" batch_size=batch_size,\n",
" num_workers=num_workers\n",
@@ -204,7 +227,8 @@
"metadata": {},
"outputs": [],
"source": [
- "datamodule = SklearnDataModule.from_dataset(*datasets.load_diabetes(return_X_y=True))"
+ "x, y = datasets.load_diabetes(return_X_y=True)\n",
+ "datamodule = SklearnDataModule.from_dataset(x, y, NumpyPreprocess())"
]
},
{
@@ -350,25 +374,7 @@
]
}
],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.8.5"
- }
- },
+ "metadata": {},
"nbformat": 4,
"nbformat_minor": 4
}
diff --git a/flash_notebooks/generic_task.ipynb b/flash_notebooks/generic_task.ipynb
index 4e7b25e465..a3ddd6a86b 100644
--- a/flash_notebooks/generic_task.ipynb
+++ b/flash_notebooks/generic_task.ipynb
@@ -34,7 +34,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": null,
"id": "straight-vision",
"metadata": {},
"outputs": [],
@@ -45,7 +45,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": null,
"id": "foreign-design",
"metadata": {},
"outputs": [],
@@ -55,12 +55,12 @@
"from torch.utils.data import DataLoader, random_split\n",
"from torchvision import datasets, transforms\n",
"\n",
- "from flash import ClassificationTask"
+ "from flash.core.classification import ClassificationTask"
]
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": null,
"id": "mathematical-barbados",
"metadata": {},
"outputs": [],
@@ -83,7 +83,7 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": null,
"id": "selective-request",
"metadata": {},
"outputs": [],
@@ -106,41 +106,10 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": null,
"id": "stunning-anime",
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz\n"
- ]
- },
- {
- "ename": "HTTPError",
- "evalue": "HTTP Error 503: Service Unavailable",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mHTTPError\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mMNIST\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'./data'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdownload\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtransform\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtransforms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mToTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
- "\u001b[0;32m~/Documents/GitHub/lightning-flash/.venv/lib/python3.8/site-packages/torchvision/datasets/mnist.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, root, train, transform, target_transform, download)\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 78\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdownload\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 79\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdownload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 80\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_check_exists\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/Documents/GitHub/lightning-flash/.venv/lib/python3.8/site-packages/torchvision/datasets/mnist.py\u001b[0m in \u001b[0;36mdownload\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmd5\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresources\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 145\u001b[0m \u001b[0mfilename\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0murl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrpartition\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'/'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 146\u001b[0;31m \u001b[0mdownload_and_extract_archive\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdownload_root\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mraw_folder\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfilename\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmd5\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmd5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 147\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[0;31m# process and save as torch files\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/Documents/GitHub/lightning-flash/.venv/lib/python3.8/site-packages/torchvision/datasets/utils.py\u001b[0m in \u001b[0;36mdownload_and_extract_archive\u001b[0;34m(url, download_root, extract_root, filename, md5, remove_finished)\u001b[0m\n\u001b[1;32m 312\u001b[0m \u001b[0mfilename\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbasename\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 313\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 314\u001b[0;31m \u001b[0mdownload_url\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdownload_root\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmd5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 315\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 316\u001b[0m \u001b[0marchive\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdownload_root\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/Documents/GitHub/lightning-flash/.venv/lib/python3.8/site-packages/torchvision/datasets/utils.py\u001b[0m in \u001b[0;36mdownload_url\u001b[0;34m(url, root, filename, md5, max_redirect_hops)\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[0m_urlretrieve\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 139\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 140\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 141\u001b[0m \u001b[0;31m# check integrity of downloaded file\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mcheck_integrity\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfpath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmd5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/Documents/GitHub/lightning-flash/.venv/lib/python3.8/site-packages/torchvision/datasets/utils.py\u001b[0m in \u001b[0;36mdownload_url\u001b[0;34m(url, root, filename, md5, max_redirect_hops)\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 131\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Downloading '\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0murl\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m' to '\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mfpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 132\u001b[0;31m \u001b[0m_urlretrieve\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 133\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0murllib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0merror\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mURLError\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mIOError\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# type: ignore[attr-defined]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0murl\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'https'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/Documents/GitHub/lightning-flash/.venv/lib/python3.8/site-packages/torchvision/datasets/utils.py\u001b[0m in \u001b[0;36m_urlretrieve\u001b[0;34m(url, filename, chunk_size)\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_urlretrieve\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mchunk_size\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1024\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilename\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"wb\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mfh\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 29\u001b[0;31m \u001b[0;32mwith\u001b[0m \u001b[0murllib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0murlopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murllib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mRequest\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mheaders\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m\"User-Agent\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUSER_AGENT\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mresponse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 30\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtotal\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mresponse\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlength\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mpbar\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mchunk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mresponse\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mchunk_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/.pyenv/versions/3.8.5/lib/python3.8/urllib/request.py\u001b[0m in \u001b[0;36murlopen\u001b[0;34m(url, data, timeout, cafile, capath, cadefault, context)\u001b[0m\n\u001b[1;32m 220\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 221\u001b[0m \u001b[0mopener\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_opener\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 222\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mopener\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 223\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 224\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0minstall_opener\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopener\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/.pyenv/versions/3.8.5/lib/python3.8/urllib/request.py\u001b[0m in \u001b[0;36mopen\u001b[0;34m(self, fullurl, data, timeout)\u001b[0m\n\u001b[1;32m 529\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mprocessor\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_response\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprotocol\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 530\u001b[0m \u001b[0mmeth\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprocessor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmeth_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 531\u001b[0;31m \u001b[0mresponse\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmeth\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreq\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresponse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 532\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 533\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresponse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/.pyenv/versions/3.8.5/lib/python3.8/urllib/request.py\u001b[0m in \u001b[0;36mhttp_response\u001b[0;34m(self, request, response)\u001b[0m\n\u001b[1;32m 638\u001b[0m \u001b[0;31m# request was successfully received, understood, and accepted.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 639\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m200\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0mcode\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0;36m300\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 640\u001b[0;31m response = self.parent.error(\n\u001b[0m\u001b[1;32m 641\u001b[0m 'http', request, response, code, msg, hdrs)\n\u001b[1;32m 642\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/.pyenv/versions/3.8.5/lib/python3.8/urllib/request.py\u001b[0m in \u001b[0;36merror\u001b[0;34m(self, proto, *args)\u001b[0m\n\u001b[1;32m 567\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhttp_err\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 568\u001b[0m \u001b[0margs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'default'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'http_error_default'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0morig_args\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 569\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_chain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 570\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 571\u001b[0m \u001b[0;31m# XXX probably also want an abstract factory that knows when it makes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/.pyenv/versions/3.8.5/lib/python3.8/urllib/request.py\u001b[0m in \u001b[0;36m_call_chain\u001b[0;34m(self, chain, kind, meth_name, *args)\u001b[0m\n\u001b[1;32m 500\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhandler\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mhandlers\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 501\u001b[0m \u001b[0mfunc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhandler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmeth_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 502\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 503\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 504\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/.pyenv/versions/3.8.5/lib/python3.8/urllib/request.py\u001b[0m in \u001b[0;36mhttp_error_default\u001b[0;34m(self, req, fp, code, msg, hdrs)\u001b[0m\n\u001b[1;32m 647\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0mHTTPDefaultErrorHandler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mBaseHandler\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 648\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mhttp_error_default\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreq\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhdrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 649\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mHTTPError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreq\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_url\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhdrs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfp\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 650\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 651\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0mHTTPRedirectHandler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mBaseHandler\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;31mHTTPError\u001b[0m: HTTP Error 503: Service Unavailable"
- ]
- }
- ],
+ "outputs": [],
"source": [
"dataset = datasets.MNIST('./data', download=True, transform=transforms.ToTensor())"
]
@@ -239,24 +208,6 @@
"results = trainer.test(classifier, test_dataloaders=DataLoader(test))"
]
},
- {
- "cell_type": "markdown",
- "id": "classified-cholesterol",
- "metadata": {},
- "source": [
- "# Predicting"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "academic-preference",
- "metadata": {},
- "outputs": [],
- "source": [
- "classifier.predict(predict)"
- ]
- },
{
"cell_type": "markdown",
"id": "traditional-faculty",
@@ -299,25 +250,7 @@
]
}
],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.8.5"
- }
- },
+ "metadata": {},
"nbformat": 4,
"nbformat_minor": 5
}
diff --git a/flash_notebooks/image_classification.ipynb b/flash_notebooks/image_classification.ipynb
index 403d5e2908..238596f327 100644
--- a/flash_notebooks/image_classification.ipynb
+++ b/flash_notebooks/image_classification.ipynb
@@ -198,7 +198,11 @@
"cell_type": "code",
"execution_count": null,
"id": "apart-arrangement",
- "metadata": {},
+ "metadata": {
+ "tags": [
+ "outputPrepend"
+ ]
+ },
"outputs": [],
"source": [
"trainer.finetune(model, datamodule=datamodule, strategy=\"freeze_unfreeze\")"
@@ -351,25 +355,7 @@
]
}
],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.6.9"
- }
- },
+ "metadata": {},
"nbformat": 4,
"nbformat_minor": 5
}
diff --git a/flash_notebooks/tabular_classification.ipynb b/flash_notebooks/tabular_classification.ipynb
index 8cbb8470a4..f493f5c57f 100644
--- a/flash_notebooks/tabular_classification.ipynb
+++ b/flash_notebooks/tabular_classification.ipynb
@@ -95,12 +95,12 @@
"outputs": [],
"source": [
"datamodule = TabularData.from_csv(\n",
- " train_csv=\"./data/titanic/titanic.csv\",\n",
- " test_csv=\"./data/titanic/test.csv\",\n",
- " categorical_cols=[\"Sex\", \"Age\", \"SibSp\", \"Parch\", \"Ticket\", \"Cabin\", \"Embarked\"],\n",
- " numerical_cols=[\"Fare\"],\n",
- " target_col=\"Survived\",\n",
- " val_size=0.25,\n",
+ " [\"Sex\", \"Age\", \"SibSp\", \"Parch\", \"Ticket\", \"Cabin\", \"Embarked\"],\n",
+ " [\"Fare\"],\n",
+ " target_field=\"Survived\",\n",
+ " train_file=\"./data/titanic/titanic.csv\",\n",
+ " test_file=\"./data/titanic/test.csv\",\n",
+ " val_split=0.25,\n",
")\n"
]
},
diff --git a/flash_notebooks/text_classification.ipynb b/flash_notebooks/text_classification.ipynb
index f238c09b60..81e0e850e1 100644
--- a/flash_notebooks/text_classification.ipynb
+++ b/flash_notebooks/text_classification.ipynb
@@ -109,12 +109,12 @@
"metadata": {},
"outputs": [],
"source": [
- "datamodule = TextClassificationData.from_files(\n",
+ "datamodule = TextClassificationData.from_csv(\n",
" train_file=\"data/imdb/train.csv\",\n",
" val_file=\"data/imdb/valid.csv\",\n",
" test_file=\"data/imdb/test.csv\",\n",
- " input=\"review\",\n",
- " target=\"sentiment\",\n",
+ " input_fields=\"review\",\n",
+ " target_fields=\"sentiment\",\n",
" batch_size=512\n",
")"
]
@@ -188,7 +188,7 @@
"metadata": {},
"outputs": [],
"source": [
- "trainer.finetune(model, datamodule=datamodule, strategy=\"no_freeze\")"
+ "trainer.finetune(model, datamodule=datamodule, strategy=\"freeze\")"
]
},
{
@@ -301,9 +301,9 @@
"metadata": {},
"outputs": [],
"source": [
- "datamodule = TextClassificationData.from_file(\n",
+ "datamodule = TextClassificationData.from_csv(\n",
" predict_file=\"data/imdb/predict.csv\",\n",
- " input=\"review\",\n",
+ " input_fields=\"review\",\n",
")\n",
"predictions = flash.Trainer().predict(model, datamodule=datamodule)\n",
"print(predictions)"
@@ -351,25 +351,7 @@
]
}
],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.8.5"
- }
- },
+ "metadata": {},
"nbformat": 4,
"nbformat_minor": 5
}
diff --git a/requirements.txt b/requirements.txt
index bce2efd675..41601b3f3a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -16,6 +16,6 @@ rouge-score>=0.0.4
sentencepiece>=0.1.95
filelock # comes with 3rd-party dependency
pycocotools>=2.0.2 ; python_version >= "3.7"
-kornia==0.5.0
+kornia>=0.5.1
pytorchvideo
matplotlib # used by the visualisation callback
diff --git a/requirements/extras.txt b/requirements/extras.txt
index 78882bf8f3..ccb0a4be9b 100644
--- a/requirements/extras.txt
+++ b/requirements/extras.txt
@@ -1 +1 @@
-timm>=0.4.5
\ No newline at end of file
+timm>=0.4.5