HLO Passes - Categorization and Functions of different passes

680 views
Skip to first unread message

Raghav Garg

unread,
Oct 16, 2020, 9:12:51 PM10/16/20
to XLA development
Hi all,

I am working with XLA HLO passes and trying to see if these passes affect a model's execution time or the average step time.

Currently, I am using a benchmarked model (ResNet50 with CIFAR10). I am enabling and disabling these passes and further trying to watch how certain passes are important or specific to a model. I am interested to know the function of all passes to identify their specific optimization pertaining to model we run.

I ran the model for a few steps (around 10) and profiled its average step time (using TensorFlow profiler) to see how disabling a set of passes varies the step time as compared to the default state (when no pass is disabled). For now, I tried working with 20 passes out of 51 available (with no specific criteria) and gave them few runs to record the average step time. Check the attached results. 

Observation: One of the passes "hlo-get-dimension-size-rewriter" seems important to be always enabled as when disabled has increased the average step time by approx 19.2 ms from the default state.

Please guide about the functionality of all the given passes so that I can get a better understanding about their usage.

I have also attached a snapshot of a tensorboard profiler (when no pass was disabled). This image is just for reference. I make an observation here that the "Kernel Launch time" and "Device Compute Time" are one of the varying factors comparing to other profiled results. Could you throw some light about what factors involve for these times.

Looking forward to your reply.

Best,
Raghav



result.xlsx
profiler_snapshot.JPG

Raghav Garg

unread,
Oct 19, 2020, 5:37:15 AM10/19/20
to Adrian Kuegel, xla...@googlegroups.com
Hi,

Even if you can't provide information for all the passes, maybe describing a few of them might help my case for now.

I am working with the below 20 passes (out of 51 available) and understanding their behaviour. So, let me know if it's possible to provide a brief explanation about these.

1. reduction-degenerate-dim-remover
2. cusolver-rewriter
3. simplify-sorts
4. cse
5. simplify-while-loops
6. reduction-layout-normalizer
7. stable-sort-expander
8. variadic-op-splitter
9. convolution_4d_expander
10. tuple-simplifier
11. all-reduce-combiner
12. simplify-conditional
13. simplification
14. fusion_merger
15. cublas-gemm-pad-for-speed
16. flatten-call-graph
17. cudnn_pad_for_convolutions
18. constant_folding
19. multi_output_fusion
20. hlo-get-dimension-size-rewriter

Looking forward to your input.

On Mon, Oct 19, 2020 at 4:30 AM Adrian Kuegel <aku...@google.com> wrote:
I think nobody can answer this question without looking at the code for each pass to see what it does. As you have noticed, there are quite a few passes, and I doubt anyone knows all of these by heart. In general, the passes are there for a reason, so disabling them usually doesn't make much sense. Some are even necessary for correctness, otherwise you will get a failure when you run the model.

--
You received this message because you are subscribed to the Google Groups "XLA development" group.
To unsubscribe from this group and stop receiving emails from it, send an email to xla-dev+u...@googlegroups.com.
To view this discussion on the web visit https://groups.google.com/d/msgid/xla-dev/92b2e76c-d7b0-45d6-9161-77332cb2318fn%40googlegroups.com.


--


Google Germany GmbH

Erika-Mann-Straße 33

80636 München


Geschäftsführer: Paul Manicle, Halimah DeLaine Prado

Registergericht und -nummer: Hamburg, HRB 86891

Sitz der Gesellschaft: Hamburg


Diese E-Mail ist vertraulich. Wenn Sie nicht der richtige Adressat sind, leiten Sie diese bitte nicht weiter, informieren Sie den Absender und löschen Sie die E-Mail und alle Anhänge. Vielen Dank.

      

This e-mail is confidential. If you are not the right addressee please do not forward it, please inform the sender, and please erase this e-mail including any attachments. Thanks.



--
Best,
Raghav

grzpaw...@gmail.com

unread,
Oct 19, 2020, 5:48:55 AM10/19/20
to XLA development
Hey,

You should be able to go through that list and match up those passes to header files here https://github.com/graphcore/tensorflow/tree/c595ee360fc25c7fc59cec187bc43b088fbc15dc/tensorflow/compiler/xla/service which should describe what each pass does

Raghav Garg

unread,
Oct 19, 2020, 8:49:38 PM10/19/20
to XLA development
Hi,

Thank you for providing this information.

As you mentioned, I tried matching the passes with the header files and I found majority of them. But there were still some passes that I couldn't directly find information about, like- "reduction-degenerate-dim-remover ", "convolution_4d_expander ",  "simplification".  Please help if you could provide any detail or hyperlink to refer.

Also, I observed that there are passes which didn't exist earlier when I was dumping them with my default model run state like  - "dynamic_padder" and "reduction-splitter". But they exist now.
And, there are some other passes which although existed previously but they aren't dumped now, like  - "convolution-group-converter", "depthwise-convolution-converter", and "reshape-mover". Could you please throw some light on it?

Best,
Raghav

Adrian Kuegel

unread,
Oct 20, 2020, 2:43:36 AM10/20/20
to Raghav Garg, XLA development
Some passes are also in the gpu subdirectory (if they are gpu specific). convolution_4d_expander is one of these. This is actually one I know about as I have written it, it is only useful if you have convolutions with 4 spatial dimensions in your model, but at least 1 of them is a trivial 1 dimension. Then, it would remove such dimensions to make the convolution work with Cudnn (as it only supports at most 3 spatial dimensions).
"simplification" is the algebraic simplifier, it is a very important pass and should always be used. An easier way to find out which pass is which is to look at the header includes of gpu_compiler.cc
"convolution-group-converter" and "depthwise-convolution-converter" are correctness passes, if you have any depthwise or group convolutions, 
Some nvidia specific passes are also registered here:
But in theory the only thing you can save by not running passes is compile time. Runtime should usually be better by running the passes, unless the model somehow triggers an unlucky case.
When you are measuring time, are you measuring compile time as well? That would make not so much sense, XLA is known to be not very fast regarding compile time, it shines by optimizing the runtime. So usually you would compile once, then run lots of iterations while training your model.

Raghav Garg

unread,
Oct 20, 2020, 3:37:55 AM10/20/20
to Adrian Kuegel, XLA development
Hi Adrian,

Thank you. That was very informative and helpful.
--
Best,
Raghav

Raghav Garg

unread,
Oct 24, 2020, 7:13:11 PM10/24/20
to Adrian Kuegel, XLA development
Hi,

I have a follow up question. Though this question had been asked earlier but I believe I didn't get an update on it.

I have attached a snapshot of a tensorboard profiler (when no optimization pass is disabled).  I make an observation here that the "Kernel Launch time" and "Device Compute Time" are one of the varying factors in the step time compared to other profiled results when I am playing with enabling or disabling of passes. Could you throw some light about what factors are involved for these times.  

Looking forward to hearing from you.
--
Best,
Raghav
profiler_snapshot.JPG
Reply all
Reply to author
Forward
0 new messages