It constructs an analytical gradient from the code. The reason is that you can compute the gradient directly. This can enable optimizations such as avoiding caching big matrices because you don't need to keep track of states/trace the graph, or you can compute the 2nd, 3rd, 4th... and so on derivatives because you have an analytical gradient.
For example in an affine function, the gradient of the bias/intercept is the gradient of the loss wrt the activation function and for the weights, it's the product of loss wrt activation function and the input to the layer.
With automatic graph construction e.g. eager Tensorflow/Pytorch, the layer needs to cache the input of the layer, so that it can compute the gradient of the weights. If the layer receives inputs multiple times within the computation graph, you end up caching it multiple times.
With analytical gradients, you may be able to save memory by finding optimizations because you have the analytical gradient, e.g. above you can sum the inputs ie (dL/dz)input1 + (dL/dz)input2 = (dL/dz)(input1+input2).
Isn't the input of the layer fundamentally a part of the gradient computation? So even in this case (inspecting LLVM code) the computation still needs to look at the input.
Even so, you maybe able to perform optimizations that were not possible under normal circumstances, e.g. you have an exponent in the output of a layer followed by a log in the next. Think SoftMax and logloss.
You don't always need the input to compute the gradient. For example the gradient of a sum function doesn't require the original input, it just sets all of the derivative(input)'s to 1.
I think in essence what PartiallyTyped is trying to say is that one potential optimization opportunity in whole-program AD is that you can avoid having to cache the original inputs of the program if you know that derivative computation won't need it (e.g. its only used in a sum and not a product or something whose derivative depends on the value). Some ML frameworks must cache all of the inputs to an operation since they don't know whether it will be necessary for the reverse pass of an operation. You could go even further and decide to cache a different & smaller set of intermediate values that still lets you compute the gradient.
Beyond cache reduction, in our paper we demonstrate a lot of interesting ways that combining AD with a compiler can create potential speed-up. For example, we are often able to dead-code eliminate part of the original forward-pass code since it's not needed to compute the gradient.
We also have a cool example in the paper showing an asymptotic [O(N^2) => O(N)] speedup on a code for normalizing a vector because doing AD in the compiler lets Enzyme run after optimization (and in that example benefit from loop invariant code motion in a way that tools that aren't in the compiler cannot do).
Yeah my best guess at that is that they were trying to say you'd only need to store one value: the sum, rather than the two individual values -- but I'm not completely sure.
The essence of what I was trying to say with the example is that a layer may be used multiple times through the computation graph. Without an analytical gradient, you may end up caching all of the inputs to the layer to compute the gradient. The alternative is to sum up the inputs because the gradient is linear with respect to the inputs; with an analytical gradient you can find that and compute it within the code instead of looking for adhoc optimisations within the graph.
For example in an affine function, the gradient of the bias/intercept is the gradient of the loss wrt the activation function and for the weights, it's the product of loss wrt activation function and the input to the layer.
With automatic graph construction e.g. eager Tensorflow/Pytorch, the layer needs to cache the input of the layer, so that it can compute the gradient of the weights. If the layer receives inputs multiple times within the computation graph, you end up caching it multiple times.
With analytical gradients, you may be able to save memory by finding optimizations because you have the analytical gradient, e.g. above you can sum the inputs ie (dL/dz)input1 + (dL/dz)input2 = (dL/dz)(input1+input2).