How to dump MLIR after every optimization phase during TFLite conversion

672 views
Skip to first unread message

JPR

unread,
Apr 24, 2020, 5:42:30 PM4/24/20
to MLIR
Hi,

I posted this question on TFLite group but didn't get any response. So thought this might be a more appropriate group to ask this question. Thanks!

I built tensorflow 2.1.0 from source. For the following model, I would like to dump the sequence of MLIR optimizations and transformations that happen during TensorFlow to TFLite conversion when tf.lite.TFLiteConverter is called . 

I understand that mlir-opt is responsible for doing the MLIR optimizations but I don't know how mlir-opt gets invoked while we make tf.lite.TFLiteConverter call. Could you please provide me the steps to dump the MLIR passes for this program?


import numpy as np

import shutil

import tensorflow as tf

from tensorflow import keras

import tensorflow.keras.backend as K

from tensorflow.keras import layers

import inspect

import os


SAVE_AND_LOAD_MODEL = True


class MultilayerLinear(keras.Model):

    def __init__(self, hidden_size, num_layers):

        super().__init__()

        self.lin_layers = []

        for __ in range(num_layers):

            self.lin_layers.append(layers.Dense(hidden_size, activation='relu'))


    @tf.function

    def call(self, x):

        for l in self.lin_layers:

            x = l(x)

        return x


hidden_size = 64

input_size = 128

num_layers = 1

model = MultilayerLinear(hidden_size, num_layers)

x = tf.random.normal((8 , input_size))

y = model(x)

model.build((None, input_size))


if SAVE_AND_LOAD_MODEL:

    export_dir = "./tflite_conversion_test"

    tf.saved_model.save(model, export_dir)


    model = tf.saved_model.load(export_dir)


x = K.placeholder(shape=(None, input_size))

concrete_function = model.call.get_concrete_function(x)

converter = tf.lite.TFLiteConverter([concrete_function])

tflite_model = converter.convert()


One last question, when I built tensorflow for CPU, I didn't see mlir-opt binary after the build. When I built MLIR separately, I was able to see mlir-opt. Do you know in which directory does mlir-opt binary gets generated during tensorflow build?

Thanks a lot!!

Mehdi AMINI

unread,
Apr 26, 2020, 12:51:18 AM4/26/20
to JPR, MLIR, ash...@google.com, Jacques Pienaar
Hi,

On Fri, Apr 24, 2020 at 2:42 PM JPR <rohi...@gmail.com> wrote:
Hi,

I posted this question on TFLite group but didn't get any response. So thought this might be a more appropriate group to ask this question. Thanks!

I built tensorflow 2.1.0 from source. For the following model, I would like to dump the sequence of MLIR optimizations and transformations that happen during TensorFlow to TFLite conversion when tf.lite.TFLiteConverter is called . 

I don't think there is such debugging option from the python level, this seems like a good feature request though!
Adding Ashwin and Jacques in case I missed the options, and to advise on the best way to achieve this right now.

 

I understand that mlir-opt is responsible for doing the MLIR optimizations but I don't know how mlir-opt gets invoked while we make tf.lite.TFLiteConverter call. Could you please provide me the steps to dump the MLIR passes for this program?


mlir-opt is a testing tool to invoke passes/pipeline on MLIR files, it is also specific to MLIR upstream (in the LLVM repo) and does not include any TensorFlow specific aspects.
The equivalent to mlir-opt in TensorFlow is tf-opt which you can get as //tensorflow/compiler/mlir:tf-opt ; it is really a superset of mlir-opt: it has the same functionality but also includes all of the TF/TFLite support.
Note that tf-opt is still a testing tool: it isn't directly invoked by TFLite or the converter. Instead the convert will invoke a pass-pipeline directly through the C++ API. It is possible to reproduce this pipeline with `tf-opt` though, even though in the case of TFLite I don't know if the invocation is straightforward.

If you are interested in the details, the method is ConvertMLIRToTFLiteFlatBuffer which is populated the MLIR passes by calling AddTFToTFLConversionPasses
Ideally this should be turned into an MLIR "pass pipeline" so that it could be called directly from `tf-opt`. Right now on the command-line it is wired to the tf_tfl_translate tool here, you can invoke this tool with //tensorflow/compiler/mlir/lite::tf_tfl_translate 

These are tests under tensorflow/compiler/mlir/lite/tests/ that you can take as examples, for example: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/tests/end2end/add.pbtxt

Here is how to invoke the translator to get a TFLite flatbuffer:

bazel run -c opt    //tensorflow/compiler/mlir/lite:tf_tfl_translate -- -tf-input-arrays=input0,input1 -tf-input-shapes=4:4 -tf-input-data-types=DT_INT32,DT_INT32 -tf-output-arrays=Add `pwd`/tensorflow/compiler/mlir/lite/tests/end2end/add.pbtxt  -o output.lite

You can add `--print-ir-after-all` and other options similar to what `mlir-opt` accepts to get the IR printed after each pass (you'll need to sync to a recent version of the repository to get these options though).
`mlir-opt` isn't useful to anything in TensorFlow, this is why it does not get build by default in TensorFlow. The tests in TF are using tf-opt though which should get build automatically.


-- 
Mehdi

JPR

unread,
Apr 26, 2020, 1:50:00 AM4/26/20
to MLIR
Hi Mehdi,

Thanks a lot for letting me know how MLIR is integrated with TFLite. Really useful information. 

Just need one clarification. I understand that we cannot dump MLIR from Python source code.

[Mehdi] You can add `--print-ir-after-all` and other options similar to what `mlir-opt` accepts to get the IR printed after each pass (you'll need to sync to a recent version of the repository to get these options though).

Can I use  `--print-ir-after-all` and other MLIR IR dump options in the bazel build command ? ( that MLIR could recognize while it is converting TF->TFLite when tf.lite.TFLiteConverter is called and dump the IR after every pass?)

Thanks and Regards,
JPR

Mehdi AMINI

unread,
Apr 26, 2020, 1:53:25 AM4/26/20
to JPR, MLIR
On Sat, Apr 25, 2020 at 10:50 PM JPR <rohi...@gmail.com> wrote:
Hi Mehdi,

Thanks a lot for letting me know how MLIR is integrated with TFLite. Really useful information. 

Just need one clarification. I understand that we cannot dump MLIR from Python source code.

[Mehdi] You can add `--print-ir-after-all` and other options similar to what `mlir-opt` accepts to get the IR printed after each pass (you'll need to sync to a recent version of the repository to get these options though).

Can I use  `--print-ir-after-all` and other MLIR IR dump options in the bazel build command ? ( that MLIR could recognize while it is converting TF->TFLite when tf.lite.TFLiteConverter is called and dump the IR after every pass?)

Yes: I linked the commit in my previous email that enables it, with it you can just append it to the Bazel invocation I gave.


--
You received this message because you are subscribed to the Google Groups "MLIR" group.
To unsubscribe from this group and stop receiving emails from it, send an email to mlir+uns...@tensorflow.org.
To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/mlir/0fb3f8a3-13a5-43b5-b732-85a8e1581e63%40tensorflow.org.
Reply all
Reply to author
Forward
Message has been deleted
0 new messages