I am attempting to fit a model with VI. Here's the model:
def gen_lognorm_model(y):
"""Based on Heirarchical 8 Schools model.
This is a pooling model with a learned weight parameter, tau, that weights mean y and y.
See TFP eyample
https://www.tensorflow.org/probability/eyamples/Eight_Schools"""
y, = tuple(tf.cast(y, tf.float32) for y in (y, )) # somewhat complicated, but ready to be expanded for further data.
mean_loc = tf.math.log(tf.math.reduce_mean(y))
mean_scale = tf.math.reduce_std(y)
return tfd.JointDistributionSequential([
tfd.LogNormal(loc=mean_loc, scale=1.), # y_mean
tfd.LogNormal(loc=tf.math.log(.5), scale=.2), # tau
tfd.Independent( # lognorm_y
tfd.LogNormal(loc=tf.math.log(y), scale=1.),
reinterpreted_batch_ndims=1),
tfd.Independent(
# tfd.LogNormal(loc=tf.ones_like(y), scale=.1),
tfd.Deterministic(tf.ones_like(y)),
reinterpreted_batch_ndims=1), # est_scale
# tfd.Deterministic(tf.ones_like(y)),
lambda est_scale, lognorm_y, tau, y_mean: ( # est_y
tfd.Independent(
tfd.LogNormal(
loc=((y_mean[..., tf.newaxis] + lognorm_y * tau[..., tf.newaxis])/(tau[..., tf.newaxis] + 1.)),
scale=est_scale
),
reinterpreted_batch_ndims=1
))
])
In the final distribution, `est_scale`, the commented out LogNormal distribution was resulting in a posterior that seemed pretty off. I set the scale parameter of the final functional distribution(`est_y`) to `1.` and ran again, which gave different results, though not better. In interest of exploring possible ways to model this problem, I decided to set `est_scale` with the tfd.Deterministic layer that is present now. Howevver, this (contrary to my expectation) gives results that are different from setting the scale in `est_y` to 1. I'm unclear on why this is. Given that the overall behavior of this model is not meeting my expectations and that changing from `scale=1.` to `scale=tfd.Determnisitc` is also giving (to me) unexpected results, I'm concerted that I've misunderstood something important. Why would using a deterministic distribution give results different from setting the scale to the same value as a float?