# Minibatch-based optimization of the ELBO

58 views

### Josh Chang

May 24, 2022, 8:10:56 PM5/24/22
to TensorFlow Probability
I don't see much in the documentation on using minibatch-based optimization for variational inference. I'm looking for an easy way to reweight the two terms in the ELBO in the case where I'm feeding in a batch of data at a time into tfp.vi.monte_carlo_variational_loss, and taking a gradient descent step per batch.

It appears that I would need to reweight the two terms (log_weights = target_log_prob - q_lp)  here: https://github.com/tensorflow/probability/blob/88d217dfe8be49050362eb14ba3076c0dc0f1ba6/tensorflow_probability/python/vi/csiszar_divergence.py#L1132

Am I missing a simpler way of implementing minibatch training rather than optimizing the _make_importance_weighted_divergence_fn function?

### Christopher Suter

May 25, 2022, 12:16:53 PM5/25/22
to Josh Chang, TensorFlow Probability
I think you can apply a rescaling term to the KL by wrapping the default discrepancy_fn argument. By default it is tfp.vi.reverse_kl, which you can replace with lambda *a : beta * tfp.vi.reverse_kl(*a).

--
You received this message because you are subscribed to the Google Groups "TensorFlow Probability" group.
To unsubscribe from this group and stop receiving emails from it, send an email to tfprobabilit...@tensorflow.org.

### Josh Chang

May 26, 2022, 1:29:11 PM5/26/22
to Christopher Suter, TensorFlow Probability
Hi Christopher, thanks for the reply. I'm not convinced that reweighting the discrepancy_fn term would do what I need. The contribution to the elbo from a single data point should be

$$\left(\frac{1}{N}D_{KL}(q(\theta|\xi)|P(\theta)) - \mathbb{E}_q \log P(D_n| \theta) \right)$$

which implies actually that I would need to reweight the prior within unormalized_log_prob itself.

Warm regards