-
Notifications
You must be signed in to change notification settings - Fork 0
/
dict_to_csv.py
29 lines (26 loc) · 1.15 KB
/
dict_to_csv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import csv
import os
import argparse
parser = argparse.ArgumentParser(description='Propert ResNets for CIFAR10 in pytorch')
parser.add_argument('--save-dir', dest='save_dir',
help='The directory used to save the trained models',
default='save_temp', type=str)
parser.add_argument('--output-file', dest='output_file',
help='The directory used to save the trained models',
default='output.tsv', type=str)
args = parser.parse_args()
def average(input):
return sum(input)/len(input)
dict_data = torch.load(os.path.join(args.save_dir, "excel_data","dict"))
fields = dict_data.keys()
dict_data["avg test acc"] = average(dict_data["avg test acc"])
dict_data["data transferred"] = average(dict_data["data transferred"])
print(dict_data)
if not( os.path.isfile(args.output_file) ):
with open(args.output_file, 'a') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames= fields, delimiter='\t' )
writer.writeheader()
with open(args.output_file, 'a') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames= fields, delimiter='\t' )
writer.writerow(dict_data)