-
Notifications
You must be signed in to change notification settings - Fork 158
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add CLIPPO model, pp_ops, config, and readme. Also update proj/image_…
…text trainer and evaluators. (#27)
- Loading branch information
Showing
13 changed files
with
702 additions
and
123 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
## Image-and-Language Understanding from Pixels Only | ||
|
||
*by Michael Tschannen, Basil Mustafa, Neil Houlsby* [[arxiv]](https://arxiv.org/abs/2212.08045) | ||
|
||
We provide code to train CLIP with Pixels Only (CLIPPO) models on image/alt-text data sets. | ||
|
||
To train your own CLIPPO model, please follow the setup instructions in the [`big_vision` main README](https://github.com/google-research/big_vision#cloud-tpu-vm-setup). In the following, we provide the CLIPPO-specific commands required in addition to the setup, assume you are using the Google Cloud TPU setup (potentially with adapted TPU configuration, see table below). If you are using GPUs, please set up your machine directly and only execute the `--command` portions of the commands below from the `big_vision` repository root. | ||
|
||
The text rendering preproprocessing function requires manual download of the Unifont .hex files from [Unifoundry](https://unifoundry.com/unifont/) (please follow link for license).: | ||
|
||
```bash | ||
gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all \ | ||
--command "bash big_vision/pp/proj/clippo/download_unifont.sh" | ||
``` | ||
|
||
Launch the training by running | ||
|
||
```bash | ||
gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all \ | ||
--command "TFDS_DATA_DIR=gs://$GS_BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.trainers.proj.image_text.contrastive --config big_vision/configs/proj/clippo/train_clippo.py --workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'`" | ||
``` | ||
|
||
*Important note:* The input pipeline relies on [TensorFlow Datasets (TFDS)](https://www.tensorflow.org/datasets) which does not provide automatic integration with large image/alt-text datasets out of the box. The above config therefore trains by default on MS-COCO Captions which can be automatically downloaded via TFDS, and additionally initializes the CLIPPO ViT backbone with weights pretrained on ImageNet21k. This setup is not meant to produce good accuracy, but to provide the user with a way to sanity-check their setup. If you want to train on a large data set such as [`LAION-400M`](https://arxiv.org/abs/2111.02114) or [`YFCC100M`](https://arxiv.org/abs/1503.01817), please follow [these instructions](https://www.tensorflow.org/datasets/add_dataset) to wrap your data set using TFDS, and update the dataset in the config accordingly. Also note that the ImageNet1k evaluations require manual download of the data, see [these instructions](https://github.com/google-research/big_vision#preparing-tfds-data). To train with your own data set and with ImageNet1k-based evaluations, use `--config big_vision/configs/proj/clippo/train_clippo.py:test_with_coco=False,i1k_eval=True` in the command above. | ||
|
||
#### Expected results | ||
|
||
| train dataset | batch size | #steps | TPU chips | ImageNet 0-shot | MS-COCO I→T | MS-COCO T→I | Config `arg` | | ||
| :--- | ---: | ---: | ---: | :---: | :---: | :---: | :--- | | ||
| *MS-COCO (sanity check)* | 4000 | 400 | 32 v3 | 4.2 | 12.6 | 8.6 | `i1k_eval=True` | | ||
| LAION-400M | 8192 | 100k |128 v2 | 51.5 | 44.8 | 29.3 | `test_with_coco=False,i1k_eval=True` | | ||
| LAION-400M | 10240\* | 100k | 128 v3 | 53.6 | 46.7 | 30.3 | `test_with_coco=False,i1k_eval=True` | | ||
|
||
\* The experiments in the paper use a batch size of 10240 which requires a memory-optimized ViT implementation to run on 128 TPU v2 chips or 128 TPU v3 chips (in which case the TPU memory capacity allows to increase the batch size beyond 10240). | ||
|
||
#### Citation | ||
|
||
``` | ||
@article{tschannen2022image, | ||
title={Image-and-Language Understanding from Pixels Only}, | ||
author={Tschannen, Michael and Mustafa, Basil and Houlsby, Neil}, | ||
journal={arXiv preprint arXiv:2212.08045}, | ||
year={2022} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
# Copyright 2022 Big Vision Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# pylint: disable=line-too-long | ||
r"""Trains CLIP with Pixels Only (CLIPPO), https://arxiv.org/abs/2212.08045 | ||
IMPORTANT NOTE: This config uses coco_captions by default for demonstration | ||
purposes since the TFDS catalog does not provide any large image/alt-text data | ||
set; the training will not produce a model with useful accuracy. Please | ||
replace the data set below (marked by a comment) with an appropriate image/ | ||
alt-text data set wrapped in TFDS (for example LAION-400M) and run the config | ||
with the suffix `:test_with_coco=False` to train on your data set. Refer to | ||
the following guide to build a TFDS wrapper for your favorite image/alt-text | ||
data set: | ||
https://www.tensorflow.org/datasets/add_dataset | ||
Also note that evaluation on ImageNet requires manual TFDS setup, see | ||
https://github.com/google-research/big_vision#preparing-tfds-data | ||
Example training: | ||
big_vision.trainers.proj.image_text.contrastive \ | ||
--config big_vision/configs/proj/clippo/train_clippo.py \ | ||
--workdir gs://[your_bucket]/big_vision/`date '+%Y-%m-%d_%H%M'` | ||
""" | ||
|
||
import big_vision.configs.common as bvcc | ||
from big_vision.configs.common_fewshot import get_fewshot_lsr | ||
from big_vision.configs.proj.image_text import common | ||
from ml_collections import ConfigDict | ||
|
||
|
||
def get_config(arg=None): | ||
"""The base configuration.""" | ||
arg = bvcc.parse_arg( | ||
arg, res=224, runlocal=False, variant='B/16', | ||
test_with_coco=True, i1k_eval=False) | ||
config = ConfigDict() | ||
|
||
config.input = {} | ||
if arg.test_with_coco: | ||
# Use COCO Captions for sanity-checking | ||
config.input.data = dict(name='coco_captions', split='train') | ||
val_data = dict(config.input.data) | ||
val_data['split'] = 'val' | ||
config.input.batch_size = 4000 if not arg.runlocal else 32 | ||
config.input.shuffle_buffer_size = 50_000 if not arg.runlocal else 50 | ||
config.total_steps = 400 if not arg.runlocal else 10 | ||
else: | ||
# Please add your favorite image/alt-text dataset here | ||
config.input.data = None | ||
val_data = None | ||
assert config.input.data is not None and val_data is not None, ( | ||
config.input.data, val_data) | ||
|
||
# The value in the paper is 10 * 1024, which requires 128 TPUv3 cores or a | ||
# memory optimized ViT implementation when running on 128 TPUv2 cores. | ||
config.input.batch_size = 8 * 1024 if not arg.runlocal else 32 | ||
config.input.shuffle_buffer_size = 250_000 if not arg.runlocal else 50 | ||
config.total_steps = 100_000 if not arg.runlocal else 10 | ||
|
||
def tokenizer(inkey, outkey='labels'): | ||
return (f'render_unifont(' | ||
f'inkey="{inkey}", ' | ||
f'outkey="{outkey}", ' | ||
f'image_size={arg.res}, ' | ||
f'lower=True, ' | ||
f'font_size=16, ' | ||
f'text_brightness=0, ' | ||
f'background_brightness=127)|' | ||
f'value_range(-1, 1, inkey="{outkey}", outkey="{outkey}")') | ||
|
||
pp_image = f'decode|resize({arg.res})|value_range(-1,1)' | ||
if arg.test_with_coco: | ||
# Train with augmentation when sanity-checking | ||
pp_image_aug = ( | ||
f'decode|resize({arg.res})|flip_lr|randaug(2,10)|value_range(-1,1)') | ||
config.input.pp = pp_eval = ( | ||
f'{pp_image_aug}|flatten|{tokenizer("captions/text")}|' | ||
f'keep("image", "labels")') | ||
else: | ||
config.input.pp = pp_eval = ( | ||
f'{pp_image}|flatten|{tokenizer("text")}|keep("image", "labels")') | ||
|
||
config.pp_modules = [ | ||
'ops_general', 'ops_image', 'ops_text', 'proj.clippo.pp_ops'] | ||
|
||
config.log_training_steps = 50 | ||
config.ckpt_steps = 1000 | ||
config.keep_ckpt_steps = 5000 | ||
|
||
config.loss_use_global_batch = True | ||
|
||
# Define the model | ||
config.model_name = 'proj.clippo.one_tower' | ||
|
||
config.model = ConfigDict() | ||
config.model.image_model = 'vit' | ||
config.model.image = ConfigDict({ | ||
'variant': arg.variant, | ||
'pool_type': 'map', | ||
'head_zeroinit': False, | ||
}) | ||
|
||
if arg.test_with_coco: | ||
# Initialize with ImageNet21k pretrained checkpoint for sanity-checking | ||
assert arg.variant == 'B/16', arg.variant | ||
config.model_init = {'image': 'howto-i21k-B/16'} | ||
config.model_load = {} | ||
config.model_load['img_load_kw'] = { | ||
'dont_load': ['^head/.*', '^MAPHead_0/.*', 'cls']} | ||
|
||
config.model.temperature_init = 10.0 | ||
config.model.out_dim = 768 | ||
|
||
# Define the optimizer | ||
config.optax_name = 'big_vision.scale_by_adafactor' | ||
config.grad_clip_norm = 1.0 | ||
|
||
if arg.test_with_coco: | ||
# Short schedule for sanity-checking | ||
config.lr = 0.0001 | ||
config.wd = 0.0003 | ||
config.schedule = dict(decay_type='rsqrt', | ||
timescale=100, | ||
warmup_steps=100 if not arg.runlocal else 5, | ||
cooldown_steps=100 if not arg.runlocal else 5) | ||
else: | ||
config.lr = 0.001 | ||
config.wd = 0.0001 | ||
config.schedule = dict(decay_type='rsqrt', | ||
timescale=10_000, | ||
warmup_steps=10_000 if not arg.runlocal else 5, | ||
cooldown_steps=10_000 if not arg.runlocal else 5) | ||
|
||
# Eval section (Both few-shot and zero-shot) | ||
eval_common = dict( | ||
type='proj.image_text.contrastive', | ||
use_global_batch=config.loss_use_global_batch, | ||
log_steps=1000 if not arg.runlocal else 5, | ||
) | ||
config.evals = {} | ||
sub = '[:4]' if arg.runlocal else '' | ||
config.evals.val = { | ||
**eval_common, | ||
'data': val_data, | ||
'pp_fn': pp_eval, | ||
} | ||
config.evals.coco = { | ||
**eval_common, | ||
'data': dict(name='coco_captions', split=f'val{sub}'), | ||
'pp_fn': ( | ||
f'{pp_image}|flatten|{tokenizer("captions/text")}|' | ||
f'keep("image", "labels")'), | ||
} | ||
|
||
if arg.i1k_eval: | ||
# Requires manual download, see | ||
# https://github.com/google-research/big_vision#preparing-tfds-data | ||
config.evals.imagenet = { | ||
**eval_common, | ||
'data': dict(name='imagenet2012', split=f'validation{sub}'), | ||
'pp_fn': ( | ||
f'{pp_image}|clip_i1k_label_names|' | ||
f'{tokenizer("labels")}|keep("image", "labels")'), | ||
} | ||
config.evals.disclf = dict( | ||
type='proj.image_text.discriminative_classifier', | ||
pp_txt=tokenizer('texts', 'labels'), | ||
prefix='z/0shot/', | ||
log_steps=5_000 if not arg.runlocal else 5) | ||
|
||
config.evals.retrieval_coco = common.get_coco( | ||
pp_img=f'resize({arg.res})|value_range(-1, 1)', | ||
pp_txt=tokenizer('texts'), | ||
log_steps=5_000 if not arg.runlocal else 5, | ||
) | ||
|
||
# Few-shot metrics | ||
config.evals.fewshot = get_fewshot_lsr() | ||
config.evals.fewshot.log_steps = 5_000 if not arg.runlocal else 5 | ||
config.evals.fewshot.representation_layer = 'img/pre_logits' | ||
|
||
config.seed = 0 | ||
|
||
return config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.