For those of us using the 2D NATTEN kernel from their library along with torch.compile, is this faster? Especially given all their tricks (e.g., the non-deterministic KV-parallelism)
In my (very amateurish) testing, I think the performance seemed pretty comparable (for non-dilated natten). I need to do some proper benchmarking though!
It's interesting that optimizing a computation that can be described in a single line of math takes so much work. It took forever even to discover Flash attention. And in the 6 years since transformers were invented, thousands of papers worked on making it faster.
Attention(Q,K,V) = Softmax(Q*K^T/sqrt(d_k))*V
FlexAttention seems to have found the right abstraction for the task.
Yea, because the math have stripped down the whole thing to : I have data I do operation on them. while in reality we deal with multi head attention / grouped query and the positional encoding.
That’s all without taking into account the broadcasting done on the batch dimension
For most LLM workloads today (short text chats), hundreds or a couple thousand tokens suffice. attention mechanisms don’t dominate (< 30% compute). But as the modalities inevitably grow, work in attention approximation/compression is going to be paramount.
Nice to see Pytorch already elegantly supporting this next step in research
I didn't see any notice of this being CUDA only (like FlashAttention). I tried running on my Mac M3, python 3.11.8, following the quickstart (with the deviation of running it in a new venv). Got the following error:
/attention-gym/.venv/lib/python3.11/site-packages/torch/_subclasses/functional_tensor.py:258: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_numpy.cpp:84.)
cpu = _conversion_method_template(device=torch.device("cpu"))
Traceback (most recent call last):
File "/attention-gym/attn_gym/masks/document_mask.py", line 7, in <module>
from torch.nn.attention.flex_attention import _mask_mod_signature
ModuleNotFoundError: No module named 'torch.nn.attention.flex_attention'
Ah sorry, should have put that in the blog post. This leverages Triton heavily, so it'll only work on machines that have Triton backends (at least, we've tested on Nvidia and AMD GPUs)
Always had the curiosity to put something together with pytorch but it always seemed either a steep learning curve or there wasn't a big motivator (project, problem to solve, something in my daily routine to optimize).
Does anybody have a good starting point to learn with hands-on projects and also that could accommodate for flexattention?
A classifier for handwritten digits in the MNIST dataset is generally considered the "Hello World" of neural networks. I went over it in a course, but there are countless tutorials to be found online, i.e. https://www.digitalocean.com/community/tutorials/introductio...
Once you begin to understand how to handle data and how to define layers, you can start playing around with whatever your heart desires. The rabbit hole is vast and endless :)
Agreed that PyTorch tutorials are a great place to start. Specific to flexattention, the blog references the accompanying attention gym, which has a series of examples of how to use flex: https://github.com/pytorch-labs/attention-gym/
We’re quite happy with this abstraction - happy to answer any questions about it!