TFP exact GP regression in eager mode

74 views
Skip to first unread message

Maxim Osipov

unread,
Aug 14, 2019, 8:46:35 AM8/14/19
to TensorFlow Probability
Cross-posting from https://stackoverflow.com/questions/57493949/tfp-exact-gp-regression-in-eager-mode

I'm trying to perform exact GP regression using the TF2.0 eager mode, based on the original graph based example from https://colab.research.google.com/github/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Gaussian_Process_Regression_In_TFP.ipynb


amplitude = (
    np.finfo(np.float64).tiny +
    tf.nn.softplus(tf.Variable(initial_value=1., name='amplitude', dtype=np.float64))
)
length_scale = (
    np.finfo(np.float64).tiny +
    tf.nn.softplus(tf.Variable(initial_value=1., name='length_scale', dtype=np.float64))
)
observation_noise_variance = (
    np.finfo(np.float64).tiny +
    tf.nn.softplus(tf.Variable(initial_value=1e-6,
                               name='observation_noise_variance',
                               dtype=np.float64))
)

kernel = tfk.ExponentiatedQuadratic(amplitude, length_scale)

gp = tfd.GaussianProcess(
    kernel=kernel,
    index_points=tf.expand_dims(x, 1),
    observation_noise_variance=observation_noise_variance
)

neg_log_likelihood = lambda: -gp.log_prob(y)

optimizer = tf.optimizers.Adam(learning_rate=.01)

num_iters = 1000
lls_ = np.zeros(num_iters, np.float64)
for i in range(num_iters):
    lls_[i] = neg_log_likelihood()
    optimizer.minimize(neg_log_likelihood, var_list=[amplitude, length_scale, observation_noise_variance])

However optimization fails with:

'tensorflow.python.framework.ops.EagerTensor' object has no attribute '_in_graph_mode'

And if I move the amplitude, length_scale and observation_noise_variance each to tf.Variable, like:

amplitude = tf.Variable(initial_value=1., name='amplitude', dtype=np.float64)
amplitude_ = (
    np.finfo(np.float64).tiny +
    tf.nn.softplus(amplitude)
)

Optimization fails with:

ValueError: No gradients provided for any variable: ['amplitude:0', 'length_scale:0', 'observation_noise_variance:0'].

What am I doing wrong?

Christopher Suter

unread,
Aug 14, 2019, 3:31:01 PM8/14/19
to Maxim Osipov, TensorFlow Probability
Hi, the issue is that in eager mode (the default in TF2), calling functions on tf.Variables returns a Tensor that is the result of computing that function at the Variable's current value. The connection to the original Variable is lost. In this case, we're taking the Variable and constraining it to be positive, then handing the result to the kernel, handing the kernel to the GP, and ultimately trying to optimize the marginal log prob via gradient descent w.r.t. those variables. But the connection to the variables has been lost, so the gradients are `None`.

The easiest immediate workaround is to make sure that the `loss` you optimize is a function that takes the raw Variables as input, and returns the log prob. This means constructing the kernel and GP instance inside the loss callable. To avoid extra overhead in the optimization loop, you can decorate that loss function with @tf.function, which should improve performance.

Longer term: We recently added tfp.util.DeferredTensor to help work around these issues, as well as making changes to all our Distributions to prevent calling tf.convert_to_tensor on any Variable (or DeferredTensor) inputs. We haven't updated GP yet (the migration is still in progress), so it's possible it's still going to cause problems in some cases (e.g., if you were training a GP LVM, optimizing variable index_points, I expect we'd fail today).

In short: quick fix is to make sure your loss callable maps tf.Variables to the loss. slightly more ambitious (and possibly not 100% working today): use tfp.util.DeferredTensor (see docstrings and unit tests for more info and examples).

HTH! Please come back with any followup questions.

--
You received this message because you are subscribed to the Google Groups "TensorFlow Probability" group.
To unsubscribe from this group and stop receiving emails from it, send an email to tfprobabilit...@tensorflow.org.
To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/tfprobability/88817062-29a5-4ee8-a6cf-5e89f7b5cd46%40tensorflow.org.
Message has been deleted
Message has been deleted

Maxim Osipov

unread,
Aug 15, 2019, 4:09:59 AM8/15/19
to TensorFlow Probability, maxim....@gmail.com
Thank you Christopher!

Just for the records - the workaround using GradientTape is:

amplitude_ = tf.Variable(initial_value=1., name='amplitude_', dtype=np.float64)
length_scale_ = tf.Variable(initial_value=1., name='length_scale_', dtype=np.float64)
observation_noise_variance_ = tf.Variable(initial_value=1e-6,
                                         name='observation_noise_variance_',
                                         dtype=np.float64)

@tf.function
def neg_log_likelihood():
    amplitude = np.finfo(np.float64).tiny + tf.nn.softplus(amplitude_)
    length_scale = np.finfo(np.float64).tiny + tf.nn.softplus(length_scale_)
    observation_noise_variance = np.finfo(np.float64).tiny + tf.nn.softplus(observation_noise_variance_)

    kernel = tfk.ExponentiatedQuadratic(amplitude, length_scale)

    gp = tfd.GaussianProcess(
        kernel=kernel,
        index_points=tf.expand_dims(x, 1),
        observation_noise_variance=observation_noise_variance
    )

    return -gp.log_prob(y)

optimizer = tf.optimizers.Adam(learning_rate=.01)

num_iters = 1000

nlls = np.zeros(num_iters, np.float64)
for i in range(num_iters):
    nlls[i] = neg_log_likelihood()
    with tf.GradientTape() as tape:
        loss = neg_log_likelihood()
    grads = tape.gradient(loss, [amplitude_, length_scale_, observation_noise_variance_])
    optimizer.apply_gradients(zip(grads, [amplitude_, length_scale_, observation_noise_variance_]))


Sincerely
Max

Christopher Suter

unread,
Aug 16, 2019, 5:03:32 PM8/16/19
to Maxim Osipov, TensorFlow Probability
that looks right! have you gotten it running?

--
You received this message because you are subscribed to the Google Groups "TensorFlow Probability" group.
To unsubscribe from this group and stop receiving emails from it, send an email to tfprobabilit...@tensorflow.org.

Ali

unread,
Dec 26, 2019, 7:09:52 AM12/26/19
to TensorFlow Probability
Hi Maxim,

I'm also curious to see if you got this GP code example to run in TF 2.0 - would you mind sharing your notebook if you have?

Thank you
Reply all
Reply to author
Forward
0 new messages