Compiling Tensorflow/JAX with custom XLA optimisation passes on TPU

269 views
Skip to first unread message

art....@gmail.com

unread,
Feb 8, 2022, 10:21:27 AM2/8/22
to XLA development
Hello XLA devs,

I would like to compile Tensorflow/JAX with custom XLA optimisations and test the code on a TPU. Adding new XLA optimisations and their options was a pretty straightforward process for CPU and GPU. However, it is not clear how to get the TPU support for these optimisations. I asked similar question on https://github.com/google/jax/issues/9261, and as far as I understand, everything is hidden in the libtpu library, which is not open-sourced, meaning I will not be able to test the optimisations on TPU.

Can someone help to elaborate how to add TPU support in customised XLA code and compile it on machines with TPU?

Best regards,
Artem Artemev

Peter Hawkins

unread,
Feb 8, 2022, 10:28:10 AM2/8/22
to art....@gmail.com, XLA development
In general, you cannot at the moment. The TPU compiler is a closed-source black box, although most of the front-end HLO passes are the same ones as in open source XLA. You could add additional compiler passes that rewrite the HLO *before* the HLO is passed to the TPU compiler, but everything after that point is closed source and not user-accessible. For example, JAX runs a few MHLO lowering passes before converting MHLO to HLO and compiling it with, e.g., the TPU compiler.

I'm sure we'd be interested to hear more about the use case, though. Is it something that could be done with an HLO->HLO rewrite, ideally early in the compilation flow?

Peter

--
You received this message because you are subscribed to the Google Groups "XLA development" group.
To unsubscribe from this group and stop receiving emails from it, send an email to xla-dev+u...@googlegroups.com.
To view this discussion on the web visit https://groups.google.com/d/msgid/xla-dev/ee8fc0ab-ae9b-4a3d-9b07-e3a24041b229n%40googlegroups.com.

art....@gmail.com

unread,
Feb 9, 2022, 11:27:57 AM2/9/22
to XLA development
Interesting. The introduced optimisations can be applied before the standard list of XLA passes, but I still have to apply them within the TF code. Could you advise on how to proceed with it? Should I create a HloPassPipeline like here [1] and add it before TPU runs it own HLO transformations in here [2]?

[1] https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/tpu/tpu_on_demand_compiler.cc#L54

Best,
Artem

art....@gmail.com

unread,
Feb 9, 2022, 11:34:59 AM2/9/22
to XLA development
[Wrong references in previous email; the order is wrong]

Interesting. The introduced optimisations can be applied before the standard list of XLA passes, but I still have to apply them within the TF code. Could you advise on how to proceed with it? Should I create a HloPassPipeline like here [1] and add it before TPU runs its own HLO transformations in here [2]?

Reply all
Reply to author
Forward
0 new messages