Hi everyone,
I am starting to think about how to make some features in various packages compatible with the Array API (which allows non-Numpy arrays to be used, e.g. jax, pytorch, etc.), and I was curious if anyone here has recommendations for the best way to support this API in the astropy ecosystem
From what I can tell, the main change for a function would be that inside it we do:
xp = array_namespace(a, b)
and then use xp.function instead of np.function afterwards. From my experiments with this, it seems to work reasonably well, but one doesn't benefit from e.g. JIT with jax.
So then I tried the idea of having the original array API kernel function, a jit-ed wrapper, and then a wrapper on top of that which selects the right one specifically if jax arrays are being used, and it does work, but I'm not sure if it's the optimal pattern. Another option would be that we could simply provide in the public API the original function and a version that uses JIT-ed versions for jax, but that seems messy.
So I guess my question is – what guidelines should we use to support the array API, and do we want to encourage any patterns in terms of supporting JIT-ed versions, or should we not worry about this? If there are existing guidelines that people feel are optimal (e.g.
https://docs.scipy.org/doc/scipy/dev/api-dev/array_api.html), we should probably link to them in the developer docs, or we could write our own of course.
Thanks!
Tom