Skip to content

Commit

Permalink
jax.numpy.clip: update use of deprecated arguments.
Browse files Browse the repository at this point in the history
- a is now positional-only
- a_min is now min
- a_max is now max

The old argument names have been deprecated since JAX v0.4.27.

PiperOrigin-RevId: 714414303
  • Loading branch information
Jake VanderPlas authored and Scenic Authors committed Jan 11, 2025
1 parent 8ef50d5 commit 794bb40
Show file tree
Hide file tree
Showing 8 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion scenic/model_lib/base_models/encoder_decoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def encoder_decoder_metrics_function(
# Calculate (clipped) perplexity after averaging log-perplexities:
evaluated_metrics['perplexity'] = (jnp.clip(
jnp.exp(evaluated_metrics['loss'][0] / evaluated_metrics['loss'][1]),
a_max=_MAX_PERPLEXITY), 1)
max=_MAX_PERPLEXITY), 1)
return evaluated_metrics # pytype: disable=bad-return-type # jax-types


Expand Down
6 changes: 3 additions & 3 deletions scenic/projects/adversarialtraining/attacks/attack_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def project_perturbation_pyramid_inf(aug_params, epsilon, input_image,
# The idea is to ensure that the sum of the perturbations over the pyramid
# levels can't be more than epsilon.
clipped_perturbation_pyramid = jax.tree_util.tree_map(
functools.partial(jnp.clip, a_min=-epsilon, a_max=epsilon), aug_params)
functools.partial(jnp.clip, min=-epsilon, max=epsilon), aug_params)

return clipped_perturbation_pyramid

Expand All @@ -62,8 +62,8 @@ def project_perturbation_pyramid_l2(aug_params, epsilon, input_image,
clipped_perturbation_pyramid = jax.tree_util.tree_map(
functools.partial(
jnp.clip,
a_min=-epsilon / pyramid_levels,
a_max=epsilon / pyramid_levels), aug_params)
min=-epsilon / pyramid_levels,
max=epsilon / pyramid_levels), aug_params)

return clipped_perturbation_pyramid

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_metrics(misc_artifacts, batch, metrics, images_to_log, metrics_fn):
metrics['|adv_logits-logits|_l2'] = batchwise_scalar_to_metric(logit_norms)
metrics[
'|adv_logits-logits|_l2 / |adv-orig|_l2'] = batchwise_scalar_to_metric(
logit_norms / (jnp.clip(l2_norms, a_min=1e-5)))
logit_norms / (jnp.clip(l2_norms, min=1e-5)))

# network performance on adversarial
logits_are_correct = jnp.argmax(
Expand Down
2 changes: 1 addition & 1 deletion scenic/projects/adversarialtraining/attacks/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def unnormalize_imgnet(input_tensors):
def normalize_minmax(tensor):
mn = jnp.min(tensor, axis=(-1, -2, -3), keepdims=True)
mx = jnp.max(tensor, axis=(-1, -2, -3), keepdims=True)
return (tensor - mn) / jnp.clip(mx - mn, a_min=1e-5)
return (tensor - mn) / jnp.clip(mx - mn, min=1e-5)


def psum_metric_normalizer(metrics: Tuple[jnp.ndarray, jnp.ndarray]
Expand Down
4 changes: 2 additions & 2 deletions scenic/projects/av_mae/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def generate_image_grid(target: jnp.ndarray,
if modality == 'spectrogram':
prediction_clipped = prediction
else:
prediction_clipped = jnp.clip(prediction, a_min=-1, a_max=1)
prediction_clipped = jnp.clip(prediction, min=-1, max=1)

# Normalise to uint8 in range [0, 255] for summary-writing.
def normalise(tensor: jnp.ndarray, offset: float = 127.5) -> jnp.ndarray:
Expand Down Expand Up @@ -221,7 +221,7 @@ def generate_image_grid_from_video(target: jnp.ndarray,
# mean and std-dev instead of [-1, 1].
# if jnp.max(target) > 1 or jnp.min(target) < -1:
# raise ValueError('Invalid ranges in target.')
prediction_clipped = jnp.clip(prediction, a_min=-1, a_max=1)
prediction_clipped = jnp.clip(prediction, min=-1, max=1)

# Normalise to uint8 in range [0, 255] for summary-writing.
def normalise(tensor: jnp.ndarray, offset: float = 127.5) -> jnp.ndarray:
Expand Down
2 changes: 1 addition & 1 deletion scenic/projects/baselines/centernet/modeling/roi_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def roi_align(self,
# make these numbers configurable if needed.
level_assignment = jnp.floor(4 + jnp.log2(sqrt_area / 224 + 1e-8))
level_assignment = jnp.clip(
level_assignment, a_min=min_level, a_max=max_level)
level_assignment, min=min_level, max=max_level)
scale = jnp.float_power(2.0, level_assignment)[:, :, None]

return roi_align.multilevel_roi_align(
Expand Down
2 changes: 1 addition & 1 deletion scenic/projects/baselines/deformable_detr/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def compute_cost(
All pairs cost matrix [bs, nout, ntargets].
"""
# Calculate cost using pred_prob [bs, npreds].
logfn = lambda x: jnp.log(jnp.clip(x, a_min=1e-8))
logfn = lambda x: jnp.log(jnp.clip(x, min=1e-8))
neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-logfn(1 - out_prob))
pos_cost_class = alpha * ((1 - out_prob)**gamma) * (-logfn(out_prob))
cost_class = pos_cost_class - neg_cost_class
Expand Down
2 changes: 1 addition & 1 deletion scenic/projects/unloc/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _normalize_distance(self, distances: jnp.ndarray) -> jnp.ndarray:
if self.distance_normalizer == 'relu_clip':
distances = nn.relu(distances)
# We normalize the distances to be 0 and 1.
return jnp.clip(distances, a_max=1.0)
return jnp.clip(distances, max=1.0)
elif self.distance_normalizer == 'sigmoid':
return nn.sigmoid(distances)
elif self.distance_normalizer == 'relu':
Expand Down

0 comments on commit 794bb40

Please sign in to comment.