-
Notifications
You must be signed in to change notification settings - Fork 158
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds CapPa model (https://arxiv.org/abs/2306.07915). (#82)
* Updates to `jax.local_devices(backend='tpu')`. * Adds preprocess_ops.coco_captions * Adds CapPa model code. * Optional `jax.distributed.initialize()` (with warning). The API call `jax.distributed.initialize()` will fail if the program is executed outside a multihost environment (e.g. local testing). * Applies Lucas's comment.
- Loading branch information
Showing
9 changed files
with
1,354 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Image Captioners Are Scalable Vision Learners Too | ||
|
||
*by Michael Tschannen, Manoj Kumar, Andreas Steiner, Xiaohua Zhai, Neil Houlsby, Lucas Beyer* [[arxiv]](https://arxiv.org/abs/2306.07915) | ||
|
||
![CapPa Architecture](./cappa_architecture.png) | ||
|
||
This directory contains a config for training a CapPa model from scratch. | ||
Note that most models in the paper were trained on a proprietary dataset | ||
(WebLI), but similar results can be obtained by training on [LAION](https://laion.ai/). | ||
|
||
By default, this config trains on COCO captions as this data set is readily | ||
available in [TFDS](https://www.tensorflow.org/datasets) without manual steps. | ||
This is not meant to produce a meaningful model, but | ||
provides a way for the user to run the config out of the box. Please update the | ||
config with with a TFDS-wrapped variant of your favorite image/text data set to | ||
train capable models. | ||
|
||
After setting up `big_vision` as described in the [main README](https://github.com/google-research/big_vision#cloud-tpu-vm-setup), training can be launched as follows | ||
|
||
``` | ||
python big_vision.trainers.proj.cappa.generative \ | ||
--config big_vision/configs/proj/cappa/pretrain.py \ | ||
--workdir gs://$GS_BUCKET_NAME/big_vision/`date '+%m-%d_%H%M'` | ||
``` | ||
|
||
To run the Cap baseline (autoregressive captioning without parallel prediction), | ||
set `config.model.masked_pred_prob = 0.0`. | ||
|
||
### Citation | ||
``` | ||
@inproceedings{tschannen2023image, | ||
title={Image Captioners Are Scalable Vision Learners Too}, | ||
author={Tschannen, Michael and Kumar, Manoj and Steiner, Andreas and Zhai, Xiaohua and Houlsby, Neil and Beyer, Lucas}, | ||
booktitle={Neural Information Processing Systems (NeurIPS)}, | ||
year={2023} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
# Copyright 2023 Big Vision Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# pylint: disable=line-too-long | ||
r"""Trains a CapPa model (https://arxiv.org/abs/2306.07915) on coco_captions. | ||
This config is for reference, we never ran a full training on a large | ||
image/text data set on public infrastructure. | ||
big_vision.trainers.proj.cappa.generative \ | ||
--config big_vision/configs/proj/cappa/pretrain.py \ | ||
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` | ||
""" | ||
|
||
|
||
from big_vision.configs import common_fewshot | ||
import big_vision.configs.common as bvcc | ||
import ml_collections | ||
|
||
|
||
def get_config(arg=None): | ||
"""Returns the base config.""" | ||
config = bvcc.parse_arg(arg, | ||
runlocal=False, | ||
total_steps=366_500, | ||
batch_size=8*1024, | ||
warmup_steps=10_000, | ||
) | ||
|
||
config.evals = {} | ||
config.input = {} | ||
config.input.batch_size = config.batch_size if not config.runlocal else 8 | ||
shuffle_buffer_size = 50_000 if not config.runlocal else 50 | ||
|
||
res = 224 | ||
patch_size = 16 | ||
max_text_tokens = 64 | ||
|
||
pp_image = (f'resize({res})|value_range(-1,1)') | ||
|
||
def tokenizer(inkey, outkey): | ||
return (f'tokenize(max_len={max_text_tokens}, model="c4_en", ' | ||
f'eos="sticky", inkey="{inkey}", outkey="{outkey}")') | ||
|
||
pp_coco = (f'decode|{pp_image}|' | ||
'coco_captions("captions")|choice(inkey="captions", outkey="text")|' | ||
f'{tokenizer("text", "labels")}|keep("image", "labels")') | ||
config.input.pp = pp_coco | ||
|
||
# NOTE: "coco_captions" is way too small a dataset to train on. It's simply | ||
# used here to serve as a smoke test that the implementation works correctly. | ||
config.input.data = dict(name='coco_captions', split='train') # num_examples=82_783 | ||
config.input.shuffle_buffer_size = shuffle_buffer_size | ||
|
||
config.evals.val_coco = { | ||
'type': 'proj.cappa.perplexity', | ||
'pred': 'perplexity', | ||
'log_steps': 1000, | ||
'data': dict(name='coco_captions', split='val'), # num_examples=5_000 | ||
'pp_fn': pp_coco, | ||
} | ||
|
||
# Few-shot metrics | ||
config.evals.fewshot = common_fewshot.get_fewshot_lsr( | ||
target_resolution=res, resize_resolution=int(256 / 224 * res)) | ||
config.evals.fewshot.type = 'fewshot_lsr' | ||
config.evals.fewshot.log_steps = 5_000 if not config.runlocal else 5 | ||
config.evals.fewshot.representation_layer = 'pre_logits' | ||
config.evals.fewshot.pred = 'enc_rep' | ||
config.evals.fewshot.pp_eval = config.evals.fewshot.pp_train | ||
|
||
# NOTE: Scoring of the entire imagenet validation set is rather slow: | ||
# ~100 secs / 1k classes / host. | ||
config.evals['imagenet/scoring'] = dict( | ||
type='proj.cappa.scoring_classifier', | ||
pred='score', | ||
log_percent=0.1, | ||
data=dict(name='imagenet2012', split='validation'), | ||
pp_fn=f'decode|{pp_image}|keep("image", "label")', | ||
pp_txt=tokenizer('label', 'labels'), | ||
) | ||
|
||
for e in config.evals.values(): | ||
e.skip_first = True | ||
|
||
config.log_training_steps = 50 | ||
config.ckpt_steps = 1000 | ||
config.keep_ckpt_steps = None # 10_000 | ||
|
||
# Model section | ||
config.model_name = 'proj.cappa.cappa' | ||
config.model = ml_collections.ConfigDict() | ||
config.model.num_layers = 12 | ||
config.model.num_heads = 12 | ||
config.model.mlp_dim = 3072 | ||
config.model.emb_dim = 768 | ||
config.model.vocab_size = 32_000 | ||
config.model.patches = (patch_size, patch_size) | ||
config.model.seq_len = max_text_tokens | ||
config.model.posemb_type = 'learn' | ||
|
||
# Decoder | ||
config.model.decoder_num_layers = 6 | ||
# 0 values here mean to use the same value as for the encoder | ||
config.model.decoder_num_heads = 0 | ||
config.model.decoder_mlp_dim = 0 | ||
config.model.decoder_emb_dim = 0 | ||
config.model.dec_dropout_rate = 0.0 | ||
config.model.masked_pred_prob = 0.75 | ||
config.model.masking_ratio = 1.0 | ||
config.model.decoder_bias = False | ||
|
||
config.optax_name = 'big_vision.scale_by_adafactor' | ||
config.optax = dict(beta2_cap=0.999) | ||
config.grad_clip_norm = 1.0 | ||
config.label_smoothing = 0.0 | ||
|
||
schedule = dict(decay_type='cosine', | ||
warmup_steps=config.warmup_steps | ||
if not config.runlocal else 5) | ||
|
||
# Standard schedule | ||
config.lr = 0.001 | ||
config.wd = 0.0001 | ||
config.schedule = schedule | ||
|
||
config.seed = 0 | ||
|
||
return config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Copyright 2023 Big Vision Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Evaluator for perplexity of a model.""" | ||
from big_vision.evaluators import mean | ||
import big_vision.utils as u | ||
import jax.numpy as jnp | ||
|
||
|
||
# Temporary global flag to facilitate backwards compatability. Will be removed | ||
# by the end of year 2023. | ||
API = 'jit' | ||
|
||
|
||
def perplexity(predict_fn, normalize_by_seqlen): | ||
"""Returns a function that computes perplexity.""" | ||
|
||
def _perplexity_fn(train_state, batch, pad_token=0, **kw): | ||
logits, _ = predict_fn(train_state, batch, **kw) | ||
|
||
# Ignore perplexity on the padding label. | ||
weights = jnp.where(batch['labels'] != pad_token, 1, 0).astype(jnp.float32) | ||
if batch.get('label_masks') is not None: | ||
weights = weights * batch['label_masks'] | ||
|
||
losses = u.weighted_softmax_xent( | ||
logits=logits, labels=batch['labels'], | ||
weights=weights, label_smoothing=0.0, | ||
reduction=False, normalize=normalize_by_seqlen) | ||
|
||
return {'perplexity': losses} | ||
return _perplexity_fn | ||
|
||
|
||
class Evaluator(mean.Evaluator): | ||
"""Perplexity evaluator.""" | ||
|
||
def __init__(self, predict_fn, *a, normalize_by_seqlen=False, **kw): | ||
super().__init__(perplexity(predict_fn, normalize_by_seqlen), *a, **kw) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Copyright 2023 Big Vision Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Scoring classifier. | ||
This one is based on a generative perspective for image classification. | ||
Here we input the image as well as all the tokenized labels to compute their | ||
perplexity and select the one with minimum loss as the prediction. | ||
""" | ||
import functools | ||
from big_vision.datasets.imagenet import class_names as imagenet_class_names | ||
from big_vision.evaluators import mean | ||
from big_vision.pp import builder as pp_builder | ||
import jax.numpy as jnp | ||
import numpy as np | ||
|
||
# Temporary global flag to facilitate backwards compatability. Will be removed | ||
# by the end of year 2023. | ||
API = "jit" | ||
|
||
|
||
CLASS_NAMES = { | ||
"imagenet2012": imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES, | ||
} | ||
|
||
|
||
# As a separate function to cache result across instances. | ||
@functools.lru_cache(maxsize=None) | ||
def get_classes(dataset_name, pp_txt): | ||
"""Load the class label strings and tokenize them using pp_txt.""" | ||
pp_fn = pp_builder.get_preprocess_fn(pp_txt, log_data=False) | ||
return np.array([pp_fn({"label": name})["labels"] | ||
for name in CLASS_NAMES[dataset_name]]) | ||
|
||
|
||
def scoring(predict_fn, tokenized_labels): | ||
|
||
def _scoring_fn(train_state, batch, *a, **kw): | ||
batch = {"_label_tokens": tokenized_labels, **batch} | ||
scores = predict_fn(train_state, batch, *a, **kw) | ||
predictions = jnp.argmax(scores, axis=-1) | ||
return {"prec@1": predictions == batch["label"]} | ||
|
||
return _scoring_fn | ||
|
||
|
||
class Evaluator(mean.Evaluator): | ||
"""Evaluator for classification accuracy based on scoring all classes.""" | ||
|
||
def __init__(self, predict_fn, data, pp_fn, pp_txt, *a, **kw): | ||
cls_tokens = get_classes(data["name"], pp_txt) | ||
super().__init__(scoring(predict_fn, cls_tokens), data, pp_fn, *a, **kw) |
Oops, something went wrong.