-
Notifications
You must be signed in to change notification settings - Fork 73
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Can't use subnetwork inference on custom module #127
Comments
Hi, could you post code or describe the module? In general, there are two ways:
|
Hi, thank you for your quick reply. If you happen to know the FiLM layer (https://arxiv.org/abs/1709.07871), it's the module I'm using, basically what it does is apply an affine transformation for each channel of the image. Otherwise please see the attached code: class SimpleFiLM(nn.Module):
def __init__(self, num_channels):
super().__init__()
self.num_channels = num_channels
self.scale = nn.Parameter(torch.zeros(num_channels))
self.shift = nn.Parameter(torch.zeros(num_channels))
def forward(self, x):
"""
x: [batch_size, num_channels, height, width]
"""
scale = self.scale + 1.
cur_scale = scale.reshape(1, self.num_channels, 1, 1)
cur_shift = self.shift.reshape(1, self.num_channels, 1, 1)
x = cur_scale * x + cur_shift
return x I suppose in my case reimplementing it using standard modules might be easier? Previously I was only digging into asdl, I was wondering do you happen to know which backend is better suited? |
I would recommend extending asdl since it's faster and more flexible wrt architectures. For this layer, I think it's easier to implement an extension in asdl following the |
Thank you for your reply. Unfortunately |
In this case, you need to implement the module-extension yourself but you can use the simple |
I'll work on it when I'm back from the holiday and there's a high chance I'll run into problems, many thanks in advance :) |
Hi,
I tried to use the subnetwork inference on my model and it didn't work, and I pinpointed the bug down to the issue that the Hessian is not being computed for my custom pytorch module. I was wondering how can I solve this? Thank you.
Kind regards,
Rui
The text was updated successfully, but these errors were encountered: