110 views

Skip to first unread message

Mar 22, 2020, 11:07:26 AM3/22/20

to Rust for TensorFlow

Hi! First I want to say thanks in advance for this project---it is great to see two exciting things (Rust and TensorFlow) meeting here.

I recently came to the project in hopes of doing a bit of reinforcement learning using TensorFlow, preferably all called from Rust. I am following the Python code in the book "Hands on Machine Learning" with Scikit-Learn and TensorFlow" second edition. In one of the key algorithms (Deep Q-learning Networks, DQNs) the code makes use of `tf.GradientTape` to do a custom learning step, updating the training weights directly using the automatically-calculated gradient from the tape.

It seems the Rust for TensorFlow bindings do not support `tf.GradientTape`, and from what I can see looking at the Python source for that class, it is implemented in pure Python. Additionally I don't see anything about the gradient tape in the TensorFlow C API.

My question then is what is the scope of this project? Is it to provide a convenient Rust wrapper around the TF C API only? Or do you envision reproducing such things as might be implemented only in the Python library?

And additionally, if I were to make Rust implementation of `GradientTape` (I guess it would have to be a macro, not even sure it's possible or if I'm capable of it) would such a contribution be welcome?

Thanks

- Josh

Mar 26, 2020, 12:54:35 AM3/26/20

to Josh Hansen, Rust for TensorFlow

The project does rely on the C API, and unfortunately, a lot of TensorFlow is implemented in Python, which we can't bind in easily. I am working on providing a higher-level API than what the C API provides, and would eventually like to provide a basic Keras API. I'm not very familiar with GradientTape, and I can't say how easy or hard it would be to implement in Rust. It would probably rely on experimental functionality which is still being implemented (e.g. generated code for built-in ops). Contributions are definitely welcome, but I'd recommend sketching out an outline of what the API would look like and what the approach would be before sending a pull request so we can discuss it and make sure the approach lines up well with the rest of the library.

--

You received this message because you are subscribed to the Google Groups "Rust for TensorFlow" group.

To unsubscribe from this group and stop receiving emails from it, send an email to rust+uns...@tensorflow.org.

To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/rust/542fccd5-428f-48d5-a5de-5fa696387333%40tensorflow.org.

Apr 11, 2020, 1:14:46 AM4/11/20

to Adam Crume, Rust for TensorFlow

Saw this discussion recently about differential computation support for Rust, thought it was relevant: https://internals.rust-lang.org/t/native-differential-programming-support-for-rust/9626

If we could ever have something along the lines of:

let grad = gradient! {

// specify computations in the graph via usual API

};

That would be amazing. Can you tell me more about "generated code for built-in ops"? I'm curious how heavy a lift this would be.

To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/rust/CAPm%2Bi6Q4Me4LQ2HAQCiycv31Ndp7JPEaJLw7RDPdzut%3DtsO8HA%40mail.gmail.com.

Apr 23, 2020, 11:58:49 AM4/23/20

to Josh Hansen, Rust for TensorFlow

The gradient computation is not performed natively in Rust, but you can call Graph::add_gradients to manually add gradient computation to your graph. If you just want to train a model, it's easier to avoid calling that manually and use an optimizer instead (example usage can be found in examples/xor.rs).

The generated code I was referring to is the fact that you can now (for example) write code like:

let z = ops::add(x, y, &mut scope)?;

instead of:

let z = {

let mut nd = graph.new_operation("Add", "Add_0")?;

nd.add_input(x);

nd.add_input(y);

nd.finish()?

nd.add_input(x);

nd.add_input(y);

nd.finish()?

};

which means that you don't need to worry about having constants for the names of ops and their attributes, or generating names for the ops if you don't care what they are (e.g. Add_0, Add_1, etc.), and of course the code is shorter.

Note that both the optimizers and the generated ops currently require the "experimental_training" feature, although I plan to enable them by default soon.

Apr 23, 2020, 4:30:36 PM4/23/20

to Adam Crume, Rust for TensorFlow

Thank you, that's very helpful. Perhaps you could guide me in which direction to pursue: I'm working on building a DNN to work with the `rsrl` reinforcement learning package, to provide a better AI for my game, Umpire. As such I need to be able to update a model given a specific error value for a specific state and action pair. My current approach has been to build a model using Keras, serialize it without training, then deserialize it in Rust and try to invoke optimizers defined for each output in the Keras model from Rust to optimize the state action value estimates. But I can't seem to find the connection point of the Keras-defined optimizers within the Rust TF API. Is it possible to access Keras optimizers from Rust? Or would I need to define those using the lower-level `tensorflow` API in Python?

If I can't make that work then `add_gradients` could be an alternative. But I'd rather avoid manually generating the gradients for a complex network if at all possible. Though the docs do suggest that it can autoderive the gradients for some functions?

- Josh

To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/rust/CAPm%2Bi6SGBkhGeVr0GFoRE0pOTf-zPwsYmXw-uFLpsCQ%2BkC%2BFgA%40mail.gmail.com.

Apr 27, 2020, 11:16:57 AM4/27/20

to Josh Hansen, Rust for TensorFlow

The Keras optimizers can definitely be used in Rust; it's just a matter of finding the right operations and running them. You should be able to find it by looking at SavedModelBundle::load(...)?.meta_graph_def().signatures(). This currently requires enabling the experimental_training feature.

To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/rust/CABBA6nySwQidVhwPcxJzWCtHGcnzZsXj9DOe6pBUNBajfL2yyQ%40mail.gmail.com.

May 11, 2020, 1:48:51 AM5/11/20

to Adam Crume, Rust for TensorFlow

Hmm... I'm not sure I see a way to utilize Keras optimizers since the only two signatures I'm seeing are for "__saved_model_init_op" and a "tensorflow/serving/predict" method. And there are functions "__call__", "_default_save_signature", and "call_and_return_all_conditional_losses". Is that what you would expect to be seeing? I do have "experimental_training" enabled.

I may be able to do the training on the Rust side using the `Optimizer` trait, but it seems to require a list of all variables to be optimized. Is that available somewhere, like Python's "trainable_variables" on `Model`?

The script defining the model is here.

And here is the saved_model_cli output in case it's helpful:

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['__saved_model_init_op']:

The given SavedModel SignatureDef contains the following input(s):

The given SavedModel SignatureDef contains the following output(s):

outputs['__saved_model_init_op'] tensor_info:

dtype: DT_INVALID

shape: unknown_rank

name: NoOp

Method name is:

signature_def['serving_default']:

The given SavedModel SignatureDef contains the following input(s):

inputs['1d_features'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 14)

name: serving_default_1d_features:0

inputs['is_enemy_belligerent'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 121)

name: serving_default_is_enemy_belligerent:0

inputs['is_neutral'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 121)

name: serving_default_is_neutral:0

inputs['is_observed'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 121)

name: serving_default_is_observed:0

The given SavedModel SignatureDef contains the following output(s):

outputs['action_value_00'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:0

outputs['action_value_01'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:1

outputs['action_value_02'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:2

outputs['action_value_03'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:3

outputs['action_value_04'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:4

outputs['action_value_05'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:5

outputs['action_value_06'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:6

outputs['action_value_07'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:7

outputs['action_value_08'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:8

outputs['action_value_09'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:9

outputs['action_value_10'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:10

outputs['action_value_11'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:11

outputs['action_value_12'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:12

outputs['action_value_13'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:13

outputs['action_value_14'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:14

outputs['action_value_15'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:15

outputs['action_value_16'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:16

outputs['action_value_17'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:17

outputs['action_value_18'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:18

Method name is: tensorflow/serving/predict

WARNING:tensorflow:From /home/josh/Projects/Umpire/venv/lib/python3.8/site-packages/tensorflow/python/ops/resource_variable_ops.py:1813: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.

Instructions for updating:

If using Keras pass *_constraint arguments to layers.

Defined Functions:

Function Name: '__call__'

Option #1

Callable with:

Argument #1

DType: list

Value: [TensorSpec(shape=(None, 14), dtype=tf.float64, name='1d_features'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='is_enemy_belligerent'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='is_observed'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='is_neutral')]

Argument #2

DType: bool

Value: True

Argument #3

DType: NoneType

Value: None

Option #2

Callable with:

Argument #1

DType: list

Value: [TensorSpec(shape=(None, 14), dtype=tf.float64, name='inputs/0'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='inputs/1'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='inputs/2'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='inputs/3')]

Argument #2

DType: bool

Value: True

Argument #3

DType: NoneType

Value: None

Option #3

Callable with:

Argument #1

DType: list

Value: [TensorSpec(shape=(None, 14), dtype=tf.float64, name='1d_features'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='is_enemy_belligerent'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='is_observed'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='is_neutral')]

Argument #2

DType: bool

Value: False

Argument #3

DType: NoneType

Value: None

Option #4

Callable with:

Argument #1

DType: list

Value: [TensorSpec(shape=(None, 14), dtype=tf.float64, name='inputs/0'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='inputs/1'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='inputs/2'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='inputs/3')]

Argument #2

DType: bool

Value: False

Argument #3

DType: NoneType

Value: None

Function Name: '_default_save_signature'

Option #1

Callable with:

Argument #1

DType: list

Value: [TensorSpec(shape=(None, 14), dtype=tf.float64, name='1d_features'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='is_enemy_belligerent'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='is_observed'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='is_neutral')]

Function Name: 'call_and_return_all_conditional_losses'

Option #1

Callable with:

Argument #1

DType: list

Value: [TensorSpec(shape=(None, 14), dtype=tf.float64, name='inputs/0'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='inputs/1'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='inputs/2'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='inputs/3')]

Argument #2

DType: bool

Value: True

Argument #3

DType: NoneType

Value: None

Option #2

Callable with:

Argument #1

DType: list

Value: [TensorSpec(shape=(None, 14), dtype=tf.float64, name='1d_features'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='is_enemy_belligerent'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='is_observed'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='is_neutral')]

Argument #2

DType: bool

Value: False

Argument #3

DType: NoneType

Value: None

Option #3

Callable with:

Argument #1

DType: list

Value: [TensorSpec(shape=(None, 14), dtype=tf.float64, name='1d_features'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='is_enemy_belligerent'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='is_observed'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='is_neutral')]

Argument #2

DType: bool

Value: True

Argument #3

DType: NoneType

Value: None

Option #4

Callable with:

Argument #1

DType: list

Value: [TensorSpec(shape=(None, 14), dtype=tf.float64, name='inputs/0'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='inputs/1'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='inputs/2'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='inputs/3')]

Argument #2

DType: bool

Value: False

Argument #3

DType: NoneType

Value: None

signature_def['__saved_model_init_op']:

The given SavedModel SignatureDef contains the following input(s):

The given SavedModel SignatureDef contains the following output(s):

outputs['__saved_model_init_op'] tensor_info:

dtype: DT_INVALID

shape: unknown_rank

name: NoOp

Method name is:

signature_def['serving_default']:

The given SavedModel SignatureDef contains the following input(s):

inputs['1d_features'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 14)

name: serving_default_1d_features:0

inputs['is_enemy_belligerent'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 121)

name: serving_default_is_enemy_belligerent:0

inputs['is_neutral'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 121)

name: serving_default_is_neutral:0

inputs['is_observed'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 121)

name: serving_default_is_observed:0

The given SavedModel SignatureDef contains the following output(s):

outputs['action_value_00'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:0

outputs['action_value_01'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:1

outputs['action_value_02'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:2

outputs['action_value_03'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:3

outputs['action_value_04'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:4

outputs['action_value_05'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:5

outputs['action_value_06'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:6

outputs['action_value_07'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:7

outputs['action_value_08'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:8

outputs['action_value_09'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:9

outputs['action_value_10'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:10

outputs['action_value_11'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:11

outputs['action_value_12'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:12

outputs['action_value_13'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:13

outputs['action_value_14'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:14

outputs['action_value_15'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:15

outputs['action_value_16'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:16

outputs['action_value_17'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:17

outputs['action_value_18'] tensor_info:

dtype: DT_DOUBLE

shape: (-1, 1)

name: StatefulPartitionedCall:18

Method name is: tensorflow/serving/predict

WARNING:tensorflow:From /home/josh/Projects/Umpire/venv/lib/python3.8/site-packages/tensorflow/python/ops/resource_variable_ops.py:1813: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.

Instructions for updating:

If using Keras pass *_constraint arguments to layers.

Defined Functions:

Function Name: '__call__'

Option #1

Callable with:

Argument #1

DType: list

Value: [TensorSpec(shape=(None, 14), dtype=tf.float64, name='1d_features'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='is_enemy_belligerent'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='is_observed'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='is_neutral')]

Argument #2

DType: bool

Value: True

Argument #3

DType: NoneType

Value: None

Option #2

Callable with:

Argument #1

DType: list

Value: [TensorSpec(shape=(None, 14), dtype=tf.float64, name='inputs/0'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='inputs/1'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='inputs/2'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='inputs/3')]

Argument #2

DType: bool

Value: True

Argument #3

DType: NoneType

Value: None

Option #3

Callable with:

Argument #1

DType: list

Value: [TensorSpec(shape=(None, 14), dtype=tf.float64, name='1d_features'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='is_enemy_belligerent'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='is_observed'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='is_neutral')]

Argument #2

DType: bool

Value: False

Argument #3

DType: NoneType

Value: None

Option #4

Callable with:

Argument #1

DType: list

Value: [TensorSpec(shape=(None, 14), dtype=tf.float64, name='inputs/0'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='inputs/1'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='inputs/2'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='inputs/3')]

Argument #2

DType: bool

Value: False

Argument #3

DType: NoneType

Value: None

Function Name: '_default_save_signature'

Option #1

Callable with:

Argument #1

DType: list

Value: [TensorSpec(shape=(None, 14), dtype=tf.float64, name='1d_features'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='is_enemy_belligerent'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='is_observed'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='is_neutral')]

Function Name: 'call_and_return_all_conditional_losses'

Option #1

Callable with:

Argument #1

DType: list

Value: [TensorSpec(shape=(None, 14), dtype=tf.float64, name='inputs/0'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='inputs/1'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='inputs/2'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='inputs/3')]

Argument #2

DType: bool

Value: True

Argument #3

DType: NoneType

Value: None

Option #2

Callable with:

Argument #1

DType: list

Value: [TensorSpec(shape=(None, 14), dtype=tf.float64, name='1d_features'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='is_enemy_belligerent'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='is_observed'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='is_neutral')]

Argument #2

DType: bool

Value: False

Argument #3

DType: NoneType

Value: None

Option #3

Callable with:

Argument #1

DType: list

Value: [TensorSpec(shape=(None, 14), dtype=tf.float64, name='1d_features'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='is_enemy_belligerent'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='is_observed'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='is_neutral')]

Argument #2

DType: bool

Value: True

Argument #3

DType: NoneType

Value: None

Option #4

Callable with:

Argument #1

DType: list

Value: [TensorSpec(shape=(None, 14), dtype=tf.float64, name='inputs/0'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='inputs/1'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='inputs/2'), TensorSpec(shape=(None, 121), dtype=tf.float64, name='inputs/3')]

Argument #2

DType: bool

Value: False

Argument #3

DType: NoneType

Value: None

To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/rust/CAPm%2Bi6RdNrQzQ5jRwcU-B9dKa0F-btGMoYVVyWMhvYfnJASYWg%40mail.gmail.com.

May 11, 2020, 1:36:59 PM5/11/20

to Josh Hansen, Rust for TensorFlow

I forgot that tf.saved_model.save doesn't save the training ops by default. One possibility is to manually add a signature (i.e. the "signatures" argument to model.save) for training the model. You may also need to set include_optimizer=True.

I wouldn't recommend using the Optimizer impls unless you're building the model in Rust. There isn't currently a way to list all variables; they have to be tracked manually. We'll probably need to add this at some point.

To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/rust/CABBA6nyPyJmyY-XC8E9%3D58jXyWCZuLfHOSuwk4ozfTOvC%2BH6gA%40mail.gmail.com.

May 15, 2020, 10:06:02 PM5/15/20

to Adam Crume, Rust for TensorFlow

I've been banging my head against this the past few days and can't seem to get training in Rust to work using a Keras model. I've tried a million things and am really close (I think) but just can't get the last inch. Some details:

I am skeptical in spite of what you said above that Keras models can be trained in C/Rust, at least not without building your own optimization. I've tried a million things to get `Model.fit` into the saved model. I tried providing `signatures` explicitly, but you can't get one for `fit` because it doesn't have `get_concrete_function`. I've tried saving with `include_optimizer=True` but the functions I find in the saved model don't seem to actually allow inference, here's a sample from a two unit network:

"__inference_model_layer_call_and_return_conditional_losses_244"

"__inference_model_layer_call_and_return_conditional_losses_262"

"__inference_tf_op_layer_weighted_loss/num_elements_layer_call_fn_471"

"__inference_model_layer_call_and_return_conditional_losses_354"

"__inference_tf_op_layer_weighted_loss/num_elements/Cast_layer_call_and_return_conditional_losses_220"

"__inference_tf_op_layer_Mean_layer_call_and_return_conditional_losses_152"

"__inference_model_layer_call_and_return_conditional_losses_312"

"__inference__wrapped_model_100"

"__inference_model_layer_call_fn_387"

"__inference_y_hat_layer_call_fn_416"

"__inference_y_hat_layer_call_and_return_conditional_losses_407"

"__inference_tf_op_layer_Mean_layer_call_fn_439"

"__inference_tf_op_layer_weighted_loss/value_layer_call_and_return_conditional_losses_498"

"__inference_tf_op_layer_weighted_loss/Sum_1_layer_call_and_return_conditional_losses_207"

"__inference_tf_op_layer_weighted_loss/num_elements/Cast_layer_call_and_return_conditional_losses_487"

"__inference_tf_op_layer_weighted_loss/value_layer_call_and_return_conditional_losses_234"

"__inference_tf_op_layer_weighted_loss/Mul_layer_call_and_return_conditional_losses_445"

"__inference_tf_op_layer_weighted_loss/num_elements_layer_call_and_return_conditional_losses_466"

"__inference_tf_op_layer_weighted_loss/Sum_layer_call_and_return_conditional_losses_193"

"__inference_model_layer_call_and_return_conditional_losses_284"

"__inference__traced_save_538"

"__inference_tf_op_layer_weighted_loss/num_elements/Cast_layer_call_fn_492"

"__inference_tf_op_layer_weighted_loss/Sum_1_layer_call_and_return_conditional_losses_477"

"__inference_tf_op_layer_weighted_loss/Mul_layer_call_and_return_conditional_losses_166"

"__inference_tf_op_layer_SquaredDifference_layer_call_and_return_conditional_losses_137"

"__inference_y_hat_layer_call_and_return_conditional_losses_115"

"__inference_tf_op_layer_Mean_layer_call_and_return_conditional_losses_434"

"__inference_tf_op_layer_SquaredDifference_layer_call_and_return_conditional_losses_422"

"__inference_tf_op_layer_weighted_loss/value_layer_call_fn_504"

"__inference_model_layer_call_fn_291"

"__inference_tf_op_layer_weighted_loss/num_elements_layer_call_and_return_conditional_losses_179"

"__inference_model_layer_call_fn_319"

"__inference_tf_op_layer_weighted_loss/Sum_1_layer_call_fn_482"

"__inference_tf_op_layer_weighted_loss/Sum_layer_call_and_return_conditional_losses_456"

"__inference_model_layer_call_and_return_conditional_losses_377"

"__inference_tf_op_layer_weighted_loss/Mul_layer_call_fn_450"

"__inference_tf_op_layer_weighted_loss/Sum_layer_call_fn_461"

"__inference_signature_wrapper_331"

"__inference__traced_restore_556"

"__inference_model_layer_call_fn_397"

"__inference_tf_op_layer_SquaredDifference_layer_call_fn_428"

"__inference_model_layer_call_and_return_conditional_losses_262"

"__inference_tf_op_layer_weighted_loss/num_elements_layer_call_fn_471"

"__inference_model_layer_call_and_return_conditional_losses_354"

"__inference_tf_op_layer_weighted_loss/num_elements/Cast_layer_call_and_return_conditional_losses_220"

"__inference_tf_op_layer_Mean_layer_call_and_return_conditional_losses_152"

"__inference_model_layer_call_and_return_conditional_losses_312"

"__inference__wrapped_model_100"

"__inference_model_layer_call_fn_387"

"__inference_y_hat_layer_call_fn_416"

"__inference_y_hat_layer_call_and_return_conditional_losses_407"

"__inference_tf_op_layer_Mean_layer_call_fn_439"

"__inference_tf_op_layer_weighted_loss/value_layer_call_and_return_conditional_losses_498"

"__inference_tf_op_layer_weighted_loss/Sum_1_layer_call_and_return_conditional_losses_207"

"__inference_tf_op_layer_weighted_loss/num_elements/Cast_layer_call_and_return_conditional_losses_487"

"__inference_tf_op_layer_weighted_loss/value_layer_call_and_return_conditional_losses_234"

"__inference_tf_op_layer_weighted_loss/Mul_layer_call_and_return_conditional_losses_445"

"__inference_tf_op_layer_weighted_loss/num_elements_layer_call_and_return_conditional_losses_466"

"__inference_tf_op_layer_weighted_loss/Sum_layer_call_and_return_conditional_losses_193"

"__inference_model_layer_call_and_return_conditional_losses_284"

"__inference__traced_save_538"

"__inference_tf_op_layer_weighted_loss/num_elements/Cast_layer_call_fn_492"

"__inference_tf_op_layer_weighted_loss/Sum_1_layer_call_and_return_conditional_losses_477"

"__inference_tf_op_layer_weighted_loss/Mul_layer_call_and_return_conditional_losses_166"

"__inference_tf_op_layer_SquaredDifference_layer_call_and_return_conditional_losses_137"

"__inference_y_hat_layer_call_and_return_conditional_losses_115"

"__inference_tf_op_layer_Mean_layer_call_and_return_conditional_losses_434"

"__inference_tf_op_layer_SquaredDifference_layer_call_and_return_conditional_losses_422"

"__inference_tf_op_layer_weighted_loss/value_layer_call_fn_504"

"__inference_model_layer_call_fn_291"

"__inference_tf_op_layer_weighted_loss/num_elements_layer_call_and_return_conditional_losses_179"

"__inference_model_layer_call_fn_319"

"__inference_tf_op_layer_weighted_loss/Sum_1_layer_call_fn_482"

"__inference_tf_op_layer_weighted_loss/Sum_layer_call_and_return_conditional_losses_456"

"__inference_model_layer_call_and_return_conditional_losses_377"

"__inference_tf_op_layer_weighted_loss/Mul_layer_call_fn_450"

"__inference_tf_op_layer_weighted_loss/Sum_layer_call_fn_461"

"__inference_signature_wrapper_331"

"__inference__traced_restore_556"

"__inference_model_layer_call_fn_397"

"__inference_tf_op_layer_SquaredDifference_layer_call_fn_428"

I've tried subclassing `tf.Module` and adding `tf.function`s to call `Model.fit` for me, but autograph chokes. I tried calculating gradients with `tf.GradientTape` but never get anything but `None` no matter which variables I watch. I tried using the Keras `optimizers` API directly but get "No gradients provided for any variable" and was told on Reddit that Keras' optimizers API is not meant to work with Keras outputs, go figure.

Here's where I'm closest: I can calculate gradients directly using `tf.keras.backend.gradients` and I can serialize them with the saved model if I use `tf.compat.v1.disable_v2_behavior`. I can see looking at the raw file contents that those gradients have been included, or at least something by the same name, but they don't show up using `saved_model_cli show` or in the signatures or functions seen by the Rust API.

My main options now:

1) find a way to access those gradients within Rust and apply them. This would be my preference.

2) get that `trainable_variables` call implemented in Rust and use the experimental training API. This could also be great but I've taken a look at doing it myself and have no idea where to get started.

3) give up on training in Rust and embed a Python interpreter in my game and do Keras training there.

4) give up on training in Rust and embed my game into Python to do the training all in Python/Keras.

5) Something else? You hinted above that training of Keras models in Rust should be possible. I've seen no sign of that, but maybe I'm missing something?

If you've read this far, thanks for your time. And thanks for your help getting me this far.

- Josh

To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/rust/CAPm%2Bi6TefcZRW7OGVNoiTDJZxv-dj4TX08etsW15bDg6Lw4T1g%40mail.gmail.com.

May 20, 2020, 12:33:12 AM5/20/20

to Josh Hansen, Rust for TensorFlow

I've been trying to figure this out and running into a lot of the same issues. I assume the Keras API is using the same functionality under the hood, but I can't convince it to expose the training op. You could try loading the gradients in Rust and applying an optimizer, but the API assumes that you're building the graph in Rust, and there may be rough edges (e.g. creating Variable instances). Another possibility is to stop using the Keras API, but that does require you to build the graph a bit more manually.

We could also ask for help on the main TensorFlow project, since they know a lot more about the inner workings of TensorFlow than I do.

May 23, 2020, 1:00:44 AM5/23/20

to Adam Crume, Rust for TensorFlow

Thanks for taking a look. I put a question up on StackOverflow which seems like the main forum they're recommending: https://stackoverflow.com/questions/61907288/can-keras-models-be-trained-using-the-tensorflow-c-api

Feel free to upvote it, it has no answers so far.

I've made a bit of progress and wanted to run it by you to see if you know any tensorflow-rust-specific issues before I ask about this on a more general mailinglist. I *MIGHT* have a `tf.function` that will do the training, based on an example from a book. While Keras's `Model.fit` might not work, it seems that `GradientTape` and`apply_gradients` do work with AutoGraph, so I can build my own training step. But I'm having trouble really trying it out since it's unclear how to actually call a function. This is my module:

class TrainableAI(tf.Module):

def __init__(self, keras_model, *args, **kwargs):

self.model = keras_model

self.mse = tf.keras.losses.MeanSquaredError()

self.optimizer = Adam()

super().__init__(*args, **kwargs)

tf.TensorSpec(shape=(1,14,), dtype=tf.float32),#1d_features

tf.TensorSpec(shape=(1,121,), dtype=tf.float32),#is_enemy_belligerent

tf.TensorSpec(shape=(1,121,), dtype=tf.float32),#is_observed

tf.TensorSpec(shape=(1,121,), dtype=tf.float32),#is_neutral

tf.TensorSpec(shape=(1,POSSIBLE_ACTIONS,), dtype=tf.float32),# true_action_values

])

def fit_action(self, _1d_features, is_enemy_belligerent, is_observed, is_neutral, true_action_values):

inputs = [_1d_features, is_enemy_belligerent, is_observed, is_neutral, true_action_values]

with tf.GradientTape() as tape:

estimated_action_values = self.model(inputs)

mse = tf.keras.losses.MeanSquaredError()

loss = mse(estimated_action_values, true_action_values)

grads = tape.gradient(loss, model.trainable_variables)

print(f"Gradients: {grads}")

self.optimizer.apply_gradients(zip(grads, model.trainable_variables))

And here's what `saved_model_cli show --dir $DIR --all` shows:

$ ./venv/bin/saved_model_cli show --dir ai/umpire_regressor --all

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['__saved_model_init_op']:

The given SavedModel SignatureDef contains the following input(s):

The given SavedModel SignatureDef contains the following output(s):

outputs['__saved_model_init_op'] tensor_info:

dtype: DT_INVALID

shape: unknown_rank

name: NoOp

Method name is:

WARNING:tensorflow:From /home/josh/Projects/Umpire/venv/lib/python3.8/site-packages/tensorflow/python/ops/resource_variable_ops.py:1813: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.

Instructions for updating:

If using Keras pass *_constraint arguments to layers.

Defined Functions:

Instructions for updating:

If using Keras pass *_constraint arguments to layers.

Defined Functions:

Function Name: 'fit_action'

Option #1

Callable with:

Argument #1

_1d_features: TensorSpec(shape=(1, 14), dtype=tf.float32, name='_1d_features')

Argument #2

is_enemy_belligerent: TensorSpec(shape=(1, 121), dtype=tf.float32, name='is_enemy_belligerent')

Argument #3

is_observed: TensorSpec(shape=(1, 121), dtype=tf.float32, name='is_observed')

Argument #4

is_neutral: TensorSpec(shape=(1, 121), dtype=tf.float32, name='is_neutral')

Argument #5

true_action_values: TensorSpec(shape=(1, 19), dtype=tf.float32, name='true_action_values')

Argument #2

is_enemy_belligerent: TensorSpec(shape=(1, 121), dtype=tf.float32, name='is_enemy_belligerent')

Argument #3

is_observed: TensorSpec(shape=(1, 121), dtype=tf.float32, name='is_observed')

Argument #4

is_neutral: TensorSpec(shape=(1, 121), dtype=tf.float32, name='is_neutral')

Argument #5

true_action_values: TensorSpec(shape=(1, 19), dtype=tf.float32, name='true_action_values')

So the `tf.function` seems to get saved properly. But Rust sees the world a bit differently. Rust doesn't see a function called `fit_action` but rather `__inference_fit_action_$N` where $N is some integer. At least that's the best match I can find. And when I create an operation with the type being the name of that function, I get this:

Error: "Error finishing operation fit_action: InvalidArgument: 0 inputs specified of 129 inputs in Op while building NodeDef \'fit_action_op\' using Op<name=__inference_fit_action_2397; signature=placeholder:float, is_enemy_belligerent:float, is_observed:float, is_neutral:float, true_action_values:float, umpire_regressor_dense0_matmul_readvariableop_dense0_kernel:resource, umpire_regressor_dense0_biasadd_readvariableop_dense0_bias:resource, umpire_regressor_action_value_18_matmul_readvariableop_action_value_18_kernel:resource, umpire_regressor_action_value_18_biasadd_readvariableop_action_value_18_bias:resource, umpire_regressor_action_value_17_matmul_readvariableop_action_value_17_kernel:resource, umpire_regressor_action_value_17_biasadd_readvariableop_action_value_17_bias:resource, umpire_regressor_action_value_16_matmul_readvariableop_action_value_16_kernel:resource, umpire_regressor_action_value_16_biasadd_readvariableop_action_value_16_bias:resource, umpire_regressor_action_value_15_matmul_readvariableop_action_value_15_kernel:resource, umpire_regressor_action_value_15_biasadd_readvariableop_action_value_15_bias:resource, umpire_regressor_action_value_14_matmul_readvariableop_action_value_14_kernel:resource, umpire_regressor_action_value_14_biasadd_readvariableop_action_value_14_bias:resource, umpire_regressor_action_value_13_matmul_readvariableop_action_value_13_kernel:resource, umpire_regressor_action_value_13_biasadd_readvariableop_action_value_13_bias:resource, umpire_regressor_action_value_12_matmul_readvariableop_action_value_12_kernel:resource, umpire_regressor_action_value_12_biasadd_readvariableop_action_value_12_bias:resource, umpire_regressor_action_value_11_matmul_readvariableop_action_value_11_kernel:resource, umpire_regressor_action_value_11_biasadd_readvariableop_action_value_11_bias:resource, umpire_regressor_action_value_10_matmul_readvariableop_action_value_10_kernel:resource, umpire_regressor_action_value_10_biasadd_readvariableop_action_value_10_bias:resource, umpire_regressor_action_value_09_matmul_readvariableop_action_value_09_kernel:resource, umpire_regressor_action_value_09_biasadd_readvariableop_action_value_09_bias:resource, umpire_regressor_action_value_08_matmul_readvariableop_action_value_08_kernel:resource, umpire_regressor_action_value_08_biasadd_readvariableop_action_value_08_bias:resource, umpire_regressor_action_value_07_matmul_readvariableop_action_value_07_kernel:resource, umpire_regressor_action_value_07_biasadd_readvariableop_action_value_07_bias:resource, umpire_regressor_action_value_06_matmul_readvariableop_action_value_06_kernel:resource, umpire_regressor_action_value_06_biasadd_readvariableop_action_value_06_bias:resource, umpire_regressor_action_value_05_matmul_readvariableop_action_value_05_kernel:resource, umpire_regressor_action_value_05_biasadd_readvariableop_action_value_05_bias:resource, umpire_regressor_action_value_04_matmul_readvariableop_action_value_04_kernel:resource, umpire_regressor_action_value_04_biasadd_readvariableop_action_value_04_bias:resource, umpire_regressor_action_value_03_matmul_readvariableop_action_value_03_kernel:resource, umpire_regressor_action_value_03_biasadd_readvariableop_action_value_03_bias:resource, umpire_regressor_action_value_02_matmul_readvariableop_action_value_02_kernel:resource, umpire_regressor_action_value_02_biasadd_readvariableop_action_value_02_bias:resource, umpire_regressor_action_value_01_matmul_readvariableop_action_value_01_kernel:resource, umpire_regressor_action_value_01_biasadd_readvariableop_action_value_01_bias:resource, umpire_regressor_action_value_00_matmul_readvariableop_action_value_00_kernel:resource, umpire_regressor_action_value_00_biasadd_readvariableop_action_value_00_bias:resource, adam_identity_readvariableop_adam_learning_rate:resource, adam_readvariableop_adam_iter:resource, adam_identity_1_readvariableop_adam_beta_1:resource, adam_identity_2_readvariableop_adam_beta_2:resource, adam_adam_update_dense0_kernel_resourceapplyadam_adam_dense0_kernel_m:resource, adam_adam_update_dense0_kernel_resourceapplyadam_adam_dense0_kernel_v:resource, adam_adam_update_dense0_bias_resourceapplyadam_adam_dense0_bias_m:resource, adam_adam_update_dense0_bias_resourceapplyadam_adam_dense0_bias_v:resource, adam_adam_update_action_value_00_kernel_resourceapplyadam_adam_action_value_00_kernel_m:resource, adam_adam_update_action_value_00_kernel_resourceapplyadam_adam_action_value_00_kernel_v:resource, adam_adam_update_action_value_00_bias_resourceapplyadam_adam_action_value_00_bias_m:resource, adam_adam_update_action_value_00_bias_resourceapplyadam_adam_action_value_00_bias_v:resource, adam_adam_update_action_value_01_kernel_resourceapplyadam_adam_action_value_01_kernel_m:resource, adam_adam_update_action_value_01_kernel_resourceapplyadam_adam_action_value_01_kernel_v:resource, adam_adam_update_action_value_01_bias_resourceapplyadam_adam_action_value_01_bias_m:resource, adam_adam_update_action_value_01_bias_resourceapplyadam_adam_action_value_01_bias_v:resource, adam_adam_update_action_value_02_kernel_resourceapplyadam_adam_action_value_02_kernel_m:resource, adam_adam_update_action_value_02_kernel_resourceapplyadam_adam_action_value_02_kernel_v:resource, adam_adam_update_action_value_02_bias_resourceapplyadam_adam_action_value_02_bias_m:resource, adam_adam_update_action_value_02_bias_resourceapplyadam_adam_action_value_02_bias_v:resource, adam_adam_update_action_value_03_kernel_resourceapplyadam_adam_action_value_03_kernel_m:resource, adam_adam_update_action_value_03_kernel_resourceapplyadam_adam_action_value_03_kernel_v:resource, adam_adam_update_action_value_03_bias_resourceapplyadam_adam_action_value_03_bias_m:resource, adam_adam_update_action_value_03_bias_resourceapplyadam_adam_action_value_03_bias_v:resource, adam_adam_update_action_value_04_kernel_resourceapplyadam_adam_action_value_04_kernel_m:resource, adam_adam_update_action_value_04_kernel_resourceapplyadam_adam_action_value_04_kernel_v:resource, adam_adam_update_action_value_04_bias_resourceapplyadam_adam_action_value_04_bias_m:resource, adam_adam_update_action_value_04_bias_resourceapplyadam_adam_action_value_04_bias_v:resource, adam_adam_update_action_value_05_kernel_resourceapplyadam_adam_action_value_05_kernel_m:resource, adam_adam_update_action_value_05_kernel_resourceapplyadam_adam_action_value_05_kernel_v:resource, adam_adam_update_action_value_05_bias_resourceapplyadam_adam_action_value_05_bias_m:resource, adam_adam_update_action_value_05_bias_resourceapplyadam_adam_action_value_05_bias_v:resource, adam_adam_update_action_value_06_kernel_resourceapplyadam_adam_action_value_06_kernel_m:resource, adam_adam_update_action_value_06_kernel_resourceapplyadam_adam_action_value_06_kernel_v:resource, adam_adam_update_action_value_06_bias_resourceapplyadam_adam_action_value_06_bias_m:resource, adam_adam_update_action_value_06_bias_resourceapplyadam_adam_action_value_06_bias_v:resource, adam_adam_update_action_value_07_kernel_resourceapplyadam_adam_action_value_07_kernel_m:resource, adam_adam_update_action_value_07_kernel_resourceapplyadam_adam_action_value_07_kernel_v:resource, adam_adam_update_action_value_07_bias_resourceapplyadam_adam_action_value_07_bias_m:resource, adam_adam_update_action_value_07_bias_resourceapplyadam_adam_action_value_07_bias_v:resource, adam_adam_update_action_value_08_kernel_resourceapplyadam_adam_action_value_08_kernel_m:resource, adam_adam_update_action_value_08_kernel_resourceapplyadam_adam_action_value_08_kernel_v:resource, adam_adam_update_action_value_08_bias_resourceapplyadam_adam_action_value_08_bias_m:resource, adam_adam_update_action_value_08_bias_resourceapplyadam_adam_action_value_08_bias_v:resource, adam_adam_update_action_value_09_kernel_resourceapplyadam_adam_action_value_09_kernel_m:resource, adam_adam_update_action_value_09_kernel_resourceapplyadam_adam_action_value_09_kernel_v:resource, adam_adam_update_action_value_09_bias_resourceapplyadam_adam_action_value_09_bias_m:resource, adam_adam_update_action_value_09_bias_resourceapplyadam_adam_action_value_09_bias_v:resource, adam_adam_update_action_value_10_kernel_resourceapplyadam_adam_action_value_10_kernel_m:resource, adam_adam_update_action_value_10_kernel_resourceapplyadam_adam_action_value_10_kernel_v:resource, adam_adam_update_action_value_10_bias_resourceapplyadam_adam_action_value_10_bias_m:resource, adam_adam_update_action_value_10_bias_resourceapplyadam_adam_action_value_10_bias_v:resource, adam_adam_update_action_value_11_kernel_resourceapplyadam_adam_action_value_11_kernel_m:resource, adam_adam_update_action_value_11_kernel_resourceapplyadam_adam_action_value_11_kernel_v:resource, adam_adam_update_action_value_11_bias_resourceapplyadam_adam_action_value_11_bias_m:resource, adam_adam_update_action_value_11_bias_resourceapplyadam_adam_action_value_11_bias_v:resource, adam_adam_update_action_value_12_kernel_resourceapplyadam_adam_action_value_12_kernel_m:resource, adam_adam_update_action_value_12_kernel_resourceapplyadam_adam_action_value_12_kernel_v:resource, adam_adam_update_action_value_12_bias_resourceapplyadam_adam_action_value_12_bias_m:resource, adam_adam_update_action_value_12_bias_resourceapplyadam_adam_action_value_12_bias_v:resource, adam_adam_update_action_value_13_kernel_resourceapplyadam_adam_action_value_13_kernel_m:resource, adam_adam_update_action_value_13_kernel_resourceapplyadam_adam_action_value_13_kernel_v:resource, adam_adam_update_action_value_13_bias_resourceapplyadam_adam_action_value_13_bias_m:resource, adam_adam_update_action_value_13_bias_resourceapplyadam_adam_action_value_13_bias_v:resource, adam_adam_update_action_value_14_kernel_resourceapplyadam_adam_action_value_14_kernel_m:resource, adam_adam_update_action_value_14_kernel_resourceapplyadam_adam_action_value_14_kernel_v:resource, adam_adam_update_action_value_14_bias_resourceapplyadam_adam_action_value_14_bias_m:resource, adam_adam_update_action_value_14_bias_resourceapplyadam_adam_action_value_14_bias_v:resource, adam_adam_update_action_value_15_kernel_resourceapplyadam_adam_action_value_15_kernel_m:resource, adam_adam_update_action_value_15_kernel_resourceapplyadam_adam_action_value_15_kernel_v:resource, adam_adam_update_action_value_15_bias_resourceapplyadam_adam_action_value_15_bias_m:resource, adam_adam_update_action_value_15_bias_resourceapplyadam_adam_action_value_15_bias_v:resource, adam_adam_update_action_value_16_kernel_resourceapplyadam_adam_action_value_16_kernel_m:resource, adam_adam_update_action_value_16_kernel_resourceapplyadam_adam_action_value_16_kernel_v:resource, adam_adam_update_action_value_16_bias_resourceapplyadam_adam_action_value_16_bias_m:resource, adam_adam_update_action_value_16_bias_resourceapplyadam_adam_action_value_16_bias_v:resource, adam_adam_update_action_value_17_kernel_resourceapplyadam_adam_action_value_17_kernel_m:resource, adam_adam_update_action_value_17_kernel_resourceapplyadam_adam_action_value_17_kernel_v:resource, adam_adam_update_action_value_17_bias_resourceapplyadam_adam_action_value_17_bias_m:resource, adam_adam_update_action_value_17_bias_resourceapplyadam_adam_action_value_17_bias_v:resource, adam_adam_update_action_value_18_kernel_resourceapplyadam_adam_action_value_18_kernel_m:resource, adam_adam_update_action_value_18_kernel_resourceapplyadam_adam_action_value_18_kernel_v:resource, adam_adam_update_action_value_18_bias_resourceapplyadam_adam_action_value_18_bias_m:resource, adam_adam_update_action_value_18_bias_resourceapplyadam_adam_action_value_18_bias_v:resource -> ; is_stateful=true>"

I know I still need to add inputs to provide the first four arguments (the ones specified in the `tf.function`, but the 125 arguments after that are asking a lot more. It seems to want me to fill in every parameter of the optimizer as well as of the Keras model referenced by the `tf.function`, all of which I had thought were serialized with the model.

Anyway, do you have any clue what's going on here? I just wanted to check to make sure this issue isn't specific to tensorflow-rust before sending this on to a more general mailinglist.

Thanks

Reply all

Reply to author

Forward

0 new messages

Search

Clear search

Close search

Google apps

Main menu