reinterpreted_batch_ndims=1)
))
I've removed the `Independent`s on the priors, but kept it on the likelihood. Why? Forget about batching for a sec. We should think of the model structure being passed to JointDistribution as describing the generative process for a single "sample path". In this case, the generative process is
- draw one rate each from two Exponential priors
- draw uniform change point
- for years 1...N, draw a count from a Poisson with either the earlier or later rate (drawn above).
This expresses the factorization of the joint distribution over all the random variables:
p(e, l, s, {d_t}_t=1^N) = p(e) p(l) p(s) p(d_1 | e, l, s) p(d_2 | e, l, s) ... p(d_N | e, l, s)
Ok, why the Independent on the Poisson and not the others? When we instantiate the Poisson with a vector of rates, it imbues a "batch shape" -- in this case the shape is the number of years (I think that's 30 in your running example?). Here's the "fundamental law of Distribution batch_shape":
when you ask a Distribution with a batch_shape for the log_prob of a single datum, you will get a batch_shape-shaped answer. If my (scalar, say) distribution D has batch_shape [5] and I ask it for D.log_prob(0.), I will get an answer of shape [5], which is the corresponding log_prob at 0 for each of the batch of 5 (presumably differently parameterized) distributions. Note: there's some implicit broadcasting going on here. I'd get the same answer if I asked for
D.log_prob(np.zeros(batch_shape)). I could also pass in a different datum to be evaluated for each of the distributions:
D.log_prob(np.arange(5)) would pass 0 to the 0th distribution, 1 to the 1th, 2 to the 2th, etc. I'd still get an answer of shape [5]. I could also pass in something with, say, shape [30, 1] (or [30, 5]) and the inner dimension would broadcast against the batch_shape of the distribution -- in each case I'd get back a log_prob result of shape [30, 5]. If you're familiar with numpy broadcasting behavior, this should all feel pretty familiar. The nice thing about having batch_shapes, instead of, say, a Python list of separate Distribution instances, is that the computation will be "vectorized"; modern CPUs and GPUs can perform the same computation on multiple inputs more efficiently than doing them serially, or even in parallel on separate (process) threads (see
SIMD).
Independent lets us "reinterpret" a batch of distributions as a single distribution over multivariate samples. Primarily, this means instead of getting batch_shape-shaped log_probs, we'll just get a single log_prob out. The name independent is meant to evoke the factorization structure of a distribution over several independent quantities, like the Poisson bits of the factorization above:
p(d_1 | e, l, s) p(d_2 | e, l, s) ... p(d_N | e, l, s) <-- N independent terms, all Poisson, but with different rates
the reinterpreted_batch_ndims argument says how many batch dimensions we should "reinterpret". Often this is just 1, but we don't set it by default (I don't know if there's a safe and sensible way to do this, but I haven't thought about it much).