Dear All,
I want to use TensorFlow probability on JAX. The example on logistic regression works. When I try multiple regression, I get the error below. Thank you for any help. Here is a link to a notebook.
https://colab.research.google.com/drive/1IhktKj6bAJvGDv71NR-QQBndpCKYKsKf?usp=sharing
TypeError Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py in <lambda>(x1, x2) 677 fn = lambda x1, x2: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x1, x2)) 678 else: --> 679 fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2)) 680 fn = jit(fn, inline=True) 681 if lax_doc:
TypeError: sub got incompatible shapes for broadcasting: (100,), (5,).
--
You received this message because you are subscribed to the Google Groups "TensorFlow Probability" group.
To unsubscribe from this group and stop receiving emails from it, send an email to tfprobabilit...@tensorflow.org.
To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/tfprobability/CAEGAC7T7rL5F5gDzQsNPGAzazDe4yaqugf1ccwG0hGrjC-K%3DFw%40mail.gmail.com.