Hi everyone,
I'm currently learning Bayesian approaches through the book "Rethinking" and would like to utilize TensorFlow Probability.
In this example, the model is defined using a function. However, I prefer to use the JointDistributionNamed approach as it closely aligns with how models are written in "Rethinking."
I'm facing issues when running HMC, particularly with the target_log_prob_fn function. I'm unsure about the type of object returned within the HamiltonianMonteCarlo function.
Here is my code and the corresponding error message:
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd
tfb = tfp.bijectors
import pandas as pd
d = pd.read_csv('rethinking-master/data/Howell1.csv',sep = ';')
height= d.height
model = tfd.JointDistributionNamedAutoBatched(dict(
s = tfd.Sample(tfd.Exponential(1), sample_shape=1),
alpha = tfd.Sample(tfd.Normal(0,1), sample_shape=1),
beta = tfd.Sample(tfd.Normal(0,1), sample_shape=1),
y = lambda s,alpha,beta: tfd.Independent(tfd.Normal( alpha + beta * d.weight.values,s),
reinterpreted_batch_ndims=1),
))
def _trace_fn_transitioned(_, pkr):
return pkr.inner_results.inner_results.log_accept_ratio
num_chains = 4
num_leapfrog_steps = 4
step_size = 0.8
burnin = 500
params = ['s', 'alpha', 'beta']
init_state = list(model.sample(num_chains))[:-1]
bijectors = [tfb.Identity() for i in init_state]
observed_data=(d.height.values)
def target_log_prob_fn (**x):
print(x[0])
return model.log_prob(model.sample(y = observed_data, **x))
hmc_kernel = tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn, num_leapfrog_steps=num_leapfrog_steps, step_size=step_size
)
inner_kernel = tfp.mcmc.TransformedTransitionKernel(
inner_kernel=hmc_kernel, bijector=bijectors
)
kernel = tfp.mcmc.SimpleStepSizeAdaptation(
inner_kernel=inner_kernel,
target_accept_prob=0.8,
num_adaptation_steps=int(0.8 * burnin),
log_accept_prob_getter_fn=lambda pkr: pkr.inner_results.log_accept_ratio,
)
tfp.mcmc.sample_chain(
num_results=544,
num_burnin_steps=burnin,
current_state=init_state,
kernel=kernel,
trace_fn=_trace_fn_transitioned,
)
error: target_log_prob_fn() takes 0 positional arguments but 3 were given
Thanks,
Sebastian