import subprocess import numpy as np import argparse import os import time import gc from collections import Mapping, Container from sys import getsizeof import h5py from torch.utils.data import DataLoader, Dataset from pytorchtools import EarlyStopping from sklearn import metrics from DS1_model_retrieval_multi_task import * from sklearn.preprocessing import StandardScaler import csv from sklearn.metrics import confusion_matrix, f1_score, roc_auc_score, precision_score, recall_score import matplotlib.pyplot as plt from loss_functions import * def deep_getsizeof(o, ids): d = deep_getsizeof if id(o) in ids: return 0 r = getsizeof(o) ids.add(id(o)) if isinstance(o, str) or isinstance(0, np.unicode): return r if isinstance(o, Mapping): return r + sum(d(k, ids) + d(v, ids) for k, v in o.iteritems()) if isinstance(o, Container): return r + sum(d(x, ids) for x in o) return r # Memory check def memoryCheck(): ps = subprocess.Popen(['nvidia-smi', '--query-gpu=memory.used,utilization.gpu', '--format=csv'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) print(ps.communicate(), '\n') os.system("free -m") # Free memory def freeCacheMemory(): torch.cuda.empty_cache() gc.collect() # Build dataloaders def myDataloader(videoFeatures, audioFeatures, labels, labels_vid, labels_aud, args, shuffleBool=False): class my_dataset(Dataset): def __init__(self, videoData, audioData, label, label_vid, label_aud): self.videoData = videoData self.audioData = audioData self.label = label self.label_vid = label_vid self.label_aud = label_aud def __getitem__(self, index): return self.videoData[index], self.audioData[index], self.label[index], self.label_vid[index], \ self.label_aud[index] def __len__(self): return len(self.videoData) # Build dataloaders my_dataloader = DataLoader(dataset=my_dataset(videoFeatures, audioFeatures, labels, labels_vid, labels_aud), batch_size=args.batch_size, shuffle=shuffleBool) return my_dataloader def myDataloader_retrieval(videoFeatures, audioFeatures, emoValues, args, shuffleBool=False): class my_dataset(Dataset): def __init__(self, videoData, audioData, emo): self.videoData = videoData self.audioData = audioData self.emo = emo def __getitem__(self, index): return self.videoData[index], self.audioData[index], self.emo[index] def __len__(self): return len(self.videoData) # Build dataloaders my_dataloader = DataLoader(dataset=my_dataset(videoFeatures, audioFeatures, emoValues), batch_size=args.batch_size, shuffle=shuffleBool) return my_dataloader # Train def train_func(train_loader, validate_loader, the_model, optimizer, criter, criter_vid, criter_aud, device, n_epochs, patience): start_time = time.time() # to track the training loss as the model trains train_losses = [] valid_losses = [] # to track the validation loss as the model trains # to track the average training loss per epoch as the model trains avg_train_losses = [] # to track the average validation loss per epoch as the model trains avg_valid_losses = [] # initialize the early_stopping object early_stopping = EarlyStopping(patience=patience, verbose=True) for epoch in range(1, n_epochs + 1): # epoch_acc = 0 # Adjust learning rate # adjust_learning_rate(optimizer, epoch) ##################### ## train the model ## ##################### the_model.train() # prep model for training count_batches = 0 for (video_feature, audio_feature, labels, labels_vid, labels_aud) in train_loader: video_feature, audio_feature, labels, labels_vid, labels_aud = video_feature.to(device), audio_feature.to( device), labels.to(device), labels_vid.to(device), labels_aud.to(device) # clear the gradients of all optimized variables optimizer.zero_grad() # forward pass: compute predicted outputs by passing inputs to the model sim, out_vid, out_aud = the_model.forward(video_feature, audio_feature) dist = 1.0 - sim loss = criter(dist, labels.float()) + criter_vid(out_vid, labels_vid) + criter_aud(out_aud, labels_aud) # backward pass: compute gradient of the loss with respect to model parameters loss.backward(retain_graph=True) # perform a single optimization step (parameter update) optimizer.step() # epoch_acc += acc.item() # record training loss train_losses.append(loss.item()) if (count_batches % 100) == 0: print('Batch: ', count_batches) count_batches += 1 # Free catch memory del video_feature, audio_feature, labels_vid, labels_aud freeCacheMemory() ###################### # validate the model # ###################### the_model.eval() # prep model for evaluation # val_epoch_acc = 0 # data, label_for_model, target_disc, target_cont) for (v_video_feature, v_audio_feature, vLabels, vLabels_vid, vLabels_aud) in validate_loader: v_video_feature, v_audio_feature, vLabels, vLabels_vid, vLabels_aud = v_video_feature.to( device), v_audio_feature.to(device), vLabels.to(device), vLabels_vid.to(device), vLabels_aud.to(device) vsim, vout_vid, vout_aud = the_model(v_video_feature, v_audio_feature) vdist = 1.0 - vsim # validation loss: batch_valid_losses = criter.forward(vdist, vLabels.float()) + criter_vid.forward(vout_vid, vLabels_vid) + criter_aud.forward(vout_aud, vLabels_aud) valid_losses.append(batch_valid_losses.item()) del v_video_feature, v_audio_feature, vLabels freeCacheMemory() # print training/validation statistics # calculate average loss over an epoch train_loss = np.average(train_losses) avg_train_losses.append(train_loss) valid_loss = np.average(valid_losses) epoch_len = len(str(n_epochs)) print_msg = (f'[{epoch:>{epoch_len}}/{n_epochs:>{epoch_len}}]' + f' train_loss: {train_loss:.8f} ' + f' valid_loss: {valid_loss:.8f} ') print(print_msg) # clear lists to track next epoch train_losses = [] valid_losses = [] # early_stopping needs the loss to check if it has decreased, # and if it has, it will make a checkpoint of the current model early_stopping(valid_loss.item(), the_model) print('Epoch[{}/{}]: Training time: {} seconds '.format(epoch, n_epochs, time.time() - start_time)) start_time = time.time() if early_stopping.early_stop: print("Early stopping") break # load the last checkpoint with the best model the_model.load_state_dict(torch.load('checkpoint.pt')) return the_model, avg_train_losses, avg_valid_losses # Load extracted features and arousal/valence files def loadingfiles(feature_file, label_file): # Load extracted features and arousal .h5 files print('\n') print('Loading h5 files containing extracted features......') loading_time = time.time() h5file = h5py.File(feature_file, mode='r') getKey = list(h5file.keys())[0] getData = h5file.get(getKey) features = np.asarray(getData) features = torch.from_numpy(features) h5file.close() labelValues = [] labels_vid = [] labels_aud = [] with open(label_file, 'r') as csvfile: csvReader = csv.reader(csvfile) for row in csvReader: labelValues.append(np.int((row[2]))) labels_vid.append(np.int(row[5])) labels_aud.append(np.int(row[6])) labelValues = np.asarray(labelValues) labelValues = torch.from_numpy(labelValues) labels_vid = np.asarray(labels_vid) labels_vid = torch.from_numpy(labels_vid) labels_aud = np.asarray(labels_aud) labels_aud = torch.from_numpy(labels_aud) csvfile.close() return features, labelValues, labels_vid, labels_aud def loadingfiles_retrieval(feature_file, csv_filename): print('Loading h5 files containing extracted features......') loading_time = time.time() h5file = h5py.File(feature_file, mode='r') getKey = list(h5file.keys())[0] getData = h5file.get(getKey) features = np.asarray(getData) features = torch.from_numpy(features) h5file.close() print('Time for loading extracted features: ', time.time() - loading_time) all_filenames = [] with open(csv_filename, 'r') as csvfile: csvReader = csv.reader(csvfile) for row in csvReader: all_filenames.append(row[2]) # 1: filename, 2: emotion label csvfile.close() return features, all_filenames def get_audio_video_views(validate_loader, model_path, device): model = embedding_network().to(device) model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() video_view = [] audio_view = [] with torch.no_grad(): for (video_features, audio_features, _) in validate_loader: video_emb = model.video_projection(model.video_br(video_features)) audio_emb = model.audio_projection(model.audio_br(audio_features)) video_view.append(video_emb) audio_view.append(audio_emb) return torch.cat(video_view), torch.cat(audio_view) def compute_similarity(query, data): cosine_sim = nn.CosineSimilarity() similarity = cosine_sim(query.unsqueeze(0), data) return similarity def AvgP(y_pred_rank, queries_label, label): score = 0.0 count = 0.0 lab_rel = list(queries_label).count(label) for i, p in enumerate(y_pred_rank): if int(p) == int(label): count += 1 score += count / (i + 1.0) if lab_rel == 0: avgP = 0.0 else: avgP = (score * 1.0) / lab_rel return avgP def prec_rec(RelN, y_pred, queries_label, label): prec_list = [] rec_list = [] for _id in RelN: count = 0.0 lab_retrieved = _id lab_rel = list(queries_label).count(label) for i, p in enumerate(y_pred[:_id]): if int(p) == int(label): count += 1 if lab_retrieved != 0: prec = (count * 1.0) / lab_retrieved else: prec = 0.0 if lab_rel != 0: rec = (count * 1.0) / lab_rel else: rec = 0.0 prec_list.append(prec) rec_list.append(rec) return np.asarray(prec_list), np.asarray(rec_list) def metric(queries, view, queries_label): Ap_sum_view = 0.0 query_num = queries.shape[0] prec_all, rec_all = [], [] acc_count_1 = 0 acc_count_3 = 0 acc_count_5 = 0 acc_count_10 = 0 pre_final, rec_final = [], [] for _idx in range(query_num): label = queries_label[_idx] sim_vector = compute_similarity(queries[_idx], view) rank_view_index = torch.argsort(sim_vector, dim=- 1, descending=True) pred_view_label = [queries_label[index] for index in rank_view_index] prec_list, rec_list = prec_rec(RelN, pred_view_label, queries_label, label) AP_view = AvgP(pred_view_label, queries_label, label) Ap_sum_view += AP_view prec_all.append(prec_list) rec_all.append(rec_list) # TopK values_top1, indices_top1 = torch.topk(sim_vector, 1) # top1 recommend_label_top1 = [queries_label[idx_top1] for idx_top1 in indices_top1] values_top3, indices_top3 = torch.topk(sim_vector, 3) # top3 recommend_label_top3 = [queries_label[idx_top3] for idx_top3 in indices_top3] values_top5, indices_top5 = torch.topk(sim_vector, 5) # top5 recommend_label_top5 = [queries_label[idx_top5] for idx_top5 in indices_top5] values_top10, indices_top10 = torch.topk(sim_vector, 10) # top10 recommend_label_top10 = [queries_label[idx_top10] for idx_top10 in indices_top10] if label in recommend_label_top1: acc_count_1 += 1 if label in recommend_label_top3: acc_count_3 += 1 if label in recommend_label_top5: acc_count_5 += 1 if label in recommend_label_top10: acc_count_10 += 1 mAp_view = float("{:.5f}".format((Ap_sum_view * 1.0) / query_num)) print("mAP_view={}".format(mAp_view)) prec_all, rec_all = np.array(prec_all), np.array(rec_all) [pre_final.append(np.mean(prec_all[:, i])) for i in range(prec_all.shape[1])] [rec_final.append(np.mean(rec_all[:, i])) for i in range(rec_all.shape[1])] print("accuracy_top1 (%): ", acc_count_1 / len(queries_label) * 100) print("accuracy_top3 (%): ", acc_count_3 / len(queries_label) * 100) print("accuracy_top5 (%): ", acc_count_5 / len(queries_label) * 100) #print("accuracy_top10 (%): ", acc_count_10 / len(queries_label) * 100) return mAp_view, pre_final, rec_final def plot_precsion_recall_retrieval_task(rec, prec, query2retrieval=""): # Read this: https://ils.unc.edu/courses/2013_spring/inls509_001/lectures/10-EvaluationMetrics.pdf => Note: When plotting a PR curve, we use the best precision for a level of recall or greater! rec_np = np.array(rec) prec_np = np.array(prec) indices = np.argsort(-prec_np) # decreaing order new_prec_np = [prec_np[id] for id in indices] new_rec_np = [rec_np[id] for id in indices] init = new_rec_np[0] get_rec = [] get_prec = [] for i in range(1, len(new_rec_np)): if new_rec_np[i] > init: get_rec.append(new_rec_np[i]) get_prec.append(new_prec_np[i]) init = new_rec_np[i] plt.plot(get_rec, get_prec, label=query2retrieval + " audio-visual retrieval") # plt.plot(rec_cca_va, prec_cca_va, label="CCA") plt.title('') plt.xlabel('Recall', fontsize=14) plt.ylabel('Precision', fontsize=14) leg = plt.legend(bbox_to_anchor=(0.65, 1), ncol=1, mode=None, shadow=True, fancybox=True) leg.get_frame().set_alpha(0.65) plt.grid(True) # plt.savefig("./DS1_retrieval.png") plt.show() # Main def main(args): # Device configuration use_cuda = not args.no_cuda and torch.cuda.is_available() # Manual seed torch.manual_seed(args.seed) device = torch.device("cuda" if use_cuda else "cpu") print('Device: ', device) kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} # Data # for model training video_train_features, train_labelValues, train_labelValues_vid, train_labelValues_aud = loadingfiles( video_feature_file_train, label_file_train) audio_train_features, _, _, _ = loadingfiles(audio_feature_file_train, label_file_train) memoryCheck() video_val_features, val_labelValues, val_labelValues_vid, val_labelValues_aud = loadingfiles( video_feature_file_validate, label_file_validate) audio_val_features, _, _, _ = loadingfiles(audio_feature_file_validate, label_file_validate) # standardize: scaler_1 = StandardScaler() video_train_features = torch.from_numpy(scaler_1.fit_transform(video_train_features)).float() video_val_features = torch.from_numpy(scaler_1.transform(video_val_features)).float() scaler_2 = StandardScaler() audio_train_features = torch.from_numpy(scaler_2.fit_transform(audio_train_features)).float() audio_val_features = torch.from_numpy(scaler_2.transform(audio_val_features)).float() memoryCheck() train_dataset = myDataloader(video_train_features, audio_train_features, train_labelValues, train_labelValues_vid, train_labelValues_aud, args, True) validate_dataset = myDataloader(video_val_features, audio_val_features, val_labelValues, val_labelValues_vid, val_labelValues_aud, args, False) memoryCheck() # ------------------------------------------------------------------------------------------------ # input_size for the model video_dim = video_train_features.shape[1] audio_dim = audio_train_features.shape[1] m_start_time = time.time() # Build the model model = embedding_network(video_dim, audio_dim).to(device) model = model.to(device) memoryCheck() # Loss and optimizer # Cross Entropy Loss criterion = ContrastiveLoss() criterion_video = nn.CrossEntropyLoss() criterion_audio = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.wd) model, train_losses, valid_losses = train_func(train_dataset, validate_dataset, model, optimizer, criterion, criterion_video, criterion_audio, device, args.num_epochs, args.patience) print('Training time: ', time.time() - m_start_time) # save model saved_model = "_DS1_model_retrieval_multi_task.pth" torch.save(model.state_dict(), os.path.join(args.model_path, saved_model)) print("Saved the best model!") # Find matches print("FOR TESTING: ") retrieval_video_test_features, retrieval_test_labelValues = loadingfiles_retrieval(video_feature_retrieval_test, label_retrieval_test) retrieval_audio_test_features, _ = loadingfiles_retrieval(audio_feature_retrieval_test, label_retrieval_test) retrieval_video_test_features = torch.from_numpy(scaler_1.transform(retrieval_video_test_features)).float() retrieval_audio_test_features = torch.from_numpy(scaler_2.transform(retrieval_audio_test_features)).float() retrieval_test_dataset = myDataloader_retrieval(retrieval_video_test_features, retrieval_audio_test_features, retrieval_test_labelValues, args, False) vid_view, aud_view = get_audio_video_views(retrieval_test_dataset, os.path.join(args.model_path, saved_model), device) print("Video query => Retrieve music: ") vid2aud_mAP, vid2aud_precision_list, vid2aud_recall_list = metric(vid_view, aud_view, retrieval_test_labelValues) print("Music query => Retrieve videos: ") aud2vid_mAP, aud2vid_precision_list, aud2vid_recall_list = metric(aud_view, vid_view, retrieval_test_labelValues) plot_precsion_recall_retrieval_task(vid2aud_recall_list, vid2aud_precision_list, query2retrieval="Video to music:") plot_precsion_recall_retrieval_task(aud2vid_recall_list, aud2vid_precision_list, query2retrieval="Music to video:") #os.remove('./checkpoint.pt') if __name__ == "__main__": dir_path = "./DS1_EmoMV_A" # path to extracted features model_path = os.path.join(dir_path, 'models') # path to save models => rember to create this folder pred_path = os.path.join(dir_path, 'predicted_values') # path to save predicted values => rember to create this folder parser = argparse.ArgumentParser() parser.add_argument('--model_path', type=str, default=model_path, help='path for saving trained models') parser.add_argument('--num_epochs', type=int, default=1000) parser.add_argument('--patience', type=int, default=20, help='early stopping patience; how long to wait after last time validation loss improved') parser.add_argument('--batch_size', type=int, default=256, help='number of feature vectors loaded per batch') parser.add_argument('--lr', type=float, default=0.0001, metavar='LR', help='initial learning rate') parser.add_argument('--wd', type=float, default=0.1, help='weight decay') # parser.add_argument('--mm', type=float, default=0.9, help='momentum') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 123)') args = parser.parse_args() print(args) # for training label_file_train = os.path.join(dir_path + "/" + "annotation", "DS1_TRAIN_MATCH_MISMATCH_labels.csv") video_feature_file_train = os.path.join(dir_path + "/" + "extracted_features", "SlowFast_DS1_TRAIN_MATCH_MISMATCH.h5") audio_feature_file_train = os.path.join(dir_path + "/" + "extracted_features", "VGGish_DS1_TRAIN_MATCH_MISMATCH.h5") # for validation label_file_validate = os.path.join(dir_path + "/" + "annotation", "DS1_VAL_MATCH_MISMATCH_labels.csv") video_feature_file_validate = os.path.join(dir_path + "/" + "extracted_features", "SlowFast_DS1_VAL_MATCH_MISMATCH.h5") audio_feature_file_validate = os.path.join(dir_path + "/" + "extracted_features", "VGGish_DS1_VAL_MATCH_MISMATCH.h5") # for testing dir_path_retrieval = "./DS1_EmoMV_A" label_retrieval_test = os.path.join(dir_path_retrieval + "/" + "for_retrieval", "DS1_from_MVED_test_set_5_classes_for_retrieval.csv") video_feature_retrieval_test = os.path.join(dir_path_retrieval + "/" + "for_retrieval", "SlowFast_DS1_from_MVED_test_set_5_classes_for_retrieval.h5") audio_feature_retrieval_test = os.path.join(dir_path_retrieval + "/" + "for_retrieval", "VGGish_DS1_from_MVED_test_set_5_classes_for_retrieval.h5") # ------------------------------------------------------------------------------------------------------------------- main_start_time = time.time() lbl_range = ["Mismatched", "Matched"] # mismatched: 0, matched: 1 RelN = [i for i in range(1, 250)] # for EmoMV-B, EmoMV-C, remember to update this number main(args) print('Total running time: {:.5f} seconds'.format(time.time() - main_start_time))