Implement XLA device for only some ops?

533 views
Skip to first unread message

Christian Convey

unread,
Oct 2, 2017, 9:13:45 PM10/2/17
to XLA development
I'm taking my first stab at writing a custom XLA device, based on TF 1.3.  I'm trying to work incrementally, and so I'm trying to override just one method of xla::DfsHloVisitorWithDefault at a time.

In order to work incrementally, I'm trying to have TF/XLA only assign ops to my device when my device can truly handle them.  Otherwise some feature-complete device, such as (non-XLA) "cpu", should handle that node.

I only know of one way to indicate the subset of functionality that my device can support: selectively returning true or false from the callback function I supply to the "REGISTER_XLA_BACKEND" macro.  So my code looks roughly like the following:

class MyEmitter : public xla::DfsHloVisitorWithDefault {
   xla
::Status DefaultAction(xla::HloInstruction* hlo) override {
     
return xla::Unimplemented("TBD");
   
}
};


static bool OpFilter(KernelDef* kdef) {
   
if ((kdef->op() == "BroadcastArgs") || (kdef->op() == "BroadcastGradientArgs")) {
     
return false;
   
}

   
return true;
}

REGISTER_XLA_BACKEND( MY_DEVICE_NAME_JIT, MY_SUPPORTED_SCALAR_TYPES, OpFilter );


Assuming that I'm on the right track at all, my problem is that I'm unclear on the relationship between (the "KernelDef" objects supplied to "OpFilter") and (the particular virtual methods that will later be called on my "MyEmitter" instance).

For example, even with my "OpFilter" coded as shown above, I'm still finding that my "MyEmitter::HandleBroadcast(...)" method is being called.  Which is a problem, because I'm not yet ready to provide a meaningful implementation of that HLO op.

Any suggestions?

Bjarke Roune

unread,
Oct 2, 2017, 9:44:31 PM10/2/17
to XLA development
If you just want to test each op as you write it, you can use the direct XLA test suite, which is what we usually use for work on XLA. It runs faster since it does not have TensorFlow in the loop and it is more thorough than the XLA-specific testing that goes through TensorFlow. Going direct to XLA for a test also makes it easier to hit corner cases, since you can make the XLA graph that you want directly instead of trying to get TensorFlow to more indirectly emit the XLA graph that you want.

You'll probably have an easier time debugging these single-op tests than debugging a model that's half-way running on your backend.  For the XLA broadcast op, we have


and


I'm not sure why we split this into two separate files, but combined these two files have a lot of tests for the XLA broadcast op. We have such tests for each individual XLA op. It still makes sense to also run the TensorFlow tests, though I usually only do that after getting the direct-to-XLA test suite passing. If you find any bug in XLA that is caught by the TensorFlow test suite but not the XLA one, feel free to contribute a test case for that.

Bjarke

Bjarke Roune

unread,
Oct 2, 2017, 10:00:14 PM10/2/17
to XLA development
To get a bit more into your actual question, there's at least two different reasons that you could be seeing a broadcast op without TensorFlow placing any broadcast ops on your XLA device.

The first reason is that the conversion from a TensorFlow graph to an XLA graph is not 1:1, so there are TensorFlow ops that map to multiple XLA ops. I'm not sure if there is such a case that emits a broadcast op from a non-broadcast op, but there might be. E.g. if fused batch normalization were expanded at this layer of the software stack, it would be generating a broadcast, though it's actually handled at the HLO level, bringin us to the second reason.

The second reason is that XLA does a wide range of transformations and optimizations between receiving an op from TensorFlow and passing that graph to the backend. For example here's a case where the AlgebraicSimplifier pass converts a dot op to a broadcast:


The easiest way to figure out what's going on in a particular case is to look at the XLA graph when it gets passed to XLA initially and right before it is getting passed to your backend using --xla_generate_hlo_graph :

Peter Hawkins

unread,
Oct 2, 2017, 10:10:32 PM10/2/17
to Bjarke Roune, XLA development
Hi...

On Mon, Oct 2, 2017 at 10:00 PM 'Bjarke Roune' via XLA development <xla...@googlegroups.com> wrote:
To get a bit more into your actual question, there's at least two different reasons that you could be seeing a broadcast op without TensorFlow placing any broadcast ops on your XLA device.

The first reason is that the conversion from a TensorFlow graph to an XLA graph is not 1:1, so there are TensorFlow ops that map to multiple XLA ops.

In fact, I would go further and say that most TensorFlow operators map to multiple HLO operators.
 
I'm not sure if there is such a case that emits a broadcast op from a non-broadcast op, but there might be. E.g. if fused batch normalization were expanded at this layer of the software stack, it would be generating a broadcast, though it's actually handled at the HLO level, bringin us to the second reason.

There are plenty of cases that emit HLO Broadcasts from a non-Broadcast TensorFlow op. Here are a couple of examples:

* A common idiom to fill a tensor with a scalar is to use an explicit HLO broadcast:

* Most binary operators (e.g., addition) support implicit broadcasting on the TensorFlow level, which is lowered to an explicit HLO Broadcast operator:

Peter

 
--
You received this message because you are subscribed to the Google Groups "XLA development" group.
To unsubscribe from this group and stop receiving emails from it, send an email to xla-dev+u...@googlegroups.com.
To post to this group, send email to xla...@googlegroups.com.
To view this discussion on the web visit https://groups.google.com/d/msgid/xla-dev/4a08c832-13c6-4678-b5c8-2ba42cc2a20c%40googlegroups.com.
For more options, visit https://groups.google.com/d/optout.

Christian Convey

unread,
Oct 3, 2017, 11:03:10 AM10/3/17
to XLA development
Thanks for the info, it's all very helpful. 

My impression is that XLA is meant to be independent of TF.  So XLA's operators are defined entirely by HLO, without any appeal to TF concepts.  And (perhaps wrongly) I think of KernelDef as a TF-specific concept.

So I can't tell if this is just a leaky abstraction, or if I'm misunderstanding the design.

Peter Hawkins

unread,
Oct 3, 2017, 11:07:30 AM10/3/17
to Christian Convey, XLA development
Hi...

On Tue, Oct 3, 2017 at 11:03 AM Christian Convey <christia...@gmail.com> wrote:
Thanks for the info, it's all very helpful. 

My impression is that XLA is meant to be independent of TF. 

Yes, they are. TensorFlow is a layer on top of XLA. The code is quite well separated. You can use XLA entirely independently of TensorFlow itself, although we don't exactly go out of our way to tell people about it or encourage it. They just happen to live in the same code repository for logistical reasons and TensorFlow knows how to make use of XLA.
 
So XLA's operators are defined entirely by HLO, without any appeal to TF concepts.  And (perhaps wrongly) I think of KernelDef as a TF-specific concept.

They are. XLA's operator semantics are defined here:

For each TF operator, we lower it onto one or more XLA operators. KernelDef is entirely a TF concept. Can you say more about what you mean?

Peter
 

So I can't tell if this is just a leaky abstraction, or if I'm misunderstanding the design.

--
You received this message because you are subscribed to the Google Groups "XLA development" group.
To unsubscribe from this group and stop receiving emails from it, send an email to xla-dev+u...@googlegroups.com.
To post to this group, send email to xla...@googlegroups.com.

Christian Convey

unread,
Oct 3, 2017, 11:21:27 AM10/3/17
to XLA development
On Tuesday, October 3, 2017 at 11:07:30 AM UTC-4, Peter Hawkins wrote:
They are. XLA's operator semantics are defined here:

For each TF operator, we lower it onto one or more XLA operators. KernelDef is entirely a TF concept. Can you say more about what you mean?

I'm talking about the final parameter to the REGISTER_XLA_BACKEND macro, which is a filter function that deals with KernelDef objects.

My understanding of that macro is that it's the way we tell TF about the capabilities of the XLA device being registered.  So it's surprising to me that it would be in terms of KernelDef objects, rather than HLO operators.

Have I misunderstood the intent of that macro?

Peter Hawkins

unread,
Oct 3, 2017, 11:26:54 AM10/3/17
to Christian Convey, XLA development
Hi...

You've understood the intent correctly.

However the filter is expressed in terms of "which TF ops do we support?" not "which XLA ops do we support?". This is so, for example, we can tell the TF device placement logic that device:MY_HARDWARE does not support certain TF ops.

The only device-wide filtering logic is based on types. You can say "device MY_HARDWARE supports DT_INT32 but not DT_INT64". Otherwise you'll have to manually whitelist/blacklist TF ops that work in the filtering function.

There's no easy way to identify which TF ops lower to which XLA ops other than looking at the code, but ultimately for these kinds of core ops like Broadcast you really will just have to implement them before you get very far :-)

Peter
 

--
You received this message because you are subscribed to the Google Groups "XLA development" group.
To unsubscribe from this group and stop receiving emails from it, send an email to xla-dev+u...@googlegroups.com.
To post to this group, send email to xla...@googlegroups.com.

Christian Convey

unread,
Oct 3, 2017, 11:55:20 AM10/3/17
to XLA development
Hi Peter, 

Thanks, that's super helpful.  Much appreciated.

- Christian

Vincent Mirian

unread,
Jul 30, 2020, 3:17:30 AM7/30/20
to XLA development
Hi,

I have the same objective and I am attempting to implement the solution in TF 2.2. A description of my problem is stated at: https://groups.google.com/forum/#!topic/xla-dev/Mo_73_bJHTA. Where is the code reference in TF 2.2 that lowers TF ops to XLA OP (HLO IR)?

Thank you in advance,
Vincent Mirian

Chris Leary

unread,
Jul 30, 2020, 1:21:54 PM7/30/20
to Vincent Mirian, XLA development
Hi Vincent,

There's not a way to say "I only support this subset" and have arbitrary programs work. That being said, if programs happen to use that supported subset then I believe they'll happen to work, just by not hitting any of the unsupported paths. e.g. if you don't have an FFT, if the TF program you happen to try to run doesn't use FFT, that would not cause an error.

To make arbitrary programs work with a limited subset you'd need to define the conversion from the normally available set of HLO ops (which TF assumes it can target) to your more limited set, which you can do in an XLA backend with a custom compiler plugin that desugars HLOs into ones you support.

I believe the entry point you're looking for as tf2xla conversion may be here (though my info is all a bit stale so somebody can correct me if I'm giving outdated info):


That's looping over the nodes in the TF graph and translating via the XlaCompilationDevice, see the comment here:


You can see translations like these via registered kernels in that device that turn TF graph nodes into XLA operations: ones: https://github.com/tensorflow/tensorflow/blob/19624f9650e87d55f8b2910ef68cd47fb332ea0f/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc#L33

So it's conceptually an abstract execution of the TF graph using some typical TF constructs to run the abstract execution of the TF nodes (like devices and kernels). You'll see at the bottom of the kernel files things like:

REGISTER_XLA_OP(Name("AddN").AllowVariantTypes(), AddNOp);
If you run through a basic XLA example and set some breakpoints / put some printfs to follow the flow may also be a bit clearer.

It's likely possible to either customize the translations (by registering XlaCompilationDevice kernels) or use an XLA compiler plugin that desugars HLOs into your more limited set (when possible).

HTH!

- Leary

Jacques Pienaar

unread,
Jul 30, 2020, 2:14:42 PM7/30/20
to Chris Leary, Vincent Mirian, XLA development
Hey Vincent,

The conversion target in the new legalization framework from TF to HLO could be useful here. That allows you to specify "I support HLO's A & B" and then during conversion from TF ops to HLO ops it will only convert TF ops that can be expressed with only those HLOs (see https://github.com/tensorflow/tensorflow/blob/a3116681da4cc86afcf1675f94233b1593ff6fe4/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc#L707 as an "interesting" example where we convert a TF op to another TF op by way of HLOs using a conversion target and multi-hop lowering, before lowering those to TFLite). The result post the previous lowering is then a graph with both TF and HLO ops (but only supported HLOs). Given that graph you'd need to extract these for your target into TF ops. For TPU we group these into a compile op and an execute op, https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc and some other passes there could be of interest. But there is no generic pass that does that upstream yet (e.g., converting these clusters of HLOs into ops that would invoke your compiler). Post that you can translate these to HloProto, see https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/xla/tests/translate/export.mlir as example, but you'd probably want to use the C++ helper functions, the translate tool is mostly intended for testing and exploration during development.

Then you don't need to know the mapping, you just need to express the desired target and new op support doesn't require changes your side. As Chris mentioned though, you might actually want to add additional patterns in there and so convert from TF ops to other TF ops (which have supported HLOs as lowering), or HLO ops to supported HLOs.


-- Jacques

Chris Leary

unread,
Jul 30, 2020, 2:42:15 PM7/30/20
to Jacques Pienaar, Vincent Mirian, XLA development
Thanks Jacques (I did mention my info was likely stale :-).

So to try to rephase to see if I understand, this means, via this non-tf2xla path (where tf2xla path is presumably working-but-not-the-future), it will partition the original TF graph into "these are in the supported HLO set that you told me about" and "these are not so I left them as TF ops"? And there's a way to add more desugarings / mappings as customizations it sounds like? Then it would be up to Vincent to figure out how to execute the things in the "I left them as TF ops" partition independently from his HLO backend? LMK if I understood that right.

Is there a test target that can be run to see an example output from this TF/HLO partitioning-sort-of process?

- Leary

Jacques Pienaar

unread,
Jul 30, 2020, 3:41:14 PM7/30/20
to Chris Leary, Vincent Mirian, XLA development
Well this path also uses parts of tf2xla path for now :)

Yes and yes (although there is no hook exposed TF side, beyond registering a GraphOptimization pass and running it internally, and support for using this via experimental_compile=true just landed).

No good example that shows the partial state unfortunately - we should add one, good point. I'd say easiest is to build tf-mlir-opt tool and then run `tf-opt -xla-legalize-tf=allow-partial-conversion file.mlir` with a file such as

func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
  %0:5 = "tf.FusedBatchNormGrad"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
  %1 = "tf.Acosh"(%0#0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
  return %1 : tensor<8x8x8x8xf32>
}

(which has an op lowered via a different pass and so would remain, that will change but not for next day :)). Post that I'd suggest running canonicalize to reduce all shape computations (tf-opt -xla-legalize-tf=allow-partial-conversion -canonicalize file.mlir) and then you get

  func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> tensor<8x8x8x8xf32> {
    %0 = mhlo.constant dense<1.000000e-03> : tensor<8xf32>
    %1 = mhlo.add %arg4, %0 : tensor<8xf32>
    %2 = "mhlo.rsqrt"(%1) : (tensor<8xf32>) -> tensor<8xf32>
    %3 = mhlo.multiply %arg2, %2 : tensor<8xf32>
    %4 = "mhlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>) -> tensor<8x8x8x8xf32>
    %5 = mhlo.multiply %arg0, %4 : tensor<8x8x8x8xf32>
    %6 = "tf.Acosh"(%5) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
    return %6 : tensor<8x8x8x8xf32>
  }

[we recently renamed the dialect to meta HLO/mhlo per request to avoid ambiguity as it has some ops that are not in XLA HLO]. From there its open to users, if you want to execute via TF at the end you could group all the HLO ops into some tf.MyDeviceCompileAndExecute op (where you could even encode the HloProto as string attribute) and execute TF graph as normal.

-- Jacques

Reply all
Reply to author
Forward
0 new messages