In case you didn't know, you can parallelize the slow Python loop in selective_scan that computes all the x's:
x = torch.zeros((b, d_in, n)) for i in range(l): x = deltaA[:, :, i] * x + deltaB_u[:, :, i] ⋮
You can then compute all the y's with one einsum, instead of l sequential einsums.
---
[a] Previous discussion on HN: https://news.ycombinator.com/item?id=38556669
For what it's worth, you can keep both, and make parallel vs sequential execution an option, with a boolean flag.
You can also leave the sequential code as a comment explaining what the parallel code does.
Or, if slow execution doesn't bother you, leave it as is.
In case you didn't know, you can parallelize the slow Python loop in selective_scan that computes all the x's:
with only two calls to the PyTorch API. See the examples here: https://github.com/glassroom/heinsen_sequence/blob/main/READ... .[a]You can then compute all the y's with one einsum, instead of l sequential einsums.
---
[a] Previous discussion on HN: https://news.ycombinator.com/item?id=38556669