Hi everyone, I was trying to train a fully connected neural network through the method of variational inference with DenseFlipout layers but when the neural network seems to have reached convergence huge spikes appear in the training loss plot. On stackoverflow I was suggested to use gradient clipping ( i added
clipnorm=1.0 among the parameters of the SGD optimizer) but it doesn't seem to work.
my model is quite simple altough slightly overparameterized. I tried lowering the number of neurons per layer but if i do it the neural network doesn't fit the data very well
def create_flipout_bnn_model(train_size):
def normal_sp(params):
return tfd.Normal(loc=params[:,0:1], scale=1e-3 + tf.math.softplus(0.05 * params[:,1:2]))
kernel_divergence_fn=lambda q, p, _: tfp.distributions.kl_divergence(q, p) / (train_size)
bias_divergence_fn=lambda q, p, _: tfp.distributions.kl_divergence(q, p) / (train_size)
inputs = Input(shape=(1,),name="input layer")
hidden = tfp.layers.DenseFlipout(20,
kernel_divergence_fn=kernel_divergence_fn,
activation="relu",name="DenseFlipout_layer_1")(inputs)
hidden = tfp.layers.DenseFlipout(20,
kernel_divergence_fn=kernel_divergence_fn,
activation="relu",name="DenseFlipout_layer_2")(hidden)
hidden = tfp.layers.DenseFlipout(20,
kernel_divergence_fn=kernel_divergence_fn,
activation="relu",name="DenseFlipout_layer_3")(hidden)
params = tfp.layers.DenseFlipout(2,
kernel_divergence_fn=kernel_divergence_fn,
name="DenseFlipout_layer_5")(hidden)
dist = tfp.layers.DistributionLambda(normal_sp,name = 'normal_sp')(params)
model = Model(inputs=inputs, outputs=dist)
return model