Using Horseshoe prior in variational autoencoder

27 views
Skip to first unread message

Bryor Snefjella

unread,
Jul 18, 2021, 2:53:13 PM7/18/21
to TensorFlow Probability
I'm interesting in using a horseshoe prior instead of the standard gaussian one in a variational autoencoder. Doing this with the implemented horseshoe distribution is simple:

tfk = tf.keras
tfkl
= tf.keras.layers
tfd
= tfp.distributions
tfpl
= tfp.layers

latent_dim = 1
prior = tfd.horseshoe(0.01 * tf.ones(latent_dim))

encoder = tfk.Sequential([
  tfkl
.Dense(2),
  tfpl
.DistributionLambda(
    make_distribution_fn
=lambda t: tfd.Normal(
        loc=t[..., 0], scale=tf.exp(t[..., 1])))
  tfpl.KLDivergenceAddLoss(prior)
])

But it seems like most work on horseshoe priors in variational inference using a different paramterization, like the one from this previous thread:


Root = tfd.JointDistributionCoroutine.Root
def horseshoe_prior_in_compound_representation(num_features=100, global_scale=.1):
  local_scale_variance
= yield Root(tfd.Independent(
      tfd
.InverseGamma(

       
0.5 * tf.ones([num_features]),
       
0.5 * tf.ones([num_features]),

       
),
      reinterpreted_batch_ndims
= 1
     
)
   
)
  local_scale_noncentered
= yield Root(tfd.Independent(
      tfd
.HalfNormal(scale=tf.ones([num_features])),
      reinterpreted_batch_ndims
= 1
     
))

  local_scale
= local_scale_noncentered * tf.sqrt(local_scale_variance)


  weights_noncentered
= yield Root(tfd.Independent(
          tfd
.Normal(
              loc
=tf.zeros([num_features]),
              scale
=tf.ones([num_features])
           
),
            reinterpreted_batch_ndims
= 1
       
))

  weights
= weights_noncentered * local_scale * global_scale
 
return weights

What I'm wondering about is 1) whether this alternative parameterization would also be advantageous in my use case (prior on the latent space of an autoencoder) and 2) how to use this parameterization of the horseshoe prior in a tfpl.KLDivergenceAddLoss() layer.

Thanks!
Bryor Snefjella
Reply all
Reply to author
Forward
0 new messages