-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathstarcoder_fim_main.py
234 lines (193 loc) · 7.94 KB
/
starcoder_fim_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import os
import fire
import glob
import tqdm
import json
import numpy as np
from functools import partial
from transformers import AutoTokenizer
import multiprocessing
TOKENIZED_DATASETS_DIR = '/lustre/scratch/shared-folders/llm_project/bowen.tan/tokenized_datasets'
FIM_RATE = 1.
OUTPUT_DIR = f'/lustre/scratch/shared-folders/llm_project/bowen.tan/fim_datasets'
TOKENIZER_NAME = 'huggyllama/llama-7b'
WORD_BUFFER_SIZE = 2048 * 2
CONTEXT_LENGTH = 2048
MULTIPROCESSING_BUFFERSIZE = 12800
MULTIPROCESSING_CHUNKSIZE = 100
ADDITIONAL_SPECIAL_TOKENS = [
"<fim_prefix>",
"<fim_middle>",
"<fim_suffix>",
"<fim_pad>",
"<filename>",
"<gh_stars>",
"<issue_start>",
"<issue_comment>",
"<issue_closed>",
"<jupyter_start>",
"<jupyter_text>",
"<jupyter_code>",
"<jupyter_output>",
"<empty_output>",
"<commit_before>",
"<commit_msg>",
"<commit_after>",
"<reponame>"
]
def tokenize_text(text, tokenizer):
return tokenizer(text, add_special_tokens=False)['input_ids']
# From https://github.com/EleutherAI/gpt-neox/blob/FIM-clean/megatron/data/gpt2_dataset.py#L339
def permute(sample, tokenizer, fim_rate, spm_rate, truncate_or_pad):
"""
Take in a sample (np array w/ size (0,chunklength)) and perform a FIM transformation on it.
Maintain the same sample length (if transform creates a few extra tokens, drop them).
"""
suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id = (
tokenizer.vocab[tok]
for tok in ['<fim_suffix>', '<fim_prefix>', '<fim_middle>', '<fim_pad>']
)
if np.random.binomial(1, fim_rate): # sample bernoulli dist
contents = tokenizer.decode(sample, skip_special_tokens=False)
try:
# A boundary can be =0 (prefix will be empty)
# a boundary can be =len(contents) (suffix will be empty)
# The two boundaries can be equal (middle will be empty)
boundaries = list(
np.random.randint(low=0, high=len(contents) + 1, size=2))
boundaries.sort()
except ValueError as e:
print(len(contents), contents)
print(e)
raise e
prefix = contents[:boundaries[0]]
middle = contents[boundaries[0]:boundaries[1]]
suffix = contents[boundaries[1]:]
prefix = np.array(
tokenize_text(prefix, tokenizer=tokenizer), dtype=np.int64)
middle = np.array(
tokenize_text(middle, tokenizer=tokenizer), dtype=np.int64)
suffix = np.array(
tokenize_text(suffix, tokenizer=tokenizer), dtype=np.int64)
# here we truncate each given segment to fit the same length as it was before
# A consequence is that we never reach the end of a file?
# we should rather truncate at the context-level
if truncate_or_pad:
# need to make same length as the input. Take the 3 sentinel tokens into account
new_length = suffix.shape[0] + prefix.shape[0] + middle.shape[0] + 3
diff = new_length - sample.shape[0]
if diff > 0: # too long
if suffix.shape[
0] <= diff: # if there's no space to truncate the suffix: stop and report it. atm i should have stopped this from happening
return sample
suffix = suffix[:suffix.shape[0] - diff]
elif diff < 0: # too short
suffix = np.concatenate(
[suffix, np.full((-1 * diff), pad_tok_id)])
if np.random.binomial(1, spm_rate):
# SPM (variant 2 from FIM paper)
new_sample = np.concatenate([
[prefix_tok_id, suffix_tok_id], suffix,
[middle_tok_id], prefix, middle
])
else:
# PSM
new_sample = np.concatenate([
[prefix_tok_id], prefix,
[suffix_tok_id], suffix,
[middle_tok_id], middle
])
else:
# don't do FIM preproc
new_sample = sample
return new_sample
def fim(token_ids, tokenizer, fim_rate, spm_rate):
sample = np.array(token_ids, dtype=np.int64)
sample_len = sample.shape[0]
permute_fn = partial(
permute,
tokenizer=tokenizer,
fim_rate=fim_rate,
spm_rate=spm_rate,
truncate_or_pad=False)
if fim_rate != 0:
assert (fim_rate <= 1 and fim_rate >= 0), \
"FIM rate must be a probability 0 <= rate <= 1"
eod = tokenizer.eos_token_id
pad = tokenizer.vocab['<fim_pad>']
segment_breaks = np.argwhere(sample == eod) # split sample by document
if segment_breaks.shape != (0, 1):
# then there is an EOD token in this example
curr_start_position = 0
new_samples = []
for loc in np.nditer(segment_breaks):
# Only permute non-empty segments.
if loc - curr_start_position > 0:
# permute {prefix, suffix, middle} or {suffix, prefix, middle}
permuted = permute_fn(
sample=sample[curr_start_position:loc])
new_samples += [permuted, [eod]]
curr_start_position = loc + 1 # jump over the EOD token
# Permute the segment after the last EOD
permuted = permute_fn(sample=sample[curr_start_position:])
new_samples.append(permuted)
sample = np.concatenate(new_samples)
else:
sample = permute_fn(sample=sample)
# Truncate or pad sequence to max-length
diff = sample.shape[0] - sample_len
if diff > 0: # too long
sample = sample[:sample_len]
elif diff < 0: # too short
sample = np.concatenate([sample, np.full((-1 * diff), pad)])
assert sample.shape[0] == sample_len
# end FIM-specific code
return sample.tolist()
# return {'text': np.array(sample, dtype=np.int64)}
def process_example(example, tokenizer, fim_rate, spm_rate):
example['token_ids'] = fim(
token_ids=example['token_ids'],
tokenizer=tokenizer,
fim_rate=fim_rate,
spm_rate=spm_rate)
return example
def main(spm_rate=0.):
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
tokenizer.add_special_tokens(
{'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS})
process_fn = partial(
process_example,
tokenizer=tokenizer,
fim_rate=FIM_RATE,
spm_rate=spm_rate)
for subset_name in os.listdir(f'{TOKENIZED_DATASETS_DIR}/'):
if not subset_name.startswith('starcoder.'):
continue
for tokenized_chunk_filename in glob.glob(
f'{TOKENIZED_DATASETS_DIR}/{subset_name}/*.jsonl'):
subset_output_dir = f'{OUTPUT_DIR}/{subset_name}.spm{spm_rate}'
output_filename = (
f'{subset_output_dir}/' +
tokenized_chunk_filename.split('/')[-1])
if os.path.exists(output_filename):
print(f'{output_filename} exists. skipped.')
continue
else:
os.makedirs(subset_output_dir, exist_ok=True)
with open(output_filename, 'w') as output_file:
pool = multiprocessing.Pool(processes=os.cpu_count())
buffer = []
for line in tqdm.tqdm(
open(tokenized_chunk_filename), desc=output_filename):
buffer.append(json.loads(line))
if len(buffer) == MULTIPROCESSING_BUFFERSIZE:
for example in pool.map(
process_fn,
buffer,
chunksize=MULTIPROCESSING_CHUNKSIZE):
output_file.write(json.dumps(example) + '\n')
buffer = []
for example in pool.map(process_fn, buffer):
output_file.write(json.dumps(example) + '\n')
if __name__ == '__main__':
fire.Fire(main)