Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Per-sample-gradient: Get gradient 0 when using grad(params_tograd, params) with respect to part of model's parameters #1122

Open
Ancientshi opened this issue Apr 1, 2023 · 1 comment

Comments

@Ancientshi
Copy link

Hi Pytorch team, recently I need to calculate per sample's gradient with respect to part of model's parameters. The problem is that for the toy example, it works. But for the Wide & Deep model, it doesn't work and returns me all 0 gradients. I don't know why.

Here is the toy example:

import torch
from functorch import grad
from functorch import make_functional_with_buffers, vmap, grad
import torch.nn.functional as F
import torch.nn as nn

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        output = x
        return output
    

device = 'cuda'
num_models = 10
batch_size = 64

data = torch.randn(batch_size, 1, 28, 28, device=device)
targets = torch.randint(10, (64,), device=device)

model = SimpleCNN().to(device=device)
model=model.eval()
fmodel, params, buffers = make_functional_with_buffers(model)

def loss_fn(predictions, targets):
    return F.nll_loss(predictions, targets)

def compute_loss_stateless_model (params_tograd,params, buffers, sample, target):
    for key, value in params_tograd.items():
        params[key]=value
    batch = sample.unsqueeze(0)
    targets = target.unsqueeze(0)

    predictions = fmodel(params, buffers, batch) 
    loss = loss_fn(predictions, targets)
    return loss

ft_compute_grad = grad(compute_loss_stateless_model)

ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None,None, None, 0, 0))

params_tograd={}
for i in [-2,0]:
    params_tograd[i]=params[i]
ft_per_sample_grads = ft_compute_sample_grad(params_tograd,[p for p in params], buffers, data, targets)
print(ft_per_sample_grads)

The result is :
image

However, when I apply this method to the real scenario, it doesn't works and all return 0 gradient.

      model.load_state_dict(w_tao)
      fmodel, params, buffers = make_functional_with_buffers(model) 

      def loss_fn(predictions, targets):
          return F.mse_loss(predictions, targets)
      
      def compute_loss_stateless_model (params_tograd,params, buffers, sample, target):
          for key, value in params_tograd.items():
              params[key]=value
          batch = sample.unsqueeze(0)
          targets = target.unsqueeze(0)

          predictions = fmodel(params, buffers, batch) 
          loss = loss_fn(predictions, targets)
          return loss
      
      ft_compute_grad = grad(compute_loss_stateless_model)
      ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None,None,None, 0, 0))
      params_tograd={}
      if dataset=='lastfm-1k':
          params_tograd[-1]=params[-1]
          params_tograd[4]=params[4]
      else:
          params_tograd[0]=params[0]
          
      prod_all=[]
      for batch_idx, (inputs, targets) in tqdm(enumerate(train_data_loader)):
          inputs, targets = inputs.to(self.device).float(), targets.to(self.device).float()
          
          ft_per_sample_grads = ft_compute_sample_grad(params_tograd,[p for p in params], buffers, inputs, targets)
          print(ft_per_sample_grads)
          sys.exit()
          if dataset!='lastfm-1k':
              params_grads=ft_per_sample_grads[-1].reshape(ft_per_sample_grads[-1].shape[0] ,-1)
          elif dataset=='lastfm-1k':
              params_grad_dnn=ft_per_sample_grads[-1].reshape(ft_per_sample_grads[-1].shape[0] ,-1)
              params_grad_linear=ft_per_sample_grads[4].reshape(ft_per_sample_grads[4].shape[0] ,-1)
              params_grads=torch.cat([params_grad_dnn,params_grad_linear],-1) 
              
          prod=torch.mm(params_grads,grad_mean.unsqueeze(1)).squeeze().detach().to('cpu').numpy()
          prod_all.extend(prod)         
      return dict(zip(range(1,len(prod_all)+1), prod_all))

image

image

Also, in the Wide & Deep module, the shape of linear_logit should be (batch_size,1), but when apply this method, the error will happened here, and the system said the shape of linear_logit and sparse_feat_logit is not match, here I attached the print out result. (I suppose when using this method, the X.shape[0]=0, but why?)

        linear_logit = torch.zeros([X.shape[0], 1]).to(self.device)
        if len(sparse_embedding_list) > 0:
            #torch.Size([1000, 1, 7])
            sparse_embedding_cat = torch.cat(sparse_embedding_list, dim=-1)
            if sparse_feat_refine_weight is not None:
                # w_{x,i}=m_{x,i} * w_i (in IFM and DIFM)
                sparse_embedding_cat = sparse_embedding_cat * sparse_feat_refine_weight.unsqueeze(1)
            
            sparse_feat_logit = torch.sum(sparse_embedding_cat, dim=-1, keepdim=False)
            try:     
                linear_logit += sparse_feat_logit
            except:
                print(linear_logit.shape)
                print(sparse_feat_logit.shape)
                print('linear_logit\n',linear_logit)
                print('sparse_feat_logit\n',sparse_feat_logit)
                sys.exit()
                linear_logit=sparse_feat_logit

image

@Ancientshi Ancientshi reopened this Oct 19, 2023
@Ancientshi
Copy link
Author

When I want to calculate the gradients of the embedding layer,:

RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant