Idea: JAX interoperability through its IR

21 views
Skip to first unread message

bill huang

unread,
Sep 1, 2023, 8:10:42 AM9/1/23
to Numerical Elixir (Nx)
It's just an idea, the motivation is that a lot of people is using JAX, and many exisiting libraries are written in JAX. If we could somehow bring JAX functions into Nx, it will open up a whole new ecosystem to us.

Intuitively, it's possible, as JAX program are not normal python program, and JAX is basically a frontend for XLA. So we could spin up a python environment just to get the IR of a JAX function, sent it back to elixir, and wrapped it with Nx's defn. However, since I am not familiar with Nx's architecture, I am not sure how this could be implemented.

What do you think? Is it even possible to implement?

Bill Huang

bill huang

unread,
Sep 7, 2023, 1:02:19 AM9/7/23
to Numerical Elixir (Nx)
I did some digging, JAX has `make_jaxpr` that can return a expression that is similar to `Nx.Defn.Expr`, so I just need to find a way to interpret jaxpr.
One way to do this is directly translating jaxpr into Defn.Expr, and invoke the jit compiler at elixir side.
Another way would be expanding the jaxpr into a regular elixir/nx program through the use of macro. This is more costly, but the result feels more like a native Nx function, and we might be able to use a different backend like torchx.
Reply all
Reply to author
Forward
0 new messages