TensorFlow Probability on JAX help

172 views
Skip to first unread message

Mingo Damian

unread,
Jan 21, 2022, 12:16:34 PM1/21/22
to tfprob...@tensorflow.org

Dear All,

 

I want to use TensorFlow probability on JAX. The example on logistic regression works. When I try multiple regression, I get the error below. Thank you for any help. Here is a link to a notebook. 

https://colab.research.google.com/drive/1IhktKj6bAJvGDv71NR-QQBndpCKYKsKf?usp=sharing


TypeError                                 Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py in <lambda>(x1, x2)
    677     fn = lambda x1, x2: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x1, x2))
    678   else:
--> 679     fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2))
    680   fn = jit(fn, inline=True)
    681   if lax_doc:


TypeError: sub got incompatible shapes for broadcasting: (100,), (5,).



 

Pavel Sountsov

unread,
Jan 21, 2022, 1:21:58 PM1/21/22
to Mingo Damian, TensorFlow Probability
Your tfd.Independent constructor should take 2 as the number of reinterpreted dimensions, not 1. You have a dimension corresponding to the data size (100) and the number of outputs (1). A general check you can do when writing JointDistributions is to look at `dist.batch_shape` and make sure all of those are the same (typically empty).

--
You received this message because you are subscribed to the Google Groups "TensorFlow Probability" group.
To unsubscribe from this group and stop receiving emails from it, send an email to tfprobabilit...@tensorflow.org.
To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/tfprobability/CAEGAC7T7rL5F5gDzQsNPGAzazDe4yaqugf1ccwG0hGrjC-K%3DFw%40mail.gmail.com.

Mingo Damian

unread,
Jan 21, 2022, 3:12:10 PM1/21/22
to Pavel Sountsov, TensorFlow Probability
Thank you so much and for the tip on 'dist.batch'. This solves my problem.

Reply all
Reply to author
Forward
0 new messages