Attention Mask for Packed Sequences (via Attention Bias) #139
Replies: 3 comments
-
Hello, and thanks for using EasyDeL. It would be a cool feature to add is there any code related to someone who has implemented this before or not? for GPU flash attention I want to use JAX-triton to port flash attention2 implementation to jax-Pallas and use that it would be an easier, faster, and better option to use that instead of re-creating flash attention 2 for Pallas jax GPUs. is there any related paper released for additional information on implementation?, as I understood from what has been explained here it's not that much of a hard thing to implement we just need a 3D attention mask or causal mak which is not hard to implement for GPUs, but for TPUs there has to be some little tricks for sharding and multi-hosting implementation. |
Beta Was this translation helpful? Give feedback.
-
Hi @erfanzar , Thanks for your response! I think using Jax-triton is a great idea!! Regarding specific implementation for packed attention: Here is what I've implemented for Megatron-LLM based on the issue i shared earlier. However, I don't know if it will be super helpful since it directly uses flash-attention's We only have Also, I wonder whether you have an estimated timeline for integrating jax-triton's attention into EasyDel? I might also be able to take a look at this if you integrated jax-triton into the repo :-)! |
Beta Was this translation helpful? Give feedback.
-
I'll work on that ASAP and actually, I am interested in working on this since there are available CUDA and triton implementations |
Beta Was this translation helpful? Give feedback.
-
Hi @erfanzar,
Thanks for the great repo! It looks really useful for training open-source models on TPU and GPUs!
I wonder if it is easy to implement the feature that allows users to pass in packed sequences (see this) that allows us to maximize the hardware utilization?
The idea is that the user can provide a list of
attention_mask_in_length
, then under the hood, we tweak theattention_mask
OR adjust the attention bias accordingly, so that two different examples cannot attend to each other.I see that you support attention bias for flash attention here, which could be a good starting point (e.g., just set the attention across sequences to be -inf). However, I find that the
flash_func
has a different function signature:flash_func
isflash_attn_gpu.mha
, which does not really has the argumentbias
according to here.flash_attn_tpu.flash_attention
), that argument exists.Is this intended? If not, do we need to fix the GPU's implementation to be the same as TPU's? Is this the best way to implement such packing feature?
Beta Was this translation helpful? Give feedback.
All reactions