Issues with implementing custom distribution in tensorflowprobability on JAX.

82 views
Skip to first unread message

Damian Ndiwago

unread,
Jun 30, 2022, 11:40:21 AM6/30/22
to TensorFlow Probability

I have implemented a custom distribution in tfp on JAX but have issues when I try to use it. It works with tfp, but I want to use it with JAX. Here is the error I get

 TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)> While tracing the function <lambda> at /usr/local/lib/python3.7/dist-packages/tensorflow_probability/substrates/jax/distributions/joint_distribution.py:367 for eval_shape, this value became a tracer due to JAX operations on these lines:

and link  https://colab.research.google.com/drive/1REDfedxTbozoKJh-nTNVcETPaJFU9jft?usp=sharing  to my  notebook

 

Thank you for any help.

Damian

Colin Carroll

unread,
Jun 30, 2022, 12:02:49 PM6/30/22
to TensorFlow Probability, ndiwag...@gmail.com
Hey Damian -- 
I got this working with 3 changes:

1. Replace imports like "from tensorflow_probability.python.distributions import distribution" with "from tensorflow_probability.substrates.jax.distributions import distribution"
2. Remove the `reparameterization` import, and use `tfd.FULLY_REPARAMETERIZED` instead
3. Use `tf.convert_to_tensor` instead of `ps.convert_to_tensor` in `sample_n`

--Colin

Damian Ndiwago

unread,
Jul 1, 2022, 5:17:09 AM7/1/22
to Colin Carroll, TensorFlow Probability
Thank you so much. It solved the problem.
Reply all
Reply to author
Forward
0 new messages