Skip to content

Commit

Permalink
Release Audiovisual Masked Autoencoders, ICCV 2023
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 570066803
  • Loading branch information
anuragarnab authored and Scenic Authors committed Oct 2, 2023
1 parent de635e1 commit 67bde34
Show file tree
Hide file tree
Showing 31 changed files with 11,208 additions and 0 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ Projects that were developed in Scenic or used it for their experiments:
* [Verbs in Action: Improving verb understanding in video-language models](https://arxiv.org/abs/2304.06708)
* [Unified Visual Relationship Detection with Vision and Language Models](https://arxiv.org/abs/2303.08998)
* [REVEAL: Retrieval-Augmented Visual-Language Pre-Training with Multi-Source Multimodal Knowledge Memory](https://arxiv.org/abs/2212.05221)
* [Audiovisual Masked Autoencoders](https://arxiv.org/abs/2212.05922)

More information can be found in [projects](https://github.com/google-research/scenic/tree/main/scenic/projects#list-of-projects-hosted-in-scenic).

Expand All @@ -85,6 +86,7 @@ Baselines that were reproduced in Scenic:
* [PCT: Point Cloud Transformer](https://arxiv.org/abs/2012.09688)
* [Universal Transformers](https://arxiv.org/abs/1807.03819)
* [PonderNet](https://arxiv.org/abs/2107.05407)
* [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377)


More information can be found in [baseline models](https://github.com/google-research/scenic/tree/main/scenic/projects/baselines#scenic-baseline-models).
Expand Down
7 changes: 7 additions & 0 deletions scenic/projects/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@
> trained end-to-end from spectrograms and full-frame RGB for the task of
> audiovisual speech recognition (AV-ASR).
* [Audiovisual Masked Autoencoders](av-mae)

> Audiovisual Masked Autoencoders performs self-supervised learning on
> multiple modalities (audio and video) to improve representation learning
> for both unimodal and multimodal downstream tasks. Details can be found
> in the [paper](https://arxiv.org/abs/2212.05922).
* [ViViT](vivit)

> ViViT is a family of pure-transformer based models for video
Expand Down
60 changes: 60 additions & 0 deletions scenic/projects/av_mae/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
### Audiovisual Masked Autoencoders

This repository is the JAX implementation of our ICCV 2023 paper,
[Audiovisual Masked Autoencoders](https://arxiv.org/abs/2212.05922).

Audiovisual Masked Autoencoders (AV-MAE) pretrains models on video and audio
data jointly, and shows improvements in both unimodal and multimodal downstream
tasks.

#### Getting Started

This project, like others in Scenic, uses [configuration files](configs).

To pretrain a model on AudioSet, run the following command:

```shell
$ python -m scenic.projects.av_mae.main \
--config=scenic/projects/av_mae/configs/audioset/pretrain.py \
--workdir=av_mae/
```

And then to finetune this model, run:

```shell
$ python -m scenic.projects.av_mae.main \
--config=scenic/projects/av_mae/configs/audioset/finetune.py \
--workdir=av_mae/
```

Make sure to set `config.init_from.checkpoint_path` to the pretrained model
when finetuning.

#### Model Zoo

The following table contains AV-MAE checkpoints trained on various datasets.
Checkpoints are provided as Scenic checkpoints compatible with
[Flax](https://github.com/google/flax).

| Dataset | Model size | Pretraining modalities | Pretrained model | Finetuning modalities | Finetuned model | mAP / Accuracy |
|----------|------------|------------------------|-------------------|-----------------------|-------------------|----------------|
| AudioSet | Large | audio, video | [config](configs/audioset/pretrain.py) [checkpoint](https://storage.googleapis.com/scenic-bucket/av_mae/audioset/as2m_selfsup_audiovisual/checkpoint) | audio, video | [config](configs/audioset/finetune.py) [checkpoint](https://storage.googleapis.com/scenic-bucket/av_mae/audioset/as2m_selfsup_audiovisual_finetuned_audiovisual/checkpoint) | 51.8 |
| | | | | audio | [config](configs/audioset/finetune.py) [checkpoint](https://storage.googleapis.com/scenic-bucket/av_mae/audioset/as2m_selfsup_audiovisual_finetuned_audio/checkpoint) | 46.6 |
| | | | | video | [config](configs/audioset/finetune.py) [checkpoint](https://storage.googleapis.com/scenic-bucket/av_mae/audioset/as2m_selfsup_audiovisual_finetuned_video/checkpoint) | 31.1 |
| | | audio | [config](configs/audioset/pretrain.py#L144) [checkpoint](https://storage.googleapis.com/scenic-bucket/av_mae/audioset/as2m_selfsup_audio/checkpoint) | audio | [config](configs/audioset/finetune.py) [checkpoint](https://storage.googleapis.com/scenic-bucket/av_mae/audioset/as2m_selfsup_audio_finetuned_audio/checkpoint) | 46.4 |
| VGGSound | Large | audio, video | [config](configs/vggsound/pretrain.py) [checkpoint](https://storage.googleapis.com/scenic-bucket/av_mae/vggsound/vggsound_selfsup_audiovisual/checkpoint) | audio, video | [config](configs/vggsound/finetune.py) [checkpoint](https://storage.googleapis.com/scenic-bucket/av_mae/vggsound/vggsound_selfsup_audiovisual_finetuned_audiovisual/checkpoint) | 65.0 |
| | | | | audio | [config](configs/vggsound/finetune.py) [checkpoint](https://storage.googleapis.com/scenic-bucket/av_mae/vggsound/vggsound_selfsup_audiovisual_finetuned_audio/checkpoint) | 57.2 |
| | | | | video | [config](configs/vggsound/finetune.py) [checkpoint](https://storage.googleapis.com/scenic-bucket/av_mae/vggsound/vggsound_selfsup_audiovisual_finetuned_video/checkpoint) | 50.3 |

#### Reference

If you use this project, please cite the following BibTeX entry:

```
@inproceedings{georgescu2023audiovisual,
title={Audiovisual Masked Autoencoders},
author={Georgescu, Mariana-Iuliana and Fonseca, Eduardo and Ionescu, Radu Tudor and Lucic, Mario and Schmid, Cordelia and Arnab, Anurag},
booktitle={International Conference on Computer Vision (ICCV)},
year={2023}
}
```
304 changes: 304 additions & 0 deletions scenic/projects/av_mae/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
"""Base model definition."""

import functools
from typing import Dict, Optional, Tuple, Union

from immutabledict import immutabledict
import jax
import jax.numpy as jnp
import numpy as np

from scenic.model_lib.base_models import base_model
from scenic.model_lib.base_models import model_utils
from scenic.model_lib.base_models import regression_model
from scenic.model_lib.layers import nn_ops


# TODO(aarnab): Compute validation metrics.
_REGRESSION_METRICS = immutabledict({
'mean_squared_error':
(functools.partial(model_utils.weighted_squared_error, axis=-1),
model_utils.num_examples)
})

Patch = Union[Tuple[int, int], Tuple[int, int, int]]


class FeatureTargets():
RGB = 'rgb'
SPECTROGRAM = 'spectrogram'


def get_output_shapes(feature_target: str,
patch_size: Patch,
select_central_frame: Optional[bool] = None,
channels: int = 3):
"""Returns the output shape, depending on the feature regression target."""
if feature_target == FeatureTargets.RGB:
if len(patch_size) == 3 and select_central_frame:
output_elements = patch_size[:2] + (1, channels)
else:
output_elements = patch_size + (channels,)
return np.prod(np.array(output_elements))
elif feature_target == FeatureTargets.SPECTROGRAM:
output_elements = patch_size + (channels,)
return np.prod(np.array(output_elements))
else:
raise NotImplementedError('Other feature targets not implemented yet.')


def extract_tubelets_from_video(
x: jnp.ndarray,
tubelet_size: Tuple[int, int, int],
select_central_frame: bool) -> jnp.ndarray:
"""Extracts tubelets from videos for use as regression targets.
Args:
x: Input tensor of shape [batch, time, height, width, channels]
tubelet_size: Tuple containing tubelet/patch size parameterised as
[ph, pw, pt].
select_central_frame: If True, select the central frame as the feature
regression target.
Returns:
Tensor of shape [n, gt * gh * gw, pt * ph * pw * c] where
gt = t // pt, gh = h // ph, gw = w // pw.
"""
ph, pw, pt = tubelet_size
n, t, h, w, c = x.shape
gt, gh, gw = t // pt, h // ph, w // pw
x = x.reshape([n, gt, pt, gh, ph, gw, pw, c])
# Shape is then [n, gt, gh, gw, pt, ph, pw, c].
x = jnp.transpose(x, axes=[0, 1, 3, 5, 2, 4, 6, 7])
if select_central_frame:
x = x[:, :, :, :, pt // 2, :, :, :]
pt = 1
return x.reshape([n, gt * gh * gw, pt * ph * pw * c])


def get_rgb_targets(inputs: jnp.ndarray,
patch_size: Patch,
select_central_frame: Optional[bool] = None,
reconstruct_grayscale: bool = False,
standardise_per_patch: bool = False,
standardise_per_patch_channels: bool = False
) -> jnp.ndarray:
"""Get RGB targets to use for feature regression.
Here, the targets are the raw rgb patches of the image.
Args:
inputs: Tensor of shape [b, h, w, c] or [b, t, h, w, c]. The former are
images, and the later video.
patch_size: The shape of the patch, defined as [ph, pw] for images, and
[ph, pw, pt] for video.
select_central_frame: If video and True, select the central frame as the
feature regression target.
reconstruct_grayscale: If True, the target patch is in grayscale rather
than rgb.
standardise_per_patch: If true, standardise each patch by subtracting the
mean and dividing by the standard deviation of that patch.
standardise_per_patch_channels: If true, standardise each patch by
subtracting the mean and dividing by the standard deviation of that patch
per channels.
Returns:
Patched inputs. For images, shape is [b, gh * gw, ph * pw * c] where
gh = h // ph and gw = w // pw.
For video, shape is [b, gt * gh * gw, pt * ph * pw * c].
"""
if not (inputs.ndim == 4 or inputs.ndim == 5):
raise ValueError('Inputs should be 4D (images) or 5D (video).')

if reconstruct_grayscale:
# Reference for converting between RGB and grayscale.
# https://en.wikipedia.org/wiki/Luma_%28video%29
# Also used in tf.image.rgb_to_grayscale
rgb_weights = jnp.tile(jnp.array([[0.2989, 0.5870, 0.1140]]), (3, 1)).T
inputs = jnp.matmul(inputs, rgb_weights)

if inputs.ndim == 4:
batch = inputs.shape[0]
# Shape is [batch, ht, wt, hp, wp, c]
patched_image = nn_ops.patch_image(inputs, inputs_shape=None,
patch_size=patch_size)
num_tokens = patched_image.shape[1] * patched_image.shape[2]
patched_input = jnp.reshape(patched_image, (batch, num_tokens, -1))
elif inputs.ndim == 5:
if select_central_frame is None:
raise ValueError('`select_central_frame` must be defined.')
patched_input = extract_tubelets_from_video(
inputs,
patch_size,
select_central_frame)

if standardise_per_patch:
patched_input = jax.nn.standardize(patched_input, axis=-1, epsilon=1e-6)
elif standardise_per_patch_channels:
old_shape = patched_input.shape
batch, num_tokens = patched_input.shape[:2]
num_channels = inputs.shape[-1]
patched_input = jnp.reshape(patched_input,
(batch, num_tokens, -1, num_channels))
patched_input = jax.nn.standardize(patched_input, axis=2, epsilon=1e-6)
patched_input = jnp.reshape(patched_input, old_shape)

return patched_input


def get_spectogram_targets(inputs: jnp.ndarray,
patch_size: Patch,
standardise_per_patch: bool = False
) -> jnp.ndarray:
"""Get spectogram targets to use for feature regression.
Here, the targets are the raw spectogram patches of the image.
Args:
inputs: Tensor of shape [b, h, w, c].
patch_size: The shape of the patch, defined as [ph, pw].
standardise_per_patch: If true, standardise each patch by subtracting the
mean and dividing by the standard deviation of that patch.
Returns:
Patched inputs. Shape is [b, gh * gw, ph * pw * c].
"""
if inputs.ndim != 4:
raise ValueError('Inputs should be 4D.')

if inputs.ndim == 4:
batch = inputs.shape[0]
# Shape is [batch, ht, wt, hp, wp, c]
patched_image = nn_ops.patch_image(inputs, inputs_shape=None,
patch_size=patch_size)
num_tokens = patched_image.shape[1] * patched_image.shape[2]
patched_input = jnp.reshape(patched_image, (batch, num_tokens, -1))

if standardise_per_patch:
patched_input = jax.nn.standardize(patched_input, axis=-1, epsilon=1e-6)

return patched_input


def feature_regression_metrics_function(
predictions: jnp.ndarray,
prediction_masks: jnp.ndarray,
batch: base_model.Batch,
feature_target: str,
metrics: base_model.MetricNormalizerFnDict = _REGRESSION_METRICS,
) -> Dict[str, Tuple[float, int]]:
"""Calculate metrics for the feature regression task.
Currently we assume each metric_fn has the API:
```metric_fn(predictions, targets, weights)```
and returns an array of shape [batch,]. We also assume that to compute
the aggregate metric, one should sum across all batches, then divide by the
total samples seen. In this way we currently only support metrics of the 1/N
sum f(inputs, targets). Note, the caller is responsible for dividing by
the normalizer when computing the mean of each metric.
Args:
predictions: Output of model in shape [batch, length].
prediction_masks: Which of the predictions are valid.
batch: Batch (dict) with keys 'targets' and optionally 'batch_mask'.
feature_target: The feature target used for feature regression.
metrics: The regression metrics to evaluate. The key is the
name of the metric, and the value is the metrics function.
Returns:
A dict of metrics, in which keys are metrics name and values are tuples of
(metric, normalizer).
"""
if feature_target == FeatureTargets.RGB:
targets = batch['target_rgb']
else:
raise NotImplementedError(
f'Feature target {feature_target} not implemented')

batch_mask = batch.get('batch_mask')
if batch_mask is None:
batch_mask = jnp.ones(prediction_masks.shape)
if batch_mask.ndim == 1:
n_batch = predictions.shape[0]
batch_mask = jnp.reshape(batch_mask, (n_batch, 1))
weights = batch_mask * prediction_masks

evaluated_metrics = {}
for key, val in metrics.items():
evaluated_metrics[key] = model_utils.psum_metric_normalizer(
(val[0](predictions, targets, weights), val[1](predictions, targets, # pytype: disable=wrong-arg-types # jax-ndarray
weights)))
return evaluated_metrics # pytype: disable=bad-return-type # jax-ndarray


class MaskedFeatureRegressionModel(regression_model.RegressionModel):
"""Defines commonalities between all masked self-supervised models."""

def loss_function(self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks
predictions: jnp.ndarray,
prediction_masks: jnp.ndarray,
batch: base_model.Batch,
model_params: Optional[jnp.ndarray] = None) -> float:
"""Returns the (weighted) mean squared error.
Args:
predictions: Output of model in shape [batch, num_tokens, channels].
prediction_masks: The tokens to compute the loss on. Shape is
[batch, num_tokens]
batch: Batch (dict) with keys 'targets' and optionally 'batch_mask'.
model_params: Parameters of the model, for optionally applying
L2 regularization.
Returns:
The (weighted) mean squared error.
"""
batch_mask = batch.get('batch_mask')
if batch_mask is None:
batch_mask = jnp.ones(prediction_masks.shape)
if batch_mask.ndim == 1:
batch_mask = jnp.expand_dims(batch_mask, axis=-1)
if self.config.masked_feature_loss.get('loss_unmasked_tokens', False):
loss_weights = batch_mask
else:
loss_weights = batch_mask * prediction_masks

feature_target = self.config.masked_feature_loss.target
if feature_target == FeatureTargets.RGB:
targets = batch[f'target_{feature_target}']
else:
raise NotImplementedError(
f'Feature target {feature_target} not implemented.')

total_loss = model_utils.weighted_mean_squared_error(
predictions, targets, loss_weights, axis=-1)

# Mean squared error is normalised by the number of tokens.
# If this option is enabled, we normalise further by the number of features
# we are regressing to.
if self.config.masked_feature_loss.get('normalise_by_output_dimension',
False):
output_dimension = predictions.shape[-1]
total_loss = total_loss / output_dimension

if self.config.get('l2_decay_factor'):
l2_loss = model_utils.l2_regularization(model_params)
total_loss += 0.5 * self.config.l2_decay_factor * l2_loss
return total_loss # pytype: disable=bad-return-type # jax-ndarray

def get_metrics_fn(self, split: Optional[str] = None) -> base_model.MetricFn:
"""Returns a callable metric function for the model.
By default, we return the same metric for each split.
Args:
split: The split for which we calculate the metrics. It should be one of
the ['train', 'validation', 'test'].
Returns: A metric function with the following API:
```metrics_fn(predictions, batch)```
"""

del split # Same function for all splits.
return functools.partial(
feature_regression_metrics_function,
feature_target=self.config.masked_feature_loss.target,
metrics=_REGRESSION_METRICS)
Loading

0 comments on commit 67bde34

Please sign in to comment.