diff --git a/data/femnist/preprocess/data_to_json.py b/data/femnist/preprocess/data_to_json.py index 9f2d2b31..7d19bfcc 100644 --- a/data/femnist/preprocess/data_to_json.py +++ b/data/femnist/preprocess/data_to_json.py @@ -44,17 +44,17 @@ def relabel_class(c): num_json = int(math.ceil(len(writers) / MAX_WRITERS)) -users = [[] for _ in range(num_json)] -num_samples = [[] for _ in range(num_json)] -user_data = [{} for _ in range(num_json)] +users = [] +num_samples = [] +user_data = {} writer_count = 0 json_index = 0 for (w, l) in writers: - users[json_index].append(w) - num_samples[json_index].append(len(l)) - user_data[json_index][w] = {'x': [], 'y': []} + users.append(w) + num_samples.append(len(l)) + user_data[w] = {'x': [], 'y': []} size = 28, 28 # original image size is 128, 128 for (f, c) in l: @@ -69,16 +69,16 @@ def relabel_class(c): nc = relabel_class(c) - user_data[json_index][w]['x'].append(vec) - user_data[json_index][w]['y'].append(nc) + user_data[w]['x'].append(vec) + user_data[w]['y'].append(nc) writer_count += 1 if writer_count == MAX_WRITERS: all_data = {} - all_data['users'] = users[json_index] - all_data['num_samples'] = num_samples[json_index] - all_data['user_data'] = user_data[json_index] + all_data['users'] = users + all_data['num_samples'] = num_samples + all_data['user_data'] = user_data file_name = 'all_data_%d.json' % json_index file_path = os.path.join(parent_path, 'data', 'all_data', file_name) @@ -90,3 +90,7 @@ def relabel_class(c): writer_count = 0 json_index += 1 + + users[:] = [] + num_samples[:] = [] + user_data.clear() diff --git a/data/utils/remove_users.py b/data/utils/remove_users.py index 3221e3ff..88bec01c 100644 --- a/data/utils/remove_users.py +++ b/data/utils/remove_users.py @@ -55,7 +55,7 @@ if 'hierarchies' in data: curr_hierarchy = data['hierarchies'][i] curr_num_samples = data['num_samples'][i] - if (curr_num_samples > args.min_samples): + if (curr_num_samples >= args.min_samples): user_data[curr_user] = data['user_data'][curr_user] users.append(curr_user) if curr_hierarchy is not None: diff --git a/models/client.py b/models/client.py index 1a454653..09eda5b3 100644 --- a/models/client.py +++ b/models/client.py @@ -48,10 +48,10 @@ def test(self, set_to_use='test'): Return: dict of metrics returned by the model. """ - assert set_to_use in ['train', 'test'] + assert set_to_use in ['train', 'test', 'val'] if set_to_use == 'train': data = self.train_data - elif set_to_use == 'test': + elif set_to_use == 'test' or set_to_use == 'val': data = self.eval_data return self.model.test(data) diff --git a/models/main.py b/models/main.py index db584bc4..f2bb8461 100644 --- a/models/main.py +++ b/models/main.py @@ -61,7 +61,7 @@ def main(): server = Server(client_model) # Create clients - clients = setup_clients(args.dataset, client_model) + clients = setup_clients(args.dataset, client_model, args.use_val_set) client_ids, client_groups, client_num_samples = server.get_clients_info(clients) print('Clients in Total: %d' % len(clients)) @@ -69,7 +69,7 @@ def main(): print('--- Random Initialization ---') stat_writer_fn = get_stat_writer_function(client_ids, client_groups, client_num_samples, args) sys_writer_fn = get_sys_writer_function(args) - print_stats(0, server, clients, client_num_samples, args, stat_writer_fn) + print_stats(0, server, clients, client_num_samples, args, stat_writer_fn, args.use_val_set) # Simulate training for i in range(num_rounds): @@ -88,7 +88,7 @@ def main(): # Test model if (i + 1) % eval_every == 0 or (i + 1) == num_rounds: - print_stats(i + 1, server, clients, client_num_samples, args, stat_writer_fn) + print_stats(i + 1, server, clients, client_num_samples, args, stat_writer_fn, args.use_val_set) # Save server model ckpt_path = os.path.join('checkpoints', args.dataset) @@ -112,14 +112,15 @@ def create_clients(users, groups, train_data, test_data, model): return clients -def setup_clients(dataset, model=None): +def setup_clients(dataset, model=None, use_val_set=False): """Instantiates clients based on given train and test data directories. Return: all_clients: list of Client objects. """ + eval_set = 'test' if not use_val_set else 'val' train_data_dir = os.path.join('..', 'data', dataset, 'data', 'train') - test_data_dir = os.path.join('..', 'data', dataset, 'data', 'test') + test_data_dir = os.path.join('..', 'data', dataset, 'data', eval_set) users, groups, train_data, test_data = read_data(train_data_dir, test_data_dir) @@ -147,15 +148,16 @@ def writer_fn(num_round, ids, metrics, groups, num_samples): def print_stats( - num_round, server, clients, num_samples, args, writer): + num_round, server, clients, num_samples, args, writer, use_val_set): train_stat_metrics = server.test_model(clients, set_to_use='train') print_metrics(train_stat_metrics, num_samples, prefix='train_') writer(num_round, train_stat_metrics, 'train') - test_stat_metrics = server.test_model(clients, set_to_use='test') - print_metrics(test_stat_metrics, num_samples, prefix='test_') - writer(num_round, test_stat_metrics, 'test') + eval_set = 'test' if not use_val_set else 'val' + test_stat_metrics = server.test_model(clients, set_to_use=eval_set) + print_metrics(test_stat_metrics, num_samples, prefix='{}_'.format(eval_set)) + writer(num_round, test_stat_metrics, eval_set) def print_metrics(metrics, weights, prefix=''): diff --git a/models/utils/args.py b/models/utils/args.py index 975b5211..522d000a 100644 --- a/models/utils/args.py +++ b/models/utils/args.py @@ -46,6 +46,9 @@ def parse_args(): type=str, default='metrics', required=False) + parser.add_argument('--use-val-set', + help='use validation set;', + action='store_true') # Minibatch doesn't support num_epochs, so make them mutually exclusive epoch_capability_group = parser.add_mutually_exclusive_group()