Hey Tom --
Nathaniel (or Dan F-M, if he's still lurking here!) can probably weigh in here as well, but in my experience it's best to leave the JIT compilation to the user anyways. JIT behavior (and the interactions between JIT and JAX's other optimizations) are often very context dependent, so a user probably will generally get the best performance by putting the JIT on the outermost function call. For example, imagine a function you might want to provide in a future Astropy where coordinate objects contain array API-compatible objects:
def offset_by(skycoord, offset):
  ...logic to do the angular offset properly...
  return new_skycoord
In this hypothetical, the "skycoord" and "offset" are probably pytree or dataclass-like structures (i.e. act a bit like dictionaries), and probably they contain Quantity-like objects to store the data. There is some cost to unpacking and then repacking (the new_skycoord) objects and the underlying Quantity objects.
Now imagine a user has their own function like this, that uses only arrays:
def offset_my_coord(ra, dec, dra, ddec):
  # create skycoord
  c = SkyCoord(Quantity(ra, "deg"), ...)
  offset = ...
  new_c = offset_by(c, offset)
  return new_c.ra.degree, new_c.dec.degree
So the user's function takes in plain arrays, and outputs plain arrays. Something I've learned from Nathaniel is that if you JIT this function, much of the pytree manipulation does not impact the runtime (though it will make the compile a bit slower).Â
Anyways, that might be a bit of a convoluted example, but I think the principle still holds that it's best to leave JIT up to the user, so fine to use the array API like you were thinking / as in my demo.
best,
Adrian