forked from benbogin/spider-schema-gnn
-
Notifications
You must be signed in to change notification settings - Fork 1
/
g_util.py
63 lines (56 loc) · 2.04 KB
/
g_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
# from path import Path
import numpy as np
from collections import defaultdict
import inspect
# tensor2DToCsv(tensor,path='/home/yj/Documents/Python/Github/seq2seq/data/gan.txt')
def tensor2DToCsv(tensor,path=None,token=',',write_name=True):
def get_variable_name(variable):
callers_local_vars = inspect.currentframe().f_back.f_locals.items()
return [var_name for var_name, var_val in callers_local_vars if var_val is variable]
tensor = tensor.cpu()
name = ''.join(get_variable_name(tensor))
assert(path is not None)
z = tensor.numpy().tolist()
if len(np.shape(z)) == 2:
with open(path,'a') as f:
if write_name:
f.write(name)
else:
f.write('\r')
f.write('\r')
for i in range(np.shape(z)[0]):
for j in range(np.shape(z)[1]):
f.write(str(z[i][j]))
f.write(token)
f.write('\r')
elif len(np.shape(z)) == 1:
with open(path,'a') as f:
if write_name:
f.write(name)
else:
f.write('\r')
f.write('\r')
for i in range(np.shape(z)[0]):
f.write(str(z[i]))
f.write(token)
f.write('\r')
else:
raise "Not support 3D tensor."
# a = torch.tensor([[[1,2,3],[4,5,6]],[[2,2,2],[5,5,5]]])
# tensorToCsv(tensor,path='/home/yj/Documents/gan.txt')
def tensorToCsv(tensor,path=None,token=','):
tensor = tensor.cpu().detach()
z = tensor.numpy().tolist()
if len(np.shape(z)) == 3:
for i in range(np.shape(z)[0]):
tensor2DToCsv(tensor[i],path=path,token=token,write_name=False)
elif len(np.shape(z)) < 3:
tensor2DToCsv(tensor,path=path,token=token)
else:
raise "Not support 4D tensor."
# import torch
# a = torch.tensor([[[1,2,3],[4,5,6]],[[2,2,2],[5,5,5]]])
# tensorToCsv(a,path='/home/yj/Documents/gan22.txt')
# from g_util import tensorToCsv
# tensorToCsv(tensor,path='/home/yj/Documents/gan.txt')