numberOfChains=1
numberOfSamplesBurnIn=int(1E4)
numberOfSamples=int(1E4)
myModel=sm.SampleModel(jnp.linspace(0,2*jnp.pi,100), [2,3,0.5] )
p0=jnp.array([1,1,1],dtype=jnp.float32)
t=time.time()
hcm_kernel = tfp.mcmc.HamiltonianMonteCarlo(target_log_prob_fn=myModel.logLikelihood, step_size=1, num_leapfrog_steps=3)
# This adapts the inner kernel's step_size.
adaptive_hmc = tfp.mcmc.SimpleStepSizeAdaptation(
inner_kernel = hcm_kernel,
num_adaptation_steps=int(numberOfSamplesBurnIn* 0.8)
)
samples, is_accepted =tfp.mcmc.sample_chain(
num_results=numberOfSamples,
num_burnin_steps=numberOfSamplesBurnIn,
current_state=p0,
kernel=adaptive_hmc,
trace_fn=lambda _, pkr: pkr.inner_results.is_accepted,seed=jax.random.PRNGKey(0)
)
print("Elapsed: "+ str(time.time()-t))
return samples,is_accepted