It's just an idea, the motivation is that a lot of people is using JAX, and many exisiting libraries are written in JAX. If we could somehow bring JAX functions into Nx, it will open up a whole new ecosystem to us.
Intuitively, it's possible, as JAX program are not normal python program, and JAX is basically a frontend for XLA. So we could spin up a python environment just to get the
IR of a JAX function, sent it back to elixir, and wrapped it with Nx's defn. However, since I am not familiar with Nx's architecture, I am not sure how this could be implemented.
What do you think? Is it even possible to implement?
Bill Huang