-
Notifications
You must be signed in to change notification settings - Fork 88
/
llama_model_optimized.py
554 lines (492 loc) · 20.8 KB
/
llama_model_optimized.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
# SPDX-License-Identifier: Apache-2.0
from loguru import logger
from typing import List
from tqdm import tqdm
import torch
import ttnn
from ttnn import ShardTensorToMesh, ReplicateTensorToMesh
from models.utility_functions import nearest_32, profiler
from models.demos.t3000.llama2_70b.tt.llama_decoder_optimized import TtLlamaDecoder_optimized
from models.demos.t3000.llama2_70b.tt.llama_embedding import TtLlamaEmbedding
from models.demos.t3000.llama2_70b.tt.llama_common import (
freqs_to_rotation_matrix,
get_rotation_mat,
precompute_freqs,
gather_cos_sin,
get_rot_transformation_mat,
)
from models.demos.t3000.falcon40b.tt.model_utils import matmul_2d_config
from models.demos.t3000.llama2_70b.tt.llama_rope import TtLlamaRotarySetup
class TtLlamaModel_optimized:
def __init__(
self,
mesh_device,
state_dict,
base_url,
n_layers,
model_config,
configuration,
cache_path=None,
read_cache=False,
paged_attention_config=None,
vllm=False,
):
self.state_dict = state_dict
self.mesh_device = mesh_device
self.num_devices = mesh_device.get_num_devices()
self.model_config = model_config
self.read_cache = read_cache
self.vllm = vllm
self.hidden_size = configuration.dim
self.n_heads = configuration.n_heads
self.n_local_heads = self.n_heads // self.num_devices
self.padded_local_heads = 32
self.head_dim = self.hidden_size // self.n_heads
self.max_seq_len = configuration.max_seq_len
self.vocab_size = configuration.vocab_size
self.norm_eps = configuration.norm_eps
self.llama3 = self.vocab_size == 128256
self.rope_theta = configuration.rope_theta if self.llama3 else 10000.0
self.use_scaled_rope = getattr(configuration, "use_scaled_rope", False)
self.cache_path = cache_path
# Transformation matrix for rotary embeddings
transformation_mat_torch = get_rot_transformation_mat(32) # 32 for tile size
transformation_mats_prefill = ttnn.as_tensor(
transformation_mat_torch,
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
device=mesh_device,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ReplicateTensorToMesh(mesh_device),
)
transformation_mats_prefill = ttnn.to_device(transformation_mats_prefill, mesh_device)
# Transformation matrix for rotary embeddings (decode)
self.rope_setup_decode = TtLlamaRotarySetup(
self.mesh_device, self.head_dim, self.max_seq_len, self.rope_theta, self.use_scaled_rope
)
transformation_mats_decode = self.rope_setup_decode.get_trans_mats()
transformation_mats = {"prefill": transformation_mats_prefill, "decode": transformation_mats_decode}
logger.info("Creating Layers")
self.layers = [
TtLlamaDecoder_optimized(
mesh_device,
state_dict,
base_url,
layer_num,
model_config,
configuration,
transformation_mats,
cache_path=cache_path,
read_cache=read_cache,
paged_attention_config=paged_attention_config,
vllm=vllm,
)
for layer_num in tqdm(range(n_layers))
]
logger.info("Done creating layers")
# Rotary Embedding
self.cos, self.sin = precompute_freqs(
self.head_dim, self.max_seq_len * 2, self.rope_theta, self.use_scaled_rope
) # for prefill
# Embedding
self.tt_embd = TtLlamaEmbedding(
mesh_device,
state_dict,
cache_path,
)
self.load_weights()
def set_model_config(self, model_config):
self.model_config = model_config
for layer in self.layers:
layer.set_model_config(model_config)
def load_weights(self):
norm_str = "norm.weight"
norm_sharded_str = "norm_sharded.weight"
lm_head_str = "output.weight"
if not self.read_cache:
H = 8 * 1024
if self.llama3:
PADDED_VOCAB = 128 * 1024
else:
PADDED_VOCAB = 32 * 1024
padded_lm_head = torch.zeros(1, 1, H, PADDED_VOCAB)
padded_lm_head[:, :, :, : self.vocab_size] = self.state_dict[lm_head_str].transpose(-2, -1)
pt_norm_weight = self.state_dict[norm_str].reshape([1, 1, -1, 32])
else:
padded_lm_head = None
pt_norm_weight = None
padded_lm_head_ttnn = ttnn.as_tensor(
padded_lm_head,
dtype=ttnn.bfloat8_b,
layout=ttnn.TILE_LAYOUT,
device=self.mesh_device,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=3),
cache_file_name=self.cache_path / lm_head_str,
)
self.lm_head = ttnn.to_device(padded_lm_head_ttnn, self.mesh_device)
norm_ttnn = ttnn.as_tensor(
pt_norm_weight,
dtype=ttnn.bfloat16,
layout=ttnn.ROW_MAJOR_LAYOUT,
device=self.mesh_device,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ReplicateTensorToMesh(self.mesh_device),
cache_file_name=self.cache_path / norm_str,
)
self.norm = ttnn.to_device(norm_ttnn, self.mesh_device)
norm_sharded_ttnn = ttnn.as_tensor(
pt_norm_weight,
dtype=ttnn.bfloat16,
layout=ttnn.ROW_MAJOR_LAYOUT,
device=self.mesh_device,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=2),
cache_file_name=self.cache_path / norm_sharded_str,
)
self.norm_sharded = ttnn.to_device(norm_sharded_ttnn, self.mesh_device)
def validate_input_shape(self, inp_ids, mode):
assert inp_ids.dim() == 2
batch, seq_len = inp_ids.shape
assert (
batch <= self.model_config["MAX_BATCH_SIZE"]
), f"Batch size {batch} exceeds MAX_BATCH_SIZE {self.model_config['MAX_BATCH_SIZE']}"
assert (
seq_len <= self.model_config["MAX_CONTEXT_LEN"]
), f"Sequence length {seq_len} exceeds MAX_CONTEXT_LEN {self.model_config['MAX_CONTEXT_LEN']}"
def prepare_inputs(
self, inp_ids, start_pos, valid_seq_len=None, mode="decode", page_table=None, chunk_page_table=None
):
"""
Prepare inputs for decode mode. Assume that current token is at
start_pos, and KV cache has valid data up to start_pos.
inp_ids: (batch, seq)
start_pos: int
valid_seq_len: int, optional for mask padding
returns:
xs: [(seq, batch, hidden_dim)] * num_devices
start_pos: int
rot_mats: None for decode
[(1, 1, seq, head_dim), (1, 1, seq, head_dim)] * num_devices for prefill
rot_idx_tt: [(batch, 1, 1)] * num_devices for decode
None for prefill
"""
self.validate_input_shape(inp_ids, mode)
batch, seq_len = inp_ids.shape
cache_name = lambda name: self.cache_path / (f"{'llama3_' if self.llama3 else ''}{name}")
if mode == "decode":
inp_ids = inp_ids.reshape(seq_len, 1, 1, batch)
# Pad to PADDED_BATCH_SIZE
inp_ids = torch.nn.functional.pad(inp_ids, (0, self.model_config["PADDED_BATCH_SIZE"] - batch), value=0)
else:
inp_ids = inp_ids.reshape(batch, 1, 1, seq_len)
x = ttnn.as_tensor(
inp_ids,
dtype=ttnn.uint32,
layout=ttnn.ROW_MAJOR_LAYOUT,
mesh_mapper=ReplicateTensorToMesh(self.mesh_device),
)
if mode == "prefill":
assert seq_len % 32 == 0 and seq_len > 0, "Prefill mode only supports seqlen as a multiple of 32"
assert batch == 1, "prefill mode only supports batch size 1"
x = ttnn.to_device(x, self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG)
xs = self.tt_embd(x)
assert xs.shape == (batch, 1, seq_len, self.hidden_size // self.num_devices)
cos_gathered, sin_gathered = gather_cos_sin(
torch.arange(start_pos, start_pos + seq_len), self.cos, self.sin
)
assert cos_gathered.size() == (1, 1, seq_len, self.head_dim)
assert sin_gathered.size() == (1, 1, seq_len, self.head_dim)
cos_gathereds = ttnn.as_tensor(
cos_gathered,
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
cache_file_name=cache_name(f"cos_gathered_prefill_{start_pos}_{start_pos+seq_len}"),
memory_config=ttnn.DRAM_MEMORY_CONFIG,
device=self.mesh_device,
mesh_mapper=ReplicateTensorToMesh(self.mesh_device),
)
sin_gathereds = ttnn.as_tensor(
sin_gathered,
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
cache_file_name=cache_name(f"sin_gathered_prefill_{start_pos}_{start_pos+seq_len}"),
memory_config=ttnn.DRAM_MEMORY_CONFIG,
device=self.mesh_device,
mesh_mapper=ReplicateTensorToMesh(self.mesh_device),
)
cos_gathereds = ttnn.to_device(cos_gathereds, self.mesh_device)
sin_gathereds = ttnn.to_device(sin_gathereds, self.mesh_device)
rot_mats = [cos_gathereds, sin_gathereds]
rot_idxs_tt = None # unused in prefill mode
cache_idxs_tt = None # unused in prefill mode
if isinstance(page_table, torch.Tensor):
# Support vLLM tensor page_table input
page_table = ttnn.as_tensor(
page_table,
device=self.mesh_device,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
dtype=ttnn.int32,
layout=ttnn.ROW_MAJOR_LAYOUT,
mesh_mapper=ReplicateTensorToMesh(self.mesh_device),
)
if chunk_page_table is not None:
chunk_page_table = ttnn.as_tensor(
chunk_page_table,
device=self.mesh_device,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
dtype=ttnn.int32,
layout=ttnn.ROW_MAJOR_LAYOUT,
mesh_mapper=ReplicateTensorToMesh(self.mesh_device),
)
return (xs, start_pos, rot_mats, rot_idxs_tt, cache_idxs_tt, page_table, chunk_page_table)
elif mode == "decode":
assert seq_len == 1, "Decode mode only supports seq_len=1"
xs = x
# User can provide a single start pos which applies to the whole batch or a list of start positions
if isinstance(start_pos, int):
cache_idxs = torch.tensor([start_pos for _ in range(batch)], dtype=torch.int64)
else:
cache_idxs = start_pos.to(dtype=torch.int64)
cache_idxs_tt = ttnn.as_tensor(
cache_idxs,
dtype=ttnn.int32,
layout=ttnn.ROW_MAJOR_LAYOUT,
mesh_mapper=ReplicateTensorToMesh(self.mesh_device),
)
rot_mats = None # Created in prepare_device_inputs
rot_cache_idxs = torch.maximum(
cache_idxs, torch.tensor(0, dtype=torch.int64)
) # Ensure position indices are non-negative
rot_idxs_tt = self.rope_setup_decode.get_rot_idxs(rot_cache_idxs)
if isinstance(page_table, torch.Tensor):
# Support vLLM tensor page_table input
page_table = ttnn.as_tensor(
page_table,
dtype=ttnn.int32,
layout=ttnn.ROW_MAJOR_LAYOUT,
mesh_mapper=ReplicateTensorToMesh(self.mesh_device),
)
return (xs, start_pos, rot_mats, rot_idxs_tt, cache_idxs_tt, page_table)
def prepare_device_inputs_decode(
self,
tokens: torch.Tensor,
start_pos: int,
valid_seq_len=None,
mode="decode",
page_table=None,
return_tokens=False, # if true, return tokens for decode mode
return_rot_idxs=False, # if true, return rot_idxs for decode mode
):
assert mode == "decode"
tt_inp, start_pos, rot_mat, rot_idxs_tt, cache_idxs_tt, tt_page_table = self.prepare_inputs(
tokens, start_pos, valid_seq_len=valid_seq_len, mode=mode, page_table=page_table
)
tt_inp = ttnn.to_device(tt_inp, self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG)
tt_inp_emb = self.tt_embd(tt_inp)
tt_inp_emb = ttnn.interleaved_to_sharded(tt_inp_emb, self.model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"])
cache_idxs_tt = ttnn.to_device(cache_idxs_tt, self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG)
rot_mat, rot_idxs_tt = self.rope_setup_decode.get_rot_mats(
rot_idxs_tt, return_rot_idxs=True
) # Sends rot_idxs to device internally
if tt_page_table is not None:
tt_page_table = ttnn.to_device(tt_page_table, self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG)
return_out = [tt_inp_emb, start_pos, rot_mat, cache_idxs_tt, tt_page_table]
if return_tokens:
return_out.append(tt_inp)
if return_rot_idxs:
return_out.append(rot_idxs_tt)
return tuple(return_out)
def __call__(
self,
xs: List[ttnn.Tensor],
rot_mats: List[ttnn.Tensor],
start_pos: int,
user_id: int = 0,
cache_idxs=None,
last_token_idx=None,
page_table=None,
kv_cache=None,
mode="decode",
chunk_page_table=None,
chunk_start_idx=None,
) -> ttnn.Tensor:
if self.vllm:
assert page_table is not None
assert kv_cache is not None
if mode == "prefill":
return self.prefill_forward(
xs,
rot_mats,
start_pos,
user_id,
last_token_idx=last_token_idx,
page_table=page_table,
kv_cache=kv_cache,
chunk_page_table=chunk_page_table,
chunk_start_idx=chunk_start_idx,
)
elif mode == "decode":
return self.decode_forward(xs, rot_mats, start_pos, cache_idxs, page_table=page_table, kv_cache=kv_cache)
else:
raise ValueError(f"Unknown llm_mode: {mode}")
def decode_forward(
self,
xs: List[ttnn.Tensor],
rot_mats: List[ttnn.Tensor],
start_pos: int,
cache_idxs,
page_table=None,
kv_cache=None,
) -> ttnn.Tensor:
### Run all layers
for layer in self.layers:
xs = layer(
xs, rot_mats, start_pos, cache_idxs=cache_idxs, page_table=page_table, kv_cache=kv_cache, mode="decode"
) # xs is sharded
xs = ttnn.all_gather(
xs,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
memory_config=self.model_config["FINAL_ALL_GATHER_OUTPUT_MEMCFG"],
)
# In-place RMSNorm
norm_out_replicated = ttnn.rms_norm(
xs,
epsilon=self.norm_eps,
weight=self.norm,
program_config=self.model_config["LN_F_PROGCFG"],
memory_config=self.model_config["FINAL_ALL_GATHER_OUTPUT_MEMCFG"],
compute_kernel_config=self.model_config["LN_COMPUTE_KERNEL_CONFIG"],
)
### Each device does an LM head fracture
lm_head_out = ttnn.matmul(
norm_out_replicated,
self.lm_head,
program_config=(
self.model_config["LLAMA3_LM_HEAD_MM_PROGCFG"]
if self.llama3
else self.model_config["LM_HEAD_MM_PROGCFG"]
),
memory_config=ttnn.DRAM_MEMORY_CONFIG,
dtype=ttnn.bfloat16,
compute_kernel_config=self.model_config["COMPUTE_KERNEL_CONFIG"],
)
norm_out_replicated.deallocate(True)
return lm_head_out
def tt_distributed_rmsnorm(self, inp, epsilon, gamma):
# Run distributed rmsnorm part 1
tt_stats = ttnn.rms_norm_pre_all_gather(
inp, compute_kernel_config=self.model_config["LN_COMPUTE_KERNEL_CONFIG"], dtype=ttnn.bfloat16
)
# AllGather stats
tt_stats = ttnn.all_gather(
tt_stats,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
# Run distributed rmsnorm part 2
tt_out = ttnn.rms_norm_post_all_gather(
inp,
tt_stats,
epsilon=epsilon,
weight=gamma,
compute_kernel_config=self.model_config["LN_COMPUTE_KERNEL_CONFIG"],
)
tt_stats.deallocate(True)
return tt_out
def prefill_forward(
self,
xs: List[ttnn.Tensor],
rot_mats: List[ttnn.Tensor],
start_pos: int,
user_id: int = 0,
last_token_idx=None,
page_table=None,
kv_cache=None,
chunk_page_table=None,
chunk_start_idx=None,
) -> ttnn.Tensor:
### Run all layers
for layer in self.layers:
xs = layer(
xs,
rot_mats,
start_pos,
user_id,
page_table=page_table,
kv_cache=kv_cache,
mode="prefill",
chunk_page_table=chunk_page_table,
chunk_start_idx=chunk_start_idx,
) # xs is sharded
# Distributed rmsnorm
norm_out = self.tt_distributed_rmsnorm(xs, self.norm_eps, self.norm_sharded)
norm_out_replicated = ttnn.all_gather(
norm_out,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
# Deallocate original input to rmsnorm
xs.deallocate(True)
### Each device does an LM head fracture
_, _, seq_len, dmodel = norm_out_replicated.shape
if last_token_idx:
last_token_tile = last_token_idx // 32
norm_out_replicated = ttnn.slice(
norm_out_replicated,
(0, 0, last_token_tile * 32, 0),
(1, 1, (last_token_tile + 1) * 32, dmodel),
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
pc_lm_head = (
self.model_config["PREFILL_LLAMA3_LM_HEAD_MM_PROGCFG"]
if self.llama3
else self.model_config["PREFILL_LM_HEAD_MM_PROGCFG"]
)
else:
max_mm_seq_len = self.model_config["MAX_MM_SEQ_LEN"]
if seq_len >= max_mm_seq_len:
if seq_len % max_mm_seq_len != 0:
raise ValueError(f"Sequence length {seq_len} is not divisible by {max_mm_seq_len}")
batch_dim = seq_len // max_mm_seq_len # Find the division factor
norm_out_replicated = ttnn.reshape(norm_out_replicated, (1, batch_dim, seq_len // batch_dim, -1))
pc_lm_head = (
self.model_config["PREFILL_LLAMA3_LM_HEAD_MM_PROGCFG"]
if self.llama3
else self.model_config["PREFILL_LM_HEAD_MM_PROGCFG"]
)
elif seq_len == 128:
pc_lm_head = (
self.model_config["PREFILL_LLAMA3_LM_HEAD_MM_PROGCFG_128"]
if self.llama3
else self.model_config["PREFILL_LM_HEAD_MM_PROGCFG_128"]
)
else:
pc_lm_head = matmul_2d_config(
m=norm_out_replicated.shape[2],
k=norm_out_replicated.shape[3],
n=self.lm_head.shape[3],
overwrite_per_core_k=1,
grid=ttnn.CoreGrid(y=min(8, norm_out_replicated.shape[2] // 32), x=8),
is_fp32_accumulate=False,
overwrite_subblock_h=1,
overwrite_subblock_w=1,
)
lm_head_out = ttnn.linear(
norm_out_replicated,
self.lm_head,
# TODO: increase precision?
compute_kernel_config=self.model_config["COMPUTE_KERNEL_FP16_ACC_CONFIG"],
core_grid=ttnn.CoreGrid(y=8, x=8) if not pc_lm_head else None,
dtype=ttnn.bfloat16,
program_config=pc_lm_head,
)
norm_out_replicated.deallocate(True)
if not last_token_idx and seq_len >= max_mm_seq_len:
# Prefill Reshape fix (reverse)
lm_head_out = ttnn.reshape(lm_head_out, (1, 1, seq_len, -1))
return lm_head_out