Apply a bijector to a beta output layer

24 views
Skip to first unread message

Borja Arroyo

unread,
May 12, 2022, 4:40:08 PMMay 12
to TensorFlow Probability
Hi all.

I am trying to build a probabilistic output layer for real values to calculate the log probability aka log likelihood. I am trying a beta distribution to empower the degrees of freedom that this distribution offers, but of course the results should be mapped somehow to R (instead of [0,1]), thus I am trying to use a bisector, although I am not sure whether it is the best approach or not.

The code that I have implemented is the following:

b = k.layers.Dense(2, activation=k.activations.relu)(x)
x = tfp.layers.DistributionLambda(
  lambda t: tfp.distributions.Beta(1e-03 + t[..., :1], 1e-03 + t[..., 1:])
)(x)
out.append(
  tfp.layers.DistributionLambda(
    lambda t: tfp.distributions.TransformedDistribution(
      t, tfp.bijectors.Invert(tfp.bijectors.Sigmoid())
    ),
    name=f"out_{idx}"
)(x))
losses.append(lambda y_true, y_pred: -y_pred.log_prob(y_true))

The problem is that the output is not always as expected. For instance, I am comparing this result with the one obtained with an IndependentNormal layer and another with mse (the typical regressor) and there are some cases in which the beta produces some values for the accuracy whereas the other two just don't learn (there are no correlations between inputs and outputs)

Moreover, I can see this warn when the beta model gets "trained:WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass"

I may be missing something as I am not really used to work with tfp, so any kind of help, hint or whatever you may suggest is appreciated.

Many thanks in advance.

Borja



Borja Arroyo

unread,
May 17, 2022, 2:44:14 PMMay 17
to TensorFlow Probability, Borja Arroyo
I think I found a solution, the problem is that I was messing with shapes because I was not using an Independent wrapper. Below you can find a possible implementation for those having a similar problem:

out.append(
    tfp.layers.DistributionLambda(
        lambda t: tfp.distributions.TransformedDistribution(
            tfp.distributions.Independent(
                tfp.distributions.Beta(t[..., :1], t[..., 1:], allow_nan_stats=False),
                reinterpreted_batch_ndims=1,
            ),

Borja Arroyo

unread,
May 24, 2022, 10:23:56 AMMay 24
to TensorFlow Probability, Borja Arroyo
I also found another possible solution for this problem:

out.append(
   tfp.layers.DistributionLambda(
     lambda t: tfp.distributions.Independent(
       tfp.distributions.SigmoidBeta(t[..., :1], t[..., 1:], allow_nan_stats=False),
       reinterpreted_batch_ndims=1,
     ),
   name=f"out_{idx}"
)(x))

Borja Arroyo

unread,
May 24, 2022, 10:35:22 AMMay 24
to TensorFlow Probability, Borja Arroyo
I have a similar problem when building a VAE with a Beta output layer based on the tfp code for the mnist bernoulli VAE. The following code is not working (final part of the decoder):

tfkl.Permute((3, 1, 2))# change channels last to first
tfkl.Reshape((2, 28, 28, 1)),
tfpl.DistributionLambda(
   lambda t: tfd.Independent(
     tfd.Beta(t[:, 0, ...], t[:, 1, ...], allow_nan_stats=False),
     reinterpreted_batch_ndims=3,
   )
)

I have seen a similar post regarding the beta output layer and some problems related to the loss, which is also nan in my case. I am not sure about the issue as the dataset is normalized into the [0, 1] interval.
I have tried:
  1. Use the exponential of the last Conv2D layer for inputs
  2. Apply a relu activation to the aforementioned layer
  3. Add a tiny value to the inputs (1e-6)
  4. Checked for spikes in the edges and middle (as Brian Patton suggested), and just the first seem to appear: https://medium.com/datadl-ai/mnist-exploration-to-execution-25136ca00570
Moreover, and similar to what happened in the other post, I have sanity checked the same VAE with a Normal distribution instead of the Beta, and the vae gets actual losses.
Reply all
Reply to author
Forward
0 new messages