diff --git a/scenic/model_lib/base_models/encoder_decoder_model.py b/scenic/model_lib/base_models/encoder_decoder_model.py index a3af269c..252e9650 100644 --- a/scenic/model_lib/base_models/encoder_decoder_model.py +++ b/scenic/model_lib/base_models/encoder_decoder_model.py @@ -119,7 +119,7 @@ def encoder_decoder_metrics_function( # Calculate (clipped) perplexity after averaging log-perplexities: evaluated_metrics['perplexity'] = (jnp.clip( jnp.exp(evaluated_metrics['loss'][0] / evaluated_metrics['loss'][1]), - a_max=_MAX_PERPLEXITY), 1) + max=_MAX_PERPLEXITY), 1) return evaluated_metrics # pytype: disable=bad-return-type # jax-types diff --git a/scenic/projects/adversarialtraining/attacks/attack_methods.py b/scenic/projects/adversarialtraining/attacks/attack_methods.py index 23d1343c..d47c44b2 100644 --- a/scenic/projects/adversarialtraining/attacks/attack_methods.py +++ b/scenic/projects/adversarialtraining/attacks/attack_methods.py @@ -42,7 +42,7 @@ def project_perturbation_pyramid_inf(aug_params, epsilon, input_image, # The idea is to ensure that the sum of the perturbations over the pyramid # levels can't be more than epsilon. clipped_perturbation_pyramid = jax.tree_util.tree_map( - functools.partial(jnp.clip, a_min=-epsilon, a_max=epsilon), aug_params) + functools.partial(jnp.clip, min=-epsilon, max=epsilon), aug_params) return clipped_perturbation_pyramid @@ -62,8 +62,8 @@ def project_perturbation_pyramid_l2(aug_params, epsilon, input_image, clipped_perturbation_pyramid = jax.tree_util.tree_map( functools.partial( jnp.clip, - a_min=-epsilon / pyramid_levels, - a_max=epsilon / pyramid_levels), aug_params) + min=-epsilon / pyramid_levels, + max=epsilon / pyramid_levels), aug_params) return clipped_perturbation_pyramid diff --git a/scenic/projects/adversarialtraining/attacks/attack_metrics.py b/scenic/projects/adversarialtraining/attacks/attack_metrics.py index 442c1cd1..1524b6db 100644 --- a/scenic/projects/adversarialtraining/attacks/attack_metrics.py +++ b/scenic/projects/adversarialtraining/attacks/attack_metrics.py @@ -70,7 +70,7 @@ def get_metrics(misc_artifacts, batch, metrics, images_to_log, metrics_fn): metrics['|adv_logits-logits|_l2'] = batchwise_scalar_to_metric(logit_norms) metrics[ '|adv_logits-logits|_l2 / |adv-orig|_l2'] = batchwise_scalar_to_metric( - logit_norms / (jnp.clip(l2_norms, a_min=1e-5))) + logit_norms / (jnp.clip(l2_norms, min=1e-5))) # network performance on adversarial logits_are_correct = jnp.argmax( diff --git a/scenic/projects/adversarialtraining/attacks/train_utils.py b/scenic/projects/adversarialtraining/attacks/train_utils.py index 0164ac49..61763478 100644 --- a/scenic/projects/adversarialtraining/attacks/train_utils.py +++ b/scenic/projects/adversarialtraining/attacks/train_utils.py @@ -41,7 +41,7 @@ def unnormalize_imgnet(input_tensors): def normalize_minmax(tensor): mn = jnp.min(tensor, axis=(-1, -2, -3), keepdims=True) mx = jnp.max(tensor, axis=(-1, -2, -3), keepdims=True) - return (tensor - mn) / jnp.clip(mx - mn, a_min=1e-5) + return (tensor - mn) / jnp.clip(mx - mn, min=1e-5) def psum_metric_normalizer(metrics: Tuple[jnp.ndarray, jnp.ndarray] diff --git a/scenic/projects/av_mae/train_utils.py b/scenic/projects/av_mae/train_utils.py index 404dd9ad..dcb6fd2e 100644 --- a/scenic/projects/av_mae/train_utils.py +++ b/scenic/projects/av_mae/train_utils.py @@ -78,7 +78,7 @@ def generate_image_grid(target: jnp.ndarray, if modality == 'spectrogram': prediction_clipped = prediction else: - prediction_clipped = jnp.clip(prediction, a_min=-1, a_max=1) + prediction_clipped = jnp.clip(prediction, min=-1, max=1) # Normalise to uint8 in range [0, 255] for summary-writing. def normalise(tensor: jnp.ndarray, offset: float = 127.5) -> jnp.ndarray: @@ -221,7 +221,7 @@ def generate_image_grid_from_video(target: jnp.ndarray, # mean and std-dev instead of [-1, 1]. # if jnp.max(target) > 1 or jnp.min(target) < -1: # raise ValueError('Invalid ranges in target.') - prediction_clipped = jnp.clip(prediction, a_min=-1, a_max=1) + prediction_clipped = jnp.clip(prediction, min=-1, max=1) # Normalise to uint8 in range [0, 255] for summary-writing. def normalise(tensor: jnp.ndarray, offset: float = 127.5) -> jnp.ndarray: diff --git a/scenic/projects/baselines/centernet/modeling/roi_heads.py b/scenic/projects/baselines/centernet/modeling/roi_heads.py index 2b39650b..57c9299b 100644 --- a/scenic/projects/baselines/centernet/modeling/roi_heads.py +++ b/scenic/projects/baselines/centernet/modeling/roi_heads.py @@ -238,7 +238,7 @@ def roi_align(self, # make these numbers configurable if needed. level_assignment = jnp.floor(4 + jnp.log2(sqrt_area / 224 + 1e-8)) level_assignment = jnp.clip( - level_assignment, a_min=min_level, a_max=max_level) + level_assignment, min=min_level, max=max_level) scale = jnp.float_power(2.0, level_assignment)[:, :, None] return roi_align.multilevel_roi_align( diff --git a/scenic/projects/baselines/deformable_detr/model.py b/scenic/projects/baselines/deformable_detr/model.py index 40914e74..d4ba71d4 100644 --- a/scenic/projects/baselines/deformable_detr/model.py +++ b/scenic/projects/baselines/deformable_detr/model.py @@ -74,7 +74,7 @@ def compute_cost( All pairs cost matrix [bs, nout, ntargets]. """ # Calculate cost using pred_prob [bs, npreds]. - logfn = lambda x: jnp.log(jnp.clip(x, a_min=1e-8)) + logfn = lambda x: jnp.log(jnp.clip(x, min=1e-8)) neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-logfn(1 - out_prob)) pos_cost_class = alpha * ((1 - out_prob)**gamma) * (-logfn(out_prob)) cost_class = pos_cost_class - neg_cost_class diff --git a/scenic/projects/unloc/heads.py b/scenic/projects/unloc/heads.py index 183f3e7a..e3f97dc7 100644 --- a/scenic/projects/unloc/heads.py +++ b/scenic/projects/unloc/heads.py @@ -136,7 +136,7 @@ def _normalize_distance(self, distances: jnp.ndarray) -> jnp.ndarray: if self.distance_normalizer == 'relu_clip': distances = nn.relu(distances) # We normalize the distances to be 0 and 1. - return jnp.clip(distances, a_max=1.0) + return jnp.clip(distances, max=1.0) elif self.distance_normalizer == 'sigmoid': return nn.sigmoid(distances) elif self.distance_normalizer == 'relu':