Dumping HLO IR from a Tensor Flow program

1,733 views
Skip to first unread message

Hashim Sharif

unread,
Apr 17, 2018, 9:46:03 PM4/17/18
to XLA development

Hi,

As part of our work we need to analyze the Ops at the HLO IR level. I couldn't find documentation that demonstrates how to dump HLO IR given Tensorflow sources (in python). The JIT compilation doc at https://www.tensorflow.org/performance/xla/jit  suggests using the TF_XLA_FLAGS=--xla_generate_hlo_graph=.* for dumping HLO graphs, however, for me this doesn't dump the HLO graph to an output file. Specifically, we require looking at the HLO in the IR format. Any ideas?

Justin Lebar

unread,
Apr 17, 2018, 9:52:55 PM4/17/18
to hashim....@gmail.com, XLA development
A bit of a teach-a-person-to-fish:

If you search the TensorFlow repository on Github for "xla_generate_hlo_graph", there are three hits:


Since you're looking for another flag that might suit your purposes, this file looks interesting


There are various flags in here that might be of interest to you, including


If you want to find out exactly what these do (or even change what they do), I'd again recommend grep'ing the codebase, since these strings are easy to search for.

Good luck,
-Justin

On Tue, Apr 17, 2018 at 6:46 PM Hashim Sharif <hashim....@gmail.com> wrote:

Hi,

As part of our work we need to analyze the Ops at the HLO IR level. I couldn't find documentation that demonstrates how to dump HLO IR given Tensorflow sources (in python). The JIT compilation doc at https://www.tensorflow.org/performance/xla/jit  suggests using the TF_XLA_FLAGS=--xla_generate_hlo_graph=.* for dumping HLO graphs, however, for me this doesn't dump the HLO graph to an output file. Specifically, we require looking at the HLO in the IR format. Any ideas?

--
You received this message because you are subscribed to the Google Groups "XLA development" group.
To unsubscribe from this group and stop receiving emails from it, send an email to xla-dev+u...@googlegroups.com.
To post to this group, send email to xla...@googlegroups.com.
To view this discussion on the web visit https://groups.google.com/d/msgid/xla-dev/006c8c29-b0f7-4cb9-bfdc-f764a1eaf422%40googlegroups.com.
For more options, visit https://groups.google.com/d/optout.

Hashim Sharif

unread,
Apr 18, 2018, 2:15:03 AM4/18/18
to XLA development

Thanks Justin, that is helpful. However, I am still unable to extract the HLO IR. My tensor flow source is as follows:

a
= tf.placeholder(tf.int16)
b
= tf.placeholder(tf.int16)

jit_scope
= tf.contrib.compiler.jit.experimental_jit_scope  # Using JIT compilation
with jit_scope():
    add
= tf.add(a, b)
    mul
= tf.multiply(a, b)

with tf.Session() as sess:
   
# Run every operation with variable input                                                                              
   
print("Addition with variables: %i" % sess.run(add, feed_dict={a: 2, b: 3}))
   
print("Multiplication with variables: %i" % sess.run(mul, feed_dict={a: 2, b: 3}))

To run I use the command:   TF_XLA_FLAGS=--xla_generate_hlo_text_to=output  python source.py

I am expecting to get the output IR files under the output directory, but to no avail. I enabled logging with TF_CPP_MIN_VLOG_LEVEL=2 to test if XLA did compile the code, however, the output does not seem to include any XLA specific info. Am I not invoking XLA correctly?

Peter Hawkins

unread,
Apr 18, 2018, 6:16:23 AM4/18/18
to hashim....@gmail.com, XLA development
One possible reason for your issue is that I don't think tf.int16 is supported by any XLA backend. Try a different int type, say tf.int32?

(There is no fundamental reason int16 isn't supported, but no-one has done the necessary work to support it yet.)

Peter

Hashim Sharif

unread,
Apr 20, 2018, 6:59:54 PM4/20/18
to XLA development
Thanks Peter! Your solution helped. I am now able to dump HLO IR. Is there any timeline to when int16 will be supported in XLA? I would guess tf.int16 is very frequently used.

Justin Lebar

unread,
Apr 20, 2018, 7:44:31 PM4/20/18
to Hashim Sharif, XLA development
> Is there any timeline to when int16 will be supported in XLA?

Patches welcome. :)

If you look through the recent commit history, you'll see some patches
adding support for new floating-point types to XLA. Search for patches by
bi...@google.com. It's not necessarily a big undertaking.
On Fri, Apr 20, 2018 at 3:59 PM Hashim Sharif <hashim....@gmail.com>
wrote:

> Thanks Peter! Your solution helped. I am now able to dump HLO IR. Is
there any timeline to when int16 will be supported in XLA? I would guess
tf.int16 is very frequently used.

> --
> You received this message because you are subscribed to the Google Groups
"XLA development" group.
> To unsubscribe from this group and stop receiving emails from it, send an
email to xla-dev+u...@googlegroups.com.
> To post to this group, send email to xla...@googlegroups.com.
> To view this discussion on the web visit
https://groups.google.com/d/msgid/xla-dev/4626186c-3972-48d6-8170-c50643c5f2ac%40googlegroups.com
.

Seyed Hashemi

unread,
Jun 22, 2018, 4:36:30 PM6/22/18
to XLA development
any suggestions on how to this on windows?

I set an env variable TF_XLA_FLAGS=--xla_generate_hlo_text_to and still didn't get anything.
also tried cmd /V /C "set TF_XLA_FLAGS=—xla_generate_hlo_text_to=c:\\XLA\\output&&  python source.py"

and it looks like the flag is set right (see attached), but I still don't get an output file.
Capture.PNG
Reply all
Reply to author
Forward
0 new messages