I found a way, although now I cannot save the model after it is finished, it fails with a huge error log described below:
last_layer = tfp_layers.DistributionLambda(
make_distribution_fn=lambda t: distributions.Blockwise(t))(output_layers)
Traceback (most recent call last):
File "scripts/train_forecaster.py", line 156, in <module>
main()
File "scripts/train_forecaster.py", line 151, in main
the_model.save('model'))
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 1978, in save
save.save_model(self, filepath, overwrite, include_optimizer, save_format,
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/save.py", line 133, in save_model
saved_model_save.save(model, filepath, overwrite, include_optimizer,
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save.py", line 80, in save
save_lib.save(model, filepath, signatures, options)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py", line 975, in save
_, exported_graph, object_saver, asset_info = _build_meta_graph(
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py", line 1046, in _build_meta_graph
signatures = signature_serialization.find_function_to_export(
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/saved_model/signature_serialization.py", line 75, in find_function_to_export
functions = saveable_view.list_functions(saveable_view.root)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py", line 144, in list_functions
obj_functions = obj._list_functions_for_serialization( # pylint: disable=protected-access
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 2589, in _list_functions_for_serialization
functions = super(
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py", line 3018, in _list_functions_for_serialization
return (self._trackable_saved_model_saver
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/base_serialization.py", line 87, in list_functions_for_serialization
fns = self.functions_to_serialize(serialization_cache)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py", line 78, in functions_to_serialize
return (self._get_serialized_attributes(
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py", line 94, in _get_serialized_attributes
object_dict, function_dict = self._get_serialized_attributes_internal(
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/model_serialization.py", line 56, in _get_serialized_attributes_internal
super(ModelSavedModelSaver, self)._get_serialized_attributes_internal(
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py", line 104, in _get_serialized_attributes_internal
functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 163, in wrap_layer_functions
call_fn_with_losses = call_collection.add_function(
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 505, in add_function
self.add_trace(*self._input_signature)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 420, in add_trace
trace_with_training(True)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 418, in trace_with_training
fn.get_concrete_function(*args, **kwargs)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 549, in get_concrete_function
return super(LayerCall, self).get_concrete_function(*args, **kwargs)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 1167, in get_concrete_function
concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 1073, in _get_concrete_function_garbage_collected
self._initialize(args, kwargs, add_initializers_to=initializers)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 696, in _initialize
self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2855, in _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3213, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3065, in _create_graph_function
func_graph_module.func_graph_from_py_func(
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 986, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 600, in wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 526, in wrapper
ret = method(*args, **kwargs)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py", line 167, in wrap_with_training_arg
return tf_utils.smart_cond(
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/utils/tf_utils.py", line 64, in smart_cond
return smart_module.smart_cond(
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/framework/smart_cond.py", line 54, in smart_cond
return true_fn()
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py", line 169, in <lambda>
lambda: replace_training_and_call(True),
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py", line 165, in replace_training_and_call
return wrapped_call(*args, **kwargs)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 569, in call_and_return_conditional_losses
call_output = layer_call(inputs, *args, **kwargs)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py", line 385, in call
return self._run_internal_graph(
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py", line 508, in _run_internal_graph
outputs = node.layer(*args, **kwargs)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 245, in __call__
distribution, _ = super(DistributionLambda, self).__call__(
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py", line 985, in __call__
outputs = call_fn(inputs, *args, **kwargs)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py", line 71, in return_outputs_and_add_losses
outputs, losses = fn(inputs, *args, **kwargs)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py", line 167, in wrap_with_training_arg
return tf_utils.smart_cond(
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/utils/tf_utils.py", line 64, in smart_cond
return smart_module.smart_cond(
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/framework/smart_cond.py", line 54, in smart_cond
return true_fn()
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py", line 169, in <lambda>
lambda: replace_training_and_call(True),
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py", line 165, in replace_training_and_call
return wrapped_call(*args, **kwargs)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 543, in __call__
self.call_collection.add_trace(*args, **kwargs)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 420, in add_trace
trace_with_training(True)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 418, in trace_with_training
fn.get_concrete_function(*args, **kwargs)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 549, in get_concrete_function
return super(LayerCall, self).get_concrete_function(*args, **kwargs)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 1167, in get_concrete_function
concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 1073, in _get_concrete_function_garbage_collected
self._initialize(args, kwargs, add_initializers_to=initializers)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 696, in _initialize
self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2855, in _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3213, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3065, in _create_graph_function
func_graph_module.func_graph_from_py_func(
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 986, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 600, in wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 526, in wrapper
ret = method(*args, **kwargs)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py", line 167, in wrap_with_training_arg
return tf_utils.smart_cond(
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/utils/tf_utils.py", line 64, in smart_cond
return smart_module.smart_cond(
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/framework/smart_cond.py", line 54, in smart_cond
return true_fn()
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py", line 169, in <lambda>
lambda: replace_training_and_call(True),
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py", line 165, in replace_training_and_call
return wrapped_call(*args, **kwargs)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 569, in call_and_return_conditional_losses
call_output = layer_call(inputs, *args, **kwargs)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 251, in call
distribution, value = super(DistributionLambda, self).call(
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow/python/keras/layers/core.py", line 903, in call
result = self.function(inputs, **kwargs)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 172, in _fn
d = make_distribution_fn(*fargs, **fkwargs)
File "/home/cserpell/git/latent_prediction/models/model.py", line 132, in <lambda>
make_distribution_fn=lambda t: distributions.Blockwise(t))(output_layers)
File "<decorator-gen-144>", line 2, in __init__
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 334, in wrapped_init
default_init(self_, *args, **kwargs)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow_probability/python/distributions/blockwise.py", line 190, in __init__
joint_distribution_sequential.JointDistributionSequential(
File "<decorator-gen-140>", line 2, in __init__
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 334, in wrapped_init
default_init(self_, *args, **kwargs)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow_probability/python/distributions/joint_distribution_sequential.py", line 208, in __init__
self._build(model)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow_probability/python/distributions/joint_distribution_sequential.py", line 236, in _build
self._dist_fn_wrapped, self._dist_fn_args = zip(*[
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow_probability/python/distributions/joint_distribution_sequential.py", line 237, in <listcomp>
_unify_call_signature(i, dist_fn)
File "/home/cserpell/git/latent_prediction/p3/lib/python3.8/site-packages/tensorflow_probability/python/distributions/joint_distribution_sequential.py", line 471, in _unify_call_signature
raise TypeError('{} must be either `tfd.Distribution`-like or '
TypeError: Tensor("inputs:0", shape=(None, 1), dtype=float32) must be either `tfd.Distribution`-like or `callable`.