Dear TensorFlow-Probability group.
In order to do inference in a Gaussian Mixture Model with uncertainties in the observations, I've defined a multivariate normal distribution using precision Cholesky. I need to insert a function inside the new distribution (lines 95-115)... This would take a long time to explain but the main idea is to avoid the inverse of the sum of two matrices (covariance and uncertainties in the observations) and use only Cholesky decompositions (attached distribution).
I've defined and used this distribution (below some code), it seems work properly even calculating the likelihood in the model, but when I try to do mcmc (hmc or nuts) does not apply properly the tf.map_fn inside the distribution, i.e., tf.map_fn does not split the precisions matrices (below the error).
from __future__ import absolute_import, division, print_function
# from CholeskyWishart import CholeskyWishart
from MVNPrecisionCholeskyWithUncerts import MVNPrecisionCholeskyWithUncerts
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import tensorflow_probability as tfp
import tensorflow as tf
tfd = tfp.distributions
tfb = tfp.bijectors
DIMS = 2
VALIDATE_ARGS = True
ALLOW_NAN_STATS = False
NUM_SOURCES = 1000
dtype = np.float32
# True amplitudes
true_amp = np.array([0.5, 0.3, 0.2], dtype)
# True locations
true_loc = np.array([[-5., -5], [0, 0], [5, 5]], dtype)
# True correlations
true_cor0 = np.array([[1.0, 0.9], [0.9, 1.0]], dtype)
true_cor1 = np.array([[1.0, -0.9], [-0.9, 1.0]], dtype)
true_cor2 = np.array([[1.0, 0.1], [0.1, 1.0]], dtype)
# True variances
true_var0 = np.array([4.0, 1.0], dtype)
true_var1 = np.array([3.0, 1.0], dtype)
true_var2 = np.array([2.0, 1.0], dtype)
# Combine the variances and correlations into a covariance matrix
true_cov0 = np.expand_dims(np.sqrt(true_var0), axis=1).dot(
np.expand_dims(np.sqrt(true_var0), axis=1).T) * true_cor0
true_cov1 = np.expand_dims(np.sqrt(true_var1), axis=1).dot(
np.expand_dims(np.sqrt(true_var1), axis=1).T) * true_cor1
true_cov2 = np.expand_dims(np.sqrt(true_var2), axis=1).dot(
np.expand_dims(np.sqrt(true_var2), axis=1).T) * true_cor2
# We'll be working with precision matrices, so we'll go ahead and compute the
# true precision matrix here
true_precision0 = np.linalg.inv(true_cov0)
true_precision1 = np.linalg.inv(true_cov1)
true_precision2 = np.linalg.inv(true_cov2)
true_precisions = tf.stack(
[true_precision0, true_precision1, true_precision2])
true_cov = tf.stack([true_cov0, true_cov1, true_cov2])
true_chol_cov = tf.linalg.cholesky(true_cov)
gm = tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(probs=true_amp),
components_distribution=tfd.MultivariateNormalTriL(
loc=true_loc,
scale_tril=true_chol_cov))
obs = gm.sample(NUM_SOURCES)
# plt.scatter(obs.numpy()[:, 0], obs.numpy()[:, 1])
# plt.show()
# Observed uncertainities/covariances
# NOTE: The shape is augmented to broadcast into
# Mixture shape
uncerts_obs = tf.linalg.diag(
tfd.Uniform(low=0.1, high=0.9).sample([NUM_SOURCES, DIMS])
)
# Sanity check
a = MVNPrecisionCholeskyWithUncerts(true_loc,
tf.linalg.cholesky(true_precisions),
uncerts_obs, name="MVNPrecisionCholeskyWithUncerts")
samples = a.sample()
plt.scatter(samples.numpy()[:, 0, 0], samples.numpy()[:, 0, 1], c="green")
plt.scatter(samples.numpy()[:, 1, 0], samples.numpy()[:, 1, 1], c="green")
plt.scatter(samples.numpy()[:, 2, 0], samples.numpy()[:, 2, 1], c="green")
plt.scatter(obs.numpy()[:, 0], obs.numpy()[:, 1])
plt.show()
WARNING:tensorflow:From /home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/kernel.py:104: calling HamiltonianMonteCarlo.__init__ (from tensorflow_probability.python.mcmc.hmc) with step_size_update_fn is deprecated and will be removed after 2019-05-22.
Instructions for updating:
The `step_size_update_fn` argument is deprecated. Use `tfp.mcmc.SimpleStepSizeAdaptation` instead.
WARNING:tensorflow:From /home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py:507: calling HamiltonianMonteCarlo.__init__ (from tensorflow_probability.python.mcmc.hmc) with seed is deprecated and will be removed after 2020-09-20.
Instructions for updating:
The `seed` argument is deprecated (but will work until removed). Pass seed to `tfp.mcmc.sample_chain` instead.
WARNING:tensorflow:From /home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/linalg/linear_operator_diag.py:167: calling LinearOperator.__init__ (from tensorflow.python.ops.linalg.linear_operator) with graph_parents is deprecated and will be removed in a future version.
Instructions for updating:
Do not pass `graph_parents`. They will no longer be used.
WARNING:tensorflow:From /home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py:574: calling map_fn_v2 (from tensorflow.python.ops.map_fn) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Use fn_output_signature instead
Traceback (most recent call last):
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 1848, in _create_c_op
c_op = pywrap_tf_session.TF_FinishOperation(op_desc)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Dimensions must be equal, but are 3 and 1000 for '{{node mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/map/while/MVNPrecisionCholeskyWithUncerts/map/while/MatMul}} = BatchMatMulV2[T=DT_FLOAT, adj_x=false, adj_y=false](mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/map/while/MVNPrecisionCholeskyWithUncerts/map/while/strided_slice, mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/map/while/MVNPrecisionCholeskyWithUncerts/map/while/MatMul/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/map/while/MVNPrecisionCholeskyWithUncerts/Mul)' with input shapes: [3,2,2], [1000,2,2].
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "BGMM05.py", line 479, in <module>
pilot_samples, pilot_sampler_stat = sample_hmc(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 787, in __call__
result = self._call(*args, **kwds)
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 830, in _call
self._initialize(args, kwds, add_initializers_to=initializers)
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 717, in _initialize
self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2955, in _get_concrete_function_internal_garbage_collected
graph_function, _ = self._maybe_define_function(args, kwargs)
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3355, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3190, in _create_graph_function
func_graph_module.func_graph_from_py_func(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 989, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 626, in wrapped_fn
out = weak_wrapped_fn().__wrapped__(*args, **kwds)
File "BGMM05.py", line 269, in sample_hmc
return tfp.mcmc.sample_chain(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/sample.py", line 361, in sample_chain
(_, _, final_kernel_results), (all_states, trace) = mcmc_util.trace_scan(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py", line 460, in trace_scan
_, final_state, _, trace_arrays = tf.while_loop(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 574, in new_func
return func(*args, **kwargs)
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2489, in while_loop_v2
return while_loop(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2687, in while_loop
return while_v2.while_loop(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 188, in while_loop
body_graph = func_graph_module.func_graph_from_py_func(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 989, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 174, in wrapped_body
outputs = body(*_pack_sequence_as(orig_loop_vars, args))
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py", line 450, in _body
state = loop_fn(state, elem)
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/sample.py", line 354, in _trace_scan_fn
seed, next_state, current_kernel_results = mcmc_util.smart_for_loop(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py", line 349, in smart_for_loop
return tf.while_loop(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 574, in new_func
return func(*args, **kwargs)
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2489, in while_loop_v2
return while_loop(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2687, in while_loop
return while_v2.while_loop(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 188, in while_loop
body_graph = func_graph_module.func_graph_from_py_func(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 989, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 174, in wrapped_body
outputs = body(*_pack_sequence_as(orig_loop_vars, args))
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py", line 351, in <lambda>
body=lambda i, *args: [i + 1] + list(body_fn(*args)),
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/sample.py", line 351, in _seeded_one_step
kernel.one_step(*state_and_results, **one_step_kwargs))
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/dual_averaging_step_size_adaptation.py", line 456, in one_step
new_state, new_inner_results = self.inner_kernel.one_step(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/transformed_kernel.py", line 380, in one_step
transformed_next_state, kernel_results = self._inner_kernel.one_step(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/hmc.py", line 573, in one_step
next_state, kernel_results = self._impl.one_step(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/metropolis_hastings.py", line 215, in one_step
] = self.inner_kernel.one_step(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/hmc.py", line 774, in one_step
] = integrator(current_momentum_parts,
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/leapfrog_integrator.py", line 282, in __call__
] = tf.while_loop(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 574, in new_func
return func(*args, **kwargs)
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2489, in while_loop_v2
return while_loop(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2687, in while_loop
return while_v2.while_loop(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 188, in while_loop
body_graph = func_graph_module.func_graph_from_py_func(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 989, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 174, in wrapped_body
outputs = body(*_pack_sequence_as(orig_loop_vars, args))
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/leapfrog_integrator.py", line 284, in <lambda>
body=lambda i, *args: [i + 1] + list(_one_step( # pylint: disable=no-value-for-parameter,g-long-lambda
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/leapfrog_integrator.py", line 325, in _one_step
[next_target, next_target_grad_parts] = mcmc_util.maybe_call_fn_and_grads(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py", line 292, in maybe_call_fn_and_grads
result, grads = _value_and_gradients(fn, fn_arg_list, result, grads)
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py", line 252, in _value_and_gradients
result = fn(*fn_arg_list)
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/transformed_kernel.py", line 106, in transformed_log_prob_fn
tlp = log_prob_fn(*fn(state_parts))
File "BGMM05.py", line 412, in target_log_prob_fn
return tf.map_fn(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 574, in new_func
return func(*args, **kwargs)
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/map_fn.py", line 649, in map_fn_v2
return map_fn(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/map_fn.py", line 509, in map_fn
_, r_a = control_flow_ops.while_loop(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2687, in while_loop
return while_v2.while_loop(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 188, in while_loop
body_graph = func_graph_module.func_graph_from_py_func(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 989, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 174, in wrapped_body
outputs = body(*_pack_sequence_as(orig_loop_vars, args))
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/map_fn.py", line 499, in compute
result_value = autographed_fn(elems_value)
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 617, in wrapper
return func(*args, **kwargs)
File "BGMM05.py", line 413, in <lambda>
fn=lambda t: log_prob_fn(t),
File "BGMM05.py", line 404, in <lambda>
log_prob_fn = lambda *x: log_prior_fn(*x) + log_likelihood_fn(*x)
File "BGMM05.py", line 386, in log_likelihood_fn
mvn = MVNPrecisionCholeskyWithUncerts(
File "<decorator-gen-344>", line 2, in __init__
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 296, in wrapped_init
default_init(self_, *args, **kwargs)
File "/home/angel/gitrepos/HierarchicalModel/Preprocessing/Field/tf2/MVNPrecisionCholeskyWithUncerts.py", line 115, in __init__
C3 = tf.map_fn(get_inv_sum,
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 574, in new_func
return func(*args, **kwargs)
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/map_fn.py", line 649, in map_fn_v2
return map_fn(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/map_fn.py", line 509, in map_fn
_, r_a = control_flow_ops.while_loop(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2687, in while_loop
return while_v2.while_loop(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 188, in while_loop
body_graph = func_graph_module.func_graph_from_py_func(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 989, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 174, in wrapped_body
outputs = body(*_pack_sequence_as(orig_loop_vars, args))
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/map_fn.py", line 499, in compute
result_value = autographed_fn(elems_value)
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 617, in wrapper
return func(*args, **kwargs)
File "/home/angel/gitrepos/HierarchicalModel/Preprocessing/Field/tf2/MVNPrecisionCholeskyWithUncerts.py", line 97, in get_inv_sum
tf.linalg.matmul(precision_cholesky[x], B1)))
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py", line 201, in wrapper
return target(*args, **kwargs)
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/math_ops.py", line 3251, in matmul
return gen_math_ops.batch_mat_mul_v2(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/ops/gen_math_ops.py", line 1571, in batch_mat_mul_v2
_, _, _op, _outputs = _op_def_library._apply_op_helper(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/framework/op_def_library.py", line 748, in _apply_op_helper
op = g._create_op_internal(op_type_name, inputs, dtypes=None,
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 594, in _create_op_internal
return super(FuncGraph, self)._create_op_internal( # pylint: disable=protected-access
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 3523, in _create_op_internal
ret = Operation(
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 2010, in __init__
self._c_op = _create_c_op(self._graph, node_def, inputs,
File "/home/angel/anaconda3/envs/tf-n/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 1851, in _create_c_op
raise ValueError(str(e))
ValueError: Dimensions must be equal, but are 3 and 1000 for '{{node mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/map/while/MVNPrecisionCholeskyWithUncerts/map/while/MatMul}} = BatchMatMulV2[T=DT_FLOAT, adj_x=false, adj_y=false](mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/map/while/MVNPrecisionCholeskyWithUncerts/map/while/strided_slice, mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/map/while/MVNPrecisionCholeskyWithUncerts/map/while/MatMul/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/map/while/MVNPrecisionCholeskyWithUncerts/Mul)' with input shapes: [3,2,2], [1000,2,2].