Hacker News new | past | comments | ask | show | jobs | submit login

It definitely works, JAX only sees the unrolled loop:

  x = 0
  x += y
  x += y
  x += y
  x += y
  x += y
  return x
The reason you might need `jax.lax.fori_loop` or some such is if you have a long loop with a complex body. Replicating a complex body many times means you end up with a huge computation graph and slow compilation.



And how does TinyGrad solve this?


Fused into one operation since the Tensor isn't resolved until I call .numpy()

  kafka@tubby:/tmp$ cat fuse.py 
  from tinygrad.tensor import Tensor
  x = Tensor.zeros(1)
  for i in range(5):
    x += i
  print(x.numpy())

  kafka@tubby:/tmp$ OPT=2 GPU=1 DEBUG=2 python3 fuse.py 
  using [<pyopencl.Device 'Apple M1 Max' on 'Apple' at 0x1027f00>]
  **CL**      0 elementwise_0        args     1  kernels [1, 1, 1]          None         OPs     0.0M/   0.00G  mem  0.00 GB tm      0.15us/     0.00ms (    0.03 GFLOPS)
  **CL**        copy OUT (1,)
  [10.]


How does this differ from XLA? Would tinygrad's lazy approach also just see the same unrolled loop right before compilation?




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: