-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
702 additions
and
503 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
from elegy import types | ||
from elegy import utils | ||
import typing as tp | ||
|
||
import jax.numpy as jnp | ||
|
||
from elegy.metrics.mean import Mean | ||
from elegy.metrics.accuracy import accuracy | ||
|
||
|
||
def binary_accuracy(y_true, y_pred, threshold=0.5): | ||
"""Calculates how often predictions matches binary labels. | ||
Standalone usage: | ||
>>> y_true = [[1], [1], [0], [0]] | ||
>>> y_pred = [[1], [1], [0], [0]] | ||
>>> m = tf.keras.metrics.binary_accuracy(y_true, y_pred) | ||
>>> assert m.shape == (4,) | ||
>>> m.numpy() | ||
array([1., 1., 1., 1.], dtype=float32) | ||
Args: | ||
y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. | ||
y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. | ||
threshold: (Optional) Float representing the threshold for deciding whether | ||
prediction values are 1 or 0. | ||
Returns: | ||
Binary accuracy values. shape = `[batch_size, d0, .. dN-1]` | ||
""" | ||
y_pred = y_pred > threshold | ||
return jnp.mean(y_true == y_pred, axis=-1) | ||
|
||
|
||
class BinaryAccuracy(Mean): | ||
""" | ||
Calculates how often predictions matches binary labels. | ||
This metric creates two local variables, `total` and `count` that are used to | ||
compute the frequency with which `y_pred` matches `y_true`. This frequency is | ||
ultimately returned as `binary accuracy`: an idempotent operation that simply | ||
divides `total` by `count`. | ||
If `sample_weight` is `None`, weights default to 1. | ||
Use `sample_weight` of 0 to mask values. | ||
Args: | ||
name: (Optional) string name of the metric instance. | ||
dtype: (Optional) data type of the metric result. | ||
threshold: (Optional) Float representing the threshold for deciding | ||
whether prediction values are 1 or 0. | ||
Standalone usage: | ||
>>> m = tf.keras.metrics.BinaryAccuracy() | ||
>>> m.update_state([[1], [1], [0], [0]], [[0.98], [1], [0], [0.6]]) | ||
>>> m.result().numpy() | ||
0.75 | ||
>>> m.reset_states() | ||
>>> m.update_state([[1], [1], [0], [0]], [[0.98], [1], [0], [0.6]], | ||
... sample_weight=[1, 0, 0, 1]) | ||
>>> m.result().numpy() | ||
0.5 | ||
Usage with `compile()` API: | ||
```python | ||
model.compile(optimizer='sgd', | ||
loss='mse', | ||
metrics=[tf.keras.metrics.BinaryAccuracy()]) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, threshold: float = 0.5, on: tp.Optional[types.IndexLike] = None, **kwargs | ||
): | ||
""" | ||
Creates a `CategoricalAccuracy` instance. | ||
Arguments: | ||
threshold: | ||
on: A string or integer, or iterable of string or integers, that | ||
indicate how to index/filter the `y_true` and `y_pred` | ||
arguments before passing them to `call`. For example if `on = "a"` then | ||
`y_true = y_true["a"]`. If `on` is an iterable | ||
the structures will be indexed iteratively, for example if `on = ["a", 0, "b"]` | ||
then `y_true = y_true["a"][0]["b"]`, same for `y_pred`. For more information | ||
check out [Keras-like behavior](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/#keras-like-behavior). | ||
kwargs: Additional keyword arguments passed to Module. | ||
""" | ||
super().__init__(on=on, **kwargs) | ||
self.threshold = threshold | ||
|
||
def call( | ||
self, | ||
y_true: jnp.ndarray, | ||
y_pred: jnp.ndarray, | ||
sample_weight: tp.Optional[jnp.ndarray] = None, | ||
) -> jnp.ndarray: | ||
""" | ||
Accumulates metric statistics. `y_true` and `y_pred` should have the same shape. | ||
Arguments: | ||
y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. | ||
y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. | ||
sample_weight: Optional `sample_weight` acts as a | ||
coefficient for the metric. If a scalar is provided, then the metric is | ||
simply scaled by the given value. If `sample_weight` is a tensor of size | ||
`[batch_size]`, then the metric for each sample of the batch is rescaled | ||
by the corresponding element in the `sample_weight` vector. If the shape | ||
of `sample_weight` is `[batch_size, d0, .. dN-1]` (or can be broadcasted | ||
to this shape), then each metric element of `y_pred` is scaled by the | ||
corresponding value of `sample_weight`. (Note on `dN-1`: all metric | ||
functions reduce by 1 dimension, usually the last axis (-1)). | ||
Returns: | ||
Array with the cumulative accuracy. | ||
""" | ||
|
||
return super().call( | ||
values=binary_accuracy( | ||
y_true=y_true, y_pred=y_pred, threshold=self.threshold | ||
), | ||
sample_weight=sample_weight, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from unittest import TestCase | ||
|
||
import jax.numpy as jnp | ||
import tensorflow.keras as tfk | ||
import numpy as np | ||
|
||
import elegy | ||
from elegy.testing_utils import transform_and_run | ||
|
||
|
||
class BinaryCrossentropyTest(TestCase): | ||
@transform_and_run | ||
def test_compatibility(self): | ||
|
||
y_true = (np.random.uniform(0, 1, size=(5, 6, 7)) > 0.5).astype(np.float32) | ||
y_pred = np.random.uniform(0, 1, size=(5, 6, 7)) | ||
sample_weight = np.random.uniform(0, 1, size=(5, 6)) | ||
|
||
assert np.allclose( | ||
tfk.metrics.BinaryAccuracy()(y_true, y_pred), | ||
elegy.metrics.BinaryAccuracy()(y_true, y_pred), | ||
) | ||
|
||
assert np.allclose( | ||
tfk.metrics.BinaryAccuracy(threshold=0.3)(y_true, y_pred), | ||
elegy.metrics.BinaryAccuracy(threshold=0.3)(y_true, y_pred), | ||
) | ||
|
||
assert np.allclose( | ||
tfk.metrics.BinaryAccuracy(threshold=0.3)( | ||
y_true, y_pred, sample_weight=sample_weight | ||
), | ||
elegy.metrics.BinaryAccuracy(threshold=0.3)( | ||
y_true, y_pred, sample_weight=sample_weight | ||
), | ||
) | ||
|
||
@transform_and_run | ||
def test_cummulative(self): | ||
|
||
tm = tfk.metrics.BinaryAccuracy(threshold=0.3) | ||
em = elegy.metrics.BinaryAccuracy(threshold=0.3) | ||
|
||
# 1st run | ||
y_true = (np.random.uniform(0, 1, size=(5, 6, 7)) > 0.5).astype(np.float32) | ||
y_pred = np.random.uniform(0, 1, size=(5, 6, 7)) | ||
sample_weight = np.random.uniform(0, 1, size=(5, 6)) | ||
|
||
assert np.allclose( | ||
tm(y_true, y_pred, sample_weight=sample_weight), | ||
em(y_true, y_pred, sample_weight=sample_weight), | ||
) | ||
|
||
# 2nd run | ||
y_true = (np.random.uniform(0, 1, size=(5, 6, 7)) > 0.5).astype(np.float32) | ||
y_pred = np.random.uniform(0, 1, size=(5, 6, 7)) | ||
sample_weight = np.random.uniform(0, 1, size=(5, 6)) | ||
|
||
assert np.allclose( | ||
tm(y_true, y_pred, sample_weight=sample_weight), | ||
em(y_true, y_pred, sample_weight=sample_weight), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.