Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flops about mamba2 #51

Open
dumpmemory opened this issue Aug 2, 2024 · 2 comments
Open

flops about mamba2 #51

dumpmemory opened this issue Aug 2, 2024 · 2 comments

Comments

@dumpmemory
Copy link

Hi:

I wonder if u know how to calculate the flops of this

y = mamba_chunk_scan_combined(
function. Thanks.

@alxndrTL
Copy link
Owner

alxndrTL commented Aug 8, 2024

Hello, I haven't got time to really dig into Mamba2 (I just reused the official code to integrate it with muP and a caching mecanisme). I have yet to make it work entirely in PyTorch and from there I guess it will be easier to compute the FLOPs.

@dumpmemory
Copy link
Author

Thanks for your reply.

I wonder the flops of triton part if i use mamba-ssm.

i have tried to write a function according to the mama2 paper as following:

def ssd_flops(T,Q,P,N):
    # center blocks
    #print(T,Q,P,N)
    center_blocks_sma_compute = T*Q*N+T*Q*Q+T*P*N
    #print("center_blocks_sma_compute",center_blocks_sma_compute/1e9,T*Q*N/1e9,T*Q*Q/1e9,T*P*N/1e9)
    #low-rank blocks right factors b terms
    b_compute = T*N*P
    #low-rank blocks right factors a terms
    a_compute = T*N*P/Q

    #low-rank blocks left factor c terms
    c_compute = T*P*N
    return center_blocks_sma_compute+b_compute+a_compute+c_compute

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants