-
Notifications
You must be signed in to change notification settings - Fork 0
/
partition_data.py
129 lines (115 loc) · 5.08 KB
/
partition_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from load_data import *
import logging
import random
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
def record_net_data_stats(y_train, net_dataidx_map):
net_cls_counts = {}
for net_i, dataidx in net_dataidx_map.items():
unq, unq_cnt = np.unique(y_train[dataidx], return_counts=True)
tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))}
net_cls_counts[net_i] = tmp
data_list=[]
for net_id, data in net_cls_counts.items():
n_total=0
for class_id, n_data in data.items():
n_total += n_data
data_list.append(n_total)
print('mean:', np.mean(data_list))
print('std:', np.std(data_list))
logger.info('Data statistics: %s' % str(net_cls_counts))
return net_cls_counts
def partition_class(dataset,n_parties,n_class,n_label=3):
if n_class==n_label:
(X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts)=partition_dirichlet(dataset, n_parties, beta=100000)
else:
if dataset =="cifar10":
X_train, y_train, X_test, y_test = load_cifar10_data()
elif dataset == "cifar10l":
X_train, y_train, X_test, y_test = load_cifar10_data()
elif dataset == 'cifar100':
X_train, y_train, X_test, y_test = load_cifar100_data()
elif dataset=="tiny-imagenet":
X_train, y_train, X_test, y_test = load_tiny_imagenet_data()
elif dataset=="mnist":
X_train, y_train, X_test, y_test = load_mnist_data()
elif dataset=="fmnist":
X_train, y_train, X_test, y_test = load_fmnist_data()
elif dataset=="SVHN":
X_train, y_train, X_test, y_test = load_svhn_data()
times=[0 for i in range(n_class)]
contain=[]
count=0
for i in range(n_parties):
contain.append([])
for j in range(n_label):
if count < n_class:
label_id=count
count+=1
else:
while (True):
label_id = random.randint(0, n_class - 1)
if label_id not in contain[i]:
break
times[label_id] += 1
contain[i].append(label_id)
net_dataidx_map ={i:np.ndarray(0,dtype=np.int64) for i in range(n_parties)}
for i in range(n_class):
idx_k = np.where(y_train==i)[0]
np.random.shuffle(idx_k)
split = np.array_split(idx_k,times[i])
ids=0
for j in range(n_parties):
if i in contain[j]:
net_dataidx_map[j]=np.append(net_dataidx_map[j],split[ids])
ids+=1
for i in range(n_class):
idx_k = np.where(y_train==i)[0]
np.random.shuffle(idx_k)
split = np.array_split(idx_k,times[i])
ids=0
for j in range(n_parties):
if i in contain[j]:
net_dataidx_map[j]=np.append(net_dataidx_map[j],split[ids])
ids+=1
traindata_cls_counts = record_net_data_stats(y_train, net_dataidx_map)
return (X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts)
def partition_dirichlet(dataset,n_parties, beta=0.4):
if dataset == 'cifar10':
X_train, y_train, X_test, y_test = load_cifar10_data()
elif dataset == 'cifar100':
X_train, y_train, X_test, y_test = load_cifar100_data()
elif dataset == "tiny-imagenet":
X_train, y_train, X_test, y_test = load_tiny_imagenet_data()
elif dataset == "mnist":
X_train, y_train, X_test, y_test = load_mnist_data()
elif dataset == "fmnist":
X_train, y_train, X_test, y_test = load_fmnist_data()
elif dataset == "SVHN":
X_train, y_train, X_test, y_test = load_svhn_data()
min_size = 0
min_require_size = 10
K = 10
if dataset == 'cifar100':
K = 100
if dataset=='tiny-imagenet':
K=200
N = y_train.shape[0]
net_dataidx_map = {}
while min_size < min_require_size:
idx_batch = [[] for _ in range(n_parties)]
for k in range(K):
idx_k = np.where(y_train == k)[0]
np.random.shuffle(idx_k)
proportions = np.random.dirichlet(np.repeat(beta, n_parties))
proportions = np.array([p * (len(idx_j) < N / n_parties) for p, idx_j in zip(proportions, idx_batch)])
proportions = proportions / proportions.sum()
proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))]
min_size = min([len(idx_j) for idx_j in idx_batch])
for j in range(n_parties):
np.random.shuffle(idx_batch[j])
net_dataidx_map[j] = idx_batch[j]
traindata_cls_counts = record_net_data_stats(y_train, net_dataidx_map)
return (X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts)