Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't use subnetwork inference on custom module #127

Open
ruili-pml opened this issue Jun 2, 2023 · 6 comments
Open

Can't use subnetwork inference on custom module #127

ruili-pml opened this issue Jun 2, 2023 · 6 comments

Comments

@ruili-pml
Copy link
Contributor

Hi,

I tried to use the subnetwork inference on my model and it didn't work, and I pinpointed the bug down to the issue that the Hessian is not being computed for my custom pytorch module. I was wondering how can I solve this? Thank you.

Kind regards,
Rui

@aleximmer
Copy link
Owner

Hi, could you post code or describe the module? In general, there are two ways:

  1. implement the relevant extension for the module in backpack or asdl yourself
  2. implement your module in terms of standard modules (e.g. torch.nn.Linear) combined with non-parametrized transformations like reshape, view
    It depends on your module what's possible and easier.

@ruili-pml
Copy link
Contributor Author

ruili-pml commented Jun 2, 2023

Hi, thank you for your quick reply. If you happen to know the FiLM layer (https://arxiv.org/abs/1709.07871), it's the module I'm using, basically what it does is apply an affine transformation for each channel of the image. Otherwise please see the attached code:

class SimpleFiLM(nn.Module):
    
    def __init__(self, num_channels):
        super().__init__()
        
        self.num_channels = num_channels
        
        self.scale = nn.Parameter(torch.zeros(num_channels))
        self.shift = nn.Parameter(torch.zeros(num_channels))
        
    def forward(self, x):
        
        """
        x: [batch_size, num_channels, height, width]
        """
        
        scale = self.scale + 1.
        
        cur_scale = scale.reshape(1, self.num_channels, 1, 1)
        cur_shift = self.shift.reshape(1, self.num_channels, 1, 1)

        x = cur_scale * x + cur_shift
        
        return x

I suppose in my case reimplementing it using standard modules might be easier? Previously I was only digging into asdl, I was wondering do you happen to know which backend is better suited?

@aleximmer
Copy link
Owner

I would recommend extending asdl since it's faster and more flexible wrt architectures. For this layer, I think it's easier to implement an extension in asdl following the scale and bias modules, which are essentially scalar scale and shift parameter modules (see https://github.com/kazukiosawa/asdl/tree/master/asdl/operations). However, their extension is only correct if you remove the scale = self.scale + 1. Is there any necessity for this? I would argue you can instead just add the one at initialization. Let me know if anything is unclear or in case you run into problems with this.

@ruili-pml
Copy link
Contributor Author

Thank you for your reply. Unfortunately scale = self.scale + 1 is necessary, it's pointed out in section 7.2 Model Details in the appendix that it has a large influence on the performance. Does it mean I basically need to implement the whole module?

@aleximmer
Copy link
Owner

In this case, you need to implement the module-extension yourself but you can use the simple Bias and Scale operations as template and only extend it to the multivariate setting. The +1 will be correctly incorporated into the partial gradient of the output, i.e., out_grads in asdl so you don't need to handle it separately. I am happy to help if you run into problems, just let me know.

@ruili-pml
Copy link
Contributor Author

I'll work on it when I'm back from the holiday and there's a high chance I'll run into problems, many thanks in advance :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants