-
Notifications
You must be signed in to change notification settings - Fork 34
/
semantic_channel.py
126 lines (88 loc) · 3.56 KB
/
semantic_channel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Jul 3 18:51:36 2020|
@author: wuzongze
"""
import pickle
import numpy as np
import pandas as pd
import argparse
import os
def LoadAMask(opt):
for i in range(0,1000,int(opt.num_per)):
try:
tmp=os.path.join(opt.align_folder,str(i))
with open(tmp, 'rb') as handle:
var_grad = pickle.load(handle)
if not 'all_var_grad' in locals():
num_layer=len(var_grad)
all_var_grad=[[] for i in range(num_layer)]
for k in range(num_layer):
all_var_grad[k].append(var_grad[k])
except FileNotFoundError:
print(i)
continue
for i in range(num_layer):
all_var_grad[i]=np.concatenate(all_var_grad[i])
print('num of sample:',all_var_grad[0].shape[0])
return all_var_grad
def TopRate(all_var_grad):
num_layer=len(all_var_grad)
num_semantic=all_var_grad[0].shape[2]
discount_factor=2 #large number means pay higher weight precision (prefer small area)
all_count_top=[]
for lindex in range(num_layer):
layer_g=all_var_grad[lindex]
num_channel=layer_g.shape[1]
count_top=np.zeros([num_channel,num_semantic])
for cindex in range(num_channel):
semantic_in=layer_g[:,cindex,:,0]/(layer_g[:,cindex,:,2]**discount_factor)
semantic_top=np.nanargmax(np.abs(semantic_in),axis=1)
semantic_top=pd.Series(semantic_top)
tmp=semantic_top.value_counts()
count_top[cindex,tmp.index]=tmp.values
all_count_top.append(count_top)
tmp=all_var_grad[0][:,0,:,2]
mask_counts2=~np.isnan(tmp)
mask_counts3=mask_counts2.sum(axis=0)
mask_counts3[mask_counts3==0]=1 # ignore 0
all_count_top2=[]
for lindex in range(len(all_count_top)):
all_count_top2.append(all_count_top[lindex]/mask_counts3)
return all_count_top2
def PadTRGB(opt,all_count_top2):
with open(opt.s_path, "rb") as fp: #Pickling
s_names,all_s=pickle.load( fp)
tmp_index=0
all_count_top3=[[] for i in range(len(s_names))]
num_sa=all_count_top2[0].shape[1]
for i in range(len(s_names)):
s_name=s_names[i]
if 'ToRGB' in s_name:
tmp=np.zeros([all_s[i].shape[1],num_sa])
else:
tmp=all_count_top2[tmp_index]
tmp_index+=1
all_count_top3[i]=tmp
all_count_top2=all_count_top3
return all_count_top2
#%%
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='predict pose of object')
parser.add_argument('-align_folder',default='./npy/ffhq/align_mask_32',type=str,help='path to align_mask_32 folder')
parser.add_argument('-s_path',default='./npy/ffhq/S',type=str,help='path to ')
parser.add_argument('-save_folder',default='./npy/ffhq/',type=str,help='path to save folder')
parser.add_argument('-num_per',default='4',type=str,help='path to model file')
parser.add_argument('-include_trgb', action='store_true')
opt = parser.parse_args()
#%%
all_var_grad=LoadAMask(opt)
all_count_top2=TopRate(all_var_grad)
if not opt.include_trgb:
all_count_top2=PadTRGB(opt,all_count_top2)
#%%
tmp=os.path.join(opt.save_folder,'semantic_top_32')
with open(tmp, 'wb') as handle:
pickle.dump(all_count_top2, handle, protocol=pickle.HIGHEST_PROTOCOL)
#%%