Error when speeding up Custom Layers in Keras with @tf.function

442 views
Skip to first unread message

Niclas Danielsson

unread,
Aug 22, 2019, 6:10:37 PM8/22/19
to devel...@tensorflow.org

Hi,


I might have come across another bug


I can't find any references as to whether one would want to put a @tf.function decorator on the call function of a custom layer to speed up inference

However, if I use the functional API for pure inference, it is not clear if the decorator is automatically applied or not (I.e. if the CNN would be executed in graph mode)

Colab example is included here:

https://colab.research.google.com/drive/1WTuizL6U7scO9qmnsqGMcuSG8C2rCmQ4


In that example I constructed the below mini-example, which accepts the decorator without any problem:


class CustomConv(layers.Layer):

    def __init__(self):
        super(CustomConv, self).__init__()
        
        self.conv = Conv2D(filters=32,
                            kernel_size=1,
                            strides=1,
                            padding="same",
                            data_format="channels_last",
                            use_bias=True)
        
    @tf.function
    def call(self, x):
        return self.conv(x)

      
def LittleConvModel(input_shape=None):

    if input_shape is None:
        input_shape = (416, 416, 3)

    input_img = layers.Input(shape=input_shape)
    x = CustomConv()(input_img)  
    model = models.Model(input_img, x)
    return model


model = LittleConvModel()


However, the below example, where I just switch the convolution for a batchnorm, fails with the error:

ValueError: Trying to capture a tensor from an inner function. This can be caused by accessing a tensor defined inside a loop or conditional body, or a subfunction, from a calling function, without going through the proper return value mechanism. Consider using TensorFlow mechanisms such as TensorArrays to return tensors from inner functions or loop / conditional bodies. Tensor: Tensor("batch_normalization/batch_normalization_trainable:0", dtype=bool); tensor graph: FuncGraph(name=call, id=140151478939432); this graph: FuncGraph(name=keras_graph, id=140151584677840)


class CustomBatchNorm(layers.Layer):

    def __init__(self):
        super(CustomBatchNorm, self).__init__()
        
        self.bn = BatchNormalization(epsilon=1e-4)
        
    @tf.function
    def call(self, x):
        return self.bn(x)

      
def LittleBatchNormModel(input_shape=None):

    if input_shape is None:
        input_shape = (416, 416, 3)

    input_img = layers.Input(shape=input_shape)
    x = CustomBatchNorm()(input_img)  
    model = models.Model(input_img, x)
    return model


model = LittleBatchNormModel()


I don't see any reason why this would fail for one case but not the other. So there are 2 questions:

1: Is this a bug in the batchnorm layer definition that needs to be fixed?

2: Does it make sense to use the @tf.function decorator this way at all? Or how would you else make sure a simple inference pass like the one below is accelerated by being executed as a graph?


img = np.random.randint(255, size=(224, 224, 3))
img = img.astype(np.float32)
img_exp = np.expand_dims(img, axis=0)
result = model(img_exp)


BR,

/NIclas


Martin Wicke

unread,
Aug 22, 2019, 6:40:50 PM8/22/19
to Niclas Danielsson, Kibeom Kim, devel...@tensorflow.org

--
You received this message because you are subscribed to the Google Groups "TensorFlow Developers" group.
To unsubscribe from this group and stop receiving emails from it, send an email to developers+...@tensorflow.org.
To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/developers/1566511830792.67501%40axis.com.

Francois Chollet

unread,
Aug 22, 2019, 7:28:04 PM8/22/19
to Martin Wicke, Niclas Danielsson, Kibeom Kim, devel...@tensorflow.org
The solution here is that your `call` method should look like `call(self, inputs, training=False)`, and every time you use your layer, you should pass a value for the `training` argument.

Otherwise the BN layer receives the global learning phase tensor as argument (if you don't specify anything) and tf.function doesn't like that.


> Does it make sense to use the @tf.function decorator this way at all? Or how would you else make sure a simple inference pass like the one below is accelerated by being executed as a graph?

In general you should never add this decorator on a `call` method, it is more likely to slow down your code than to speed it up. You should add a tf.function decorator on your end-to-end training step (the function that computes the forward pass and backwards pass for your entire model in one go), not on individual layer calls. And if you're using `fit` you are not expected to decorate anything.


Niclas Danielsson

unread,
Aug 23, 2019, 3:34:48 AM8/23/19
to Francois Chollet, Martin Wicke, Kibeom Kim, devel...@tensorflow.org
I see. The strange thing is that this happened when I tried to convert a model I had defined completely in the Subclassing API to a model using the Functional API in order to benchmark performance.

The reason it seems strange is because I then wrapped the top-level model's call function with a @tf.function decorator and it contained batchnorm as components inside Custom layers (with 2 levels of recursive custom layer definitions), but then I got no error message from having forgotten the training parameter (I only used the model for inference).

So this post naturally generates some follow up questions (numbered below):

1: Why do you get an error when you put the decorator on a custom layer that directly wraps the batchnorm, but not when you put it on a top level model class's call function?

Also, what is the best practice when you just want to run a single inference (just want to make a model(input) call) optimized so it is calculated as a graph if you do not put the decorator on the call function?

2: Do I need to wrap it like this?
@tf.function
def my_inference_pass(model, input):
    return model(input)

I might want to do post processing so in this case I do not want to convert to tensors.
3: But if I use the predict function, does the predict function automatically apply this decorator? I guess that is one of the functions of the predict function?


I found an example that indeed does what you say about the training parameter here:

But this is very easy to miss the relevance of this training parameter, since the tutorial does not comment at all why this is needed, and it is not applied for the other layers, so it is the exception.


More importantly. Your comment: "Otherwise the BN layer receives the global learning phase tensor as argument (if you don't specify anything) and tf.function doesn't like that."

does explain why things go wrong, though not why tf.function does not like the "global learning phase tensor". 


4: I guess this is intentional? (though I don't see why). However, the global learning phase tensor seems useless if it can not be applied universally, or does that tensor exist for some other purpose?


5: Either way, can I read about this global phase tensor anywhere, in the documentation? (and when and how you would use it) This fact seems important to understand.


Sorry if this gets long-winded, but I do ask for multiple reasons. Partly to understand for myself of course, but also more importantly because I think these things need to be documented so I hope this can be input to improving the documentation also for others and be helpful for the Tensorflow adoption as a whole. :-)

(I am also about to give an internal training on TF 2.0 soon and this forces me to turn every stone and try to think about every way you might use the framework, as well as ways you should NOT use the framework.)


I believe that now when the graph optimizations mainly exist in an automated partly opaque way in the background, it is still very important to understand exactly what those mechanisms are and how to use them correctly, because any serious implementation should obviously be optimized for maximum performance when deployed for production (as well as ideally when training as well)


 Thanks in advance,

/Niclas


 



From: Francois Chollet <fcho...@google.com>
Sent: Friday, August 23, 2019 01:27
To: Martin Wicke
Cc: Niclas Danielsson; Kibeom Kim; devel...@tensorflow.org
Subject: Re: Error when speeding up Custom Layers in Keras with @tf.function
 

Niclas Danielsson

unread,
Aug 23, 2019, 4:46:10 AM8/23/19
to Francois Chollet, Martin Wicke, Kibeom Kim, devel...@tensorflow.org

One more thing...


Regarding the comment: "In general you should never add this decorator on a `call` method, "


In the following description (see "Exporting from a function without a fixed signature:") about saving models to TF Lite,

this seems indeed to refer to a subclassed model where the example explicitly wraps the call function with @tf.function,

again without any explanation as to if this is really needed for the export.

https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/saved_model/save


class Model(tf.keras.Model):


 
@tf.function
 
def call(self, x):
   
...

m
= Model()
tf
.saved_model.save(
    m
, '/tmp/saved_model/',
    signatures
=m.call.get_concrete_function(
        tf
.TensorSpec(shape=[None, 3], dtype=tf.float32, name="inp")))

Is this a special case only for exporting models, or is this a typo in the doc and the tf.function decorator should be removed?
BR,
/Niclas



From: Niclas Danielsson
Sent: Friday, August 23, 2019 09:34
To: Francois Chollet; Martin Wicke
Cc: Kibeom Kim; devel...@tensorflow.org

Niclas Danielsson

unread,
Aug 23, 2019, 8:08:27 AM8/23/19
to Francois Chollet, Martin Wicke, Kibeom Kim, devel...@tensorflow.org

I will once more return to the statement:

"The solution here is that your `call` method should look like `call(self, inputs, training=False)`, and every time you use your layer, you should pass a value for the `training` argument. "


This does not seem to make any difference....


In the updated colab here:

https://colab.research.google.com/drive/1WTuizL6U7scO9qmnsqGMcuSG8C2rCmQ4

I show that it does not help to add the training parameter to the call function, because the error occurs during initialization of the Model Class itself and not during inference. So it is really NOT possible to add the decorator to the call function if the function contains a batchnorm layer.


Furthermore, if I INSTEAD add the decorator to a function that wraps the inference pass on a created model, then there is no error, and despite what you suggested, I DO NOT need to add the training parameter at all (at least it does not generate an error)


That is, wraps the inference pass like this

@tf.function
def inference_pass(model, image):
    return model(image)


Of course this is more in line with your recommendation, but it seems wrong that it should fail if I decorate the call function. There is no obvious reason for this (and as I showed in the previous mail, such examples exist in the documentation)


What should I make of this?


Note that this is the beta1 release

!pip install -q tensorflow-gpu==2.0.0-beta1

BR,

/Niclas


The code that fails is included below:

------------------------------------------------------------------

# This setup still fails in graph mode, even with the training parameter added


class CustomBatchNorm(layers.Layer):

    def __init__(self):
        super(CustomBatchNorm, self).__init__()
        
        self.bn = BatchNormalization(epsilon=1e-4)
        
    @tf.function

    def call(self, x, training=False):
        return self.bn(x, training=training)



      
def LittleBatchNormModel(input_shape=None):

    if input_shape is None:
        input_shape = (416, 416, 3)

    input_img = layers.Input(shape=input_shape)
    x = CustomBatchNorm()(input_img)  
    model = models.Model(input_img, x)
    return model


model = LittleBatchNormModel()





From: Niclas Danielsson
Sent: Friday, August 23, 2019 10:46

Francois Chollet

unread,
Aug 23, 2019, 1:49:40 PM8/23/19
to Niclas Danielsson, Martin Wicke, Kibeom Kim, devel...@tensorflow.org
This seems like a bug, I will file an issue and we will look into it.

Niclas Danielsson

unread,
Aug 24, 2019, 5:20:29 AM8/24/19
to Francois Chollet, Martin Wicke, Kibeom Kim, devel...@tensorflow.org
Great!

BR,
/Niclas

Sent from my iPhone

Sebastian Taciak

unread,
Jun 14, 2020, 11:29:28 AM6/14/20
to TensorFlow Developers, wi...@google.com, niclas.d...@axis.com, kkim...@google.com
Hi Francois, 

when you say " And if you're using `fit` you are not expected to decorate anything" does it mean that applying the .fit() method already contains all speedups you can expect from tf.function(). I am working the standard keras.layers and wonder if wrapping them up in tf.function can give any speed up especially when we provide the signatures. Could you comment on that please and if possible add discussion of that point in tf guide or similar place.

Best,

Sebastian
To unsubscribe from this group and stop receiving emails from it, send an email to devel...@tensorflow.org.

--
You received this message because you are subscribed to the Google Groups "TensorFlow Developers" group.
To unsubscribe from this group and stop receiving emails from it, send an email to devel...@tensorflow.org.

Francois Chollet

unread,
Jun 15, 2020, 1:09:31 AM6/15/20
to Sebastian Taciak, TensorFlow Developers, Martin Wicke, Niclas Danielsson, kkim...@google.com
> when you say " And if you're using `fit` you are not expected to decorate anything" does it mean that applying the .fit() method already contains all speedups you can expect from tf.function(). I am working the standard keras.layers and wonder if wrapping them up in tf.function can give any speed up especially when we provide the signatures. Could you comment on that please and if possible add discussion of that point in tf guide or similar place.

Yes, calling `fit` by default will use a compiled graph. If you train via `fit` you don't need to apply any tf.function manually.

Note that you can also make `fit` run eagerly by passing `run_eagerly=True` to `compile()`.

Best,

Francois

To unsubscribe from this group and stop receiving emails from it, send an email to developers+...@tensorflow.org.
To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/developers/28c62175-764d-48f1-b6bb-07c953ba4d18o%40tensorflow.org.
Reply all
Reply to author
Forward
0 new messages