Hi there,
I'd like to try creating a mixture of joint distributions.
However, when I try to do so I seem to come up against the below error.
Any help would be much appreciated.
Thanks,
Gareth
import tensorflow_probability as tfp
tfd = tfp.distributions
print(tfp.__version__)
components_distribution = tfd.JointDistributionSequential([
tfd.Gamma(concentration=[10.0, 12.0], rate=2.0), # g
tfd.Normal(loc=[0.0, 0.0], scale=2.), # n
lambda n, g: tfd.Normal(loc=n, scale=g) # m
])
mix = tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(
probs=[0.3, 0.7],
),
components_distribution=components_distribution,
)
leads to
0.13.0
Traceback (most recent call last):
File "/home/gwilliams/Code/talon/users/gwilliams/random_effects/mixture_example_not_working.py", line 44, in <module>
mix = tfd.MixtureSameFamily(
File "/home/gwilliams/Code/talon/venv_talon/lib/python3.9/site-packages/decorator.py", line 232, in fun
return caller(func, *(extras + args), **kw)
File "/home/gwilliams/Code/talon/venv_talon/lib/python3.9/site-packages/tensorflow_probability/python/distributions/distribution.py", line 346, in wrapped_init
default_init(self_, *args, **kwargs)
File "/home/gwilliams/Code/talon/venv_talon/lib/python3.9/site-packages/tensorflow_probability/python/distributions/mixture_same_family.py", line 189, in __init__
super(MixtureSameFamily, self).__init__(
File "/home/gwilliams/Code/talon/venv_talon/lib/python3.9/site-packages/tensorflow_probability/python/distributions/distribution.py", line 632, in __init__
d for d in self._parameter_control_dependencies(is_init=True)
File "/home/gwilliams/Code/talon/venv_talon/lib/python3.9/site-packages/tensorflow_probability/python/distributions/mixture_same_family.py", line 615, in _parameter_control_dependencies
tensorshape_util.with_rank_at_least(
File "/home/gwilliams/Code/talon/venv_talon/lib/python3.9/site-packages/tensorflow_probability/python/internal/tensorshape_util.py", line 370, in with_rank_at_least
return _cast_tensorshape(tf.TensorShape(x).with_rank_at_least(rank), type(x))
File "/home/gwilliams/Code/talon/venv_talon/lib/python3.9/site-packages/tensorflow/python/framework/tensor_shape.py", line 765, in __init__
self._dims = [Dimension(d) for d in dims]
File "/home/gwilliams/Code/talon/venv_talon/lib/python3.9/site-packages/tensorflow/python/framework/tensor_shape.py", line 765, in <listcomp>
self._dims = [Dimension(d) for d in dims]
File "/home/gwilliams/Code/talon/venv_talon/lib/python3.9/site-packages/tensorflow/python/framework/tensor_shape.py", line 206, in __init__
six.raise_from(
File "<string>", line 3, in raise_from
TypeError:
Dimension value must be integer or None or have an __index__ method,
got value 'TensorShape([2])' with type '<class
'tensorflow.python.framework.tensor_shape.TensorShape'>'