Skip to content

Commit

Permalink
docs: adding PESQ example to gallery (#2755)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 15, 2024
1 parent a8cbae7 commit 6377aa5
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 5 deletions.
117 changes: 117 additions & 0 deletions examples/audio/pesq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""
Evaluating Speech Quality with PESQ metric
==============================================
This notebook will guide you through calculating the Perceptual Evaluation of Speech Quality (PESQ) score,
a key metric in assessing how effective noise reduction and enhancement techniques are in improving speech quality.
PESQ is widely adopted in industries such as telecommunications, VoIP, and audio processing.
It provides an objective way to measure the perceived quality of speech signals from a human listener's perspective.
Imagine being on a noisy street, trying to have a phone call. The technology behind the scenes aims
to clean up your voice and make it sound clearer on the other end. But how do engineers measure that improvement?
This is where PESQ comes in. In this notebook, we will simulate a similar scenario, applying a simple noise reduction
technique and using the PESQ score to evaluate how much the speech quality improves.
"""

# %%
# Import necessary libraries
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchaudio
from torchmetrics.audio import PerceptualEvaluationSpeechQuality

# %%
# Generate Synthetic Clean and Noisy Audio Signals
# We'll generate a clean sine wave (representing a clean speech signal) and add white noise to simulate the noisy version.


def generate_sine_wave(frequency, duration, sample_rate, amplitude: float = 0.5):
"""Generate a clean sine wave at a given frequency."""
t = torch.linspace(0, duration, int(sample_rate * duration))
return amplitude * torch.sin(2 * np.pi * frequency * t)


def add_noise(waveform: torch.Tensor, noise_factor: float = 0.05) -> torch.Tensor:
"""Add white noise to a waveform."""
noise = noise_factor * torch.randn(waveform.size())
return waveform + noise


# Parameters for the synthetic audio
sample_rate = 16000 # 16 kHz typical for speech
duration = 3 # 3 seconds of audio
frequency = 440 # A4 note, can represent a simple speech-like tone

# Generate the clean sine wave
clean_waveform = generate_sine_wave(frequency, duration, sample_rate)

# Generate the noisy waveform by adding white noise
noisy_waveform = add_noise(clean_waveform)


# %%
# Apply Basic Noise Reduction Technique
# In this step, we apply a simple spectral gating method for noise reduction using torchaudio's
# `spectrogram` method. This is to simulate the enhancement of noisy speech.


def reduce_noise(noisy_signal: torch.Tensor, threshold: float = 0.2) -> torch.Tensor:
"""Basic noise reduction using spectral gating."""
# Compute the spectrogram
spec = torchaudio.transforms.Spectrogram()(noisy_signal)

# Apply threshold-based gating: values below the threshold will be zeroed out
spec_denoised = spec * (spec > threshold)

# Convert back to the waveform
return torchaudio.transforms.GriffinLim()(spec_denoised)


# Apply noise reduction to the noisy waveform
enhanced_waveform = reduce_noise(noisy_waveform)

# %%
# Initialize the PESQ Metric
# PESQ can be computed in two modes: 'wb' (wideband) or 'nb' (narrowband).
# Here, we are using 'wb' mode for wideband speech quality evaluation.
pesq_metric = PerceptualEvaluationSpeechQuality(fs=sample_rate, mode="wb")

# %%
# Compute PESQ Scores
# We will calculate the PESQ scores for both the noisy and enhanced versions compared to the clean signal.
# The PESQ scores give us a numerical evaluation of how well the enhanced speech
# compares to the clean speech. Higher scores indicate better quality.

pesq_noisy = pesq_metric(clean_waveform, noisy_waveform)
pesq_enhanced = pesq_metric(clean_waveform, enhanced_waveform)

print(f"PESQ Score for Noisy Audio: {pesq_noisy.item():.4f}")
print(f"PESQ Score for Enhanced Audio: {pesq_enhanced.item():.4f}")

# %%
# Visualize the waveforms
# We can visualize the waveforms of the clean, noisy, and enhanced audio to see the differences.
fig, axs = plt.subplots(3, 1, figsize=(12, 9))

# Plot clean waveform
axs[0].plot(clean_waveform.numpy())
axs[0].set_title("Clean Audio Waveform (Sine Wave)")
axs[0].set_xlabel("Time")
axs[0].set_ylabel("Amplitude")

# Plot noisy waveform
axs[1].plot(noisy_waveform.numpy(), color="orange")
axs[1].set_title(f"Noisy Audio Waveform (PESQ: {pesq_noisy.item():.4f})")
axs[1].set_xlabel("Time")
axs[1].set_ylabel("Amplitude")

# Plot enhanced waveform
axs[2].plot(enhanced_waveform.numpy(), color="green")
axs[2].set_title(f"Enhanced Audio Waveform (PESQ: {pesq_enhanced.item():.4f})")
axs[2].set_xlabel("Time")
axs[2].set_ylabel("Amplitude")

# Adjust layout for better visualization
fig.tight_layout()
plt.show()
12 changes: 7 additions & 5 deletions examples/audio/signal_to_noise_ratio.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,10 @@
import torch
from torchmetrics.audio import SignalNoiseRatio

# Set seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)


# %%
# Generate a clean signal (simulating a high-quality recording)


def generate_clean_signal(length: int = 1000) -> Tuple[np.ndarray, np.ndarray]:
"""Generate a clean signal (sine wave)"""
t = np.linspace(0, 1, length)
Expand All @@ -32,6 +29,8 @@ def generate_clean_signal(length: int = 1000) -> Tuple[np.ndarray, np.ndarray]:

# %%
# Add Gaussian noise to the signal to simulate the noisy environment


def add_noise(signal: np.ndarray, noise_level: float = 0.5) -> np.ndarray:
"""Add Gaussian noise to the signal."""
noise = noise_level * np.random.randn(signal.shape[0])
Expand All @@ -40,6 +39,8 @@ def add_noise(signal: np.ndarray, noise_level: float = 0.5) -> np.ndarray:

# %%
# Apply FFT to filter out the noise


def fft_denoise(noisy_signal: np.ndarray, threshold: float) -> np.ndarray:
"""Denoise the signal using FFT."""
freq_domain = np.fft.fft(noisy_signal) # Filter frequencies using FFT
Expand All @@ -50,6 +51,7 @@ def fft_denoise(noisy_signal: np.ndarray, threshold: float) -> np.ndarray:

# %%
# Generate and plot clean, noisy, and denoised signals to visualize the reconstruction

length = 1000
t, clean_signal = generate_clean_signal(length)
noisy_signal = add_noise(clean_signal, noise_level=0.5)
Expand Down
5 changes: 5 additions & 0 deletions examples/image/clip_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

# %%
# Get sample images

images = {
"astronaut": astronaut(),
"cat": cat(),
Expand All @@ -27,6 +28,7 @@

# %%
# Define a hypothetical captions for the images

captions = [
"A photo of an astronaut.",
"A photo of a cat.",
Expand All @@ -35,6 +37,7 @@

# %%
# Define the models for CLIPScore

models = [
"openai/clip-vit-base-patch16",
# "openai/clip-vit-base-patch32",
Expand All @@ -44,6 +47,7 @@

# %%
# Collect scores for each image-caption pair

score_results = []
for model in models:
clip_score = CLIPScore(model_name_or_path=model)
Expand All @@ -54,6 +58,7 @@

# %%
# Create an animation to display the scores

fig, (ax_img, ax_table) = plt.subplots(1, 2, figsize=(10, 5))


Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ lint.per-file-ignores."docs/source/conf.py" = [
"D103",
]
lint.per-file-ignores."examples/*" = [
"ANN", # any annotaions
"D205", # 1 blank line required between summary line and description
"D212", # [*] Multi-line docstring summary should start at the first line
"D415", # First line should end with a period, question mark, or exclamation point
Expand Down

0 comments on commit 6377aa5

Please sign in to comment.