Hacker News new | past | comments | ask | show | jobs | submit login
Jax – Composable transformations of Python and NumPy programs (github.com/google)
235 points by lelf on April 8, 2020 | hide | past | favorite | 65 comments



Hello from the JAX team!

We'd like to take this opportunity to give a shout out to some of the awesome projects folks are building on top of JAX, e.g.,

* Flax, a neural network library for JAX (https://github.com/google/flax)

* Haiku, a neural network library for JAX inspired by Sonnet (https://github.com/deepmind/dm-haiku)

* RLax, a library for building reinforcement learning agents (https://github.com/deepmind/rlax)

* NumPyro, a probabilistic programming library on top of JAX (https://github.com/pyro-ppl/numpyro)

* JAX-MD, a differentiable molecular dynamics package built on top of JAX (https://github.com/google/jax-md)


I've noticed that many JAX libraries (including those from Google) seem to adopt an object-oriented style more similar to Torch/Keras rather than JAX's functional style demonstrated in modules like jax.experimental.stax. This is disappointing since stax is quite clean and these libraries seem to use a lot of hacks to make OO work with JAX. Is there an effort to implement and maintain more full-featured functional libraries in the jax/stax style?


I've been involved w. jax/stax/trax/flax - I think the real issue w. the stax-like functional form is that it gets unwieldy very quickly when dealing w. more complicated models that are natively general-graphs as opposed to simple sequential pipelines that can be trivially mapped to a combinator expression tree. Of course there are many solutions here, but ultimately if you're building an NN library you need to build something that ML researchers actually want to use daily, and that tends to look closer to hackable pytorch-like DSLs rather than higher-order functional code - which often looks elegant but tends to hurt readability and rework speed.


Don't forget Neural Tangents, a high level library for building and running experiments with infinite width neural networks: https://github.com/google/neural-tangents


What about Trax (https://github.com/google/trax) and how does it compare with Flax or Haiku?

Interesting that googlers who are supposed to use Tensorflow are now actively developing a new autograd engine and at least three new DL frameworks on top of it. What do you think about this segmentation?


Good catch, I missed Trax! Trax is a configuration-driven neural network framework focused on sequence model research, as a successor to Tensor2Tensor.

Comparisons are hard in general and I don't have a good answer for you right now, but keep in mind most of these libraries are from researchers openly sharing the codebases they develop for their own work. We see the role of JAX as analogous to NumPy, that is, a common substrate on which folks can build these sorts of tools.


JAX is similar to many other autodifferentiation frameworks being developed for machine learning. It's a spiritual successor to the popular autograd framework (same authors).

The advantages JAX brings are

    * Numpy-adjacent interface 
    * Auto-vectorization
    * Great tie-in with XLA
    * Higher-order derivatives (critical for some applications)
    * Simplified, functional interface
The disadvantage is that it's younger and not trying to be a fully-fledged competitive neural network framework and thus is behind and less resourced compared with other libraries implementing auto-differentiation.

I think it's really exciting to see more people learning about JAX and beginning to use it for serious projects.


> The disadvantage is that it's [...] not trying to be a fully-fledged competitive neural network framework

Perhaps this is an unpopular opinion, but to me this is an advantage. JAX is a library upon which you can build a neural-network framework -- or a framework for something else.

It seems clear that we're still figuring out the Right Way to write code that defines a neural network. The fact that JAX lets you write different, competing libraries/APIs/DSLs -- however you want to think of it -- lets us innovate more freely.


Honestly, I don't disagree, but I feel like it's harder to get resources if you're not aiming for neural networks somewhat directly.


I continue to wonder how does JAX compare to Zygote.jl see https://twitter.com/oxinabox_frames/status/12394957113324789... for a good start


It would be great to see a comparison of Zygote to Jax (both features and performance). I think the Julia language with the ability to manipulate it's AST via macros makes this kind of source-to-source transformation a lot easier than it would be in Python. I'm not sure how they achieved this in Python, but I'm going to guess it was a lot more work.

Also, Zygote is much more ambitious than Jax. Zygote aims to support all of the Julia language whereas Jax is limited to a subset of Python. I wonder if the Zygote folks are biting off too much here? Though currently Zygote doesn't support mutation.


More than AST manipulation (which is probably not a key element in Zygote), the most important aspect is Julia's multi-stage JIT compilation (which is actually an aggressive mixture of JIT and AoT). A Julia program has access to it's own compiler at runtime, which you can call to compile any program down to Julia's Intermediate Representation, which happens to be a Julia structure that you can freely manipulate just like the AST and Static Single Assignment making it already similar to an execution graph. Then the Zygote library traverses the structure systematically writing the backward pass as if it was part of the code in the first place (plus applying any kind of optimization), and finally compiles everything together down to machine code through the LLVM.

Since Python is not compiled like Julia, the approach usually involves using custom types that store a separate intermediate representation from python (a custom one or the MLIR for example) and overloading functions (and perhaps other more sophisticated ways of metaprogramming) to map all the required operations while interpreting the Python code before compiling. That's why Zygote says it works on Zygote unaware libraries, as it doesn't store the graph within any of the types or directly overload any of the methods (it does need to know how to reverse them though), and also why it can directly operate on control structures since they are part of the full IR even when you can't overload them.

And while Zygote aims to fully support the Julia language, they'll certainly work on getting the most important operations working well enough before focusing on the more complicated stuff like mutation.


AST manipulation does not do it for AD in general (though XGrad.jl is doing it as a symbolic diff approach https://github.com/dfdx/XGrad.jl/),

you want to work with the code reduced down to Single assignment form, and work with code blocks and the control flow graph. for a number of reasons: 1. all control flow looks the same now (no 2 different kinds of loops + GOTOs) 2. one expression per line, no need to untangle larger experessions 3. Host of techniques from compiler world can come out to be used.

This is how Zgyote works, it used code that has been lowered into this form


JAX is focusing more on performance and probably will be next big think for google machine learning teams like https://github.com/google/flax (most of the people hate tensorflow) Dunno about zygote future


Why would this be 'the next big thing' if Google seems to be committing to Swift as its DL language of choice?


Yeah I agree with Lyndon, Zygote is a much more ambitious project than Jax, so I would not be surprised if we see them make quicker progress than Zygote in this space.


I wonder if the Zygote folks would be better off focusing on a subset of Julia for now instead of supporting all of Julia. And then focusing on performance for that subset. Maybe this is more difficult because of the way Zygote is implemented (via Julia's macro system) which just lends itself to whole-language support?


Maybe. The nice thing with Julia is that all of the ADs just target Julia AST, so you can mix and match them. I'm using ForwardDiff + ReverseDiff compiled tapes with Zygote over it, because each have different optimizations that can be done to different parts of the code. ChainRules.jl is a great package that will allow this to pick up more steam. While I don't expect most users to go this deep into AD, if you actually want to use it in the backend of packages in a very deep and high performance way, there's nothing better than mixing the strengths of different approaches together.


It's hard to say, I really do think that what the Zygote folks are trying to do is quite achievable, but I also worry that Zygote's prominence may have been bad for the Julia autodiff community.

A year or two ago, it at least seemed to be a much more vibrant situation where there was more effort going into lots of different approaches, but then Zygote really picked up steam and it seemed like it was going to solve all our problems and cure cancer so development on the other packages slowed down, but now the 'finish line' for Zygote doesn't seem to be as close as we initially thought.


On the other hand, there are a lot less Julia developers than Python developers so focusing on Zygote instead of splitting the effort over several AD packages makes some sense, doesn't it? There were already a few AD packages that were fairly mature: https://www.juliadiff.org/

Also, I think performance is very important because Zygote is used in Flux (Julia's main deep learning framework at this point.) and if it's keeping Flux's performance from matching, say, PyTorch's then that's going to limit adoption of Flux and Julia. One of Julia's main claims to fame is performance so people coming to kick the tires are going to be disappointed if it's actually slower for this domain.


I think my concern is that the main Zygote dev is not very responsive or communicative. That's fine, he doesn't owe the rest of the community anything, but it also seems irresponsible to put all our eggs in his basket (note I'm by no means knowledgeable about this and I'm open to the idea that it's a mis-characterization. This is merely my impression as a relative outsider looking in).

Besides, any implementation of AD is going to have downsides. Having many different, hot swappable implementations is quite nice because you can tailor towards your needs.


We have so many AD packages than that in julia. Julia suffers the lisp problem for AD. http://winestockwebdesign.com/Essays/Lisp_Curse.html It is a real problem.

One of my on going projects is ChainRules (http://www.juliadiff.org/ChainRulesCore.jl/dev/) which will unite them under one set of custom senstitivities and more generally make it easier to mix and match them

I need to update the JuliaDiff website, I want to list all of them in a table with some some key points.

Forward mode:

- [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl)

- [ForwardDiff2](https://github.com/YingboMa//ForwardDiff2.jl)

Reverse Mode:

- [Nabla](https://github.com/invenia/Nabla.jl/)

- [Tracker](https://github.com/FluxML/Tracker.jl)

- [Yota](https://github.com/dfdx/Yota.jl)

- [Zygote](https://github.com/FluxML/Zygote.jl)

- [ReverseDiff](https://github.com/JuliaDiff/ReverseDiff.jl)

- [AutoGrad.jl](https://github.com/denizyuret/AutoGrad.jl)

- [NiLang](https://github.com/GiggleLiu/NiLang.jl) (arguably not reverse mode)

Symbolic:

- [ModelingToolKit](https://github.com/JuliaDiffEq/ModelingToolkit.jl)

- [XGrad.jl](https://github.com/dfdx/XGrad.jl)

Finite Differencing:

- [Calculus](https://github.com/JuliaMath/Calculus.jl) (please stop)

- [FiniteDifferences](https://github.com/JuliaDiff/FiniteDifferences.jl)

- [FiniteDiff](https://github.com/JuliaDiff/FiniteDiff.jl)


Here is an earlier submission of JAX from December 2018:

https://news.ycombinator.com/item?id=18636054

JAX is pretty neat because it is effectively a derivatives compiler: it can automatically differentiate a function and JIT compile the result. This makes training in machine learning both fast and easy because gradient descent no longer has to be written by hand.


> gradient descent no longer has to be written by hand

Nobody's been writing derivatives by hand for 5+ years. All major frameworks (PyTorch, Tensorflow, MXNet, autodiff, Chainer, Theano, etc.) have decent to great automatic differentiation.

The differences and improvements are more subtle (easy parallelization/vectorization, higher-order gradients, good XLA support).


For high performance CUDA kernels people still need to write derivatives by hand. I know this as for my own research, and for many production systems, I'd still need to write it myself. Many of my architectures wouldn't have been possible without writing the CUDA myself (Quasi-Recurrent Neural Network[1]) or using optimized hand written black boxes (cuDNN RNN). The lack of open optimized hand written CUDA kernels has actually been an impediment to progress in the field.

Automatic differentiation allows for great flexibility and composability but the performance is still far from good, even with the various JITs available. Jax seems to be one of the most flexible and optimized for many use cases for now however.

[1]: https://github.com/salesforce/pytorch-qrnn


Right, you still need to write derivative rules by hand for the primitive operations of an auto-diff system. Automatic differentiation provides composition, it doesn't solve the root mathematical problem of differentiating operations at the lowest level.

So yes, if need a new primitive to add an efficient CUDA kernel, you will probably also have to write its derivative manually too. JAX has a few shortcuts that occasionally make this easier but fundamentally it has the same challenge as any auto-diff system.


I still strongly disagree. Few of these hand written CUDA kernels outside of the frameworks are about implementing derivative rules, they're about eliminating the CUDA call overheads or avoiding the layered computational / memory inefficiencies that existing ML compilers have trouble handling.

Next to none of the frameworks are yet able to JIT you a performant RNN, yet RNNs only use very standard components[1]. OpenAI had a massive speed and memory usage boost for attention by implementing what amounts to a few standard primitives together[2].

There are massive gaps in the optimizations that existing ML compilers provide. The landscape is starting to get better but it's still filled with many pitholes.

[1]: https://twitter.com/stanfordnlp/status/1224106217192087552

[2]: https://openai.com/blog/sparse-transformer/


It depends what you define as primitive. I've had plenty of compositions of existing primitives for which the auto-derived backprop was orders of magnitude slower than a hand written one. I didn't need to write my own backprop, but I benefited tremendously from it. I don't think my experience is particularly rare.


But is autodiff combined with a blackbox jit a real solution? The jit either works for your new model or it does not. If it does not, you can do pretty much nothing about it, other than ping jax authors or get your own hands dirty with jax internal code. Why is noone working on a usable low-level framework, where I can implement QRNN or more complicated stuff without relying on a black-box jit? Jax could have chosen to be this, but instead is a fancy solution to a non-problem.


How has your experience with CUDA been? Is it as painful as it appears at first glance? I've done a ton of python and C, and yet whenever I look at C++ code, it just screams stay away.

But I have some almost-reasonably-performant pytorch that I'd rather not just use as a cash burning machine, so it looks like it might be time to dive into CUDA :-\


The CUDA I've written has never been joyous but it also hasn't been as horrific as I'd expected. There's a period of hair pulling but persistence will get you through it. The majority of CUDA code is closer to C than C++ too which is helpful. I'll be looking at diving back into CUDA in the near future given the exact speed issues we've been mentioning so feel free to get in touch.


Function getting differentiated regularly is a loss function defined as (DesiredOutput(x) - HugeNumberOfParametersAppliedTo(x))^2. Are you saying that the symbolic expression gets transformed and is then used to represent the gradient?

I thought that PyTorch, Tensorflow and similar already do that.


Many frameworks already compute derivatives, but they don't use a symbolic representation. Instead they use a method called "automatic differentiation" which does something along the lines of (a) extracts a trace of the algorithm by executing the code with dummy arguments, then (b) uses the chain rule to compute component derivatives at each node in the execution tree and combine them into the final answer.

These methods are much faster than perturbation-based derivatives and much more applicable than symbolic methods (which cannot be automatically extracted from a program).


Not sure what you mean by “automatically extracted from a program”, all DL frameworks manually write backward pass for each op.


I mean the tracing operation that produces a structure appropriate for AD computation. I agree with you that there's work needed to specify the node derivatives.

Although, honestly, I misspoke. The difference between AD and symbolic differentiation is more subtle. Really AD is profiting because it uses AST representations to keep a graph of intermediate values while symbolic methods can blow up exponentially (or require clever, difficult to generalize tricks to reconstruct that graph).


Jax is compelling because it’s a lot simpler than other solutions. It’s hard to quantify the value of simplicity over new features, but I’d argue we usually undervalue simplicity; it helps move way faster and makes debugging & maintenance exponentially easier. Helps focus on the 80/20 of helping customers instead of internal bullshit...


I am wondering what is the state of things for fixing gradients are kinks, or if there is no hope for general automatic differentiation libraries.

    import jax.numpy as np
    from jax import grad, jit, vmap
    from jax import random
    
    def relu(x):
      return np.where(x>0, x, np.zeros_like(x))
    
    def identity(x):
      return relu(x) - relu(-x)
    
    derivative_identity = grad(identity)
    derivative_identity(0.0)
returns 1.0 or 0.0? It currently returns 0.0. (Edit: typo in sign)


For many kinks, it doesn't in practice seem to matter that much. Most applications are computing large stochastic derivatives which smooth out kinks through averaging.

Critical kinks are those that affect the geometry of the gradient in systematic ways. For instance, a model with a mixture of discrete and continuous parameters. These are serious blockers and require more complex methods to solve such as Rao-Blackwellization (marginalizing out the discrete parameters). Generally this appears as model bias or substantially increased, often fatal variance in loss curves.



The derivative of x->x (identity) should be 1 regardless of your definition of derivative.

Subgradients are only applicable when summing convex functions. Here relu(x) - relu(-x) is a sum of a convex function and a concave function.


0.0>0 is false.


Has anyone benchmarked Jax? Curious how it compares to PyTorch for nontrivial networks, say ResNet.


Check out the Flax ResNet50 example: https://github.com/google/flax/tree/master/examples/imagenet

It runs about as fast as any of the other popular machine learning frameworks, occasionally faster.

Disclaimer: I work for Google and use JAX, although I'm not on the Jax team.


It is compiled to XLA so should be a lot faster then pure PyTorch but probably will be slower then TVM (https://tvm.apache.org/) i can prepare some benchmarks in next few days if u are interested :)


That would be wonderful if you’re able to! Also doubles as a good intro to Jax :). Please feel free to tweet at me (@tbenst) or email [same username at stanford dot edu] if you do get around to it.


Sure (dont have twitter yet) but will post it here on hacker news in next week probably :)


> XLA so should be a lot faster

I've yet to see anything get "a lot faster" because of XLA. It's a ton of complicated code, but then you end up spending the vast majority of time in NVIDIA's cuDNN anyway, so any benefits you might have hoped for will be marginal at best.



Easily outperformed by the more traditional TensorRT, which TF also supports: https://devblogs.nvidia.com/tensorrt-integration-speeds-tens....

In fact, also seems to be outperformed by plain PyTorch using a single V100: https://github.com/NVIDIA/DeepLearningExamples/tree/master/P...


Interesting. I wonder why there's such a difference between Nvidia Pytorch benchmark and Exxact results: Nvidia is more than twice faster for single GPU. V100 should only be ~10% faster than Quadro 8000. Either Exxact is incompetent, or Nvidia has some special sauce.


FWIW, NVIDIA TensorRT pre-profiles the models before it runs them. I don't know how it does that exactly (that part is closed source) but I'd guess they just try different algorithms on each op individually (i.e. plain conv vs Winograd) and pick a good balance of speed and memory usage according to heuristics. On some nets this can make all the difference in the world, and ResNet50 is basically the most studied architecture in existence, so you can bet it's in every single benchmark for this kind of thing, and as such it receives disproportionate attention.


I thought all frameworks can do this type of profiling (e.g. torch.backends.cudnn.benchmark = True).

Nvidia might have eliminated any potential data pipeline bottlenecks (with careful DALI tuning), but I'd still expect a lot less speedup. Maybe they compiled pytorch with certain tricks, and used newer CUDA/CuDNN code, idk.


TRT profiling is more extensive. On the model I'm currently working with (which runs on Jetson Xavier), initial TRT profiling takes something like 4 minutes. The model is an object detector. You can save the result, but it's hardware dependent then, so the resulting model is only optimal for the particular hardware it was optimized for. I cache it on disk - I can't wait 4 minutes every time I run a test.

PyTorch, as far as I can tell, does much lighter cuDNN profiling. It's more pareto optimal, I suppose, but the benefit is nowhere near as significant.

Another framework which does amazing optimization (but on the CPU) is OpenVINO. Normally I don't expect much on the software side from Intel, but this thing really blows the doors off everything else if you don't have a GPU at your disposal, provided that you have an Intel processor. The wayt they do it is they generate kernels that fit your data, but not the way XLA does it. They hand code them in a DSL that produces assembly, using Xbyak, and incorporate their deep knowledge of Intel hardware into that. When it's time to run the model, that DSL spits out optimal kernels just for that particular model. It's pretty neat work, IMO.


I see, thanks. In my day job I develop hardware accurate simulations of a deep learning accelerator. This involves looking at a Spice model, simplifying it into some set of abstractions using Numpy, then accelerating this Numpy code using GPUs. Currently I'm porting a resnet-50 model from Numpy to Pytorch, and the next step is to speed up the Pytorch code (because right now I get ~1 image per second, which is about 10 times better than Numpy). Perhaps I should look into porting the model from Pytorch to TensorRT.


If you're working with pytorch, porting basically means export to ONNX. Sometimes you'll run into an op that doesn't work with ONNX, but there are a lot fewer of those in TRT7. Unfortunately I have to work with TRT6, so I have to use PyTorch 1.2 and be "creative" to work around TRT6 bugs. That said, it could very well be painless for you. No reason not to try. Just export the model, and benchmark it with `trtexec`, in both fp32 and fp16. An hour of work at most.


I'm very interested in seeing Resnet results. Especially FP16 precision running on tensor cores (V100 cards). Please use synthetic input vectors so that the input pipeline is not a bottleneck (as it is often the case, and it varies per framework).


I assume that this only works for total derivatives (e.g., with respect to time) not general partial derivatives (i.e., the n-th total derivative with respect to another variable), right? Otherwise how would one avoid generating massive amounts of code, and temporary data along with it, if I want to


You can run a k steps of adam-sgd and differentiate through the learning rate and scalarising parameters of a composite loss function in order to meta-learn them.

Support for general n-th total derivatives is rather good :)


Does the resulting chain of transformation pull all data into main memory at once, or can JAX handle larger than memory datasets?


Anyone from the Numba team care to comment?


While there is an overlap in how they work technically, I think the overlap is not too big. So if you ask about a Numba vs JAX comparison here, I'm not sure if such a comparison makes too much sense.

JAX core is an extensible system for transforming numerical Python functions. This core is used to implement automatic differentiation, translation to TF XLA, etc.

Numba does not have such a generic function transformation framework - it just supports a single transformation, that is from numerical Python function to machine code.


We (JAX) see JAX and Numba as mostly complementary; they have different strengths and they are focused on different things.

We haven't tried combining them yet, but we think it would be fun to explore (https://github.com/google/jax/issues/1870). For example, you could use Numba to hand write a numerical kernel that then participates in a machine learning model that uses JAX automatic differentiation.


A big difference is jax having a tracing compiler while numba does not. In general, jax is more suited towards vectorizable code (so using numpy functions broadcasted over axes) while numba is better for accelerating manual for loops.

(I'm not from either jax or numba, but a keen jax user for non-ML research.)


Out of curiosity what are you using JAX for?


How is it different from ChainerX?




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: