Skip to content

Commit

Permalink
Implement Orca Pytorch metrics (#3545)
Browse files Browse the repository at this point in the history
* SparseCategoricalAccuracy

* SparseCategoricalAccuracy

* CategoricalAccuracy

* fix

* BinaryAccuracy

* Top5Accuracy

* UT

* license
  • Loading branch information
leonardozcm authored Mar 5, 2021
1 parent 4892245 commit bbc04e6
Show file tree
Hide file tree
Showing 3 changed files with 249 additions and 3 deletions.
83 changes: 83 additions & 0 deletions pyzoo/test/zoo/orca/learn/test_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch


def test_torch_Accuracy():
from zoo.orca.learn.pytorch.pytorch_metrics import Accuracy
pred = torch.tensor([0, 2, 3, 4])
target = torch.tensor([1, 2, 3, 4])
acc = Accuracy()
acc(pred, target)
assert acc.compute() == 0.75
pred = torch.tensor([0, 2, 3, 4])
target = torch.tensor([1, 1, 2, 4])
acc(pred, target)
assert acc.compute() == 0.5


def test_torch_BinaryAccuracy():
from zoo.orca.learn.pytorch.pytorch_metrics import BinaryAccuracy
target = torch.tensor([1, 1, 0, 0])
pred = torch.tensor([0.98, 1, 0, 0.6])
bac = BinaryAccuracy()
bac(pred, target)
assert bac.compute() == 0.75
target = torch.tensor([1, 1, 0, 0])
pred = torch.tensor([0.98, 1, 0, 0.6])
bac(pred, target, threshold=0.7)
assert bac.compute() == 0.875


def test_torch_CategoricalAccuracy():
from zoo.orca.learn.pytorch.pytorch_metrics import CategoricalAccuracy
pred = torch.tensor([[0.1, 0.9, 0.8], [0.05, 0.95, 0]])
target = torch.tensor([[0, 0, 1], [0, 1, 0]])
cacc = CategoricalAccuracy()
cacc(pred, target)
assert cacc.compute() == 0.5
pred = torch.tensor([[0.1, 0.9, 0.8], [0.05, 0.95, 0]])
target = torch.tensor([[0, 1, 0], [0, 1, 0]])
cacc(pred, target)
assert cacc.compute() == 0.75


def test_torch_SparseCategoricalAccuracy():
from zoo.orca.learn.pytorch.pytorch_metrics import SparseCategoricalAccuracy
pred = torch.tensor([[0.1, 0.9, 0.8], [0.05, 0.95, 0]])
target = torch.tensor([[2], [1]])
scacc = SparseCategoricalAccuracy()
scacc(pred, target)
assert scacc.compute() == 0.5
pred = torch.tensor([[0.1, 0.9, 0.8], [0.05, 0.95, 0]])
target = torch.tensor([2, 0])
scacc(pred, target)
assert scacc.compute() == 0.25


def test_torch_Top5Accuracy():
from zoo.orca.learn.pytorch.pytorch_metrics import Top5Accuracy
pred = torch.tensor([[0.1, 0.9, 0.8, 0.4, 0.5, 0.2],
[0.05, 0.95, 0, 0.4, 0.5, 0.2]])
target = torch.tensor([2, 2])
top5acc = Top5Accuracy()
top5acc(pred, target)
assert top5acc.compute() == 0.5
pred = torch.tensor([[0.1, 0.9, 0.8, 0.4, 0.5, 0.2],
[0.05, 0.95, 0, 0.4, 0.5, 0.2]])
target = torch.tensor([[2], [1]])
top5acc(pred, target)
assert top5acc.compute() == 0.75
16 changes: 16 additions & 0 deletions pyzoo/zoo/orca/learn/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ def get_bigdl_metric(self):
SparseCategoricalAccuracy as KerasSparseCategoricalAccuracy
return KerasSparseCategoricalAccuracy()

def get_pytorch_metric(self):
from zoo.orca.learn.pytorch import pytorch_metrics
return pytorch_metrics.SparseCategoricalAccuracy()

def get_name(self):
return "SparseCategoricalAccuracy"

Expand All @@ -183,6 +187,10 @@ def get_bigdl_metric(self):
from zoo.pipeline.api.keras.metrics import CategoricalAccuracy as KerasCategoricalAccuracy
return KerasCategoricalAccuracy()

def get_pytorch_metric(self):
from zoo.orca.learn.pytorch import pytorch_metrics
return pytorch_metrics.CategoricalAccuracy()

def get_name(self):
return "CategoricalAccuracy"

Expand All @@ -198,6 +206,10 @@ def get_bigdl_metric(self):
from zoo.pipeline.api.keras.metrics import BinaryAccuracy as KerasBinaryAccuracy
return KerasBinaryAccuracy()

def get_pytorch_metric(self):
from zoo.orca.learn.pytorch import pytorch_metrics
return pytorch_metrics.BinaryAccuracy()

def get_name(self):
return "BinaryAccuracy"

Expand All @@ -217,5 +229,9 @@ def get_bigdl_metric(self):
from zoo.pipeline.api.keras.metrics import Top5Accuracy as KerasTop5Accuracy
return KerasTop5Accuracy()

def get_pytorch_metric(self):
from zoo.orca.learn.pytorch import pytorch_metrics
return pytorch_metrics.Top5Accuracy()

def get_name(self):
return "Top5Accuracy"
153 changes: 150 additions & 3 deletions pyzoo/zoo/orca/learn/pytorch/pytorch_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,171 @@ def _unify_input_formats(preds, target):
raise ValueError("preds the same or one more dimensions than targets")

if preds.ndim == target.ndim + 1:
preds = torch.argmax(preds, dim=1)
preds = torch.argmax(preds, dim=-1)

if preds.ndim == target.ndim and preds.is_floating_point():
preds = (preds >= 0.5).long()
return preds, target


class Accuracy:
"""Calculates how often predictions matches labels.
def __init__(self):
For example, if `y_true` is tensor([1, 2, 3, 4])_ and `y_pred` is tensor([0, 2, 3, 4])
then the accuracy is 3/4 or .75. If the weights were specified as
tensor([1, 1, 0, 0]) then the accuracy would be 1/2 or .5.
Usage:
```python
acc = Accuracy()
acc(torch.tensor([0, 2, 3, 4]), torch.tensor([1, 2, 3, 4]))
assert acc.compute() == 0.75
```
"""

def __init__(self):
self.correct = torch.tensor(0)
self.total = torch.tensor(0)

def __call__(self, preds, targets):
preds, target = _unify_input_formats(preds, targets)
self.correct += torch.sum(preds == target)
self.correct += torch.sum(torch.eq(preds, targets))
self.total += target.numel()

def compute(self):
return self.correct.float() / self.total


class SparseCategoricalAccuracy:
"""Calculates how often predictions matches integer labels.
For example, if `y_true` is tensor([[2], [1]]) and `y_pred` is
tensor([[0.1, 0.9, 0.8], [0.05, 0.95, 0]]) then the categorical accuracy is 1/2 or .5.
If the weights were specified as tensor([0.7, 0.3]) then the categorical accuracy
would be .3. You can provide logits of classes as `y_pred`, since argmax of
logits and probabilities are same.
Usage:
```python
acc = SparseCategoricalAccuracy()
acc(torch.tensor([[0.1, 0.9, 0.8], [0.05, 0.95, 0]]), torch.tensor([[2], [1]]))
assert acc.compute() == 0.5
```
"""

def __init__(self):
self.total = torch.tensor(0)
self.correct = torch.tensor(0)

def __call__(self, preds, targets):
batch_size = targets.size(0)
if preds.ndim == targets.ndim:
targets = torch.squeeze(targets, dim=-1)
preds = torch.argmax(preds, dim=-1)
preds = preds.type_as(targets)
self.correct += torch.sum(torch.eq(preds, targets))
self.total += batch_size

def compute(self):
return self.correct.float() / self.total


class CategoricalAccuracy:
"""Calculates how often predictions matches integer labels.
For example, if `y_true` is torch.tensor([[0, 0, 1], [0, 1, 0]]) and `y_pred` is
torch.tensor([[0.1, 0.9, 0.8], [0.05, 0.95, 0]]) then the categorical accuracy is 1/2 or .5.
If the weights were specified as tensor([0.7, 0.3]) then the categorical accuracy
would be .3. You can provide logits of classes as `y_pred`, since argmax of
logits and probabilities are same.
Usage:
```python
pred = torch.tensor([[0.1, 0.9, 0.8], [0.05, 0.95, 0]])
target = torch.tensor([[0, 0, 1], [0, 1, 0]])
cacc = CategoricalAccuracy()
cacc(pred, target)
```
"""
def __init__(self):
self.total = torch.tensor(0)
self.correct = torch.tensor(0)

def __call__(self, preds, targets):
batch_size = targets.size(0)
self.correct += torch.sum(
torch.eq(
torch.argmax(preds, dim=-1), torch.argmax(targets, dim=-1)))
self.total += batch_size

def compute(self):
return self.correct.float() / self.total


class BinaryAccuracy:
"""Calculates how often predictions matches labels.
For example, if `y_true` is tensor([1, 1, 0, 0]) and `y_pred` is tensor([0.98, 1, 0, 0.6])
then the binary accuracy is 3/4 or .75. If the weights were specified as
[1, 0, 0, 1] then the binary accuracy would be 1/2 or .5.
Usage:
```python
target = torch.tensor([1, 1, 0, 0])
pred = torch.tensor([0.98, 1, 0, 0.6])
bac = BinaryAccuracy()
bac(pred, target)
assert bac.compute() == 0.75
```
"""

def __init__(self):
self.total = torch.tensor(0)
self.correct = torch.tensor(0)

def __call__(self, preds, targets, threshold=0.5):
batch_size = targets.size(0)
threshold = torch.tensor(threshold)
self.correct += torch.sum(
torch.eq(
torch.gt(preds, threshold), targets))
self.total += batch_size

def compute(self):
return self.correct.float() / self.total


class Top5Accuracy:
"""Computes how often integer targets are in the top `K` predictions.
Usage:
```python
pred = torch.tensor([[0.1, 0.9, 0.8, 0.4, 0.5, 0.2],
[0.05, 0.95, 0, 0.4, 0.5, 0.2]])
target = torch.tensor([2, 2])
top5acc = Top5Accuracy()
top5acc(pred, target)
assert top5acc.compute() == 0.5
```
"""

def __init__(self):
self.total = torch.tensor(0)
self.correct = torch.tensor(0)

def __call__(self, preds, targets):
batch_size = targets.size(0)
_, preds = preds.topk(5, dim=-1, largest=True, sorted=True)
preds = preds.type_as(targets).t()
targets = targets.view(1, -1).expand_as(preds)

self.correct += preds.eq(targets).view(-1).sum()
self.total += batch_size

def compute(self):
return self.correct.float() / self.total

0 comments on commit bbc04e6

Please sign in to comment.