Hello TFP community,
I am trying to code up a fairly simple joint distribution p(x,y), defined as follows
x~Bernoulli(p_success=0.7)
y~Normal(0,1) if x else Laplace(0,1)
I believe this can be accomplished using the JointDistribution APIs such JointDistributionNamed. Following is the code I am using.
@tf.function
def condDist(e):
return tfd.Normal(loc=0.0,scale=1.0) if e else tfd.Laplace(loc=0.0,scale=1.0)
joint = tfd.JointDistributionNamed(
dict(x = tfd.Bernoulli(probs=0.7),
y = lambda x: condDist(x)
),
batch_ndims=0,
use_vectorized_map=True)
I am able to define this joint distribution, but getting an error when calling the sample() function (see the error message at the end). Apparently the tf.cond() function (that encodes the IF statement in the conditional distribution) doesn't like two different distributions (Normal vs. Laplace) are outputted based on whether x is 0 or 1. This should certainly be permissible, hence my hunch is that it is some sort of bug. Any insight will be greatly appreciated.
File "/tmp/ipykernel_37/3096765418.py", line 3, in condDist *
return tfd.Normal(loc=0.0,scale=1.0) if e else tfd.Laplace(loc=0.0,scale=1.0)
TypeError: true_fn and false_fn arguments to tf.cond must have the same number, type, and overall structure of return values.
true_fn output: tfp.distributions.Normal("Normal_1_1", batch_shape=[], event_shape=[], dtype=float32)
false_fn output: tfp.distributions.Laplace("Laplace_1_1", batch_shape=[], event_shape=[], dtype=float32)
Error details:
The two structures don't have the same nested structure.
First structure: type=Normal str=tfp.distributions.Normal("Normal_1_1", batch_shape=[], event_shape=[], dtype=float32)
Second structure: type=Laplace str=tfp.distributions.Laplace("Laplace_1_1", batch_shape=[], event_shape=[], dtype=float32)
More specifically: Incompatible CompositeTensor TypeSpecs: type=Normal_ACTTypeSpec str=Normal_ACTTypeSpec(3, {'loc': TensorSpec(shape=(), dtype=tf.float32, name=None), 'scale': TensorSpec(shape=(), dtype=tf.float32, name=None)}, {'validate_args': False, 'allow_nan_stats': True, 'name': 'Normal_1_1'}, ('parameters',), (), ('name',), {}) vs. type=Laplace_ACTTypeSpec str=Laplace_ACTTypeSpec(3, {'loc': TensorSpec(shape=(), dtype=tf.float32, name=None), 'scale': TensorSpec(shape=(), dtype=tf.float32, name=None)}, {'validate_args': False, 'allow_nan_stats': True, 'name': 'Laplace_1_1'}, ('parameters',), (), ('name',), {})