Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor edits to moving averages. #83

Merged
merged 1 commit into from
Jun 9, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 29 additions & 31 deletions rlax/_src/moving_averages.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,38 @@
import jax.numpy as jnp


@chex.dataclass(frozen=True)
class EmaState:
# The tree of first moments.
first_moment: chex.ArrayTree
# The tree of second moments.
second_moment: chex.ArrayTree
# The product of the all decays from the start of accumulating.
decay_product: float


@chex.dataclass(frozen=True)
class EmaMoments:
"""data-class holding the latest mean and variance estimates."""
# The tree of means.
mean: chex.ArrayTree
# The tree of variances.
variance: chex.ArrayTree


@chex.dataclass(frozen=True)
class EmaState:
"""data-class holding the exponential moving average state."""
# The tree of exponential moving averages of the values
mu: chex.ArrayTree
# The tree of exponential moving averages of the squared values
nu: chex.ArrayTree
# The product of the all decays from the start of accumulating.
decay_product: float

def debiased_moments(self):
"""Returns debiased moments as in Adam."""
tiny = jnp.finfo(self.decay_product).tiny
debias = 1.0 / jnp.maximum(1 - self.decay_product, tiny)
mean = jax.tree_map(lambda m1: m1 * debias, self.mu)
# This computation of the variance may lose some numerical precision, if
# the mean is not approximately zero.
variance = jax.tree_map(
lambda m2, m: jnp.maximum(0.0, m2 * debias - jnp.square(m)),
self.nu, mean)
return EmaMoments(mean=mean, variance=variance)


def create_ema(decay=0.999, pmean_axis_name=None):
"""An updater of moments.

Expand All @@ -53,8 +67,7 @@ def create_ema(decay=0.999, pmean_axis_name=None):
def init_state(template_tree):
zeros = jax.tree_map(lambda x: jnp.zeros_like(jnp.mean(x)), template_tree)
scalar_zero = jnp.ones([], dtype=jnp.float32)
return EmaState(
first_moment=zeros, second_moment=zeros, decay_product=scalar_zero)
return EmaState(mu=zeros, nu=zeros, decay_product=scalar_zero)

def _update(moment, value):
mean = jnp.mean(value)
Expand All @@ -65,25 +78,10 @@ def _update(moment, value):

def update_moments(tree, state):
squared_tree = jax.tree_map(jnp.square, tree)
first_moment = jax.tree_map(_update, state.first_moment, tree)
second_moment = jax.tree_map(_update, state.second_moment, squared_tree)
mu = jax.tree_map(_update, state.mu, tree)
nu = jax.tree_map(_update, state.nu, squared_tree)
state = EmaState(
first_moment=first_moment, second_moment=second_moment,
decay_product=state.decay_product * decay)
moments = compute_moments(state)
return moments, state
mu=mu, nu=nu, decay_product=state.decay_product * decay)
return state.debiased_moments(), state

return init_state, update_moments


def compute_moments(state):
"""Returns debiased moments as in Adam."""
tiny = jnp.finfo(state.decay_product).tiny
debias = 1.0 / jnp.maximum(1 - state.decay_product, tiny)
mean = jax.tree_map(lambda m1: m1 * debias, state.first_moment)
# This computation of the variance may lose some numerical precision, if
# the mean is not approximately zero.
variance = jax.tree_map(
lambda m2, m: jnp.maximum(0.0, m2 * debias - jnp.square(m)),
state.second_moment, mean)
return EmaMoments(mean=mean, variance=variance)