target_log_prob_fn for JointDistributionNamed

39 views
Skip to first unread message

Basosa

unread,
Nov 29, 2023, 1:19:16 PM11/29/23
to TensorFlow Probability
Hi everyone,

I'm currently learning Bayesian approaches through the book "Rethinking" and would like to utilize TensorFlow Probability.
In this example, the model is defined using a function. However, I prefer to use the JointDistributionNamed approach as it closely aligns with how models are written in "Rethinking."
I'm facing issues when running HMC, particularly with the target_log_prob_fn function. I'm unsure about the type of object returned within the HamiltonianMonteCarlo function.

Here is my code and the corresponding error message:


    import tensorflow_probability as tfp
    from tensorflow_probability import distributions as tfd
    tfb = tfp.bijectors
    import pandas as pd
    d = pd.read_csv('rethinking-master/data/Howell1.csv',sep = ';')
    height= d.height
    model = tfd.JointDistributionNamedAutoBatched(dict(
        s = tfd.Sample(tfd.Exponential(1), sample_shape=1),
        alpha = tfd.Sample(tfd.Normal(0,1), sample_shape=1),
        beta = tfd.Sample(tfd.Normal(0,1), sample_shape=1),
        y = lambda s,alpha,beta: tfd.Independent(tfd.Normal( alpha + beta * d.weight.values,s),
                                              reinterpreted_batch_ndims=1),
    ))

    def _trace_fn_transitioned(_, pkr):
        return pkr.inner_results.inner_results.log_accept_ratio
   
    num_chains = 4
    num_leapfrog_steps = 4
    step_size = 0.8
    burnin = 500
    params = ['s', 'alpha', 'beta']
    init_state = list(model.sample(num_chains))[:-1]
    bijectors = [tfb.Identity() for i in init_state]
    observed_data=(d.height.values)
   
    def target_log_prob_fn (**x):
        print(x[0])
        return model.log_prob(model.sample(y = observed_data, **x))
   
   
    hmc_kernel = tfp.mcmc.HamiltonianMonteCarlo(
            target_log_prob_fn, num_leapfrog_steps=num_leapfrog_steps, step_size=step_size
        )
   
    inner_kernel = tfp.mcmc.TransformedTransitionKernel(
        inner_kernel=hmc_kernel, bijector=bijectors
    )
   
    kernel = tfp.mcmc.SimpleStepSizeAdaptation(
        inner_kernel=inner_kernel,
        target_accept_prob=0.8,
        num_adaptation_steps=int(0.8 * burnin),
        log_accept_prob_getter_fn=lambda pkr: pkr.inner_results.log_accept_ratio,
    )
   
    tfp.mcmc.sample_chain(
            num_results=544,
            num_burnin_steps=burnin,
            current_state=init_state,
            kernel=kernel,
            trace_fn=_trace_fn_transitioned,
        )


error: target_log_prob_fn() takes 0 positional arguments but 3 were given

Thanks,

Sebastian
Reply all
Reply to author
Forward
0 new messages