-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
45 changed files
with
7,181 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
*.tar | ||
*.json | ||
*.pyc | ||
*.DS_Store | ||
checkpoints/* | ||
|
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
### | ||
# Author: Kai Li | ||
# Date: 2021-06-19 23:35:14 | ||
# LastEditors: Please set LastEditors | ||
# LastEditTime: 2022-10-04 15:28:54 | ||
### | ||
|
||
import argparse | ||
import json | ||
import os | ||
import soundfile as sf | ||
from tqdm import tqdm | ||
|
||
|
||
def get_mouth_path(in_mouth_dir, wav_file, out_filename, data_type): | ||
wav_file = wav_file.split("_") | ||
# if out_filename == "s1": | ||
# file_path = os.path.join( | ||
# in_mouth_dir, data_type, "{}_{}.npz".format(wav_file[0], wav_file[1]) | ||
# ) | ||
# else: | ||
# file_path = os.path.join( | ||
# in_mouth_dir, data_type, "{}_{}.npz".format(wav_file[3], wav_file[4]) | ||
# ) | ||
if out_filename == "s1": | ||
file_path = os.path.join( | ||
in_mouth_dir, "{}_{}.npz".format(wav_file[0], wav_file[1]) | ||
) | ||
if out_filename == "s2": | ||
file_path = os.path.join( | ||
in_mouth_dir, "{}_{}.npz".format(wav_file[3], wav_file[4]) | ||
) | ||
return file_path | ||
|
||
|
||
def preprocess_one_dir(in_audio_dir, in_video_dir, out_dir, out_filename, data_type): | ||
"""Create .json file for one condition.""" | ||
file_infos = [] | ||
in_dir = os.path.abspath(in_audio_dir) | ||
wav_list = os.listdir(in_dir) | ||
wav_list.sort() | ||
for wav_file in tqdm(wav_list): | ||
if not wav_file.endswith(".wav"): | ||
continue | ||
wav_path = os.path.join(in_dir, wav_file) | ||
samples = sf.SoundFile(wav_path) | ||
if out_filename == "mix": | ||
file_infos.append((wav_path, len(samples))) | ||
else: | ||
file_infos.append( | ||
( | ||
wav_path, | ||
get_mouth_path(in_video_dir, wav_file, out_filename, data_type), | ||
len(samples), | ||
) | ||
) | ||
if not os.path.exists(out_dir): | ||
os.makedirs(out_dir) | ||
with open(os.path.join(out_dir, out_filename + ".json"), "w") as f: | ||
json.dump(file_infos, f, indent=4) | ||
|
||
|
||
def preprocess(inp_args): | ||
"""Create .json files for all conditions.""" | ||
speaker_list = ["mix", "s1", "s2"] | ||
for data_type in ["tr", "cv", "tt"]: | ||
for spk in speaker_list: | ||
preprocess_one_dir( | ||
os.path.join(inp_args.in_audio_dir, data_type, spk), | ||
inp_args.in_mouth_dir, | ||
os.path.join(inp_args.out_dir, data_type), | ||
spk, | ||
data_type, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser("WHAM data preprocessing") | ||
parser.add_argument( | ||
"--in_audio_dir", | ||
type=str, | ||
default=None, | ||
help="Directory path of audio including tr, cv and tt", | ||
) | ||
parser.add_argument( | ||
"--in_mouth_dir", | ||
type=str, | ||
default=None, | ||
help="Directory path of video including tr, cv and tt", | ||
) | ||
parser.add_argument( | ||
"--out_dir", type=str, default=None, help="Directory path to put output files" | ||
) | ||
args = parser.parse_args() | ||
print(args) | ||
preprocess(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
### | ||
# Author: Kai Li | ||
# Date: 2021-06-19 23:35:14 | ||
# LastEditors: Kai Li | ||
# LastEditTime: 2021-08-01 11:46:24 | ||
### | ||
|
||
import argparse | ||
import json | ||
import os | ||
import soundfile as sf | ||
from tqdm import tqdm | ||
import re | ||
|
||
|
||
|
||
|
||
def get_mouth_path(in_mouth_dir, wav_file, out_filename, data_type): | ||
# wav_file = wav_file.split("_") | ||
p = re.compile(r'id\d{5}_.{11}_\d{5}') | ||
res = p.findall(wav_file) | ||
assert len(res) == 2, f"matching failded for case: {wav_file}" | ||
if out_filename == "s1": | ||
file_path = os.path.join( | ||
in_mouth_dir, "{}.npz".format(res[0]) | ||
) | ||
else: | ||
file_path = os.path.join( | ||
in_mouth_dir, "{}.npz".format(res[1]) | ||
) | ||
# file_path = os.path.join( | ||
# in_mouth_dir, "{}.npz".format(wav_file[:25]) | ||
# ) | ||
return file_path | ||
|
||
|
||
def preprocess_one_dir(in_audio_dir, in_video_dir, out_dir, out_filename, data_type): | ||
"""Create .json file for one condition.""" | ||
file_infos = [] | ||
in_dir = os.path.abspath(in_audio_dir) | ||
wav_list = os.listdir(in_dir) | ||
wav_list.sort() | ||
for wav_file in tqdm(wav_list): | ||
if not wav_file.endswith(".wav"): | ||
continue | ||
wav_path = os.path.join(in_dir, wav_file) | ||
samples = sf.SoundFile(wav_path) | ||
if out_filename == "mix": | ||
file_infos.append((wav_path, len(samples))) | ||
else: | ||
file_infos.append( | ||
( | ||
wav_path, | ||
get_mouth_path(in_video_dir, wav_file, out_filename, data_type), | ||
len(samples), | ||
) | ||
) | ||
if not os.path.exists(out_dir): | ||
os.makedirs(out_dir) | ||
with open(os.path.join(out_dir, out_filename + ".json"), "w") as f: | ||
json.dump(file_infos, f, indent=4) | ||
|
||
|
||
def preprocess(inp_args): | ||
"""Create .json files for all conditions.""" | ||
speaker_list = ["mix", "s1", "s2"] | ||
for data_type in ["tr", "cv", "tt"]: | ||
for spk in speaker_list: | ||
preprocess_one_dir( | ||
os.path.join(inp_args.in_audio_dir, data_type, spk), | ||
inp_args.in_mouth_dir, | ||
os.path.join(inp_args.out_dir, data_type), | ||
spk, | ||
data_type, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser("WHAM data preprocessing") | ||
parser.add_argument( | ||
"--in_audio_dir", | ||
type=str, | ||
default=None, | ||
help="Directory path of audio including tr, cv and tt", | ||
) | ||
parser.add_argument( | ||
"--in_mouth_dir", | ||
type=str, | ||
default=None, | ||
help="Directory path of video including tr, cv and tt", | ||
) | ||
parser.add_argument( | ||
"--out_dir", type=str, default=None, help="Directory path to put output files" | ||
) | ||
args = parser.parse_args() | ||
print(args) | ||
preprocess(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
[简体中文](README_zh-CN.md) | English | ||
|
||
# <font color=E7595C>I</font><font color=F6C446>I</font><font color=00C7EE>A</font><font color=00D465>Net</font>: An <font color=E7595C>I</font>ntra- and <font color=F6C446>I</font>nter-Modality <font color=00C7EE>A</font>ttention <font color=00D465>Net</font>work for Audio-Visual Speech Separation | ||
|
||
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/scanet-a-self-and-cross-attention-network-for/speech-separation-on-lrs2)](https://paperswithcode.com/sota/speech-separation-on-lrs2?p=scanet-a-self-and-cross-attention-network-for) | ||
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/scanet-a-self-and-cross-attention-network-for/speech-separation-on-lrs3)](https://paperswithcode.com/sota/speech-separation-on-lrs3?p=scanet-a-self-and-cross-attention-network-for) | ||
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/scanet-a-self-and-cross-attention-network-for/speech-separation-on-voxceleb2)](https://paperswithcode.com/sota/speech-separation-on-voxceleb2?p=scanet-a-self-and-cross-attention-network-for) | ||
[![arXiv](https://img.shields.io/badge/arXiv-2308.08143-b31b1b.svg)](https://arxiv.org/abs/2308.08143) | ||
[![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg?style=flat-square)](https://makeapullrequest.com) | ||
[![GitHub license](https://img.shields.io/github/license/JusperLee/IIANet.svg?color=blue)](https://github.com/JusperLee/IIANet/blob/master/LICENSE) | ||
![GitHub stars](https://img.shields.io/github/stars/JusperLee/IIANet) | ||
![GitHub forks](https://img.shields.io/github/forks/JusperLee/IIANet) | ||
![Website](https://img.shields.io/website?url=https%3A%2F%2Fcslikai.cn%2FIIANet%2F&up_message=Demo%20Page&down_message=Demo%20Page&logo=webmin) | ||
|
||
|
||
By [1] Tsinghua University, [2]Chinese Institute for Brain Research. | ||
* [Kai Li](https://cslikai.cn)[1], Runxuan Yang[1], [Fuchun Sun](https://scholar.google.com/citations?user=DbviELoAAAAJ&hl=en)[1], [Xiaolin Hu](https://www.xlhu.cn/)[1,2]. | ||
|
||
This repository is an official implementation of the IIANet accepted to **ICML 2024** (**Poster**). | ||
|
||
## ✨Key Highlights: | ||
|
||
1. We propose an attention-based cross-modal speech separation network called IIANet, which extensively uses intra-attention (IntraA) and inter-attention (InterA) mechanisms within and across the speech and video modalities. | ||
|
||
2. Compared with existing CNN and Transformer methods, IIANet achieves significantly better separation quality on three audio-visual speech separation datasets while greatly reducing computational complexity and memory usage. | ||
|
||
3. A faster version, IIANet-fast, surpasses CTCNet by 1.1 dB on the challenging LRS2 dataset with only 11% MACs of CTCNet. | ||
|
||
4. Qualitative evaluations on real-world YouTube scenarios show that IIANet generates higher-quality separated speech than other separation models. | ||
|
||
## 🚀Overall Pipeline | ||
|
||
<video playsinline="" autoplay="" loop="" preload="" muted="" width="900"> | ||
<source src="figures/overall.mp4" type="video/mp4"> | ||
</video> | ||
|
||
## 🪢IIANet Architecture | ||
|
||
<video playsinline="" autoplay="" loop="" preload="" muted="" width="900"> | ||
<source src="figures/separation.mp4" type="video/mp4"> | ||
</video> | ||
|
||
## 🔧Installation | ||
|
||
1. Clone the repository: | ||
|
||
```shell | ||
git clone https://github.com/JusperLee/IIANet.git | ||
cd IIANet/ | ||
``` | ||
|
||
2. Create and activate the conda environment: | ||
|
||
```shell | ||
conda create -n iianet python=3.8 | ||
conda activate iianet | ||
``` | ||
|
||
3. Install PyTorch and torchvision following the [official instructions](https://pytorch.org). The code requires `python>=3.8`, `pytorch>=1.11`, `torchvision>=0.13`. | ||
|
||
4. Install other dependencies: | ||
|
||
```shell | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## 📊Model Performance | ||
|
||
We evaluate IIANet and its fast version IIANet-fast on three datasets: LRS2, LRS3, and VoxCeleb2. The results show that IIANet achieves significantly better speech separation quality than existing methods while maintaining high efficiency [1]. | ||
|
||
| Method | Dataset | SI-SNRi | SDRi | PESQ | Params | MACs | GPU Infer Time | Download | | ||
|:---:|:-----:|:------:|:----:|:----:|:------:|:-----:|:-----------:|:----:| | ||
| IIANet | LRS2 | 16.0 | 16.2 | 3.23 | 3.1 | 18.6 | 110.11 ms | [Config](configs/LRS2-IIANet.yml)/[Model](https://github.com/JusperLee/IIANet/releases/download/v1.0.0/lrs2.zip) | | ||
| IIANet | LRS3 | 18.3 | 18.5 | 3.28 | 3.1 | 18.6 | 110.11 ms | [Config](configs/LRS3-IIANet.yml)/[Model](https://github.com/JusperLee/IIANet/releases/download/v1.0.0/lrs3.zip) | | ||
| IIANet | VoxCeleb2 | 13.6 | 14.3 | 3.12 | 3.1 | 18.6 | 110.11 ms| [Config](configs/Vox2-IIANet.yml)/[Model](https://github.com/JusperLee/IIANet/releases/download/v1.0.0/vox2.zip) | | ||
|
||
## 💥Real-world Evaluation | ||
For single video inference, please refer to [`inference.py`](inference.py). | ||
```shell | ||
# Inference on a single video | ||
# You can modify the video path in inference.py | ||
python inference.py | ||
``` | ||
|
||
## 📚Training | ||
|
||
Before starting training, please modify the parameter configurations in [`configs`](configs). | ||
|
||
A simple example of training configuration: | ||
|
||
```yaml | ||
data_config: | ||
train_dir: DataPreProcess/LRS2/tr | ||
valid_dir: DataPreProcess/LRS2/cv | ||
test_dir: DataPreProcess/LRS2/tt | ||
n_src: 1 | ||
sample_rate: 16000 | ||
segment: 2.0 | ||
normalize_audio: false | ||
batch_size: 3 | ||
num_workers: 24 | ||
pin_memory: true | ||
persistent_workers: false | ||
``` | ||
Use the following commands to start training: | ||
```shell | ||
python train.py --conf_dir configs/LRS2-IIANet.yml | ||
python train.py --conf_dir configs/LRS3-IIANet.yml | ||
python train.py --conf_dir configs/Vox2-IIANet.yml | ||
``` | ||
|
||
## 📈Testing/Inference | ||
|
||
To evaluate a model on one or more GPUs, specify the `CUDA_VISIBLE_DEVICES`, `dataset`, `model` and `checkpoint`: | ||
|
||
```shell | ||
python test.py --conf_dir checkpoints/lrs2/conf.yml | ||
python test.py --conf_dir checkpoints/lrs3/conf.yml | ||
python test.py --conf_dir checkpoints/vox2/conf.yml | ||
``` | ||
|
||
## 💡Future Work | ||
|
||
1. Validate the effectiveness and robustness of IIANet on larger-scale datasets such as AVSpeech. | ||
2. Further optimize the architecture and training strategies of IIANet to improve speech separation quality while reducing computational costs. | ||
3. Explore the applications of IIANet in other multimodal tasks, such as speech enhancement, speaker recognition, etc. | ||
|
||
## 📜Citation | ||
|
||
If you find our work helpful, please consider citing: | ||
|
||
``` | ||
@inproceedings{lee2024iianet, | ||
title={IIANet: An Intra- and Inter-Modality Attention Network for Audio-Visual Speech Separation}, | ||
author={Kai Li and Runxuan Yang and Fuchun Sun and Xiaolin Hu}, | ||
booktitle={International Conference on Machine Learning}, | ||
year={2024} | ||
} | ||
``` |
Oops, something went wrong.