Skip to content

Commit

Permalink
reddit dataset v2
Browse files Browse the repository at this point in the history
  • Loading branch information
scaldas committed Dec 5, 2019
2 parents 5c45514 + 0bf6a9a commit fb5b1da
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 23 deletions.
26 changes: 15 additions & 11 deletions data/femnist/preprocess/data_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,17 @@ def relabel_class(c):

num_json = int(math.ceil(len(writers) / MAX_WRITERS))

users = [[] for _ in range(num_json)]
num_samples = [[] for _ in range(num_json)]
user_data = [{} for _ in range(num_json)]
users = []
num_samples = []
user_data = {}

writer_count = 0
json_index = 0
for (w, l) in writers:

users[json_index].append(w)
num_samples[json_index].append(len(l))
user_data[json_index][w] = {'x': [], 'y': []}
users.append(w)
num_samples.append(len(l))
user_data[w] = {'x': [], 'y': []}

size = 28, 28 # original image size is 128, 128
for (f, c) in l:
Expand All @@ -69,16 +69,16 @@ def relabel_class(c):

nc = relabel_class(c)

user_data[json_index][w]['x'].append(vec)
user_data[json_index][w]['y'].append(nc)
user_data[w]['x'].append(vec)
user_data[w]['y'].append(nc)

writer_count += 1
if writer_count == MAX_WRITERS:

all_data = {}
all_data['users'] = users[json_index]
all_data['num_samples'] = num_samples[json_index]
all_data['user_data'] = user_data[json_index]
all_data['users'] = users
all_data['num_samples'] = num_samples
all_data['user_data'] = user_data

file_name = 'all_data_%d.json' % json_index
file_path = os.path.join(parent_path, 'data', 'all_data', file_name)
Expand All @@ -90,3 +90,7 @@ def relabel_class(c):

writer_count = 0
json_index += 1

users[:] = []
num_samples[:] = []
user_data.clear()
2 changes: 1 addition & 1 deletion data/utils/remove_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
if 'hierarchies' in data:
curr_hierarchy = data['hierarchies'][i]
curr_num_samples = data['num_samples'][i]
if (curr_num_samples > args.min_samples):
if (curr_num_samples >= args.min_samples):
user_data[curr_user] = data['user_data'][curr_user]
users.append(curr_user)
if curr_hierarchy is not None:
Expand Down
4 changes: 2 additions & 2 deletions models/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def test(self, set_to_use='test'):
Return:
dict of metrics returned by the model.
"""
assert set_to_use in ['train', 'test']
assert set_to_use in ['train', 'test', 'val']
if set_to_use == 'train':
data = self.train_data
elif set_to_use == 'test':
elif set_to_use == 'test' or set_to_use == 'val':
data = self.eval_data
return self.model.test(data)

Expand Down
20 changes: 11 additions & 9 deletions models/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,15 @@ def main():
server = Server(client_model)

# Create clients
clients = setup_clients(args.dataset, client_model)
clients = setup_clients(args.dataset, client_model, args.use_val_set)
client_ids, client_groups, client_num_samples = server.get_clients_info(clients)
print('Clients in Total: %d' % len(clients))

# Initial status
print('--- Random Initialization ---')
stat_writer_fn = get_stat_writer_function(client_ids, client_groups, client_num_samples, args)
sys_writer_fn = get_sys_writer_function(args)
print_stats(0, server, clients, client_num_samples, args, stat_writer_fn)
print_stats(0, server, clients, client_num_samples, args, stat_writer_fn, args.use_val_set)

# Simulate training
for i in range(num_rounds):
Expand All @@ -88,7 +88,7 @@ def main():

# Test model
if (i + 1) % eval_every == 0 or (i + 1) == num_rounds:
print_stats(i + 1, server, clients, client_num_samples, args, stat_writer_fn)
print_stats(i + 1, server, clients, client_num_samples, args, stat_writer_fn, args.use_val_set)

# Save server model
ckpt_path = os.path.join('checkpoints', args.dataset)
Expand All @@ -112,14 +112,15 @@ def create_clients(users, groups, train_data, test_data, model):
return clients


def setup_clients(dataset, model=None):
def setup_clients(dataset, model=None, use_val_set=False):
"""Instantiates clients based on given train and test data directories.
Return:
all_clients: list of Client objects.
"""
eval_set = 'test' if not use_val_set else 'val'
train_data_dir = os.path.join('..', 'data', dataset, 'data', 'train')
test_data_dir = os.path.join('..', 'data', dataset, 'data', 'test')
test_data_dir = os.path.join('..', 'data', dataset, 'data', eval_set)

users, groups, train_data, test_data = read_data(train_data_dir, test_data_dir)

Expand Down Expand Up @@ -147,15 +148,16 @@ def writer_fn(num_round, ids, metrics, groups, num_samples):


def print_stats(
num_round, server, clients, num_samples, args, writer):
num_round, server, clients, num_samples, args, writer, use_val_set):

train_stat_metrics = server.test_model(clients, set_to_use='train')
print_metrics(train_stat_metrics, num_samples, prefix='train_')
writer(num_round, train_stat_metrics, 'train')

test_stat_metrics = server.test_model(clients, set_to_use='test')
print_metrics(test_stat_metrics, num_samples, prefix='test_')
writer(num_round, test_stat_metrics, 'test')
eval_set = 'test' if not use_val_set else 'val'
test_stat_metrics = server.test_model(clients, set_to_use=eval_set)
print_metrics(test_stat_metrics, num_samples, prefix='{}_'.format(eval_set))
writer(num_round, test_stat_metrics, eval_set)


def print_metrics(metrics, weights, prefix=''):
Expand Down
3 changes: 3 additions & 0 deletions models/utils/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def parse_args():
type=str,
default='metrics',
required=False)
parser.add_argument('--use-val-set',
help='use validation set;',
action='store_true')

# Minibatch doesn't support num_epochs, so make them mutually exclusive
epoch_capability_group = parser.add_mutually_exclusive_group()
Expand Down

0 comments on commit fb5b1da

Please sign in to comment.