Hello TFP community,
I am converting my deterministic neural network with 4 dense layers to BNNs using DenseVariational.
I noticed that it works well with 1/4 layers is BNN. Starting at 2/4 layers, the performance reduces quickly and with 4/4 layers (all parameters are distributions), the model just generates random noise.
Using Stochastic Weight Averaging - SWA (https://www.tensorflow.org/addons/api_docs/python/tfa/optimizers/SWA) improves the training in general, making 2/4 layers work well but 3/4 and 4/4 are still badly trained.
SWA is a simple trick that might work sometime.
Do you guys have any other trick you can share that improves the training of deep BNNs?
Thank you very much.