tpu distributed strategy

Skip to first unread message

Rohan Mahajan

May 2, 2021, 6:46:03 PM5/2/21
to TPU Users

I had some questions about distribute strategy custom training.

  1. "If you are using regularization losses in your model then you need to scale the loss value by number of replicas. You can do this by using the tf.nn.scale_regularization_loss function."

Why are regularization loss treated different compared to other type of losses?

  1. "Using tf.reduce_mean is not recommended. Doing so divides the loss by actual per replica batch size which may vary step to step."

How does the batch size vary step to step per replica? Also, I am not exactly sure why we would not tf.reduce mean if the batch size changes step instead of diving by the global batch size

Message has been deleted

Rohan Mahajan

May 3, 2021, 1:44:08 PM5/3/21
to Russell Power, TPU Users
Hi Russell,
Thanks for your response. I am not a google employee(only google cloud user) so can't really read the source code link above.

When you mean the default losses are already hacked, is that if we use reduction=sum in those losses. Is that the recommended behavior. 
The below listing on the same page seems we would have to do some changes if we use reduction=none.
" "If labels is multi-dimensional, then average the per_example_loss across the number of elements in each sample. For example, if the shape of predictions is (batch_size, H, W, n_classes) and labels is (batch_size, H, W), you will need to update per_example_loss like: 

per_example_loss /= tf.cast(tf.reduce_prod(tf.shape(labels)[1:]), tf.float32)Caution: Verify the shape of your loss. Loss functions in tf.losses/tf.keras.losses typically return the average over the last dimension of the input. The loss classes wrap these functions. Passing reduction=Reduction.NONE when creating an instance of a loss class means "no additional reduction". For categorical losses with an example input shape of [batch, W, H, n_classes] the n_classes dimension is reduced. For pointwise losses like losses.mean_squared_error or losses.binary_crossentropy include a dummy axis so that [batch, W, H, 1] is reduced to [batch, W, H]. Without the dummy axis [batch, W, H] will be incorrectly reduced to [batch, W]." 

On Mon, May 3, 2021 at 8:36 AM Russell Power <> wrote:
DS doesn't override the behavior for reduce_mean (or other TF operations), so it doesn't do what you expect: instead of computing the global mean across all replicas, you get the per-replica mean value. There's explicit code in the optimizers/losses to perform global reductions.

The default losses are already "hacked" to handle global reductions :/

You received this message because you are subscribed to the Google Groups "TPU Users" group.
To unsubscribe from this group and stop receiving emails from it, send an email to
To view this discussion on the web visit
Reply all
Reply to author
0 new messages