Tensorflow-Lite on device training assistance. Continual learning project.

Skip to first unread message

Nikolas Stavrou

Mar 19, 2023, 9:03:59 PM3/19/23
to TensorFlow Lite
Greetings everyone.

I am working on a continual learning project where I am trying to use a pretrained tflite model and expand on it by adding a latent replay buffer with the goal of being able to perform on-device training with continual learning capabilities on the classes. 

I am trying to expand this work: https://arxiv.org/abs/2105.01946
By firstly trying to utilize the new on device training way before making changes to the model as shown here: https://www.tensorflow.org/lite/examples/on_device_training/overview

I am also using as a reference the following tensorflow example: https://github.com/tensorflow/examples/blob/master/lite/examples/model_personalization/README.md

I am new to tensorflow and I'm having a hard time figuring out how to expand on the example and utilize the new way of on device training but I'm working hard on it.

My current issue is that I'm unable to create and convert the model in tflite format. I followed the tutorial and only altered the code to use MobileNetV2 because i dont want to use a custom model and train it, i want a pretrained model. However, I cant find a way to fix this error: 

TypeError: The decorated function train has 3 required argument(s), but tf.function was only passed an input_signature of length 2. This covers 2 required argument(s): ['self', 'x'], but TensorSpecs are still required for the remaining 1 argument(s): ['y'].

I tried keeping the original code to see if i can run it but i get the same error. Any help is appreciated.

If you also have any resource/video etc that could help me with any of the aforementioned tasks please share it, it would be of great help!


Haoliang Zhang

Mar 19, 2023, 11:46:35 PM3/19/23
to Nikolas Stavrou, TensorFlow Lite

Assuming if you are following the example train signature here. The train signature requires both the training example input, and also the target label. Have you provided the target label? I think this is what but TensorSpecs are still required for the remaining 1 argument(s): ['y']. refers to.

You received this message because you are subscribed to the Google Groups "TensorFlow Lite" group.
To unsubscribe from this group and stop receiving emails from it, send an email to tflite+un...@tensorflow.org.
To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/tflite/b42a41be-b0f5-4b05-98ab-63c68278c454n%40tensorflow.org.


Nikolas Stavrou

Mar 20, 2023, 3:52:31 AM3/20/23
to TensorFlow Lite, haol...@google.com, TensorFlow Lite, Nikolas Stavrou
I was trying the example from the website which seems a bit different than the one you sent. I will try the one you sent me and see. Why do we have to define the amount of classes? Ideally I would want the classes to be dynamically incremental when we choose so within the app. The train function by taking the labels as an argument doesn't restrict this?

Nikolas Stavrou

Mar 20, 2023, 6:20:44 AM3/20/23
to TensorFlow Lite, haol...@google.com, TensorFlow Lite, Nikolas Stavrou
The one you've sent me works fine, I managed to create a tflite model from it. The implementation of the head with the buffer should be done inside android studio using the training signature?

I also didn't understand what the usage of this signature here is:

  def initialize_weights(self):
    """Initializes the weights and bias of the head model.

      Map of initialized weight and bias.
    self.ws.assign(tf.random.uniform((self.num_features, self.num_classes)))
    self.bs.assign(tf.random.uniform((1, self.num_classes)))
    return {'ws': self.ws, 'bs': self.bs}

On Monday, March 20, 2023 at 5:46:35 AM UTC+2 haol...@google.com wrote:

Haoliang Zhang

Mar 23, 2023, 7:04:40 PM3/23/23
to Nikolas Stavrou, TensorFlow Lite
I think the `initialize_weights` signature is just using random values to fill the weights and bias tensor.

Nikolas Stavrou

Mar 29, 2023, 1:34:48 PM3/29/23
to TensorFlow Lite, haol...@google.com, TensorFlow Lite, Nikolas Stavrou
Thanks. I've been working a bit on the demo and I think I managed to add a simple replay buffer using the same logic as the val trainingSamples: MutableList<TrainingSample> = mutableListOf().

I have 2 questions in case anyone could offer help:

1: Is there a way to debug the app besides using Logcat so I can see that everything works correctly? (Bottleneck, training cycles, batches, the replay buffer etc.)

Experimenting with the demo I found out that for example, on the first training I added 20 samples of 1 class and 20 samples of another class (2 different objects). I trained the model and performed inference fine. I then went to perform a 2nd training by adding 20 samples to a 3rd class (different object again). When training, the loss remained high at around 3-5 and the inference was completely wrong on the 3rd class. I want to understand why this occurs, I am guessing it is expected with transfer learning however I implemented my replay buffer so I want a way to debug to see what works incorrectly. With the replay buffer, the 3rd class should be able to be inferenced correctly as well.

From what I understood, every click on the training button to add new samples and perform on-device training again only uses the samples taken during that moment and not from previously correctly? So, a simple mutablelist that stores trainingsamples and acts as ar replay buffer should work in utilizing during the training those samples plus the samples taken during that training cycle right?

2: I also went through the model creation script in python that creates the training signatures to be used by the interpreter in android studio (generate_training_model.py). Am I able to add hidden layers in the training signature besides just a fully connected layer with softmax activation with the purpose of those hidden layers be trained on device? (The idea for this is to copy the last few MobileNetV2 layers and add them again in the training function so they can be trained on the device)

For reference, this is the function I am talking about: (I also guess __init__ should be changed too, right?)

def __init__(self, learning_rate=0.001):
    """Initializes a transfer learning model instance.

      learning_rate: A learning rate for the optimzer.
    self.num_features = NUM_FEATURES
    self.num_classes = NUM_CLASSES

    # trainable weights and bias for softmax
    self.ws = tf.Variable(
        tf.zeros((self.num_features, self.num_classes)),
    self.bs = tf.Variable(
        tf.zeros((1, self.num_classes)), name='bs', trainable=True)

    # base model
    self.base = tf.keras.applications.MobileNetV2(
        input_shape=(IMG_SIZE, IMG_SIZE, 3),
    # loss function and optimizer
    self.loss_fn = tf.keras.losses.CategoricalCrossentropy()
    self.optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
      tf.TensorSpec([None, NUM_FEATURES], tf.float32),
      tf.TensorSpec([None, NUM_CLASSES], tf.float32),
  def train(self, bottleneck, label):
    """Runs one training step with the given bottleneck features and labels.

      bottleneck: A tensor of bottleneck features generated from the base model.
      label: A tensor of class labels for the given batch.

      Map of the training loss.
    with tf.GradientTape() as tape:
      logits = tf.matmul(bottleneck, self.ws) + self.bs
      prediction = tf.nn.softmax(logits)
      loss = self.loss_fn(prediction, label)
    gradients = tape.gradient(loss, [self.ws, self.bs])
    self.optimizer.apply_gradients(zip(gradients, [self.ws, self.bs]))
    result = {'loss': loss}
    for grad in gradients:
      result[grad.name] = grad
    return result

Nikolas Stavrou

Mar 29, 2023, 4:31:23 PM3/29/23
to TensorFlow Lite, Nikolas Stavrou, haol...@google.com, TensorFlow Lite
A follow-up question occurred regarding the trainingSamples list where the samples are added with their bottlenecks and class names as seen below:
processInputImage(image, rotation)?.let { tensorImage ->
val bottleneck = loadBottleneck(tensorImage)
val newSample = TrainingSample(

By debugging with Logcat, I realized that the samples in the trainingSamples list are never removed after a training cycle ends. This means that if we add 2 samples, train, perform inference and then go ahead to add 10 more, the trainingSamples will contain 12 samples instead of 10. Is this really the case? It contradicts my observation on my previous reply of the model as it is not classifying correctly the addition of a 3rd class on the 2nd training cycle. Shouldn't it be able to inference everything correctly since every time it retrains on all of the samples gathered in the app so far? 

This confused me a lot because the point of my replayBuffer was to fix the aforementioned issue but how the model is currently without changes, it already retrains using all of the available data and still makes a wrong inference. 

Please correct me if I understood something wrong. 

Here is my Logcat for reference which I observed that the trainingSamples contain all the images gathered and not just those taken with the camera in our last training.

2023-03-29 23:09:49.006 21374-21374 TrainBatch              org...examples.modelpersonalization  D  Training samples: 2
2023-03-29 23:09:49.006 21374-21374 TrainBatch              org...examples.modelpersonalization  D  Replay buffer: 0
2023-03-29 23:09:49.006 21374-21374 ReplayBuffer            org...examples.modelpersonalization  D  Replay buffer size: 0
2023-03-29 23:09:49.006 21374-21374 ReplayBuffer            org...examples.modelpersonalization  D  Training samples size: 2
2023-03-29 23:09:49.006 21374-21374 ReplayBuffer            org...examples.modelpersonalization  D  Combined samples size: 2
2023-03-29 23:09:54.572 21374-21374 PauseTraining           org...examples.modelpersonalization  D  Updating replay buffer
2023-03-29 23:09:54.572 21374-21374 ReplayBuffer            org...examples.modelpersonalization  D  Number of trainingSamples before updating replayBuffer are: 2
2023-03-29 23:09:54.572 21374-21374 ReplayBuffer            org...examples.modelpersonalization  D  Adding 0 samples to replay buffer
2023-03-29 23:09:54.572 21374-21374 ReplayBuffer            org...examples.modelpersonalization  D  Replay buffer size before removing extra samples is now: 0
2023-03-29 23:09:54.572 21374-21374 ReplayBuffer            org...examples.modelpersonalization  D  Replay buffer size after removing extra samples is now: 0
2023-03-29 23:15:08.067 21374-21374 TrainBatch              org...examples.modelpersonalization  D  Training samples: 12
2023-03-29 23:15:08.068 21374-21374 TrainBatch              org...examples.modelpersonalization  D  Replay buffer: 0
2023-03-29 23:15:08.068 21374-21374 ReplayBuffer            org...examples.modelpersonalization  D  Replay buffer size: 0
2023-03-29 23:15:08.068 21374-21374 ReplayBuffer            org...examples.modelpersonalization  D  Training samples size: 12
2023-03-29 23:15:08.068 21374-21374 ReplayBuffer            org...examples.modelpersonalization  D  Combined samples size: 12

Ashok Kumar

Apr 2, 2023, 11:40:49 AM4/2/23
to TensorFlow Lite, nikolas....@gmail.com, haol...@google.com, TensorFlow Lite
Hi Nikolas,

i hope you are doing well, I am also working on similar project to on-device training and continues learning in mobile. instead of using custom scratch model i used a pretrained model. I done some code changes in this python code for digit classification. I truly appreciate you and Tensorflow Team to take a look at the code and provide me with some 
feedback is code works fine or not.
 Thank you!
Reply all
Reply to author
0 new messages