Minibatch-based optimization of the ELBO

58 views
Skip to first unread message

Josh Chang

unread,
May 24, 2022, 8:10:56 PM5/24/22
to TensorFlow Probability
I don't see much in the documentation on using minibatch-based optimization for variational inference. I'm looking for an easy way to reweight the two terms in the ELBO in the case where I'm feeding in a batch of data at a time into `tfp.vi.monte_carlo_variational_loss`, and taking a gradient descent step per batch.

 It appears that I would need to reweight the two terms (`log_weights = target_log_prob - q_lp`)  here: https://github.com/tensorflow/probability/blob/88d217dfe8be49050362eb14ba3076c0dc0f1ba6/tensorflow_probability/python/vi/csiszar_divergence.py#L1132

Am I missing a simpler way of implementing minibatch training rather than optimizing the `_make_importance_weighted_divergence_fn` function?

Christopher Suter

unread,
May 25, 2022, 12:16:53 PM5/25/22
to Josh Chang, TensorFlow Probability
I think you can apply a rescaling term to the KL by wrapping the default discrepancy_fn argument. By default it is `tfp.vi.reverse_kl`, which you can replace with `lambda *a : beta * tfp.vi.reverse_kl(*a)`.

--
You received this message because you are subscribed to the Google Groups "TensorFlow Probability" group.
To unsubscribe from this group and stop receiving emails from it, send an email to tfprobabilit...@tensorflow.org.
To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/tfprobability/225c6fbf-237e-4b93-996a-3adc9fe251bcn%40tensorflow.org.

Josh Chang

unread,
May 26, 2022, 1:29:11 PM5/26/22
to Christopher Suter, TensorFlow Probability
Hi Christopher, thanks for the reply. I'm not convinced that reweighting the discrepancy_fn term would do what I need. The contribution to the elbo from a single data point should be

$$
\left(\frac{1}{N}D_{KL}(q(\theta|\xi)|P(\theta)) - \mathbb{E}_q \log P(D_n| \theta) \right)
$$

which implies actually that I would need to reweight the prior within unormalized_log_prob itself.

Warm regards

Christopher Suter

unread,
May 27, 2022, 5:37:40 PM5/27/22
to Josh Chang, TensorFlow Probability
I'm not sure what you mean by reweighting the prior within unnormalized_log_prob. I think I'm not understanding what you want to do (for example, I'm not sure why you linked to _make_importance_weighted_divergence_fn, given the wording of the question you asked -- i'm probably missing something).

For full batch VI, you want the expected log likelihood of all the data, and a single KL penalty. For a minibatch of B out of N data, you need to downweight the KL term by B / N, so that after N / B minibatches, you end up with the equivalent of 1 KL penalty. In the case you wrote down, for a single data point, B = 1 and you just need to reweight the KL by 1/N. My original suggestion would accomplish this, but I suspect you're actually trying to do something else. If you can clarify, maybe I can offer more help.
Reply all
Reply to author
Forward
0 new messages