diff --git a/README.md b/README.md index 8ef3bd30cc..19c66898ed 100644 --- a/README.md +++ b/README.md @@ -244,7 +244,7 @@ To illustrate, say we want to build a model to predict if a passenger survived o ```python # import our libraries -from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall +from torchmetrics.classification import Accuracy, Precision, Recall import flash from flash import download_data from flash.tabular import TabularClassifier, TabularData diff --git a/docs/source/reference/tabular_classification.rst b/docs/source/reference/tabular_classification.rst index ebb1fc6a1b..1aab5296ed 100644 --- a/docs/source/reference/tabular_classification.rst +++ b/docs/source/reference/tabular_classification.rst @@ -47,7 +47,7 @@ Next, we create the :class:`~flash.tabular.TabularClassifier` task, using the Da import flash from flash import download_data from flash.tabular import TabularClassifier, TabularData - from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall + from torchmetrics.classification import Accuracy, Precision, Recall # 1. Download the data download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/') diff --git a/flash/tabular/classification/model.py b/flash/tabular/classification/model.py index 166a35a1d5..7c9738b328 100644 --- a/flash/tabular/classification/model.py +++ b/flash/tabular/classification/model.py @@ -14,9 +14,10 @@ from typing import Any, Callable, List, Optional, Tuple, Type import torch -from pytorch_lightning.metrics import Metric from pytorch_tabnet.tab_network import TabNet from torch.nn import functional as F +from torch.nn.functional import softmax +from torchmetrics import Metric from flash.core.classification import ClassificationTask from flash.core.data import DataPipeline @@ -85,7 +86,7 @@ def predict( def forward(self, x_in): # TabNet takes single input, x_in is composed of (categorical, numerical) x = torch.cat([x for x in x_in if x.numel()], dim=1) - return self.model(x)[0] + return softmax(self.model(x)[0]) @classmethod def from_data(cls, datamodule, **kwargs) -> 'TabularClassifier': diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py index eff2bfa050..44d12241d8 100644 --- a/flash/text/classification/model.py +++ b/flash/text/classification/model.py @@ -16,7 +16,7 @@ from typing import Callable, Mapping, Sequence, Type, Union import torch -from pytorch_lightning.metrics.classification import Accuracy +from torchmetrics.classification import Accuracy from transformers import BertForSequenceClassification from flash.core.classification import ClassificationDataPipeline, ClassificationTask diff --git a/flash/text/seq2seq/summarization/metric.py b/flash/text/seq2seq/summarization/metric.py index 7736a60dd6..90dc0c2487 100644 --- a/flash/text/seq2seq/summarization/metric.py +++ b/flash/text/seq2seq/summarization/metric.py @@ -15,9 +15,9 @@ import numpy as np import torch -from pytorch_lightning.metrics import Metric from rouge_score import rouge_scorer, scoring from rouge_score.scoring import AggregateScore, Score +from torchmetrics import Metric from flash.text.seq2seq.summarization.utils import add_newline_to_end_of_each_sentence diff --git a/flash/text/seq2seq/translation/metric.py b/flash/text/seq2seq/translation/metric.py index 5b29cdc90b..878e2ba1e3 100644 --- a/flash/text/seq2seq/translation/metric.py +++ b/flash/text/seq2seq/translation/metric.py @@ -20,7 +20,7 @@ from typing import List import torch -from pytorch_lightning.metrics import Metric +from torchmetrics import Metric def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter: diff --git a/flash/vision/classification/model.py b/flash/vision/classification/model.py index 114175b90b..af300102f1 100644 --- a/flash/vision/classification/model.py +++ b/flash/vision/classification/model.py @@ -14,9 +14,10 @@ from typing import Any, Callable, Mapping, Sequence, Type, Union import torch -from pytorch_lightning.metrics import Accuracy from torch import nn from torch.nn import functional as F +from torch.nn.functional import softmax +from torchmetrics import Accuracy from flash.core.classification import ClassificationTask from flash.vision.backbones import backbone_and_num_features @@ -33,7 +34,7 @@ class ImageClassifier(ClassificationTask): loss_fn: Loss function for training, defaults to :func:`torch.nn.functional.cross_entropy`. optimizer: Optimizer to use for training, defaults to :class:`torch.optim.SGD`. metrics: Metrics to compute for training and evaluation, - defaults to :class:`pytorch_lightning.metrics.Accuracy`. + defaults to :class:`torchmetrics.Accuracy`. learning_rate: Learning rate to use for training, defaults to ``1e-3``. """ @@ -67,7 +68,7 @@ def __init__( def forward(self, x) -> Any: x = self.backbone(x) - return self.head(x) + return softmax(self.head(x)) @staticmethod def default_pipeline() -> ImageClassificationDataPipeline: diff --git a/flash/vision/embedding/image_embedder_model.py b/flash/vision/embedding/image_embedder_model.py index 0e0884d5c8..a10e40da81 100644 --- a/flash/vision/embedding/image_embedder_model.py +++ b/flash/vision/embedding/image_embedder_model.py @@ -14,11 +14,11 @@ from typing import Any, Callable, Mapping, Optional, Sequence, Type, Union import torch -from pytorch_lightning.metrics import Accuracy from pytorch_lightning.utilities.distributed import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn from torch.nn import functional as F +from torchmetrics import Accuracy from flash.core import Task from flash.core.data import TaskDataPipeline diff --git a/flash_examples/finetuning/tabular_classification.py b/flash_examples/finetuning/tabular_classification.py index e9769296d3..265e27f390 100644 --- a/flash_examples/finetuning/tabular_classification.py +++ b/flash_examples/finetuning/tabular_classification.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall +from torchmetrics.classification import Accuracy, Precision, Recall import flash from flash.core.data import download_data diff --git a/flash_examples/generic_task.py b/flash_examples/generic_task.py index 397bd04ea4..e1bfbbc652 100644 --- a/flash_examples/generic_task.py +++ b/flash_examples/generic_task.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import urllib + import pytorch_lightning as pl from torch import nn, optim from torch.utils.data import DataLoader, random_split @@ -18,16 +21,23 @@ from flash import ClassificationTask +_PATH_ROOT = os.path.dirname(os.path.dirname(__file__)) +# TorchVision hotfix https://github.com/pytorch/vision/issues/1938 +opener = urllib.request.build_opener() +opener.addheaders = [('User-agent', 'Mozilla/5.0')] +urllib.request.install_opener(opener) + # 1. Load a basic backbone model = nn.Sequential( nn.Flatten(), nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 10), + nn.Softmax(), ) # 2. Load a dataset -dataset = datasets.MNIST('./data', download=True, transform=transforms.ToTensor()) +dataset = datasets.MNIST(os.path.join(_PATH_ROOT, 'data'), download=True, transform=transforms.ToTensor()) # 3. Split the data randomly train, val, test = random_split(dataset, [50000, 5000, 5000]) # type: ignore diff --git a/flash_notebooks/tabular_classification.ipynb b/flash_notebooks/tabular_classification.ipynb index da50cac2fa..460e996a14 100644 --- a/flash_notebooks/tabular_classification.ipynb +++ b/flash_notebooks/tabular_classification.ipynb @@ -50,7 +50,7 @@ "metadata": {}, "outputs": [], "source": [ - "from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall\n", + "from torchmetrics.classification import Accuracy, Precision, Recall\n", "\n", "import flash\n", "from flash.core.data import download_data\n", diff --git a/requirements.txt b/requirements.txt index 32361349b8..a727cff477 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ pytorch-lightning==1.2.0rc0 # todo: we shall align with real 1.2 torch>=1.7 # TODO: regenerate weights with lewer PT version PyYAML>=5.1 Pillow>=7.2 +torchmetrics>=0.2.0 torchvision>=0.8 # lower to 0.7 after PT 1.6 transformers>=4.0 pytorch-tabnet==3.1 diff --git a/tests/__init__.py b/tests/__init__.py index e69de29bb2..b499bb5f7f 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,6 @@ +import urllib + +# TorchVision hotfix https://github.com/pytorch/vision/issues/1938 +opener = urllib.request.build_opener() +opener.addheaders = [('User-agent', 'Mozilla/5.0')] +urllib.request.install_opener(opener)