Skip to content

Commit

Permalink
also fix deprecation warning of using a_min, a_max in jnp.clip
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed Jun 25, 2024
1 parent 81bd46f commit 1cce538
Show file tree
Hide file tree
Showing 17 changed files with 34 additions and 41 deletions.
2 changes: 1 addition & 1 deletion examples/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def predict(rng_key, X, Y, X_test, var, length, noise, use_cholesky=True):
K = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, Y))

sigma_noise = jnp.sqrt(jnp.clip(jnp.diag(K), a_min=0.0)) * jax.random.normal(
sigma_noise = jnp.sqrt(jnp.clip(jnp.diag(K), 0.0)) * jax.random.normal(
rng_key, X_test.shape[:1]
)

Expand Down
4 changes: 2 additions & 2 deletions notebooks/source/time_series_forecasting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@
" level, s, moving_sum = carry\n",
" season = s[0] * level**pow_season\n",
" exp_val = level + coef_trend * level**pow_trend + season\n",
" exp_val = jnp.clip(exp_val, a_min=0)\n",
" exp_val = jnp.clip(exp_val, 0)\n",
" # use expected vale when forecasting\n",
" y_t = jnp.where(t >= N, exp_val, y[t])\n",
"\n",
Expand All @@ -215,7 +215,7 @@
" )\n",
" level_p = jnp.where(t >= seasonality, moving_sum / seasonality, y_t - season)\n",
" level = level_sm * level_p + (1 - level_sm) * level\n",
" level = jnp.clip(level, a_min=0)\n",
" level = jnp.clip(level, 0)\n",
"\n",
" new_s = (s_sm * (y_t - level) / season + (1 - s_sm)) * s[0]\n",
" # repeat s when forecasting\n",
Expand Down
2 changes: 1 addition & 1 deletion numpyro/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def effective_sample_size(x):
Rho_k = np.concatenate(
[
Rho_init,
np.minimum.accumulate(np.clip(Rho_k[1:, ...], a_min=0, a_max=None), axis=0),
np.minimum.accumulate(np.clip(Rho_k[1:, ...], 0, None), axis=0),
],
axis=0,
)
Expand Down
12 changes: 5 additions & 7 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,7 @@ def sample(self, key, sample_shape=()):
assert is_prng_key(key)
shape = sample_shape + self.batch_shape
samples = random.dirichlet(key, self.concentration, shape=shape)
return jnp.clip(
samples, a_min=jnp.finfo(samples).tiny, a_max=1 - jnp.finfo(samples).eps
)
return jnp.clip(samples, jnp.finfo(samples).tiny, 1 - jnp.finfo(samples).eps)

@validate_sample
def log_prob(self, value):
Expand Down Expand Up @@ -840,15 +838,15 @@ def sample(self, key, sample_shape=()):
u = random.uniform(
key, shape=sample_shape + self.batch_shape, minval=finfo.tiny
)
u_con0 = jnp.clip(u ** (1 / self.concentration0), a_max=1 - finfo.eps)
u_con0 = jnp.clip(u ** (1 / self.concentration0), None, 1 - finfo.eps)
log_sample = jnp.log1p(-u_con0) / self.concentration1
return jnp.clip(jnp.exp(log_sample), a_min=finfo.tiny, a_max=1 - finfo.eps)
return jnp.clip(jnp.exp(log_sample), finfo.tiny, 1 - finfo.eps)

@validate_sample
def log_prob(self, value):
finfo = jnp.finfo(jnp.result_type(float))
normalize_term = jnp.log(self.concentration0) + jnp.log(self.concentration1)
value_con1 = jnp.clip(value**self.concentration1, a_max=1 - finfo.eps)
value_con1 = jnp.clip(value**self.concentration1, None, 1 - finfo.eps)
return (
xlogy(self.concentration1 - 1, value)
+ xlog1py(self.concentration0 - 1, -value_con1)
Expand Down Expand Up @@ -2363,7 +2361,7 @@ def log_prob(self, value):

def cdf(self, value):
cdf = (value - self.low) / (self.high - self.low)
return jnp.clip(cdf, a_min=0.0, a_max=1.0)
return jnp.clip(cdf, 0.0, 1.0)

def icdf(self, value):
return self.low + value * (self.high - self.low)
Expand Down
2 changes: 1 addition & 1 deletion numpyro/distributions/directional.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def norm_const(self):
lbinoms = num - 2 * den

fs = lbinoms.reshape(-1, 1) + m * (
jnp.log(jnp.clip(corr**2, a_min=jnp.finfo(jnp.result_type(float)).tiny))
jnp.log(jnp.clip(corr**2, jnp.finfo(jnp.result_type(float)).tiny))
- jnp.log(4 * jnp.prod(conc, axis=-1))
)
fs += log_I1(49, conc, terms=51).sum(-1)
Expand Down
4 changes: 2 additions & 2 deletions numpyro/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _to_probs_multinom(logits):

def _to_logits_multinom(probs):
minval = jnp.finfo(jnp.result_type(probs)).min
return jnp.clip(jnp.log(probs), a_min=minval)
return jnp.clip(jnp.log(probs), minval)


class BernoulliProbs(Distribution):
Expand Down Expand Up @@ -443,7 +443,7 @@ def log_prob(self, value):

def cdf(self, value):
cdf = (jnp.floor(value) + 1 - self.low) / (self.high - self.low + 1)
return jnp.clip(cdf, a_min=0.0, a_max=1.0)
return jnp.clip(cdf, 0.0, 1.0)

def icdf(self, value):
return self.low + value * (self.high - self.low + 1) - 1
Expand Down
2 changes: 1 addition & 1 deletion numpyro/distributions/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


def _clamp_preserve_gradients(x, min, max):
return x + lax.stop_gradient(jnp.clip(x, a_min=min, a_max=max) - x)
return x + lax.stop_gradient(jnp.clip(x, min, max) - x)


# adapted from https://github.com/pyro-ppl/pyro/blob/dev/pyro/distributions/transforms/iaf.py
Expand Down
15 changes: 6 additions & 9 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import weakref

import numpy as np
from numpy.core.numeric import normalize_axis_tuple

import jax
from jax import lax, vmap
Expand Down Expand Up @@ -58,7 +57,7 @@

def _clipped_expit(x):
finfo = jnp.finfo(jnp.result_type(x))
return jnp.clip(expit(x), a_min=finfo.tiny, a_max=1.0 - finfo.eps)
return jnp.clip(expit(x), finfo.tiny, 1.0 - finfo.eps)


class Transform(object):
Expand Down Expand Up @@ -651,11 +650,11 @@ def _inverse(self, y):
pad_width = [(0, 0)] * (y.ndim - 1) + [(1, 0)]
remainder = jnp.pad(remainder, pad_width, mode="constant", constant_values=1.0)
finfo = jnp.finfo(y.dtype)
remainder = jnp.clip(remainder, a_min=finfo.tiny)
remainder = jnp.clip(remainder, finfo.tiny)
t = y / remainder

# inverse of tanh
t = jnp.clip(t, a_min=-1 + finfo.eps, a_max=1 - finfo.eps)
t = jnp.clip(t, -1 + finfo.eps, 1 - finfo.eps)
return jnp.arctanh(t)

def log_abs_det_jacobian(self, x, y, intermediates=None):
Expand All @@ -667,7 +666,7 @@ def log_abs_det_jacobian(self, x, y, intermediates=None):
# of the diagonal part of the jacobian
one_minus_remainder = jnp.cumsum(jnp.abs(y[..., :-1]), axis=-1)
eps = jnp.finfo(y.dtype).eps
one_minus_remainder = jnp.clip(one_minus_remainder, a_max=1 - eps)
one_minus_remainder = jnp.clip(one_minus_remainder, None, 1 - eps)
# log(remainder) = log1p(remainder - 1)
stick_breaking_logdet = jnp.sum(jnp.log1p(-one_minus_remainder), axis=-1)

Expand Down Expand Up @@ -1075,9 +1074,7 @@ def __call__(self, x):

def _inverse(self, y):
y_crop = y[..., :-1]
z1m_cumprod = jnp.clip(
1 - jnp.cumsum(y_crop, axis=-1), a_min=jnp.finfo(y.dtype).tiny
)
z1m_cumprod = jnp.clip(1 - jnp.cumsum(y_crop, axis=-1), jnp.finfo(y.dtype).tiny)
# hence x = logit(z) = log(z / (1 - z)) = y[::-1] / z1m_cumprod
x = jnp.log(y_crop / z1m_cumprod)
return x + jnp.log(x.shape[-1] - jnp.arange(x.shape[-1]))
Expand Down Expand Up @@ -1418,7 +1415,7 @@ def _inverse(self, y: jnp.ndarray) -> jnp.ndarray:
return y

def extend_axis_rev(self, array: jnp.ndarray, axis: int) -> jnp.ndarray:
normalized_axis = normalize_axis_tuple(axis, array.ndim)[0]
normalized_axis = axis if axis >= 0 else jnp.ndim(array) + axis

n = array.shape[normalized_axis]
last = jnp.take(array, jnp.array([-1]), axis=normalized_axis)
Expand Down
2 changes: 1 addition & 1 deletion numpyro/distributions/truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def sample(self, key, sample_shape=()):
key, jnp.ones(self.batch_shape + sample_shape + (self.num_gamma_variates,))
)
x = jnp.sum(x / denom, axis=-1)
return jnp.clip(x * (0.5 / jnp.pi**2), a_max=self.truncation_point)
return jnp.clip(x * (0.5 / jnp.pi**2), None, self.truncation_point)

@validate_sample
def log_prob(self, value):
Expand Down
6 changes: 3 additions & 3 deletions numpyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def scan_fn(carry, val):
def signed_stick_breaking_tril(t):
# make sure that t in (-1, 1)
eps = jnp.finfo(t.dtype).eps
t = jnp.clip(t, a_min=(-1 + eps), a_max=(1 - eps))
t = jnp.clip(t, -1 + eps, 1 - eps)
# transform t to tril matrix with identity diagonal
r = vec_to_tril_matrix(t, diagonal=-1)

Expand Down Expand Up @@ -417,7 +417,7 @@ def logmatmulexp(x, y):

def clamp_probs(probs):
finfo = jnp.finfo(jnp.result_type(probs, float))
return jnp.clip(probs, a_min=finfo.tiny, a_max=1.0 - finfo.eps)
return jnp.clip(probs, finfo.tiny, 1.0 - finfo.eps)


def betainc(a, b, x):
Expand Down Expand Up @@ -607,7 +607,7 @@ def safe_normalize(x, *, p=2):
assert isinstance(p, (float, int))
assert p >= 0
norm = jnp.linalg.norm(x, p, axis=-1, keepdims=True)
x = x / jnp.clip(norm, a_min=jnp.finfo(x).tiny)
x = x / jnp.clip(norm, jnp.finfo(x).tiny)
# Avoid the singularity.
mask = jnp.all(x == 0, axis=-1, keepdims=True)
x = jnp.where(mask, x.shape[-1] ** (-1 / p), x)
Expand Down
6 changes: 3 additions & 3 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,7 @@ def log_density(x):
def scan_body(carry, eps_beta):
eps, beta = eps_beta
eta = eta0 + eta_coeff * beta
eta = jnp.clip(eta, a_min=0.0, a_max=self.eta_max)
eta = jnp.clip(eta, 0.0, self.eta_max)
z_prev, v_prev, log_factor = carry
z_half = z_prev + v_prev * eta * inv_mass_matrix
q_grad = (1.0 - beta) * grad(base_z_dist.log_prob)(z_half)
Expand Down Expand Up @@ -1186,7 +1186,7 @@ def blocked_surrogate_model(x):
def scan_body(carry, eps_beta):
eps, beta = eps_beta
eta = eta0 + eta_coeff * beta
eta = jnp.clip(eta, a_min=0.0, a_max=self.eta_max)
eta = jnp.clip(eta, 0.0, self.eta_max)
z_prev, v_prev, log_factor = carry
z_half = z_prev + v_prev * eta * inv_mass_matrix
q_grad = (1.0 - beta) * grad(base_z_dist_log_prob)(z_half)
Expand Down Expand Up @@ -1641,7 +1641,7 @@ def base_z_dist_log_prob(x):
def scan_body(carry, eps_beta):
eps, beta = eps_beta
eta = eta0 + eta_coeff * beta
eta = jnp.clip(eta, a_min=0.0, a_max=self.eta_max)
eta = jnp.clip(eta, 0.0, self.eta_max)
assert eps.shape == (subsample_size, D)
assert eta.shape == beta.shape == (subsample_size,)
z_prev, v_prev, log_factor = carry
Expand Down
2 changes: 1 addition & 1 deletion numpyro/infer/barker.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def sample(self, state, model_args, model_kwargs):
- softplus(-dx_flat * y_grad_flat_scaled)
)
)
accept_prob = jnp.clip(jnp.exp(log_accept_ratio), a_max=1.0)
accept_prob = jnp.clip(jnp.exp(log_accept_ratio), None, 1.0)

x, x_flat, pe, x_grad = jax.lax.cond(
random.bernoulli(key_accept, accept_prob),
Expand Down
2 changes: 1 addition & 1 deletion numpyro/infer/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def _hmc_next(
)
delta_energy = energy_new - energy_old
delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy)
accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)
accept_prob = jnp.clip(jnp.exp(-delta_energy), None, 1.0)
diverging = delta_energy > max_delta_energy
transition = random.bernoulli(rng_key, accept_prob)
vv_state, energy = cond(
Expand Down
2 changes: 1 addition & 1 deletion numpyro/infer/hmc_gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ def potential_fn(z_gibbs, gibbs_state, z_hmc):
# given a fixed hmc_sites, pe_new - pe_curr = loglik_new - loglik_curr
pe = state.hmc_state.potential_energy
pe_new = potential_fn(z_gibbs_new, gibbs_state_new, state.hmc_state.z)
accept_prob = jnp.clip(jnp.exp(pe - pe_new), a_max=1.0)
accept_prob = jnp.clip(jnp.exp(pe - pe_new), None, 1.0)
transition = random.bernoulli(rng_key, accept_prob)
grad_ = jacfwd if self.inner_kernel._forward_mode_differentiation else grad
z_gibbs, gibbs_state, pe, z_grad = cond(
Expand Down
6 changes: 3 additions & 3 deletions numpyro/infer/hmc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ def update_fn(t, accept_prob, z_info, state):
)
# account the the case log_step_size is an extreme number
finfo = jnp.finfo(jnp.result_type(step_size))
step_size = jnp.clip(step_size, a_min=finfo.tiny, a_max=finfo.max)
step_size = jnp.clip(step_size, finfo.tiny, finfo.max)

# update mass matrix state
is_middle_window = (0 < window_idx) & (window_idx < (num_windows - 1))
Expand Down Expand Up @@ -759,7 +759,7 @@ def _biased_transition_kernel(current_tree, new_tree):
# If new tree is turning or diverging, we won't move the proposal
# to the new tree.
transition_prob = jnp.where(
new_tree.turning | new_tree.diverging, 0.0, jnp.clip(transition_prob, a_max=1.0)
new_tree.turning | new_tree.diverging, 0.0, jnp.clip(transition_prob, None, 1.0)
)
return transition_prob

Expand Down Expand Up @@ -872,7 +872,7 @@ def _build_basetree(
tree_weight = -delta_energy

diverging = delta_energy > max_delta_energy
accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)
accept_prob = jnp.clip(jnp.exp(-delta_energy), None, 1.0)
return TreeInfo(
z_new,
r_new,
Expand Down
2 changes: 1 addition & 1 deletion numpyro/infer/mixed_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def body_fn(i, vals):
# Algo 1, line 11: perform MH correction
delta_energy = energy_new - energy_old - delta_pe_sum
delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy)
accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)
accept_prob = jnp.clip(jnp.exp(-delta_energy), None, 1.0)

# record the correct new num_steps
hmc_state = hmc_state._replace(num_steps=hmc_state_new.num_steps)
Expand Down
4 changes: 1 addition & 3 deletions numpyro/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,7 @@ def __init__(self, *args, clip_norm=10.0, **kwargs):
def update(self, g, state):
i, opt_state = state
# clip norm
g = jax.tree.map(
lambda g_: jnp.clip(g_, a_min=-self.clip_norm, a_max=self.clip_norm), g
)
g = jax.tree.map(lambda g_: jnp.clip(g_, -self.clip_norm, self.clip_norm), g)
opt_state = self.update_fn(i, g, opt_state)
return i + 1, opt_state

Expand Down

0 comments on commit 1cce538

Please sign in to comment.