XLA device and autoclustering

169 views
Skip to first unread message

Moshe Maor

unread,
Jul 8, 2021, 6:22:30 AM7/8/21
to XLA development
Hi all, 

it seems, per the code, that XLA_* devices are being deprecated, and that the proper way to run through XLA compiler is by triggering the autocluster. a couple of questions re that: 

1. how does the mark-for-compilation decide which XLA JIT device to use to lower the cluster and compile it, if the registration of the XLA JIT devices are against XLA_* devices during device registration? 
2. it seems like DeviceInfoCache::GetIdFor has some hacked way of re-assigning already registered JIT devices to existing CPU/GPU devices. is that related? 
3. if yes - how can that be supported for other XLA backends? 
4. per deprecation of XLA_* devices, what is the flow for eager execution via XLA device? 

Thanks, 
Moshe 

George Karpenkov

unread,
Jul 21, 2021, 6:46:21 PM7/21/21
to Moshe Maor, XLA development
hi Moshe,

Some of these questions sound like an XY problem, could you describe what would you like to achieve and we might be able to help?

In particular, questions (1) and (4) seem to make incorrect assumptions.

George

--
You received this message because you are subscribed to the Google Groups "XLA development" group.
To unsubscribe from this group and stop receiving emails from it, send an email to xla-dev+u...@googlegroups.com.
To view this discussion on the web visit https://groups.google.com/d/msgid/xla-dev/80b4fea5-cf6c-4b5d-a987-db041d41744fn%40googlegroups.com.

Moshe Maor

unread,
Jul 22, 2021, 2:19:46 AM7/22/21
to XLA development
Hi George, 

my end goal is to decide whether to implement an XLA backend for an accelerator rather than implementing a 'native' TF integration. to decide that, I'm trying to better understand XLA 'under the hood', from various aspects. One of these aspects is the XLA-device vs. XLA-compilation-device where it seems that an implementation can take one of two directions of either implementing/registering XLA device along with an XLA backend (like TPU) and the other is to register a 'regular' device (like CPU or GPU) along with XLA backend for that device, and then "somehow" tell XLA that the device has a JIT compilation backend that can be used as required. if i understand correctly, this would allow a hybrid execution where clustered regions of graphs would go via the XLA path while non supported ops would take the native TF executor path over either that accelerator or the CPU. 

Thanks, 
Moshe 

George Karpenkov

unread,
Aug 1, 2021, 12:53:02 PM8/1/21
to Moshe Maor, XLA development
Hi Moshe,

On Wed, Jul 21, 2021 at 11:19 PM Moshe Maor <moshe...@gmail.com> wrote:

my end goal is to decide whether to implement an XLA backend for an accelerator rather than implementing a 'native' TF integration.

Great!
 
to decide that, I'm trying to better understand XLA 'under the hood', from various aspects. One of these aspects is the XLA-device vs. XLA-compilation-device where it seems that an implementation can take one of two directions of either implementing/registering XLA device along with an XLA backend (like TPU) and the other is to register a 'regular' device (like CPU or GPU) along with XLA backend for that device, and then "somehow" tell XLA that the device has a JIT compilation backend that can be used as required.

No.

XLA-compilation-device is an implementation detail, and is not something you need to implement per device. Again "XLA-device" is something which does op-by-op translation and execution, is only supported on TPUs, and is not something you would need to do or change. For your backend you should only need to add a single line adding xla-compilation-device registration.

All you need to do for the new accelerator is to implement a new backend in compiler/xla and add registrations for it.
 
if i understand correctly, this would allow a hybrid execution where clustered regions of graphs would go via the XLA path while non supported ops would take the native TF executor path over either that accelerator or the CPU. 

Yes. There are multiple interfaces used to connect TF and XLA, and all of them can be used for that purpose (either using "jit_compile=True" to "manually cluster" compiled ops, or using "autoclustering" to let a compiler pass do it for you, or even doing op-by-op execution,though it's not recommended).
 

Thanks, 
Moshe 


On Thursday, July 22, 2021 at 1:46:21 AM UTC+3 ches...@google.com wrote:
hi Moshe,

Some of these questions sound like an XY problem, could you describe what would you like to achieve and we might be able to help?

In particular, questions (1) and (4) seem to make incorrect assumptions.

George

On Thu, Jul 8, 2021 at 3:22 AM Moshe Maor <moshe...@gmail.com> wrote:
Hi all, 

it seems, per the code, that XLA_* devices are being deprecated, and that the proper way to run through XLA compiler is by triggering the autocluster. a couple of questions re that: 

1. how does the mark-for-compilation decide which XLA JIT device to use to lower the cluster and compile it, if the registration of the XLA JIT devices are against XLA_* devices during device registration? 
2. it seems like DeviceInfoCache::GetIdFor has some hacked way of re-assigning already registered JIT devices to existing CPU/GPU devices. is that related? 
3. if yes - how can that be supported for other XLA backends? 
4. per deprecation of XLA_* devices, what is the flow for eager execution via XLA device? 

Thanks, 
Moshe 

--
You received this message because you are subscribed to the Google Groups "XLA development" group.
To unsubscribe from this group and stop receiving emails from it, send an email to xla-dev+u...@googlegroups.com.
To view this discussion on the web visit https://groups.google.com/d/msgid/xla-dev/80b4fea5-cf6c-4b5d-a987-db041d41744fn%40googlegroups.com.

--
You received this message because you are subscribed to the Google Groups "XLA development" group.
To unsubscribe from this group and stop receiving emails from it, send an email to xla-dev+u...@googlegroups.com.

Moshe Maor

unread,
Aug 2, 2021, 7:43:45 AM8/2/21
to XLA development
Hi George, 

a couple of follow up questions: 

1. Regarding the registration for a device "A" that i already have a "native" integration for (i.e., i have platform/device/streamExecutor for device A). Do I only need to perform XlaOpRegistry::RegisterCompilationDevice("A", registration) for some A-specific DeviceRegistration definition? 
2. Am I losing something for not implementing a 'real' XLA-device  (like XLA_GPU, XLA_CPU devices that were originally implemented) but rather hooking an XLA backend to a 'regular' device?
3. when using this flow, do I have to explicitly register the special XLA kernels for this 'A' device? I see CPU/GPU are registering these: _XlaCompile/_XlaRun/_XlaMerge. are there any others required? 

Thanks, 
Moshe 

George Karpenkov

unread,
Aug 3, 2021, 1:22:20 PM8/3/21
to Moshe Maor, XLA development
On Mon, Aug 2, 2021 at 4:43 AM Moshe Maor <moshe...@gmail.com> wrote:
Hi George, 

a couple of follow up questions: 

1. Regarding the registration for a device "A" that i already have a "native" integration for (i.e., i have platform/device/streamExecutor for device A). Do I only need to perform XlaOpRegistry::RegisterCompilationDevice("A", registration) for some A-specific DeviceRegistration definition? 

Yes ideally just replicating this section https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/jit/xla_gpu_device.cc;l=80-93?q=file:xla_gpu_device should be enough to give you tf2xla conversion.
 
2. Am I losing something for not implementing a 'real' XLA-device  (like XLA_GPU, XLA_CPU devices that were originally implemented) but rather hooking an XLA backend to a 'regular' device?

No, those are deprecated and disabled by default, you have to pass a special flag to enable them.
 
3. when using this flow, do I have to explicitly register the special XLA kernels for this 'A' device? I see CPU/GPU are registering these: _XlaCompile/_XlaRun/_XlaMerge. are there any others required? 

Those are device-agnostic and should be registered for you. With tf.function(jit_compile=True) only a single XlaLocalLaunch kernel is currently required.
 

Moshe Maor

unread,
Aug 3, 2021, 2:51:02 PM8/3/21
to George Karpenkov, XLA development
Thanks for the clarifications George, 

Re third item - registering 'special' XLA ops, what entity is registering those for me? I do see explicit registration for compile/run (that are required for the 'main' clustering flow) for CPU/GPU/TPU devices e.g. here:
and also another set of XLA* ops registration for both CPU/GPU under:

Another question that bothers me is whether we can implement the entire device and backend objects in a loadable plugin. 
It seems that there are several locations that require change inside TF itself, such as the platform kind enumeration (enum+string+some functions around it). 
Is that on purpose? I don't see any way around the need to change this to add a new platform type, although the rest of the registrations seem to be less hardcoded into the code. 

Thanks, 
Moshe 

George Karpenkov

unread,
Aug 3, 2021, 7:04:55 PM8/3/21
to Moshe Maor, XLA development
Hi Moshe,

>  I do see explicit registration for compile/run (that are required for the 'main' clustering flow) for CPU/GPU/TPU devices e.g. here:
Sorry, you are right, they still have to be registered per-device.
 
> Another question that bothers me is whether we can implement the entire device and backend objects in a loadable plugin. 
That I have no idea, you would have to ask on TensorFlow forums.
I think there were attempts to do backends for custom platforms before, it might also be possible to shift enumeration to registration-based approaches if you can specify concrete ones.
Reply all
Reply to author
Forward
0 new messages