Hi,
We have a backend implementation that is based on XLA device. At that point I am getting an error that is based on the fact that bad data is being sent from host to device, after the StatefulPartitionedCall is completed on the device.
Interestingly, I see that StatefulPartitionedCall is actually executed on the device and not only that, it is running the entire XLA compilation path! (which doesn’t make sense as this should have been a fallback flow) – does that have anything to do with the fact that I am registering my device as XLA device and therefore the XlaCompileOnDemandOp kicks in to handle it?...
In addition, the result of the computation is remaining on the device but somehow TF thinks that they are located on the CPU and issue a cpu to device transfer (which could have made sense, maybe, for non XLA execution of the StatefulPartitionedCall path – like what I see for the same script when running on GPU with XLA autoclustering enabled).
Any thoughts?
Thanks,
Moshe
--
You received this message because you are subscribed to the Google Groups "XLA development" group.
To unsubscribe from this group and stop receiving emails from it, send an email to xla-dev+u...@googlegroups.com.
To view this discussion on the web visit https://groups.google.com/d/msgid/xla-dev/85c9aa43-9589-4faf-9e87-254e93dc920bn%40googlegroups.com.
To view this discussion on the web visit https://groups.google.com/d/msgid/xla-dev/b0dd3031-0a18-4ae3-aa64-8ee6fc5e8668n%40googlegroups.com.
To view this discussion on the web visit https://groups.google.com/d/msgid/xla-dev/fec02cd4-0b98-4dda-afe3-afe4b16918edn%40googlegroups.com.
Thanks George,I understand that defining an XLA device makes sense to enable XLA only content to run on device and to get eager and tf function content execute via the xlaCompileOnDemand kernel.however, this would not lower a complete tf function to an XLA computation if it is not annotated with jit_compile=True, right?So still, I am looking for the best method to force auto-declaring all functions with jit_compile=True. Is that something that exist today? do you have it for the TPU?
e.g. how does you force XLA compilation for the functions that are automatically generated by Keras model.compile and model.fit? after all, the script may use only the Keras high level API and not explicitly define any function explicitly.if there is no existing mechanism, does it makes sense to propose a PR that forces automated jit_compile=True annotation to any generated function in regions between tf.config.optimizer.set_jit('mustcompile') and tf.config.optimizer.set_jit('off')?
To view this discussion on the web visit https://groups.google.com/d/msgid/xla-dev/861fefd6-c927-4d82-b696-912966ec822an%40googlegroups.com.
To view this discussion on the web visit https://groups.google.com/d/msgid/xla-dev/d20ea208-23b4-47be-bf41-ec38391644f6n%40googlegroups.com.
To view this discussion on the web visit https://groups.google.com/d/msgid/xla-dev/da5bbf51-423e-45cb-8c57-ea097a74b00en%40googlegroups.com.