Autobatching an unnormalized density

Skip to first unread message

Charles Margossian

Feb 8, 2022, 1:18:06 PMFeb 8
to TensorFlow Probability
I have a distribution for which I can write an unnormalized log density. I would like to autobatch it, perhaps using tfd.JointDistributionSequentialAutoBatched. This density can however not be written as a composition of TFP distributions.

Calculation of the target log density looks as follows:
def ag_marginal_lpdf(Z, L, B, beta, n_states):
target_lpdf = 0
for l in range(0, n_states - 1):
q = jnp.matmul(L.transpose(), Z[:, l])
target_lpdf += -0.5 *, q)
for n in range(0, n_part - 1):
terms_to_sum = jnp.matmul(B[n, :], Z)
target_lpdf += scp.special.logsumexp(terms_to_sum)
return target_lpdf

def unnormalized_log_prob(Z_flat):
Z = Z_flat.reshape(n_part, n_states)
return ag_marginal_lpdf(Z, L, B, beta, n_states)
(I realize the for loops can be vectorized, but I leave this step to later)

I want to run mcmc to sample from the space of the Z array, which has shape (n_part * n_states, ). The other arguments are constant. To run multiple chains, I initialize the chain with an object of shape (num_chains, n_part * n_states).

My understanding is I would have to write ag_marginal_lpdf to return an array of densities with shape (num_chains, ). Given the linear algebra involved, I'm hopping to find a slicker solution using tfd.JointDistributionSequentialAutoBatched, but the examples I've seen build on top of already existing distributions.

Is there a slick way of doing this here? For context, the target density is defined in equation (5) of this paper.

Thank you for your help!

Junpeng Lao

Feb 9, 2022, 1:32:59 AMFeb 9
to TensorFlow Probability, Charles Margossian
Hi Charles.

Charles Margossian

Feb 9, 2022, 5:43:59 PMFeb 9
to TensorFlow Probability,, Charles Margossian

Hi Junpeng,
Thanks for the example, I'll try it out.  Looking at m0 in your demo, it looks like the parameters are defined inside the log_prob function, for example:

beta = yield root(tfd.Normal(0., 10., name='beta')) 

Ok, this makes sense. We're declaring beta, so that we can use it to compute potential_val, and we're putting a prior on it.

In my example, we technically don't have a prior on the matrix Z. (To wit, this is a not a Bayesian model). Is there a way to declare Z without assigning a distribution to it? Equivalently, I could assign a flat prior (how would one do this?), or I can extract a normal kernel from my joint density and put a normal "prior" -- but this would be hack, and it would hurt code clarity.

I'll see if I can make the hacks work but I'm all ears for a more elegant solution.

Pavel Sountsov

Feb 9, 2022, 6:11:35 PMFeb 9
to Charles Margossian, TensorFlow Probability,
You can autobatch the original `unnormalized_log_prob` by just doing `jax.vmap(unnormalized_log_prob)`, if I understand your question correctly. This will transform `unnormalized_log_prob` to handle an extra leading dimension.

We don't have a way to ergonomically declare flat priors in TFP (the best workaround I know of is to put a normal prior or something like that and then cancel it out with IncrementLogProb). Frankly, I'd just quickly whip up my own distribution class, but maybe that's easier said than done for non-TFP developers...

You received this message because you are subscribed to the Google Groups "TensorFlow Probability" group.
To unsubscribe from this group and stop receiving emails from it, send an email to
To view this discussion on the web visit

Brian Patton 🚀

Feb 10, 2022, 10:58:48 AMFeb 10
to Pavel Sountsov, Charles Margossian, TensorFlow Probability,
If you wanted to contribute an ImproperUniform to sit next to IncrementLogProb, that seems reasonable.

If you wanted a flat prior on positive values, would tfb.Softplus()(tfexpdist.ImproperUniform()) wouldn't work. So maybe we'd want some way to also configure a default support bijector.

I guess _sample_n could just raise NotImplemented.

Brian Patton | Software Engineer |

Charles Margossian

Feb 10, 2022, 1:02:03 PMFeb 10
to TensorFlow Probability, Brian Patton, Charles Margossian, TensorFlow Probability,, Pavel Sountsov
Thank you everyone for your help!

Pavel's solution using vmap is exactly what I was looking for. Everything thing seems to run smoothly now.

The flat prior was more a hack than anything else. The idea was to declare a variable in a tfp.distribution object without changing the log prob. I think this could be a useful feature, though I wouldn't necessarily label it as declaring a flat prior. Priors only make sense in Bayesian models and MCMC applies more generally to any probabilistic model.
Reply all
Reply to author
0 new messages