Hacker Newsnew | past | comments | ask | show | jobs | submitlogin

It definitely works, JAX only sees the unrolled loop:

  x = 0
  x += y
  x += y
  x += y
  x += y
  x += y
  return x
The reason you might need `jax.lax.fori_loop` or some such is if you have a long loop with a complex body. Replicating a complex body many times means you end up with a huge computation graph and slow compilation.



And how does TinyGrad solve this?


Fused into one operation since the Tensor isn't resolved until I call .numpy()

  kafka@tubby:/tmp$ cat fuse.py 
  from tinygrad.tensor import Tensor
  x = Tensor.zeros(1)
  for i in range(5):
    x += i
  print(x.numpy())

  kafka@tubby:/tmp$ OPT=2 GPU=1 DEBUG=2 python3 fuse.py 
  using [<pyopencl.Device 'Apple M1 Max' on 'Apple' at 0x1027f00>]
  **CL**      0 elementwise_0        args     1  kernels [1, 1, 1]          None         OPs     0.0M/   0.00G  mem  0.00 GB tm      0.15us/     0.00ms (    0.03 GFLOPS)
  **CL**        copy OUT (1,)
  [10.]


How does this differ from XLA? Would tinygrad's lazy approach also just see the same unrolled loop right before compilation?




Consider applying for YC's Fall 2025 batch! Applications are open till Aug 4

Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: