Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code to load weights #1

Open
ahirner opened this issue May 6, 2017 · 7 comments
Open

Code to load weights #1

ahirner opened this issue May 6, 2017 · 7 comments

Comments

@ahirner
Copy link

ahirner commented May 6, 2017

The code in models.py constructs the graph in a very sleek way. Is it possible to see how you transformed the weights into mlstm_ns.pt too?

@guillitte
Copy link
Owner

This is the code I used to load the weight from numpy files 👍

embed.weight.data = torch.from_numpy(np.load("embd.npy"))
rnn.h2o.weight.data = torch.from_numpy(np.load("w.npy")).t()
rnn.h2o.bias.data = torch.from_numpy(np.load("b.npy"))
rnn.layers[0].wx.weight.data = torch.from_numpy(np.load("wx.npy")).t()
rnn.layers[0].wh.weight.data = torch.from_numpy(np.load("wh.npy")).t()
rnn.layers[0].wh.bias.data = torch.from_numpy(np.load("b0.npy"))
rnn.layers[0].wmx.weight.data = torch.from_numpy(np.load("wmx.npy")).t()
rnn.layers[0].wmh.weight.data = torch.from_numpy(np.load("wmh.npy")).t()

@ahirner
Copy link
Author

ahirner commented May 6, 2017

Thx for reverse engineering and sharing!

@guillitte
Copy link
Owner

I haded the lm.py file allowing to retrain the model on new data. It was used to create the model and load the weights.

@ahirner
Copy link
Author

ahirner commented May 7, 2017

I tried to map the the original TF variables to the original .npy files to your .npy files. Is this mainly correct? Also, I wouldn't know how 14 and 15.npy would be used if they were (b0?) and which file corresponds to gmh in the pytorch version.

#Embedding for ASCII one-hot
embd = 0.npy = embed.npy

#State
wh = 1.npy = wh.npy
wmx = concat(2:6.npy) = wmx.npy
wmh = empty? = wmh.npy

gx = empty
gh = empty
gmx = empty
gmh = 7.npy = ?

wx = 8.npy = wx.npy
wh = 9.npy = wh.npy
wmx = 10.npy = wmx.npy
wmh = 11.npy = wmh.npy

#Fully connected
w = 12.npy = w.npy

@guillitte
Copy link
Owner

Things are more complicated than this, because the tf model is using l2 regularization. Pytorch handles this differently. This is why I had to hack the tensorflow model to produce the different npy files.

@ahirner
Copy link
Author

ahirner commented May 14, 2017

Interesting, I assume you extracted the variables from a live TF graph then. I also found that L2 is added in pytorch's optimizer (usually?) and suspect that was the difference you talk about. Thanks!

@guillitte
Copy link
Owner

Forgive me, it is not L2 regularization but weights normalization which is the problem. And yes, I extracted the variables with tf code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants