You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello, I haven't got time to really dig into Mamba2 (I just reused the official code to integrate it with muP and a caching mecanisme). I have yet to make it work entirely in PyTorch and from there I guess it will be easier to compute the FLOPs.
I wonder the flops of triton part if i use mamba-ssm.
i have tried to write a function according to the mama2 paper as following:
defssd_flops(T,Q,P,N):
# center blocks#print(T,Q,P,N)center_blocks_sma_compute=T*Q*N+T*Q*Q+T*P*N#print("center_blocks_sma_compute",center_blocks_sma_compute/1e9,T*Q*N/1e9,T*Q*Q/1e9,T*P*N/1e9)#low-rank blocks right factors b termsb_compute=T*N*P#low-rank blocks right factors a termsa_compute=T*N*P/Q#low-rank blocks left factor c termsc_compute=T*P*Nreturncenter_blocks_sma_compute+b_compute+a_compute+c_compute
Hi:
I wonder if u know how to calculate the flops of this
mamba.py/mambapy/mamba2.py
Line 256 in dcd6a32
The text was updated successfully, but these errors were encountered: