You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
First of all, thank you for the profiling tools, it's been very helpful.
As the title mentions, the output for the number of operations increased every time I ran the profile function, because I'm using this in a notebook and I'm running it multiple times in debugging. I found out that it's because the hooks are being registered and added every time. I managed to fix this by creating a list of handles, and appending the handles for all of the hooks. When the profile function is done counting, I simply iterate over the handles and remove them. I don't know if this is efficient or scales up well for larger models, but for my specific case I did not notice an increase (very small model).
Edit: Another solution would be to check if hooks already exist (i.e. check handle list is not None) and then skipping the count which is more efficient, but I can't figure out how to do that right now.
Here's the code I used:
def profile(model, input_size, custom_ops = {}):
model.eval()
#create list of handles to keep track of the hooks to remove them later
handle = []
def add_hooks(m):
if len(list(m.children())) > 0: return
m.register_buffer('total_ops', torch.zeros(1))
m.register_buffer('total_params', torch.zeros(1))
for p in m.parameters():
m.total_params += torch.Tensor([p.numel()])
if isinstance(m, torch.nn.Conv2d):
handle.append(m.register_forward_hook(count_conv2d))
elif isinstance(m, torch.nn.BatchNorm2d):
handle.append(m.register_forward_hook(count_bn2d))
elif isinstance(m, torch.nn.ReLU):
handle.append(m.register_forward_hook(count_relu))
elif isinstance(m, (torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d)):
handle.append(m.register_forward_hook(count_maxpool))
elif isinstance(m, (torch.nn.AvgPool1d, torch.nn.AvgPool2d, torch.nn.AvgPool3d)):
handle.append(m.register_forward_hook(count_avgpool))
elif isinstance(m, torch.nn.Linear):
handle.append(m.register_forward_hook(count_linear))
elif isinstance(m, (torch.nn.Dropout, torch.nn.Dropout2d, torch.nn.Dropout3d)):
pass
else:
print("Not implemented for ", m)
model.apply(add_hooks)
x = torch.zeros(input_size)
model(x)
# if(isinstace(m)):
total_ops = 0
total_params = 0
for m in model.modules():
if len(list(m.children())) > 0: continue
total_ops += m.total_ops
total_params += m.total_params
total_ops = total_ops
total_params = total_params
#Delete the handles to maintain same number of ops
for i in handle:
i.remove()
return total_ops, total_params
The text was updated successfully, but these errors were encountered:
Hello,
First of all, thank you for the profiling tools, it's been very helpful.
As the title mentions, the output for the number of operations increased every time I ran the profile function, because I'm using this in a notebook and I'm running it multiple times in debugging. I found out that it's because the hooks are being registered and added every time. I managed to fix this by creating a list of handles, and appending the handles for all of the hooks. When the profile function is done counting, I simply iterate over the handles and remove them. I don't know if this is efficient or scales up well for larger models, but for my specific case I did not notice an increase (very small model).
Edit: Another solution would be to check if hooks already exist
(i.e. check handle list is not None)and then skipping the count which is more efficient, but I can't figure out how to do that right now.Here's the code I used:
The text was updated successfully, but these errors were encountered: