From 7b5f8ea0f8346818a4c7e79e2a684dfebf410c69 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 11 Mar 2021 13:00:46 +0100 Subject: [PATCH 1/9] switch to torchmetrics --- flash/tabular/classification/model.py | 2 +- flash/text/classification/model.py | 2 +- flash/text/seq2seq/summarization/metric.py | 2 +- flash/text/seq2seq/translation/metric.py | 2 +- flash/vision/classification/model.py | 4 ++-- flash/vision/embedding/image_embedder_model.py | 2 +- flash_examples/finetuning/tabular_classification.py | 2 +- requirements.txt | 1 + 8 files changed, 9 insertions(+), 8 deletions(-) diff --git a/flash/tabular/classification/model.py b/flash/tabular/classification/model.py index 166a35a1d5..40b61cc808 100644 --- a/flash/tabular/classification/model.py +++ b/flash/tabular/classification/model.py @@ -14,9 +14,9 @@ from typing import Any, Callable, List, Optional, Tuple, Type import torch -from pytorch_lightning.metrics import Metric from pytorch_tabnet.tab_network import TabNet from torch.nn import functional as F +from torchmetrics import Metric from flash.core.classification import ClassificationTask from flash.core.data import DataPipeline diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py index eff2bfa050..44d12241d8 100644 --- a/flash/text/classification/model.py +++ b/flash/text/classification/model.py @@ -16,7 +16,7 @@ from typing import Callable, Mapping, Sequence, Type, Union import torch -from pytorch_lightning.metrics.classification import Accuracy +from torchmetrics.classification import Accuracy from transformers import BertForSequenceClassification from flash.core.classification import ClassificationDataPipeline, ClassificationTask diff --git a/flash/text/seq2seq/summarization/metric.py b/flash/text/seq2seq/summarization/metric.py index 7736a60dd6..90dc0c2487 100644 --- a/flash/text/seq2seq/summarization/metric.py +++ b/flash/text/seq2seq/summarization/metric.py @@ -15,9 +15,9 @@ import numpy as np import torch -from pytorch_lightning.metrics import Metric from rouge_score import rouge_scorer, scoring from rouge_score.scoring import AggregateScore, Score +from torchmetrics import Metric from flash.text.seq2seq.summarization.utils import add_newline_to_end_of_each_sentence diff --git a/flash/text/seq2seq/translation/metric.py b/flash/text/seq2seq/translation/metric.py index 5b29cdc90b..878e2ba1e3 100644 --- a/flash/text/seq2seq/translation/metric.py +++ b/flash/text/seq2seq/translation/metric.py @@ -20,7 +20,7 @@ from typing import List import torch -from pytorch_lightning.metrics import Metric +from torchmetrics import Metric def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter: diff --git a/flash/vision/classification/model.py b/flash/vision/classification/model.py index 114175b90b..7c9c5f17d5 100644 --- a/flash/vision/classification/model.py +++ b/flash/vision/classification/model.py @@ -14,9 +14,9 @@ from typing import Any, Callable, Mapping, Sequence, Type, Union import torch -from pytorch_lightning.metrics import Accuracy from torch import nn from torch.nn import functional as F +from torchmetrics import Accuracy from flash.core.classification import ClassificationTask from flash.vision.backbones import backbone_and_num_features @@ -33,7 +33,7 @@ class ImageClassifier(ClassificationTask): loss_fn: Loss function for training, defaults to :func:`torch.nn.functional.cross_entropy`. optimizer: Optimizer to use for training, defaults to :class:`torch.optim.SGD`. metrics: Metrics to compute for training and evaluation, - defaults to :class:`pytorch_lightning.metrics.Accuracy`. + defaults to :class:`torchmetrics.Accuracy`. learning_rate: Learning rate to use for training, defaults to ``1e-3``. """ diff --git a/flash/vision/embedding/image_embedder_model.py b/flash/vision/embedding/image_embedder_model.py index 0e0884d5c8..a10e40da81 100644 --- a/flash/vision/embedding/image_embedder_model.py +++ b/flash/vision/embedding/image_embedder_model.py @@ -14,11 +14,11 @@ from typing import Any, Callable, Mapping, Optional, Sequence, Type, Union import torch -from pytorch_lightning.metrics import Accuracy from pytorch_lightning.utilities.distributed import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn from torch.nn import functional as F +from torchmetrics import Accuracy from flash.core import Task from flash.core.data import TaskDataPipeline diff --git a/flash_examples/finetuning/tabular_classification.py b/flash_examples/finetuning/tabular_classification.py index e9769296d3..265e27f390 100644 --- a/flash_examples/finetuning/tabular_classification.py +++ b/flash_examples/finetuning/tabular_classification.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall +from torchmetrics.classification import Accuracy, Precision, Recall import flash from flash.core.data import download_data diff --git a/requirements.txt b/requirements.txt index 32361349b8..6e66c19d13 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ pytorch-lightning==1.2.0rc0 # todo: we shall align with real 1.2 torch>=1.7 # TODO: regenerate weights with lewer PT version PyYAML>=5.1 Pillow>=7.2 +torchmetrics>=0.2 torchvision>=0.8 # lower to 0.7 after PT 1.6 transformers>=4.0 pytorch-tabnet==3.1 From a601d333eda1e70fe63ca9c01780ada24c588891 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 12 Mar 2021 14:49:49 +0100 Subject: [PATCH 2/9] more --- README.md | 2 +- docs/source/reference/tabular_classification.rst | 2 +- flash_notebooks/tabular_classification.ipynb | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 8ef3bd30cc..19c66898ed 100644 --- a/README.md +++ b/README.md @@ -244,7 +244,7 @@ To illustrate, say we want to build a model to predict if a passenger survived o ```python # import our libraries -from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall +from torchmetrics.classification import Accuracy, Precision, Recall import flash from flash import download_data from flash.tabular import TabularClassifier, TabularData diff --git a/docs/source/reference/tabular_classification.rst b/docs/source/reference/tabular_classification.rst index ebb1fc6a1b..1aab5296ed 100644 --- a/docs/source/reference/tabular_classification.rst +++ b/docs/source/reference/tabular_classification.rst @@ -47,7 +47,7 @@ Next, we create the :class:`~flash.tabular.TabularClassifier` task, using the Da import flash from flash import download_data from flash.tabular import TabularClassifier, TabularData - from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall + from torchmetrics.classification import Accuracy, Precision, Recall # 1. Download the data download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/') diff --git a/flash_notebooks/tabular_classification.ipynb b/flash_notebooks/tabular_classification.ipynb index da50cac2fa..460e996a14 100644 --- a/flash_notebooks/tabular_classification.ipynb +++ b/flash_notebooks/tabular_classification.ipynb @@ -50,7 +50,7 @@ "metadata": {}, "outputs": [], "source": [ - "from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall\n", + "from torchmetrics.classification import Accuracy, Precision, Recall\n", "\n", "import flash\n", "from flash.core.data import download_data\n", From f42074151461de0fac1297bb9d1e9f0cd08bedc6 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 12 Mar 2021 14:59:08 +0100 Subject: [PATCH 3/9] 0.2.0 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 6e66c19d13..a727cff477 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ pytorch-lightning==1.2.0rc0 # todo: we shall align with real 1.2 torch>=1.7 # TODO: regenerate weights with lewer PT version PyYAML>=5.1 Pillow>=7.2 -torchmetrics>=0.2 +torchmetrics>=0.2.0 torchvision>=0.8 # lower to 0.7 after PT 1.6 transformers>=4.0 pytorch-tabnet==3.1 From d51bb5ea7477e0e30a92312d14c3e07a988f6de7 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 12 Mar 2021 15:19:13 +0100 Subject: [PATCH 4/9] softmax --- flash/vision/classification/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flash/vision/classification/model.py b/flash/vision/classification/model.py index 7c9c5f17d5..af300102f1 100644 --- a/flash/vision/classification/model.py +++ b/flash/vision/classification/model.py @@ -16,6 +16,7 @@ import torch from torch import nn from torch.nn import functional as F +from torch.nn.functional import softmax from torchmetrics import Accuracy from flash.core.classification import ClassificationTask @@ -67,7 +68,7 @@ def __init__( def forward(self, x) -> Any: x = self.backbone(x) - return self.head(x) + return softmax(self.head(x)) @staticmethod def default_pipeline() -> ImageClassificationDataPipeline: From 6da498c03a2f6b689c0d01ebe662649ad6e6ae25 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 12 Mar 2021 15:30:07 +0100 Subject: [PATCH 5/9] softmax --- flash/tabular/classification/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flash/tabular/classification/model.py b/flash/tabular/classification/model.py index 40b61cc808..7c9738b328 100644 --- a/flash/tabular/classification/model.py +++ b/flash/tabular/classification/model.py @@ -16,6 +16,7 @@ import torch from pytorch_tabnet.tab_network import TabNet from torch.nn import functional as F +from torch.nn.functional import softmax from torchmetrics import Metric from flash.core.classification import ClassificationTask @@ -85,7 +86,7 @@ def predict( def forward(self, x_in): # TabNet takes single input, x_in is composed of (categorical, numerical) x = torch.cat([x for x in x_in if x.numel()], dim=1) - return self.model(x)[0] + return softmax(self.model(x)[0]) @classmethod def from_data(cls, datamodule, **kwargs) -> 'TabularClassifier': From d20aa3c082d509d975fe8ae514e7a0d51f87af1f Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 12 Mar 2021 16:11:43 +0100 Subject: [PATCH 6/9] softmax --- flash_examples/generic_task.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flash_examples/generic_task.py b/flash_examples/generic_task.py index 397bd04ea4..de1439d783 100644 --- a/flash_examples/generic_task.py +++ b/flash_examples/generic_task.py @@ -24,6 +24,7 @@ nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 10), + nn.Softmax(), ) # 2. Load a dataset From 82b3835a2145d296dbbc5b4b9e9d557286e8a421 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 12 Mar 2021 16:29:31 +0100 Subject: [PATCH 7/9] mnist --- tests/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/__init__.py b/tests/__init__.py index e69de29bb2..b499bb5f7f 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,6 @@ +import urllib + +# TorchVision hotfix https://github.com/pytorch/vision/issues/1938 +opener = urllib.request.build_opener() +opener.addheaders = [('User-agent', 'Mozilla/5.0')] +urllib.request.install_opener(opener) From 5f04e691d78c727c92f2ca1bc4cc1d3eac9c6b2e Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 12 Mar 2021 16:46:01 +0100 Subject: [PATCH 8/9] cache --- flash_examples/generic_task.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/flash_examples/generic_task.py b/flash_examples/generic_task.py index de1439d783..2b07034b04 100644 --- a/flash_examples/generic_task.py +++ b/flash_examples/generic_task.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os + import pytorch_lightning as pl from torch import nn, optim from torch.utils.data import DataLoader, random_split @@ -18,6 +20,8 @@ from flash import ClassificationTask +_PATH_ROOT = os.path.dirname(os.path.dirname(__file__)) + # 1. Load a basic backbone model = nn.Sequential( nn.Flatten(), @@ -28,7 +32,7 @@ ) # 2. Load a dataset -dataset = datasets.MNIST('./data', download=True, transform=transforms.ToTensor()) +dataset = datasets.MNIST(os.path.join(_PATH_ROOT, 'data'), download=True, transform=transforms.ToTensor()) # 3. Split the data randomly train, val, test = random_split(dataset, [50000, 5000, 5000]) # type: ignore From 51006ed7fa79148ce2d56c52a21ef30c79f27113 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 12 Mar 2021 16:57:49 +0100 Subject: [PATCH 9/9] . --- flash_examples/generic_task.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/flash_examples/generic_task.py b/flash_examples/generic_task.py index 2b07034b04..e1bfbbc652 100644 --- a/flash_examples/generic_task.py +++ b/flash_examples/generic_task.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import urllib import pytorch_lightning as pl from torch import nn, optim @@ -21,6 +22,10 @@ from flash import ClassificationTask _PATH_ROOT = os.path.dirname(os.path.dirname(__file__)) +# TorchVision hotfix https://github.com/pytorch/vision/issues/1938 +opener = urllib.request.build_opener() +opener.addheaders = [('User-agent', 'Mozilla/5.0')] +urllib.request.install_opener(opener) # 1. Load a basic backbone model = nn.Sequential(