Array API guidelines?

8 views
Skip to first unread message

Thomas Robitaille

unread,
Mar 13, 2026, 6:02:18 AM (5 days ago) Mar 13
to astropy-dev mailing list
Hi everyone,

I am starting to think about how to make some features in various packages compatible with the Array API (which allows non-Numpy arrays to be used, e.g. jax, pytorch, etc.), and I was curious if anyone here has recommendations for the best way to support this API in the astropy ecosystem

From what I can tell, the main change for a function would be that inside it we do:

    xp = array_namespace(a, b)

and then use xp.function instead of np.function afterwards. From my experiments with this, it seems to work reasonably well, but one doesn't benefit from e.g. JIT with jax.

So then I tried the idea of having the original array API kernel function, a jit-ed wrapper, and then a wrapper on top of that which selects the right one specifically if jax arrays are being used, and it does work, but I'm not sure if it's the optimal pattern. Another option would be that we could simply provide in the public API the original function and a version that uses JIT-ed versions for jax, but that seems messy.

So I guess my question is – what guidelines should we use to support the array API, and do we want to encourage any patterns in terms of supporting JIT-ed versions, or should we not worry about this? If there are existing guidelines that people feel are optimal (e.g. https://docs.scipy.org/doc/scipy/dev/api-dev/array_api.html), we should probably link to them in the developer docs, or we could write our own of course.

Thanks!
Tom






Adrian Price-Whelan

unread,
Mar 13, 2026, 10:04:31 AM (5 days ago) Mar 13
to astro...@googlegroups.com
Hi Tom, all --

Cool!! I'm so glad to hear you are thinking about this!

I agree we should not have multiple implementations of functions. But why do you say that using the xp.function approach wouldn't benefit from JIT in, e.g., jax? Isn't that the whole benefit of the Array API? If you have a pure function implemented with the xp = array_namespace(a, b) functions, you should be able to then use jax.jit() on the function. Here's a very simple demo: https://gist.github.com/adrn/691718dcd9aa2af064d9283b8825ae32

But maybe I misunderstood your question?

best,
Adrian

--
You received this message because you are subscribed to the Google Groups "astropy-dev" group.
To unsubscribe from this group and stop receiving emails from it, send an email to astropy-dev...@googlegroups.com.
To view this discussion visit https://groups.google.com/d/msgid/astropy-dev/CAGMHX_1Mm-2g4emCp_MGgGZvVZrOJLXLk31j-PdXYMrXFek1fg%40mail.gmail.com.


--
Adrian M. Price-Whelan (he / him)
Research Scientist @ CCA/Flatiron Institute
Asst. Director for Scientific Software @ Simons Foundation

Thomas Robitaille

unread,
Mar 13, 2026, 5:38:02 PM (5 days ago) Mar 13
to astro...@googlegroups.com
Hi Adrian,

On Fri, 13 Mar 2026 at 14:04, Adrian Price-Whelan <adri...@gmail.com> wrote:
Hi Tom, all --

Cool!! I'm so glad to hear you are thinking about this!

I agree we should not have multiple implementations of functions. But why do you say that using the xp.function approach wouldn't benefit from JIT in, e.g., jax? Isn't that the whole benefit of the Array API? If you have a pure function implemented with the xp = array_namespace(a, b) functions, you should be able to then use jax.jit() on the function. Here's a very simple demo: https://gist.github.com/adrn/691718dcd9aa2af064d9283b8825ae32

Sorry I should have explained better - indeed calling jax.jit(...) on the function is easy, my question is more in the context of a package providing functions whether it's best to leave it up to the user to jax.jit functions manually or whether we can provide the jax.jit-ed functions as part of the package API.

Cheers,
Tom

Reply all
Reply to author
Forward
0 new messages