Skip to content

Commit

Permalink
Fixed type errors to unblock an internal type annotations refactoring…
Browse files Browse the repository at this point in the history
… in JAX

Some JAX internal used Any instead of Array or in their type annotations.
jax-ml/jax#17760 changed these to alias jax.Array and
uncovered type errors fixed here.

PiperOrigin-RevId: 570090425
  • Loading branch information
superbobry authored and Scenic Authors committed Oct 2, 2023
1 parent 67bde34 commit a7fea37
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion scenic/projects/baselines/detr/detr_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def get_label(label):
if batch_weights is not None:
batch_num_inputs = batch_weights.sum()
else:
batch_num_inputs = tgt_labels_onehot.shape[0]
batch_num_inputs = jnp.asarray(tgt_labels_onehot.shape[0])
max_logits = jnp.max(orig_src_logits, axis=-2)
tgt_labels_multihot = jnp.max(orig_tgt_labels_onehot, axis=-2)
prec_at_one = model_utils.weighted_top_one_correctly_classified(
Expand Down

0 comments on commit a7fea37

Please sign in to comment.