They’re both exploring the same space of optimizing the memory needed by the KV cache which is essentially another name for the context window (no one elides the KV cache as otherwise you’re doing N^2 math to do attention). They’re exploring different approaches to achieve the same goal and they may be both possible to apply simultaneously to reduce the attention mechanism to almost 0 memory usage which would be really cool, but I’m curious how they compare against each other individually.
The only memory mechanism within an LLM as far as I know is the attention mechanism where it compares all previous tokens to generate a probability distribution for the next token to generate. The attention mechanism has a thing called a KV cache to take the O(n^2) matrix math down to O(n) by caching and reusing the results of some math from previous tokens. The size of how many tokens the context will cover is called the context window (e.g. 128k for Llama).
The articles use very similar verbiage.
> The context window can be considered the model’s working memory
Snip
> Universal transformer memory optimizes prompts using neural attention memory models (NAMMs), simple neural networks that decide whether to “remember” or “forget” each given token stored in the LLM’s memory.
snip
> Meanwhile, by discarding unnecessary tokens, NAMM enabled the LLM model to save up to 75% of its cache memory while performing the tasks.
You just have to be familiar with the wording in the space and read enough literature. Here’s more direct wording from the NAMM paper:
> NAMMs use evolution to optimize the performance of LMs by pruning their KV cache memory. Evolved NAMMs can be zero-shot transferred to other transformers, even across input modalities and task domains.
This is all related work about shrinking the size of the KV cache as the context grows both due to memory and it also has a speed up effect since you’re not having to attend all the tokens (O(n) -> sublinear with the size of the context).
Context is critical in the LLM answering correctly and remembering all the information given to it + everything it said. Typical limits for open models these days are 128k but with techniques like this it could scale even further allowing better performance on thing like code completion.
I thought the context would also have floating point numbers so that tokens would be included in a more fuzzy way, and that when requests are sent it would result in loading slightly different tokens into the cache. Yeah my understanding certainly is limited and I’d like to study it more. Thanks for the response, I see more similarity now.
The word you're looking for is latent space and yes, everything in the compute graph, including context cache & compute is done in latent space. Literal input tokens are first converted to latent space through the embedding layer and literal output tokens are generated by converting the last compute tensor into token probabilities & taking the most probable token. Everything in the middle though happens in the "floating point" latent space.
When you hear something like "it's attending all previous tokens" IMHO it's not strictly the correct explanation since you're attending through latent space which doesn't actually correspond 1:1 with tokens but is a multidimensional representation of that token & all preceding tokens as understood by that attention head. But conceptually it's how it's described because the size of your context goes up by 1 tensor for every token you process, even though applying attention actually ends up changing all tensors in the KV cache (hence self-attention). Also important to note that each attention head within each layer has it's own KV cache. LLMs are an autoregressive family of models where the output of each layer feeds into the input of the next and each layer has a transformer performing attention. That's another reason why it's not strictly correct to think of it as tokens make up your context because there's actually many many contexts within a transformer model. That's why your 128k context window can be ~15 GiB for a naiive inference implementation - 128k context window * 1024 * 1024-element tensor * 2 bytes per tensor * 8 attention heads * 8 layers (or something along those lines). And that's what this work is talking about shrinking (as does the HeadKV).
> tokens would be included in a more fuzzy way, and that when requests are sent it would result in loading slightly different tokens into the cache
The entire process of LLMs is generally actually 100% deterministic based on the same inputs & given a fixed seed for the RNG (modulo bugs in the inference math / bugs in HW/SW for the accelerator). Some inference implementations don't guarantee this property in the face of concurrent requests & you can't control the seed for hosted LLMs which is why it seems like random responses for the same query.
The KV cache feels more like a graph to me, like in the RDF sense. Each parameter could be numbered and given a URL it seems. I have some studying to do. I think building a simple neural net and looking at raw data for context in whatever LLM I’m playing with in Ollama are good things to try.
This isn't like lossless compression. Both techniques involve throwing lots of information away, with the justification that doing so does not significantly affect the end result.
The extent to which using both the techniques together will help will depend on how much overlap there is between the information each ends up discarding.
Modern LLMs are still quite inefficient in their representation of information. We're at like the DEFLATE era and we've still yet to invent zstd where there's only marginal incremental gains; so right now there's a lot of waste to prune away.
[1] https://arxiv.org/html/2410.19258v3