# target_log_prob_fn for JointDistributionNamed

34 views

### Basosa

Nov 29, 2023, 1:19:16 PM11/29/23
to TensorFlow Probability
Hi everyone,

I'm currently learning Bayesian approaches through the book "Rethinking" and would like to utilize TensorFlow Probability.
Specifically, I
In this example, the model is defined using a function. However, I prefer to use the JointDistributionNamed approach as it closely aligns with how models are written in "Rethinking."
I'm facing issues when running HMC, particularly with the target_log_prob_fn function. I'm unsure about the type of object returned within the HamiltonianMonteCarlo function.

Here is my code and the corresponding error message:

import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd
tfb = tfp.bijectors
import pandas as pd
height= d.height
model = tfd.JointDistributionNamedAutoBatched(dict(
s = tfd.Sample(tfd.Exponential(1), sample_shape=1),
alpha = tfd.Sample(tfd.Normal(0,1), sample_shape=1),
beta = tfd.Sample(tfd.Normal(0,1), sample_shape=1),
y = lambda s,alpha,beta: tfd.Independent(tfd.Normal( alpha + beta * d.weight.values,s),
reinterpreted_batch_ndims=1),
))

def _trace_fn_transitioned(_, pkr):
return pkr.inner_results.inner_results.log_accept_ratio

num_chains = 4
num_leapfrog_steps = 4
step_size = 0.8
burnin = 500
params = ['s', 'alpha', 'beta']
init_state = list(model.sample(num_chains))[:-1]
bijectors = [tfb.Identity() for i in init_state]
observed_data=(d.height.values)

def target_log_prob_fn (**x):
print(x[0])
return model.log_prob(model.sample(y = observed_data, **x))

hmc_kernel = tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn, num_leapfrog_steps=num_leapfrog_steps, step_size=step_size
)

inner_kernel = tfp.mcmc.TransformedTransitionKernel(
inner_kernel=hmc_kernel, bijector=bijectors
)

inner_kernel=inner_kernel,
target_accept_prob=0.8,
log_accept_prob_getter_fn=lambda pkr: pkr.inner_results.log_accept_ratio,
)

tfp.mcmc.sample_chain(
num_results=544,
num_burnin_steps=burnin,
current_state=init_state,
kernel=kernel,
trace_fn=_trace_fn_transitioned,
)

error: target_log_prob_fn() takes 0 positional arguments but 3 were given

Thanks,

Sebastian