It definitely works, JAX only sees the unrolled loop:
x = 0
x += y
x += y
x += y
x += y
x += y
return x
The reason you might need `jax.lax.fori_loop` or some such is if you have a long loop with a complex body. Replicating a complex body many times means you end up with a huge computation graph and slow compilation.
Fused into one operation since the Tensor isn't resolved until I call .numpy()
kafka@tubby:/tmp$ cat fuse.py
from tinygrad.tensor import Tensor
x = Tensor.zeros(1)
for i in range(5):
x += i
print(x.numpy())
kafka@tubby:/tmp$ OPT=2 GPU=1 DEBUG=2 python3 fuse.py
using [<pyopencl.Device 'Apple M1 Max' on 'Apple' at 0x1027f00>]
**CL** 0 elementwise_0 args 1 kernels [1, 1, 1] None OPs 0.0M/ 0.00G mem 0.00 GB tm 0.15us/ 0.00ms ( 0.03 GFLOPS)
**CL** copy OUT (1,)
[10.]