I have implemented a custom distribution in tfp on JAX but have issues when I try to use it. It works with tfp, but I want to use it with JAX. Here is the error I get
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32)>with<DynamicJaxprTrace(level=1/0)> While tracing the function <lambda> at /usr/local/lib/python3.7/dist-packages/tensorflow_probability/substrates/jax/distributions/joint_distribution.py:367 for eval_shape, this value became a tracer due to JAX operations on these lines:
and link https://colab.research.google.com/drive/1REDfedxTbozoKJh-nTNVcETPaJFU9jft?usp=sharing to my notebook
Thank you for any help.