Hacker News new | past | comments | ask | show | jobs | submit login
FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention (pytorch.org)
210 points by limoce 3 months ago | hide | past | favorite | 24 comments



Hi, one of the authors of this blog post (Horace He), along with Driss Guessous, Yanbo Liang, and Joy Dong.

We’re quite happy with this abstraction - happy to answer any questions about it!


For those of us using the 2D NATTEN kernel from their library along with torch.compile, is this faster? Especially given all their tricks (e.g., the non-deterministic KV-parallelism)


In my (very amateurish) testing, I think the performance seemed pretty comparable (for non-dilated natten). I need to do some proper benchmarking though!


Is this for Ampere and newer only as FA2?


I believe it should run on V100 as well (although definitely not tested as well), and an user reported that they got it running on T4 too.


It's interesting that optimizing a computation that can be described in a single line of math takes so much work. It took forever even to discover Flash attention. And in the 6 years since transformers were invented, thousands of papers worked on making it faster.

Attention(Q,K,V) = Softmax(Q*K^T/sqrt(d_k))*V

FlexAttention seems to have found the right abstraction for the task.


Yea, because the math have stripped down the whole thing to : I have data I do operation on them. while in reality we deal with multi head attention / grouped query and the positional encoding.

That’s all without taking into account the broadcasting done on the batch dimension


I would agree with this. For example, how would you represent causal attention in the standard equation?


this is true of even just matrix multiplication (A*B) of which attention has two


For most LLM workloads today (short text chats), hundreds or a couple thousand tokens suffice. attention mechanisms don’t dominate (< 30% compute). But as the modalities inevitably grow, work in attention approximation/compression is going to be paramount.

Nice to see Pytorch already elegantly supporting this next step in research


I didn't see any notice of this being CUDA only (like FlashAttention). I tried running on my Mac M3, python 3.11.8, following the quickstart (with the deviation of running it in a new venv). Got the following error:

/attention-gym/.venv/lib/python3.11/site-packages/torch/_subclasses/functional_tensor.py:258: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_numpy.cpp:84.) cpu = _conversion_method_template(device=torch.device("cpu")) Traceback (most recent call last): File "/attention-gym/attn_gym/masks/document_mask.py", line 7, in <module> from torch.nn.attention.flex_attention import _mask_mod_signature ModuleNotFoundError: No module named 'torch.nn.attention.flex_attention'


Ah sorry, should have put that in the blog post. This leverages Triton heavily, so it'll only work on machines that have Triton backends (at least, we've tested on Nvidia and AMD GPUs)


> FlexAttention achieves 90% of FlashAttention2’s performance in the forward pass and 85% in the backward pass.

It's very good. But note FlashAttention-3 is 1.5x - 2x faster than FlashAttention-2.


These benchmarks are on Ampere, where FA3 has no performance benefits over FA2.

On Hopper, FlexAttention is currently about 80% of FlashAttention3's performance (about 500 TFLOPs peak)


Not bad.


Always had the curiosity to put something together with pytorch but it always seemed either a steep learning curve or there wasn't a big motivator (project, problem to solve, something in my daily routine to optimize).

Does anybody have a good starting point to learn with hands-on projects and also that could accommodate for flexattention?


IMO the PyTorch getting started tutorials are really good (https://pytorch.org/tutorials/beginner/basics/intro.html).

A classifier for handwritten digits in the MNIST dataset is generally considered the "Hello World" of neural networks. I went over it in a course, but there are countless tutorials to be found online, i.e. https://www.digitalocean.com/community/tutorials/introductio...

Once you begin to understand how to handle data and how to define layers, you can start playing around with whatever your heart desires. The rabbit hole is vast and endless :)


Agreed that PyTorch tutorials are a great place to start. Specific to flexattention, the blog references the accompanying attention gym, which has a series of examples of how to use flex: https://github.com/pytorch-labs/attention-gym/


Check Out Kaggle for the challenges


This is so cool. I want to try to implement something with this right now.


Can someone do a short summary or TL;DR for this?


https://x.com/chhillee/status/1821253769147118004?s=46

Perhaps this tweet thread would be better.



Thanks, just weaned myself of Twitter / X.




Consider applying for YC's W25 batch! Applications are open till Nov 12.

Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: