Understanding the HLO IR Structure

3,357 views
Skip to first unread message

Hashim Sharif

unread,
May 15, 2018, 11:02:45 PM5/15/18
to XLA development

Hi,

As part of our work, we are looking to translate the HLO IR to our in-house compiler IR. Going through the XLA classes, the structure/representation of the IR seems unclear. It appears as "HloModule" represents the total computation compiled by a higher level tf.Graph, "HloComputation" represents instruction sequences, and "HloInstruction" represents the individual operations defined under the XLA OperationSemantics documentation [1].  If this understanding is incorrect, please feel free to correct. Some confusions are:

1) What is the purpose of "HloComputation"? Does it represent a Basic Block of HloInstructions with CFG edges between HloComputations?

2) How are tensors represented in XLA? HloInstruction includes "operands" that point to other HloInstructions. In the context of a particular instruction type (say Add), how will be tensor values be represented? For instance, if the tensor was a tf.Variable whose values got updated by a gradient descent, how could we extract those from HLO IR (including all elements in the tensor)?

3) When HLO IR is dumped using the TF_XLA_FLAGS, the IR files are dumped in clusters. Do these clusters refer to separate HloComputations? Why are these clusters required/formed?

Coming from an LLVM background, the IR structure doesn't look very familiar. However, since XLA has an LLVM backend, I am assuming HLO constructs do translate well enough to the LLVM IR. Any pointers are very much appreciated.

-Hashim

References
[1] https://www.tensorflow.org/performance/xla/operation_semantics

Justin Lebar

unread,
May 15, 2018, 11:17:05 PM5/15/18
to Hashim Sharif, XLA development
> 1) What is the purpose of "HloComputation"? Does it represent a Basic Block of HloInstructions with CFG edges between HloComputations?

An HloModule is like a whole program.

An HloComputation is like a function.  A module contains one special computation, the "entry computation".  This is like "main".  Running a module always consists of running the entry computation from beginning to end.

A computation has some number of parameters, and exactly one output (its "root").

Control flow is represented by special HloInstructions, "while" and "conditional".  For example, the conditional HLO has two HLO computations.  The two computations must have the same number of parameters and corresponding parameters must have the same shape.  The two computations' root nodes must also have the same shape.  The conditional HLO takes as input a boolean plus some number of other inputs.  The boolean tells it which of the two computations to run.  The other inputs are passed to the computation that's run.  The "while" HLO is a similar idea.

> 2) How are tensors represented in XLA? HloInstruction includes "operands" that point to other HloInstructions. In the context of a particular instruction type (say Add), how will be tensor values be represented? For instance, if the tensor was a tf.Variable whose values got updated by a gradient descent, how could we extract those from HLO IR (including all elements in the tensor)?

HLO is pure (mostly).  There are (mostly) no side-effects.  To extract a value from an HloModule, that value must (with some exceptions, such as the outfeed or send ops) be present in the root node of the root computation of the module.  You cannot retrieve intermediate values produced in a module except by returning them in the root node.  If you have multiple return values, you return a tuple.

> 3) When HLO IR is dumped using the TF_XLA_FLAGS, the IR files are dumped in clusters. Do these clusters refer to separate HloComputations?

They are HloModules.

> Why are these clusters required/formed?

I think I answered this in my previous email to you about clustering?

--
You received this message because you are subscribed to the Google Groups "XLA development" group.
To unsubscribe from this group and stop receiving emails from it, send an email to xla-dev+u...@googlegroups.com.
To post to this group, send email to xla...@googlegroups.com.
To view this discussion on the web visit https://groups.google.com/d/msgid/xla-dev/c49be33d-d6c1-4eb3-accd-e50a3b63441b%40googlegroups.com.
For more options, visit https://groups.google.com/d/optout.

Hashim Sharif

unread,
May 16, 2018, 12:39:01 AM5/16/18
to XLA development

That is a great explanation of the IR structure. Thanks!

> HLO is pure (mostly).  There are (mostly) no side-effects.  To extract a value from an HloModule, that value must (with some exceptions, such as the outfeed or send ops) be present in the root > node of the root computation of the module.  You cannot retrieve intermediate values produced in a module except by returning them in the root node.  If you have multiple return values, you
> return a tuple.

Allow me to better explain our usage scenario. We need to translate an HLO IR graph into an LLVM-like representation for the feed-forward phase (for a DNN). So essentially our IR would require not only the operators (HLO instructions) but also the tensor values that are learnt in the training phase. Since the weight tensors (for instance in a convolutional layer) would be constants for purposes of inference, we need to extract them as constant LLVM arrays. Does XLA not include a representation for data values, constant data arrays? If no, how do you have a suggestion for how we can extract the HLO operations corresponding to higher-level TF operations, along with the tensor values?

-Hashim

 

Sanjoy Das

unread,
May 16, 2018, 12:47:53 AM5/16/18
to hashim....@gmail.com, XLA development
On Tue, May 15, 2018 at 9:39 PM Hashim Sharif <hashim....@gmail.com>
wrote:
> Allow me to better explain our usage scenario. We need to translate an
HLO IR graph into an LLVM-like representation for the feed-forward phase
(for a DNN). So essentially our IR would require not only the operators
(HLO instructions) but also the tensor values that are learnt in the
training phase. Since the weight tensors (for instance in a convolutional
layer) would be constants for purposes of inference, we need to extract
them as constant LLVM arrays. Does XLA not include a representation for
data values, constant data arrays?

XLA has a "Constant" HLO instruction that can represent compile time
constant values. See
https://github.com/tensorflow/tensorflow/blob/b12c3bb1157245adf6230a2e045831348f679b5b/tensorflow/compiler/xla/client/xla_client/xla_builder.h#L169

The LLVM backends will lower these constant hlo instructions to constant
arrays in LLVM as you're expecting.

-- Sanjoy
https://groups.google.com/d/msgid/xla-dev/70ddfdc2-38a3-4be9-9318-1f3edf53fe43%40googlegroups.com
.

Bjarke Roune

unread,
May 16, 2018, 2:19:22 PM5/16/18
to XLA development
It's worth noting that LLVM does net deal well with very large constants, like those you'd get for the weights in an ML model with large weight tensors, so if your LLVM-like representation is originally based on actual LLVM, you might run into some issues. On other platforms, we are getting around this by passing the weights in as parameters, while leaving them on the device so that there is no data transfer per invocation for doing this. That works unless your low-level compilation phase needs to know the exact values in the weights as compile-time constants to compile the model, though so far this has not been necessary for us for large weight tensors.

On HloComputation, it should help to know that a "computation" is precisely what would be called a "function" in other programming languages, which should clear up its role. E.g. a map operation takes a computation (i.e. a function) that it calls on the data that it receives. The "computation shape" is the type/signature of the function.

Bjarke

Hashim Sharif

unread,
May 17, 2018, 4:16:59 AM5/17/18
to XLA development

Thanks Sanjoy and Bjarke,

It's worth noting that LLVM does net deal well with very large constants, like those you'd get for the weights in an ML model with large weight tensors, so if your LLVM-like representation is originally based on actual LLVM, you might run into some issues. On other platforms, we are getting around this by passing the weights in as parameters, while leaving them on the device so that there is no data transfer per invocation for doing this. That works unless your low-level compilation phase needs to know the exact values in the weights as compile-time constants to compile the model, though so far this has not been necessary for us for large weight tensors.

Since you mention the weights can be passed as parameters, what is the interface/invocation that passes the input tensors to an XLA graph/module? For purposes of our work, we need to relate the tensor values (of constant weights) with the corresponding XLA ops that consume these values. Given that the dumped XLA does not seem to represent the constant tensors, it is unclear i) how input tensors are passed as parameters to the XLA graph or ii) how can the tensor values (available in the higher level TF graph - tf.Graph) be related to the corresponding XLA ops.

-Hashim

Bjarke Roune

unread,
May 17, 2018, 1:55:18 PM5/17/18
to Hashim Sharif, XLA development
I'll give the XLA-level answer. If the weights are passed in as parameters, then the compiler does not know what the weights are, only their shape (i.e. the type of the array including array bounds). When you run the XLA module (i.e. program), the parameters are passed in at runtime and they will be the values of the Parameter nodes in the graph. The parameters need to already be on the device, or otherwise they would need to be transferred to the device first. If you want to transfer data in and out of the program while it runs, you'd need support for infeed and outfeed nodes. The only drawback here is that the compiler does not know the exact values of the weights, though for large weight arrays this has so far for us not been useful information to have in the compiler anyway. If you want to map XLA ops back to TF ops for debugging, you can look at the metadata field on the XLA op, which contains a string from TF that gives you an idea of where this op came from originally.

--
You received this message because you are subscribed to a topic in the Google Groups "XLA development" group.
To unsubscribe from this topic, visit https://groups.google.com/d/topic/xla-dev/cgMgzdjlOQI/unsubscribe.
To unsubscribe from this group and all its topics, send an email to xla-dev+unsubscribe@googlegroups.com.

To post to this group, send email to xla...@googlegroups.com.

Sanjoy Das

unread,
May 17, 2018, 2:04:38 PM5/17/18
to Bjarke Roune, hashim....@gmail.com, XLA development
On Thu, May 17, 2018 at 10:55 AM 'Bjarke Roune' via XLA development <
xla...@googlegroups.com> wrote:
> I'll give the XLA-level answer. If the weights are passed in as
parameters, then the compiler does not know what the weights are, only
their shape (i.e. the type of the array including array bounds). When you
run the XLA module (i.e. program), the parameters are passed in at runtime
and they will be the values of the Parameter nodes in the graph. The
parameters need to already be on the device, or otherwise they would need
to be transferred to the device first. If you want to transfer data in and
out of the program while it runs, you'd need support for infeed and outfeed
nodes. The only drawback here is that the compiler does not know the exact
values of the weights, though for large weight arrays this has so far for
us not been useful information to have in the compiler anyway.

On the CPU backend we will sometimes flip the layout of large weight
matrices if we think that can make some operations faster (see
ShouldMakeOperandColumnMajor in cpu_layout_assignment.cc). We can't do
this for parameters since from XLA's perspective the parameters have a
fixed layout. I'm working on some more optimizations (targeting the CPU
backend, though in principle this restriction can be lifted) that will be
more effective with the weights expressed as constants in XLA IR.

LLVM does have trouble with large constants, but this problem is already
addressed for the XLA JIT (when JITting we don't emit large constants in
the XLA IR as LLVM constants, but we instead use an "external constant
pool" for these). However, I have a better fix on the horizon, llvm's
ConstantDataArray is much better at handing large constants and I plan to
move both the JIT and the AOT backends to use it in the near future.

-- Sanjoy

Bjarke Roune

unread,
May 17, 2018, 2:46:01 PM5/17/18
to Sanjoy Das, Tom Jablin, Hashim Sharif, XLA development
Thanks for the update on the LLVM constant situation.

Layout is a challenge for parameters, it's true. A feature that we have wanted to do for a long time is for XLA to surface to TF what layout it would prefer for a parameter array and then TF would provide the parameter in that layout, though this is not in place and it's not clear when it might become available. In the mean time there is a specific way of doing layout that usually works out well that we use for TPU-enabled models - well enough so far to drop the priority of doing more about it.

My understanding is that we are focusing on keeping weights as parameters for TPUs also for inference, where advantages include more straightforward sharing of identical weights among similar models on the same device, though the other approach with constants was also discussed. It seems unfortunate if we are pursuing different strategies on this for different backends, though maybe the shared approach needs to be to support both ways well. +tjablin has been looking into this; are we still targeting weights as parameters as the recommended method, or has this changed?

Seyed Hashemi

unread,
Jun 26, 2018, 6:42:13 PM6/26/18
to XLA development
I am new to tensorflow, can someone please explain these terms from XLA:

Kernel
Cluster
how they relate to HLOmodule? 

Thanks,
Seyed
Reply all
Reply to author
Forward
0 new messages