-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(preprocessing): add 2 more preprocessing commands (#123)
- Loading branch information
Showing
4 changed files
with
279 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from collections import defaultdict | ||
from logging import getLogger | ||
from pathlib import Path | ||
|
||
import soundfile as sf | ||
import torch | ||
from joblib import Parallel, delayed | ||
from pyannote.audio import Pipeline | ||
from tqdm import tqdm | ||
from tqdm_joblib import tqdm_joblib | ||
|
||
LOG = getLogger(__name__) | ||
|
||
|
||
def _process_one( | ||
input_path: Path, | ||
output_dir: Path, | ||
*, | ||
min_speakers: int = 1, | ||
max_speakers: int = 1, | ||
huggingface_token: str | None = None, | ||
) -> None: | ||
try: | ||
audio, sr = sf.read(input_path) | ||
except Exception as e: | ||
LOG.warning(f"Failed to read {input_path}: {e}") | ||
return | ||
pipeline = Pipeline.from_pretrained( | ||
"pyannote/speaker-diarization", use_auth_token=huggingface_token | ||
) | ||
if pipeline is None: | ||
raise ValueError("Failed to load pipeline") | ||
|
||
LOG.info(f"Processing {input_path}. This may take a while...") | ||
diarization = pipeline( | ||
input_path, min_speakers=min_speakers, max_speakers=max_speakers | ||
) | ||
|
||
LOG.info(f"Found {len(diarization)} tracks, writing to {output_dir}") | ||
speaker_count = defaultdict(int) | ||
|
||
output_dir.mkdir(parents=True, exist_ok=True) | ||
for segment, track, speaker in tqdm( | ||
list(diarization.itertracks(yield_label=True)), desc=f"Writing {input_path}" | ||
): | ||
if segment.end - segment.start < 1: | ||
continue | ||
speaker_count[speaker] += 1 | ||
audio_cut = audio[int(segment.start * sr) : int(segment.end * sr)] | ||
sf.write( | ||
(output_dir / f"{speaker}_{speaker_count[speaker]}.wav"), | ||
audio_cut, | ||
sr, | ||
) | ||
|
||
LOG.info(f"Speaker count: {speaker_count}") | ||
|
||
|
||
def preprocess_speaker_diarization( | ||
input_dir: Path | str, | ||
output_dir: Path | str, | ||
*, | ||
min_speakers: int = 1, | ||
max_speakers: int = 1, | ||
huggingface_token: str | None = None, | ||
n_jobs: int = -1, | ||
) -> None: | ||
if huggingface_token is not None and not huggingface_token.startswith("hf_"): | ||
LOG.warning("Huggingface token probably should start with hf_") | ||
if not torch.cuda.is_available(): | ||
LOG.warning("CUDA is not available. This will be extremely slow.") | ||
input_dir = Path(input_dir) | ||
output_dir = Path(output_dir) | ||
input_dir.mkdir(parents=True, exist_ok=True) | ||
output_dir.mkdir(parents=True, exist_ok=True) | ||
input_paths = list(input_dir.rglob("*.*")) | ||
with tqdm_joblib(desc="Preprocessing speaker diarization", total=len(input_paths)): | ||
Parallel(n_jobs=n_jobs)( | ||
delayed(_process_one)( | ||
input_path, | ||
output_dir / input_path.relative_to(input_dir).parent / input_path.stem, | ||
max_speakers=max_speakers, | ||
min_speakers=min_speakers, | ||
huggingface_token=huggingface_token, | ||
) | ||
for input_path in input_paths | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from logging import getLogger | ||
from pathlib import Path | ||
|
||
import librosa | ||
import soundfile as sf | ||
from joblib import Parallel, delayed | ||
from tqdm import tqdm | ||
from tqdm_joblib import tqdm_joblib | ||
|
||
LOG = getLogger(__name__) | ||
|
||
|
||
def _process_one( | ||
input_path: Path, | ||
output_dir: Path, | ||
*, | ||
top_db: int = 30, | ||
frame_seconds: float = 0.5, | ||
hop_seconds: float = 0.1, | ||
): | ||
try: | ||
audio, sr = librosa.load(input_path) | ||
except Exception as e: | ||
LOG.warning(f"Failed to read {input_path}: {e}") | ||
return | ||
intervals = librosa.effects.split( | ||
audio, | ||
top_db=top_db, | ||
frame_length=int(sr * frame_seconds), | ||
hop_length=int(sr * hop_seconds), | ||
) | ||
output_dir.mkdir(parents=True, exist_ok=True) | ||
for start, end in tqdm(intervals, desc=f"Writing {input_path}"): | ||
audio_cut = audio[start:end] | ||
sf.write( | ||
(output_dir / f"{input_path.stem}_{start / sr:.3f}_{end / sr:.3f}.wav"), | ||
audio_cut, | ||
sr, | ||
) | ||
|
||
|
||
def preprocess_split( | ||
input_dir: Path | str, | ||
output_dir: Path | str, | ||
*, | ||
top_db: int = 30, | ||
frame_seconds: float = 0.5, | ||
hop_seconds: float = 0.1, | ||
n_jobs: int = -1, | ||
): | ||
input_dir = Path(input_dir) | ||
output_dir = Path(output_dir) | ||
output_dir.mkdir(parents=True, exist_ok=True) | ||
input_paths = list(input_dir.rglob("*.*")) | ||
with tqdm_joblib(desc="Splitting", total=len(input_paths)): | ||
Parallel(n_jobs=n_jobs)( | ||
delayed(_process_one)( | ||
input_path, | ||
output_dir / input_path.relative_to(input_dir).parent, | ||
top_db=top_db, | ||
frame_seconds=frame_seconds, | ||
hop_seconds=hop_seconds, | ||
) | ||
for input_path in input_paths | ||
) |