Angel Berihuete
unread,May 27, 2021, 10:51:35 AM5/27/21Sign in to reply to author
Sign in to forward
You do not have permission to delete messages in this group
Either email addresses are anonymous for this group or you need the view member email addresses permission to view the original message
to TensorFlow Probability, Brian Patton, TensorFlow Probability, Angel Berihuete
Thanks Brian. Now I have two ways to do the inference with Sharded distribution, JAX or TF.
In order to draw on the code made previously, I'll try TF with a multi-physical-GPU MirroredStrategy.run(tf.function(spmd_inference, jit_compile=True)).
I've sharded my dataset using
GLOBAL_BATCH_SIZE=10000
def dataset_fn(ctx):
ucds_batches = tf.data.experimental.make_csv_dataset(
'data.csv', batch_size=GLOBAL_BATCH_SIZE, l abel_name="source_id")
return ucds_batches.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
data_sharded = st.distribute_datasets_from_function(dataset_fn)
But I still do not understand the block spmd_inference. I've tried to do something like you do in the example MovieLens recommendation system (MCMC with Sharded distributions):
def make_run(*,
num_chains=2,
step_size=1e-2,
num_leapfrog_steps=100,
num_burnin_steps=1000,
num_results=500,
):
def run(data):
...
def prior_fn():
...
prior = tfed.JointDistributionCoroutine(prior_fn)
def model_fn():
...
model = tfed.JointDistributionCoroutine(model_fn)
initial_state
def target_log_prob
momentum_distribution
kernel
return tfm.sample_chain
return run
and finally
spmd_inference = make_run()
spmd_inference_ds = functools.partial(spmd_inference, data=data_sharded)
output = st.run(tf.function(spmd_inference, jit_compile=True))
but 'data' argument is not well defined
TypeError: run() missing 1 required positional argument: 'data'
So how do I link data_sharded with run function?