-
Notifications
You must be signed in to change notification settings - Fork 409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bug fix for named nn.Sequential in pytorch parser #848
Conversation
hls4ml/converters/pytorch_to_hls.py
Outdated
if "layer." in key: | ||
layerInKeyName = True | ||
|
||
if '_' in layer_name: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure I follow the logic here, but I am not a pytorch expert. Can you double-check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had another look and this is a bit convoluted, but it works. The issue is that the layers inside a torch.nn.Sequential get named in the pattern nameOfSquential_n
where n
just numbers the layers inside the sequential. If the sequential does not have a name, which happens if you just go model = nn.Sequential(..)
, the layers will have names that just start with an underscore, which hls4ml doesn't like. So we add the prefix "layer" to those in the loop over the layers, which we have to remove when we go and load the tensors. The changes in this PR account for the fact that someone could create a named nn.Sequential just named layer
, which then clashes with our previous assumption that if a layer name starts with layer_
it is because we added it by hand to get around the issue with layer names starting with just an _
.
A further complication is that while torch.FX reports these structured layer names with underscores while in the state_dict
a .
is used, so we have to replace the _
with .
here. The last complication is the case of layers that used multiple times in the same model. torch.FX reports them as different layers, adding an _n
for the n-th time it is used to the layer name. But the tensors in the state dict will not have those modifications, so we also have to account for this.
Bit of a mess, but I have tested all these cases and this implementations catches all edge cases that I'm aware off.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has now been solved significantly nicer by Vladimir :)
pre-commit.ci autofix |
This now includes the changes from #840. There'll be a follow-up PR with the remaining bits we need to parse sPHENIX tracking GNN. |
Parsing of nn.Sequentials that are named members of a model class results in a naming convention for the tensors in the
state_dict
of the model different from what the parser expects, since it was so far tested only on unnamed nn.Sequentials. This PR catches this and adjusts the name of the tensors we are importing from thestate_dict
accordingly. A test is added to ensure that we keep parsing both cases successfully.Type of change
For a new feature or function, please create an issue first to discuss it
with us before submitting a pull request.
Note: Please delete options that are not relevant.
Tests
To reproduce, this will fail with this PR:
pytests have been added to verify that this keeps working.
Checklist
pre-commit
on the files I edited or added.