Distrax equivalent of TFP.layers

50 views
Skip to first unread message

Jed Homer

unread,
Aug 10, 2022, 9:36:35 AM8/10/22
to TensorFlow Probability
Hi all,

I am looking to reproduce exactly this type of model from this link:

tfk = tf.keras
tfkl
= tf.keras.layers
tfd
= tfp.distributions
tfpl
= tfp.layers

# Load data.
n
= int(1e3)
scale_tril
= np.array([[1.6180, 0.],
                       
[-2.7183, 3.1416]]).astype(np.float32)
x
= tfd.Normal(loc=0, scale=1).sample([n, 2])
eps
= tfd.Normal(loc=0, scale=0.01).sample([n, 2])
y
= tf.matmul(x, scale_tril) + eps

# Create model.
d
= tf.dimension_value(y.shape[-1])
model
= tfk.Sequential([
    tfkl
.Dense(tfpl.MultivariateNormalTriL.params_size(d)),
    tfpl
.MultivariateNormalTriL(d),
])

However the jax substrate of TFP has none of the TFP.layers implemented.

The reason I want to do this is to allow a model to learn the mean / covariance matrix of an arbitrary gaussian dist.

I have checked the relevant source code in the TFP.layers github page but I am a bit lost.

So my question is: what exactly is the tfpl.MultivariateNormalTriL layer doing? I assume it must be constraining the output to be a proper positive definite covariance matrix.

I have written my own layer that attempts to constrain a covariance matrix C = S * S^T where S is the output of the network shaped into a matrix, but my regression breaks randomly (I assume something happens to the regression that stops the output being a proper covariance matrix).

Thanks in advance!

Christopher Suter

unread,
Aug 10, 2022, 9:46:12 AM8/10/22
to Jed Homer, TensorFlow Probability
The tfp layer adds a small positive value to the diagonal (see 
https://github.com/tensorflow/probability/blob/v0.17.0/tensorflow_probability/python/layers/distribution_layer.py#L371). Do you have something similar? There's not much else magical going on in the later afaict.

--
You received this message because you are subscribed to the Google Groups "TensorFlow Probability" group.
To unsubscribe from this group and stop receiving emails from it, send an email to tfprobabilit...@tensorflow.org.
To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/tfprobability/863a1611-b282-44aa-94ba-054769d076bcn%40tensorflow.org.

Jed Homer

unread,
Aug 10, 2022, 9:51:43 AM8/10/22
to TensorFlow Probability, c...@google.com, TensorFlow Probability, Jed Homer
I saw  this, I thought something deeper was happening since they use some shift bijector.

However that bijector seems as though, as you say, it just adds a small float.

The docstring of the ScaleTriL part says

"""
...
diag_shift: Float value broadcastable and added to all diagonal entries

after applying the `diag_bijector`. Setting a positive

value forces the output diagonal entries to be positive, but

prevents inverting the transformation for matrices with

diagonal entries less than this value.
Default value: `1e-5`.
...
"""

So I guess you are right? I will try this.

Christopher Suter

unread,
Aug 10, 2022, 10:38:05 AM8/10/22
to Jed Homer, TensorFlow Probability
Oh the FillScaleTriL bijector is also constraining the diagonal to be positive. So, something like exp(diag) + epsilon is what you'll want to match it. Softplus might be more stable than exp.
Reply all
Reply to author
Forward
0 new messages