Model-personalization demo app - fixed class incremental scenario - Wanting feedback

34 views
Skip to first unread message

Nikolas Stavrou

unread,
Apr 24, 2023, 2:57:23 PM4/24/23
to TensorFlow Lite
I wanted to make an update to this.

I've changed the code in the app to load the tflite converter everytime we train and this seems to fix the class incremental scenario and make it work correctly. However, this means that if we use hidden layers in our head (which the default tflite model converter from the repo does not use) the weights will get lost. I guess if we want to extend this to work with trainable weights we could utilize the save and restore weights signature functions from the model converter python code. Any feedback on this?

I also removed the samples from being kept between training cycles and utilized a replayBuffer instead. This seems to offer some continual learning capabilities since without it and with keeping only the latest taken samples, the model does not inference the old classes correctly but only remembers the newest one which is indeed correct.

My last question is that with the replay buffer and using only the samples taken in the last training cycle, how much should the loss be? It seems that it's always a low value which almost immediately converges to 0.000. Is this correct or this might be wrong?

Warm regards,
Nikolas

Nikolas Stavrou

unread,
Apr 24, 2023, 3:05:55 PM4/24/23
to TensorFlow Lite, Nikolas Stavrou
I would like to make a correction and say that indeed in the training function we update the weights and biases as seen here:
@tf.function(input_signature=[
      tf.TensorSpec([None, NUM_FEATURES], tf.float32),
      tf.TensorSpec([None, NUM_CLASSES], tf.float32),
  ])
  def train(self, bottleneck, label):
    """Runs one training step with the given bottleneck features and labels.

    Args:
      bottleneck: A tensor of bottleneck features generated from the base model.
      label: A tensor of class labels for the given batch.

    Returns:
      Map of the training loss.
    """
    with tf.GradientTape() as tape:
      logits = tf.matmul(bottleneck, self.ws) + self.bs
      prediction = tf.nn.softmax(logits)
      loss = self.loss_fn(prediction, label)
    gradients = tape.gradient(loss, [self.ws, self.bs])
    self.optimizer.apply_gradients(zip(gradients, [self.ws, self.bs]))
    result = {'loss': loss}
    for grad in gradients:
      result[grad.name] = grad
    return result

Is there anywhere where the use of save and restore signature functions is used so I can see how to properly use them? I want to setup the model again every time we call startTraining() by calling setupModelPersonalization() but before doing so, saving the weights and then restoring them. Anyone could give feedback on how to do that?

Here are the saving and restore functions for reference:
@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
  def save(self, checkpoint_path):
    """Saves the trainable weights to the given checkpoint file.

    Args:
      checkpoint_path: A file path to save the model.

    Returns:
      Map of the checkpoint file path.
    """
    tensor_names = [self.ws.name, self.bs.name]
    tensors_to_save = [self.ws.read_value(), self.bs.read_value()]
    tf.raw_ops.Save(
        filename=checkpoint_path,
        tensor_names=tensor_names,
        data=tensors_to_save,
        name='save')
    return {'checkpoint_path': checkpoint_path}

  @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
  def restore(self, checkpoint_path):
    """Restores the serialized trainable weights from the given checkpoint file.

    Args:
      checkpoint_path: A path to a saved checkpoint file.

    Returns:
      Map of restored weight and bias.
    """
    restored_tensors = {}
    restored = tf.raw_ops.Restore(
        file_pattern=checkpoint_path,
        tensor_name=self.ws.name,
        dt=np.float32,
        name='restore')
    self.ws.assign(restored)
    restored_tensors['ws'] = restored
    restored = tf.raw_ops.Restore(
        file_pattern=checkpoint_path,
        tensor_name=self.bs.name,
        dt=np.float32,
        name='restore')
    self.bs.assign(restored)
    restored_tensors['bs'] = restored
    return restored_tensors

Nikolas Stavrou

unread,
Apr 24, 2023, 4:37:21 PM4/24/23
to TensorFlow Lite, Nikolas Stavrou
It is worth noting that I've made an attempt to call the save and restore functions myself as follows:

However, the app crashes once we try to call startTraining for the second time with the following error:
E/AndroidRuntime: FATAL EXCEPTION: main
    Process: org.tensorflow.lite.examples.modelpersonalization, PID: 22275
    java.lang.IllegalArgumentException: Internal error: Cannot copy empty/scalar Tensors.
        at org.tensorflow.lite.TensorImpl.readMultiDimensionalArray(Native Method)
        at org.tensorflow.lite.TensorImpl.copyTo(TensorImpl.java:219)
        at org.tensorflow.lite.NativeInterpreterWrapper.runSignature(NativeInterpreterWrapper.java:213)
        at org.tensorflow.lite.Interpreter.runSignature(Interpreter.java:261)
        at org.tensorflow.lite.examples.modelpersonalization.TransferLearningHelper.startTraining(TransferLearningHelper.kt:172)
        at org.tensorflow.lite.examples.modelpersonalization.fragments.CameraFragment.onViewCreated$lambda-15$lambda-11(CameraFragment.kt:239)
        at org.tensorflow.lite.examples.modelpersonalization.fragments.CameraFragment.$r8$lambda$kJ43nXk7KWjLovHKsDpGLNEcwME(Unknown Source:0)
        at org.tensorflow.lite.examples.modelpersonalization.fragments.CameraFragment$$ExternalSyntheticLambda14.onClick(Unknown Source:2)
        at android.view.View.performClick(View.java:6256)
        at android.view.View$PerformClick.run(View.java:24701)
        at android.os.Handler.handleCallback(Handler.java:789)
        at android.os.Handler.dispatchMessage(Handler.java:98)
        at android.os.Looper.loop(Looper.java:164)
        at android.app.ActivityThread.main(ActivityThread.java:6541)
        at java.lang.reflect.Method.invoke(Native Method)
        at com.android.internal.os.Zygote$MethodAndArgsCaller.run(Zygote.java:240)
        at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:767)

fun startTraining() {
if (interpreter == null || firstTrainingFlag) {
setupModelPersonalization()
firstTrainingFlag = false

}
else
{
// Save weights
val checkpointPath = MainActivity.getCheckpointPath(context)
val saveInputs: MutableMap<String, Any> = HashMap()
saveInputs[SAVE_INPUT_KEY] = checkpointPath
val saveOutputs: MutableMap<String, Any> = HashMap()
saveOutputs[SAVE_OUTPUT_KEY] = checkpointPath
interpreter?.runSignature(saveInputs, saveOutputs, SAVE_KEY)

setupModelPersonalization()

// Load weights
val restoreInputs: MutableMap<String, Any> = HashMap()
restoreInputs[RESTORE_INPUT_KEY] = checkpointPath
val restoreOutputs: MutableMap<String, Any> = HashMap()
val restoredTensors = HashMap<String, FloatArray>()
restoreOutputs[RESTORE_OUTPUT_KEY] = restoredTensors
interpreter?.runSignature(restoreInputs, restoreOutputs, RESTORE_KEY)
}

fun getCheckpointPath(context: Context): String {
val checkpointDir = context.getDir("checkpoints", Context.MODE_PRIVATE)
if (!checkpointDir.exists()) {
checkpointDir.mkdirs()
}
return File(checkpointDir, "checkpoint").absolutePath
}


Any help is appreciated.

Nikolas Stavrou

unread,
Apr 27, 2023, 5:52:29 AM4/27/23
to TensorFlow Lite, Nikolas Stavrou
I've managed to print the weights of the layer in the head before the softmax activation and I wanted to see if they're saved everytime I call setupModelPersonalization or not.

If anyone could give a feedback, the weights are below for the class incremental scenario without calling the setup of the model every training cycle (as it was on demo app which it doesnt inference correctly) and for the scenario where we call the setup of the model every time (it inferences the class incremental scenario correctly but I am unsure if the weights are being saved, otherwise there is no on-device training happening).
weightswithsetupeverytime.txt
weightswithnosetupeverytime.txt
Reply all
Reply to author
Forward
0 new messages