Hacker News new | past | comments | ask | show | jobs | submit login
Jax and Equinox: What are they and why should I bother? (garymm.org)
58 points by spearman 5 days ago | hide | past | favorite | 16 comments





Equinox has great idioms — really pioneered the Pytree perspective. Penzai is also great.

JAX feels close to achieving a sort of high-level GPGPU ecosystem. Super fledgling — but I keep finding more little libraries that build on JAX (and can be compositionally used with other JAX libraries because of it).

Only problem is that lots of compositional usage leads to big code and therefore big compile times for XLA.


Personally a big fan of Flax. The way it separates the params and the compute graph is - imo - the right (TM) way. Saying this after many years of soing ML :)

There's quite a few other libraries associated with Equinox in the JAX ecosystem:

https://github.com/patrick-kidger/equinox?tab=readme-ov-file...

I've enjoyed using Equinox and Diffrax for performing ODE simulations. To my knowledge the only other peer library with similar capabilities is the Julia DifferentialEquations.jl package.


I wish Jax had everything I need to experiment with DL models built in, natively - like Pytorch. Instead there are many third party libraries (flax, trax, haiku, this one, etc). I have no idea which one to use. This was the case when I first played with jax 5 years ago, and it’s still the case today (even worse it seems). This makes it a non starter for me.

Why?

Use any. I used to work with Flax, now I work with Equinox more. Choose any between Flax, Equinox, and Haiku.


Use any

Too much overhead when I just want to get shit done.


Okay, I am telling you to use Flax.

Its high level API is quite similar to that of PyTorch. So, you will feel right at home.


So, “do as I say, not as I do”? I just looked at Equinox and it does look a little better than flax, but flax seems to be more widely used. Both haiku and trax seem to be on the way out. All four were intended to do exactly the same thing. What a mess of an ecosystem.

For me the questions to answer for whether or not I should bother.

Will it try and bind me to other technologies?

Does it work out of the box on ${GPU}?

Is it well supported?

Will it continue to be supported?


Re support: JAX is open source and X, Apple, and Google all use it, so I can't imagine it being abandoned in the next 5 years at least.

Playing with JAX on Google Colab (Nvidia T4), everything works great.

Sadly, I cannot get JAX to work with the built-in GPU on my M1 MacBook Air. In theory it's supposed to work:

https://developer.apple.com/metal/jax/

But it crashes Python when I try to run a compiled function. And that's only after discovering I need an older specific version of jax-metal, because newer versions apparently don't work with M1 anymore (only M2/M3) -- they don't even report the GPU as existing. And even if you get it running, it's missing support for complex numbers.

I'm not clear whether it's Google or Apple who is building/maintaining support for Apple M chips though.

JAX works perfectly in CPU mode though on my MBA, so at least I can use it for development and debugging.


Pretty sure it's Apple building it, and they're using JAX in-house so I imagine it will get better over time. Though they do love to drop support for old things so maybe M1 will never work again...

I think they’re likely using MLX in house now, no? (Probably not everyone, ofc - but seems likely that many will just use the native array framework designed explicitly for MX chips)

From https://machinelearning.apple.com/research/introducing-apple...

> Our foundation models are trained on Apple's AXLearn framework, an open-source project we released in 2023. It builds on top of JAX


Wow! thanks for the ref

Outside training, M1 also makes sense as a lower-bound deployment target ("we support Apple Silicon").



Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: