Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

how to calculate the flops if one module have 'einsum' option? #105

Closed
JunyaoHu opened this issue Mar 21, 2023 · 2 comments
Closed

how to calculate the flops if one module have 'einsum' option? #105

JunyaoHu opened this issue Mar 21, 2023 · 2 comments
Labels
enhancement New feature or request question Further information is requested wontfix This will not be worked on

Comments

@JunyaoHu
Copy link

In diffusion model like this, it use NIN module with einsum operation, how to calculate it?

def _einsum(a, b, c, x, y):
  einsum_str = '{},{}->{}'.format(''.join(a), ''.join(b), ''.join(c))
  return torch.einsum(einsum_str, x, y)

def contract_inner(x, y):
  """tensordot(x, y, 1)."""
  x_chars = list(string.ascii_lowercase[:len(x.shape)])
  y_chars = list(string.ascii_lowercase[len(x.shape):len(y.shape) + len(x.shape)])
  y_chars[0] = x_chars[-1]  # first axis of y and last of x get summed
  out_chars = x_chars[:-1] + y_chars[1:]
  return _einsum(x_chars, y_chars, out_chars, x, y)

class NIN(nn.Module):
  def __init__(self, in_dim, num_units, init_scale=0.1):
    super().__init__()
    self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
    self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)

  def forward(self, x):
    x = x.permute(0, 2, 3, 1)
    y = contract_inner(x, self.W) + self.b
    return y.permute(0, 3, 1, 2)


class AttnBlockpp(nn.Module):
  """Channel-wise self-attention block. Modified from DDPM."""

  def __init__(self, channels, skip_rescale=False, init_scale=0., n_heads=1, n_head_channels=-1):
    super().__init__()
    num_groups = min(channels // 4, 32)
    while(channels % num_groups != 0): # must find another value
      num_groups -= 1
    self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels,
                                  eps=1e-6)
    self.NIN_0 = NIN(channels, channels)
    self.NIN_1 = NIN(channels, channels)
    self.NIN_2 = NIN(channels, channels)
    self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
    self.skip_rescale = skip_rescale
    if n_head_channels == -1:
      self.n_heads = n_heads
    else:
      if channels < n_head_channels:
        self.n_heads = 1
      else:
        assert channels % n_head_channels == 0
        self.n_heads = channels // n_head_channels

  def forward(self, x):
    B, C, H, W = x.shape
    h = self.GroupNorm_0(x)
    q = self.NIN_0(h)
    k = self.NIN_1(h)
    v = self.NIN_2(h)

    C = C // self.n_heads

    w = torch.einsum('bchw,bcij->bhwij', q.reshape(B * self.n_heads, C, H, W), k.reshape(B * self.n_heads, C, H, W)) * (int(C) ** (-0.5))
    w = torch.reshape(w, (B * self.n_heads, H, W, H * W))
    w = F.softmax(w, dim=-1)
    w = torch.reshape(w, (B * self.n_heads, H, W, H, W))
    h = torch.einsum('bhwij,bcij->bchw', w, v.reshape(B * self.n_heads, C, H, W))
    h = h.reshape(B, C * self.n_heads, H, W)
    h = self.NIN_3(h)
    if not self.skip_rescale:
      return x + h
    else:
      return (x + h) / np.sqrt(2.)

when I try to calculation flops ofMCVD diffusion this, the algorithm ingore NIN module

DataParallel(
  59.02 M, 93.938% Params, 19.31 GMac, 100.000% MACs, 
  (module): UNetMore_DDPM(
    59.02 M, 93.938% Params, 19.31 GMac, 100.000% MACs, 
    (unet): NCSNpp(
      59.02 M, 93.938% Params, 19.31 GMac, 100.000% MACs, 
      (act): SiLU(0, 0.000% Params, 0.0 Mac, 0.000% MACs, )
      (all_modules): ModuleList(
        59.02 M, 93.938% Params, 19.31 GMac, 100.000% MACs, 
        (0): Linear(37.25 k, 0.059% Params, 37.25 KMac, 0.000% MACs, in_features=96, out_features=384, bias=True)
        (1): Linear(147.84 k, 0.235% Params, 147.84 KMac, 0.001% MACs, in_features=384, out_features=384, bias=True)
        (2): Conv2d(13.06 k, 0.021% Params, 53.48 MMac, 0.277% MACs, 15, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ResnetBlockBigGANppGN(
          313.92 k, 0.500% Params, 681.2 MMac, 3.528% MACs, 
          (actnorm0): get_act_norm(
            73.92 k, 0.118% Params, 467.14 KMac, 0.002% MACs, 
            (act): SiLU(0, 0.000% Params, 0.0 Mac, 0.000% MACs, )
            (act_emb): SiLU(0, 0.000% Params, 0.0 Mac, 0.000% MACs, )
            (Dense_0): Linear(73.92 k, 0.118% Params, 73.92 KMac, 0.000% MACs, in_features=384, out_features=192, bias=True)
            (Norm_0): GroupNorm(0, 0.000% Params, 393.22 KMac, 0.002% MACs, 24, 96, eps=1e-05, affine=False)
          )
         ...
        (7): AttnBlockpp(
          384, 0.001% Params, 393.22 KMac, 0.002% MACs, 
          (GroupNorm_0): GroupNorm(384, 0.001% Params, 393.22 KMac, 0.002% MACs, 32, 192, eps=1e-06, affine=True)
          (NIN_0): NIN(0, 0.000% Params, 0.0 Mac, 0.000% MACs, )
          (NIN_1): NIN(0, 0.000% Params, 0.0 Mac, 0.000% MACs, )
          (NIN_2): NIN(0, 0.000% Params, 0.0 Mac, 0.000% MACs, )
          (NIN_3): NIN(0, 0.000% Params, 0.0 Mac, 0.000% MACs, )
        )
        (8): ResnetBlockBigGANppGN(
          959.62 k, 1.527% Params, 680.56 MMac, 3.525% MACs, 
          (actnorm0): get_act_norm(
            147.84 k, 0.235% Params, 344.45 KMac, 0.002% MACs, 
            (act): SiLU(0, 0.000% Params, 0.0 Mac, 0.000% MACs, )
            (act_emb): SiLU(0, 0.000% Params, 0.0 Mac, 0.000% MACs, )
            (Dense_0): Linear(147.84 k, 0.235% Params, 147.84 KMac, 0.001% MACs, in_features=384, out_features=384, bias=True)
            (Norm_0): GroupNorm(0, 0.000% Params, 196.61 KMac, 0.001% MACs, 32, 192, eps=1e-05, affine=False)
          )
          (Conv_0): Conv2d(331.97 k, 0.528% Params, 339.94 MMac, 1.761% MACs, 192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (actnorm1): get_act_norm(
            147.84 k, 0.235% Params, 344.45 KMac, 0.002% MACs, 
            (act): SiLU(0, 0.000% Params, 0.0 Mac, 0.000% MACs, )
            (act_emb): SiLU(0, 0.000% Params, 0.0 Mac, 0.000% MACs, )
            (Dense_0): Linear(147.84 k, 0.235% Params, 147.84 KMac, 0.001% MACs, in_features=384, out_features=384, bias=True)
            (Norm_0): GroupNorm(0, 0.000% Params, 196.61 KMac, 0.001% MACs, 32, 192, eps=1e-05, affine=False)
          )
          (Dropout_0): Dropout(0, 0.000% Params, 0.0 Mac, 0.000% MACs, p=0.1, inplace=False)
          (Conv_1): Conv2d(331.97 k, 0.528% Params, 339.94 MMac, 1.761% MACs, 192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (act): SiLU(0, 0.000% Params, 0.0 Mac, 0.000% MACs, )
        )
        (9): AttnBlockpp(
          384, 0.001% Params, 393.22 KMac, 0.002% MACs, 
          (GroupNorm_0): GroupNorm(384, 0.001% Params, 393.22 KMac, 0.002% MACs, 32, 192, eps=1e-06, affine=True)
          (NIN_0): NIN(0, 0.000% Params, 0.0 Mac, 0.000% MACs, )
          (NIN_1): NIN(0, 0.000% Params, 0.0 Mac, 0.000% MACs, )
          (NIN_2): NIN(0, 0.000% Params, 0.0 Mac, 0.000% MACs, )
          (NIN_3): NIN(0, 0.000% Params, 0.0 Mac, 0.000% MACs, )
        )
        ...
@JunyaoHu
Copy link
Author

JunyaoHu commented Mar 21, 2023

and I found some reference but i am not sure they say right

Bihaqo/t3f#220

v923z/micropython-ulab#320

dgasmith/opt_einsum#103

Bihaqo/t3f#219

https://stackoverflow.com/questions/53183222/numpy-einsum-path-reports-more-flop-and-speeddown

https://stackoverflow.com/questions/31187512/how-does-architecture-affect-numpy-array-operation-performance

https://stackoverflow.com/questions/62395075/efficient-tensor-contraction-with-python/62415935#62415935

Example:

import numpy as np

At = np.random.randn(8, 10)
G = np.random.randn(10, 3)
x = np.random.randn(100, 300, 75, 10)
Bt = np.random.randn(10, 10)
w = np.random.randn(100, 300, 3)

np.einsum("mn,nr,fcr,nh,octh->oftm", At, G, w, Bt, x, optimize="optimal")
res = np.einsum_path("mn,nr,fcr,nh,octh->oftm", At, G, w, Bt, x, optimize="optimal")

for r in res:
    print(r)

"""
    ['einsum_path', (1, 2), (1, 2), (1, 2), (0, 1)]
  Complete contraction:  mn,nr,fcr,nh,octh->oftm
         Naive scaling:  8
     Optimized scaling:  5
      Naive FLOP count:  2.700e+12
  Optimized FLOP count:  5.072e+09
   Theoretical speedup:  532.355
  Largest intermediate:  2.250e+07 elements
--------------------------------------------------------------------------
scaling                  current                                remaining
--------------------------------------------------------------------------
   4                 fcr,nr->nfc                     mn,nh,octh,nfc->oftm
   5               octh,nh->ntoc                        mn,nfc,ntoc->oftm
   5              ntoc,nfc->ntfo                            mn,ntfo->oftm
   5               ntfo,mn->oftm                               oftm->oftm
"""
float(res[1].split('\n')[4].split(':')[1].strip())
# 5072000000.0

@sovrasov
Copy link
Owner

Thanks for investigation! Unfortunately, ptflops cant track all functional-style operations like torch.einsum or even torch.nn.functional.upsample (trat's stated in readme as well). In your sample NIN is not counted since ptflops doesn't see any familiar module inside it. The only solution here is to implement a custom hook for your NIN module.

@sovrasov sovrasov added enhancement New feature or request question Further information is requested wontfix This will not be worked on labels Apr 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request question Further information is requested wontfix This will not be worked on
Projects
None yet
Development

No branches or pull requests

2 participants