Skip to content

Commit

Permalink
feature/binary-accuracy (#87)
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae authored Aug 24, 2020
1 parent 67d9f2b commit 5742d16
Show file tree
Hide file tree
Showing 8 changed files with 702 additions and 503 deletions.
24 changes: 12 additions & 12 deletions elegy/callbacks/callback_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,18 +141,18 @@ def _call_batch_hook(self, mode, hook, batch, logs=None):
batch_hook(batch, logs)
self._delta_ts[hook_name].append(time.time() - t_before_callbacks)

delta_t_median = np.median(self._delta_ts[hook_name])
if (
self._delta_t_batch > 0.0
and delta_t_median > 0.95 * self._delta_t_batch
and delta_t_median > 0.1
):
logging.warning(
"Method (%s) is slow compared "
"to the batch update (%f). Check your callbacks.",
hook_name,
delta_t_median,
)
# delta_t_median = np.median(self._delta_ts[hook_name])
# if (
# self._delta_t_batch > 0.0
# and delta_t_median > 0.95 * self._delta_t_batch
# and delta_t_median > 0.1
# ):
# logging.warning(
# "Method (%s) is slow compared "
# "to the batch update (%f). Check your callbacks.",
# hook_name,
# delta_t_median,
# )

def _call_begin_hook(self, mode):
"""Helper function for on_{train|test|predict}_begin methods."""
Expand Down
3 changes: 3 additions & 0 deletions elegy/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .accuracy import Accuracy, accuracy
from .binary_accuracy import BinaryAccuracy, binary_accuracy
from .binary_crossentropy import BinaryCrossentropy, binary_crossentropy
from .categorical_accuracy import CategoricalAccuracy, categorical_accuracy
from .mean import Mean
Expand Down Expand Up @@ -30,4 +31,6 @@
"SparseCategoricalAccuracy",
"sparse_categorical_accuracy",
"Sum",
"BinaryAccuracy",
"binary_accuracy",
]
115 changes: 115 additions & 0 deletions elegy/metrics/binary_accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from elegy import types
from elegy import utils
import typing as tp

import jax.numpy as jnp

from elegy.metrics.mean import Mean
from elegy.metrics.accuracy import accuracy


def binary_accuracy(y_true, y_pred, threshold=0.5):
"""Calculates how often predictions matches binary labels.
Standalone usage:
>>> y_true = [[1], [1], [0], [0]]
>>> y_pred = [[1], [1], [0], [0]]
>>> m = tf.keras.metrics.binary_accuracy(y_true, y_pred)
>>> assert m.shape == (4,)
>>> m.numpy()
array([1., 1., 1., 1.], dtype=float32)
Args:
y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
threshold: (Optional) Float representing the threshold for deciding whether
prediction values are 1 or 0.
Returns:
Binary accuracy values. shape = `[batch_size, d0, .. dN-1]`
"""
y_pred = y_pred > threshold
return jnp.mean(y_true == y_pred, axis=-1)


class BinaryAccuracy(Mean):
"""
Calculates how often predictions matches binary labels.
This metric creates two local variables, `total` and `count` that are used to
compute the frequency with which `y_pred` matches `y_true`. This frequency is
ultimately returned as `binary accuracy`: an idempotent operation that simply
divides `total` by `count`.
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
Args:
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
threshold: (Optional) Float representing the threshold for deciding
whether prediction values are 1 or 0.
Standalone usage:
>>> m = tf.keras.metrics.BinaryAccuracy()
>>> m.update_state([[1], [1], [0], [0]], [[0.98], [1], [0], [0.6]])
>>> m.result().numpy()
0.75
>>> m.reset_states()
>>> m.update_state([[1], [1], [0], [0]], [[0.98], [1], [0], [0.6]],
... sample_weight=[1, 0, 0, 1])
>>> m.result().numpy()
0.5
Usage with `compile()` API:
```python
model.compile(optimizer='sgd',
loss='mse',
metrics=[tf.keras.metrics.BinaryAccuracy()])
```
"""

def __init__(
self, threshold: float = 0.5, on: tp.Optional[types.IndexLike] = None, **kwargs
):
"""
Creates a `CategoricalAccuracy` instance.
Arguments:
threshold:
on: A string or integer, or iterable of string or integers, that
indicate how to index/filter the `y_true` and `y_pred`
arguments before passing them to `call`. For example if `on = "a"` then
`y_true = y_true["a"]`. If `on` is an iterable
the structures will be indexed iteratively, for example if `on = ["a", 0, "b"]`
then `y_true = y_true["a"][0]["b"]`, same for `y_pred`. For more information
check out [Keras-like behavior](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/#keras-like-behavior).
kwargs: Additional keyword arguments passed to Module.
"""
super().__init__(on=on, **kwargs)
self.threshold = threshold

def call(
self,
y_true: jnp.ndarray,
y_pred: jnp.ndarray,
sample_weight: tp.Optional[jnp.ndarray] = None,
) -> jnp.ndarray:
"""
Accumulates metric statistics. `y_true` and `y_pred` should have the same shape.
Arguments:
y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
sample_weight: Optional `sample_weight` acts as a
coefficient for the metric. If a scalar is provided, then the metric is
simply scaled by the given value. If `sample_weight` is a tensor of size
`[batch_size]`, then the metric for each sample of the batch is rescaled
by the corresponding element in the `sample_weight` vector. If the shape
of `sample_weight` is `[batch_size, d0, .. dN-1]` (or can be broadcasted
to this shape), then each metric element of `y_pred` is scaled by the
corresponding value of `sample_weight`. (Note on `dN-1`: all metric
functions reduce by 1 dimension, usually the last axis (-1)).
Returns:
Array with the cumulative accuracy.
"""

return super().call(
values=binary_accuracy(
y_true=y_true, y_pred=y_pred, threshold=self.threshold
),
sample_weight=sample_weight,
)
62 changes: 62 additions & 0 deletions elegy/metrics/binary_accuracy_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from unittest import TestCase

import jax.numpy as jnp
import tensorflow.keras as tfk
import numpy as np

import elegy
from elegy.testing_utils import transform_and_run


class BinaryCrossentropyTest(TestCase):
@transform_and_run
def test_compatibility(self):

y_true = (np.random.uniform(0, 1, size=(5, 6, 7)) > 0.5).astype(np.float32)
y_pred = np.random.uniform(0, 1, size=(5, 6, 7))
sample_weight = np.random.uniform(0, 1, size=(5, 6))

assert np.allclose(
tfk.metrics.BinaryAccuracy()(y_true, y_pred),
elegy.metrics.BinaryAccuracy()(y_true, y_pred),
)

assert np.allclose(
tfk.metrics.BinaryAccuracy(threshold=0.3)(y_true, y_pred),
elegy.metrics.BinaryAccuracy(threshold=0.3)(y_true, y_pred),
)

assert np.allclose(
tfk.metrics.BinaryAccuracy(threshold=0.3)(
y_true, y_pred, sample_weight=sample_weight
),
elegy.metrics.BinaryAccuracy(threshold=0.3)(
y_true, y_pred, sample_weight=sample_weight
),
)

@transform_and_run
def test_cummulative(self):

tm = tfk.metrics.BinaryAccuracy(threshold=0.3)
em = elegy.metrics.BinaryAccuracy(threshold=0.3)

# 1st run
y_true = (np.random.uniform(0, 1, size=(5, 6, 7)) > 0.5).astype(np.float32)
y_pred = np.random.uniform(0, 1, size=(5, 6, 7))
sample_weight = np.random.uniform(0, 1, size=(5, 6))

assert np.allclose(
tm(y_true, y_pred, sample_weight=sample_weight),
em(y_true, y_pred, sample_weight=sample_weight),
)

# 2nd run
y_true = (np.random.uniform(0, 1, size=(5, 6, 7)) > 0.5).astype(np.float32)
y_pred = np.random.uniform(0, 1, size=(5, 6, 7))
sample_weight = np.random.uniform(0, 1, size=(5, 6))

assert np.allclose(
tm(y_true, y_pred, sample_weight=sample_weight),
em(y_true, y_pred, sample_weight=sample_weight),
)
32 changes: 14 additions & 18 deletions elegy/metrics/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from enum import Enum

import jax.numpy as jnp
import numpy as np

from elegy import initializers, module, types, utils, hooks
from elegy.metrics.metric import Metric
Expand All @@ -18,7 +19,7 @@ def reduce(
count: tp.Optional[jnp.ndarray],
values: jnp.ndarray,
reduction: Reduction,
sample_weight: tp.Optional[jnp.ndarray],
sample_weight: tp.Optional[np.ndarray],
dtype: jnp.dtype,
) -> tp.Tuple[jnp.ndarray, jnp.ndarray, tp.Optional[jnp.ndarray]]:

Expand All @@ -30,23 +31,18 @@ def reduce(
# values, sample_weight=sample_weight
# )

# try:
# # Broadcast weights if possible.
# sample_weight = weights_broadcast_ops.broadcast_weights(
# sample_weight, values
# )
# except ValueError:
# # Reduce values to same ndim as weight array
# ndim = K.ndim(values)
# weight_ndim = K.ndim(sample_weight)
# if reduction == metrics_utils.Reduction.SUM:
# values = math_ops.reduce_sum(
# values, axis=list(range(weight_ndim, ndim))
# )
# else:
# values = math_ops.reduce_mean(
# values, axis=list(range(weight_ndim, ndim))
# )
try:
# Broadcast weights if possible.
sample_weight = jnp.broadcast_to(sample_weight, values.shape)
except ValueError:
# Reduce values to same ndim as weight array
ndim = values.ndim
weight_ndim = sample_weight.ndim
if reduction == Reduction.SUM:
values = jnp.sum(values, axis=list(range(weight_ndim, ndim)))
else:
values = jnp.mean(values, axis=list(range(weight_ndim, ndim)))

values = values * sample_weight

value_sum = jnp.sum(values)
Expand Down
4 changes: 4 additions & 0 deletions elegy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,8 @@ class during training. This can be useful to tell the model to "pay
if metrics_states is not None:
self.metrics_states = metrics_states

# logs = jax.tree_map(np.asarray, logs)

return logs

def _update(
Expand Down Expand Up @@ -1079,6 +1081,8 @@ def test_on_batch(
if metrics_states is not None:
self.metrics_states = metrics_states

# logs = jax.tree_map(np.asarray, logs)

return logs

def _test(
Expand Down
Loading

0 comments on commit 5742d16

Please sign in to comment.