MVNPrecisionCholeskyWithUncerts well-defined but not for inference

35 views
Skip to first unread message

Angel Berihuete

unread,
Apr 1, 2021, 1:35:31 PM4/1/21
to TensorFlow Probability
Dear TensorFlow-Probability group.

In order to do inference in a Gaussian Mixture Model with uncertainties in the observations, I've defined a multivariate normal distribution using precision Cholesky. I need to insert a function inside the new distribution (lines 95-115)... This would take a long time to explain but the main idea is to avoid the inverse of the sum of two matrices (covariance and uncertainties in the observations) and use only Cholesky decompositions (attached distribution).

I've defined and used this distribution (below some code), it seems work properly even calculating the likelihood in the model,  but when I try to do mcmc (hmc or nuts) does not apply properly the tf.map_fn inside the distribution, i.e., tf.map_fn does not split the precisions matrices (below the error).

```python
from __future__ import absolute_import, division, print_function
# from CholeskyWishart import CholeskyWishart
from MVNPrecisionCholeskyWithUncerts import MVNPrecisionCholeskyWithUncerts

import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import tensorflow_probability as tfp
import tensorflow as tf
tfd = tfp.distributions
tfb = tfp.bijectors

DIMS = 2
VALIDATE_ARGS = True
ALLOW_NAN_STATS = False
NUM_SOURCES = 1000

dtype = np.float32

# True amplitudes
true_amp = np.array([0.5, 0.3, 0.2], dtype)

# True locations
true_loc = np.array([[-5., -5], [0, 0], [5, 5]], dtype)

# True correlations
true_cor0 = np.array([[1.0, 0.9], [0.9, 1.0]], dtype)
true_cor1 = np.array([[1.0, -0.9], [-0.9, 1.0]], dtype)
true_cor2 = np.array([[1.0, 0.1], [0.1, 1.0]], dtype)

# True variances
true_var0 = np.array([4.0, 1.0], dtype)
true_var1 = np.array([3.0, 1.0], dtype)
true_var2 = np.array([2.0, 1.0], dtype)

# Combine the variances and correlations into a covariance matrix
true_cov0 = np.expand_dims(np.sqrt(true_var0), axis=1).dot(
    np.expand_dims(np.sqrt(true_var0), axis=1).T) * true_cor0
true_cov1 = np.expand_dims(np.sqrt(true_var1), axis=1).dot(
    np.expand_dims(np.sqrt(true_var1), axis=1).T) * true_cor1
true_cov2 = np.expand_dims(np.sqrt(true_var2), axis=1).dot(
    np.expand_dims(np.sqrt(true_var2), axis=1).T) * true_cor2

# We'll be working with precision matrices, so we'll go ahead and compute the
# true precision matrix here
true_precision0 = np.linalg.inv(true_cov0)
true_precision1 = np.linalg.inv(true_cov1)
true_precision2 = np.linalg.inv(true_cov2)

true_precisions = tf.stack(
    [true_precision0, true_precision1, true_precision2])

true_cov = tf.stack([true_cov0, true_cov1, true_cov2])

true_chol_cov = tf.linalg.cholesky(true_cov)

gm = tfd.MixtureSameFamily(
    mixture_distribution=tfd.Categorical(probs=true_amp),
    components_distribution=tfd.MultivariateNormalTriL(
        loc=true_loc,
        scale_tril=true_chol_cov))

obs = gm.sample(NUM_SOURCES)

# plt.scatter(obs.numpy()[:, 0], obs.numpy()[:, 1])
# plt.show()

# Observed uncertainities/covariances
# NOTE: The shape is augmented to broadcast into
# Mixture shape

uncerts_obs = tf.linalg.diag(
    tfd.Uniform(low=0.1, high=0.9).sample([NUM_SOURCES, DIMS])
)


# Sanity check
a = MVNPrecisionCholeskyWithUncerts(true_loc,
                                    tf.linalg.cholesky(true_precisions),
                                    uncerts_obs, name="MVNPrecisionCholeskyWithUncerts")
samples = a.sample()
plt.scatter(samples.numpy()[:, 0, 0], samples.numpy()[:, 0, 1], c="green")
plt.scatter(samples.numpy()[:, 1, 0], samples.numpy()[:, 1, 1], c="green")
plt.scatter(samples.numpy()[:, 2, 0], samples.numpy()[:, 2, 1], c="green")
plt.scatter(obs.numpy()[:, 0], obs.numpy()[:, 1])
plt.show()
```

When I try to do hmc:


```bash
WARNING:tensorflow:From /home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/kernel.py:104: calling HamiltonianMonteCarlo.__init__ (from tensorflow_probability.python.mcmc.hmc) with step_size_update_fn is deprecated and will be removed after 2019-05-22.
Instructions for updating:
The `step_size_update_fn` argument is deprecated. Use `tfp.mcmc.SimpleStepSizeAdaptation` instead.
WARNING:tensorflow:From /home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py:507: calling HamiltonianMonteCarlo.__init__ (from tensorflow_probability.python.mcmc.hmc) with seed is deprecated and will be removed after 2020-09-20.
Instructions for updating:
The `seed` argument is deprecated (but will work until removed). Pass seed to `tfp.mcmc.sample_chain` instead.
WARNING:tensorflow:From /home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/linalg/linear_operator_diag.py:167: calling LinearOperator.__init__ (from tensorflow.python.ops.linalg.linear_operator) with graph_parents is deprecated and will be removed in a future version.
Instructions for updating:
Do not pass `graph_parents`.  They will  no longer be used.
WARNING:tensorflow:From /home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py:574: calling map_fn_v2 (from tensorflow.python.ops.map_fn) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Use fn_output_signature instead
Traceback (most recent call last):
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 1848, in _create_c_op
    c_op = pywrap_tf_session.TF_FinishOperation(op_desc)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Dimensions must be equal, but are 3 and 1000 for '{{node mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/map/while/MVNPrecisionCholeskyWithUncerts/map/while/MatMul}} = BatchMatMulV2[T=DT_FLOAT, adj_x=false, adj_y=false](mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/map/while/MVNPrecisionCholeskyWithUncerts/map/while/strided_slice, mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/map/while/MVNPrecisionCholeskyWithUncerts/map/while/MatMul/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/map/while/MVNPrecisionCholeskyWithUncerts/Mul)' with input shapes: [3,2,2], [1000,2,2].

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "BGMM05.py", line 479, in <module>
    pilot_samples, pilot_sampler_stat = sample_hmc(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 787, in __call__
    result = self._call(*args, **kwds)
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 830, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 717, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2955, in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3355, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3190, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 989, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 626, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "BGMM05.py", line 269, in sample_hmc
    return tfp.mcmc.sample_chain(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/sample.py", line 361, in sample_chain
    (_, _, final_kernel_results), (all_states, trace) = mcmc_util.trace_scan(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py", line 460, in trace_scan
    _, final_state, _, trace_arrays = tf.while_loop(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 574, in new_func
    return func(*args, **kwargs)
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2489, in while_loop_v2
    return while_loop(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2687, in while_loop
    return while_v2.while_loop(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 188, in while_loop
    body_graph = func_graph_module.func_graph_from_py_func(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 989, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 174, in wrapped_body
    outputs = body(*_pack_sequence_as(orig_loop_vars, args))
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py", line 450, in _body
    state = loop_fn(state, elem)
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/sample.py", line 354, in _trace_scan_fn
    seed, next_state, current_kernel_results = mcmc_util.smart_for_loop(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py", line 349, in smart_for_loop
    return tf.while_loop(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 574, in new_func
    return func(*args, **kwargs)
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2489, in while_loop_v2
    return while_loop(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2687, in while_loop
    return while_v2.while_loop(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 188, in while_loop
    body_graph = func_graph_module.func_graph_from_py_func(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 989, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 174, in wrapped_body
    outputs = body(*_pack_sequence_as(orig_loop_vars, args))
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py", line 351, in <lambda>
    body=lambda i, *args: [i + 1] + list(body_fn(*args)),
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/sample.py", line 351, in _seeded_one_step
    kernel.one_step(*state_and_results, **one_step_kwargs))
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/dual_averaging_step_size_adaptation.py", line 456, in one_step
    new_state, new_inner_results = self.inner_kernel.one_step(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/transformed_kernel.py", line 380, in one_step
    transformed_next_state, kernel_results = self._inner_kernel.one_step(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/hmc.py", line 573, in one_step
    next_state, kernel_results = self._impl.one_step(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/metropolis_hastings.py", line 215, in one_step
    ] = self.inner_kernel.one_step(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/hmc.py", line 774, in one_step
    ] = integrator(current_momentum_parts,
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/leapfrog_integrator.py", line 282, in __call__
    ] = tf.while_loop(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 574, in new_func
    return func(*args, **kwargs)
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2489, in while_loop_v2
    return while_loop(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2687, in while_loop
    return while_v2.while_loop(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 188, in while_loop
    body_graph = func_graph_module.func_graph_from_py_func(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 989, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 174, in wrapped_body
    outputs = body(*_pack_sequence_as(orig_loop_vars, args))
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/leapfrog_integrator.py", line 284, in <lambda>
    body=lambda i, *args: [i + 1] + list(_one_step(  # pylint: disable=no-value-for-parameter,g-long-lambda
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/leapfrog_integrator.py", line 325, in _one_step
    [next_target, next_target_grad_parts] = mcmc_util.maybe_call_fn_and_grads(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py", line 292, in maybe_call_fn_and_grads
    result, grads = _value_and_gradients(fn, fn_arg_list, result, grads)
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py", line 252, in _value_and_gradients
    result = fn(*fn_arg_list)
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/transformed_kernel.py", line 106, in transformed_log_prob_fn
    tlp = log_prob_fn(*fn(state_parts))
  File "BGMM05.py", line 412, in target_log_prob_fn
    return tf.map_fn(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 574, in new_func
    return func(*args, **kwargs)
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/map_fn.py", line 649, in map_fn_v2
    return map_fn(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/map_fn.py", line 509, in map_fn
    _, r_a = control_flow_ops.while_loop(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2687, in while_loop
    return while_v2.while_loop(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 188, in while_loop
    body_graph = func_graph_module.func_graph_from_py_func(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 989, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 174, in wrapped_body
    outputs = body(*_pack_sequence_as(orig_loop_vars, args))
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/map_fn.py", line 499, in compute
    result_value = autographed_fn(elems_value)
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 617, in wrapper
    return func(*args, **kwargs)
  File "BGMM05.py", line 413, in <lambda>
    fn=lambda t: log_prob_fn(t),
  File "BGMM05.py", line 404, in <lambda>
    log_prob_fn = lambda *x: log_prior_fn(*x) + log_likelihood_fn(*x)
  File "BGMM05.py", line 386, in log_likelihood_fn
    mvn = MVNPrecisionCholeskyWithUncerts(
  File "<decorator-gen-344>", line 2, in __init__
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 296, in wrapped_init
    default_init(self_, *args, **kwargs)
  File "/home/angel/gitrepos/HierarchicalModel/Preprocessing/Field/tf2/MVNPrecisionCholeskyWithUncerts.py", line 115, in __init__
    C3 = tf.map_fn(get_inv_sum,
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 574, in new_func
    return func(*args, **kwargs)
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/map_fn.py", line 649, in map_fn_v2
    return map_fn(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/map_fn.py", line 509, in map_fn
    _, r_a = control_flow_ops.while_loop(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2687, in while_loop
    return while_v2.while_loop(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 188, in while_loop
    body_graph = func_graph_module.func_graph_from_py_func(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 989, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 174, in wrapped_body
    outputs = body(*_pack_sequence_as(orig_loop_vars, args))
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/map_fn.py", line 499, in compute
    result_value = autographed_fn(elems_value)
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 617, in wrapper
    return func(*args, **kwargs)
  File "/home/angel/gitrepos/HierarchicalModel/Preprocessing/Field/tf2/MVNPrecisionCholeskyWithUncerts.py", line 97, in get_inv_sum
    tf.linalg.matmul(precision_cholesky[x], B1)))
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py", line 201, in wrapper
    return target(*args, **kwargs)
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/math_ops.py", line 3251, in matmul
    return gen_math_ops.batch_mat_mul_v2(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/gen_math_ops.py", line 1571, in batch_mat_mul_v2
    _, _, _op, _outputs = _op_def_library._apply_op_helper(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/framework/op_def_library.py", line 748, in _apply_op_helper
    op = g._create_op_internal(op_type_name, inputs, dtypes=None,
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 594, in _create_op_internal
    return super(FuncGraph, self)._create_op_internal(  # pylint: disable=protected-access
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 3523, in _create_op_internal
    ret = Operation(
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 2010, in __init__
    self._c_op = _create_c_op(self._graph, node_def, inputs,
  File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 1851, in _create_c_op
    raise ValueError(str(e))
ValueError: Dimensions must be equal, but are 3 and 1000 for '{{node mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/map/while/MVNPrecisionCholeskyWithUncerts/map/while/MatMul}} = BatchMatMulV2[T=DT_FLOAT, adj_x=false, adj_y=false](mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/map/while/MVNPrecisionCholeskyWithUncerts/map/while/strided_slice, mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/map/while/MVNPrecisionCholeskyWithUncerts/map/while/MatMul/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/map/while/MVNPrecisionCholeskyWithUncerts/Mul)' with input shapes: [3,2,2], [1000,2,2].
```
MVNPrecisionCholeskyWithUncerts.py
MVNPrecisionCholeskyWithUncerts.py

Dave Moore

unread,
Apr 1, 2021, 1:58:36 PM4/1/21
to Angel Berihuete, TensorFlow Probability
Just a few quick thoughts:

1. It would help to have a reproducible example, including your MCMC code. Colab (colab.research.google.com) is a good way to share executable reproductions.

2. TFP generally assumes that Tensors passed to distribution constructors and methods can have an arbitrary number of leftmost batch dimensions --- the MCMC API uses this to do multi-chain MCMC, among other things. I'm guessing that your distribution violates this assumption somehow. One potential red flag is that you access precision_cholesky.shape[0]. Indexing shapes from the left usually breaks batch semantics; you almost always want to index from the right, e.g., shape[-3] or similar, so that the semantics don't change when a batch dimension is added.

3. This won't solve your problem, but you may find that `tf.vectorized_map` is much faster than `tf.map_fn`.

4. Instead of creating a new TransformedDistribution inside of your log_prob and sample methods every time they're called, you might as well just inherit from TransformedDistribution so you get all of the relevant methods for free:

class MVNPrecisionCholeskyWithUncerts(tfd.TransformedDistribution):

def __init__(self, ...):
# .... math to compute `precision_cholesky_plus_uncerts`
super(MVNPrecisionCholeskyWithUncerts, self).__init__(
            distribution=tfd.Independent(
                tfd.Normal(loc=tf.zeros(self._precision_cholesky_plus_uncerts.shape[:-1]),
                           scale=tf.ones(self._precision_cholesky_plus_uncerts.shape[:-1])),
                reinterpreted_batch_ndims=1),
            bijector=tfb.Chain([
                tfb.Shift(shift=self._loc),
                tfb.Invert(tfb.ScaleMatvecTriL(
                    scale_tril=self._precision_cholesky_plus_uncerts,
                    adjoint=True))]))

Best,
Dave

--
You received this message because you are subscribed to the Google Groups "TensorFlow Probability" group.
To unsubscribe from this group and stop receiving emails from it, send an email to tfprobabilit...@tensorflow.org.
To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/tfprobability/4b408e97-f7aa-4d20-b330-f6b8970d5fc9n%40tensorflow.org.

Angel Berihuete

unread,
Apr 2, 2021, 6:18:14 AM4/2/21
to TensorFlow Probability, dav...@google.com, TensorFlow Probability, Angel Berihuete
Many thanks for the quick response!

I've added all your suggestions to our code, but the problem persists. I agree with you that the problem is related to the way hmc or nuts use the shape. I've added a reproducible example in a Colab:


Any help will be welcome!
Cheers!

Dave Moore

unread,
Apr 5, 2021, 7:23:46 PM4/5/21
to Angel Berihuete, TensorFlow Probability
I think I found your problem; it's quite subtle.

Your HMC step sizes are initialized to have the same rank as your state variables. But HMC steps are applied to *unconstrained* state variables---as determined by your constraining bijector---and when the `FillTriangular` bijector transforms a value, it changes its rank (the unconstrained representation of a triangular matrix is just a vector). When HMC tries to propose a new iterate, it does something like

  bijector.inverse(initial_state.precisions)  # shape [nchains, 3, 3]
+ step_size.precisions               # shape [nchains,       1, 1, 1]
* gradient wrt unconstrained precision        # shape [nchains, 3, 3]
= new_state.precisions               # shape [nchains, nchains, 3, 3]

and because the step size has one too many dimensions, broadcasting results in two `num_chains` dimensions for the proposed next state. This extra dimension propagates through the target log prob to cause the error you're seeing. To avoid this, it should be sufficient to use the ranks of the unconstrained state parts when defining initial step sizes, e.g.:

unconstrained_init_state = tfb.JointMap(bijectors).inverse(init_state)
step_size = [tf.fill([nchain] + [1] * (len(s.shape) - 1),
tf.constant(0.5, tf.float32)) for s in unconstrained_init_state]

As a TODO on our end, TFP should probably try to detect cases like this and provide a more informative error message when a step attempts to increase the state dimension.

Best,
Dave

Angel Berihuete

unread,
Apr 6, 2021, 7:00:39 AM4/6/21
to TensorFlow Probability, dav...@google.com, TensorFlow Probability, Angel Berihuete
Great!! It works perfectly.
Thank you very much Dave, you have made my day ;)
Cheers!
Angel
Reply all
Reply to author
Forward
0 new messages