Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
switch to torchmetrics (#169)
Browse files Browse the repository at this point in the history
* switch to torchmetrics

* 0.2.0

* mnist

* cache
  • Loading branch information
Borda authored Mar 12, 2021
1 parent dbb768c commit 1f93ce1
Show file tree
Hide file tree
Showing 13 changed files with 33 additions and 14 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ To illustrate, say we want to build a model to predict if a passenger survived o

```python
# import our libraries
from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall
from torchmetrics.classification import Accuracy, Precision, Recall
import flash
from flash import download_data
from flash.tabular import TabularClassifier, TabularData
Expand Down
2 changes: 1 addition & 1 deletion docs/source/reference/tabular_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Next, we create the :class:`~flash.tabular.TabularClassifier` task, using the Da
import flash
from flash import download_data
from flash.tabular import TabularClassifier, TabularData
from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall
from torchmetrics.classification import Accuracy, Precision, Recall
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/')
Expand Down
5 changes: 3 additions & 2 deletions flash/tabular/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
from typing import Any, Callable, List, Optional, Tuple, Type

import torch
from pytorch_lightning.metrics import Metric
from pytorch_tabnet.tab_network import TabNet
from torch.nn import functional as F
from torch.nn.functional import softmax
from torchmetrics import Metric

from flash.core.classification import ClassificationTask
from flash.core.data import DataPipeline
Expand Down Expand Up @@ -85,7 +86,7 @@ def predict(
def forward(self, x_in):
# TabNet takes single input, x_in is composed of (categorical, numerical)
x = torch.cat([x for x in x_in if x.numel()], dim=1)
return self.model(x)[0]
return softmax(self.model(x)[0])

@classmethod
def from_data(cls, datamodule, **kwargs) -> 'TabularClassifier':
Expand Down
2 changes: 1 addition & 1 deletion flash/text/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Callable, Mapping, Sequence, Type, Union

import torch
from pytorch_lightning.metrics.classification import Accuracy
from torchmetrics.classification import Accuracy
from transformers import BertForSequenceClassification

from flash.core.classification import ClassificationDataPipeline, ClassificationTask
Expand Down
2 changes: 1 addition & 1 deletion flash/text/seq2seq/summarization/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

import numpy as np
import torch
from pytorch_lightning.metrics import Metric
from rouge_score import rouge_scorer, scoring
from rouge_score.scoring import AggregateScore, Score
from torchmetrics import Metric

from flash.text.seq2seq.summarization.utils import add_newline_to_end_of_each_sentence

Expand Down
2 changes: 1 addition & 1 deletion flash/text/seq2seq/translation/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import List

import torch
from pytorch_lightning.metrics import Metric
from torchmetrics import Metric


def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter:
Expand Down
7 changes: 4 additions & 3 deletions flash/vision/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
from typing import Any, Callable, Mapping, Sequence, Type, Union

import torch
from pytorch_lightning.metrics import Accuracy
from torch import nn
from torch.nn import functional as F
from torch.nn.functional import softmax
from torchmetrics import Accuracy

from flash.core.classification import ClassificationTask
from flash.vision.backbones import backbone_and_num_features
Expand All @@ -33,7 +34,7 @@ class ImageClassifier(ClassificationTask):
loss_fn: Loss function for training, defaults to :func:`torch.nn.functional.cross_entropy`.
optimizer: Optimizer to use for training, defaults to :class:`torch.optim.SGD`.
metrics: Metrics to compute for training and evaluation,
defaults to :class:`pytorch_lightning.metrics.Accuracy`.
defaults to :class:`torchmetrics.Accuracy`.
learning_rate: Learning rate to use for training, defaults to ``1e-3``.
"""

Expand Down Expand Up @@ -67,7 +68,7 @@ def __init__(

def forward(self, x) -> Any:
x = self.backbone(x)
return self.head(x)
return softmax(self.head(x))

@staticmethod
def default_pipeline() -> ImageClassificationDataPipeline:
Expand Down
2 changes: 1 addition & 1 deletion flash/vision/embedding/image_embedder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from typing import Any, Callable, Mapping, Optional, Sequence, Type, Union

import torch
from pytorch_lightning.metrics import Accuracy
from pytorch_lightning.utilities.distributed import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn
from torch.nn import functional as F
from torchmetrics import Accuracy

from flash.core import Task
from flash.core.data import TaskDataPipeline
Expand Down
2 changes: 1 addition & 1 deletion flash_examples/finetuning/tabular_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall
from torchmetrics.classification import Accuracy, Precision, Recall

import flash
from flash.core.data import download_data
Expand Down
12 changes: 11 additions & 1 deletion flash_examples/generic_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,33 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import urllib

import pytorch_lightning as pl
from torch import nn, optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

from flash import ClassificationTask

_PATH_ROOT = os.path.dirname(os.path.dirname(__file__))
# TorchVision hotfix https://github.com/pytorch/vision/issues/1938
opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)

# 1. Load a basic backbone
model = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 128),
nn.ReLU(),
nn.Linear(128, 10),
nn.Softmax(),
)

# 2. Load a dataset
dataset = datasets.MNIST('./data', download=True, transform=transforms.ToTensor())
dataset = datasets.MNIST(os.path.join(_PATH_ROOT, 'data'), download=True, transform=transforms.ToTensor())

# 3. Split the data randomly
train, val, test = random_split(dataset, [50000, 5000, 5000]) # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion flash_notebooks/tabular_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"metadata": {},
"outputs": [],
"source": [
"from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall\n",
"from torchmetrics.classification import Accuracy, Precision, Recall\n",
"\n",
"import flash\n",
"from flash.core.data import download_data\n",
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pytorch-lightning==1.2.0rc0 # todo: we shall align with real 1.2
torch>=1.7 # TODO: regenerate weights with lewer PT version
PyYAML>=5.1
Pillow>=7.2
torchmetrics>=0.2.0
torchvision>=0.8 # lower to 0.7 after PT 1.6
transformers>=4.0
pytorch-tabnet==3.1
Expand Down
6 changes: 6 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import urllib

# TorchVision hotfix https://github.com/pytorch/vision/issues/1938
opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)

0 comments on commit 1f93ce1

Please sign in to comment.