Proper way of learning multivariate Gaussian distribution with non-diagonal covariance matrix

159 views
Skip to first unread message

Yen Ting Lin

unread,
Sep 9, 2021, 6:30:41 PM9/9/21
to TensorFlow Probability

Hi all,

I'm new here, and new to TFP, so apologies if it's been asked and solved before; I tried to search but didn't find a direct solution :) 

We are trying to perform some variational inference using TFP. The method requires both generating samples from a multivariate Gaussian distribution, and evaluate log_prob on those samples. Very typical variational inference like Bayes by Backprop, in some sense, with the generalization that we want to learn the full covariance matrix which is not constrained to be diagonal.

  Given the dimension N, the trainable variables are N biases, and N(N+1)/2 entries in the Cholesky decomposed, lower-triangle matrix. In the following snippets, what we found is that the only way forward is to inject a full N by N tril into tfp.distributions.MultivariateNormalTriL:

 string = 'Inject tril as a matrix'

 inferred_tril_mtx = tf.Variable(tril_GT.astype(np.float32), dtype='float32', trainable=True, name='inferred_tril_mtx')

 trilFullMVN = tfp.distributions.MultivariateNormalTriL(inferred_Mean, scale_tril=inferred_tril_mtx)

 with tf.GradientTape() as tape:

     samples = trilFullMVN.sample(1024)

    entropy = tf.reduce_sum(trilFullMVN.log_prob(samples))

 variables = tape.watched_variables()

 print(string)

print('learnable variables:')

 for variable in variables:

print(variable)

 

Output:

 

Inject tril as a matrixlearnable variables:<tf.Variable 'inferred_Mean:0' shape=(2,) dtype=float32, numpy=array([0., 0.], dtype=float32)><tf.Variable 'inferred_tril_mtx:0' shape=(2, 2) dtype=float32, numpy=array([[1. , 0. ],       [0.8, 0.6]], dtype=float32)>

 

This makes the problem approximately twice as large (N^2 v. N(N+1)/2), and I'm concerned if the lower triangular structure is strictly preserved during training (experiments seemed to suggest so). Other ways seemed to have broken the automatic differentiation—the trainable variable of the bias can always flow through, but those ones in the covariance matrix could not. I was not able to declare a N(N+1)/2-entry vector and convert/fill it into a lower triangular matrix, neither via tfp.math.fill_triangular() nor tfp.bijectors.FillScaleTriL():

 

string = 'Inject tril as a vector, use tfp.bijectors.FillScaleTriL() to convert it into a matrix'

 

inferred_tril_vec = tf.Variable(tril_vec_GT.astype(np.float32), dtype='float32', trainable=True, name='inferred_tril_vec')

 

trilVec1MVN = tfp.distributions.MultivariateNormalTriL(inferred_Mean, scale_tril=tfp.bijectors.FillScaleTriL().forward(inferred_tril_vec))

 

with tf.GradientTape() as tape:

 

    samples = trilVec1MVN.sample(1024)

    entropy = tf.reduce_sum(trilVec1MVN.log_prob(samples))

   

variables = tape.watched_variables()

 

print(string)

print('learnable variables:')

 

for variable in variables:

print(variable)

 

Output:

 

Inject lower triangle as a vector, use tfp.bijectors.FillScaleTriL() to convert it into a matrixlearnable variables:<tf.Variable 'inferred_Mean:0' shape=(2,) dtype=float32, numpy=array([0., 0.], dtype=float32)>

 

string = 'Inject tril as a vector, use tfp.math.fill_triangular() to convert it into a matrix'

 

inferred_tril_vec = tf.Variable(tril_vec_GT.astype(np.float32), dtype='float32', trainable=True, name='inferred_tril_vec')

 

trilVec2MVN = tfp.distributions.MultivariateNormalTriL(inferred_Mean, scale_tril=tfp.math.fill_triangular(inferred_tril_vec))

 

with tf.GradientTape() as tape:

 

    samples = trilVec2MVN.sample(1024)

    entropy = tf.reduce_sum(trilVec2MVN.log_prob(samples))

   

variables = tape.watched_variables()

 

print(string)

print('learnable variables:')

 

for variable in variables:

print(variable)

 

Output:

 

<tf.Variable 'inferred_Mean:0' shape=(2,) dtype=float32, numpy=array([0., 0.], dtype=float32)>

 

 

 Out of curiosity, I tried out the deprecated tfp.distributions. MultivariateNormalFullCovariance, and follow the warning/instruction to combine tfp.distributions.MultivariateNormalTriL and tf.linalg.cholesky(), also no luck:

 

string = 'Inject full cov as a matrix, use tfp.linalg.cholesky() as suggested in tfp.distributions.MultivariateNormalFullCovariance'

 

fullCov = tf.Variable(cov_GT.astype(np.float32), dtype='float32', trainable=True, name='inferred_fullCov')

 

covFullMVN = tfp.distributions.MultivariateNormalTriL(inferred_Mean, scale_tril=tf.linalg.cholesky(fullCov))

 

with tf.GradientTape() as tape:

 

    samples = covFullMVN.sample(1024)

    entropy = tf.reduce_sum(covFullMVN.log_prob(samples))

   

variables = tape.watched_variables()

 

print(string)

print('learnable variables:')

 

for variable in variables:

print(variable)

 

Output:

 

Inject full cov as a matrix, use tfp.linalg.cholesky() as suggested in tfp.distributions.MultivariateNormalFullCovariancelearnable variables:<tf.Variable 'inferred_Mean:0' shape=(2,) dtype=float32, numpy=array([0., 0.], dtype=float32)>

 

The same syntax seems to work with diagonal Gaussian tfp.distributions. MultivariateNormalDiag, though:

 

string = 'Inject diagonal of the covariance matrix as a vector'

 

inferred_covDiag = tf.Variable(np.ones(2).astype(np.float32), dtype='float32', trainable=True, name='inferred_covDiag')

diagMVN = tfp.distributions.MultivariateNormalDiag(inferred_Mean, inferred_covDiag)

 

with tf.GradientTape() as tape:

 

    samples = diagMVN.sample(1024)

    entropy = tf.reduce_sum(diagMVN.log_prob(samples))

   

variables = tape.watched_variables()

 

print(string)

print('learnable variables:')

 

for variable in variables:

print(variable)

 

Inject diagonal of the covariance matrix as a vector

learnable variables:

<tf.Variable 'inferred_Mean:0' shape=(2,) dtype=float32, numpy=array([0., 0.], dtype=float32)>

<tf.Variable 'inferred_covDiag:0' shape=(2,) dtype=float32, numpy=array([1., 1.], dtype=float32)>

 

I wonder what the best practice is? Any comment will be appreciated, thanks in advance!

Reply all
Reply to author
Forward
0 new messages