-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest.py
79 lines (65 loc) · 2.37 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# -*- coding: utf-8 -*-
"""
Created on Wed Apr 24 13:05:28 2019
@author: Hung
"""
#%%
#IMPORT LIBRARIES
from utils import save_features, get_features, read_features, feature_extraction, reconstruct, split
from model import get_model
from data_prepare import stereo_to_mono, compress
import os
from sklearn.model_selection import train_test_split
import soundfile as sf
from keras.models import load_model
from scipy import signal
from keras.callbacks import ModelCheckpoint,EarlyStopping
import datetime
import librosa
import matplotlib.pyplot as plt
import scipy
import numpy as np
#%%
# PREPARE DATA
#stereo_to_mono('rawdata','groundtruth')
#compress('groundtruth','training_samples')
test_data,test_fs = sf.read('training_samples/Triviul-Dorothy.wav')
orig_data,orig_fs = sf.read('groundtruth/Triviul-Dorothy.wav')
#%%
# LOAD MODEL
model = load_model('SRCNN_2019-05-03 22_09_28_bestMix.h5')
#%%
# Reconstruct
predict = reconstruct(test_data,test_fs,model)
#%%
# SPECTROGRAM ANALYSIS
output_data,output_fs = sf.read('output_with_phase.wav')
# Plot spectrogram of original data
plt.figure(0)
orig_f,orig_t,orig_spec = scipy.signal.stft(orig_data,orig_fs)
plt.pcolormesh(orig_t, orig_f, 20 * np.log10(np.abs(orig_spec) + 0.0001))
plt.title('Spectrogram of original high-quality data')
plt.xlabel('Time (s)')
plt.ylabel('Freq (Hz)')
# Plot spectrogram of test data (compressed)
plt.figure(1)
test_f, test_t, test_spec = scipy.signal.stft(test_data,test_fs)
plt.pcolormesh(test_t, test_f, 20 * np.log10(np.abs(test_spec) + 0.0001))
plt.title('Spectrogram of test data (compressed)')
plt.xlabel('Time (s)')
plt.ylabel('Freq (Hz)')
# Plot spectrogram of created high-quality data
plt.figure(2)
predict_f, predict_t, predict_spec = scipy.signal.stft(output_data,output_fs)
plt.pcolormesh(predict_t, predict_f, 20 * np.log10(np.abs(predict_spec) + 0.0001))
plt.title('Spectrogram of output')
plt.xlabel('Time (s)')
plt.ylabel('Freq (Hz)')
# Calculate mse
#orig_vs_test = np.mean((orig_data[:len(output_data)]-test_data[:len(output_data)])**2)
#orig_vs_output = np.mean((orig_data[:len(output_data)]-output_data)**2)
# Calculate mae
orig_vs_test = np.mean(np.abs((orig_data[:len(output_data)]-test_data[:len(output_data)])))
orig_vs_output = np.mean(np.abs((orig_data[:len(output_data)]-output_data)))
print('MAE of original vs downsampled files: ',orig_vs_test)
print('MAE of original vs output files: ',orig_vs_output)