-
Notifications
You must be signed in to change notification settings - Fork 77
/
calc_metrics.py
67 lines (57 loc) · 3.08 KB
/
calc_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from os.path import join
from glob import glob
from argparse import ArgumentParser
from soundfile import read
from tqdm import tqdm
from pesq import pesq
import pandas as pd
import librosa
from pystoi import stoi
from sgmse.util.other import energy_ratios, mean_std
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("--clean_dir", type=str, required=True, help='Directory containing the clean data')
parser.add_argument("--noisy_dir", type=str, required=True, help='Directory containing the noisy data')
parser.add_argument("--enhanced_dir", type=str, required=True, help='Directory containing the enhanced data')
args = parser.parse_args()
data = {"filename": [], "pesq": [], "estoi": [], "si_sdr": [], "si_sir": [], "si_sar": []}
# Evaluate standard metrics
noisy_files = []
noisy_files += sorted(glob(join(args.noisy_dir, '*.wav')))
noisy_files += sorted(glob(join(args.noisy_dir, '**', '*.wav')))
for noisy_file in tqdm(noisy_files):
filename = noisy_file.replace(args.noisy_dir, "")[1:]
if 'dB' in filename:
clean_filename = filename.split("_")[0] + ".wav"
else:
clean_filename = filename
x, sr_x = read(join(args.clean_dir, clean_filename))
y, sr_y = read(join(args.noisy_dir, filename))
x_hat, sr_x_hat = read(join(args.enhanced_dir, filename))
assert sr_x == sr_y == sr_x_hat
n = y - x
x_hat_16k = librosa.resample(x_hat, orig_sr=sr_x_hat, target_sr=16000) if sr_x_hat != 16000 else x_hat
x_16k = librosa.resample(x, orig_sr=sr_x, target_sr=16000) if sr_x != 16000 else x
data["filename"].append(filename)
data["pesq"].append(pesq(16000, x_16k, x_hat_16k, 'wb'))
data["estoi"].append(stoi(x, x_hat, sr_x, extended=True))
data["si_sdr"].append(energy_ratios(x_hat, x, n)[0])
data["si_sir"].append(energy_ratios(x_hat, x, n)[1])
data["si_sar"].append(energy_ratios(x_hat, x, n)[2])
# Save results as DataFrame
df = pd.DataFrame(data)
# Print results
print("PESQ: {:.2f} ± {:.2f}".format(*mean_std(df["pesq"].to_numpy())))
print("ESTOI: {:.2f} ± {:.2f}".format(*mean_std(df["estoi"].to_numpy())))
print("SI-SDR: {:.1f} ± {:.1f}".format(*mean_std(df["si_sdr"].to_numpy())))
print("SI-SIR: {:.1f} ± {:.1f}".format(*mean_std(df["si_sir"].to_numpy())))
print("SI-SAR: {:.1f} ± {:.1f}".format(*mean_std(df["si_sar"].to_numpy())))
# Save average results to file
log = open(join(args.enhanced_dir, "_avg_results.txt"), "w")
log.write("PESQ: {:.2f} ± {:.2f}".format(*mean_std(df["pesq"].to_numpy())) + "\n")
log.write("ESTOI: {:.2f} ± {:.2f}".format(*mean_std(df["estoi"].to_numpy())) + "\n")
log.write("SI-SDR: {:.1f} ± {:.2f}".format(*mean_std(df["si_sdr"].to_numpy())) + "\n")
log.write("SI-SIR: {:.1f} ± {:.1f}".format(*mean_std(df["si_sir"].to_numpy())) + "\n")
log.write("SI-SAR: {:.1f} ± {:.1f}".format(*mean_std(df["si_sar"].to_numpy())) + "\n")
# Save DataFrame as csv file
df.to_csv(join(args.enhanced_dir, "_results.csv"), index=False)