Skip to content

Commit

Permalink
Update avsr recipe (#3493)
Browse files Browse the repository at this point in the history
Summary:
This PR is to include few changes in the AV-ASR recipe. The changes include better results, a faster face detector (Mediapipe), renamed variable names, a streamlined dataloader, and a few illustrated examples. These changes were made to improve the usability of the recipe.

Pull Request resolved: #3493

Reviewed By: mthrok

Differential Revision: D47758072

Pulled By: mpc001

fbshipit-source-id: 4533587776f3a7a74f3f11b0ece773a0934bacdc
  • Loading branch information
Pingchuan Ma authored and facebook-github-bot committed Jul 25, 2023
1 parent 56e2266 commit d464479
Show file tree
Hide file tree
Showing 19 changed files with 496 additions and 304 deletions.
103 changes: 54 additions & 49 deletions examples/avsr/README.md
Original file line number Diff line number Diff line change
@@ -1,70 +1,75 @@
<p align="center"><img width="160" src="doc/lip_white.png" alt="logo"></p>
<h1 align="center">RNN-T ASR/VSR/AV-ASR Examples</h1>
<p align="center"><img width="160" src="https://download.pytorch.org/torchaudio/doc-assets/avsr/lip_white.png" alt="logo"></p>
<h1 align="center">Real-time ASR/VSR/AV-ASR Examples</h1>

This repository contains sample implementations of training and evaluation pipelines for RNNT based automatic, visual, and audio-visual (ASR, VSR, AV-ASR) models on LRS3. This repository includes both streaming/non-streaming modes. We follow the same training pipeline as [AutoAVSR](https://arxiv.org/abs/2303.14307).
<div align="center">

[📘Introduction](#introduction) |
[📊Training](#Training) |
[🔮Evaluation](#Evaluation)
</div>

## Introduction

This directory contains the training recipe for real-time audio, visual, and audio-visual speech recognition (ASR, VSR, AV-ASR) models, which is an extension of [Auto-AVSR](https://arxiv.org/abs/2303.14307).

Please refer to [this tutorial]() for real-time AV-ASR inference from microphone and camera.

## Preparation
1. Setup the environment.
```
conda create -y -n autoavsr python=3.8
conda activate autoavsr
```

2. Install PyTorch nightly version (Pytorch, Torchvision, Torchaudio) from [source](https://pytorch.org/get-started/), along with all necessary packages:
1. Install PyTorch (pytorch, torchvision, torchaudio) from [source](https://pytorch.org/get-started/), along with all necessary packages:

```Shell
pip install pytorch-lightning sentencepiece
pip install torch torchvision torchaudio pytorch-lightning sentencepiece
```

3. Preprocess LRS3 to a cropped-face dataset from the [data_prep](./data_prep) folder.
2. Preprocess LRS3. See the instructions in the [data_prep](./data_prep) folder.

4. `[sp_model_path]` is a sentencepiece model to encode targets, which can be generated using `train_spm.py`.

### Training ASR or VSR model

- `[root_dir]` is the root directory for the LRS3 cropped-face dataset.
- `[modality]` is the input modality type, including `v`, `a`, and `av`.
- `[mode]` is the model type, including `online` and `offline`.
## Usage

### Training

```Shell

python train.py --root-dir [root_dir] \
--sp-model-path ./spm_unigram_1023.model
--exp-dir ./exp \
--num-nodes 8 \
--gpus 8 \
--md [modality] \
--mode [mode]
python train.py --exp-dir=[exp_dir] \
--exp-name=[exp_name] \
--modality=[modality] \
--mode=[mode] \
--root-dir=[root-dir] \
--sp-model-path=[sp_model_path] \
--num-nodes=[num_nodes] \
--gpus=[gpus]
```

### Training AV-ASR model
- `exp-dir` and `exp-name`: The directory where the checkpoints will be saved, will be stored at the location `[exp_dir]`/`[exp_name]`.
- `modality`: Type of the input modality. Valid values are: `video`, `audio`, and `audiovisual`.
- `mode`: Type of the mode. Valid values are: `online` and `offline`.
- `root-dir`: Path to the root directory where all preprocessed files will be stored.
- `sp-model-path`: Path to the sentencepiece model. Default: `./spm_unigram_1023.model`, which can be produced using `train_spm.py`.
- `num-nodes`: The number of machines used. Default: 4.
- `gpus`: The number of gpus in each machine. Default: 8.

### Evaluation

```Shell
python train.py --root-dir [root-dir] \
--sp-model-path ./spm_unigram_1023.model
--exp-dir ./exp \
--num-nodes 8 \
--gpus 8 \
--md av \
--mode [mode]
python eval.py --modality=[modality] \
--mode=[mode] \
--root-dir=[dataset_path] \
--sp-model-path=[sp_model_path] \
--checkpoint-path=[checkpoint_path]
```

### Evaluating models
- `modality`: Type of the input modality. Valid values are: `video`, `audio`, and `audiovisual`.
- `mode`: Type of the mode. Valid values are: `online` and `offline`.
- `root-dir`: Path to the root directory where all preprocessed files will be stored.
- `sp-model-path`: Path to the sentencepiece model. Default: `./spm_unigram_1023.model`.
- `checkpoint-path`: Path to a pretraned model.

```Shell
python eval.py --dataset-path [dataset_path] \
--sp-model-path ./spm_unigram_1023.model
--md [modality] \
--mode [mode] \
--checkpoint-path [checkpoint_path]
```
## Results

The table below contains WER for AV-ASR models [offline evaluation].
The table below contains WER for AV-ASR models that were trained from scratch [offline evaluation].

| Model | WER [%] | Params (M) |
|:-----------:|:------------:|:--------------:|
| Non-streaming models | |
| AV-ASR | 4.0 | 50 |
| Streaming models | |
| AV-ASR | 4.3 | 40 |
| Model | Training dataset (hours) | WER [%] | Params (M) |
|:--------------------:|:------------------------:|:-------:|:----------:|
| Non-streaming models | | | |
| AV-ASR | LRS3 (438) | 3.9 | 50 |
| Streaming models | | | |
| AV-ASR | LRS3 (438) | 3.9 | 40 |
7 changes: 2 additions & 5 deletions examples/avsr/average_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ def average_checkpoints(last):


def ensemble(args):
last = [
os.path.join(args.exp_dir, args.experiment_name, f"epoch={n}.ckpt")
for n in range(args.epochs - 10, args.epochs)
]
model_path = os.path.join(args.exp_dir, args.experiment_name, "model_avg_10.pth")
last = [os.path.join(args.exp_dir, args.exp_name, f"epoch={n}.ckpt") for n in range(args.epochs - 10, args.epochs)]
model_path = os.path.join(args.exp_dir, args.exp_name, "model_avg_10.pth")
torch.save({"state_dict": average_checkpoints(last)}, model_path)
45 changes: 6 additions & 39 deletions examples/avsr/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,52 +110,19 @@ def __init__(
self.num_workers = num_workers

def train_dataloader(self):
datasets = [LRS3(self.args, subset="train")]

if not self.train_dataset_lengths:
self.train_dataset_lengths = [dataset._lengthlist for dataset in datasets]

dataset = torch.utils.data.ConcatDataset(
[
CustomBucketDataset(
dataset,
lengths,
self.max_frames,
self.train_num_buckets,
batch_size=self.batch_size,
)
for dataset, lengths in zip(datasets, self.train_dataset_lengths)
]
dataset = LRS3(self.args, subset="train")
dataset = CustomBucketDataset(
dataset, dataset.lengths, self.max_frames, self.train_num_buckets, batch_size=self.batch_size
)

dataset = TransformDataset(dataset, self.train_transform)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=self.num_workers,
batch_size=None,
shuffle=self.train_shuffle,
dataset, num_workers=self.num_workers, batch_size=None, shuffle=self.train_shuffle
)
return dataloader

def val_dataloader(self):
datasets = [LRS3(self.args, subset="val")]

if not self.val_dataset_lengths:
self.val_dataset_lengths = [dataset._lengthlist for dataset in datasets]

dataset = torch.utils.data.ConcatDataset(
[
CustomBucketDataset(
dataset,
lengths,
self.max_frames,
1,
batch_size=self.batch_size,
)
for dataset, lengths in zip(datasets, self.val_dataset_lengths)
]
)

dataset = LRS3(self.args, subset="val")
dataset = CustomBucketDataset(dataset, dataset.lengths, self.max_frames, 1, batch_size=self.batch_size)
dataset = TransformDataset(dataset, self.val_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=self.num_workers)
return dataloader
Expand Down
74 changes: 49 additions & 25 deletions examples/avsr/data_prep/README.md
Original file line number Diff line number Diff line change
@@ -1,48 +1,72 @@

# Preprocessing LRS3
# Pre-process LRS3

We provide a pre-processing pipeline to detect and crop full-face images in this repository.
We provide a pre-processing pipeline in this repository for detecting and cropping full-face regions of interest (ROIs) as well as corresponding audio waveforms for LRS3.

## Prerequisites
## Introduction

Install all dependency-packages.
Before feeding the raw stream into our model, each video sequence has to undergo a specific pre-processing procedure. This involves three critical steps. The first step is to perform face detection. Following that, each individual frame is aligned to a referenced frame, commonly known as the mean face, in order to normalize rotation and size differences across frames. The final step in the pre-processing module is to crop the face region from the aligned face image.

<div align="center">

<table style="display: inline-table;">
<tr><td><img src="https://download.pytorch.org/torchaudio/doc-assets/avsr/original.gif", width="144"></td><td><img src="https://download.pytorch.org/torchaudio/doc-assets/avsr/detected.gif" width="144"></td><td><img src="https://download.pytorch.org/torchaudio/doc-assets/avsr/transformed.gif" width="144"></td><td><img src="https://download.pytorch.org/torchaudio/doc-assets/avsr/cropped.gif" width="144"></td></tr>
<tr><td>0. Original</td> <td>1. Detection</td> <td>2. Transformation</td> <td>3. Face ROIs</td> </tr>
</table>
</div>

## Preparation

1. Install all dependency-packages.

```Shell
pip install -r requirements.txt
```

Install [RetinaFace](./tools) tracker.
2. Install [retinaface](./tools) or [mediapipe](https://pypi.org/project/mediapipe/) tracker. If you have installed the tracker, please skip it.

## Preprocessing LRS3

## Preprocessing
To pre-process the LRS3 dataset, plrase follow these steps:

### Step 1. Pre-process the LRS3 dataset.
Please run the following script to pre-process the LRS3 dataset:
1. Download the LRS3 dataset from the official website.

2. Run the following command to preprocess the dataset:

```Shell
python main.py \
python preprocess_lrs3.py \
--data-dir=[data_dir] \
--detector=[detector] \
--dataset=[dataset] \
--root=[root] \
--folder=[folder] \
--groups=[num_groups] \
--job-index=[job_index]
--root-dir=[root] \
--subset=[subset] \
--seg-duration=[seg_duration] \
--groups=[n] \
--job-index=[j]
```

- `[data_dir]` and `[landmarks_dir]` are the directories for original dataset and corresponding landmarks.

- `[root]` is the directory for saved cropped-face dataset.

- `[folder]` can be set to `train` or `test`.
- `data-dir`: Path to the directory containing video files.
- `detector`: Type of face detector. Valid values are: `mediapipe` and `retinaface`. Default: `retinaface`.
- `dataset`: Name of the dataset. Valid value is: `lrs3`.
- `root-dir`: Path to the root directory where all preprocessed files will be stored.
- `subset`: Name of the subset. Valid values are: `train` and `test`.
- `seg-duration`: Length of the maximal segment in seconds. Default: `16`.
- `groups`: Number of groups to split the dataset into.
- `job-index`: Job index for the current group. Valid values are an integer within the range of `[0, n)`.

- `[num_groups]` and `[job-index]` are used to split the dataset into multiple threads, where `[job-index]` is an integer in [0, `[num_groups]`).

### Step 2. Merge the label list.
After completing Step 2, run the following script to merge all labels.
3. Run the following command to merge all labels:

```Shell
python merge.py \
--root-dir=[root_dir] \
--dataset=[dataset] \
--root=[root] \
--folder=[folder] \
--groups=[num_groups] \
--subset=[subset] \
--seg-duration=[seg_duration] \
--groups=[n]
```

- `root-dir`: Path to the root directory where all preprocessed files will be stored.
- `dataset`: Name of the dataset. Valid values are: `lrs2` and `lrs3`.
- `subset`: The subset name of the dataset. For LRS2, valid values are `train`, `val`, and `test`. For LRS3, valid values are `train` and `test`.
- `seg-duration`: Length of the maximal segment in seconds. Default: `16`.
- `groups`: Number of groups to split the dataset into.
6 changes: 6 additions & 0 deletions examples/avsr/data_prep/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ def __init__(self, modality, detector="retinaface", resize=None):

self.landmarks_detector = LandmarksDetector(device="cuda:0")
self.video_process = VideoProcess(resize=resize)
if detector == "mediapipe":
from detectors.mediapipe.detector import LandmarksDetector
from detectors.mediapipe.video_process import VideoProcess

self.landmarks_detector = LandmarksDetector()
self.video_process = VideoProcess(resize=resize)

def load_data(self, data_filename, transform=True):
if self.modality == "audio":
Expand Down
52 changes: 52 additions & 0 deletions examples/avsr/data_prep/detectors/mediapipe/detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#! /usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright 2021 Imperial College London (Pingchuan Ma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)

import warnings

import mediapipe as mp

import numpy as np
import torchvision

warnings.filterwarnings("ignore")


class LandmarksDetector:
def __init__(self):
self.mp_face_detection = mp.solutions.face_detection
self.short_range_detector = self.mp_face_detection.FaceDetection(
min_detection_confidence=0.5, model_selection=0
)
self.full_range_detector = self.mp_face_detection.FaceDetection(min_detection_confidence=0.5, model_selection=1)

def __call__(self, video_frames):
landmarks = self.detect(video_frames, self.full_range_detector)
if all(element is None for element in landmarks):
landmarks = self.detect(video_frames, self.short_range_detector)
assert any(l is not None for l in landmarks), "Cannot detect any frames in the video"
return landmarks

def detect(self, filename, detector):
video_frames = torchvision.io.read_video(filename, pts_unit="sec")[0].numpy()
landmarks = []
for frame in video_frames:
results = detector.process(frame)
if not results.detections:
landmarks.append(None)
continue
face_points = []
for idx, detected_faces in enumerate(results.detections):
max_id, max_size = 0, 0
bboxC = detected_faces.location_data.relative_bounding_box
ih, iw, ic = frame.shape
bbox = int(bboxC.xmin * iw), int(bboxC.ymin * ih), int(bboxC.width * iw), int(bboxC.height * ih)
bbox_size = (bbox[2] - bbox[0]) + (bbox[3] - bbox[1])
if bbox_size > max_size:
max_id, max_size = idx, bbox_size
lmx = [[int(bboxC.xmin * iw), int(bboxC.ymin * ih)], [int(bboxC.width * iw), int(bboxC.height * ih)]]
face_points.append(lmx)
landmarks.append(np.reshape(np.array(face_points[max_id]), (2, 2)))
return landmarks
Loading

0 comments on commit d464479

Please sign in to comment.