Let me offer yet another approach: there is the
XlaBuilder library (see description of the
operations here), that you can use to compose computations in C++ (a more bare-bones approach than Jax, but requires no Python). It will create a protobuf that describes your computation, that you can then feed to any PJRT plugin (CPU, GPU, etc.) to JIT compile. Presumably it should work with other closed source PJRT plugins (TPU, Intel's I believe) as well. See
pjrt_c_api.h file for the PJRT plugin (.so file in linux) C API, but I think there is a simpler binding/wrapper for C++ in that directory.
My experience has been that this works really nice -- I've done a lot built just on that. In particular, a Go binding to both XlaBuilder and PJRT (see
github.com/gomlx/gopjrt) that you may want to take a look at for the Bazel's BUILD/WORKSPACE files. Now, there are no tutorials and the documentation is sometimes not detailed enough. So every now and then there are still details I need to guess, look at how Jax is doing things, or ask here.
cheers!