lambda defend, sd_def, attack, sd_att, home, alpha: (
tfd.Independent(
tfd.Poisson(
rate=(
tf.exp(
alpha
+ home
+ tf.gather(attack, home_id, axis=-1)
- tf.gather(defend, away_id, axis=-1)
)
),
name="s1"
),
reinterpreted_batch_ndims=1,
)
),
lambda s1, defend, sd_def, attack, sd_att, home, alpha: (
tfd.Independent(
tfd.Poisson(
rate=(
tf.exp(
alpha +
tf.gather(attack, away_id, axis=-1)
- tf.gather(defend, home_id, axis=-1)
)
),
name="s2"
),
reinterpreted_batch_ndims=1,
)
),
]
)
```
with an output on `model(home_id, away_id).resolve_graph()` of
```
(('alpha', ()),
('home', ()),
('sd_att', ()),
('attack', ('sd_att',)),
('sd_def', ()),
('defend', ('sd_def',)),
('s1', ('defend', 'sd_def', 'attack', 'sd_att', 'home', 'alpha')),
('x', ('s1', 'defend', 'sd_def', 'attack', 'sd_att', 'home', 'alpha')))
```
All seems good! Although not sure why the last distribution is called `x` and not `s2`.
However, I'm getting errors now when I try and sample. Code for sampling:
```
@tf.function
def target_log_prob(alpha, home, sd_att, attack, sd_def, defend):
"""Computes joint log prob pinned at `s1` and `s2`."""
return model(home_id, away_id).log_prob(
[alpha, home, sd_att, attack, sd_def, defend, s1, s2]
)
@tf.function(autograph=False, jit_compile=True)
def sample(num_chains, num_results, num_burnin_steps):
"""Samples from the partial pooling model."""
hmc = tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn=target_log_prob,
num_leapfrog_steps=10,
step_size=0.01
)
initial_state = [
tf.zeros([num_chains], name='init_alpha'),
tf.zeros([num_chains], name='init_home'),
tf.ones([num_chains], name='init_sd_att'),
tf.zeros([num_chains, num_teams], name='init_attack'),
tf.ones([num_chains], name='init_sd_def'),
tf.zeros([num_chains, num_teams], name='init_defend'),
]
unconstraining_bijectors = [
tfp.bijectors.Identity(), # alpha
tfp.bijectors.Identity(), # home
tfp.bijectors.Exp(), # sd_att
tfp.bijectors.Identity(), # attack
tfp.bijectors.Exp(), # sd_def
tfp.bijectors.Identity(), # defend
]
kernel = tfp.mcmc.TransformedTransitionKernel(
inner_kernel=hmc,
bijector=unconstraining_bijectors
)
samples, kernel_results = tfp.mcmc.sample_chain(
num_results=num_results,
num_burnin_steps=num_burnin_steps,
current_state=initial_state,
kernel=kernel
)
return samples
samples = sample(num_chains=4, num_results=1000, num_burnin_steps=1000)
```
But I get the error `Dimensions must be equal, but are 4 and 330 for '{{node JointDistributionSequential/log_prob/add_1}} = AddV2[T=DT_FLOAT](JointDistributionSequential/log_prob/add, JointDistributionSequential/log_prob/GatherV2)' with input shapes: [4], [4,330].` Obviously 4 is the number of chains, and 330 is the number of games (`len(home_id)`=`len(away_id)`). I'm not sure where I introduced this error though.
Best,
Theo