-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Code to load weights #1
Comments
This is the code I used to load the weight from numpy files 👍
|
Thx for reverse engineering and sharing! |
I haded the lm.py file allowing to retrain the model on new data. It was used to create the model and load the weights. |
I tried to map the the original TF variables to the original .npy files to your .npy files. Is this mainly correct? Also, I wouldn't know how 14 and 15.npy would be used if they were (b0?) and which file corresponds to gmh in the pytorch version. #Embedding for ASCII one-hot
embd = 0.npy = embed.npy
#State
wh = 1.npy = wh.npy
wmx = concat(2:6.npy) = wmx.npy
wmh = empty? = wmh.npy
gx = empty
gh = empty
gmx = empty
gmh = 7.npy = ?
wx = 8.npy = wx.npy
wh = 9.npy = wh.npy
wmx = 10.npy = wmx.npy
wmh = 11.npy = wmh.npy
#Fully connected
w = 12.npy = w.npy |
Things are more complicated than this, because the tf model is using l2 regularization. Pytorch handles this differently. This is why I had to hack the tensorflow model to produce the different npy files. |
Interesting, I assume you extracted the variables from a live TF graph then. I also found that L2 is added in pytorch's optimizer (usually?) and suspect that was the difference you talk about. Thanks! |
Forgive me, it is not L2 regularization but weights normalization which is the problem. And yes, I extracted the variables with tf code. |
The code in models.py constructs the graph in a very sleek way. Is it possible to see how you transformed the weights into mlstm_ns.pt too?
The text was updated successfully, but these errors were encountered: