Saving models trained in Java with SavedModel

220 views
Skip to first unread message

Karl Lessard

unread,
Jun 24, 2020, 12:23:27 AM6/24/20
to SIG JVM
Hi everyone, 

During our last (core) session, we discussed comparing two approaches for saving models trained in Java for our first release: saved models and frozen graphs. I was mandated to look at the former while Adam Pocock would investigate the latter.

For saved models, I have some good news! It looks like this technique is more straightforward than expected, at least for a minimal support. After one evening of research and a few lines of code, I was able to train a model in Java, save it as a saved model with its trained variables, reload it from that saved model and run inference on it. 

Most of the work for saving/restoring variables is done by the save and restore operators. Other than that, it is just a matter of initializing properly some protobuf metadata with the graph definition.

So I suggest that we don't spend more time looking at the frozen graphs, as they are a legacy format. I'll start to work on a proper API to save the models using TensorFlow Java and present it to you. We can also discuss more about this during our next framework session, which is scheduled for next Friday.

Karl

Zahra Badey

unread,
Jun 24, 2020, 12:36:54 PM6/24/20
to SIG JVM
Hi Karl,
I use this API in TensorFlow V1.x for inference of saved models, it will be great to have it in the release.
Happy to help if needed and look forward to more info from next session.

Thanks,
- Zahra

Alexey Zinoviev

unread,
Jun 26, 2020, 12:48:08 PM6/26/20
to SIG JVM
Adam and Karl, please share your approaches to Save/Restore ops which were discussed on the demo.
I'd like to be an early adopter for different kind of models to experiment.

Karl Lessard

unread,
Jun 29, 2020, 12:37:41 AM6/29/20
to Alexey Zinoviev, SIG JVM
Hi Alexey,

Please take a look at this branch, it has a first draft of an API for exporting saved models from Java (really just starting, expect it to change!). The export can be invoked like in this unit test, it is as simple as that. Of course, it just covers the very basics right now, I'll keep you in touch with my progress if you are interested. Also let me know if you find out something that is not working for you. If you don't want to checkout that whole branch and just copy pieces of it in yours, check at the SavedModelBundle.Exporter class and at the Graph.addVariableSaver method.

Good luck!
Karl

On Fri, Jun 26, 2020 at 12:48 PM Alexey Zinoviev <zalesl...@gmail.com> wrote:
Adam and Karl, please share your approaches to Save/Restore ops which were discussed on the demo.
I'd like to be an early adopter for different kind of models to experiment.

--
You received this message because you are subscribed to the Google Groups "SIG JVM" group.
To unsubscribe from this group and stop receiving emails from it, send an email to jvm+uns...@tensorflow.org.
To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/jvm/f388b2ad-d242-49af-bd78-310c520f9cbfo%40tensorflow.org.

Alexey Zinoviev

unread,
Jun 29, 2020, 1:39:17 AM6/29/20
to Karl Lessard, SIG JVM
Great, will test it soon

пн, 29 июн. 2020 г., 7:37 Karl Lessard <karl.l...@gmail.com>:

Andrew Schaumberg

unread,
Aug 19, 2020, 11:52:42 PM8/19/20
to SIG JVM
Hi Alexey,


This code is in the session_saved_model branch, in a fork, if you like https://github.com/aday00/java/tree/session_saved_model

Some early diffs of this change are available here, if you prefer https://github.com/tensorflow/java/issues/100#issuecomment-674596955

The testing I ran is listed here https://github.com/tensorflow/java/issues/100#issuecomment-674791320

Hope this helps,
-Andrew

On Monday, June 29, 2020 at 1:39:17 AM UTC-4, Alexey Zinoviev wrote:
Great, will test it soon

пн, 29 июн. 2020 г., 7:37 Karl Lessard <karl....@gmail.com>:
Hi Alexey,

Please take a look at this branch, it has a first draft of an API for exporting saved models from Java (really just starting, expect it to change!). The export can be invoked like in this unit test, it is as simple as that. Of course, it just covers the very basics right now, I'll keep you in touch with my progress if you are interested. Also let me know if you find out something that is not working for you. If you don't want to checkout that whole branch and just copy pieces of it in yours, check at the SavedModelBundle.Exporter class and at the Graph.addVariableSaver method.

Good luck!
Karl

On Fri, Jun 26, 2020 at 12:48 PM Alexey Zinoviev <zales...@gmail.com> wrote:
Adam and Karl, please share your approaches to Save/Restore ops which were discussed on the demo.
I'd like to be an early adopter for different kind of models to experiment.

--
You received this message because you are subscribed to the Google Groups "SIG JVM" group.
To unsubscribe from this group and stop receiving emails from it, send an email to j...@tensorflow.org.

Alexey Zinoviev

unread,
Aug 20, 2020, 5:41:51 AM8/20/20
to Andrew Schaumberg, SIG JVM
Great, I'll try your approach on my graphs and variables or on models from java-models repository
Many thanks!

чт, 20 авг. 2020 г. в 06:52, Andrew Schaumberg <schaumbe...@gmail.com>:
To unsubscribe from this group and stop receiving emails from it, send an email to jvm+uns...@tensorflow.org.
To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/jvm/1afb4b96-f292-44c7-9e48-bade1b01c583o%40tensorflow.org.

Andrew Schaumberg

unread,
Sep 9, 2020, 3:50:29 PM9/9/20
to Alexey Zinoviev, SIG JVM
Hi Alexey,

Excellent, I really appreciate your java-models work!  I've been working from your VGG implementation towards a Resnet. https://github.com/tensorflow/java-models/blob/master/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGGModel.java

Resnet and others use a GlobalAveragePooling2D, but I didn't see a reduce_mean/reduceMean in TensorflowJava, only reduce_sum/reduceSum.

Not sure if there's a better way to do this, but in Scala this is:

      val global_avg_pool2d = tf.withName("global_avg_pool2d").math.div(
                                tf.reduceSum(conv, tf.array(3)), // global average pooling works across dimensions height=1 and width=2, reducing in depth=3 (i.e. depth represents all the conv kernel outputs at this position)
                                tf.constant(64f) // channel depth is 64 from conv layer
                              )

Thanks again!
-Andrew

Jim Clarke

unread,
Sep 9, 2020, 4:53:18 PM9/9/20
to Andrew Schaumberg, Alexey Zinoviev, SIG JVM
Andrew, 

If I remember right, you have to call tf.math.mean() and set the Axis to all the axis in the shape.
Maybe there is a better way, but this worked for me.

Try something like this
===============================

allAxis(Operand<T> op) {
    int rank = op.asOutput().shape().numDimensions();
    int[] axes = new int[rank];
    for (int i = 0; i < rank; i++) {
      axes[i] = i;
    }
    return axes; // return 0,1,2,3,...
}

public static <T extends TType> Operand<TInt32> allAxis(Ops tf, Operand<T> op) {
    int[] ranks = allAxis(op);
    return tf.constant(ranks);
  }

axis = allAxis(tf, xf);
return tf.math.mean(xf, axis, Mean.keepDims(false));

===============================

jim

Alexey Zinoviev

unread,
Sep 10, 2020, 2:31:49 AM9/10/20
to Jim Clarke, Andrew Schaumberg, SIG JVM
Hi, Andrew, 

I was blocked for an advanced ResNet example too (I've made simple branches with union/merging without GlobalAvgPooling).
Looks like custom and common reduceMean in Jim's example (via math.mean) is the best approach to replace a lack of the reduceMean op.

From another side, your approach with 

val global_avg_pool2d = tf.withName("global_avg_pool2d").math.div(
                                tf.reduceSum(conv, tf.array(3)), // global average pooling works across dimensions height=1 and width=2, reducing in depth=3 (i.e. depth represents all the conv kernel outputs at this position)
                                tf.constant(64f) // channel depth is 64 from conv layer
                              )  


is very transparent and corresponds to the spirit of the examples in the repository (special cases with specific numbers and operations). Depends on your goal.

Please, share (make a PR) a correctly working example with ResNet if it will be possible, I could review it.

Sincerely yours,
         Alexey



ср, 9 сент. 2020 г. в 23:53, Jim Clarke <jimcla...@gmail.com>:

Andrew Schaumberg

unread,
Sep 10, 2020, 2:37:40 AM9/10/20
to Jim Clarke, Alexey Zinoviev, SIG JVM
Hi Jim,

Thanks for suggesting mean(), a much better solution!  I noticed Python Keras just calls Tensorflow's reduce_mean, and wasn't sure if TensorflowJava planned to support reduceMean too.

As a test, I'm trying to implement models like Resnet closely to the original Python's GlobalAveragePooling2D:
https://github.com/tensorflow/tensorflow/blob/ee1c2f112dc666a19e0d9e4ad679c0e948ebde8a/tensorflow/python/keras/applications/resnet.py#L172

GlobalAveragePooling2D calls Keras' backend mean():

Keras' mean() calls TensorFlow's reduce_mean():

As a TensorflowJava newbie, I was surprised TensorflowJava seemed to have TensorFlow's reduce_sum as reduceSum, but didn't seem to have TensorFlow's reduce_mean as reduceMean:

TensorflowJava has ReduceSum and others, but no ReduceMean:
...
./tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceAll.java
./tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceAny.java
./tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceMax.java
./tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceMin.java
./tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceProd.java
./tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceSum.java
./tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/ReduceJoin.java
...

Anyway, I'm just exploring and trying to learn what to expect.  I'd hope a Keras Resnet and TensorflowJava Resnet do very similar computations on the GPU, to check one against the other. 

Thanks again,
-Andrew

Karl Lessard

unread,
Sep 10, 2020, 9:02:39 AM9/10/20
to Andrew Schaumberg, Jim Clarke, Alexey Zinoviev, SIG JVM
Interesting fact is that when you look at the code of `ReduceSum.java` and `Sum.java` for example, you can see that they are exactly the same, they both call the `Sum` kernel of the TensorFlow runtime. Furthermore, there is no definition to be found for the "Reduce" variants of these ops in the list of API defs we are carrying for Java. So it sounds to me that all these `Reduce*` ops you were mentioning Andrew are remnants of the past that we should probably get rid of to avoid any more confusion.

This discovery is particularly interesting as we get closer to a first release, we probably want to clean up the list of our ops and make sure in the future that obsolete operations are automatically discarded by our build. I'll add a topic for tomorrow's meeting about that.

Thanks for sharing this!
Karl

Andrew Schaumberg

unread,
Sep 10, 2020, 3:22:11 PM9/10/20
to Karl Lessard, Jim Clarke, Alexey Zinoviev, SIG JVM
Hi Karl,

Thanks for the API lesson!  Given that interesting fact, I looked more closely at TF math ops, and it seems ReduceSum is a typedef for Sum, and ReduceMean is a typedef for Mean, though I don't see any of these ReduceX functions are deprecated https://www.tensorflow.org/api_docs/cc/group/math-ops

I'd agree with either (a) keeping all the typedefs in TensorflowJava, or (b) keeping none of the typedefs in TensorflowJava with a note why they're not there so newbies like me understand.  Currently, it seems TensorflowJava has some typedefs (e.g. ReduceSum) but not others (e.g. ReduceMean).  Maybe there's a more Java-like way of handling typedefs, to alias ReduceMean to Mean etc without the boilerplate?  Anyhow, I'm glad experts like you have it all in mind.  I'm here to learn.

To continue this TensorflowJava Resnet implementation exercise, to ultimately compare to Keras, I wonder if Tensorboard is available for TensorflowJava graphs?  According to https://www.tensorflow.org/tensorboard/graphs#graphs_of_tffunctions I'd start with tf.summary.trace_on or so, and end with tf.summary.trace_export or so.  I see TensorflowJava has:
./tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/Trace.java
./tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/SummaryOps.java
./tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/SummaryWriter.java
./tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/Summary.java

The trouble is, tf.summary.createFileWriter and tf.summary.traceOn don't seem to exist in TensorflowJava's SumaryOps, so I'm unsure how to start Tensorboard logging in TensorflowJava.

A bit of my Scala code is:
      val datetime = java.time.LocalDateTime.now
      val tensorboard_logdir = "tensorboard-%d%02d%02dt%02d%02d%02d".format(datetime.getYear, datetime.getMonthValue, datetime.getDayOfMonth, datetime.getHour, datetime.getMinute, datetime.getSecond)
      val tensorboard_writer = tf.summary.createFileWriter(tensorboard_logdir)
      tf.summary.traceOn(true, true)

Unfortunately scalac complains:
tf.scala:1279: error: value createFileWriter is not a member of org.tensorflow.op.SummaryOps
      val tensorboard_writer = tf.summary.createFileWriter(tensorboard_logdir)
                                          ^
tf.scala:1280: error: value traceOn is not a member of org.tensorflow.op.SummaryOps
      tf.summary.traceOn(true, true)
                 ^
two errors found

Any Tensorboard tips would be greatly appreciated!  Have a great weekend and thanks for TensorflowJava!
-Andrew

Jim Clarke

unread,
Sep 10, 2020, 4:57:32 PM9/10/20
to Andrew Schaumberg, Karl Lessard, Alexey Zinoviev, SIG JVM
All,

I looked at the Python TF source for reduceMean, and it straight out calls mean.
In Python, however, you can pass “None”  for the Axis, and it gets interpreted as all axis.
I traced the Python code for mean, when axis is None, and  it does the same thing as allAxis does.
It would be nice to have similar behavior, perhaps another method that calls AllAxis, before calling
the low level Ops for mean.

I agree with Karl, that reduceMean is superfluous to mean, as are some of the other ReduceXXX methods, like ReduceSum is to Sum.

There are also some other MathOps  that are higher level in Python built on the lower level Ops.
such as tensordotconfusion_matrix,  etc. that are encoded in Python tf.math,  but not yet encoded in Java.
Plus,  there are many higher level Python Ops that do things like type casts, and reshapes before calling their corresponding lower level Ops.
But, I don’t think we have come to a consensus yet on how to handle these in the Java API.


jim

Karl Lessard

unread,
Sep 10, 2020, 5:50:18 PM9/10/20
to Andrew Schaumberg, Jim Clarke, Alexey Zinoviev, SIG JVM
I recalled I did this Mnist example a long time ago, based on TF1.x, which makes use of the summary ops, maybe you can take a look at what I did there?

For the typedefs, I think that at least from the raw list of operations, we should stick to a single endpoint per op, unless you are adding more logic around it (like Jim did with `SparseCrossentropyWithLogits`). But let's see what the SIG has to say about it tomorrow.

- Karl

Andrew Schaumberg

unread,
Sep 14, 2020, 1:50:21 PM9/14/20
to Karl Lessard, Jim Clarke, Alexey Zinoviev, SIG JVM
Hi Karl & all,

Thanks for the Mnist example!  If anyone's interested, kindly find the attached screenshot of Tensorboard from TF-Java training.  There are a bunch of accuracies listed, but accuracy_3 is most important (acc_3 is batch mean, acc_6 is batch stdev, acc_13 is batch max, and acc_17 is batch min, I believe).

For transfer learning, is it possible in TensorflowJava to load an ImageNet-pretrained Python Keras Resnet50 model, and add a Dense layer of neurons on top in Java/Scala?
Relatedly, is it possible to get the loss of the Keras Resnet50 while training in TensorflowJava?

To explain further, with Python code like this to save a pretrained model:
from keras.applications.resnet50 import ResNet50
from keras.engine.topology import Input
from keras.layers import GlobalAveragePooling2D
from keras.models import Model
image_size = 224
channel_count = 3
input_tensor = Input(shape=(image_size, image_size, channel_count), name='r50in')
base_model = ResNet50(weights='imagenet', include_top=False, input_tensor=input_tensor)
model = base_model.output
model = GlobalAveragePooling2D()(model, name='r50globavgpool')
model = Model(inputs=input_tensor, outputs=model)
model.summary()
model.save("imagenet_r50", overwrite=True, include_optimizer=False, save_format='tf')

I'd like to load the model in TensorflowJava, but I get input_info/output_info TensorInfo objects, not Operands, so I'm not sure how to add Dense layers in Scala on top of the Keras model.
      val saved_model: org.tensorflow.SavedModelBundle = org.tensorflow.SavedModelBundle.load(pretrained_model_dirname, "saved_model")
      val signature = saved_model.metaGraphDef.getSignatureDefMap.get("serving_default") // SignatureDef
      val input_info = signature.getInputsMap.get("r50in") // TensorInfo
      val output_info = signature.getOutputsMap.get("r50globavgpool") // TensorInfo

I realize I could use this pretrained model for training, feeding/fetching the layers by name, but I'd like access to the Operands, to engineer the model architecture further.
      saved_model.session.runner
                                        .feed(input_info.getName, x_tensor)
                                        .fetch(output_info.getName)
                                        .run
                                        .get(0).expect(org.tensorflow.types.TFloat32.DTYPE)

I'd like to do something like this, but I can't use the TensorInfo output_info object to add a Dense layer on top:
      val dense_layer = tf.withName("dense_%s".format(name_suffix)).math.add(
                          tf.linalg.matMul(output_info, weights),
                          biases
                        )

Really appreciate this discussion, learning a lot, thanks for your time!
-Andrew
tensorflow-java-tensorboard.png

Jim Clarke

unread,
Sep 14, 2020, 1:58:54 PM9/14/20
to Andrew Schaumberg, Karl Lessard, Alexey Zinoviev, SIG JVM
Andrew,

We are only starting to add Keras type elements to TensorFlow java, so what you are asking for is not available yet. 

jim
<tensorflow-java-tensorboard.png>

Andrew Schaumberg

unread,
Sep 14, 2020, 3:38:40 PM9/14/20
to Jim Clarke, Karl Lessard, Alexey Zinoviev, SIG JVM
Hi Jim & all,

Thanks for your prompt reply!  A different way to save a Keras model involves saving a Numpy array.  I'd often have a callback in Keras that included this:
  np.save(model_save_fn, snapshotted_weights)

So, if I in Python Keras printed the pretrained Resnet50 weights into a plain text file, then parsed that plain text in Java/Scala, could I in TensorflowJava initialize layers with these weights?
It might work something like this, but initializing from the weights rather than a truncated normal:
      val weights = tf.withName("dense_weight_%s".format(name_suffix)).variable(
                      tf.math.mul(
                        tf.random.truncatedNormal(
                          tf.array(input_size, output_size), TFloat32.DTYPE, org.tensorflow.op.random.TruncatedNormal.seed(seed)
                        ),
                        tf.constant(0.1f)
                      )
                    )

I guess I'd have something like
      val weights = tf.withName("dense_weight_%s".format(name_suffix)).variable(
                      tf.math.mul(
                        tf.array(...parsed weights go here or so...)
                        tf.constant(0.1f)
                      )
                    )


Unfortunately, I'd need to define BatchNorm etc in TensorflowJava, so maybe pretrained VGG19 weights are a simpler first step.
So, in VGG19, I can print out the first block1_conv1 weights, like this:
base_model = VGG19(weights='imagenet', include_top=False, input_tensor=input_tensor)
...
block1_conv1 = base_model.get_layer('block1_conv1').get_weights()
print("==block1_conv1 weights==")
print(block1_conv1)

That prints:
...
==block1_conv1 weights==
[array([[[[ 3.41195226e-01,  9.56311151e-02,  1.77448951e-02,
           2.89807528e-01, -9.20122489e-02,  2.07042053e-01,
           6.46437183e-02,  2.16439571e-02,  1.08167537e-01,
           4.84041087e-02,  5.45682721e-02, -6.10647909e-02,
          -1.58045009e-01,  5.19339405e-02, -6.87913522e-02,
           1.33999050e-01, -1.58462003e-02,  1.73906349e-02,
           1.66022182e-01,  5.78112081e-02,  3.48867804e-01,
           2.11988866e-01,  1.48273855e-01, -1.69187382e-01,
           3.48284580e-02,  1.27128616e-01, -3.71840857e-02,
          -2.00428069e-01, -3.16871032e-02, -1.86070353e-01,
          -2.19486892e-01,  1.27115071e-01, -9.16607082e-02,
          -3.43449079e-02, -1.90350711e-01, -2.66607553e-01,
           3.13598178e-02, -3.12656164e-01,  1.40622392e-01,
           1.21161930e-01, -9.97794569e-02,  2.96889722e-01,
          -5.61806671e-02,  2.05685452e-01, -1.03926789e-02,
           9.24662501e-02, -1.07948564e-01, -3.37282866e-01,
           4.27512638e-02,  6.48847446e-02, -9.78276972e-03,
           3.77967954e-01,  3.66937593e-02, -2.69813687e-01,
           1.28001258e-01, -1.02722347e-01,  1.93587355e-02,
           3.05614114e-01, -2.40945131e-01, -1.63531616e-01,
          -2.92619884e-01, -1.14364550e-01, -5.09986579e-02,
          -2.99792644e-03],
         [ 4.64183718e-01,  3.35566774e-02,  1.02450453e-01,
           4.35352564e-01, -1.08011074e-01, -1.64764345e-01,
           8.33548680e-02,  6.91149086e-02, -1.98017612e-01,
          -1.48166239e-01,  1.24934725e-01,  5.46611026e-02,
           3.00729215e-01,  1.84157230e-02, -1.21154279e-01,
          -1.85422197e-01, -7.28116706e-02,  1.85673743e-01,
          -1.73196927e-01, -6.17760159e-02,  2.37114772e-01,
           2.84024507e-01,  6.23529106e-02, -4.54035282e-01,
           1.11567155e-01,  7.88022876e-02, -6.66245446e-02,
           3.54866609e-02, -6.33498430e-02, -1.74995512e-02,
...
...
          -1.30720837e-02,  2.11829692e-01,  6.30676225e-02,
          -1.69432253e-01,  1.14183865e-01,  1.58425927e-01,
           2.94493884e-01, -1.00173786e-01, -1.56037942e-01,
          -3.25661480e-01],
         [-4.16022718e-01, -1.14913411e-01, -1.46728873e-01,
          -1.96428165e-01, -2.61094384e-02, -3.41196507e-02,
          -9.46008414e-03,  9.66064408e-02,  1.05738558e-01,
           6.76928088e-02, -3.86121631e-01, -1.00944348e-01,
          -1.95946872e-01, -1.12268366e-01,  3.15226912e-01,
           1.22262374e-01,  1.81769550e-01, -1.86081201e-01,
           9.39305127e-02,  1.99210957e-01, -3.11680824e-01,
          -2.52262384e-01, -1.59127533e-01,  3.43767434e-01,
          -1.10607758e-01, -1.41195193e-01,  1.82832837e-01,
           2.69443803e-02,  2.73368835e-01,  3.57156396e-02,
           1.29292026e-01,  1.27877861e-01,  1.06675653e-02,
          -1.85635537e-02, -3.01205404e-02,  1.97462142e-01,
          -1.47572801e-01,  1.76026970e-01, -2.24853754e-01,
          -5.00783511e-02, -7.94276670e-02,  2.06059963e-02,
           2.04005763e-02, -1.00091748e-01, -1.30253002e-01,
           1.26242280e-01, -3.39091308e-02,  3.62772673e-01,
          -4.56045792e-02,  6.26502335e-02, -1.58212170e-01,
          -3.25717837e-01,  1.59315526e-01,  3.15451205e-01,
           7.69315362e-02,  5.67030907e-02,  1.59861729e-01,
          -3.58525515e-02,  8.61789584e-02,  9.54354554e-02,
           2.41779909e-01, -1.30795062e-01, -1.37962803e-01,
          -2.65884489e-01]]]], dtype=float32), array([ 0.7301776 ,  0.06493629,  0.03428847,  0.8260386 ,  0.2578029 ,
        0.54867655, -0.01243854,  0.34789944,  0.5510871 ,  0.06297145,
        0.6069906 ,  0.26703122,  0.649414  ,  0.17073655,  0.4772309 ,
        0.38250586,  0.46373144,  0.21496128,  0.46911287,  0.23825859,
        0.4751922 ,  0.70606434,  0.27007523,  0.6855273 ,  0.03216552,
        0.6025288 ,  0.3503486 ,  0.446798  ,  0.7732652 ,  0.58191687,
        0.39083108,  1.7519354 ,  0.66117406,  0.30213955,  0.53059655,
        0.6773747 ,  0.33273223,  0.49127793,  0.26548928,  0.18805602,
        0.07412001,  1.1081088 ,  0.28224325,  0.86755145,  0.19422948,
        0.810332  ,  0.36062282,  0.5072004 ,  0.42472315,  0.49632648,
        0.15117475,  0.79454446,  0.33494323,  0.47283995,  0.41552398,
        0.08496041,  0.37947032,  0.6006739 ,  0.47174454,  0.8130921 ,
        0.45521152,  1.0892007 ,  0.47757268,  0.4072122 ], dtype=float32)]


That looks like the weights, with the biases at the end.  So I'll parse that in Java/Scala and initialize TensorflowJava Operands with it to get the pretrained weights in, yes?

Thanks again,
-Andrew

Jim Clarke

unread,
Sep 14, 2020, 3:46:44 PM9/14/20
to Andrew Schaumberg, Karl Lessard, Alexey Zinoviev, SIG JVM
Howard,

With the new SaveModel function that Karl was talking about you should be able to load a Python trained model into the Java API.
I assume you could run predictions on that model. 
What we are currently missing are the Keras pieces, and I haven’t really looked at that yet.

jim

Andrew Schaumberg

unread,
Sep 14, 2020, 4:19:30 PM9/14/20
to Jim Clarke, Karl Lessard, Alexey Zinoviev, SIG JVM
Hi Jim & all,

Admittedly, I'm a bit desperate to get TensorflowJava CNN training to converge, so I'm hoping these VGG19 weights will work in TF-Java conv layers.  Without CNN pretraining, it seems like local minima are problematic for me, even for simple ML tasks.  Training an MLP on hand-engineered features works wonderfully -- it's just learning from naive convolutional features that's difficult.

Thanks again,
-Andrew

Karl Lessard

unread,
Sep 15, 2020, 9:44:06 AM9/15/20
to Andrew Schaumberg, Jim Clarke, Alexey Zinoviev, SIG JVM
Hey Andrew,

You can already load and run saved models from the actual code but this PR, which will be merged soon, makes it a lot simpler. For ResNet50, according to your previous post, you can do something like this:

try (SavedModelBundle savedModel = SavedModelBundle.load(r50DirName, "saved_model");
    ConcreteFunction globAvgPooling2D = ConcreteFunction.create(...)) { // this is your custom layer
   
    try (Tensor<TFloat32> inputImages = TFloat32.tensorOf(...);  // there are different ways to load images in a tensor, pick any
        Tensor<TFloat32> r50Result = savedModel.function(Signature.DEFAULT_KEY).call(inputImages).expect(TFloat32.DTYPE);
        Tensor<TFloat32> result = globAvgPooling2D.call(r50Result).expect(TFloat32.DTYPE)) {
        ...
    }
}

Keras' ResNet50 accepts a batch of images in input so you can call the inner block in a loop and train your new layer. But this is not as nice as in Python, where your model seems to just "extend" from the pretrained model, plus you won't be able to save both pretrained model and new layer in the same saved model. Still I never tried to play with transfer learning so far so I invite you to do some experiments, probably by checking out the PR's branch or wait for a few days until it gets merged.

- Karl

Adam Pocock

unread,
Sep 15, 2020, 3:06:04 PM9/15/20
to SIG JVM, schaumbe...@gmail.com, karl.l...@gmail.com, zalesl...@gmail.com, SIG JVM, jimcla...@gmail.com
First, you should do it the way that Jim and Karl explain, as it's much better. However there are alternative (much worse) ways.

You can do the thing you suggest, or you could build the graph structure, add placeholders for each node and then run an assign op for each node to load them back in (it's what we do in Tribuo as we're still targeting TF 1.14 until the alpha release from this group). You can see an example here https://github.com/oracle/tribuo/blob/31566c5e2e866f0a73cdfc2c835b874e83792ebe/Interop/Tensorflow/src/main/java/org/tribuo/interop/tensorflow/TensorflowUtil.java#L334. This would allow you to modify certain layers if you want to do transfer learning, but there should be a better way of doing it.

Adam

Andrew Schaumberg

unread,
Sep 16, 2020, 5:22:14 PM9/16/20
to Adam Pocock, SIG JVM, karl.l...@gmail.com, zalesl...@gmail.com, jimcla...@gmail.com
Hi Adam & all,

I really appreciate everyone's guidance!  Though it's not a good way to go (alas deadlines), I can load VGG19 block1 conv layer weights/biases by dumping them from Keras and initializing TensorflowJava layers with the values, then train this highly truncated VGG19 CNN in TensorflowJava, with a dense layer on top.  However, if I try to also load VGG19 block2 layers and train this in TensorflowJava, I get NaN losses, even if learning rate is zero.  For perspective, I also get NaN losses training a CNN from scratch that has a few layers and learning rate is zero.  However, MLP and shallow CNN training work for me.

To proceed systematically, is there a TensorflowJava API to freeze & unfreeze layers?  I'd like to try loading VGG19 weights, freezing all that, training just one dense layer on top until convergence, then slowly unfreezing VGG19 layers one-by-one, hopefully to avoid NaN losses while ultimately training end-to-end.  Alternatively, a dynamic learning rate would probably help, but I don't think TensorflowJava is there yet.

Many thanks again for your ideas and suggestions, learning a lot here!
-Andrew

Adam Pocock

unread,
Sep 16, 2020, 5:47:46 PM9/16/20
to Andrew Schaumberg, SIG JVM, karl.l...@gmail.com, zalesl...@gmail.com, jimcla...@gmail.com
There isn't any notion of layers in the API just yet. You can find the gradients for a subset of the operations and update those, but the optimisers don't really support that at the moment either. You could do it by hand by matching the names of the gradients up with the appropriate layers, making a no-op that has those as control dependencies and then only running that no-op, but that's likely to be error prone (and tricky to do with more complex optimisers than regular gradient descent).

Alternatively while we're in the land of bad solutions to this problem, you could make n graphs, where each one had only a single layer made from variables, and the layers below were constants. Then you could train the first graph, pull out the tensor for the first layer, make a second graph using that first layer's weights as a constant, add a second layer as a variable, train it, etc. That would be a lot of ugly code, but it would work.

We're working on a keras style API, but as mentioned it's not complete yet, so currently training proceeds through regular tf 1.x style operations.

I'll have to look into how keras does freezable layers, it's probably feeding different variable subsets into the optimizer construction, but I'm not sure if we're exposing the right entry points for that.

Adam

Andrew Schaumberg

unread,
Sep 18, 2020, 11:13:55 AM9/18/20
to Adam Pocock, SIG JVM, karl.l...@gmail.com, zalesl...@gmail.com, jimcla...@gmail.com
Hi Adam & all,

Thanks very much for sharing your expertise, and thinking outside the box with me as we ventured deep into the land of bad solutions.  I'm pleased to report Keras-pretrained deep VGG19 networks train now end-to-end for me in TensorflowJava, all the way from blocks 1 to 4 of VGG19 -- millions of parameters.  I imagine block 5 would be OK to include too, but this is enough of an encoder for me, for now.

During the 2020-08-28 SIGJVM meeting that Adam chaired, he commented that initialization of neural networks is really important.  It turns out I needed to initialize the dense layer/weights&biases with orders of magnitude smaller numbers to avoid NaN losses [presumably from numerical instability from exploded gradients].  Generally now, if I get NaN losses from training deep nets, I'll reduce the learning rate and randomly-initialized weight magnitude until NaN losses aren't a problem.  Thanks for your insightful commentary Adam, that helped me appreciate the importance of initialization, and focus there.

Thanks again,
-Andrew

Karl Lessard

unread,
Sep 18, 2020, 7:49:28 PM9/18/20
to Andrew Schaumberg, Adam Pocock, SIG JVM, zalesl...@gmail.com, jimcla...@gmail.com
I’m happy things turned that way, congrats Andrew! 

Don’t hesitate to share more of your experiences with transfer learning in Jar, I don’t think a lot of us have done this so far. If you have some time to spare, I invite you to write down and publish a small example of what you have accomplished under https://github.com/tensorflow/java-models, I think that could be very useful for other users who wants to achieve the same task.

Cheers,
Karl

Andrew Schaumberg

unread,
Sep 18, 2020, 10:22:33 PM9/18/20
to Karl Lessard, Adam Pocock, SIG JVM, zalesl...@gmail.com, jimcla...@gmail.com
Hi Karl & all,

Yes, if our luck holds, perhaps we'll publish a follow-up to our prior work, which is available at https://doi.org/10.1038/s41379-020-0540-1 and as a preprint at https://doi.org/10.1101/396663
Kindly email me privately if you, Adam, Alexey, etc are interested in being acknowledged in any forthcoming manuscripts, or included as coauthors.  Really appreciate your code and insights.
I'd encourage SIGJVM to publish a technical report on arxiv.org so it's citable in academic work, perhaps comparing speed, accuracy, lines of code, features, or default CPU support to Keras etc.  I'm happy to support that effort, but I obviously can't lead.

Please also let me know if there's a home for TensorflowJava-related Scala code, somewhat in the spirit of https://github.com/dhruvrajan/tensorflow-keras-scala
Tonight I think I may have gotten shared weights working too, but for convenience Tensorboard tracing is not comprehensive.
The uberjar approach has worked for me for many years, for stability and portability reasons, but I think I'd need help to set up sbt, maven, etc for proper development builds of a TensorflowScala or similar.  I still have a lot to learn.

Unfortunately, code & data from our prior and ongoing work may be released at a glacial academic pace, due to the need to publish, protect patient privacy, etc.  We'll open up eventually though, and intend to use Creative Commons licensing etc.

Cheers,
-Andrew
Reply all
Reply to author
Forward
0 new messages