MCMC sampler doesn't getaway from starting point in High dimensional MVN dist

99 مرّة مشاهدة
التخطي إلى أول رسالة غير مقروءة

jalal haghigh

غير مقروءة،
07‏/04‏/2021، 8:39:59 ص7‏/4‏/2021
إلى TensorFlow Probability
Dear TFP community,
I have problem with sampling from high dimensional multivariate normal distribution with TFP Hamiltonian MCMC. Sampler doesn't getaway from starting point with different configurations (step size & chain length and etc.). All parameters get same value. In my real problem there is a need to define a MVN distribution with at least 300 dimensions (toy example provided here suppose simpler covariance matrix and fixed mean). I really appreciate any help you can provide.
Thanks.

import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
tfd = tfp.distributions
import matplotlib.pyplot as plt

dimen=50

cov_base=np.array([[40000,21000,4800],[21000,22500,2400],[4800,2400,1600]])
cov_mat1=np.tile(cov_base,(dimen,dimen))
jit=np.diag(np.diag(cov_mat1))*.00001
cov_mat2=(cov_mat1)+jit
cov_mat=cov_mat2.astype('float32')

m1=np.zeros([dimen,dimen])
mu_real=tf.reshape(tf.concat([tf.reshape(tf.ones(dimen)*1800,[dimen,1]),tf.reshape(tf.ones(dimen)*1500,[dimen,1]),tf.reshape(tf.ones(dimen)*2200,[dimen,1])], 1),[-1])
mu_model=tf.cast(mu_real,dtype=tf.float32)

prior_dist = tfd.MultivariateNormalFullCovariance(loc=mu_model, covariance_matrix=cov_mat)

def joint_log_prob(param1,param2,param3):
mu_vec=tf.reshape(tf.concat([param1,param2,param3], 0),[-1])
outpu_t=prior_dist.log_prob(mu_vec)
return outpu_t

#MCMC Sampling part
# Create an HMC TransitionKernel
target_log_prob_fn = joint_log_prob
num_burnin_steps = 1000
num_results = 500
num_chains = 16
step_size = [0.05,0.1,0.4]
kernel = tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn=target_log_prob_fn,
num_leapfrog_steps=3,
step_size=step_size)
kernel = tfp.mcmc.SimpleStepSizeAdaptation(
inner_kernel=kernel, num_adaptation_steps=int(num_burnin_steps * 0.8))

@tf.function()
def run_chain():
return tfp.mcmc.sample_chain(
num_results=num_results,
num_burnin_steps=num_burnin_steps,
current_state=[tf.ones(dimen)*1000,tf.ones(dimen)*2000,tf.ones(dimen)*3000],
kernel=kernel,
trace_fn=lambda _, pkr: [pkr.inner_results.is_accepted,
pkr.inner_results.log_accept_ratio])

samples, [is_accepted, log_accept_ratio] = run_chain()
print("Acceptance rate:", is_accepted.numpy().mean())
print("param1:", tf.reduce_mean(samples[0]))
print("param2", tf.reduce_mean(samples[1]))
print("param3", tf.reduce_mean(samples[2]))
print(samples[0])
print(samples[1])
print(samples[2])

fig = plt.figure()
ax1 = fig.add_subplot(3,1,1)
ax1.plot(samples[0])
ax2 = fig.add_subplot(3,1,2)
ax2.plot(samples[1])
ax3 = fig.add_subplot(3,1,3)
ax3.plot(samples[2])
plt.show()

Pavel Sountsov

غير مقروءة،
07‏/04‏/2021، 4:14:07 م7‏/4‏/2021
إلى jalal haghigh،TensorFlow Probability
For approximately Gaussian target distributions, HMC has a well defined condition number that determines the amount of work (read: num_leapfrog_steps) that is necessary to efficiently generate samples. In your code, you can compute it as:

sigmas = np.sqrt(np.linalg.eigvalsh(cov_mat))
sigmas = sigmas.max() / sigmas
kappa = np.sum(sigmas**4)**0.25
print(kappa)
> 34606.4420635801

So, you'll need approximately that many leapfrog steps to see good progress. Since typically won't be able to compute that condition number ahead of time on real problems, you'll also want to use an algorithm that can learn the number of leapfrog steps automatically (e.g. tfp.mcmc.NoUTurnSampler and tfp.experimental.mcmc.PreconditionedNoUTurnSampler, although I think this type of problem is actually one where NUTS is pathologically bad). You'll find that using that many leapfrog steps is rather slow, so if your real problem is also badly conditioned, then you'll want to look into preconditioning options we have available, e.g. tfp.mcmc.TransformedTransitionKernel e.g. tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo. There's also a bunch of API for attempting to learn the preconditioner, but if your problem is truly as badly conditioned as this test, I'd recommend doing some model engineering rather than trying to fix it post hoc.

--
You received this message because you are subscribed to the Google Groups "TensorFlow Probability" group.
To unsubscribe from this group and stop receiving emails from it, send an email to tfprobabilit...@tensorflow.org.
To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/tfprobability/69ce4ca4-229b-4ef4-ad02-8f52532f857dn%40tensorflow.org.

jalal haghigh

غير مقروءة،
08‏/04‏/2021، 2:44:20 م8‏/4‏/2021
إلى TensorFlow Probability،Pavel Sountsov،TensorFlow Probability،jalal haghigh
Thank you very much Pavel for elaboration on HMC. I found that there is also mistake in input of joint_log_prob(). Seems each parameter must be delivered separately or all as a single tensor. The way that I have grouped them as three tensor made of 50 components caused problem.
الرد على الكل
رد على الكاتب
إعادة توجيه
0 رسالة جديدة