Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

switch to torchmetrics #169

Merged
merged 9 commits into from
Mar 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ To illustrate, say we want to build a model to predict if a passenger survived o

```python
# import our libraries
from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall
from torchmetrics.classification import Accuracy, Precision, Recall
import flash
from flash import download_data
from flash.tabular import TabularClassifier, TabularData
Expand Down
2 changes: 1 addition & 1 deletion docs/source/reference/tabular_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Next, we create the :class:`~flash.tabular.TabularClassifier` task, using the Da
import flash
from flash import download_data
from flash.tabular import TabularClassifier, TabularData
from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall
from torchmetrics.classification import Accuracy, Precision, Recall
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/')
Expand Down
5 changes: 3 additions & 2 deletions flash/tabular/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
from typing import Any, Callable, List, Optional, Tuple, Type

import torch
from pytorch_lightning.metrics import Metric
from pytorch_tabnet.tab_network import TabNet
from torch.nn import functional as F
from torch.nn.functional import softmax
from torchmetrics import Metric

from flash.core.classification import ClassificationTask
from flash.core.data import DataPipeline
Expand Down Expand Up @@ -85,7 +86,7 @@ def predict(
def forward(self, x_in):
# TabNet takes single input, x_in is composed of (categorical, numerical)
x = torch.cat([x for x in x_in if x.numel()], dim=1)
return self.model(x)[0]
return softmax(self.model(x)[0])

@classmethod
def from_data(cls, datamodule, **kwargs) -> 'TabularClassifier':
Expand Down
2 changes: 1 addition & 1 deletion flash/text/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Callable, Mapping, Sequence, Type, Union

import torch
from pytorch_lightning.metrics.classification import Accuracy
from torchmetrics.classification import Accuracy
from transformers import BertForSequenceClassification

from flash.core.classification import ClassificationDataPipeline, ClassificationTask
Expand Down
2 changes: 1 addition & 1 deletion flash/text/seq2seq/summarization/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

import numpy as np
import torch
from pytorch_lightning.metrics import Metric
from rouge_score import rouge_scorer, scoring
from rouge_score.scoring import AggregateScore, Score
from torchmetrics import Metric

from flash.text.seq2seq.summarization.utils import add_newline_to_end_of_each_sentence

Expand Down
2 changes: 1 addition & 1 deletion flash/text/seq2seq/translation/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import List

import torch
from pytorch_lightning.metrics import Metric
from torchmetrics import Metric


def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter:
Expand Down
7 changes: 4 additions & 3 deletions flash/vision/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
from typing import Any, Callable, Mapping, Sequence, Type, Union

import torch
from pytorch_lightning.metrics import Accuracy
from torch import nn
from torch.nn import functional as F
from torch.nn.functional import softmax
from torchmetrics import Accuracy

from flash.core.classification import ClassificationTask
from flash.vision.backbones import backbone_and_num_features
Expand All @@ -33,7 +34,7 @@ class ImageClassifier(ClassificationTask):
loss_fn: Loss function for training, defaults to :func:`torch.nn.functional.cross_entropy`.
optimizer: Optimizer to use for training, defaults to :class:`torch.optim.SGD`.
metrics: Metrics to compute for training and evaluation,
defaults to :class:`pytorch_lightning.metrics.Accuracy`.
defaults to :class:`torchmetrics.Accuracy`.
learning_rate: Learning rate to use for training, defaults to ``1e-3``.
"""

Expand Down Expand Up @@ -67,7 +68,7 @@ def __init__(

def forward(self, x) -> Any:
x = self.backbone(x)
return self.head(x)
return softmax(self.head(x))

@staticmethod
def default_pipeline() -> ImageClassificationDataPipeline:
Expand Down
2 changes: 1 addition & 1 deletion flash/vision/embedding/image_embedder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from typing import Any, Callable, Mapping, Optional, Sequence, Type, Union

import torch
from pytorch_lightning.metrics import Accuracy
from pytorch_lightning.utilities.distributed import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn
from torch.nn import functional as F
from torchmetrics import Accuracy

from flash.core import Task
from flash.core.data import TaskDataPipeline
Expand Down
2 changes: 1 addition & 1 deletion flash_examples/finetuning/tabular_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall
from torchmetrics.classification import Accuracy, Precision, Recall

import flash
from flash.core.data import download_data
Expand Down
12 changes: 11 additions & 1 deletion flash_examples/generic_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,33 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import urllib

import pytorch_lightning as pl
from torch import nn, optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

from flash import ClassificationTask

_PATH_ROOT = os.path.dirname(os.path.dirname(__file__))
# TorchVision hotfix https://github.com/pytorch/vision/issues/1938
opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)

# 1. Load a basic backbone
model = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 128),
nn.ReLU(),
nn.Linear(128, 10),
nn.Softmax(),
)

# 2. Load a dataset
dataset = datasets.MNIST('./data', download=True, transform=transforms.ToTensor())
dataset = datasets.MNIST(os.path.join(_PATH_ROOT, 'data'), download=True, transform=transforms.ToTensor())

# 3. Split the data randomly
train, val, test = random_split(dataset, [50000, 5000, 5000]) # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion flash_notebooks/tabular_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"metadata": {},
"outputs": [],
"source": [
"from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall\n",
"from torchmetrics.classification import Accuracy, Precision, Recall\n",
"\n",
"import flash\n",
"from flash.core.data import download_data\n",
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pytorch-lightning==1.2.0rc0 # todo: we shall align with real 1.2
torch>=1.7 # TODO: regenerate weights with lewer PT version
PyYAML>=5.1
Pillow>=7.2
torchmetrics>=0.2.0
torchvision>=0.8 # lower to 0.7 after PT 1.6
transformers>=4.0
pytorch-tabnet==3.1
Expand Down
6 changes: 6 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import urllib

# TorchVision hotfix https://github.com/pytorch/vision/issues/1938
opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)