There's at least four sources of inefficiency in the TFP implementation:
1. TensorFlow, will only parallelize within the atomic operations (e.g. it'll run matrix multiplies/einsums in parallel) but not much else. This is unlike Stan where each chain is sitting completely independently on each core, achieving near perfect parallelizations (cache issues aside).
2. TFP's NUTS is batch aware, and will run all 4 chains in a SIMD fashion, which implies that each NUTS step will take as maximum number of leapfrog steps across chains.
3. TFP's NUTS has quite a bit of control flow, which TensorFlow struggles with.
4. You're using XLA compilation, which will compile the program the first time you run a @tf.function-decorated function. This compilation can be surprisingly slow at times. In Stan, you're measuring this in a separate cell.
I can believe all those points multiplied together could result in the 10x slowdown you're seeing.
1 and 2 could potentially be addressed by using TensorFlow's SPMD facilities, which will explicitly run computations on different threads. Sadly, the API is a little byzantine, but here's a rather large snippet to get you started:
NUM_CORES = 2
NUM_CHAINS = 4
NUM_CHAINS_PER_CORE = NUM_CHAINS // NUM_CORES
assert NUM_CORES * NUM_CHAINS_PER_CORE == NUM_CHAINS
physical_devices = tf.config.experimental.list_physical_devices()
tf.config.experimental.set_virtual_device_configuration(
physical_devices[0],
[tf.config.experimental.VirtualDeviceConfiguration()] * NUM_CORES)
print(tf.config.list_logical_devices())
strategy = tf.distribute.MirroredStrategy(
devices=tf.config.list_logical_devices())
def target_log_prob_fn(x):
return -tf.reduce_sum((x / tf.linspace(0.1, 1., 10)) **2, -1)
@tf.function(autograph=False, jit_compile=True)
def sample(seed):
return tfp.mcmc.sample_chain(
1000,
tf.zeros([NUM_CHAINS_PER_CORE, 10]),
kernel=tfp.mcmc.NoUTurnSampler(target_log_prob_fn, step_size=0.1),
trace_fn=None,
seed=seed)
seeds = tfp.random.split_seed((0, 0), NUM_CORES)
seeds = strategy.experimental_distribute_values_from_function(
lambda ctx: seeds[ctx.replica_id_in_sync_group])
chain = tf.nest.map_structure(lambda x: tf.concat(x.values, 1),
strategy.run(sample, (seeds,)))