-
Notifications
You must be signed in to change notification settings - Fork 664
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: This PR is to include few changes in the AV-ASR recipe. The changes include better results, a faster face detector (Mediapipe), renamed variable names, a streamlined dataloader, and a few illustrated examples. These changes were made to improve the usability of the recipe. Pull Request resolved: #3493 Reviewed By: mthrok Differential Revision: D47758072 Pulled By: mpc001 fbshipit-source-id: 4533587776f3a7a74f3f11b0ece773a0934bacdc
- Loading branch information
1 parent
56e2266
commit d464479
Showing
19 changed files
with
496 additions
and
304 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,70 +1,75 @@ | ||
<p align="center"><img width="160" src="doc/lip_white.png" alt="logo"></p> | ||
<h1 align="center">RNN-T ASR/VSR/AV-ASR Examples</h1> | ||
<p align="center"><img width="160" src="https://download.pytorch.org/torchaudio/doc-assets/avsr/lip_white.png" alt="logo"></p> | ||
<h1 align="center">Real-time ASR/VSR/AV-ASR Examples</h1> | ||
|
||
This repository contains sample implementations of training and evaluation pipelines for RNNT based automatic, visual, and audio-visual (ASR, VSR, AV-ASR) models on LRS3. This repository includes both streaming/non-streaming modes. We follow the same training pipeline as [AutoAVSR](https://arxiv.org/abs/2303.14307). | ||
<div align="center"> | ||
|
||
[📘Introduction](#introduction) | | ||
[📊Training](#Training) | | ||
[🔮Evaluation](#Evaluation) | ||
</div> | ||
|
||
## Introduction | ||
|
||
This directory contains the training recipe for real-time audio, visual, and audio-visual speech recognition (ASR, VSR, AV-ASR) models, which is an extension of [Auto-AVSR](https://arxiv.org/abs/2303.14307). | ||
|
||
Please refer to [this tutorial]() for real-time AV-ASR inference from microphone and camera. | ||
|
||
## Preparation | ||
1. Setup the environment. | ||
``` | ||
conda create -y -n autoavsr python=3.8 | ||
conda activate autoavsr | ||
``` | ||
|
||
2. Install PyTorch nightly version (Pytorch, Torchvision, Torchaudio) from [source](https://pytorch.org/get-started/), along with all necessary packages: | ||
1. Install PyTorch (pytorch, torchvision, torchaudio) from [source](https://pytorch.org/get-started/), along with all necessary packages: | ||
|
||
```Shell | ||
pip install pytorch-lightning sentencepiece | ||
pip install torch torchvision torchaudio pytorch-lightning sentencepiece | ||
``` | ||
|
||
3. Preprocess LRS3 to a cropped-face dataset from the [data_prep](./data_prep) folder. | ||
2. Preprocess LRS3. See the instructions in the [data_prep](./data_prep) folder. | ||
|
||
4. `[sp_model_path]` is a sentencepiece model to encode targets, which can be generated using `train_spm.py`. | ||
|
||
### Training ASR or VSR model | ||
|
||
- `[root_dir]` is the root directory for the LRS3 cropped-face dataset. | ||
- `[modality]` is the input modality type, including `v`, `a`, and `av`. | ||
- `[mode]` is the model type, including `online` and `offline`. | ||
## Usage | ||
|
||
### Training | ||
|
||
```Shell | ||
|
||
python train.py --root-dir [root_dir] \ | ||
--sp-model-path ./spm_unigram_1023.model | ||
--exp-dir ./exp \ | ||
--num-nodes 8 \ | ||
--gpus 8 \ | ||
--md [modality] \ | ||
--mode [mode] | ||
python train.py --exp-dir=[exp_dir] \ | ||
--exp-name=[exp_name] \ | ||
--modality=[modality] \ | ||
--mode=[mode] \ | ||
--root-dir=[root-dir] \ | ||
--sp-model-path=[sp_model_path] \ | ||
--num-nodes=[num_nodes] \ | ||
--gpus=[gpus] | ||
``` | ||
|
||
### Training AV-ASR model | ||
- `exp-dir` and `exp-name`: The directory where the checkpoints will be saved, will be stored at the location `[exp_dir]`/`[exp_name]`. | ||
- `modality`: Type of the input modality. Valid values are: `video`, `audio`, and `audiovisual`. | ||
- `mode`: Type of the mode. Valid values are: `online` and `offline`. | ||
- `root-dir`: Path to the root directory where all preprocessed files will be stored. | ||
- `sp-model-path`: Path to the sentencepiece model. Default: `./spm_unigram_1023.model`, which can be produced using `train_spm.py`. | ||
- `num-nodes`: The number of machines used. Default: 4. | ||
- `gpus`: The number of gpus in each machine. Default: 8. | ||
|
||
### Evaluation | ||
|
||
```Shell | ||
python train.py --root-dir [root-dir] \ | ||
--sp-model-path ./spm_unigram_1023.model | ||
--exp-dir ./exp \ | ||
--num-nodes 8 \ | ||
--gpus 8 \ | ||
--md av \ | ||
--mode [mode] | ||
python eval.py --modality=[modality] \ | ||
--mode=[mode] \ | ||
--root-dir=[dataset_path] \ | ||
--sp-model-path=[sp_model_path] \ | ||
--checkpoint-path=[checkpoint_path] | ||
``` | ||
|
||
### Evaluating models | ||
- `modality`: Type of the input modality. Valid values are: `video`, `audio`, and `audiovisual`. | ||
- `mode`: Type of the mode. Valid values are: `online` and `offline`. | ||
- `root-dir`: Path to the root directory where all preprocessed files will be stored. | ||
- `sp-model-path`: Path to the sentencepiece model. Default: `./spm_unigram_1023.model`. | ||
- `checkpoint-path`: Path to a pretraned model. | ||
|
||
```Shell | ||
python eval.py --dataset-path [dataset_path] \ | ||
--sp-model-path ./spm_unigram_1023.model | ||
--md [modality] \ | ||
--mode [mode] \ | ||
--checkpoint-path [checkpoint_path] | ||
``` | ||
## Results | ||
|
||
The table below contains WER for AV-ASR models [offline evaluation]. | ||
The table below contains WER for AV-ASR models that were trained from scratch [offline evaluation]. | ||
|
||
| Model | WER [%] | Params (M) | | ||
|:-----------:|:------------:|:--------------:| | ||
| Non-streaming models | | | ||
| AV-ASR | 4.0 | 50 | | ||
| Streaming models | | | ||
| AV-ASR | 4.3 | 40 | | ||
| Model | Training dataset (hours) | WER [%] | Params (M) | | ||
|:--------------------:|:------------------------:|:-------:|:----------:| | ||
| Non-streaming models | | | | | ||
| AV-ASR | LRS3 (438) | 3.9 | 50 | | ||
| Streaming models | | | | | ||
| AV-ASR | LRS3 (438) | 3.9 | 40 | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,48 +1,72 @@ | ||
|
||
# Preprocessing LRS3 | ||
# Pre-process LRS3 | ||
|
||
We provide a pre-processing pipeline to detect and crop full-face images in this repository. | ||
We provide a pre-processing pipeline in this repository for detecting and cropping full-face regions of interest (ROIs) as well as corresponding audio waveforms for LRS3. | ||
|
||
## Prerequisites | ||
## Introduction | ||
|
||
Install all dependency-packages. | ||
Before feeding the raw stream into our model, each video sequence has to undergo a specific pre-processing procedure. This involves three critical steps. The first step is to perform face detection. Following that, each individual frame is aligned to a referenced frame, commonly known as the mean face, in order to normalize rotation and size differences across frames. The final step in the pre-processing module is to crop the face region from the aligned face image. | ||
|
||
<div align="center"> | ||
|
||
<table style="display: inline-table;"> | ||
<tr><td><img src="https://download.pytorch.org/torchaudio/doc-assets/avsr/original.gif", width="144"></td><td><img src="https://download.pytorch.org/torchaudio/doc-assets/avsr/detected.gif" width="144"></td><td><img src="https://download.pytorch.org/torchaudio/doc-assets/avsr/transformed.gif" width="144"></td><td><img src="https://download.pytorch.org/torchaudio/doc-assets/avsr/cropped.gif" width="144"></td></tr> | ||
<tr><td>0. Original</td> <td>1. Detection</td> <td>2. Transformation</td> <td>3. Face ROIs</td> </tr> | ||
</table> | ||
</div> | ||
|
||
## Preparation | ||
|
||
1. Install all dependency-packages. | ||
|
||
```Shell | ||
pip install -r requirements.txt | ||
``` | ||
|
||
Install [RetinaFace](./tools) tracker. | ||
2. Install [retinaface](./tools) or [mediapipe](https://pypi.org/project/mediapipe/) tracker. If you have installed the tracker, please skip it. | ||
|
||
## Preprocessing LRS3 | ||
|
||
## Preprocessing | ||
To pre-process the LRS3 dataset, plrase follow these steps: | ||
|
||
### Step 1. Pre-process the LRS3 dataset. | ||
Please run the following script to pre-process the LRS3 dataset: | ||
1. Download the LRS3 dataset from the official website. | ||
|
||
2. Run the following command to preprocess the dataset: | ||
|
||
```Shell | ||
python main.py \ | ||
python preprocess_lrs3.py \ | ||
--data-dir=[data_dir] \ | ||
--detector=[detector] \ | ||
--dataset=[dataset] \ | ||
--root=[root] \ | ||
--folder=[folder] \ | ||
--groups=[num_groups] \ | ||
--job-index=[job_index] | ||
--root-dir=[root] \ | ||
--subset=[subset] \ | ||
--seg-duration=[seg_duration] \ | ||
--groups=[n] \ | ||
--job-index=[j] | ||
``` | ||
|
||
- `[data_dir]` and `[landmarks_dir]` are the directories for original dataset and corresponding landmarks. | ||
|
||
- `[root]` is the directory for saved cropped-face dataset. | ||
|
||
- `[folder]` can be set to `train` or `test`. | ||
- `data-dir`: Path to the directory containing video files. | ||
- `detector`: Type of face detector. Valid values are: `mediapipe` and `retinaface`. Default: `retinaface`. | ||
- `dataset`: Name of the dataset. Valid value is: `lrs3`. | ||
- `root-dir`: Path to the root directory where all preprocessed files will be stored. | ||
- `subset`: Name of the subset. Valid values are: `train` and `test`. | ||
- `seg-duration`: Length of the maximal segment in seconds. Default: `16`. | ||
- `groups`: Number of groups to split the dataset into. | ||
- `job-index`: Job index for the current group. Valid values are an integer within the range of `[0, n)`. | ||
|
||
- `[num_groups]` and `[job-index]` are used to split the dataset into multiple threads, where `[job-index]` is an integer in [0, `[num_groups]`). | ||
|
||
### Step 2. Merge the label list. | ||
After completing Step 2, run the following script to merge all labels. | ||
3. Run the following command to merge all labels: | ||
|
||
```Shell | ||
python merge.py \ | ||
--root-dir=[root_dir] \ | ||
--dataset=[dataset] \ | ||
--root=[root] \ | ||
--folder=[folder] \ | ||
--groups=[num_groups] \ | ||
--subset=[subset] \ | ||
--seg-duration=[seg_duration] \ | ||
--groups=[n] | ||
``` | ||
|
||
- `root-dir`: Path to the root directory where all preprocessed files will be stored. | ||
- `dataset`: Name of the dataset. Valid values are: `lrs2` and `lrs3`. | ||
- `subset`: The subset name of the dataset. For LRS2, valid values are `train`, `val`, and `test`. For LRS3, valid values are `train` and `test`. | ||
- `seg-duration`: Length of the maximal segment in seconds. Default: `16`. | ||
- `groups`: Number of groups to split the dataset into. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
#! /usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
# Copyright 2021 Imperial College London (Pingchuan Ma) | ||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) | ||
|
||
import warnings | ||
|
||
import mediapipe as mp | ||
|
||
import numpy as np | ||
import torchvision | ||
|
||
warnings.filterwarnings("ignore") | ||
|
||
|
||
class LandmarksDetector: | ||
def __init__(self): | ||
self.mp_face_detection = mp.solutions.face_detection | ||
self.short_range_detector = self.mp_face_detection.FaceDetection( | ||
min_detection_confidence=0.5, model_selection=0 | ||
) | ||
self.full_range_detector = self.mp_face_detection.FaceDetection(min_detection_confidence=0.5, model_selection=1) | ||
|
||
def __call__(self, video_frames): | ||
landmarks = self.detect(video_frames, self.full_range_detector) | ||
if all(element is None for element in landmarks): | ||
landmarks = self.detect(video_frames, self.short_range_detector) | ||
assert any(l is not None for l in landmarks), "Cannot detect any frames in the video" | ||
return landmarks | ||
|
||
def detect(self, filename, detector): | ||
video_frames = torchvision.io.read_video(filename, pts_unit="sec")[0].numpy() | ||
landmarks = [] | ||
for frame in video_frames: | ||
results = detector.process(frame) | ||
if not results.detections: | ||
landmarks.append(None) | ||
continue | ||
face_points = [] | ||
for idx, detected_faces in enumerate(results.detections): | ||
max_id, max_size = 0, 0 | ||
bboxC = detected_faces.location_data.relative_bounding_box | ||
ih, iw, ic = frame.shape | ||
bbox = int(bboxC.xmin * iw), int(bboxC.ymin * ih), int(bboxC.width * iw), int(bboxC.height * ih) | ||
bbox_size = (bbox[2] - bbox[0]) + (bbox[3] - bbox[1]) | ||
if bbox_size > max_size: | ||
max_id, max_size = idx, bbox_size | ||
lmx = [[int(bboxC.xmin * iw), int(bboxC.ymin * ih)], [int(bboxC.width * iw), int(bboxC.height * ih)]] | ||
face_points.append(lmx) | ||
landmarks.append(np.reshape(np.array(face_points[max_id]), (2, 2))) | ||
return landmarks |
Oops, something went wrong.