-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpartition_data.py
210 lines (180 loc) · 7.12 KB
/
partition_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import math
import functools
import numpy as np
import torch
import torch.distributed as dist
class Partition(object):
""" Dataset-like object, but only access a subset of it. """
def __init__(self, data, indices):
self.data = data
self.indices = indices
def __len__(self):
return len(self.indices)
def __getitem__(self, index):
data_idx = self.indices[index]
return self.data[data_idx]
class DataPartitioner(object):
""" Partitions a dataset into different chuncks. """
def __init__(
self,
seed,
data,
partition_sizes,
non_iid_alpha = 1.0,
partition_type="non_iid_dirichlet",
):
# prepare info.
self.random_state = np.random.RandomState(seed)
self.data = data
self.non_iid_alpha = non_iid_alpha
self.partition_sizes = partition_sizes
self.partition_type = partition_type
self.partitions = []
# get data, data_size, indices of the data.
self.data_size = len(data)
if type(data) is not Partition:
self.data = data
indices = np.array([x for x in range(0, self.data_size)])
else:
self.data = data.data
indices = data.indices
# apply partition function.
self.partition_indices(indices)
def partition_indices(self, indices):
indices = self._create_indices(indices)
# partition indices.
from_index = 0
for partition_size in self.partition_sizes:
to_index = from_index + int(partition_size * self.data_size)
self.partitions.append(indices[from_index:to_index])
from_index = to_index
# display the class distribution over the partitions.
self.targets_of_partitions = record_class_distribution(
self.partitions,
self.data.targets if hasattr(self.data, "targets") else self.data.golds,
)
def _create_indices(self, indices):
if self.partition_type == "origin":
pass
elif self.partition_type == "random":
# it will randomly shuffle the indices.
self.random_state.shuffle(indices)
elif self.partition_type == "sorted":
# it will sort the indices based on the data label.
indices = [
i[0]
for i in sorted(
[
(idx, target)
for idx, target in enumerate(self.data.targets)
if idx in indices
],
key=lambda x: x[1],
)
]
elif self.partition_type == "non_iid_dirichlet":
num_indices = len(indices)
n_workers = len(self.partition_sizes)
targets = (
self.data.targets if hasattr(self.data, "targets") else self.data.golds
)
num_classes = len(np.unique(targets))
indices2targets = np.array(list(enumerate(targets)))
list_of_indices = build_non_iid_by_dirichlet(
random_state=self.random_state,
indices2targets=indices2targets,
non_iid_alpha=self.non_iid_alpha,
num_classes=num_classes,
num_indices=num_indices,
n_workers=n_workers,
)
indices = functools.reduce(lambda a, b: a + b, list_of_indices)
else:
raise NotImplementedError(
f"The partition scheme={self.partition_type} is not implemented yet"
)
return indices
def use(self, partition_ind):
return Partition(self.data, self.partitions[partition_ind]), self.targets_of_partitions[partition_ind]
def build_non_iid_by_dirichlet(
random_state, indices2targets, non_iid_alpha, num_classes, num_indices, n_workers
):
n_auxi_workers = 10
# random shuffle targets indices.
random_state.shuffle(indices2targets)
# partition indices.
from_index = 0
splitted_targets = []
num_splits = math.ceil(n_workers / n_auxi_workers)
split_n_workers = [
n_auxi_workers
if idx < num_splits - 1
else n_workers - n_auxi_workers * (num_splits - 1)
for idx in range(num_splits)
]
split_ratios = [_n_workers / n_workers for _n_workers in split_n_workers]
for idx, ratio in enumerate(split_ratios):
to_index = from_index + int(n_auxi_workers / n_workers * num_indices)
splitted_targets.append(
indices2targets[
from_index : (num_indices if idx == num_splits - 1 else to_index)
]
)
from_index = to_index
#
idx_batch = []
for _targets in splitted_targets:
# rebuild _targets.
_targets = np.array(_targets)
_targets_size = len(_targets)
# use auxi_workers for this subset targets.
_n_workers = min(n_auxi_workers, n_workers)
n_workers = n_workers - n_auxi_workers
# get the corresponding idx_batch.
min_size = 0
while min_size < int(0.50 * _targets_size / _n_workers):
_idx_batch = [[] for _ in range(_n_workers)]
for _class in range(num_classes):
# get the corresponding indices in the original 'targets' list.
idx_class = np.where(_targets[:, 1] == _class)[0]
idx_class = _targets[idx_class, 0]
# sampling.
try:
proportions = random_state.dirichlet(
np.repeat(non_iid_alpha, _n_workers)
)
# balance
proportions = np.array(
[
p * (len(idx_j) < _targets_size / _n_workers)
for p, idx_j in zip(proportions, _idx_batch)
]
)
proportions = proportions / proportions.sum()
proportions = (np.cumsum(proportions) * len(idx_class)).astype(int)[
:-1
]
_idx_batch = [
idx_j + idx.tolist()
for idx_j, idx in zip(
_idx_batch, np.split(idx_class, proportions)
)
]
sizes = [len(idx_j) for idx_j in _idx_batch]
min_size = min([_size for _size in sizes])
except ZeroDivisionError:
pass
idx_batch += _idx_batch
return idx_batch
def record_class_distribution(partitions, targets):
targets_of_partitions = {}
targets_np = np.array(targets)
for idx, partition in enumerate(partitions):
unique_elements, counts_elements = np.unique(
targets_np[partition], return_counts=True
)
targets_of_partitions[idx] = list(zip(unique_elements, counts_elements))
# print(
# f"the histogram of the targets in the partitions: {targets_of_partitions.items()}"
# )
return targets_of_partitions