Issues with implementing custom distribution in tensorflowprobability on JAX.

Skip to first unread message

Damian Ndiwago

Jun 30, 2022, 11:40:21 AM6/30/22
to TensorFlow Probability

I have implemented a custom distribution in tfp on JAX but have issues when I try to use it. It works with tfp, but I want to use it with JAX. Here is the error I get

 TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)> While tracing the function <lambda> at /usr/local/lib/python3.7/dist-packages/tensorflow_probability/substrates/jax/distributions/ for eval_shape, this value became a tracer due to JAX operations on these lines:

and link  to my  notebook


Thank you for any help.


Colin Carroll

Jun 30, 2022, 12:02:49 PM6/30/22
to TensorFlow Probability,
Hey Damian -- 
I got this working with 3 changes:

1. Replace imports like "from tensorflow_probability.python.distributions import distribution" with "from tensorflow_probability.substrates.jax.distributions import distribution"
2. Remove the `reparameterization` import, and use `tfd.FULLY_REPARAMETERIZED` instead
3. Use `tf.convert_to_tensor` instead of `ps.convert_to_tensor` in `sample_n`


Damian Ndiwago

Jul 1, 2022, 5:17:09 AM7/1/22
to Colin Carroll, TensorFlow Probability
Thank you so much. It solved the problem.
Reply all
Reply to author
0 new messages