StatefulPartitionedCall on an XLA device

252 views
Skip to first unread message

Moshe Maor

unread,
Nov 7, 2021, 1:31:55 PM11/7/21
to XLA development

Hi, 

We have a backend implementation that is based on XLA device. At that point I am getting an error that is based on the fact that bad data is being sent from host to device, after the StatefulPartitionedCall is completed on the device.

Interestingly, I see that StatefulPartitionedCall is actually executed on the device and not only that, it is running the entire XLA compilation path! (which doesn’t make sense as this should have been a fallback flow) – does that have anything to do with the fact that I am registering my device as XLA device and therefore the XlaCompileOnDemandOp kicks in to handle it?...

In addition, the result of the computation is remaining on the device but somehow TF thinks that they are located on the CPU and issue a cpu to device transfer (which could have made sense, maybe, for non XLA execution of the StatefulPartitionedCall path – like what I see for the same script when running on GPU with XLA autoclustering enabled).

Any thoughts? 

Thanks, 

Moshe 

George Karpenkov

unread,
Nov 11, 2021, 10:49:10 PM11/11/21
to Moshe Maor, XLA development
Hi Moshe,

I think we've discussed this a while ago already, I think there should almost never be a reason to implement an XLA device or to use XlaCompileOnDemandOp,
why not implement XLA_JIT device and use jit_compile=True instead?

George

--
You received this message because you are subscribed to the Google Groups "XLA development" group.
To unsubscribe from this group and stop receiving emails from it, send an email to xla-dev+u...@googlegroups.com.
To view this discussion on the web visit https://groups.google.com/d/msgid/xla-dev/85c9aa43-9589-4faf-9e87-254e93dc920bn%40googlegroups.com.

Moshe Maor

unread,
Nov 17, 2021, 6:26:05 AM11/17/21
to XLA development
Hi George, 

let me see if I understand the direction you're promoting and correct me if I'm wrong: 

- start with a regular streamExecutor based device 'D' (and not an XLA device) 
- implement JIT device and register it for that device
- register xlalaunch kernel for device D (no need for compile and run if we're not using autoclustering - right? )
- should we also implement other kernels? and if not - what is expected to happen to content that is not supported by XLA like regular tf function and eager? 
- whenever possible, use jit_compile=True to force XLA compilation for that device 
- but - how can we force functions to use the  jit_compile=True path if we're not explicitly coding them in the script like with functions that are generated by keras for compile, fit etc? 

Thanks, 
Moshe 

George Karpenkov

unread,
Nov 22, 2021, 5:13:42 PM11/22/21
to Moshe Maor, XLA development
Hi Moshe,

I think I'm missing a lot of context on what would you like to do --- can you send a general "introductions" email with your plan and the current progress? I think then answering questions would be a lot easier.

In general, as far as I understand, you are writing an XLA backend for an accelerator.
In that case, writing the backend, as far as I understand, should be 99% of the work, and then once it is done, you could use whatever suits your needs to convert TF code to XLA?
E.g. registering an XLA device is a single line of code, and if it is convenient for you, you should be able to do it.

George

Moshe Maor

unread,
Nov 23, 2021, 1:10:58 AM11/23/21
to XLA development
Hi George, 

Let me give a bit of a context and for the questions I have: 

We currently have a TF backend for our accelerator, which is not based on XLA. We integrate with TF on top of a device which is not a standard SE device i.e. proprietary integration with TF.  (this is because up to some time ago one couldn't have register an SE (streamexecutor) device (and platform) from a loadable plugin...)
We want to implement an XLA based backend in parallel to current backend. later, we would like to completely replace current backend with that XLA backend. we already have a working PoC of XLA backend on top of an XLA-device for our accelerator. 

per my understanding, the long term direction from Google is:
[a] not to instantiate an XlaDevice and prefer standard SE device - as you're deprecating the currently implemented XLA_* devices for cpu/gpu. --> is that correct? 
     this is a bit confusing as I guess TPU will remain XlaDevice and will have only an XLA backend? 
[b] refrain from using autoclustering and prefer must-compile based XLA activation. I saw comments along this line from you and Sanjoy. I guess that means that the direction is to explicitly annotate with must compile every function in the script. 

1. if [a] is correct - then I guess the direction is to implement a simple SE device and then register the XLA backend to it (like the registration you're doing here https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/tf2xla/xla_op_registry.cc;l=156-170). but then - how do we support eager and tf function content on a regular SE device? the only way to do so seems to be to use the RegisterXlaDeviceKernels() method, to enable compiling and executing of ops coming from eager and non-compiled functions. is that the intention? 

2. per [b] how do I force all function content in a script to be must-compile? unlike autoclustering, I don't see a global method of forcing that. Taking into account that a lot of function content is generated automatically (e,g, when script is using keras to build a training graph) it seems like this approach means that on many scripts we would not be able to enjoy XLA clustering. I would expect that there is either a global way to force it and/or a way to turn on/off must-compile per script regions. like tf.config.optimizer.set_jit('mustcompile'), tf.config.optimizer.set_jit('off') that will allow users to force regions in the script to use the XLA path. 

3. Can we add XLA ops (from TF) via a plugin? it seems like mark for compilation has a static list of allowed TF ops for clustering. What's the direction on that? 

Thanks, 
Moshe 

George Karpenkov

unread,
Nov 23, 2021, 7:48:19 PM11/23/21
to Moshe Maor, XLA development
For a new device in your case it probably does actually make sense to define an XLA device then (to guarantee that everything does happen on that device),
especially if you want to replace the "classic" TF backend with an XLA backend.

Alternatively, all entry points could be wrapped with `jit_compile=True` (similarly how one has to use tpu.rewrite to run the code on TPUs).

> 3. Can we add XLA ops (from TF) via a plugin? it seems like mark for compilation has a static list of allowed TF ops for clustering. What's the direction on that? 

I honestly don't know, sorry --- could you try and see whether it works? You could also try asking at TF SIG build (https://github.com/tensorflow/build).

Moshe Maor

unread,
Nov 24, 2021, 3:11:33 AM11/24/21
to XLA development
Thanks George, 

I understand that defining an XLA device makes sense to enable XLA only content to run on device and to get eager and tf function content execute via the xlaCompileOnDemand kernel. 
however, this would not lower a complete tf function to an XLA computation if it is not annotated with jit_compile=True, right? 
So still, I am looking for the best method to force auto-declaring all functions with jit_compile=True. Is that something that exist today? do you have it for the TPU? 
e.g. how does you force XLA compilation for the functions that are automatically generated by Keras model.compile and model.fit? after all, the script may use only the Keras high level API and not explicitly define any function explicitly. 

if there is no existing mechanism, does it makes sense to propose a PR that forces automated jit_compile=True annotation to any generated function in regions between tf.config.optimizer.set_jit('mustcompile') and tf.config.optimizer.set_jit('off')? 

Thanks, 
Moshe 

George Karpenkov

unread,
Nov 29, 2021, 12:13:04 PM11/29/21
to Moshe Maor, XLA development
Hi Moshe,

On Wed, Nov 24, 2021 at 12:11 AM Moshe Maor <moshe...@gmail.com> wrote:
Thanks George, 

I understand that defining an XLA device makes sense to enable XLA only content to run on device and to get eager and tf function content execute via the xlaCompileOnDemand kernel. 
however, this would not lower a complete tf function to an XLA computation if it is not annotated with jit_compile=True, right? 
So still, I am looking for the best method to force auto-declaring all functions with jit_compile=True. Is that something that exist today? do you have it for the TPU?  
e.g. how does you force XLA compilation for the functions that are automatically generated by Keras model.compile and model.fit? after all, the script may use only the Keras high level API and not explicitly define any function explicitly. 

if there is no existing mechanism, does it makes sense to propose a PR that forces automated jit_compile=True annotation to any generated function in regions between tf.config.optimizer.set_jit('mustcompile') and tf.config.optimizer.set_jit('off')? 


We can generalize the check above to allow for custom user-defined XLA devices as well.

George
 

Moshe Maor

unread,
Nov 29, 2021, 12:32:54 PM11/29/21
to XLA development
Hi George, 

Indeed, it makes sense to generalize this approach and implicitly force must-compile for all XLA devices, as this is the main reason for such a device.

But on top of that, what do you think about the idea of setting must-compile on/off in regions of the script with a tf.config.optimizer.set_jit(...) directive, similar to autoclustering? 
(Such a directive should override the logic in the code you shared)

that way the user has full control over the behavior for both XLA devices as well as any other device that happen to have an XLA backend registered to it. 

Thanks, 
Moshe 

George Karpenkov

unread,
Nov 29, 2021, 12:47:17 PM11/29/21
to Moshe Maor, XLA development
> But on top of that, what do you think about the idea of setting must-compile on/off in regions of the script with a tf.config.optimizer.set_jit(...) directive, similar to autoclustering? 

No, I don't think that makes sense. Global contexts which alter local behavior in non-obvious ways are bad, we should have less of them.

Moshe Maor

unread,
Nov 29, 2021, 1:09:39 PM11/29/21
to XLA development
  >  No, I don't think that makes sense. Global contexts which alter local behavior in non-obvious ways are bad, we should have less of them.

But that would be align with autoclustering activation via tf.config.optimizer.set_jit for a hybrid device that has both XLA backend and a native backend (like the native CPU/GPU devices that support XLA compilation via autoclustering).
Do you prefer a different approach for such a control over must-compile for hybrid devices ?

George Karpenkov

unread,
Nov 29, 2021, 1:42:42 PM11/29/21
to Moshe Maor, XLA development
> But that would be align with autoclustering activation via tf.config.optimizer.set_jit for a hybrid device that has both XLA backend and a native backend (like the native CPU/GPU devices that support XLA compilation via autoclustering).

Yes, and we think that tf.config.optimizer.set_jit is an incredibly confusing API which probably should not be used (e.g. it's cached in very non-obvious ways, it's unclear what the scope is, etc etc).

> Do you prefer a different approach for such a control over must-compile for hybrid devices ?

Yes, does it make sense to define an XLA device for your backend, and then force compilation of all functions when run on that device?

Moshe Maor

unread,
Nov 29, 2021, 2:11:50 PM11/29/21
to XLA development
>  Yes, does it make sense to define an XLA device for your backend, and then force compilation of all functions when run on that device?

For the long term, this definitely sounds like the proper direction. and for that we'll have to have a change in TF for auto-must-compile for all XLA devices. 

However, I'm also looking for a mid-term solution while we're bringing up a complete XLA backend, where we have regular device that runs either XLA compilation for some content and fallbacks/routed to regular (currently implemented) native implementation for all the rest. This mid-term implementation should behave like the current CPU or GPU device with native backend that can also execute XLA compiled compilations. For such a device with both native and XLA backends we need some good method of switching between and controlling the execution of either native or XLA. 
Reply all
Reply to author
Forward
0 new messages