diff --git a/examples/avsr/README.md b/examples/avsr/README.md
index 88e26d042c..ea18c9175a 100644
--- a/examples/avsr/README.md
+++ b/examples/avsr/README.md
@@ -1,70 +1,75 @@
-
-RNN-T ASR/VSR/AV-ASR Examples
+
+Real-time ASR/VSR/AV-ASR Examples
-This repository contains sample implementations of training and evaluation pipelines for RNNT based automatic, visual, and audio-visual (ASR, VSR, AV-ASR) models on LRS3. This repository includes both streaming/non-streaming modes. We follow the same training pipeline as [AutoAVSR](https://arxiv.org/abs/2303.14307).
+
+
+[📘Introduction](#introduction) |
+[📊Training](#Training) |
+[🔮Evaluation](#Evaluation)
+
+
+## Introduction
+
+This directory contains the training recipe for real-time audio, visual, and audio-visual speech recognition (ASR, VSR, AV-ASR) models, which is an extension of [Auto-AVSR](https://arxiv.org/abs/2303.14307).
+
+Please refer to [this tutorial]() for real-time AV-ASR inference from microphone and camera.
## Preparation
-1. Setup the environment.
-```
-conda create -y -n autoavsr python=3.8
-conda activate autoavsr
-```
-2. Install PyTorch nightly version (Pytorch, Torchvision, Torchaudio) from [source](https://pytorch.org/get-started/), along with all necessary packages:
+1. Install PyTorch (pytorch, torchvision, torchaudio) from [source](https://pytorch.org/get-started/), along with all necessary packages:
```Shell
-pip install pytorch-lightning sentencepiece
+pip install torch torchvision torchaudio pytorch-lightning sentencepiece
```
-3. Preprocess LRS3 to a cropped-face dataset from the [data_prep](./data_prep) folder.
+2. Preprocess LRS3. See the instructions in the [data_prep](./data_prep) folder.
-4. `[sp_model_path]` is a sentencepiece model to encode targets, which can be generated using `train_spm.py`.
-
-### Training ASR or VSR model
-
-- `[root_dir]` is the root directory for the LRS3 cropped-face dataset.
-- `[modality]` is the input modality type, including `v`, `a`, and `av`.
-- `[mode]` is the model type, including `online` and `offline`.
+## Usage
+### Training
```Shell
-
-python train.py --root-dir [root_dir] \
- --sp-model-path ./spm_unigram_1023.model
- --exp-dir ./exp \
- --num-nodes 8 \
- --gpus 8 \
- --md [modality] \
- --mode [mode]
+python train.py --exp-dir=[exp_dir] \
+ --exp-name=[exp_name] \
+ --modality=[modality] \
+ --mode=[mode] \
+ --root-dir=[root-dir] \
+ --sp-model-path=[sp_model_path] \
+ --num-nodes=[num_nodes] \
+ --gpus=[gpus]
```
-### Training AV-ASR model
+- `exp-dir` and `exp-name`: The directory where the checkpoints will be saved, will be stored at the location `[exp_dir]`/`[exp_name]`.
+- `modality`: Type of the input modality. Valid values are: `video`, `audio`, and `audiovisual`.
+- `mode`: Type of the mode. Valid values are: `online` and `offline`.
+- `root-dir`: Path to the root directory where all preprocessed files will be stored.
+- `sp-model-path`: Path to the sentencepiece model. Default: `./spm_unigram_1023.model`, which can be produced using `train_spm.py`.
+- `num-nodes`: The number of machines used. Default: 4.
+- `gpus`: The number of gpus in each machine. Default: 8.
+
+### Evaluation
```Shell
-python train.py --root-dir [root-dir] \
- --sp-model-path ./spm_unigram_1023.model
- --exp-dir ./exp \
- --num-nodes 8 \
- --gpus 8 \
- --md av \
- --mode [mode]
+python eval.py --modality=[modality] \
+ --mode=[mode] \
+ --root-dir=[dataset_path] \
+ --sp-model-path=[sp_model_path] \
+ --checkpoint-path=[checkpoint_path]
```
-### Evaluating models
+- `modality`: Type of the input modality. Valid values are: `video`, `audio`, and `audiovisual`.
+- `mode`: Type of the mode. Valid values are: `online` and `offline`.
+- `root-dir`: Path to the root directory where all preprocessed files will be stored.
+- `sp-model-path`: Path to the sentencepiece model. Default: `./spm_unigram_1023.model`.
+- `checkpoint-path`: Path to a pretraned model.
-```Shell
-python eval.py --dataset-path [dataset_path] \
- --sp-model-path ./spm_unigram_1023.model
- --md [modality] \
- --mode [mode] \
- --checkpoint-path [checkpoint_path]
-```
+## Results
-The table below contains WER for AV-ASR models [offline evaluation].
+The table below contains WER for AV-ASR models that were trained from scratch [offline evaluation].
-| Model | WER [%] | Params (M) |
-|:-----------:|:------------:|:--------------:|
-| Non-streaming models | |
-| AV-ASR | 4.0 | 50 |
-| Streaming models | |
-| AV-ASR | 4.3 | 40 |
+| Model | Training dataset (hours) | WER [%] | Params (M) |
+|:--------------------:|:------------------------:|:-------:|:----------:|
+| Non-streaming models | | | |
+| AV-ASR | LRS3 (438) | 3.9 | 50 |
+| Streaming models | | | |
+| AV-ASR | LRS3 (438) | 3.9 | 40 |
diff --git a/examples/avsr/average_checkpoints.py b/examples/avsr/average_checkpoints.py
index 74cf20f959..9b0831b347 100644
--- a/examples/avsr/average_checkpoints.py
+++ b/examples/avsr/average_checkpoints.py
@@ -23,9 +23,6 @@ def average_checkpoints(last):
def ensemble(args):
- last = [
- os.path.join(args.exp_dir, args.experiment_name, f"epoch={n}.ckpt")
- for n in range(args.epochs - 10, args.epochs)
- ]
- model_path = os.path.join(args.exp_dir, args.experiment_name, "model_avg_10.pth")
+ last = [os.path.join(args.exp_dir, args.exp_name, f"epoch={n}.ckpt") for n in range(args.epochs - 10, args.epochs)]
+ model_path = os.path.join(args.exp_dir, args.exp_name, "model_avg_10.pth")
torch.save({"state_dict": average_checkpoints(last)}, model_path)
diff --git a/examples/avsr/data_module.py b/examples/avsr/data_module.py
index 060fde60f0..c50d0fa26c 100644
--- a/examples/avsr/data_module.py
+++ b/examples/avsr/data_module.py
@@ -110,52 +110,19 @@ def __init__(
self.num_workers = num_workers
def train_dataloader(self):
- datasets = [LRS3(self.args, subset="train")]
-
- if not self.train_dataset_lengths:
- self.train_dataset_lengths = [dataset._lengthlist for dataset in datasets]
-
- dataset = torch.utils.data.ConcatDataset(
- [
- CustomBucketDataset(
- dataset,
- lengths,
- self.max_frames,
- self.train_num_buckets,
- batch_size=self.batch_size,
- )
- for dataset, lengths in zip(datasets, self.train_dataset_lengths)
- ]
+ dataset = LRS3(self.args, subset="train")
+ dataset = CustomBucketDataset(
+ dataset, dataset.lengths, self.max_frames, self.train_num_buckets, batch_size=self.batch_size
)
-
dataset = TransformDataset(dataset, self.train_transform)
dataloader = torch.utils.data.DataLoader(
- dataset,
- num_workers=self.num_workers,
- batch_size=None,
- shuffle=self.train_shuffle,
+ dataset, num_workers=self.num_workers, batch_size=None, shuffle=self.train_shuffle
)
return dataloader
def val_dataloader(self):
- datasets = [LRS3(self.args, subset="val")]
-
- if not self.val_dataset_lengths:
- self.val_dataset_lengths = [dataset._lengthlist for dataset in datasets]
-
- dataset = torch.utils.data.ConcatDataset(
- [
- CustomBucketDataset(
- dataset,
- lengths,
- self.max_frames,
- 1,
- batch_size=self.batch_size,
- )
- for dataset, lengths in zip(datasets, self.val_dataset_lengths)
- ]
- )
-
+ dataset = LRS3(self.args, subset="val")
+ dataset = CustomBucketDataset(dataset, dataset.lengths, self.max_frames, 1, batch_size=self.batch_size)
dataset = TransformDataset(dataset, self.val_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=self.num_workers)
return dataloader
diff --git a/examples/avsr/data_prep/README.md b/examples/avsr/data_prep/README.md
index 995f9ec5be..4d613140e9 100644
--- a/examples/avsr/data_prep/README.md
+++ b/examples/avsr/data_prep/README.md
@@ -1,48 +1,72 @@
-# Preprocessing LRS3
+# Pre-process LRS3
-We provide a pre-processing pipeline to detect and crop full-face images in this repository.
+We provide a pre-processing pipeline in this repository for detecting and cropping full-face regions of interest (ROIs) as well as corresponding audio waveforms for LRS3.
-## Prerequisites
+## Introduction
-Install all dependency-packages.
+Before feeding the raw stream into our model, each video sequence has to undergo a specific pre-processing procedure. This involves three critical steps. The first step is to perform face detection. Following that, each individual frame is aligned to a referenced frame, commonly known as the mean face, in order to normalize rotation and size differences across frames. The final step in the pre-processing module is to crop the face region from the aligned face image.
+
+
+
+## Preparation
+
+1. Install all dependency-packages.
```Shell
pip install -r requirements.txt
```
-Install [RetinaFace](./tools) tracker.
+2. Install [retinaface](./tools) or [mediapipe](https://pypi.org/project/mediapipe/) tracker. If you have installed the tracker, please skip it.
+
+## Preprocessing LRS3
-## Preprocessing
+To pre-process the LRS3 dataset, plrase follow these steps:
-### Step 1. Pre-process the LRS3 dataset.
-Please run the following script to pre-process the LRS3 dataset:
+1. Download the LRS3 dataset from the official website.
+
+2. Run the following command to preprocess the dataset:
```Shell
-python main.py \
+python preprocess_lrs3.py \
--data-dir=[data_dir] \
+ --detector=[detector] \
--dataset=[dataset] \
- --root=[root] \
- --folder=[folder] \
- --groups=[num_groups] \
- --job-index=[job_index]
+ --root-dir=[root] \
+ --subset=[subset] \
+ --seg-duration=[seg_duration] \
+ --groups=[n] \
+ --job-index=[j]
```
-- `[data_dir]` and `[landmarks_dir]` are the directories for original dataset and corresponding landmarks.
-
-- `[root]` is the directory for saved cropped-face dataset.
-
-- `[folder]` can be set to `train` or `test`.
+- `data-dir`: Path to the directory containing video files.
+- `detector`: Type of face detector. Valid values are: `mediapipe` and `retinaface`. Default: `retinaface`.
+- `dataset`: Name of the dataset. Valid value is: `lrs3`.
+- `root-dir`: Path to the root directory where all preprocessed files will be stored.
+- `subset`: Name of the subset. Valid values are: `train` and `test`.
+- `seg-duration`: Length of the maximal segment in seconds. Default: `16`.
+- `groups`: Number of groups to split the dataset into.
+- `job-index`: Job index for the current group. Valid values are an integer within the range of `[0, n)`.
-- `[num_groups]` and `[job-index]` are used to split the dataset into multiple threads, where `[job-index]` is an integer in [0, `[num_groups]`).
-
-### Step 2. Merge the label list.
-After completing Step 2, run the following script to merge all labels.
+3. Run the following command to merge all labels:
```Shell
python merge.py \
+ --root-dir=[root_dir] \
--dataset=[dataset] \
- --root=[root] \
- --folder=[folder] \
- --groups=[num_groups] \
+ --subset=[subset] \
+ --seg-duration=[seg_duration] \
+ --groups=[n]
```
+
+- `root-dir`: Path to the root directory where all preprocessed files will be stored.
+- `dataset`: Name of the dataset. Valid values are: `lrs2` and `lrs3`.
+- `subset`: The subset name of the dataset. For LRS2, valid values are `train`, `val`, and `test`. For LRS3, valid values are `train` and `test`.
+- `seg-duration`: Length of the maximal segment in seconds. Default: `16`.
+- `groups`: Number of groups to split the dataset into.
diff --git a/examples/avsr/data_prep/data/data_module.py b/examples/avsr/data_prep/data/data_module.py
index 556a1bf369..72bc2e69e7 100644
--- a/examples/avsr/data_prep/data/data_module.py
+++ b/examples/avsr/data_prep/data/data_module.py
@@ -19,6 +19,12 @@ def __init__(self, modality, detector="retinaface", resize=None):
self.landmarks_detector = LandmarksDetector(device="cuda:0")
self.video_process = VideoProcess(resize=resize)
+ if detector == "mediapipe":
+ from detectors.mediapipe.detector import LandmarksDetector
+ from detectors.mediapipe.video_process import VideoProcess
+
+ self.landmarks_detector = LandmarksDetector()
+ self.video_process = VideoProcess(resize=resize)
def load_data(self, data_filename, transform=True):
if self.modality == "audio":
diff --git a/examples/avsr/data_prep/detectors/mediapipe/detector.py b/examples/avsr/data_prep/detectors/mediapipe/detector.py
new file mode 100644
index 0000000000..9971dde2b5
--- /dev/null
+++ b/examples/avsr/data_prep/detectors/mediapipe/detector.py
@@ -0,0 +1,52 @@
+#! /usr/bin/env python
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Imperial College London (Pingchuan Ma)
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+import warnings
+
+import mediapipe as mp
+
+import numpy as np
+import torchvision
+
+warnings.filterwarnings("ignore")
+
+
+class LandmarksDetector:
+ def __init__(self):
+ self.mp_face_detection = mp.solutions.face_detection
+ self.short_range_detector = self.mp_face_detection.FaceDetection(
+ min_detection_confidence=0.5, model_selection=0
+ )
+ self.full_range_detector = self.mp_face_detection.FaceDetection(min_detection_confidence=0.5, model_selection=1)
+
+ def __call__(self, video_frames):
+ landmarks = self.detect(video_frames, self.full_range_detector)
+ if all(element is None for element in landmarks):
+ landmarks = self.detect(video_frames, self.short_range_detector)
+ assert any(l is not None for l in landmarks), "Cannot detect any frames in the video"
+ return landmarks
+
+ def detect(self, filename, detector):
+ video_frames = torchvision.io.read_video(filename, pts_unit="sec")[0].numpy()
+ landmarks = []
+ for frame in video_frames:
+ results = detector.process(frame)
+ if not results.detections:
+ landmarks.append(None)
+ continue
+ face_points = []
+ for idx, detected_faces in enumerate(results.detections):
+ max_id, max_size = 0, 0
+ bboxC = detected_faces.location_data.relative_bounding_box
+ ih, iw, ic = frame.shape
+ bbox = int(bboxC.xmin * iw), int(bboxC.ymin * ih), int(bboxC.width * iw), int(bboxC.height * ih)
+ bbox_size = (bbox[2] - bbox[0]) + (bbox[3] - bbox[1])
+ if bbox_size > max_size:
+ max_id, max_size = idx, bbox_size
+ lmx = [[int(bboxC.xmin * iw), int(bboxC.ymin * ih)], [int(bboxC.width * iw), int(bboxC.height * ih)]]
+ face_points.append(lmx)
+ landmarks.append(np.reshape(np.array(face_points[max_id]), (2, 2)))
+ return landmarks
diff --git a/examples/avsr/data_prep/detectors/mediapipe/video_process.py b/examples/avsr/data_prep/detectors/mediapipe/video_process.py
new file mode 100644
index 0000000000..375fd4c428
--- /dev/null
+++ b/examples/avsr/data_prep/detectors/mediapipe/video_process.py
@@ -0,0 +1,158 @@
+import cv2
+import numpy as np
+from skimage import transform as tf
+
+
+def linear_interpolate(landmarks, start_idx, stop_idx):
+ start_landmarks = landmarks[start_idx]
+ stop_landmarks = landmarks[stop_idx]
+ delta = stop_landmarks - start_landmarks
+ for idx in range(1, stop_idx - start_idx):
+ landmarks[start_idx + idx] = start_landmarks + idx / float(stop_idx - start_idx) * delta
+ return landmarks
+
+
+def warp_img(src, dst, img, std_size):
+ tform = tf.estimate_transform("similarity", src, dst)
+ warped = tf.warp(img, inverse_map=tform.inverse, output_shape=std_size)
+ warped = (warped * 255).astype("uint8")
+ return warped, tform
+
+
+def apply_transform(transform, img, std_size):
+ warped = tf.warp(img, inverse_map=transform.inverse, output_shape=std_size)
+ warped = (warped * 255).astype("uint8")
+ return warped
+
+
+def cut_patch(img, landmarks, height, width, threshold=5):
+ center_x, center_y = np.mean(landmarks, axis=0)
+ # Check for too much bias in height and width
+ if abs(center_y - img.shape[0] / 2) > height + threshold:
+ raise Exception("too much bias in height")
+ if abs(center_x - img.shape[1] / 2) > width + threshold:
+ raise Exception("too much bias in width")
+ # Calculate bounding box coordinates
+ y_min = int(round(np.clip(center_y - height, 0, img.shape[0])))
+ y_max = int(round(np.clip(center_y + height, 0, img.shape[0])))
+ x_min = int(round(np.clip(center_x - width, 0, img.shape[1])))
+ x_max = int(round(np.clip(center_x + width, 0, img.shape[1])))
+ # Cut the image
+ cutted_img = np.copy(img[y_min:y_max, x_min:x_max])
+ return cutted_img
+
+
+class VideoProcess:
+ def __init__(
+ self,
+ crop_width=128,
+ crop_height=128,
+ target_size=(224, 224),
+ reference_size=(224, 224),
+ stable_points=(0, 1),
+ start_idx=0,
+ stop_idx=2,
+ resize=(96, 96),
+ ):
+ self.reference = np.array(([[51.64568, 0.70204943], [171.95107, 159.59505]]))
+ self.crop_width = crop_width
+ self.crop_height = crop_height
+ self.start_idx = start_idx
+ self.stop_idx = stop_idx
+ self.resize = resize
+
+ def __call__(self, video, landmarks):
+ # Pre-process landmarks: interpolate frames that are not detected
+ preprocessed_landmarks = self.interpolate_landmarks(landmarks)
+ # Exclude corner cases: no landmark in all frames or number of frames is less than window length
+ if not preprocessed_landmarks:
+ return
+ # Affine transformation and crop patch
+ sequence = self.crop_patch(video, preprocessed_landmarks)
+ assert sequence is not None, "crop an empty patch."
+ return sequence
+
+ def crop_patch(self, video, landmarks):
+ sequence = []
+ for frame_idx, frame in enumerate(video):
+ transformed_frame, transformed_landmarks = self.affine_transform(
+ frame, landmarks[frame_idx], self.reference
+ )
+ patch = cut_patch(
+ transformed_frame,
+ transformed_landmarks[self.start_idx : self.stop_idx],
+ self.crop_height // 2,
+ self.crop_width // 2,
+ )
+ if self.resize:
+ patch = cv2.resize(patch, self.resize)
+ sequence.append(patch)
+ return np.array(sequence)
+
+ def interpolate_landmarks(self, landmarks):
+ valid_frames_idx = [idx for idx, lm in enumerate(landmarks) if lm is not None]
+
+ if not valid_frames_idx:
+ return None
+
+ for idx in range(1, len(valid_frames_idx)):
+ if valid_frames_idx[idx] - valid_frames_idx[idx - 1] > 1:
+ landmarks = linear_interpolate(landmarks, valid_frames_idx[idx - 1], valid_frames_idx[idx])
+
+ valid_frames_idx = [idx for idx, lm in enumerate(landmarks) if lm is not None]
+
+ # Handle corner case: keep frames at the beginning or at the end that failed to be detected
+ if valid_frames_idx:
+ landmarks[: valid_frames_idx[0]] = [landmarks[valid_frames_idx[0]]] * valid_frames_idx[0]
+ landmarks[valid_frames_idx[-1] :] = [landmarks[valid_frames_idx[-1]]] * (
+ len(landmarks) - valid_frames_idx[-1]
+ )
+
+ assert all(lm is not None for lm in landmarks), "not every frame has landmark"
+
+ return landmarks
+
+ def affine_transform(
+ self,
+ frame,
+ landmarks,
+ reference,
+ target_size=(224, 224),
+ reference_size=(224, 224),
+ stable_points=(0, 1),
+ interpolation=cv2.INTER_LINEAR,
+ border_mode=cv2.BORDER_CONSTANT,
+ border_value=0,
+ ):
+ stable_reference = self.get_stable_reference(reference, stable_points, reference_size, target_size)
+ transform = self.estimate_affine_transform(landmarks, stable_points, stable_reference)
+ transformed_frame, transformed_landmarks = self.apply_affine_transform(
+ frame, landmarks, transform, target_size, interpolation, border_mode, border_value
+ )
+
+ return transformed_frame, transformed_landmarks
+
+ def get_stable_reference(self, reference, stable_points, reference_size, target_size):
+ stable_reference = np.vstack([reference[x] for x in stable_points])
+ stable_reference[:, 0] -= (reference_size[0] - target_size[0]) / 2.0
+ stable_reference[:, 1] -= (reference_size[1] - target_size[1]) / 2.0
+ return stable_reference
+
+ def estimate_affine_transform(self, landmarks, stable_points, stable_reference):
+ return cv2.estimateAffinePartial2D(
+ np.vstack([landmarks[x] for x in stable_points]), stable_reference, method=cv2.LMEDS
+ )[0]
+
+ def apply_affine_transform(
+ self, frame, landmarks, transform, target_size, interpolation, border_mode, border_value
+ ):
+ transformed_frame = cv2.warpAffine(
+ frame,
+ transform,
+ dsize=(target_size[0], target_size[1]),
+ flags=interpolation,
+ borderMode=border_mode,
+ borderValue=border_value,
+ )
+ transformed_landmarks = np.matmul(landmarks, transform[:, :2].transpose()) + transform[:, 2].transpose()
+ return transformed_frame, transformed_landmarks
diff --git a/examples/avsr/data_prep/merge.py b/examples/avsr/data_prep/merge.py
index 132a3cce7d..0b7a33dd24 100644
--- a/examples/avsr/data_prep/merge.py
+++ b/examples/avsr/data_prep/merge.py
@@ -1,40 +1,38 @@
+import argparse
import os
-from argparse import ArgumentParser
-
-def load_args(default_config=None):
- parser = ArgumentParser()
- parser.add_argument(
- "--dataset",
- type=str,
- help="Specify the dataset name used in the experiment",
- )
- parser.add_argument(
- "--subset",
- type=str,
- help="Specify the set used in the experiment",
- )
- parser.add_argument(
- "--root-dir",
- type=str,
- help="The root directory of saved mouth patches or embeddings.",
- )
- parser.add_argument(
- "--groups",
- type=int,
- help="Specify the number of threads to be used",
- )
- parser.add_argument(
- "--seg-duration",
- type=int,
- default=16,
- help="Specify the segment length",
- )
- args = parser.parse_args()
- return args
-
-
-args = load_args()
+parser = argparse.ArgumentParser(description="Merge labels")
+parser.add_argument(
+ "--dataset",
+ type=str,
+ required=True,
+ help="Specify the dataset used in the experiment",
+)
+parser.add_argument(
+ "--subset",
+ type=str,
+ required=True,
+ help="Specify the subset of the dataset used in the experiment",
+)
+parser.add_argument(
+ "--root-dir",
+ type=str,
+ required=True,
+ help="Directory of saved mouth patches or embeddings",
+)
+parser.add_argument(
+ "--groups",
+ type=int,
+ required=True,
+ help="Number of threads for parallel processing",
+)
+parser.add_argument(
+ "--seg-duration",
+ type=int,
+ default=16,
+ help="Length of the segments",
+)
+args = parser.parse_args()
dataset = args.dataset
subset = args.subset
@@ -45,7 +43,9 @@ def load_args(default_config=None):
# Create the filename template for label files
label_template = os.path.join(
- args.root_dir, "labels", f"{dataset}_{subset}_transcript_lengths_seg{seg_duration}s.{args.groups}"
+ args.root_dir,
+ "labels",
+ f"{dataset}_{subset}_transcript_lengths_seg{seg_duration}s.{args.groups}",
)
lines = []
@@ -58,15 +58,17 @@ def load_args(default_config=None):
# Write the merged labels to a new file
dst_label_filename = os.path.join(
- args.root_dir, dataset, f"{dataset}_{subset}_transcript_lengths_seg{seg_duration}s.csv"
+ args.root_dir,
+ "labels",
+ f"{dataset}_{subset}_transcript_lengths_seg{seg_duration}s.csv",
)
with open(dst_label_filename, "w") as file:
file.write("\n".join(lines))
# Print the number of files and total duration in hours
-total_duration = sum(int(line.split(",")[2]) for line in lines) / 3600.0 / 25.0
-print(f"The completed set has {len(lines)} files with a total of {total_duration} hours.")
+total_duration = sum(int(line.split(",")[2]) for line in lines) / 3600.0 / 25.0 # simplified from /3600./25.
+print(f"The completed set has {len(lines)} files with a total of {total_duration:.2f} hours.")
# Remove the label files for each job index
print("** Remove the temporary label files **")
diff --git a/examples/avsr/data_prep/main.py b/examples/avsr/data_prep/preprocess_lrs3.py
similarity index 79%
rename from examples/avsr/data_prep/main.py
rename to examples/avsr/data_prep/preprocess_lrs3.py
index c9826f447b..3d5bde7b23 100644
--- a/examples/avsr/data_prep/main.py
+++ b/examples/avsr/data_prep/preprocess_lrs3.py
@@ -1,8 +1,8 @@
+import argparse
import glob
import math
import os
import shutil
-
import warnings
import ffmpeg
@@ -12,52 +12,60 @@
warnings.filterwarnings("ignore")
-from argparse import ArgumentParser
-
-
-def load_args(default_config=None):
- parser = ArgumentParser(description="Preprocess LRS3 to crop full-face images")
- # -- for benchmark evaluation
- parser.add_argument(
- "--data-dir",
- type=str,
- help="The directory for sequence.",
- )
- parser.add_argument(
- "--dataset",
- type=str,
- help="Specify the dataset name used in the experiment",
- )
- parser.add_argument(
- "--root-dir",
- type=str,
- help="The root directory of cropped-face dataset.",
- )
- parser.add_argument("--job-index", type=int, default=0, help="job index")
- parser.add_argument(
- "--groups",
- type=int,
- default=1,
- help="specify the number of threads to be used",
- )
- parser.add_argument(
- "--folder",
- type=str,
- default="test",
- help="specify the set used in the experiment",
- )
- args = parser.parse_args()
- return args
-
-
-args = load_args()
-
-seg_duration = 16
-detector = "retinaface"
+# Argument Parsing
+parser = argparse.ArgumentParser(description="LRS3 Preprocessing")
+parser.add_argument(
+ "--data-dir",
+ type=str,
+ help="The directory for sequence.",
+)
+parser.add_argument(
+ "--detector",
+ type=str,
+ default="retinaface",
+ help="Face detector used in the experiment.",
+)
+parser.add_argument(
+ "--dataset",
+ type=str,
+ help="Specify the dataset name used in the experiment",
+)
+parser.add_argument(
+ "--root-dir",
+ type=str,
+ help="The root directory of cropped-face dataset.",
+)
+parser.add_argument(
+ "--subset",
+ type=str,
+ required=True,
+ help="Subset of the dataset used in the experiment.",
+)
+parser.add_argument(
+ "--seg-duration",
+ type=int,
+ default=16,
+ help="Length of the segment in seconds.",
+)
+parser.add_argument(
+ "--groups",
+ type=int,
+ default=1,
+ help="Number of threads to be used in parallel.",
+)
+parser.add_argument(
+ "--job-index",
+ type=int,
+ default=0,
+ help="Index to identify separate jobs (useful for parallel processing).",
+)
+args = parser.parse_args()
+
+seg_duration = args.seg_duration
dataset = args.dataset
args.data_dir = os.path.normpath(args.data_dir)
-vid_dataloader = AVSRDataLoader(modality="video", detector=detector, resize=(96, 96))
+vid_dataloader = AVSRDataLoader(modality="video", detector=args.detector, resize=(96, 96))
aud_dataloader = AVSRDataLoader(modality="audio")
# Step 2, extract mouth patches from segments.
seg_vid_len = seg_duration * 25
@@ -66,9 +74,9 @@ def load_args(default_config=None):
label_filename = os.path.join(
args.root_dir,
"labels",
- f"{dataset}_{args.folder}_transcript_lengths_seg{seg_duration}s.csv"
+ f"{dataset}_{args.subset}_transcript_lengths_seg{seg_duration}s.csv"
if args.groups <= 1
- else f"{dataset}_{args.folder}_transcript_lengths_seg{seg_duration}s.{args.groups}.{args.job_index}.csv",
+ else f"{dataset}_{args.subset}_transcript_lengths_seg{seg_duration}s.{args.groups}.{args.job_index}.csv",
)
os.makedirs(os.path.dirname(label_filename), exist_ok=True)
print(f"Directory {os.path.dirname(label_filename)} created")
@@ -77,9 +85,9 @@ def load_args(default_config=None):
# Step 2, extract mouth patches from segments.
dst_vid_dir = os.path.join(args.root_dir, dataset, dataset + f"_video_seg{seg_duration}s")
dst_txt_dir = os.path.join(args.root_dir, dataset, dataset + f"_text_seg{seg_duration}s")
-if args.folder == "test":
- filenames = glob.glob(os.path.join(args.data_dir, args.folder, "**", "*.mp4"), recursive=True)
-elif args.folder == "train":
+if args.subset == "test":
+ filenames = glob.glob(os.path.join(args.data_dir, args.subset, "**", "*.mp4"), recursive=True)
+elif args.subset == "train":
filenames = glob.glob(os.path.join(args.data_dir, "trainval", "**", "*.mp4"), recursive=True)
filenames.extend(glob.glob(os.path.join(args.data_dir, "pretrain", "**", "*.mp4"), recursive=True))
filenames.sort()
@@ -96,7 +104,7 @@ def load_args(default_config=None):
except UnboundLocalError:
continue
- if os.path.normpath(data_filename).split(os.sep)[-3] in ["trainval", "test", "main"]:
+ if os.path.normpath(data_filename).split(os.sep)[-3] in ["trainval", "test"]:
dst_vid_filename = f"{data_filename.replace(args.data_dir, dst_vid_dir)[:-4]}.mp4"
dst_aud_filename = f"{data_filename.replace(args.data_dir, dst_vid_dir)[:-4]}.wav"
dst_txt_filename = f"{data_filename.replace(args.data_dir, dst_txt_dir)[:-4]}.txt"
diff --git a/examples/avsr/data_prep/requirements.txt b/examples/avsr/data_prep/requirements.txt
index 3c1dd8f6dd..2fc5d4a54b 100644
--- a/examples/avsr/data_prep/requirements.txt
+++ b/examples/avsr/data_prep/requirements.txt
@@ -1,3 +1,4 @@
+tqdm
scikit-image
opencv-python
ffmpeg-python
diff --git a/examples/avsr/data_prep/tools/README.md b/examples/avsr/data_prep/tools/README.md
index 9d30695ae2..adabd24507 100644
--- a/examples/avsr/data_prep/tools/README.md
+++ b/examples/avsr/data_prep/tools/README.md
@@ -1,14 +1,12 @@
## Face Recognition
-We provide [ibug.face_detection](https://github.com/hhj1897/face_detection) in this repository.
+We provide [ibug.face_detection](https://github.com/hhj1897/face_detection) in this repository. You can install directly from github repositories or by using compressed files.
+
+### Option 1. Install from github repositories
-### Prerequisites
* [Git LFS](https://git-lfs.github.com/), needed for downloading the pretrained weights that are larger than 100 MB.
You could install *`Homebrew`* and then install *`git-lfs`* without sudo priviledges.
-### From source
-
-1. Install *`ibug.face_detection`*
```Shell
git clone https://github.com/hhj1897/face_detection.git
cd face_detection
@@ -16,3 +14,15 @@ git lfs pull
pip install -e .
cd ..
```
+
+### Option 2. Install by using compressed files
+
+If you are experiencing over-quota issues for the above repositoies, you can download both packages [ibug.face_detection](https://www.doc.ic.ac.uk/~pm4115/tracker/face_detection.zip), unzip the files, and then run `pip install -e .` to install each package.
+
+```Shell
+wget https://www.doc.ic.ac.uk/~pm4115/tracker/face_detection.zip -O ./face_detection.zip
+unzip -o ./face_detection.zip -d ./
+cd face_detection
+pip install -e .
+cd ..
+```
diff --git a/examples/avsr/doc/lip_white.png b/examples/avsr/doc/lip_white.png
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/examples/avsr/eval.py b/examples/avsr/eval.py
index 174c348f00..e26bb6bf9f 100644
--- a/examples/avsr/eval.py
+++ b/examples/avsr/eval.py
@@ -16,7 +16,7 @@ def compute_word_level_distance(seq1, seq2):
def get_lightning_module(args):
sp_model = spm.SentencePieceProcessor(model_file=str(args.sp_model_path))
- if args.md == "av":
+ if args.modality == "audiovisual":
from lightning_av import AVConformerRNNTModule
model = AVConformerRNNTModule(args, sp_model)
@@ -49,7 +49,7 @@ def run_eval(model, data_module):
def parse_args():
parser = ArgumentParser()
parser.add_argument(
- "--md",
+ "--modality",
type=str,
help="Modality",
required=True,
@@ -69,20 +69,15 @@ def parse_args():
parser.add_argument(
"--sp-model-path",
type=str,
- help="Path to SentencePiece model.",
+ help="Path to sentencepiece model.",
required=True,
)
parser.add_argument(
"--checkpoint-path",
type=str,
- help="Path to checkpoint model.",
+ help="Path to a checkpoint model.",
required=True,
)
- parser.add_argument(
- "--pretrained-model-path",
- type=str,
- help="Path to Pretraned model.",
- )
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
return parser.parse_args()
diff --git a/examples/avsr/lightning.py b/examples/avsr/lightning.py
index ed2c24edbc..cbca0d976a 100644
--- a/examples/avsr/lightning.py
+++ b/examples/avsr/lightning.py
@@ -57,9 +57,9 @@ def __init__(self, args=None, sp_model=None, pretrained_model_path=None):
)
self.blank_idx = spm_vocab_size
- if args.md == "v":
+ if args.modality == "video":
self.frontend = video_resnet()
- if args.md == "a":
+ if args.modality == "audio":
self.frontend = audio_resnet()
if args.mode == "online":
@@ -116,33 +116,13 @@ def configure_optimizers(self):
[{"scheduler": self.warmup_lr_scheduler, "interval": self.lr_scheduler_interval}],
)
- def forward(self, batch: Batch):
+ def forward(self, batch):
decoder = RNNTBeamSearch(self.model, self.blank_idx)
x = self.frontend(batch.inputs.to(self.device))
hypotheses = decoder(x, batch.input_lengths.to(self.device), beam_width=20)
return post_process_hypos(hypotheses, self.sp_model)[0][0]
- def training_step(self, batch: Batch, batch_idx):
- """Custom training step.
-
- By default, DDP does the following on each train step:
- - For each GPU, compute loss and gradient on shard of training data.
- - Sync and average gradients across all GPUs. The final gradient
- is (sum of gradients across all GPUs) / N, where N is the world
- size (total number of GPUs).
- - Update parameters on each GPU.
-
- Here, we do the following:
- - For k-th GPU, compute loss and scale it by (N / B_total), where B_total is
- the sum of batch sizes across all GPUs. Compute gradient from scaled loss.
- - Sync and average gradients across all GPUs. The final gradient
- is (sum of gradients across all GPUs) / B_total.
- - Update parameters on each GPU.
-
- Doing so allows us to account for the variability in batch sizes that
- variable-length sequential data commonly yields.
- """
-
+ def training_step(self, batch, batch_idx):
opt = self.optimizers()
opt.zero_grad()
loss = self._step(batch, batch_idx, "train")
@@ -157,7 +137,7 @@ def training_step(self, batch: Batch, batch_idx):
sch = self.lr_schedulers()
sch.step()
- self.log("monitoring_step", self.global_step)
+ self.log("monitoring_step", torch.tensor(self.global_step, dtype=torch.float32))
return loss
diff --git a/examples/avsr/lightning_av.py b/examples/avsr/lightning_av.py
index b0730d43e2..94529c0229 100644
--- a/examples/avsr/lightning_av.py
+++ b/examples/avsr/lightning_av.py
@@ -116,7 +116,7 @@ def configure_optimizers(self):
[{"scheduler": self.warmup_lr_scheduler, "interval": self.lr_scheduler_interval}],
)
- def forward(self, batch: AVBatch):
+ def forward(self, batch):
decoder = RNNTBeamSearch(self.model, self.blank_idx)
video_features = self.video_frontend(batch.videos.to(self.device))
audio_features = self.audio_frontend(batch.audios.to(self.device))
@@ -127,27 +127,7 @@ def forward(self, batch: AVBatch):
)
return post_process_hypos(hypotheses, self.sp_model)[0][0]
- def training_step(self, batch: AVBatch, batch_idx):
- """Custom training step.
-
- By default, DDP does the following on each train step:
- - For each GPU, compute loss and gradient on shard of training data.
- - Sync and average gradients across all GPUs. The final gradient
- is (sum of gradients across all GPUs) / N, where N is the world
- size (total number of GPUs).
- - Update parameters on each GPU.
-
- Here, we do the following:
- - For k-th GPU, compute loss and scale it by (N / B_total), where B_total is
- the sum of batch sizes across all GPUs. Compute gradient from scaled loss.
- - Sync and average gradients across all GPUs. The final gradient
- is (sum of gradients across all GPUs) / B_total.
- - Update parameters on each GPU.
-
- Doing so allows us to account for the variability in batch sizes that
- variable-length sequential data commonly yields.
- """
-
+ def training_step(self, batch, batch_idx):
opt = self.optimizers()
opt.zero_grad()
loss = self._step(batch, batch_idx, "train")
@@ -162,7 +142,7 @@ def training_step(self, batch: AVBatch, batch_idx):
sch = self.lr_schedulers()
sch.step()
- self.log("monitoring_step", self.global_step)
+ self.log("monitoring_step", torch.tensor(self.global_step, dtype=torch.float32))
return loss
diff --git a/examples/avsr/lrs3.py b/examples/avsr/lrs3.py
index 5077f67c70..b58d96a061 100644
--- a/examples/avsr/lrs3.py
+++ b/examples/avsr/lrs3.py
@@ -40,12 +40,12 @@ def load_transcript(path):
return open(transcript_path).read().splitlines()[0]
-def load_item(path, md):
- if md == "v":
+def load_item(path, modality):
+ if modality == "video":
return (load_video(path), load_transcript(path))
- if md == "a":
+ if modality == "audio":
return (load_audio(path), load_transcript(path))
- if md == "av":
+ if modality == "audiovisual":
return (load_audio(path), load_video(path), load_transcript(path))
@@ -62,15 +62,15 @@ def __init__(
self.args = args
if subset == "train":
- self._filelist, self._lengthlist = _load_list(self.args, "lrs3_train_transcript_lengths_seg16s.csv")
+ self.files, self.lengths = _load_list(self.args, "lrs3_train_transcript_lengths_seg16s.csv")
if subset == "val":
- self._filelist, self._lengthlist = _load_list(self.args, "lrs3_test_transcript_lengths_seg16s.csv")
+ self.files, self.lengths = _load_list(self.args, "lrs3_test_transcript_lengths_seg16s.csv")
if subset == "test":
- self._filelist, self._lengthlist = _load_list(self.args, "lrs3_test_transcript_lengths_seg16s.csv")
+ self.files, self.lengths = _load_list(self.args, "lrs3_test_transcript_lengths_seg16s.csv")
def __getitem__(self, n):
- path = self._filelist[n]
- return load_item(path, self.args.md)
+ path = self.files[n]
+ return load_item(path, self.args.modality)
def __len__(self) -> int:
- return len(self._filelist)
+ return len(self.files)
diff --git a/examples/avsr/models/fusion.py b/examples/avsr/models/fusion.py
index 8d2fd1f5e5..9c5abdda9b 100644
--- a/examples/avsr/models/fusion.py
+++ b/examples/avsr/models/fusion.py
@@ -32,5 +32,5 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
return self.sequential(input)
-def fusion_module():
- return FeedForwardModule(1024, 3072, 512, 0.1)
+def fusion_module(input_dim=1024, hidden_dim=3072, output_dim=512, dropout=0.1):
+ return FeedForwardModule(input_dim, hidden_dim, output_dim, dropout)
diff --git a/examples/avsr/train.py b/examples/avsr/train.py
index e609c43611..cf4c07c9de 100644
--- a/examples/avsr/train.py
+++ b/examples/avsr/train.py
@@ -14,7 +14,7 @@ def get_trainer(args):
seed_everything(1)
checkpoint = ModelCheckpoint(
- dirpath=os.path.join(args.exp_dir, args.experiment_name) if args.exp_dir else None,
+ dirpath=os.path.join(args.exp_dir, args.exp_name) if args.exp_dir else None,
monitor="monitoring_step",
mode="max",
save_last=True,
@@ -36,13 +36,12 @@ def get_trainer(args):
strategy=DDPStrategy(find_unused_parameters=False),
callbacks=callbacks,
reload_dataloaders_every_n_epochs=1,
- resume_from_checkpoint=args.resume_from_checkpoint,
)
def get_lightning_module(args):
sp_model = spm.SentencePieceProcessor(model_file=str(args.sp_model_path))
- if args.md == "av":
+ if args.modality == "audiovisual":
from lightning_av import AVConformerRNNTModule
model = AVConformerRNNTModule(args, sp_model)
@@ -56,7 +55,7 @@ def get_lightning_module(args):
def parse_args():
parser = ArgumentParser()
parser.add_argument(
- "--md",
+ "--modality",
type=str,
help="Modality",
required=True,
@@ -86,19 +85,20 @@ def parse_args():
)
parser.add_argument(
"--exp-dir",
+ default="./exp",
type=str,
help="Directory to save checkpoints and logs to. (Default: './exp')",
)
parser.add_argument(
- "--experiment-name",
+ "--exp-name",
type=str,
help="Experiment name",
)
parser.add_argument(
"--num-nodes",
- default=8,
+ default=4,
type=int,
- help="Number of nodes to use for training. (Default: 8)",
+ help="Number of nodes to use for training. (Default: 4)",
)
parser.add_argument(
"--gpus",
@@ -113,9 +113,16 @@ def parse_args():
help="Number of epochs to train for. (Default: 55)",
)
parser.add_argument(
- "--resume-from-checkpoint", default=None, type=str, help="Path to the checkpoint to resume from"
+ "--resume-from-checkpoint",
+ default=None,
+ type=str,
+ help="Path to the checkpoint to resume from",
+ )
+ parser.add_argument(
+ "--debug",
+ action="store_true",
+ help="Whether to use debug level for logging",
)
- parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
return parser.parse_args()
diff --git a/examples/avsr/transforms.py b/examples/avsr/transforms.py
index d17c8307ab..888c5962f2 100644
--- a/examples/avsr/transforms.py
+++ b/examples/avsr/transforms.py
@@ -55,28 +55,28 @@ def _extract_features(video_pipeline, audio_pipeline, samples, args):
raw_videos = []
raw_audios = []
for sample in samples:
- if args.md == "v":
+ if args.modality == "visual":
raw_videos.append(sample[0])
- if args.md == "a":
+ if args.modality == "audio":
raw_audios.append(sample[0])
- if args.md == "av":
+ if args.modality == "audiovisual":
length = min(len(sample[0]) // 640, len(sample[1]))
raw_audios.append(sample[0][: length * 640])
raw_videos.append(sample[1][:length])
- if args.md == "v" or args.md == "av":
+ if args.modality == "visual" or args.modality == "audiovisual":
videos = torch.nn.utils.rnn.pad_sequence(raw_videos, batch_first=True)
videos = video_pipeline(videos)
video_lengths = torch.tensor([elem.shape[0] for elem in videos], dtype=torch.int32)
- if args.md == "a" or args.md == "av":
+ if args.modality == "audio" or args.modality == "audiovisual":
audios = torch.nn.utils.rnn.pad_sequence(raw_audios, batch_first=True)
audios = audio_pipeline(audios)
audio_lengths = torch.tensor([elem.shape[0] // 640 for elem in audios], dtype=torch.int32)
- if args.md == "v":
+ if args.modality == "visual":
return videos, video_lengths
- if args.md == "a":
+ if args.modality == "audio":
return audios, audio_lengths
- if args.md == "av":
+ if args.modality == "audiovisual":
return audios, videos, audio_lengths, video_lengths
@@ -100,17 +100,17 @@ def __init__(self, sp_model_path: str, args):
def __call__(self, samples: List):
targets, target_lengths = _extract_labels(self.sp_model, samples)
- if self.args.md == "a":
+ if self.args.modality == "audio":
audios, audio_lengths = _extract_features(
self.train_video_pipeline, self.train_audio_pipeline, samples, self.args
)
return Batch(audios, audio_lengths, targets, target_lengths)
- if self.args.md == "v":
+ if self.args.modality == "visual":
videos, video_lengths = _extract_features(
self.train_video_pipeline, self.train_audio_pipeline, samples, self.args
)
return Batch(videos, video_lengths, targets, target_lengths)
- if self.args.md == "av":
+ if self.args.modality == "audiovisual":
audios, videos, audio_lengths, video_lengths = _extract_features(
self.train_video_pipeline, self.train_audio_pipeline, samples, self.args
)
@@ -135,17 +135,17 @@ def __init__(self, sp_model_path: str, args):
def __call__(self, samples: List):
targets, target_lengths = _extract_labels(self.sp_model, samples)
- if self.args.md == "a":
+ if self.args.modality == "audio":
audios, audio_lengths = _extract_features(
self.valid_video_pipeline, self.valid_audio_pipeline, samples, self.args
)
return Batch(audios, audio_lengths, targets, target_lengths)
- if self.args.md == "v":
+ if self.args.modality == "visual":
videos, video_lengths = _extract_features(
self.valid_video_pipeline, self.valid_audio_pipeline, samples, self.args
)
return Batch(videos, video_lengths, targets, target_lengths)
- if self.args.md == "av":
+ if self.args.modality == "audiovisual":
audios, videos, audio_lengths, video_lengths = _extract_features(
self.valid_video_pipeline, self.valid_audio_pipeline, samples, self.args
)