Groups keyboard shortcuts have been updated
Dismiss
See shortcuts

Creating a scientific sim on OpenXLA

96 views
Skip to first unread message

Maksym Riabov

unread,
Sep 13, 2024, 7:24:54 AM9/13/24
to OpenXLA Discuss
Sup OpenXLA community, I'm new to this library.
I worked with Jax and like it, but now I have a need to work with C++ code - I have a scientific simulator that should gain from a lot of matrix multiplications, vectorization and the like. It's a very numerically intensive program - on CPU the execution usually takes 2 days to sim, and I want to put on GPU.

Instead of doing a raw CUDA optimization, I consider doing XLA optimization, which should handle it more effectively and be easier?
Is XLA suited for non-ML loads? Can I use vmap like in Jax? Are there large limitations to it?

Thank you.

George Karpenkov

unread,
Sep 13, 2024, 11:20:21 AM9/13/24
to Maksym Riabov, OpenXLA Discuss
Welcome!

There is no way to optimize C++ or CUDA code with XLA directly. If you can rewrite it using JAX primitives (maybe Pallas for custom operations on top), then you can use XLA/vmap/etc. The main limitation is the opset of supported operations: all operations have to be expressed either in HLO (https://openxla.org/stablehlo/spec), or written as custom kernels (which could be anything, but optimizations won't apply).

George

--
You received this message because you are subscribed to the Google Groups "OpenXLA Discuss" group.
To unsubscribe from this group and stop receiving emails from it, send an email to openxla-discu...@openxla.org.
To view this discussion on the web visit https://groups.google.com/a/openxla.org/d/msgid/openxla-discuss/b512b714-c284-4f42-9780-942aa03f3e86n%40openxla.org.
For more options, visit https://groups.google.com/a/openxla.org/d/optout.

Frederic Bastien CA

unread,
Sep 13, 2024, 11:37:11 AM9/13/24
to OpenXLA Discuss, George Karpenkov, OpenXLA Discuss, Maksym Riabov
Hi,

First, write in JAX as George suggested. If you find that isn't fast enough after profiling, you can write custom kernel in a few different ways:

- If block based kernel (i.e. matmul/conv), look at Pallas or Triton
- If thread based computation, look at warp/jax interop. It was useful in another simulator project.
- If you already have CUDA kernel, look at JAX custom-primitive documentation:

Jan Pfeifer

unread,
Sep 14, 2024, 11:34:21 AM9/14/24
to Frederic Bastien CA, OpenXLA Discuss, George Karpenkov, Maksym Riabov
Let me offer yet another approach: there is the XlaBuilder library (see description of the operations here), that you can use to compose computations in C++ (a more bare-bones approach than Jax, but requires no Python). It will create a protobuf that describes your computation, that you can then feed to any PJRT plugin (CPU, GPU, etc.) to JIT compile. Presumably it should work with other closed source PJRT plugins (TPU, Intel's I believe) as well. See pjrt_c_api.h file for the PJRT plugin (.so file in linux)  C API, but I think there is a simpler binding/wrapper for C++ in that directory.

My experience has been that this works really nice -- I've done a lot built just on that. In particular, a Go binding to both XlaBuilder and PJRT (see github.com/gomlx/gopjrt) that you may want to take a look at for the Bazel's BUILD/WORKSPACE files. Now, there are no tutorials and the documentation is sometimes not detailed enough. So every now and then there are still details I need to guess, look at how Jax is doing things, or ask here.

cheers!



Reply all
Reply to author
Forward
0 new messages