Why does prompt caching reduce costs? I'm assuming that the primary cost driver is GPU/TPU FLOPS, as opposed to any network / storage / etc costs.
My understanding is that an LLM will take in the stream of text, tokenize it (can be faster with caching, sure, but it's a minor drop in the bucket), then run a transformer on the entire sequence. You can't just cache the output of a transformer on a prefix to reduce workload.
Prompt caching has been a thing for LLMs since GPT-2 (e.g. transformers's `use_past=True`), it's more of a surprise that it took this long for the main LLM providers to provide a good implementation.
You actually can cache the "output" of a transformer on the prefix by caching what happens in the attention layer for that text string (specifically the "K" and "V" tensors). Since the attention layer is a big part of the compute cost of the transformer, this does cut down FLOPs dramatically.
My understanding is that the attention in all transformer layers is "causal" - that is the output of a transformer layer for token N depends only on tokens from 0 to N.
This means that every attention layer can use previously calculated outputs for the same prompt prefix. So it only needs to calculate from scratch starting from the first unique token in the prompt sequence.
I had the same question... my guess is you can do a layer by layer cache. Ie a cache in the first layer, then another independent second layer cache, and so on.
The transformer only looks backwards, so if the first part of the sequence (the prompt) doesn't change, you don't need to rerun it again on that part, just on the part after it that changed. For use cases with large prompts relative to the output size (e.g. lots of examples in the prompt), this can significantly speed up the workload.
I don't think the normalization makes it infeasible. They should be able to make an adjustment (the reverse of the normalization) in one operation. I think they are caching the attention calcs.
The hard thing (I think) is what to keep in the cache and where to keep it given you are serving lots of customers and the attention calc can be a large set of numbers pretty quickly.
They cache the results of the attention calc. For certain subsets which are common this makes a lot of sense. I'm surprised they can make it work though, given they are serving so many different users. Someone somewhere did some very clever engineering.
My understanding is that an LLM will take in the stream of text, tokenize it (can be faster with caching, sure, but it's a minor drop in the bucket), then run a transformer on the entire sequence. You can't just cache the output of a transformer on a prefix to reduce workload.