Array API guidelines?

21 views
Skip to first unread message

Thomas Robitaille

unread,
Mar 13, 2026, 6:02:18 AMMar 13
to astropy-dev mailing list
Hi everyone,

I am starting to think about how to make some features in various packages compatible with the Array API (which allows non-Numpy arrays to be used, e.g. jax, pytorch, etc.), and I was curious if anyone here has recommendations for the best way to support this API in the astropy ecosystem

From what I can tell, the main change for a function would be that inside it we do:

    xp = array_namespace(a, b)

and then use xp.function instead of np.function afterwards. From my experiments with this, it seems to work reasonably well, but one doesn't benefit from e.g. JIT with jax.

So then I tried the idea of having the original array API kernel function, a jit-ed wrapper, and then a wrapper on top of that which selects the right one specifically if jax arrays are being used, and it does work, but I'm not sure if it's the optimal pattern. Another option would be that we could simply provide in the public API the original function and a version that uses JIT-ed versions for jax, but that seems messy.

So I guess my question is – what guidelines should we use to support the array API, and do we want to encourage any patterns in terms of supporting JIT-ed versions, or should we not worry about this? If there are existing guidelines that people feel are optimal (e.g. https://docs.scipy.org/doc/scipy/dev/api-dev/array_api.html), we should probably link to them in the developer docs, or we could write our own of course.

Thanks!
Tom






Adrian Price-Whelan

unread,
Mar 13, 2026, 10:04:31 AMMar 13
to astro...@googlegroups.com
Hi Tom, all --

Cool!! I'm so glad to hear you are thinking about this!

I agree we should not have multiple implementations of functions. But why do you say that using the xp.function approach wouldn't benefit from JIT in, e.g., jax? Isn't that the whole benefit of the Array API? If you have a pure function implemented with the xp = array_namespace(a, b) functions, you should be able to then use jax.jit() on the function. Here's a very simple demo: https://gist.github.com/adrn/691718dcd9aa2af064d9283b8825ae32

But maybe I misunderstood your question?

best,
Adrian

--
You received this message because you are subscribed to the Google Groups "astropy-dev" group.
To unsubscribe from this group and stop receiving emails from it, send an email to astropy-dev...@googlegroups.com.
To view this discussion visit https://groups.google.com/d/msgid/astropy-dev/CAGMHX_1Mm-2g4emCp_MGgGZvVZrOJLXLk31j-PdXYMrXFek1fg%40mail.gmail.com.


--
Adrian M. Price-Whelan (he / him)
Research Scientist @ CCA/Flatiron Institute
Asst. Director for Scientific Software @ Simons Foundation

Thomas Robitaille

unread,
Mar 13, 2026, 5:38:02 PMMar 13
to astro...@googlegroups.com
Hi Adrian,

On Fri, 13 Mar 2026 at 14:04, Adrian Price-Whelan <adri...@gmail.com> wrote:
Hi Tom, all --

Cool!! I'm so glad to hear you are thinking about this!

I agree we should not have multiple implementations of functions. But why do you say that using the xp.function approach wouldn't benefit from JIT in, e.g., jax? Isn't that the whole benefit of the Array API? If you have a pure function implemented with the xp = array_namespace(a, b) functions, you should be able to then use jax.jit() on the function. Here's a very simple demo: https://gist.github.com/adrn/691718dcd9aa2af064d9283b8825ae32

Sorry I should have explained better - indeed calling jax.jit(...) on the function is easy, my question is more in the context of a package providing functions whether it's best to leave it up to the user to jax.jit functions manually or whether we can provide the jax.jit-ed functions as part of the package API.

Cheers,
Tom

Adrian Price-Whelan

unread,
Mar 19, 2026, 8:37:26 AMMar 19
to astro...@googlegroups.com
Hey Tom --

Nathaniel (or Dan F-M, if he's still lurking here!) can probably weigh in here as well, but in my experience it's best to leave the JIT compilation to the user anyways. JIT behavior (and the interactions between JIT and JAX's other optimizations) are often very context dependent, so a user probably will generally get the best performance by putting the JIT on the outermost function call. For example, imagine a function you might want to provide in a future Astropy where coordinate objects contain array API-compatible objects:

def offset_by(skycoord, offset):
    ...logic to do the angular offset properly...
    return new_skycoord

In this hypothetical, the "skycoord" and "offset" are probably pytree or dataclass-like structures (i.e. act a bit like dictionaries), and probably they contain Quantity-like objects to store the data. There is some cost to unpacking and then repacking (the new_skycoord) objects and the underlying Quantity objects.

Now imagine a user has their own function like this, that uses only arrays:

def offset_my_coord(ra, dec, dra, ddec):
    # create skycoord
    c = SkyCoord(Quantity(ra, "deg"), ...)
    offset = ...
    new_c = offset_by(c, offset)
    return new_c.ra.degree, new_c.dec.degree

So the user's function takes in plain arrays, and outputs plain arrays. Something I've learned from Nathaniel is that if you JIT this function, much of the pytree manipulation does not impact the runtime (though it will make the compile a bit slower). 

Anyways, that might be a bit of a convoluted example, but I think the principle still holds that it's best to leave JIT up to the user, so fine to use the array API like you were thinking / as in my demo.

best,
Adrian


Dan Foreman-Mackey

unread,
Mar 22, 2026, 10:31:09 AMMar 22
to astro...@googlegroups.com
> or Dan F-M, if he's still lurking here!

🫣

> a user probably will generally get the best performance by putting the JIT on the outermost function call.

This is a good rule of thumb, but it's not quite the whole story. AFAIK, having jit decorators below the top level never hurts, and can significantly improve JAX tracing time because the traced function will be cached. So, while the best advice is probably to put a jit at the top level, it's possible that adding conditional jits on some astropy functions that are expensive to trace (and likely to be traced multiple times within a program) you would get better real world performance.

I'm a bit out of the loop with the array API standard, but it seems like it wouldn't be too hard to have some sort of backend-conditional "maybe_jit" decorator, but probably not a big deal unless you're seeing compile-time performance issues.

Reply all
Reply to author
Forward
0 new messages