Concatenating or merging many distribution outputs into one multivariate distribution

193 views
Skip to first unread message

Cristián Serpell

unread,
Nov 23, 2020, 9:45:58 AM11/23/20
to TensorFlow Probability
Hi all,

Need for help: I have several DistributionLambda layers as the outputs of one model, and I would like to make a Concatenate-like operation into a new layer, in order to have only one output that is the mix of all the distributions, assuming they are independent. Then, I can apply a log-likelihood loss to the output of the model. Otherwise, I cannot apply the loss over a Concatenate layer, because it lost the log_prob method. I have been trying with the Blockwise distribution, but with no luck so far.

Any help would be greatly appreciated,
Cristián Serpell

Cristián Serpell

unread,
Nov 23, 2020, 10:54:13 AM11/23/20
to TensorFlow Probability, Cristián Serpell

from tensorflow.keras import layers
from tensorflow.keras import models
from tensorflow.keras import optimizers
from tensorflow_probability import distributions
from tensorflow_probability import layers as tfp_layers


def likelihood_loss(y_true, y_pred):
     """Adding negative log likelihood loss."""
     return -y_pred.log_prob(y_true)


def distribution_fn(params):
    """Distribution function."""
    return distributions.Normal(
        params[:, 0], math.log(1.0 + math.exp(params[:, 1])))


output_steps = 3
...
lstm_layer = layers.LSTM(10, return_state=True)
last_layer, l_h, l_c = lstm_layer(last_layer)
lstm_state = [l_h, l_c]
dense_layer = layers.Dense(2)
last_layer = dense_layer(last_layer)
last_layer = tfp_layers.DistributionLambda(
    make_distribution_fn=distribution_fn)(last_layer)
output_layers = [last_layer]
# Get output sequence, re-injecting the output of each step
for number in range(1, output_steps):
     last_layer = layers.Reshape((1, 1))(last_layer)
     last_layer, l_h, l_c = lstm_layer(last_layer, initial_state=lstm_states)
     # Storing state for next time step
     lstm_states = [l_h, l_c]
     last_layer = tfp_layers.DistributionLambda(
        make_distribution_fn=distribution_fn)(dense_layer(last_layer))
     output_layers.append(last_layer)

# This does not work
# last_layer = distributions.Blockwise(output_layers)

# This works for the model but cannot compute loss
# last_layer = layers.Concatenate(axis=1)(output_layers)

the_model = models.Model(inputs=[input_layer], outputs=[last_layer])
the_model.compile(loss=likelihood_loss, optimizer=optimizers.Adam(lr=0.001))

Cristián Serpell

unread,
Nov 30, 2020, 4:11:29 PM11/30/20
to TensorFlow Probability, Cristián Serpell
I found a way, although now I cannot save the model after it is finished, it fails with a huge error log described below:

last_layer = tfp_layers.DistributionLambda(
    make_distribution_fn=lambda t: distributions.Blockwise(t))(output_layers)

Traceback (most recent call last):
  File "scripts/train_forecaster.py", line 156, in <module>
    main()
  File "scripts/train_forecaster.py", line 151, in main
    the_model.save('model'))
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 1978, in save
    save.save_model(self, filepath, overwrite, include_optimizer, save_format,
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/save.py", line 133, in save_model
    saved_model_save.save(model, filepath, overwrite, include_optimizer,
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save.py", line 80, in save
    save_lib.save(model, filepath, signatures, options)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py", line 975, in save
    _, exported_graph, object_saver, asset_info = _build_meta_graph(
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py", line 1046, in _build_meta_graph
    signatures = signature_serialization.find_function_to_export(
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/saved_model/signature_serialization.py", line 75, in find_function_to_export
    functions = saveable_view.list_functions(saveable_view.root)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py", line 144, in list_functions
    obj_functions = obj._list_functions_for_serialization(  # pylint: disable=protected-access
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 2589, in _list_functions_for_serialization
    functions = super(
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py", line 3018, in _list_functions_for_serialization
    return (self._trackable_saved_model_saver
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/base_serialization.py", line 87, in list_functions_for_serialization
    fns = self.functions_to_serialize(serialization_cache)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py", line 78, in functions_to_serialize
    return (self._get_serialized_attributes(
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py", line 94, in _get_serialized_attributes
    object_dict, function_dict = self._get_serialized_attributes_internal(
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/model_serialization.py", line 56, in _get_serialized_attributes_internal
    super(ModelSavedModelSaver, self)._get_serialized_attributes_internal(
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py", line 104, in _get_serialized_attributes_internal
    functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 163, in wrap_layer_functions
    call_fn_with_losses = call_collection.add_function(
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 505, in add_function
    self.add_trace(*self._input_signature)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 420, in add_trace
    trace_with_training(True)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 418, in trace_with_training
    fn.get_concrete_function(*args, **kwargs)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 549, in get_concrete_function
    return super(LayerCall, self).get_concrete_function(*args, **kwargs)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 1167, in get_concrete_function
    concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 1073, in _get_concrete_function_garbage_collected
    self._initialize(args, kwargs, add_initializers_to=initializers)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 696, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2855, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3213, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3065, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 986, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 600, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 526, in wrapper
    ret = method(*args, **kwargs)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py", line 167, in wrap_with_training_arg
    return tf_utils.smart_cond(
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/utils/tf_utils.py", line 64, in smart_cond
    return smart_module.smart_cond(
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/framework/smart_cond.py", line 54, in smart_cond
    return true_fn()
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py", line 169, in <lambda>
    lambda: replace_training_and_call(True),
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py", line 165, in replace_training_and_call
    return wrapped_call(*args, **kwargs)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 569, in call_and_return_conditional_losses
    call_output = layer_call(inputs, *args, **kwargs)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py", line 385, in call
    return self._run_internal_graph(
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py", line 508, in _run_internal_graph
    outputs = node.layer(*args, **kwargs)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 245, in __call__
    distribution, _ = super(DistributionLambda, self).__call__(
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py", line 985, in __call__
    outputs = call_fn(inputs, *args, **kwargs)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py", line 71, in return_outputs_and_add_losses
    outputs, losses = fn(inputs, *args, **kwargs)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py", line 167, in wrap_with_training_arg
    return tf_utils.smart_cond(
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/utils/tf_utils.py", line 64, in smart_cond
    return smart_module.smart_cond(
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/framework/smart_cond.py", line 54, in smart_cond
    return true_fn()
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py", line 169, in <lambda>
    lambda: replace_training_and_call(True),
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py", line 165, in replace_training_and_call
    return wrapped_call(*args, **kwargs)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 543, in __call__
    self.call_collection.add_trace(*args, **kwargs)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 420, in add_trace
    trace_with_training(True)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 418, in trace_with_training
    fn.get_concrete_function(*args, **kwargs)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 549, in get_concrete_function
    return super(LayerCall, self).get_concrete_function(*args, **kwargs)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 1167, in get_concrete_function
    concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 1073, in _get_concrete_function_garbage_collected
    self._initialize(args, kwargs, add_initializers_to=initializers)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 696, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2855, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3213, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3065, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 986, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 600, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 526, in wrapper
    ret = method(*args, **kwargs)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py", line 167, in wrap_with_training_arg
    return tf_utils.smart_cond(
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/utils/tf_utils.py", line 64, in smart_cond
    return smart_module.smart_cond(
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/framework/smart_cond.py", line 54, in smart_cond
    return true_fn()
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py", line 169, in <lambda>
    lambda: replace_training_and_call(True),
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py", line 165, in replace_training_and_call
    return wrapped_call(*args, **kwargs)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 569, in call_and_return_conditional_losses
    call_output = layer_call(inputs, *args, **kwargs)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 251, in call
    distribution, value = super(DistributionLambda, self).call(
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/layers/core.py", line 903, in call
    result = self.function(inputs, **kwargs)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 172, in _fn
    d = make_distribution_fn(*fargs, **fkwargs)
  File "/home/cserpell/git/latent_prediction/models/model.py", line 132, in <lambda>
    make_distribution_fn=lambda t: distributions.Blockwise(t))(output_layers)
  File "<decorator-gen-144>", line 2, in __init__
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 334, in wrapped_init
    default_init(self_, *args, **kwargs)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow_probability/python/distributions/blockwise.py", line 190, in __init__
    joint_distribution_sequential.JointDistributionSequential(
  File "<decorator-gen-140>", line 2, in __init__
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 334, in wrapped_init
    default_init(self_, *args, **kwargs)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow_probability/python/distributions/joint_distribution_sequential.py", line 208, in __init__
    self._build(model)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow_probability/python/distributions/joint_distribution_sequential.py", line 236, in _build
    self._dist_fn_wrapped, self._dist_fn_args = zip(*[
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow_probability/python/distributions/joint_distribution_sequential.py", line 237, in <listcomp>
    _unify_call_signature(i, dist_fn)
  File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow_probability/python/distributions/joint_distribution_sequential.py", line 471, in _unify_call_signature
    raise TypeError('{} must be either `tfd.Distribution`-like or '
TypeError: Tensor("inputs:0", shape=(None, 1), dtype=float32) must be either `tfd.Distribution`-like or `callable`.

Brian Patton 🚀

unread,
Dec 10, 2020, 9:15:38 AM12/10/20
to Cristián Serpell, TensorFlow Probability
You might consider writing a custom subclass of tfd.Distribution for this. It could take as a __init__ parameter the keras LSTM layer, and as long as that gets set as an attr, *I think* tf.Module should pick up the linkage to any underlying variables. The latest TF nightly builds should have keras traversing tf.Module as well as keras.Model searching for variables. I don't recall if keras.Model is also subclassing tf.Module (I hope for this eventually). Then _sample_n can return a tf.stack(..) of the LSTM outputs.

Brian Patton | Software Engineer | b...@google.com



--
You received this message because you are subscribed to the Google Groups "TensorFlow Probability" group.
To unsubscribe from this group and stop receiving emails from it, send an email to tfprobabilit...@tensorflow.org.
To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/tfprobability/e5fb4db3-88ba-4f64-b6ed-441da7e9c996n%40tensorflow.org.
Reply all
Reply to author
Forward
0 new messages