forked from lifefeel/Vocal-Percussion-to-Drum
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtranscription.py
157 lines (123 loc) · 5.17 KB
/
transcription.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import torch
import torchaudio
import numpy as np
import librosa
import matplotlib.pyplot as plt
import pickle
import argparse
from pathlib import Path
import IPython.display as ipd
from model_zoo import RNNModel_onset, RNNModel_velocity
class SpecConverter():
def __init__(self, sr=44100, n_fft=2048, hop_length=1024, n_mels=128, fmin=0, fmax=None):
self.sr = sr
self.n_fft = n_fft
self.hop_length = hop_length
self.n_mels = n_mels
self.fmin = fmin
self.fmax = fmax
def forward(self, wav):
mel_spec = torchaudio.transforms.MelSpectrogram(
sample_rate=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_mels=self.n_mels)
log_mel_spec = torchaudio.transforms.AmplitudeToDB()(mel_spec(wav))
return log_mel_spec
def high_freq_content(spectrogram_dB):
# Convert the decibel values back to linear amplitude values
spectrogram = librosa.db_to_power(spectrogram_dB)
# Create a frequency axis
freqs = librosa.core.mel_frequencies(n_mels=128)
# Calculate the weighted mean of the amplitude for each bin
hfc_values = np.empty(spectrogram.shape[1])
for t in range(spectrogram.shape[1]):
hfc_t = np.sum(freqs * spectrogram[:, t])
hfc_values[t] = hfc_t
return hfc_values
def reducing_time_resolution(mel_spec, aggregate_factor=4, len_quantized=16):
db_mel_spec_cnvtd = []
for idx in range(len_quantized):
spec_for_agg = mel_spec[:, idx*aggregate_factor:(idx+1)*aggregate_factor]
aggregated_spec = torch.mean(spec_for_agg, dim=1, keepdim=True)
#print(aggregated_spec[0].shape)
db_mel_spec_cnvtd.append(aggregated_spec)
db_mel_spec_cnvtd = torch.cat(db_mel_spec_cnvtd, dim=1)
return db_mel_spec_cnvtd
def denoise(drum_roll):
first_onset_count = 90
for row in drum_roll:
for idx, val in enumerate(row):
if val > 0:
if first_onset_count - idx > 0:
first_onset_count = idx
break
elif first_onset_count - idx < 0:
break
return drum_roll[:, first_onset_count:]
def Dense_onsets(tensor):
result = np.zeros_like(tensor)
avg = 0
count = 0
first_non_zero_index = None
for i, value in enumerate(tensor):
if value != 0:
if first_non_zero_index is None:
first_non_zero_index = i
avg += value
count += 1
else:
if count != 0:
result[first_non_zero_index] = avg / count
avg, count, first_non_zero_index = 0, 0, None
if count != 0:
result[first_non_zero_index] = avg / count
return result
def get_argument_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--wav_path', type=str, default='audio_samples/pmta.wav')
return parser
def plt_imshow(npimg, title=None, filename=None):
plt.figure(figsize=(20, 10))
plt.imshow(npimg, aspect='auto', origin='lower', interpolation='nearest')
if title is not None:
plt.title(title)
output_path = f'transcribed_sample_results/{title}_{filename}.png'
plt.savefig(output_path)
print('saved:', output_path)
if __name__ == "__main__":
args = get_argument_parser().parse_args()
print('input audio:', args.wav_path)
audio_path = Path(args.wav_path)
sample_first, sr = torchaudio.load(audio_path)
device = torch.device('cpu')
onset_model = RNNModel_onset(num_nmels=128, hidden_size=128).to(device)
velocity_model = RNNModel_velocity(num_nmels=128, hidden_size=128).to(device)
# load model
onset_model.load_state_dict(torch.load('models/onset_model_noQZ.pt'))
velocity_model.load_state_dict(torch.load('models/velocity_model_noQZ.pt'))
spec_converter = SpecConverter(sr=44100, n_fft=512, hop_length=128, n_mels=128)
mel_spec = spec_converter.forward(sample_first.unsqueeze(0))
mel_spec = mel_spec[0][0][:,:2756]
mel_spec = mel_spec.to(device)
threshold = 0.4
onset_pred = onset_model(mel_spec.unsqueeze(0))
onset_pred_guide = (onset_pred > threshold).float() # time x 4
velocity_pred = velocity_model(mel_spec.unsqueeze(0), onset_pred_guide.unsqueeze(0))
velocity_pred = velocity_pred * onset_pred_guide.unsqueeze(0)
hfc_values = high_freq_content(mel_spec.cpu().detach().numpy())
onset_idx = np.argwhere(hfc_values > np.percentile(hfc_values, 60))
onset_pred_cleaned = torch.zeros_like(velocity_pred.squeeze())
for idx in onset_idx.squeeze():
onset_pred_cleaned[idx] = velocity_pred.squeeze()[idx]
plt_imshow(onset_pred_cleaned.cpu().detach().numpy().T, title='onset_pred_cleaned', filename=audio_path.stem)
aggregate_factor = onset_pred_cleaned.shape[0] // 128
db_mel_spec_cnvtd = reducing_time_resolution(onset_pred_cleaned.T, aggregate_factor, 128) # 128 x timestep
threshold_idx = torch.argwhere(db_mel_spec_cnvtd > 0.3)
drum_roll_QZ = torch.zeros_like(db_mel_spec_cnvtd)
for row in threshold_idx:
drum_roll_QZ[row[0], row[1]] = db_mel_spec_cnvtd[row[0], row[1]]
denoised_drum_roll = denoise(drum_roll_QZ)
densed_drumroll = np.zeros_like(denoised_drum_roll.cpu().detach().numpy())
for idx, row in enumerate(denoised_drum_roll.cpu().detach().numpy()):
densed_drumroll[idx] = Dense_onsets(row)
plt_imshow(densed_drumroll, title='densed_drumroll', filename=audio_path.stem)
with open(f'transcribed_sample_results/{audio_path.stem}.pkl', 'wb') as f:
pickle.dump(densed_drumroll[:,:64], f)