Hi Jim & all,
Thanks for your prompt reply! A different way to save a Keras model involves saving a Numpy array. I'd often have a callback in Keras that included this:
np.save(model_save_fn, snapshotted_weights)
So, if I in Python Keras printed the pretrained Resnet50 weights into a plain text file, then parsed that plain text in Java/Scala, could I in TensorflowJava initialize layers with these weights?
It might work something like this, but initializing from the weights rather than a truncated normal:
val weights = tf.withName("dense_weight_%s".format(name_suffix)).variable(
tf.math.mul(
tf.random.truncatedNormal(
tf.array(input_size, output_size), TFloat32.DTYPE, org.tensorflow.op.random.TruncatedNormal.seed(seed)
),
tf.constant(0.1f)
)
)
I guess I'd have something like
val weights = tf.withName("dense_weight_%s".format(name_suffix)).variable(
tf.math.mul(
tf.array(...parsed weights go here or so...)
tf.constant(0.1f)
)
)
Unfortunately, I'd need to define BatchNorm etc in TensorflowJava, so maybe pretrained VGG19 weights are a simpler first step.
So, in VGG19, I can print out the first block1_conv1 weights, like this:
base_model = VGG19(weights='imagenet', include_top=False, input_tensor=input_tensor)
...
block1_conv1 = base_model.get_layer('block1_conv1').get_weights()
print("==block1_conv1 weights==")
print(block1_conv1)
That prints:
...
==block1_conv1 weights==
[array([[[[ 3.41195226e-01, 9.56311151e-02, 1.77448951e-02,
2.89807528e-01, -9.20122489e-02, 2.07042053e-01,
6.46437183e-02, 2.16439571e-02, 1.08167537e-01,
4.84041087e-02, 5.45682721e-02, -6.10647909e-02,
-1.58045009e-01, 5.19339405e-02, -6.87913522e-02,
1.33999050e-01, -1.58462003e-02, 1.73906349e-02,
1.66022182e-01, 5.78112081e-02, 3.48867804e-01,
2.11988866e-01, 1.48273855e-01, -1.69187382e-01,
3.48284580e-02, 1.27128616e-01, -3.71840857e-02,
-2.00428069e-01, -3.16871032e-02, -1.86070353e-01,
-2.19486892e-01, 1.27115071e-01, -9.16607082e-02,
-3.43449079e-02, -1.90350711e-01, -2.66607553e-01,
3.13598178e-02, -3.12656164e-01, 1.40622392e-01,
1.21161930e-01, -9.97794569e-02, 2.96889722e-01,
-5.61806671e-02, 2.05685452e-01, -1.03926789e-02,
9.24662501e-02, -1.07948564e-01, -3.37282866e-01,
4.27512638e-02, 6.48847446e-02, -9.78276972e-03,
3.77967954e-01, 3.66937593e-02, -2.69813687e-01,
1.28001258e-01, -1.02722347e-01, 1.93587355e-02,
3.05614114e-01, -2.40945131e-01, -1.63531616e-01,
-2.92619884e-01, -1.14364550e-01, -5.09986579e-02,
-2.99792644e-03],
[ 4.64183718e-01, 3.35566774e-02, 1.02450453e-01,
4.35352564e-01, -1.08011074e-01, -1.64764345e-01,
8.33548680e-02, 6.91149086e-02, -1.98017612e-01,
-1.48166239e-01, 1.24934725e-01, 5.46611026e-02,
3.00729215e-01, 1.84157230e-02, -1.21154279e-01,
-1.85422197e-01, -7.28116706e-02, 1.85673743e-01,
-1.73196927e-01, -6.17760159e-02, 2.37114772e-01,
2.84024507e-01, 6.23529106e-02, -4.54035282e-01,
1.11567155e-01, 7.88022876e-02, -6.66245446e-02,
3.54866609e-02, -6.33498430e-02, -1.74995512e-02,
...
...
-1.30720837e-02, 2.11829692e-01, 6.30676225e-02,
-1.69432253e-01, 1.14183865e-01, 1.58425927e-01,
2.94493884e-01, -1.00173786e-01, -1.56037942e-01,
-3.25661480e-01],
[-4.16022718e-01, -1.14913411e-01, -1.46728873e-01,
-1.96428165e-01, -2.61094384e-02, -3.41196507e-02,
-9.46008414e-03, 9.66064408e-02, 1.05738558e-01,
6.76928088e-02, -3.86121631e-01, -1.00944348e-01,
-1.95946872e-01, -1.12268366e-01, 3.15226912e-01,
1.22262374e-01, 1.81769550e-01, -1.86081201e-01,
9.39305127e-02, 1.99210957e-01, -3.11680824e-01,
-2.52262384e-01, -1.59127533e-01, 3.43767434e-01,
-1.10607758e-01, -1.41195193e-01, 1.82832837e-01,
2.69443803e-02, 2.73368835e-01, 3.57156396e-02,
1.29292026e-01, 1.27877861e-01, 1.06675653e-02,
-1.85635537e-02, -3.01205404e-02, 1.97462142e-01,
-1.47572801e-01, 1.76026970e-01, -2.24853754e-01,
-5.00783511e-02, -7.94276670e-02, 2.06059963e-02,
2.04005763e-02, -1.00091748e-01, -1.30253002e-01,
1.26242280e-01, -3.39091308e-02, 3.62772673e-01,
-4.56045792e-02, 6.26502335e-02, -1.58212170e-01,
-3.25717837e-01, 1.59315526e-01, 3.15451205e-01,
7.69315362e-02, 5.67030907e-02, 1.59861729e-01,
-3.58525515e-02, 8.61789584e-02, 9.54354554e-02,
2.41779909e-01, -1.30795062e-01, -1.37962803e-01,
-2.65884489e-01]]]], dtype=float32), array([ 0.7301776 , 0.06493629, 0.03428847, 0.8260386 , 0.2578029 ,
0.54867655, -0.01243854, 0.34789944, 0.5510871 , 0.06297145,
0.6069906 , 0.26703122, 0.649414 , 0.17073655, 0.4772309 ,
0.38250586, 0.46373144, 0.21496128, 0.46911287, 0.23825859,
0.4751922 , 0.70606434, 0.27007523, 0.6855273 , 0.03216552,
0.6025288 , 0.3503486 , 0.446798 , 0.7732652 , 0.58191687,
0.39083108, 1.7519354 , 0.66117406, 0.30213955, 0.53059655,
0.6773747 , 0.33273223, 0.49127793, 0.26548928, 0.18805602,
0.07412001, 1.1081088 , 0.28224325, 0.86755145, 0.19422948,
0.810332 , 0.36062282, 0.5072004 , 0.42472315, 0.49632648,
0.15117475, 0.79454446, 0.33494323, 0.47283995, 0.41552398,
0.08496041, 0.37947032, 0.6006739 , 0.47174454, 0.8130921 ,
0.45521152, 1.0892007 , 0.47757268, 0.4072122 ], dtype=float32)]
That looks like the weights, with the biases at the end. So I'll parse that in Java/Scala and initialize TensorflowJava Operands with it to get the pretrained weights in, yes?
Thanks again,
-Andrew