Skip to content

Commit

Permalink
Add CLIPPO model, pp_ops, config, and readme. Also update proj/image_…
Browse files Browse the repository at this point in the history
…text trainer and evaluators. (#27)
  • Loading branch information
mitscha authored Dec 30, 2022
1 parent b00544b commit fd2d3bd
Show file tree
Hide file tree
Showing 13 changed files with 702 additions and 123 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ codebase:
Xiaohua Zhai*, Xiao Wang*, Basil Mustafa*, Andreas Steiner*, Daniel Keysers,
Alexander Kolesnikov, and Lucas Beyer*\
Resources: [trainer](big_vision/trainers/proj/image_text/contrastive.py), [config](big_vision/configs/proj/image_text/lit_coco.py), [colab](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/image_text/lit.ipynb).
- [Image-and-Language Understanding from Pixels Only](https://arxiv.org/abs/2212.08045), by
Michael Tschannen, Basil Mustafa, Neil Houlsby
Resources [readme](big_vision/configs/proj/clippo/README.md), [config](big_vision/configs/proj/clippo/train_clippo.py)

### Knowledge distillation

Expand Down
44 changes: 44 additions & 0 deletions big_vision/configs/proj/clippo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
## Image-and-Language Understanding from Pixels Only

*by Michael Tschannen, Basil Mustafa, Neil Houlsby* [[arxiv]](https://arxiv.org/abs/2212.08045)

We provide code to train CLIP with Pixels Only (CLIPPO) models on image/alt-text data sets.

To train your own CLIPPO model, please follow the setup instructions in the [`big_vision` main README](https://github.com/google-research/big_vision#cloud-tpu-vm-setup). In the following, we provide the CLIPPO-specific commands required in addition to the setup, assume you are using the Google Cloud TPU setup (potentially with adapted TPU configuration, see table below). If you are using GPUs, please set up your machine directly and only execute the `--command` portions of the commands below from the `big_vision` repository root.

The text rendering preproprocessing function requires manual download of the Unifont .hex files from [Unifoundry](https://unifoundry.com/unifont/) (please follow link for license).:

```bash
gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all \
--command "bash big_vision/pp/proj/clippo/download_unifont.sh"
```

Launch the training by running

```bash
gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all \
--command "TFDS_DATA_DIR=gs://$GS_BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.trainers.proj.image_text.contrastive --config big_vision/configs/proj/clippo/train_clippo.py --workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'`"
```

*Important note:* The input pipeline relies on [TensorFlow Datasets (TFDS)](https://www.tensorflow.org/datasets) which does not provide automatic integration with large image/alt-text datasets out of the box. The above config therefore trains by default on MS-COCO Captions which can be automatically downloaded via TFDS, and additionally initializes the CLIPPO ViT backbone with weights pretrained on ImageNet21k. This setup is not meant to produce good accuracy, but to provide the user with a way to sanity-check their setup. If you want to train on a large data set such as [`LAION-400M`](https://arxiv.org/abs/2111.02114) or [`YFCC100M`](https://arxiv.org/abs/1503.01817), please follow [these instructions](https://www.tensorflow.org/datasets/add_dataset) to wrap your data set using TFDS, and update the dataset in the config accordingly. Also note that the ImageNet1k evaluations require manual download of the data, see [these instructions](https://github.com/google-research/big_vision#preparing-tfds-data). To train with your own data set and with ImageNet1k-based evaluations, use `--config big_vision/configs/proj/clippo/train_clippo.py:test_with_coco=False,i1k_eval=True` in the command above.

#### Expected results

| train dataset | batch size | #steps | TPU chips | ImageNet 0-shot | MS-COCO I→T | MS-COCO T→I | Config `arg` |
| :--- | ---: | ---: | ---: | :---: | :---: | :---: | :--- |
| *MS-COCO (sanity check)* | 4000 | 400 | 32 v3 | 4.2 | 12.6 | 8.6 | `i1k_eval=True` |
| LAION-400M | 8192 | 100k |128 v2 | 51.5 | 44.8 | 29.3 | `test_with_coco=False,i1k_eval=True` |
| LAION-400M | 10240\* | 100k | 128 v3 | 53.6 | 46.7 | 30.3 | `test_with_coco=False,i1k_eval=True` |

\* The experiments in the paper use a batch size of 10240 which requires a memory-optimized ViT implementation to run on 128 TPU v2 chips or 128 TPU v3 chips (in which case the TPU memory capacity allows to increase the batch size beyond 10240).

#### Citation

```
@article{tschannen2022image,
title={Image-and-Language Understanding from Pixels Only},
author={Tschannen, Michael and Mustafa, Basil and Houlsby, Neil},
journal={arXiv preprint arXiv:2212.08045},
year={2022}
}
```
199 changes: 199 additions & 0 deletions big_vision/configs/proj/clippo/train_clippo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# Copyright 2022 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=line-too-long
r"""Trains CLIP with Pixels Only (CLIPPO), https://arxiv.org/abs/2212.08045
IMPORTANT NOTE: This config uses coco_captions by default for demonstration
purposes since the TFDS catalog does not provide any large image/alt-text data
set; the training will not produce a model with useful accuracy. Please
replace the data set below (marked by a comment) with an appropriate image/
alt-text data set wrapped in TFDS (for example LAION-400M) and run the config
with the suffix `:test_with_coco=False` to train on your data set. Refer to
the following guide to build a TFDS wrapper for your favorite image/alt-text
data set:
https://www.tensorflow.org/datasets/add_dataset
Also note that evaluation on ImageNet requires manual TFDS setup, see
https://github.com/google-research/big_vision#preparing-tfds-data
Example training:
big_vision.trainers.proj.image_text.contrastive \
--config big_vision/configs/proj/clippo/train_clippo.py \
--workdir gs://[your_bucket]/big_vision/`date '+%Y-%m-%d_%H%M'`
"""

import big_vision.configs.common as bvcc
from big_vision.configs.common_fewshot import get_fewshot_lsr
from big_vision.configs.proj.image_text import common
from ml_collections import ConfigDict


def get_config(arg=None):
"""The base configuration."""
arg = bvcc.parse_arg(
arg, res=224, runlocal=False, variant='B/16',
test_with_coco=True, i1k_eval=False)
config = ConfigDict()

config.input = {}
if arg.test_with_coco:
# Use COCO Captions for sanity-checking
config.input.data = dict(name='coco_captions', split='train')
val_data = dict(config.input.data)
val_data['split'] = 'val'
config.input.batch_size = 4000 if not arg.runlocal else 32
config.input.shuffle_buffer_size = 50_000 if not arg.runlocal else 50
config.total_steps = 400 if not arg.runlocal else 10
else:
# Please add your favorite image/alt-text dataset here
config.input.data = None
val_data = None
assert config.input.data is not None and val_data is not None, (
config.input.data, val_data)

# The value in the paper is 10 * 1024, which requires 128 TPUv3 cores or a
# memory optimized ViT implementation when running on 128 TPUv2 cores.
config.input.batch_size = 8 * 1024 if not arg.runlocal else 32
config.input.shuffle_buffer_size = 250_000 if not arg.runlocal else 50
config.total_steps = 100_000 if not arg.runlocal else 10

def tokenizer(inkey, outkey='labels'):
return (f'render_unifont('
f'inkey="{inkey}", '
f'outkey="{outkey}", '
f'image_size={arg.res}, '
f'lower=True, '
f'font_size=16, '
f'text_brightness=0, '
f'background_brightness=127)|'
f'value_range(-1, 1, inkey="{outkey}", outkey="{outkey}")')

pp_image = f'decode|resize({arg.res})|value_range(-1,1)'
if arg.test_with_coco:
# Train with augmentation when sanity-checking
pp_image_aug = (
f'decode|resize({arg.res})|flip_lr|randaug(2,10)|value_range(-1,1)')
config.input.pp = pp_eval = (
f'{pp_image_aug}|flatten|{tokenizer("captions/text")}|'
f'keep("image", "labels")')
else:
config.input.pp = pp_eval = (
f'{pp_image}|flatten|{tokenizer("text")}|keep("image", "labels")')

config.pp_modules = [
'ops_general', 'ops_image', 'ops_text', 'proj.clippo.pp_ops']

config.log_training_steps = 50
config.ckpt_steps = 1000
config.keep_ckpt_steps = 5000

config.loss_use_global_batch = True

# Define the model
config.model_name = 'proj.clippo.one_tower'

config.model = ConfigDict()
config.model.image_model = 'vit'
config.model.image = ConfigDict({
'variant': arg.variant,
'pool_type': 'map',
'head_zeroinit': False,
})

if arg.test_with_coco:
# Initialize with ImageNet21k pretrained checkpoint for sanity-checking
assert arg.variant == 'B/16', arg.variant
config.model_init = {'image': 'howto-i21k-B/16'}
config.model_load = {}
config.model_load['img_load_kw'] = {
'dont_load': ['^head/.*', '^MAPHead_0/.*', 'cls']}

config.model.temperature_init = 10.0
config.model.out_dim = 768

# Define the optimizer
config.optax_name = 'big_vision.scale_by_adafactor'
config.grad_clip_norm = 1.0

if arg.test_with_coco:
# Short schedule for sanity-checking
config.lr = 0.0001
config.wd = 0.0003
config.schedule = dict(decay_type='rsqrt',
timescale=100,
warmup_steps=100 if not arg.runlocal else 5,
cooldown_steps=100 if not arg.runlocal else 5)
else:
config.lr = 0.001
config.wd = 0.0001
config.schedule = dict(decay_type='rsqrt',
timescale=10_000,
warmup_steps=10_000 if not arg.runlocal else 5,
cooldown_steps=10_000 if not arg.runlocal else 5)

# Eval section (Both few-shot and zero-shot)
eval_common = dict(
type='proj.image_text.contrastive',
use_global_batch=config.loss_use_global_batch,
log_steps=1000 if not arg.runlocal else 5,
)
config.evals = {}
sub = '[:4]' if arg.runlocal else ''
config.evals.val = {
**eval_common,
'data': val_data,
'pp_fn': pp_eval,
}
config.evals.coco = {
**eval_common,
'data': dict(name='coco_captions', split=f'val{sub}'),
'pp_fn': (
f'{pp_image}|flatten|{tokenizer("captions/text")}|'
f'keep("image", "labels")'),
}

if arg.i1k_eval:
# Requires manual download, see
# https://github.com/google-research/big_vision#preparing-tfds-data
config.evals.imagenet = {
**eval_common,
'data': dict(name='imagenet2012', split=f'validation{sub}'),
'pp_fn': (
f'{pp_image}|clip_i1k_label_names|'
f'{tokenizer("labels")}|keep("image", "labels")'),
}
config.evals.disclf = dict(
type='proj.image_text.discriminative_classifier',
pp_txt=tokenizer('texts', 'labels'),
prefix='z/0shot/',
log_steps=5_000 if not arg.runlocal else 5)

config.evals.retrieval_coco = common.get_coco(
pp_img=f'resize({arg.res})|value_range(-1, 1)',
pp_txt=tokenizer('texts'),
log_steps=5_000 if not arg.runlocal else 5,
)

# Few-shot metrics
config.evals.fewshot = get_fewshot_lsr()
config.evals.fewshot.log_steps = 5_000 if not arg.runlocal else 5
config.evals.fewshot.representation_layer = 'img/pre_logits'

config.seed = 0

return config
74 changes: 70 additions & 4 deletions big_vision/configs/proj/image_text/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,30 +32,96 @@
# pylint: enable=line-too-long


def _square875(sz):
return f'resize({int(sz/0.875)})|central_crop({sz})|value_range(-1,1)'


def _aspect75(sz):
return f'resize_small({int(sz/0.75)})|central_crop({sz})|value_range(-1,1)'


def _drop_no_real_label(f):
return len(f['real_label']) > 0


def _drop_no_imagenet(f):
return len(f['labels_imagenet']) > 0


DISCLF_DATASET_OVERRIDES = {
'imagenet2012': {'class_names': 'clip', 'split': 'validation'},
'imagenet2012_minival': {
'dataset_name': 'imagenet2012',
'class_names': 'clip',
'split': 'train[99%:]',
},
'imagenet2012_real': {
'split': 'validation',
'class_names': 'clip',
'class_names_dataset_name': 'imagenet2012',
'pp_img': lambda sz: (
_square875(sz) + '|pad_to_shape(inkey="real_label", outkey="label", shape=[10], pad_value=-1)|keep("label", "image")'), # pylint: disable=line-too-long
'filter_fn': _drop_no_real_label,
},
'imagenet_v2': {'class_names': 'clip'},
'imagenet_a': {
'class_names': 'clip',
'pp_img': lambda sz: _aspect75(sz) + '|map("i1k_i1ka")',
},
'imagenet_r': {
'class_names': 'clip',
'pp_img': lambda sz: _square875(sz) + '|map("i1k_i1kr")',
},
}


def get_disclf(sz, *, log_steps, pp_txt=None, dataset_names=('imagenet2012',)):
"""Returns config for discriminative_classifier of specified datasets."""
config = ml_collections.ConfigDict(dict(
dataset_names=list(dataset_names),
type='proj.image_text.discriminative_classifier',
prefix='z/0shot/',
pp_img=_square875(sz),
dataset_overrides={},
log_steps=log_steps,
cache_final=True,
))
if pp_txt:
config.pp_txt = pp_txt
for name in dataset_names:
if name in DISCLF_DATASET_OVERRIDES:
config.dataset_overrides[name] = {**DISCLF_DATASET_OVERRIDES[name]}
d = config.dataset_overrides[name]
if 'pp_img' in d and callable(d['pp_img']):
with d.ignore_type():
d['pp_img'] = d['pp_img'](sz)
return config


def get_coco(
*,
log_steps,
pp_img='resize(224)|value_range(-1, 1)',
pp_txt='tokenize(max_len=16, inkey="texts", eos="sticky", pad_value=1)',
prefix='z/retr/coco_',
log_steps):
prefix='z/retr/coco_'):
"""Returns config for mscoco retrieval zero-shot.
Args:
log_steps: How often the evaluators should be run.
pp_img: Pre-processing string for "image" feature.
pp_txt: Pre-processing string for texts (expected to tokenize "texts" to
"labels").
prefix: Prefix to use for metrics.
log_steps: How often the evaluators should be run.
Returns:
`ConfigDict` that can be used as a retrieval evaluator configuration.
"""
return ml_collections.ConfigDict({
'type': 'proj.image_text.retrieval',
'log_steps': log_steps,
'pp_txt': pp_txt,
'pp_img': pp_img,
'prefix': prefix,
'dataset': 'coco_captions',
'txt_name': ('captions', 'text'),
'log_steps': log_steps,
})
16 changes: 12 additions & 4 deletions big_vision/evaluators/proj/image_text/contrastive.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Evaluator for the contrastive task."""
"""Evaluator for the contrastive task.
DON'T COMPARE ACROSS RUNS, use for training health monitoring only.
Note that this evaluator's `ncorrect_minibatch` is only a rough proxy for
training progress and does not report the actual `ncorrect`: when the same
labels found multiple times in a batch, then the reported value is biased
towards lower values.
Also note that the `ncorrect_minibatch` is a function of batch size (it's a lot
easier to find correct values in small batches).
"""
import functools

from big_vision import input_pipeline
Expand All @@ -38,9 +49,6 @@ def get_eval_fn(predict_fn, use_global_batch):

@functools.partial(jax.pmap, axis_name="batch")
def _eval_fn(params, images, labels, mask):

# Ignore the entries with all zero labels for evaluation.
mask *= jnp.clip(labels.max(axis=1), 0, 1)
zimg, ztxt, extras = predict_fn(params, images, labels)

if use_global_batch:
Expand Down
Loading

0 comments on commit fd2d3bd

Please sign in to comment.