Skip to content

Commit

Permalink
Adds CapPa model (https://arxiv.org/abs/2306.07915). (#82)
Browse files Browse the repository at this point in the history
* Updates to `jax.local_devices(backend='tpu')`.

* Adds preprocess_ops.coco_captions

* Adds CapPa model code.

* Optional `jax.distributed.initialize()` (with warning).

The API call `jax.distributed.initialize()` will fail if the program is
executed outside a multihost environment (e.g. local testing).

* Applies Lucas's comment.
  • Loading branch information
andsteing authored Dec 12, 2023
1 parent 3b8e5ab commit 7ace659
Show file tree
Hide file tree
Showing 9 changed files with 1,354 additions and 4 deletions.
37 changes: 37 additions & 0 deletions big_vision/configs/proj/cappa/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Image Captioners Are Scalable Vision Learners Too

*by Michael Tschannen, Manoj Kumar, Andreas Steiner, Xiaohua Zhai, Neil Houlsby, Lucas Beyer* [[arxiv]](https://arxiv.org/abs/2306.07915)

![CapPa Architecture](./cappa_architecture.png)

This directory contains a config for training a CapPa model from scratch.
Note that most models in the paper were trained on a proprietary dataset
(WebLI), but similar results can be obtained by training on [LAION](https://laion.ai/).

By default, this config trains on COCO captions as this data set is readily
available in [TFDS](https://www.tensorflow.org/datasets) without manual steps.
This is not meant to produce a meaningful model, but
provides a way for the user to run the config out of the box. Please update the
config with with a TFDS-wrapped variant of your favorite image/text data set to
train capable models.

After setting up `big_vision` as described in the [main README](https://github.com/google-research/big_vision#cloud-tpu-vm-setup), training can be launched as follows

```
python big_vision.trainers.proj.cappa.generative \
--config big_vision/configs/proj/cappa/pretrain.py \
--workdir gs://$GS_BUCKET_NAME/big_vision/`date '+%m-%d_%H%M'`
```

To run the Cap baseline (autoregressive captioning without parallel prediction),
set `config.model.masked_pred_prob = 0.0`.

### Citation
```
@inproceedings{tschannen2023image,
title={Image Captioners Are Scalable Vision Learners Too},
author={Tschannen, Michael and Kumar, Manoj and Steiner, Andreas and Zhai, Xiaohua and Houlsby, Neil and Beyer, Lucas},
booktitle={Neural Information Processing Systems (NeurIPS)},
year={2023}
}
```
140 changes: 140 additions & 0 deletions big_vision/configs/proj/cappa/pretrain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright 2023 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=line-too-long
r"""Trains a CapPa model (https://arxiv.org/abs/2306.07915) on coco_captions.
This config is for reference, we never ran a full training on a large
image/text data set on public infrastructure.
big_vision.trainers.proj.cappa.generative \
--config big_vision/configs/proj/cappa/pretrain.py \
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'`
"""


from big_vision.configs import common_fewshot
import big_vision.configs.common as bvcc
import ml_collections


def get_config(arg=None):
"""Returns the base config."""
config = bvcc.parse_arg(arg,
runlocal=False,
total_steps=366_500,
batch_size=8*1024,
warmup_steps=10_000,
)

config.evals = {}
config.input = {}
config.input.batch_size = config.batch_size if not config.runlocal else 8
shuffle_buffer_size = 50_000 if not config.runlocal else 50

res = 224
patch_size = 16
max_text_tokens = 64

pp_image = (f'resize({res})|value_range(-1,1)')

def tokenizer(inkey, outkey):
return (f'tokenize(max_len={max_text_tokens}, model="c4_en", '
f'eos="sticky", inkey="{inkey}", outkey="{outkey}")')

pp_coco = (f'decode|{pp_image}|'
'coco_captions("captions")|choice(inkey="captions", outkey="text")|'
f'{tokenizer("text", "labels")}|keep("image", "labels")')
config.input.pp = pp_coco

# NOTE: "coco_captions" is way too small a dataset to train on. It's simply
# used here to serve as a smoke test that the implementation works correctly.
config.input.data = dict(name='coco_captions', split='train') # num_examples=82_783
config.input.shuffle_buffer_size = shuffle_buffer_size

config.evals.val_coco = {
'type': 'proj.cappa.perplexity',
'pred': 'perplexity',
'log_steps': 1000,
'data': dict(name='coco_captions', split='val'), # num_examples=5_000
'pp_fn': pp_coco,
}

# Few-shot metrics
config.evals.fewshot = common_fewshot.get_fewshot_lsr(
target_resolution=res, resize_resolution=int(256 / 224 * res))
config.evals.fewshot.type = 'fewshot_lsr'
config.evals.fewshot.log_steps = 5_000 if not config.runlocal else 5
config.evals.fewshot.representation_layer = 'pre_logits'
config.evals.fewshot.pred = 'enc_rep'
config.evals.fewshot.pp_eval = config.evals.fewshot.pp_train

# NOTE: Scoring of the entire imagenet validation set is rather slow:
# ~100 secs / 1k classes / host.
config.evals['imagenet/scoring'] = dict(
type='proj.cappa.scoring_classifier',
pred='score',
log_percent=0.1,
data=dict(name='imagenet2012', split='validation'),
pp_fn=f'decode|{pp_image}|keep("image", "label")',
pp_txt=tokenizer('label', 'labels'),
)

for e in config.evals.values():
e.skip_first = True

config.log_training_steps = 50
config.ckpt_steps = 1000
config.keep_ckpt_steps = None # 10_000

# Model section
config.model_name = 'proj.cappa.cappa'
config.model = ml_collections.ConfigDict()
config.model.num_layers = 12
config.model.num_heads = 12
config.model.mlp_dim = 3072
config.model.emb_dim = 768
config.model.vocab_size = 32_000
config.model.patches = (patch_size, patch_size)
config.model.seq_len = max_text_tokens
config.model.posemb_type = 'learn'

# Decoder
config.model.decoder_num_layers = 6
# 0 values here mean to use the same value as for the encoder
config.model.decoder_num_heads = 0
config.model.decoder_mlp_dim = 0
config.model.decoder_emb_dim = 0
config.model.dec_dropout_rate = 0.0
config.model.masked_pred_prob = 0.75
config.model.masking_ratio = 1.0
config.model.decoder_bias = False

config.optax_name = 'big_vision.scale_by_adafactor'
config.optax = dict(beta2_cap=0.999)
config.grad_clip_norm = 1.0
config.label_smoothing = 0.0

schedule = dict(decay_type='cosine',
warmup_steps=config.warmup_steps
if not config.runlocal else 5)

# Standard schedule
config.lr = 0.001
config.wd = 0.0001
config.schedule = schedule

config.seed = 0

return config
50 changes: 50 additions & 0 deletions big_vision/evaluators/proj/cappa/perplexity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2023 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Evaluator for perplexity of a model."""
from big_vision.evaluators import mean
import big_vision.utils as u
import jax.numpy as jnp


# Temporary global flag to facilitate backwards compatability. Will be removed
# by the end of year 2023.
API = 'jit'


def perplexity(predict_fn, normalize_by_seqlen):
"""Returns a function that computes perplexity."""

def _perplexity_fn(train_state, batch, pad_token=0, **kw):
logits, _ = predict_fn(train_state, batch, **kw)

# Ignore perplexity on the padding label.
weights = jnp.where(batch['labels'] != pad_token, 1, 0).astype(jnp.float32)
if batch.get('label_masks') is not None:
weights = weights * batch['label_masks']

losses = u.weighted_softmax_xent(
logits=logits, labels=batch['labels'],
weights=weights, label_smoothing=0.0,
reduction=False, normalize=normalize_by_seqlen)

return {'perplexity': losses}
return _perplexity_fn


class Evaluator(mean.Evaluator):
"""Perplexity evaluator."""

def __init__(self, predict_fn, *a, normalize_by_seqlen=False, **kw):
super().__init__(perplexity(predict_fn, normalize_by_seqlen), *a, **kw)
63 changes: 63 additions & 0 deletions big_vision/evaluators/proj/cappa/scoring_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2023 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Scoring classifier.
This one is based on a generative perspective for image classification.
Here we input the image as well as all the tokenized labels to compute their
perplexity and select the one with minimum loss as the prediction.
"""
import functools
from big_vision.datasets.imagenet import class_names as imagenet_class_names
from big_vision.evaluators import mean
from big_vision.pp import builder as pp_builder
import jax.numpy as jnp
import numpy as np

# Temporary global flag to facilitate backwards compatability. Will be removed
# by the end of year 2023.
API = "jit"


CLASS_NAMES = {
"imagenet2012": imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES,
}


# As a separate function to cache result across instances.
@functools.lru_cache(maxsize=None)
def get_classes(dataset_name, pp_txt):
"""Load the class label strings and tokenize them using pp_txt."""
pp_fn = pp_builder.get_preprocess_fn(pp_txt, log_data=False)
return np.array([pp_fn({"label": name})["labels"]
for name in CLASS_NAMES[dataset_name]])


def scoring(predict_fn, tokenized_labels):

def _scoring_fn(train_state, batch, *a, **kw):
batch = {"_label_tokens": tokenized_labels, **batch}
scores = predict_fn(train_state, batch, *a, **kw)
predictions = jnp.argmax(scores, axis=-1)
return {"prec@1": predictions == batch["label"]}

return _scoring_fn


class Evaluator(mean.Evaluator):
"""Evaluator for classification accuracy based on scoring all classes."""

def __init__(self, predict_fn, data, pp_fn, pp_txt, *a, **kw):
cls_tokens = get_classes(data["name"], pp_txt)
super().__init__(scoring(predict_fn, cls_tokens), data, pp_fn, *a, **kw)
Loading

0 comments on commit 7ace659

Please sign in to comment.