From 7ace659452dee4b68547575352c022a2eef587a5 Mon Sep 17 00:00:00 2001 From: Andreas Steiner Date: Tue, 12 Dec 2023 23:35:24 +0100 Subject: [PATCH] Adds CapPa model (https://arxiv.org/abs/2306.07915). (#82) * Updates to `jax.local_devices(backend='tpu')`. * Adds preprocess_ops.coco_captions * Adds CapPa model code. * Optional `jax.distributed.initialize()` (with warning). The API call `jax.distributed.initialize()` will fail if the program is executed outside a multihost environment (e.g. local testing). * Applies Lucas's comment. --- big_vision/configs/proj/cappa/README.md | 37 ++ big_vision/configs/proj/cappa/pretrain.py | 140 +++++ .../evaluators/proj/cappa/perplexity.py | 50 ++ .../proj/cappa/scoring_classifier.py | 63 +++ big_vision/models/proj/cappa/cappa.py | 428 +++++++++++++++ big_vision/pp/ops_text.py | 11 + big_vision/trainers/proj/cappa/generative.py | 498 ++++++++++++++++++ big_vision/trainers/proj/cappa/predict_fns.py | 118 +++++ big_vision/utils.py | 13 +- 9 files changed, 1354 insertions(+), 4 deletions(-) create mode 100644 big_vision/configs/proj/cappa/README.md create mode 100644 big_vision/configs/proj/cappa/pretrain.py create mode 100644 big_vision/evaluators/proj/cappa/perplexity.py create mode 100644 big_vision/evaluators/proj/cappa/scoring_classifier.py create mode 100644 big_vision/models/proj/cappa/cappa.py create mode 100644 big_vision/trainers/proj/cappa/generative.py create mode 100644 big_vision/trainers/proj/cappa/predict_fns.py diff --git a/big_vision/configs/proj/cappa/README.md b/big_vision/configs/proj/cappa/README.md new file mode 100644 index 0000000..ee13d8e --- /dev/null +++ b/big_vision/configs/proj/cappa/README.md @@ -0,0 +1,37 @@ +# Image Captioners Are Scalable Vision Learners Too + +*by Michael Tschannen, Manoj Kumar, Andreas Steiner, Xiaohua Zhai, Neil Houlsby, Lucas Beyer* [[arxiv]](https://arxiv.org/abs/2306.07915) + +![CapPa Architecture](./cappa_architecture.png) + +This directory contains a config for training a CapPa model from scratch. +Note that most models in the paper were trained on a proprietary dataset +(WebLI), but similar results can be obtained by training on [LAION](https://laion.ai/). + +By default, this config trains on COCO captions as this data set is readily +available in [TFDS](https://www.tensorflow.org/datasets) without manual steps. +This is not meant to produce a meaningful model, but +provides a way for the user to run the config out of the box. Please update the +config with with a TFDS-wrapped variant of your favorite image/text data set to +train capable models. + +After setting up `big_vision` as described in the [main README](https://github.com/google-research/big_vision#cloud-tpu-vm-setup), training can be launched as follows + +``` +python big_vision.trainers.proj.cappa.generative \ + --config big_vision/configs/proj/cappa/pretrain.py \ + --workdir gs://$GS_BUCKET_NAME/big_vision/`date '+%m-%d_%H%M'` +``` + +To run the Cap baseline (autoregressive captioning without parallel prediction), +set `config.model.masked_pred_prob = 0.0`. + +### Citation +``` +@inproceedings{tschannen2023image, + title={Image Captioners Are Scalable Vision Learners Too}, + author={Tschannen, Michael and Kumar, Manoj and Steiner, Andreas and Zhai, Xiaohua and Houlsby, Neil and Beyer, Lucas}, + booktitle={Neural Information Processing Systems (NeurIPS)}, + year={2023} +} +``` \ No newline at end of file diff --git a/big_vision/configs/proj/cappa/pretrain.py b/big_vision/configs/proj/cappa/pretrain.py new file mode 100644 index 0000000..8b0df3c --- /dev/null +++ b/big_vision/configs/proj/cappa/pretrain.py @@ -0,0 +1,140 @@ +# Copyright 2023 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Trains a CapPa model (https://arxiv.org/abs/2306.07915) on coco_captions. + +This config is for reference, we never ran a full training on a large +image/text data set on public infrastructure. + +big_vision.trainers.proj.cappa.generative \ + --config big_vision/configs/proj/cappa/pretrain.py \ + --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` +""" + + +from big_vision.configs import common_fewshot +import big_vision.configs.common as bvcc +import ml_collections + + +def get_config(arg=None): + """Returns the base config.""" + config = bvcc.parse_arg(arg, + runlocal=False, + total_steps=366_500, + batch_size=8*1024, + warmup_steps=10_000, + ) + + config.evals = {} + config.input = {} + config.input.batch_size = config.batch_size if not config.runlocal else 8 + shuffle_buffer_size = 50_000 if not config.runlocal else 50 + + res = 224 + patch_size = 16 + max_text_tokens = 64 + + pp_image = (f'resize({res})|value_range(-1,1)') + + def tokenizer(inkey, outkey): + return (f'tokenize(max_len={max_text_tokens}, model="c4_en", ' + f'eos="sticky", inkey="{inkey}", outkey="{outkey}")') + + pp_coco = (f'decode|{pp_image}|' + 'coco_captions("captions")|choice(inkey="captions", outkey="text")|' + f'{tokenizer("text", "labels")}|keep("image", "labels")') + config.input.pp = pp_coco + + # NOTE: "coco_captions" is way too small a dataset to train on. It's simply + # used here to serve as a smoke test that the implementation works correctly. + config.input.data = dict(name='coco_captions', split='train') # num_examples=82_783 + config.input.shuffle_buffer_size = shuffle_buffer_size + + config.evals.val_coco = { + 'type': 'proj.cappa.perplexity', + 'pred': 'perplexity', + 'log_steps': 1000, + 'data': dict(name='coco_captions', split='val'), # num_examples=5_000 + 'pp_fn': pp_coco, + } + + # Few-shot metrics + config.evals.fewshot = common_fewshot.get_fewshot_lsr( + target_resolution=res, resize_resolution=int(256 / 224 * res)) + config.evals.fewshot.type = 'fewshot_lsr' + config.evals.fewshot.log_steps = 5_000 if not config.runlocal else 5 + config.evals.fewshot.representation_layer = 'pre_logits' + config.evals.fewshot.pred = 'enc_rep' + config.evals.fewshot.pp_eval = config.evals.fewshot.pp_train + + # NOTE: Scoring of the entire imagenet validation set is rather slow: + # ~100 secs / 1k classes / host. + config.evals['imagenet/scoring'] = dict( + type='proj.cappa.scoring_classifier', + pred='score', + log_percent=0.1, + data=dict(name='imagenet2012', split='validation'), + pp_fn=f'decode|{pp_image}|keep("image", "label")', + pp_txt=tokenizer('label', 'labels'), + ) + + for e in config.evals.values(): + e.skip_first = True + + config.log_training_steps = 50 + config.ckpt_steps = 1000 + config.keep_ckpt_steps = None # 10_000 + + # Model section + config.model_name = 'proj.cappa.cappa' + config.model = ml_collections.ConfigDict() + config.model.num_layers = 12 + config.model.num_heads = 12 + config.model.mlp_dim = 3072 + config.model.emb_dim = 768 + config.model.vocab_size = 32_000 + config.model.patches = (patch_size, patch_size) + config.model.seq_len = max_text_tokens + config.model.posemb_type = 'learn' + + # Decoder + config.model.decoder_num_layers = 6 + # 0 values here mean to use the same value as for the encoder + config.model.decoder_num_heads = 0 + config.model.decoder_mlp_dim = 0 + config.model.decoder_emb_dim = 0 + config.model.dec_dropout_rate = 0.0 + config.model.masked_pred_prob = 0.75 + config.model.masking_ratio = 1.0 + config.model.decoder_bias = False + + config.optax_name = 'big_vision.scale_by_adafactor' + config.optax = dict(beta2_cap=0.999) + config.grad_clip_norm = 1.0 + config.label_smoothing = 0.0 + + schedule = dict(decay_type='cosine', + warmup_steps=config.warmup_steps + if not config.runlocal else 5) + + # Standard schedule + config.lr = 0.001 + config.wd = 0.0001 + config.schedule = schedule + + config.seed = 0 + + return config \ No newline at end of file diff --git a/big_vision/evaluators/proj/cappa/perplexity.py b/big_vision/evaluators/proj/cappa/perplexity.py new file mode 100644 index 0000000..2ce6939 --- /dev/null +++ b/big_vision/evaluators/proj/cappa/perplexity.py @@ -0,0 +1,50 @@ +# Copyright 2023 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluator for perplexity of a model.""" +from big_vision.evaluators import mean +import big_vision.utils as u +import jax.numpy as jnp + + +# Temporary global flag to facilitate backwards compatability. Will be removed +# by the end of year 2023. +API = 'jit' + + +def perplexity(predict_fn, normalize_by_seqlen): + """Returns a function that computes perplexity.""" + + def _perplexity_fn(train_state, batch, pad_token=0, **kw): + logits, _ = predict_fn(train_state, batch, **kw) + + # Ignore perplexity on the padding label. + weights = jnp.where(batch['labels'] != pad_token, 1, 0).astype(jnp.float32) + if batch.get('label_masks') is not None: + weights = weights * batch['label_masks'] + + losses = u.weighted_softmax_xent( + logits=logits, labels=batch['labels'], + weights=weights, label_smoothing=0.0, + reduction=False, normalize=normalize_by_seqlen) + + return {'perplexity': losses} + return _perplexity_fn + + +class Evaluator(mean.Evaluator): + """Perplexity evaluator.""" + + def __init__(self, predict_fn, *a, normalize_by_seqlen=False, **kw): + super().__init__(perplexity(predict_fn, normalize_by_seqlen), *a, **kw) diff --git a/big_vision/evaluators/proj/cappa/scoring_classifier.py b/big_vision/evaluators/proj/cappa/scoring_classifier.py new file mode 100644 index 0000000..60906ba --- /dev/null +++ b/big_vision/evaluators/proj/cappa/scoring_classifier.py @@ -0,0 +1,63 @@ +# Copyright 2023 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Scoring classifier. + +This one is based on a generative perspective for image classification. +Here we input the image as well as all the tokenized labels to compute their +perplexity and select the one with minimum loss as the prediction. +""" +import functools +from big_vision.datasets.imagenet import class_names as imagenet_class_names +from big_vision.evaluators import mean +from big_vision.pp import builder as pp_builder +import jax.numpy as jnp +import numpy as np + +# Temporary global flag to facilitate backwards compatability. Will be removed +# by the end of year 2023. +API = "jit" + + +CLASS_NAMES = { + "imagenet2012": imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES, +} + + +# As a separate function to cache result across instances. +@functools.lru_cache(maxsize=None) +def get_classes(dataset_name, pp_txt): + """Load the class label strings and tokenize them using pp_txt.""" + pp_fn = pp_builder.get_preprocess_fn(pp_txt, log_data=False) + return np.array([pp_fn({"label": name})["labels"] + for name in CLASS_NAMES[dataset_name]]) + + +def scoring(predict_fn, tokenized_labels): + + def _scoring_fn(train_state, batch, *a, **kw): + batch = {"_label_tokens": tokenized_labels, **batch} + scores = predict_fn(train_state, batch, *a, **kw) + predictions = jnp.argmax(scores, axis=-1) + return {"prec@1": predictions == batch["label"]} + + return _scoring_fn + + +class Evaluator(mean.Evaluator): + """Evaluator for classification accuracy based on scoring all classes.""" + + def __init__(self, predict_fn, data, pp_fn, pp_txt, *a, **kw): + cls_tokens = get_classes(data["name"], pp_txt) + super().__init__(scoring(predict_fn, cls_tokens), data, pp_fn, *a, **kw) diff --git a/big_vision/models/proj/cappa/cappa.py b/big_vision/models/proj/cappa/cappa.py new file mode 100644 index 0000000..8c20b1b --- /dev/null +++ b/big_vision/models/proj/cappa/cappa.py @@ -0,0 +1,428 @@ +# Copyright 2023 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model definitions for CapPa (https://arxiv.org/abs/2306.07915). + +Used abbreviations for dimension annotations: + B: batch size. + H: image height. + W: image width. + P: number of patches (PH/PW: number of patches in height/width dimensions). + E: embedding size. + L: sequence length of text tokens. + V: vocab size. +""" + +from collections.abc import Sequence + +from big_vision import utils +from big_vision.models import common +from big_vision.models import vit +import flax +import flax.linen as nn +from flax.linen import partitioning +import jax +import jax.numpy as jnp + + +def shift_right(x, axis=1, constant_values=0): + """Shift to the right on given axis with padding value 0.""" + pad_widths = [(0, 0)] * len(x.shape) + pad_widths[axis] = (1, 0) + padded = jnp.pad(x, pad_widths, constant_values=constant_values) + # Cuts off the rightmost slice of size along the `axis` dimension. + # Note that `list[:-1]`` is the same as `list[slice(-1)]`. + return padded[tuple(slice(-1 if i == axis else None) for i in range(x.ndim))] + + +class MlpBlock(nn.Module): + """Transformer MLP / feed-forward block with option to deactivate bias.""" + mlp_dim: int | None = None # Defaults to 4x input dim + dropout: float = 0.0 + use_bias: bool = True + + @nn.compact + def __call__(self, x, deterministic=True): + """Applies Transformer MlpBlock module.""" + inits = dict( + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.normal(stddev=1e-6), + ) + + n, l, d = x.shape # pylint: disable=unused-variable + x = nn.Dense(self.mlp_dim or 4 * d, use_bias=self.use_bias, **inits)(x) + x = nn.gelu(x) + x = nn.Dropout(rate=self.dropout)(x, deterministic) + x = nn.Dense(d, use_bias=self.use_bias, **inits)(x) + return x + + +class EncoderDecoderBlock(nn.Module): + """Transformer encoder-decoder layer.""" + mlp_dim: int + num_heads: int + dropout_rate: float = 0. + decode: bool = False + use_bias: bool = True + + @nn.compact + def __call__(self, targets, encoded, decoder_mask=None, deterministic=True): + """Applies EncoderDecoder1DBlock module. + + Args: + targets: target text embeddings [B, L, E]. + encoded: encoded image patches from encoder [B, P, E]. + decoder_mask: decoder self-attention mask. + deterministic: bool, deterministic or not (to apply dropout). + + Returns: + output after transformer encoder-decoder block [B, L, E]. + """ + def wlc(f): + dim_names = ("act_batch", "act_len", "act_emb") + return nn.with_logical_constraint(f, dim_names) + + # Decoder block. + x = wlc(nn.LayerNorm(name="LayerNorm1", use_bias=self.use_bias)(targets)) + x = wlc(nn.SelfAttention( + num_heads=self.num_heads, use_bias=False, broadcast_dropout=False, + dropout_rate=self.dropout_rate, decode=self.decode, name="SelfAttn")( + x, decoder_mask, deterministic=deterministic)) + x = wlc(nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)) + x = wlc(x + targets) + + if encoded is not None: + # Encoder-Decoder block. + y = wlc(nn.LayerNorm(name="LayerNorm2", use_bias=self.use_bias)(x)) + y = wlc(nn.MultiHeadDotProductAttention( + num_heads=self.num_heads, use_bias=False, broadcast_dropout=False, + dropout_rate=self.dropout_rate, name="CrossAttn")( + y, encoded, deterministic=deterministic)) + y = wlc( + nn.Dropout(rate=self.dropout_rate)(y, deterministic=deterministic)) + y = wlc(y + x) + else: + y = x + + # MLP block. + z = wlc(nn.LayerNorm(name="LayerNorm3", use_bias=self.use_bias)(y)) + z = wlc(MlpBlock( + mlp_dim=self.mlp_dim, dropout=self.dropout_rate, use_bias=self.use_bias, + name="MLP")(z, deterministic=deterministic)) + + return wlc(y + z), None + + +class Decoder(nn.Module): + """Transformer decoder with parallel prediction.""" + emb_dim: int + mlp_dim: int + num_heads: int + num_layers: int + dropout_rate: float = 0. + output_vocab_size: int = 32_000 + + # Masked prediction training mode + masked_pred_prob: float = 0. + masking_ratio: float = 0. + + # Whether to use bias in MLP blocks and LN + use_bias: bool = True + + scan: bool = False + remat_policy: str = "nothing_saveable" + + @nn.compact + def __call__(self, + encoded, + targets, + pos_emb, + decoder_mask=None, + decode=False, + deterministic=True, + max_decode_length=None): + """Applies Transformer model on the inputs. + + Args: + encoded: encoded image patches from encoder [B, P, E]. + targets: target text tokens [B, L]. + pos_emb: positional embeddings. + decoder_mask: decoder self-attention mask. + decode: bool, whether to perform fast autoregressive decoding with cache. + deterministic: bool, deterministic or not (to apply dropout). + max_decode_length: optional max length for positional embeddings. + + Returns: + output of a transformer decoder [B, L, V]. + """ + y = targets.astype("int32") + if not decode: + if self.masked_pred_prob > 0.0 and not deterministic: + # Binary random variable indicating whether to do masked prediction + + def _add_random_masks(a): + # Generate random mask + n_masked = int(self.masking_ratio * a.shape[1]) + mask_locations = jnp.zeros(a.shape[:2], dtype=jnp.int32) + mask_locations = mask_locations.at[:, :n_masked].set(1) + mask_locations = jax.random.permutation( + self.make_rng("dropout"), mask_locations, axis=1, independent=True + ) + # Replace mask locations with mask token index (=vocab_size) + a_masked = jnp.where(mask_locations, self.output_vocab_size, a) + return a_masked + + def where(mask, x, y): + mask = mask.reshape((-1,) + (1,) * (x.ndim - 1)) + return jnp.where(mask, x, y) + + do_masked_pred = ( + jax.random.uniform(self.make_rng("dropout"), (len(y),)) + < self.masked_pred_prob + ) + y = where(do_masked_pred, _add_random_masks(y), shift_right(y)) + decoder_mask = where( + do_masked_pred, jnp.ones_like(decoder_mask), decoder_mask + ) + + else: + y = shift_right(y) + + embed = nn.Embed( + self.output_vocab_size + (1 if self.masked_pred_prob > 0.0 else 0), + self.emb_dim, + name="EmbedTargets", + embedding_init=nn.initializers.normal(stddev=1.0), + ) + y = embed(y) + + y = common.AddPositionEmbs( + decode=decode, name="PosEmbedTargets")(y, pos_emb) + # NOTE: One could apply dropout on the decoder's inputs here. Whether to do + # it or not, and if so, what is the best/common way, is to be determined. + # y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=deterministic) + + if self.scan: + # Mostly followed + # https://github.com/google/maxtext/blob/4d99e30b3e0e0cb1d1aa11c7db7fffe18e301498/MaxText/layers.py#L1126 + # for the scanned version. + # 1. remat + enc_dec_block_remat = nn.remat( + EncoderDecoderBlock, + prevent_cse=False, + static_argnums=(-1,), + policy=getattr(jax.checkpoint_policies, self.remat_policy, None)) + # 2. scan + initializing = self.is_mutable_collection("params") + param_scan_axis = 1 + params_spec = (param_scan_axis if initializing + else partitioning.ScanIn(param_scan_axis)) + dec_scanned = nn.scan(enc_dec_block_remat, + variable_axes={ + "params": params_spec, + "cache": 0, + }, + split_rngs={"params": True, "dropout": True}, + in_axes=nn.broadcast, + length=self.num_layers) + # 3. fprop + y, _ = dec_scanned(num_heads=self.num_heads, mlp_dim=self.mlp_dim, + dropout_rate=self.dropout_rate, decode=decode, + use_bias=self.use_bias, name="EncDecBlock")( + y, encoded, decoder_mask, deterministic) + else: + for lyr in range(self.num_layers): + y, _ = EncoderDecoderBlock( + num_heads=self.num_heads, mlp_dim=self.mlp_dim, + dropout_rate=self.dropout_rate, decode=decode, + use_bias=self.use_bias, name=f"EncDecBlock{lyr}")( + y, encoded, decoder_mask=decoder_mask, + deterministic=deterministic) + + y = nn.LayerNorm(name="LayerNorm")(y) + + logits = nn.Dense( + self.output_vocab_size, + kernel_init=nn.initializers.zeros, + name="LogitsDense", + )(y) + return logits + + +class Model(nn.Module): + """Transformer Model for sequence to sequence translation.""" + # Encoder/decoder: + num_heads: int = 8 + num_layers: int = 6 + mlp_dim: int = 2048 + emb_dim: int = 512 + enc_dropout_rate: float = 0. + vocab_size: int = 32_000 + seq_len: int = 256 + + # Encoder: + patches: Sequence[int] = (16, 16) + input_seq_len: int = 768 + posemb_type: str = "learn" + patch_dropout: float = 0. + + # Decoder: + decoder_num_heads: int = 0 + decoder_num_layers: int = 0 + decoder_mlp_dim: int = 0 + decoder_emb_dim: int = 0 + dec_dropout_rate: float = 0. + # Probability of masked prediction rather than autoregressive prediciton. + masked_pred_prob: float = 0. + # Masking ratio for masked prediction. + masking_ratio: float = 0. + # Whether to use bias in decoder MLP blocks and LN. + decoder_bias: bool = True + + scan: bool = False + remat_policy: str = "nothing_saveable" + + def setup(self): + + self.encoder = vit.Model( + patch_size=self.patches, + width=self.emb_dim, + depth=self.num_layers, + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + dropout=self.enc_dropout_rate, + posemb=self.posemb_type, + scan=self.scan, + remat_policy=self.remat_policy, + ) + + self.pos_emb_for_decoder = vit.get_posemb( + self, + self.posemb_type, + (1, self.seq_len), + self.decoder_emb_dim or self.emb_dim, + "pos_embedding_decoder", + ) + self.decoder = Decoder( + num_layers=self.decoder_num_layers or self.num_layers, + mlp_dim=self.decoder_mlp_dim or self.mlp_dim, + num_heads=self.decoder_num_heads or self.num_heads, + dropout_rate=self.dec_dropout_rate, + emb_dim=self.decoder_emb_dim or self.emb_dim, + output_vocab_size=self.vocab_size, + masked_pred_prob=self.masked_pred_prob, + masking_ratio=self.masking_ratio, + use_bias=self.decoder_bias, + scan=self.scan, + remat_policy=self.remat_policy, + ) + + def encode(self, image, train=False, return_enc_features=False): + """Encodes input image or embeddings.""" + + _, out = self.encoder(image, train=train) + encoded = out["encoded"] + + # Return intermediate features if required + if return_enc_features: + return encoded, out + + return encoded + + def decode(self, encoded, targets, decode=False, train=False, + max_decode_length=None): + """Applies Transformer decoder-branch on encoded-input and target. + + Args: + encoded: encoded image patches from encoder [B, P, E]. + targets: target text tokens [B, L]. + decode: whether to prepare and use an autoregressive cache. + train: whether it is training. + max_decode_length: optional max length for positional embeddings. + + Returns: + logits array from transformer decoder [B, L, V]. + """ + decoder_mask = None if decode else nn.make_causal_mask(targets) + logits = self.decoder( + encoded, + targets, + pos_emb=self.pos_emb_for_decoder, + decoder_mask=decoder_mask, + decode=decode, + deterministic=not train, + max_decode_length=max_decode_length) + return logits + + def __call__(self, image, text, *, decode=False, + train=False, return_enc_features=False): + """Applies Transformer model on the inputs. + + Args: + image: batch of images [B, H, W, 3]. + text: batch of tokenized texts [B, L]. + decode: whether to prepare and use an autoregressive cache. + train: whether it is training. + return_enc_features: whether to return the encoder features. + + Returns: + logits array from full transformer [B, L, V]. + """ + if return_enc_features: + encoded, out = self.encode(image, train=train, return_enc_features=True) + return encoded, out + + encoded = self.encode(image, train=train) + + decoded = self.decode(encoded, text, decode=decode, train=train) + return decoded + + +def load(init_params, init_files, model_params=None, + dont_load=("head/kernel", "head/bias", "cls")): + """Loads params from init checkpoint and merges into init_params.""" + + if isinstance(init_files, str): + # A shortcut for a single file checkpoint of a vtt model. + ckpt_params = utils.load_params(init_files) + ckpt_params = flax.training.checkpoints.convert_pre_linen(ckpt_params) + ckpt_params = common.merge_params(ckpt_params, init_params, dont_load) + + # Detect attempts to load non-scan checkpoint into scan model if possible. + if (model_params.get("scan") and + "encoderblock" not in ckpt_params["encoder"]["Transformer"]): + raise NotImplementedError("Loading a non-scan checkpoint into a " + "scan model is not supported yet!") + if (not model_params.get("scan") + and "encoderblock" in ckpt_params["encoder"]["Transformer"]): + assert "decoder.*" in dont_load or "decoder/.*" in dont_load, ( + "Converting scan decoder to a non-scan one is not supported yet!") + ckpt_params["encoder"] = utils.jit_cpu()( + vit.scan_to_pyloop)(ckpt_params["encoder"]) + + else: + assert set(init_files) == {"encoder"}, "Only encoder init supported" + enc_init = init_files["encoder"] + ckpt_params = flax.core.freeze(init_params).unfreeze() + vit_params = ckpt_params["encoder"] + encoder_params = vit.load( + vit_params, enc_init, model_cfg={}, + dont_load=dont_load) + ckpt_params["encoder"] = encoder_params + + ckpt_params["encoder"]["pos_embedding"] = vit.resample_posemb( + old=ckpt_params["encoder"]["pos_embedding"], + new=init_params["encoder"]["pos_embedding"]) + + return ckpt_params diff --git a/big_vision/pp/ops_text.py b/big_vision/pp/ops_text.py index 8ba5477..607415c 100644 --- a/big_vision/pp/ops_text.py +++ b/big_vision/pp/ops_text.py @@ -176,6 +176,17 @@ def _pp_tokenize(txt): return _pp_tokenize +@Registry.register("preprocess_ops.coco_captions") +def get_coco_captions(outkey="captions"): + """Extracts coco's captions from nested dict.""" + + def _pp_coco_captions(data): + data[outkey] = data["captions"]["text"] + return data + + return _pp_coco_captions + + @Registry.register("preprocess_ops.clip_i1k_label_names") @utils.InKeyOutKey(indefault="label", outdefault="labels") def get_pp_clip_i1k_label_names(): diff --git a/big_vision/trainers/proj/cappa/generative.py b/big_vision/trainers/proj/cappa/generative.py new file mode 100644 index 0000000..a4b8873 --- /dev/null +++ b/big_vision/trainers/proj/cappa/generative.py @@ -0,0 +1,498 @@ +# Copyright 2023 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Training loop for CapPa (https://arxiv.org/abs/2306.07915).""" +# pylint: disable=consider-using-from-import +import functools +import importlib +import multiprocessing.pool +import os + +from absl import app +from absl import flags +from absl import logging +import big_vision.datasets.core as ds_core +import big_vision.evaluators.common as eval_common +import big_vision.input_pipeline as input_pipeline +import big_vision.optax as bv_optax +import big_vision.sharding as bv_sharding +import big_vision.trainers.proj.cappa.predict_fns as predict_fns +import big_vision.utils as u +from clu import parameter_overview +import flax +import flax.linen as nn +import jax +from jax.experimental import mesh_utils +from jax.experimental import multihost_utils +from jax.experimental.array_serialization import serialization as array_serial +import jax.numpy as jnp +from ml_collections import config_flags +import numpy as np +import optax +import tensorflow as tf + +from tensorflow.io import gfile + +# pylint: disable=logging-fstring-interpolation + + +config_flags.DEFINE_config_file( + "config", None, "Training configuration.", lock_config=True) + +flags.DEFINE_string("workdir", default=None, help="Work unit directory.") +flags.DEFINE_boolean("cleanup", default=False, + help="Delete workdir (only) after successful completion.") + +# Adds jax flags to the program. +jax.config.parse_flags_with_absl() +# Transfer guard will fail the program whenever that data between a host and +# a device is transferred implicitly. This often catches subtle bugs that +# cause slowdowns and memory fragmentation. Explicit transfers are done +# with jax.device_put and jax.device_get. +jax.config.update("jax_transfer_guard", "disallow") +# Fixes design flaw in jax.random that may cause unnecessary d2d comms. +jax.config.update("jax_threefry_partitionable", True) + + +def main(argv): + del argv + + try: + jax.distributed.initialize() + except ValueError as e: + logging.warning('Could not initialize distributed environment: %s', e) + + # Make sure TF does not touch GPUs. + tf.config.set_visible_devices([], "GPU") + + # Consistent device order is important to ensure correctness of various train + # loop components, such as input pipeline, update step, evaluators. We use + # jax utils to infer device order that will be used throughout the program. + config = flags.FLAGS.config + num_sharded_replicas = config.get("num_sharded_replicas", 1) + assert jax.device_count() % num_sharded_replicas == 0, ( + num_sharded_replicas, jax.device_count()) + devices = mesh_utils.create_device_mesh( + (num_sharded_replicas, jax.device_count() // num_sharded_replicas)) + devices_flat = devices.reshape(-1) + +################################################################################ +# # +# Set up logging # +# # +################################################################################ + + # Set up work directory and print welcome message. + config = flags.FLAGS.config + workdir = flags.FLAGS.workdir + logging.info( + f"\u001b[33mHello from process {jax.process_index()} holding " + f"{jax.local_device_count()}/{jax.device_count()} devices and " + f"writing to workdir {workdir}.\u001b[0m") + + save_ckpt_path = None + if workdir: # Always create if requested, even if we may not write into it. + gfile.makedirs(workdir) + save_ckpt_path = os.path.join(workdir, "checkpoint.bv") + + # The pool is used to perform misc operations such as logging in async way. + pool = multiprocessing.pool.ThreadPool() + + # Here we register preprocessing ops from modules listed on `pp_modules`. + for m in config.get("pp_modules", ["ops_general", "ops_image", "ops_text"]): + importlib.import_module(f"big_vision.pp.{m}") + + # Setup up logging and experiment manager. + xid, wid = -1, -1 + fillin = lambda s: s + def info(s, *a): + logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a) + def write_note(note): + if jax.process_index() == 0: + info("%s", note) + + mw = u.BigVisionMetricWriter(xid, wid, workdir, config) + +################################################################################ +# # +# Input Pipeline # +# # +################################################################################ + + write_note("Initializing train dataset...") + batch_size = config.input.batch_size + if batch_size % jax.device_count() != 0: + raise ValueError(f"Batch size ({batch_size}) must " + f"be divisible by device number ({jax.device_count()})") + info("Global batch size %d on %d hosts results in %d local batch size. With " + "%d dev per host (%d dev total), that's a %d per-device batch size.", + batch_size, jax.process_count(), batch_size // jax.process_count(), + jax.local_device_count(), jax.device_count(), + batch_size // jax.device_count()) + + train_ds, ntrain_img = input_pipeline.training(config.input) + + total_steps = u.steps("total", config, ntrain_img, batch_size) + def get_steps(name, default=ValueError, cfg=config): + return u.steps(name, cfg, ntrain_img, batch_size, total_steps, default) + + u.chrono.inform(total_steps=total_steps, global_bs=batch_size, + steps_per_epoch=ntrain_img / batch_size, + measure=mw.measure, write_note=write_note) + + info("Running for %d steps, that means %f epochs", + total_steps, total_steps * batch_size / ntrain_img) + + # Start input pipeline as early as possible. + n_prefetch = config.get("prefetch_to_device", 1) + train_iter = input_pipeline.start_global(train_ds, devices_flat, n_prefetch) + + # For mixed data, add per-dataset epoch and examples seen measurements. + if isinstance(config.input.data.get("name"), str): + measure_per_dataset_times = lambda step: None # No-op + else: + nexamples = { + name: ds_core.get(**config.input[name].data).total_examples + for name in config.input.data + } + def measure_per_dataset_times(step): + total = sum(config.input.data.values()) + for name, w in config.input.data.items(): + w = w / total + mw.measure(f"examples_seen_{name}", u.chrono.accum_examples_seen * w) + mw.measure(f"epoch_{name}", step * batch_size * w / nexamples[name]) + +################################################################################ +# # +# Create Model & Optimizer # +# # +################################################################################ + + write_note(f"Initializing {config.model_name} model...") + model_mod = importlib.import_module(f"big_vision.models.{config.model_name}") + model = model_mod.Model(**config.get("model", {})) + + def init(rng): + bs = batch_size // jax.device_count() + img_shape = (bs,) + tuple(train_ds.element_spec["image"].shape[1:]) + txt_shape = (bs,) + tuple(train_ds.element_spec["labels"].shape[1:]) + dummy_img = jnp.zeros(img_shape, jnp.float32) + dummy_txt = jnp.zeros(txt_shape, jnp.int64) + variables = model.init(rng, dummy_img, dummy_txt) + return flax.core.unfreeze(variables["params"]) + + # This seed makes the Jax part of things (like model init) deterministic. + # However, full training still won't be deterministic, for example due to the + # tf.data pipeline not being deterministic even if we would set TF seed. + # See (internal link) for a fun read on what it takes. + rng = jax.random.PRNGKey(u.put_cpu(config.get("seed", 0))) + + rng, rng_init = jax.random.split(rng) + with u.chrono.log_timing("z/secs/init"): + write_note("Inferring parameter shapes...") + params_shape = jax.eval_shape(init, rng_init) + + write_note("Inferring optimizer state shapes...") + tx, sched_fns = bv_optax.make(config, params_shape, sched_kw=dict( + total_steps=total_steps, batch_size=batch_size, data_size=ntrain_img)) + opt_shape = jax.eval_shape(tx.init, params_shape) + # We jit this, such that the arrays are created on the CPU, not device[0]. + sched_fns_cpu = [u.jit_cpu()(sched_fn) for sched_fn in sched_fns] + + num_params = sum(np.prod(p.shape) for p in jax.tree_leaves(params_shape)) + mw.measure("num_params", num_params) + +################################################################################ +# # +# Shard & Transfer # +# # +################################################################################ + + # Using a 2D mesh where the model/optimizer parameters are replicated along + # the "replica" axis and sharded along the "fsdp" axis if the sharding is set + # to "fully_sharded" for the model/optimizer. + write_note("Creating device mesh...") + mesh = jax.sharding.Mesh(devices, ("replica", "fsdp")) + repl_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) + + write_note("Inferring shardings...") + params_sharding = bv_sharding.infer_sharding( + params_shape, mesh, axis_name="fsdp", + # TODO: implement scan for parameter sharding. + strategy=config.get("param_sharding", "replicated"), + extra_strategy_args=config.get("param_sharding_args", {})) + opt_sharding = bv_sharding.infer_sharding( + opt_shape, mesh, axis_name="fsdp", + strategy=config.get("optim_sharding", "replicated"), + extra_strategy_args=config.get("optim_sharding_args", {})) + + write_note("Transferring train_state to devices...") + # RNG is always replicated + rng_init = u.reshard(rng_init, repl_sharding) + + # Parameters and the optimizer are now global (distributed) jax arrays. + params = jax.jit(init, out_shardings=params_sharding)(rng_init) + opt = jax.jit(tx.init, out_shardings=opt_sharding)(params) + + rng, rng_loop = jax.random.split(rng, 2) + rng_loop = u.reshard(rng_loop, repl_sharding) + del rng # not used anymore, so delete it. + + # At this point we have everything we need to form a train state. It contains + # all the parameters that are passed and updated by the main training step. + train_state_sharding = { + "params": params_sharding, "opt": opt_sharding, "rng": repl_sharding} + train_state = { + "params": params, "opt": opt, "rng": rng_loop} + del params, opt, rng_loop # Delete to avoid memory leak or accidental reuse. + + write_note("Logging parameter overview...") + parameter_overview.log_parameter_overview( + train_state["params"], msg="Init params", + include_stats="global", jax_logging_process=0) + +################################################################################ +# # +# Update Step # +# # +################################################################################ + + @functools.partial( + jax.jit, + donate_argnums=(0,), + out_shardings=(train_state_sharding, repl_sharding)) + def update_fn(train_state, batch): + """Update step.""" + + images, labels, label_masks = ( + batch["image"], batch["labels"], batch.get("label_masks")) + + # Get device-specific loss rng. + rng = train_state["rng"] + rng, rng_model = jax.random.split(rng, 2) + + def loss_fn(params): + logits = model.apply( + {"params": params}, images, labels, + train=True, rngs={"dropout": rng_model}) + + weights = jnp.where(labels != config.get("pad_token", 0), 1.0, 0.0) + if label_masks is not None: + weights = weights * label_masks + + loss = u.weighted_softmax_xent( + logits=logits, labels=labels, + weights=weights, label_smoothing=config.get("label_smoothing", 0.0), + reduction=True, normalize=True) + + return loss + + params, opt = train_state["params"], train_state["opt"] + loss, grads = jax.value_and_grad(loss_fn)(params) + updates, opt = tx.update(grads, opt, params) + params = optax.apply_updates(params, updates) + + measurements = {"training_loss": loss} + gs = jax.tree_leaves(bv_optax.replace_frozen(config.schedule, grads, 0.)) + measurements["l2_grads"] = jnp.sqrt(sum([jnp.sum(g * g) for g in gs])) + ps = jax.tree_leaves(params) + measurements["l2_params"] = jnp.sqrt(sum([jnp.sum(p * p) for p in ps])) + us = jax.tree_leaves(updates) + measurements["l2_updates"] = jnp.sqrt(sum([jnp.sum(u * u) for u in us])) + + return {"params": params, "opt": opt, "rng": rng}, measurements + +################################################################################ +# # +# Load Checkpoint # +# # +################################################################################ + + # Decide how to initialize training. The order is important. + # 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job. + # 2. Resume from a previous checkpoint, e.g. start a cooldown training job. + # 3. Initialize model from something, e,g, start a fine-tuning job. + # 4. Train from scratch. + resume_ckpt_path = None + if save_ckpt_path and gfile.exists(f"{save_ckpt_path}-LAST"): + resume_ckpt_path = save_ckpt_path + elif config.get("resume"): + resume_ckpt_path = fillin(config.resume) + + ckpt_mngr = None + if save_ckpt_path or resume_ckpt_path: + ckpt_mngr = array_serial.GlobalAsyncCheckpointManager() + + if resume_ckpt_path: + write_note(f"Resuming training from checkpoint {resume_ckpt_path}...") + jax.tree_map(lambda x: x.delete(), train_state) + del train_state + shardings = { + **train_state_sharding, + "chrono": jax.tree_map(lambda _: repl_sharding, + u.chrono.save()), + } + loaded = u.load_checkpoint_ts( + resume_ckpt_path, tree=shardings, shardings=shardings) + train_state = {key: loaded[key] for key in train_state_sharding.keys()} + + u.chrono.load(jax.device_get(loaded["chrono"])) + del loaded + elif config.get("model_init"): + write_note(f"Initialize model from {config.model_init}...") + train_state["params"] = model_mod.load( + train_state["params"], config.model_init, config.get("model"), + **config.get("model_load", {})) + + # load has the freedom to return params not correctly sharded. Think of for + # example ViT resampling position embedings on CPU as numpy arrays. + train_state["params"] = u.reshard( + train_state["params"], train_state_sharding["params"]) + + parameter_overview.log_parameter_overview( + train_state["params"], msg="restored params", + include_stats="global", jax_logging_process=0) + +################################################################################ +# # +# Setup Evals # +# # +################################################################################ + + # Only initialize evaluators when they are first needed. + @functools.lru_cache(maxsize=None) + def evaluators(): + return eval_common.from_config( + config, + predict_fns.get_predict_fns(model), + lambda s: write_note(f"Init evaluator: {s}…\n{u.chrono.note}"), + lambda key, cfg: get_steps(key, default=None, cfg=cfg), + devices_flat, + ) + + # At this point we need to know the current step to see whether to run evals. + write_note("Inferring the first step number...") + first_step_device = bv_optax.get_count(train_state["opt"], jittable=True) + first_step = int(jax.device_get(first_step_device)) + u.chrono.inform(first_step=first_step) + + # Note that training can be pre-empted during the final evaluation (i.e. + # just after the final checkpoint has been written to disc), in which case we + # want to run the evals. + if first_step in (total_steps, 0): + write_note("Running initial or final evals...") + mw.step_start(first_step) + for (name, evaluator, _, prefix) in evaluators(): + if config.evals[name].get("skip_first") and first_step != total_steps: + continue + write_note(f"{name} evaluation...\n{u.chrono.note}") + with u.chrono.log_timing(f"z/secs/eval/{name}"): + with mesh, nn.logical_axis_rules([("act_batch", ("replica", "fsdp"))]): # pytype: disable=wrong-arg-types + for key, value in evaluator.run(train_state): + mw.measure(f"{prefix}{key}", value) + +################################################################################ +# # +# Train Loop # +# # +################################################################################ + + prof = None # Keeps track of start/stop of profiler state. + + write_note("Starting training loop, compiling the first step...") + for step, batch in zip(range(first_step + 1, total_steps + 1), train_iter): + mw.step_start(step) + + with jax.profiler.StepTraceAnnotation("train_step", step_num=step): + with u.chrono.log_timing("z/secs/update0", noop=step > first_step + 1): + with mesh, nn.logical_axis_rules([("act_batch", ("replica", "fsdp"))]): # pytype: disable=wrong-arg-types + train_state, measurements = update_fn(train_state, batch) + + # On the first host, let's always profile a handful of early steps. + if jax.process_index() == 0: + prof = u.startstop_prof(prof, step, first_step, get_steps("log_training")) + + # Report training progress + if (u.itstime(step, get_steps("log_training"), total_steps, host=0) + or u.chrono.warmup and jax.process_index() == 0): + for i, sched_fn_cpu in enumerate(sched_fns_cpu): + mw.measure(f"global_schedule{i if i else ''}", + sched_fn_cpu(u.put_cpu(step - 1))) + measurements = jax.device_get(measurements) + for name, value in measurements.items(): + mw.measure(name, value) + u.chrono.tick(step) + measure_per_dataset_times(step) + + if not np.isfinite(measurements["training_loss"]): + raise RuntimeError(f"The loss became nan or inf somewhere within steps " + f"[{step - get_steps('log_training')}, {step}]") + + # Checkpoint saving + keep_ckpt_steps = get_steps("keep_ckpt", None) or total_steps + if save_ckpt_path and ( + (keep := u.itstime(step, keep_ckpt_steps, total_steps, first=False)) + or u.itstime(step, get_steps("ckpt", None), total_steps, first=True) + ): + u.chrono.pause(wait_for=train_state) + + # Copy because we add extra stuff to the checkpoint. + ckpt = {**train_state} + + # To save chrono state correctly and safely in a multihost setup, we + # broadcast the state to all hosts and convert it to a global array. + with jax.transfer_guard("allow"): + chrono_ckpt = multihost_utils.broadcast_one_to_all(u.chrono.save()) + chrono_shardings = jax.tree_map(lambda _: repl_sharding, chrono_ckpt) + ckpt = ckpt | {"chrono": u.reshard(chrono_ckpt, chrono_shardings)} + + u.save_checkpoint_ts(ckpt_mngr, ckpt, save_ckpt_path, step, keep) + u.chrono.resume() + + for (name, evaluator, log_steps, prefix) in evaluators(): + if u.itstime(step, log_steps, total_steps, first=False, last=True): + u.chrono.pause(wait_for=train_state) + u.chrono.tick(step) # Record things like epoch number, core hours etc. + write_note(f"{name} evaluation...\n{u.chrono.note}") + with u.chrono.log_timing(f"z/secs/eval/{name}"): + with mesh, nn.logical_axis_rules( + [("act_batch", ("replica", "fsdp"))]): # pytype: disable=wrong-arg-types + for key, value in evaluator.run(train_state): + mw.measure(f"{prefix}{key}", jax.device_get(value)) + u.chrono.resume() + mw.step_end() + + # Always give a chance to stop the profiler, no matter how things ended. + # TODO: can we also do this when dying of an exception like OOM? + if jax.process_index() == 0 and prof is not None: + u.startstop_prof(prof) + + # Last note needs to happen before the pool's closed =) + write_note(f"Done!\n{u.chrono.note}") + + pool.close() + pool.join() + mw.close() + + if ckpt_mngr: + ckpt_mngr.wait_until_finished() + + # Make sure all hosts stay up until the end of main. + u.sync() + + u.maybe_cleanup_workdir(workdir, flags.FLAGS.cleanup, info) + + +if __name__ == "__main__": + app.run(main) diff --git a/big_vision/trainers/proj/cappa/predict_fns.py b/big_vision/trainers/proj/cappa/predict_fns.py new file mode 100644 index 0000000..4e369f2 --- /dev/null +++ b/big_vision/trainers/proj/cappa/predict_fns.py @@ -0,0 +1,118 @@ +# Copyright 2023 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Prediction functions for clippo/generative.py.""" + +import functools + +import big_vision.pp.ops_text as pp_ops_text +import big_vision.utils as u +import jax +import jax.numpy as jnp +import numpy as np + +# pylint: disable=missing-function-docstring + + +# We do not jit/pmap this function, because it is passed to evaluator that +# does it later. We output as many intermediate tensors as possible for +# maximal flexibility. Later `jit` will prune out things that are not needed. +def predict_fn_perplexity( + train_state, batch, *, model): + logits = model.apply( + {"params": train_state["params"]}, + batch["image"], + batch["labels"], + train=False, + ) + return logits, {"logits": logits} + + +def predict_fn_enc_rep(train_state, batch, *, model): + logits, out = model.apply( + {"params": train_state["params"]}, + batch["image"], + None, + train=False, + return_enc_features=True, + ) + return logits, out + + +def predict_fn_score( + train_state, batch, *, model, prompt="", prompt_tokenizer=""): + """For a batch of images, return score (LL) for each image-label pair.""" + encoded = model.apply( + {"params": train_state["params"]}, + batch["image"], + train=False, + method=model.encode, + ) + + # This needs to be added by the evaluator. It is the pre-computed tokenized + # list of all available labels. For ImageNet-1k, that's (1000, 13). + all_labels = batch["_label_tokens"] + + # Optionally prefix a single prompt to all labels: + if prompt and prompt_tokenizer: + prompt = make_prompt(prompt, prompt_tokenizer) # Note: this is cached. + prompts = jnp.tile(prompt, (all_labels.shape[0], 1)) + all_labels = jnp.concatenate([prompts, all_labels], axis=-1) + # For ImageNet-1k and a prompt of length 2, we now have (1000, 15). + + def score_label(label): + """Score (LogLik) each minibatch example (image) with a single `label`.""" + label_rep = jnp.tile(label, (encoded.shape[0], 1)) + logits = model.apply( + {"params": train_state["params"]}, + encoded, + label_rep, + train=False, + decode=False, + method=model.decode, + ) + # The returned value is (batch,) scalars, the score each image has with + # this label. We turn the softmax_xent's NLL into LL so higher = better. + return -u.weighted_softmax_xent( + logits=logits, + labels=label_rep, + weights=(label_rep > 0).astype(jnp.float32), # Ignore (=0). + reduction=False, + normalize=False, + ) + + # Use lax.map() instead of vmap() to conserve memory. + nlls = jax.lax.map(score_label, all_labels) # -> (nlabel, batch) + return nlls.T # -> (batch, nlabel) array of scores. + + +@functools.cache +def make_prompt(prompt, tokenizer_path, seq_len=None): + """Tokenizes `prompt` with specified tokenizer, with optional padding.""" + tokenizer = pp_ops_text.create_tokenizer(tokenizer_path, add_eos=False) + + prompt = tokenizer.tokenize(prompt).numpy() + if seq_len: + prompt = np.pad(prompt, (0, seq_len - len(prompt))).astype(np.int32) + return prompt + + +def get_predict_fns(model): + """Returns `predict_fns` for evaluators.""" + fns = { + "perplexity": predict_fn_perplexity, + "score": predict_fn_score, + "enc_rep": predict_fn_enc_rep, + } + return {name: functools.partial(fn, model=model) for name, fn in fns.items()} diff --git a/big_vision/utils.py b/big_vision/utils.py index 23aabc3..10a0a2f 100644 --- a/big_vision/utils.py +++ b/big_vision/utils.py @@ -210,7 +210,7 @@ def load_params(ckpt, **kw): else: # Here we're now loading new-style tensorstore checkpoints. # We can be a more efficient and load params and `key` only right away. - regex = f"params/{key}/.*" if key else "params/.*" + regex = f"params/{key}($|/.*)" if key else "params/.*" checkpoint = load_checkpoint_ts(ckpt, regex=regex) params = checkpoint["params"] @@ -571,6 +571,7 @@ def log_timing_avg(self, name, *, noop=False): logging.flush() def flush_timings(self): + assert self._measure is not None for name, times in self._timing_history.items(): self._measure(name, np.mean(times)) self._timing_history.clear() @@ -939,7 +940,9 @@ def tsload(path, *, tree=None, shardings=None, regex=None): names_to_load, _ = zip(*names_and_vals) if shardings is None: - shardings = jax.sharding.SingleDeviceSharding(jax.devices("cpu")[0]) + shardings = jax.sharding.SingleDeviceSharding( + jax.local_devices(backend="cpu")[0] + ) shardings = list(jax.tree_leaves(tree_broadcast(shardings, tree))) names_to_load = [os.path.join(path, name.replace("/", "~")) @@ -1319,7 +1322,7 @@ def _make_global_arr(x, shard, shape): def put_cpu(x): """Places array/pytree on a CPU device.""" - return jax.device_put(x, jax.devices("cpu")[0]) + return jax.device_put(x, jax.local_devices(backend="cpu")[0]) # TODO: remove this logic when the @@ -1327,7 +1330,9 @@ def put_cpu(x): def jit_cpu(**extra_kwargs): def _decorator(fun): def _wrapped(*args, **kwargs): - sh = jax.sharding.SingleDeviceSharding(jax.devices("cpu")[0]) + sh = jax.sharding.SingleDeviceSharding( + jax.local_devices(backend="cpu")[0] + ) return jax.jit(fun, **extra_kwargs, out_shardings=sh)(*args, **kwargs) return _wrapped return _decorator