Skip to content

Commit

Permalink
fixes #19
Browse files Browse the repository at this point in the history
  • Loading branch information
scaldas committed Mar 11, 2020
1 parent fb5b1da commit 94ff90c
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 22 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
3. Shakespeare

* **Overview:** Text Dataset of Shakespeare Dialogues
* **Details:** 1129 users
* **Details:** 1129 users (reduced to 660 with our choice of sequence length. See [bug](https://github.com/TalwalkarLab/leaf/issues/19).)
* **Task:** Next-Character Prediction

4. Celeba
Expand Down
56 changes: 35 additions & 21 deletions data/utils/split_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,51 +196,65 @@ def create_jsons_for(user_files, which_set, max_users, include_hierarchy):

user_indices = [] # indices of users in data['users'] that are not deleted

removed = 0
for i, u in enumerate(data['users']):
user_data_train[u] = {'x': [], 'y': []}
user_data_test[u] = {'x': [], 'y': []}

curr_num_samples = len(data['user_data'][u]['y'])
if curr_num_samples >= 2:
user_indices.append(i)

# ensures number of train and test samples both >= 1
num_train_samples = max(1, int(args.frac * curr_num_samples))
if curr_num_samples == 2:
num_train_samples = 1

num_test_samples = curr_num_samples - num_train_samples
num_samples_train.append(num_train_samples)
num_samples_test.append(num_test_samples)

indices = [j for j in range(curr_num_samples)]
train_indices = rng.sample(indices, num_train_samples)
train_blist = [False for _ in range(curr_num_samples)]
for j in train_indices:
train_blist[j] = True

for j in range(curr_num_samples):
if (train_blist[j]):
user_data_train[u]['x'].append(data['user_data'][u]['x'][j])
user_data_train[u]['y'].append(data['user_data'][u]['y'][j])
else:
user_data_test[u]['x'].append(data['user_data'][u]['x'][j])
user_data_test[u]['y'].append(data['user_data'][u]['y'][j])
if args.name in ['shakespeare']:
train_indices = [i for i in range(num_train_samples)]
test_indices = [i for i in range(num_train_samples + 80 - 1, curr_num_samples)]
else:
train_indices = rng.sample(indices, num_train_samples)
test_indices = [i for i in range(curr_num_samples) if i not in train_indices]

if len(train_indices) >= 1 and len(test_indices) >= 1:
user_indices.append(i)
num_samples_train.append(num_train_samples)
num_samples_test.append(num_test_samples)
user_data_train[u] = {'x': [], 'y': []}
user_data_test[u] = {'x': [], 'y': []}

train_blist = [False for _ in range(curr_num_samples)]
test_blist = [False for _ in range(curr_num_samples)]

for j in train_indices:
train_blist[j] = True
for j in test_indices:
test_blist[j] = True

for j in range(curr_num_samples):
if (train_blist[j]):
user_data_train[u]['x'].append(data['user_data'][u]['x'][j])
user_data_train[u]['y'].append(data['user_data'][u]['y'][j])
elif (test_blist[j]):
user_data_test[u]['x'].append(data['user_data'][u]['x'][j])
user_data_test[u]['y'].append(data['user_data'][u]['y'][j])

users = [data['users'][i] for i in user_indices]

all_data_train = {}
all_data_train['users'] = users
all_data_train['num_samples'] = num_samples_train
all_data_train['user_data'] = user_data_train

all_data_test = {}
all_data_test['users'] = users
all_data_test['num_samples'] = num_samples_test
all_data_test['user_data'] = user_data_test
all_data_test['user_data'] = user_data_test

if include_hierarchy:
all_data_train['hierarchies'] = data['hierarchies']
all_data_test['hierarchies'] = data['hierarchies']
hierarchies = [data['hierarchies'][i] for i in user_indices]
all_data_train['hierarchies'] = hierarchies
all_data_test['hierarchies'] = hierarchies

file_name_train = '%s_train_%s.json' % ((f[:-5]), arg_label)
file_name_test = '%s_test_%s.json' % ((f[:-5]), arg_label)
Expand Down

0 comments on commit 94ff90c

Please sign in to comment.