Hacker News new | past | comments | ask | show | jobs | submit login

> One advantage, however, of doing a more whole-program approach to AD rather than individual operators

I was under the impression that the big ML frameworks (and surely JAX with jit) are doing optimization on the complete compute graph, too.

I didn't want to make this discussion too TF/pyTorch focused (I'm not even a ML researcher). But your optimization claims sound like the other AD frameworks are not doing any optimization at all, which is not the case.

I was also thinking about derivatives of functions which are doing something iterative on the inside, like a matrix decomposition (combined with linear solve and/or matrix inversion). While a "high level" AD tracer can identify an efficient derivative of these operations, your LLVM introspection would only be able to compute the derivative through all the internal step of the matrix decomposition?




Oh for sure, any ML framework worth its salt should do some amount of graph rewriting / transformations.

I was (perhaps poorly) trying to explain how while yes AD (regardless of implementation in Enzyme, PyTorch, etc) _can_ avoid caching values using clever tricks, they can't always get away with it. The cache-reduction optimizations really depend on the abstraction level chosen for what tools can do. If a tool can only represent the binary choice of whether an input is needed or not, it could miss out on the fact that perhaps only the first element (and not the whole array/tensor) is needed.

Regarding Enzyme v JaX/etc, again I think that's the wrong way to think about these tools. They solve problems at different levels and in fact can be used together for mutual benefit.

For example a high-level AD tool in a particular DSL might know that algebraically you don't need to compute the derivative of something since from the domain knowledge it is always a constant. Without that domain knowledge, a tool will have to actually compute it. On the other side of the coin, there's no way such a high level AD tool would do all the drudgery of invariant code motion, or even lower level scheduling/register allocation (and see Enzyme paper for reasons why these can be really useful optimizations for AD).

In an ideal world you want to combine all this together and have AD done in part whenever there's some amount of meaningful optimization (and ideally remove abstraction barriers like say a black box call to cudnn). We demonstrate this high and low level AD in a minimal test case against Zygote [high level Julia AD], replacing a scalar code which is something Zygote is particularly bad at. This thus enables both the high level algebraic transformations of Zygote and the low level scalar performance of Enzyme, which is what you'd really want to do.

It looks like the discussion of this has dropped off for now, but I'm sure shoyer would be able to do a much better job of listing interesting high-level tricks JaX does [and perhaps low level ones it misses] as a consequence of its choice of where to live on the abstraction spectrum.

Also thanks for reminding me about matrix decomposition, I actually think there's a decent chance of doing that somewhat nicely at a low level from various loop analyses, but I got distracted by a large fortran code for nuclear particles.




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: