-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model] FalconMamba Support #9325
Changes from 6 commits
465dee3
e1a1a02
4193204
402758b
a80adf5
f66774a
ff57d44
42bc94c
e0e65e9
7738d2d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,20 +14,24 @@ class RMSNorm(CustomOp): | |
Refer to https://arxiv.org/abs/1910.07467 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
hidden_size: int, | ||
eps: float = 1e-6, | ||
var_hidden_size: Optional[int] = None, | ||
) -> None: | ||
def __init__(self, | ||
hidden_size: int, | ||
eps: float = 1e-6, | ||
var_hidden_size: Optional[int] = None, | ||
is_learnable: bool = True) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. RMSNorm weights are non learnable for FalconMamba model. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you explain this a bit more? It seems like this might have been done to work around some issues that popped up during weight loading. Is that right? And am I right that the weights will always be 1.0 for Falcon Mamba, i.e. we could skip the application of the weights for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
2.Yes , you are right. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the explanation -- I think it would be better to handle this in In load_weights, could you add a condition to check if dt_layernorm, b_layernorm, or c_layernorm is in the name? If this is the case, we can set the weight loader to a function that explicitly sets all of the elements to 1.0, which will make things explicitly clear. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the review I managed to integrate FalconMamba inside mamba.py. for rmsnorm , i reveretd the changes , but i think there is no need to handle dt_layernorm, b_layernorm, or c_layernorm inside load_weights since they have been initialised as nn.parameters(torch.ones(hidden_size)) inside RMSNorm initial implementation which is compatible with FalconMamba dt,b,c rmsnorms. |
||
super().__init__() | ||
|
||
self.hidden_size = hidden_size | ||
self.variance_epsilon = eps | ||
self.variance_size_override = (None if var_hidden_size == hidden_size | ||
else var_hidden_size) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: best not to introduce whitespace-only changes to files |
||
self.weight = nn.Parameter(torch.ones(hidden_size)) | ||
if is_learnable: | ||
self.register_parameter("weight", | ||
nn.Parameter(torch.ones(hidden_size))) | ||
else: | ||
self.register_buffer('weight', | ||
torch.ones(hidden_size), | ||
persistent=False) | ||
|
||
def forward_native( | ||
self, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: spurious whitespace