forked from intel-analytics/ipex-llm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
926 lines (829 loc) · 32.8 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# The code is adapted from: https://github.com/state-spaces/mamba.
#
import json
import math
import os
import time
from collections import namedtuple
from dataclasses import dataclass, field
from functools import partial
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import Tensor
from transformers.generation import (
GreedySearchDecoderOnlyOutput,
SampleDecoderOnlyOutput,
TextStreamer,
)
from transformers.utils import CONFIG_NAME, WEIGHTS_NAME
from transformers.utils.hub import cached_file
@dataclass
class MambaConfig:
d_model: int = 2560
n_layer: int = 64
vocab_size: int = 50277
ssm_cfg: dict = field(default_factory=dict)
rms_norm: bool = True
fused_add_norm: bool = False
residual_in_fp32: bool = True
pad_vocab_size_multiple: int = 8
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(
module,
n_layer,
initializer_range=0.02,
rescale_prenorm_residual=True,
n_residuals_per_layer=1,
):
if isinstance(module, nn.Linear):
if module.bias is not None:
if not getattr(module.bias, "_no_reinit", False):
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=initializer_range)
if rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["out_proj.weight", "fc2.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
# We need to reinit p since this code could be called multiple times
# Having just p *= scale would repeatedly scale it down
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
with torch.no_grad():
p /= math.sqrt(n_residuals_per_layer * n_layer)
def selective_scan(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False):
"""
u: r(B D L)
delta: r(B D L)
A: c(D N) or r(D N)
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
D: r(D)
z: r(B D L)
delta_bias: r(D), fp32
out: r(B D L)
last_state (optional): r(B D dstate) or c(B D dstate)
"""
dtype_in = u.dtype
u = u.float()
delta = delta.float()
if delta_bias is not None:
delta = delta + delta_bias[..., None].float()
if delta_softplus:
delta = F.softplus(delta)
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
is_variable_B = B.dim() >= 3
is_variable_C = C.dim() >= 3
if A.is_complex():
if is_variable_B:
B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
if is_variable_C:
C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
else:
B = B.float()
C = C.float()
x = A.new_zeros((batch, dim, dstate))
ys = []
deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A))
if not is_variable_B:
deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u)
else:
if B.dim() == 3:
deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u)
else:
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
if is_variable_C and C.dim() == 4:
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
last_state = None
for i in range(u.shape[2]):
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
if not is_variable_C:
y = torch.einsum("bdn,dn->bd", x, C)
else:
if C.dim() == 3:
y = torch.einsum("bdn,bn->bd", x, C[:, :, i])
else:
y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i])
if i == u.shape[2] - 1:
last_state = x
if y.is_complex():
y = y.real * 2
ys.append(y)
y = torch.stack(ys, dim=2) # (batch dim L)
out = y if D is None else y + u * rearrange(D, "d -> d 1")
if z is not None:
out = out * F.silu(z)
out = out.to(dtype=dtype_in)
return out if not return_last_state else (out, last_state)
def layer_norm(x, weight, bias, residual=None, eps=1e-6, prenorm=False):
dtype = x.dtype
if residual is not None:
x = (x + residual).to(x.dtype)
out = F.layer_norm(
x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps
).to(dtype)
return out if not prenorm else (out, x)
def rms_norm(x, weight, bias, residual=None, eps=1e-6, prenorm=False):
dtype = x.dtype
if residual is not None:
x = (x + residual).to(x.dtype)
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
out = out.to(dtype)
return out if not prenorm else (out, x)
def load_config_hf(model_name):
resolved_archive_file = cached_file(
model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False
)
return json.load(open(resolved_archive_file))
def load_state_dict_hf(model_name, device=None, dtype=None):
mapped_device = "cpu" if dtype not in [torch.float32, None] else device
resolved_archive_file = cached_file(
model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False
)
return torch.load(resolved_archive_file, map_location=mapped_device)
@dataclass
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
max_seqlen: int
max_batch_size: int
seqlen_offset: int = 0
batch_size_offset: int = 0
key_value_memory_dict: dict = field(default_factory=dict)
lengths_per_sample: Optional[Tensor] = None
def reset(self, max_seqlen, max_batch_size):
self.max_seqlen = max_seqlen
self.max_batch_size = max_batch_size
self.seqlen_offset = 0
if self.lengths_per_sample is not None:
self.lengths_per_sample.zero_()
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
def modify_logits_for_top_p_filtering(logits, top_p):
"""Set the logits for none top-p values to -inf. Done in-place."""
if top_p <= 0.0 or top_p >= 1.0:
return
# First sort and calculate cumulative sum of probabilities.
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
logits.masked_fill_(indices_to_remove, float("-inf"))
def modify_logit_for_repetition_penalty(
logits, prev_output_tokens, repetition_penalty=1.0
):
"""Apply repetition penalty. See https://arxiv.org/abs/1909.05858
logits: (batch_size, vocab_size)
prev_output_tokens: (batch_size, seq_len)
"""
if repetition_penalty == 1.0:
return logits
score = torch.gather(logits, 1, prev_output_tokens)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(
score < 0, score * repetition_penalty, score / repetition_penalty
)
logits.scatter_(1, prev_output_tokens, score)
return logits
def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
"""Sample from top-k logits.
Arguments:
logits: Tensor of shape (batch_size, vocab_size)
"""
if top_k == 1: # Short-circuit for greedy decoding
return logits.argmax(dim=-1)
else:
if top_p > 0.0:
assert top_p <= 1.0, "top-p should be in (0, 1]."
if top_k > 0:
top_k = min(top_k, logits.size(-1)) # Safety check
logits_top, indices = torch.topk(logits, top_k, dim=-1)
if temperature != 1.0:
logits_top /= temperature
modify_logits_for_top_p_filtering(logits_top, top_p)
return indices[
torch.arange(indices.shape[0], device=indices.device),
torch.multinomial(
torch.softmax(logits_top, dim=-1), num_samples=1
).squeeze(dim=-1),
]
else:
# Clone so that when we modify for top_p we don't change the original logits
logits_top = logits / temperature if temperature != 1.0 else logits.clone()
modify_logits_for_top_p_filtering(logits_top, top_p)
return torch.multinomial(
torch.softmax(logits_top, dim=-1), num_samples=1
).squeeze(dim=-1)
@torch.inference_mode()
def decode(
input_ids,
model,
max_new_tokens,
top_k=1,
top_p=0.0,
temperature=1.0,
repetition_penalty=1.0,
eos_token_id=None,
teacher_outputs=None,
vocab_size=None,
streamer: Optional[TextStreamer] = None,
):
"""Decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
then top-p.
We assume that all sequences in the same batch have the same length.
Arguments:
input_ids: (batch, seq_len)
max_new_tokens: int
teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
logits, the next token is taken from the teacher_outputs. Useful for testing.
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
sequences: (batch, max_length)
scores: tuples of (batch, vocab_size)
"""
if streamer is not None:
streamer.put(input_ids.cpu())
max_length = input_ids.shape[1] + max_new_tokens
batch_size = input_ids.shape[0]
teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
def get_logits(input_ids, inference_params):
decoding = inference_params.seqlen_offset > 0
if decoding:
position_ids = torch.full(
(batch_size, 1),
inference_params.seqlen_offset,
dtype=torch.long,
device=input_ids.device,
)
else:
position_ids = None
logits = model(
input_ids,
position_ids=position_ids,
inference_params=inference_params,
num_last_tokens=1,
).logits.squeeze(dim=1)
return logits[..., :vocab_size] if vocab_size is not None else logits
def sample_tokens(logits, inference_params):
if (
teacher_outputs is None
or teacher_output_len <= inference_params.seqlen_offset
):
token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
else:
token = teacher_outputs[:, inference_params.seqlen_offset]
# return rearrange(token, "b -> b 1")
return token.unsqueeze(1)
def should_stop(current_token, inference_params):
if inference_params.seqlen_offset == 0:
return False
if eos_token_id is not None and (current_token == eos_token_id).all():
return True
if inference_params.seqlen_offset >= max_length - 1:
return True
return False
scores, sequences = [], [input_ids]
sequences_cat = input_ids
while not should_stop(sequences[-1], inference_params):
scores.append(get_logits(sequences[-1], inference_params))
inference_params.seqlen_offset += sequences[-1].shape[1]
if repetition_penalty == 1.0:
sampled_tokens = sample_tokens(scores[-1], inference_params)
else:
logits = modify_logit_for_repetition_penalty(
scores[-1].clone(), sequences_cat, repetition_penalty
)
sampled_tokens = sample_tokens(logits, inference_params)
sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)
sequences.append(sampled_tokens)
if streamer is not None:
streamer.put(sampled_tokens.cpu())
if streamer is not None:
streamer.end()
output_cls = (
GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
)
return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
class GenerationMixin:
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
raise NotImplementedError
def generate(
self,
input_ids,
max_new_tokens,
top_k=1,
top_p=0.0,
temperature=1.0,
return_dict_in_generate=False,
output_scores=False,
**kwargs,
):
output = decode(
input_ids,
self,
max_new_tokens,
top_k=top_k,
top_p=top_p,
temperature=temperature,
**kwargs,
)
if not output_scores:
output.scores = None
return output if return_dict_in_generate else output.sequences
class Block(nn.Module):
def __init__(self, dim, mixer_cls, norm_cls=nn.LayerNorm, residual_in_fp32=False):
"""
Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
This Block has a slightly different structure compared to a regular
prenorm Transformer block.
The standard block is: LN -> MHA/MLP -> Add.
[Ref: https://arxiv.org/abs/2002.04745]
Here we have: Add -> LN -> Mixer, returning both
the hidden_states (output of the mixer) and the residual.
This is purely for performance reasons, as we can fuse add and LayerNorm.
The residual needs to be provided (except for the very first block).
"""
super().__init__()
self.residual_in_fp32 = residual_in_fp32
self.mixer = mixer_cls(dim)
self.norm = norm_cls(dim)
def forward(
self,
hidden_states: Tensor,
residual: Optional[Tensor] = None,
inference_params=None,
):
r"""Pass the input through the encoder layer.
Args:
hidden_states: the sequence to the encoder layer (required).
residual: hidden_states = Mixer(LN(residual))
"""
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
return hidden_states, residual
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mixer.allocate_inference_cache(
batch_size, max_seqlen, dtype=dtype, **kwargs
)
class RMSNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.ones_(self.weight)
def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
return rms_norm(
x,
self.weight,
self.bias,
residual=residual,
eps=self.eps,
prenorm=prenorm,
)
class Mamba(nn.Module):
def __init__(
self,
d_model,
d_state=16,
d_conv=4,
expand=2,
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
conv_bias=True,
bias=False,
use_fast_path=True, # Fused kernel options
layer_idx=None,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
self.use_fast_path = use_fast_path
self.layer_idx = layer_idx
self.dt_proj_in_feature = self.dt_rank
self.in_proj = nn.Linear(
self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs
)
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
bias=conv_bias,
kernel_size=d_conv,
groups=self.d_inner,
padding=d_conv - 1,
**factory_kwargs,
)
self.activation = "silu"
self.act = nn.SiLU()
self.x_proj = nn.Linear(
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
)
self.dt_proj = nn.Linear(
self.dt_rank, self.d_inner, bias=True, **factory_kwargs
)
# Initialize special dt projection to preserve variance at initialization
dt_init_std = self.dt_rank**-0.5 * dt_scale
if dt_init == "constant":
nn.init.constant_(self.dt_proj.weight, dt_init_std)
elif dt_init == "random":
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
torch.rand(self.d_inner, **factory_kwargs)
* (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
self.dt_proj.bias.copy_(inv_dt)
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
self.dt_proj.bias._no_reinit = True
# S4D real initialization
A = repeat(
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
"n -> d n",
d=self.d_inner,
).contiguous()
A_log = torch.log(A) # Keep A_log in fp32
self.A_log = nn.Parameter(A_log)
self.A_log._no_weight_decay = True
# D "skip" parameter
self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
self.D._no_weight_decay = True
self.out_proj = nn.Linear(
self.d_inner, self.d_model, bias=bias, **factory_kwargs
)
def forward(self, hidden_states, inference_params=None):
"""
hidden_states: (B, L, D)
Returns: same shape as hidden_states
"""
batch, seqlen, _ = hidden_states.shape
conv_state, ssm_state = None, None
if inference_params is not None:
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
if inference_params.seqlen_offset > 0:
# The states are updated inplace
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
return out
# We do matmul and transpose BLH -> HBL at the same time
xz = rearrange(
self.in_proj(rearrange(hidden_states, "b l d -> d (b l)").t()).t(),
"d (b l) -> b d l",
l=seqlen,
)
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
# In the backward pass we write dx and dz next to each other to avoid torch.cat
x, z = xz.chunk(2, dim=1)
# Compute short convolution
if conv_state is not None:
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
conv_state.copy_(
F.pad(x, (self.d_conv - x.shape[-1], 0))
) # Update state (B D W)
# if causal_conv1d_fn is None:
x = self.act(self.conv1d(x)[..., :seqlen])
# We're careful here about the layout, to avoid extra transposes.
# We want dt to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
dt, B, C = torch.split(
x_dbl, [self.dt_proj_in_feature, self.d_state, self.d_state], dim=-1
)
dt = self.dt_proj(dt).t()
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
assert self.activation in ["silu", "swish"]
y = selective_scan(
x,
dt,
A,
B,
C,
self.D.float(),
z=z,
delta_bias=None,
delta_softplus=True,
return_last_state=ssm_state is not None,
)
if ssm_state is not None:
y, last_state = y
ssm_state.copy_(last_state)
y = rearrange(y, "b d l -> b l d")
out = self.out_proj(y)
return out
def step(self, hidden_states, conv_state, ssm_state):
dtype = hidden_states.dtype
assert (
hidden_states.shape[1] == 1
), "Only support decoding with 1 token at a time for now"
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
x, z = xz.chunk(2, dim=-1) # (B D)
# Conv step
conv_state.copy_(
torch.roll(conv_state, shifts=-1, dims=-1)
) # Update state (B D W)
conv_state[:, :, -1] = x
x = torch.sum(
conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
) # (B D)
if self.conv1d.bias is not None:
x = x + self.conv1d.bias
x = self.act(x).to(dtype=dtype)
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
dt, B, C = torch.split(
x_db, [self.dt_proj_in_feature, self.d_state, self.d_state], dim=-1
)
dt = self.dt_proj(dt)
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
# SSM step
# Discretize A and B
dt = F.softplus(dt)
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
dB = torch.einsum("bd,bn->bdn", dt, B)
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
y = y + self.D.to(dtype) * x
y = y * self.act(z) # (B D)
out = self.out_proj(y)
return out.unsqueeze(1), conv_state, ssm_state
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
device = self.out_proj.weight.device
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
conv_state = torch.zeros(
batch_size,
self.d_model * self.expand,
self.d_conv,
device=device,
dtype=conv_dtype,
)
ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
# ssm_dtype = torch.float32
ssm_state = torch.zeros(
batch_size,
self.d_model * self.expand,
self.d_state,
device=device,
dtype=ssm_dtype,
)
return conv_state, ssm_state
def _get_states_from_cache(
self, inference_params, batch_size, initialize_states=False
):
assert self.layer_idx is not None
if self.layer_idx not in inference_params.key_value_memory_dict:
batch_shape = (batch_size,)
conv_state = torch.zeros(
batch_size,
self.d_model * self.expand,
self.d_conv,
device=self.conv1d.weight.device,
dtype=self.conv1d.weight.dtype,
)
ssm_state = torch.zeros(
batch_size,
self.d_model * self.expand,
self.d_state,
device=self.dt_proj.weight.device,
dtype=self.dt_proj.weight.dtype,
# dtype=torch.float32,
)
inference_params.key_value_memory_dict[self.layer_idx] = (
conv_state,
ssm_state,
)
else:
conv_state, ssm_state = inference_params.key_value_memory_dict[
self.layer_idx
]
# TODO: What if batch size changes between generation, and we reuse the same states?
if initialize_states:
conv_state.zero_()
ssm_state.zero_()
return conv_state, ssm_state
def create_block(
d_model,
ssm_cfg=None,
norm_epsilon=1e-5,
rms_norm=False,
residual_in_fp32=False,
layer_idx=None,
device=None,
dtype=None,
):
if ssm_cfg is None:
ssm_cfg = {}
factory_kwargs = {"device": device, "dtype": dtype}
mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
norm_cls = partial(
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
)
block = Block(
d_model,
mixer_cls,
norm_cls=norm_cls,
residual_in_fp32=residual_in_fp32,
)
block.layer_idx = layer_idx
return block
class MixerModel(nn.Module):
def __init__(
self,
d_model: int,
n_layer: int,
vocab_size: int,
ssm_cfg=None,
norm_epsilon: float = 1e-5,
rms_norm: bool = False,
initializer_cfg=None,
residual_in_fp32=False,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.residual_in_fp32 = residual_in_fp32
self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
self.layers = nn.ModuleList(
[
create_block(
d_model,
ssm_cfg=ssm_cfg,
norm_epsilon=norm_epsilon,
rms_norm=rms_norm,
residual_in_fp32=residual_in_fp32,
layer_idx=i,
**factory_kwargs,
)
for i in range(n_layer)
]
)
self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
d_model, eps=norm_epsilon, **factory_kwargs
)
self.apply(
partial(
_init_weights,
n_layer=n_layer,
**(initializer_cfg if initializer_cfg is not None else {}),
)
)
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return {
i: layer.allocate_inference_cache(
batch_size, max_seqlen, dtype=dtype, **kwargs
)
for i, layer in enumerate(self.layers)
}
def forward(self, input_ids, inference_params=None):
hidden_states = self.embedding(input_ids)
residual = None
for layer in self.layers:
hidden_states, residual = layer(
hidden_states, residual, inference_params=inference_params
)
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
return hidden_states
class MambaLMHeadModel(nn.Module, GenerationMixin):
def __init__(
self,
config: MambaConfig,
initializer_cfg=None,
device='cpu',
dtype=torch.float32,
) -> None:
self.config = config
d_model = config.d_model
n_layer = config.n_layer
vocab_size = config.vocab_size
ssm_cfg = config.ssm_cfg
rms_norm = config.rms_norm
residual_in_fp32 = config.residual_in_fp32
pad_vocab_size_multiple = config.pad_vocab_size_multiple
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
if vocab_size % pad_vocab_size_multiple != 0:
vocab_size += pad_vocab_size_multiple - (
vocab_size % pad_vocab_size_multiple
)
self.backbone = MixerModel(
d_model=d_model,
n_layer=n_layer,
vocab_size=vocab_size,
ssm_cfg=ssm_cfg,
rms_norm=rms_norm,
initializer_cfg=initializer_cfg,
residual_in_fp32=residual_in_fp32,
**factory_kwargs,
)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
# Initialize weights and apply final processing
self.apply(
partial(
_init_weights,
n_layer=n_layer,
**(initializer_cfg if initializer_cfg is not None else {}),
)
)
self.tie_weights()
def tie_weights(self):
self.lm_head.weight = self.backbone.embedding.weight
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.backbone.allocate_inference_cache(
batch_size, max_seqlen, dtype=dtype, **kwargs
)
def forward(
self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0
):
"""
"position_ids" is just to be compatible with Transformer generation. We don't use it.
num_last_tokens: if > 0, only return the logits for the last n tokens
"""
hidden_states = self.backbone(input_ids, inference_params=inference_params)
if num_last_tokens > 0:
hidden_states = hidden_states[:, -num_last_tokens:]
lm_logits = self.lm_head(hidden_states)
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
return CausalLMOutput(logits=lm_logits)
@classmethod
def from_pretrained(cls, pretrained_model_name, device='cpu', dtype=torch.float32, **kwargs):
config_data = load_config_hf(pretrained_model_name)
config = MambaConfig(**config_data)
model = cls(config, device=device, dtype=dtype, **kwargs)
model.load_state_dict(
load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)
)
return model
def save_pretrained(self, save_directory):
"""
Minimal implementation of save_pretrained for MambaLMHeadModel.
Save the model and its configuration file to a directory.
"""
# Ensure save_directory exists
if not os.path.exists(save_directory):
os.makedirs(save_directory)
# Save the model's state_dict
model_path = os.path.join(save_directory, "pytorch_model.bin")
torch.save(self.state_dict(), model_path)
# Save the configuration of the model
config_path = os.path.join(save_directory, "config.json")
with open(config_path, "w") as f:
json.dump(self.config.__dict__, f)
@property
def device(self):
return next(self.parameters()).device