Custom distribution to do a distributed Inference with JAX

161 views
Skip to first unread message

Angel Berihuete

unread,
May 25, 2021, 11:30:54 AM5/25/21
to TensorFlow Probability
Dear TFP users.

I'm working on a model with a paramter for each observation, and I've got 2e5 observations. I think a good approach to this problem is to use the, recently introduced, distributed inference with JAX. It is very promising!

I've read the docs, and I understand that in my case it will be necesary to use Sharded distributions. But, what happen when I have a custom distribution? Is it necessary to rewrite the distribution to deal with JAX, or maybe TFP team will give us a "magic function" to translate it to JAX? ;)

A notebook to show the problem:

Brian Patton 🚀

unread,
May 25, 2021, 11:59:22 AM5/25/21
to Angel Berihuete, TensorFlow Probability
If you replace your TF import with `from tensorflow_probability.substrates import jax as tfp; tf = tfp.tf2jax` you might get pretty far with that. It's unsupported, but that's (largely) how the TFP on JAX code works under the hood. You might run into issues with JAX if you have any shapes that aren't static.

If you don't want to use JAX, the Sharded distribution and its cooperating JointDistribution* classes also work with TF's TPUStrategy.run(spmd_inference) or a multi-physical-GPU MirroredStrategy.run(tf.function(spmd_inference, jit_compile=True)), we just don't have a notebook up for these cases.

The quick workaround for shape issues in JAX is to replace tf.shape, tf.concat with prefer_static.shape, prefer_static.concat (from the internal pkg)

Brian Patton | Software Engineer | b...@google.com



--
You received this message because you are subscribed to the Google Groups "TensorFlow Probability" group.
To unsubscribe from this group and stop receiving emails from it, send an email to tfprobabilit...@tensorflow.org.
To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/tfprobability/28b9f219-a321-4592-b201-a7519f97c0bbn%40tensorflow.org.

Angel Berihuete

unread,
May 27, 2021, 10:51:35 AM5/27/21
to TensorFlow Probability, Brian Patton, TensorFlow Probability, Angel Berihuete
Thanks Brian. Now I have two ways to do the inference with Sharded distribution, JAX or TF.

In order to draw on the code made previously, I'll try TF with a multi-physical-GPU MirroredStrategy.run(tf.function(spmd_inference, jit_compile=True)).

I've sharded my dataset using 

GLOBAL_BATCH_SIZE=10000

def dataset_fn(ctx):
    ucds_batches = tf.data.experimental.make_csv_dataset(
        'data.csv', batch_size=GLOBAL_BATCH_SIZE, l abel_name="source_id")
    return ucds_batches.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)

data_sharded = st.distribute_datasets_from_function(dataset_fn)

But I still do not understand the block spmd_inference. I've tried to do something like you do in the example MovieLens recommendation system (MCMC with Sharded distributions):

def make_run(*,
    num_chains=2,
    step_size=1e-2,
    num_leapfrog_steps=100,
    num_burnin_steps=1000,
    num_results=500,
    ):
  def run(data):
    ...
    def prior_fn():
    ...
    prior = tfed.JointDistributionCoroutine(prior_fn)
   def model_fn():
    ...
    model = tfed.JointDistributionCoroutine(model_fn)
    initial_state
    def target_log_prob
    momentum_distribution
    kernel
    return tfm.sample_chain
  return run

and finally

spmd_inference = make_run()
spmd_inference_ds = functools.partial(spmd_inference, data=data_sharded)
output = st.run(tf.function(spmd_inference, jit_compile=True))

but 'data' argument is not well defined

TypeError: run() missing 1 required positional argument: 'data'

So how do I link data_sharded with run function?

Brian Patton 🚀

unread,
May 27, 2021, 11:21:28 AM5/27/21
to Angel Berihuete, TensorFlow Probability
Instead of partialing the make_run, pass the sharded dataset as an argument to st.run:
st.run(tf.function(spmd_fn, jit_compile=True), sharded_dataset)
Does that fix it?

Brian Patton | Software Engineer | b...@google.com


Brian Patton 🚀

unread,
May 27, 2021, 11:22:27 AM5/27/21
to Angel Berihuete, TensorFlow Probability
Literally:

# delete these lines: spmd_inference = ...
# spmd_inference_ds = ...
output = st.run(tf.function(make_run(), jit_compile=True), data_sharded)


Brian Patton | Software Engineer | b...@google.com


Angel Berihuete

unread,
May 28, 2021, 5:54:17 AM5/28/21
to TensorFlow Probability, Brian Patton, TensorFlow Probability, Angel Berihuete
Thanks Brian. I'm coding something wrong in my model ... maybe is functools before run. Now I obtain

ValueError: positional args must be a list or tuple, got <class 'tensorflow.python.distribute.input_lib.DistributedDatasetsFromFunction'>

I've written a notebook with the model:


I appreciate your help

Brian Patton 🚀

unread,
Jun 1, 2021, 11:04:22 AM6/1/21
to Angel Berihuete, TensorFlow Probability
Oh right, the second argument to strategy.run is "args", so write
output = st.run(tf.function(make_run(), jit_compile=True), (data_sharded,))

Brian Patton | Software Engineer | b...@google.com


Angel Berihuete

unread,
Jun 2, 2021, 6:53:59 AM6/2/21
to TensorFlow Probability, Brian Patton, TensorFlow Probability, Angel Berihuete
Thanks Brian.

I have issues using DistributedDataset. I need to extract parts of the DistributedDataset in order to feed some distributions in the model. In chunk 6, line 9 I have some tf.gather ...

obs0 = tf.gather(data, indices=[0], axis=1)

but I fear this is not correct in a DistributedDataset strategy. I'll read more about tf.distribute.DistributedDataset and how to extract some parts to feed sharded distributions in the model. Again, any help is very welcome :)

Reply all
Reply to author
Forward
0 new messages