| // Performs forward-backward on the given trainingdata.
https://github.com/tesseract-ocr/tesseract/blob/master/lstm/lstmtrainer.cpp
|
| // Returns a Trainability enum to indicate the suitability of the sample. |
| Trainability LSTMTrainer::TrainOnLine(const ImageData* trainingdata, |
| bool batch) { |
| NetworkIO fwd_outputs, targets; |
| Trainability trainable = |
| PrepareForBackward(trainingdata, &fwd_outputs, &targets); |
| ++sample_iteration_; |
| if (trainable == UNENCODABLE || trainable == NOT_BOXED) { |
| return trainable; // Sample was unusable. |
| } |
| bool debug = debug_interval_ > 0 && |
| training_iteration() % debug_interval_ == 0; |
| // Run backprop on the output. |
| NetworkIO bp_deltas; |
| if (network_->IsTraining() && |
| (trainable != PERFECT || |
| training_iteration() > |
| last_perfect_training_iteration_ + perfect_delay_)) { |
| network_->Backward(debug, targets, &scratch_space_, &bp_deltas); |
| network_->Update(learning_rate_, batch ? -1.0f : momentum_, adam_beta_, |
| training_iteration_ + 1); |
| } |
| #ifndef GRAPHICS_DISABLED |
| if (debug_interval_ == 1 && debug_win_ != nullptr) { |
| delete debug_win_->AwaitEvent(SVET_CLICK); |
| } |
| #endif // GRAPHICS_DISABLED |
| // Roll the memory of past means. |
| RollErrorBuffers(); |
| return trainable; |
| } |
|
|
| // Prepares the ground truth, runs forward, and prepares the targets. |
| // Returns a Trainability enum to indicate the suitability of the sample. |
| Trainability LSTMTrainer::PrepareForBackward(const ImageData* trainingdata, |
| NetworkIO* fwd_outputs, |
| NetworkIO* targets) { |
| if (trainingdata == nullptr) { |
| tprintf("Null trainingdata.\n"); |
| return UNENCODABLE; |
| } |
| // Ensure repeatability of random elements even across checkpoints. |
| bool debug = debug_interval_ > 0 && |
| training_iteration() % debug_interval_ == 0; |
| GenericVector<int> truth_labels; |
| if (!EncodeString(trainingdata->transcription(), &truth_labels)) { |
| tprintf("Can't encode transcription: '%s' in language '%s'\n", |
| trainingdata->transcription().string(), |
| trainingdata->language().string()); |
| return UNENCODABLE; |
| } |
| bool upside_down = false; |
| if (randomly_rotate_) { |
| // This ensures consistent training results. |
| SetRandomSeed(); |
| upside_down = randomizer_.SignedRand(1.0) > 0.0; |
| if (upside_down) { |
| // Modify the truth labels to match the rotation: |
| // Apart from space and null, increment the label. This is changes the |
| // script-id to the same script-id but upside-down. |
| // The labels need to be reversed in order, as the first is now the last. |
| for (int c = 0; c < truth_labels.size(); ++c) { |
| if (truth_labels[c] != UNICHAR_SPACE && truth_labels[c] != null_char_) |
| ++truth_labels[c]; |
| } |
| truth_labels.reverse(); |
| } |
| } |
| int w = 0; |
| while (w < truth_labels.size() && |
| (truth_labels[w] == UNICHAR_SPACE || truth_labels[w] == null_char_)) |
| ++w; |
| if (w == truth_labels.size()) { |
| tprintf("Blank transcription: %s\n", |
| trainingdata->transcription().string()); |
| return UNENCODABLE; |
| } |
| float image_scale; |
| NetworkIO inputs; |
| bool invert = trainingdata->boxes().empty(); |
| if (!RecognizeLine(*trainingdata, invert, debug, invert, upside_down, |
| &image_scale, &inputs, fwd_outputs)) { |
| tprintf("Image not trainable\n"); |
| return UNENCODABLE; |
| } |
| targets->Resize(*fwd_outputs, network_->NumOutputs()); |
| LossType loss_type = OutputLossType(); |
| if (loss_type == LT_SOFTMAX) { |
| if (!ComputeTextTargets(*fwd_outputs, truth_labels, targets)) { |
| tprintf("Compute simple targets failed!\n"); |
| return UNENCODABLE; |
| } |
| } else if (loss_type == LT_CTC) { |
| if (!ComputeCTCTargets(truth_labels, fwd_outputs, targets)) { |
| tprintf("Compute CTC targets failed!\n"); |
| return UNENCODABLE; |
| } |
| } else { |
| tprintf("Logistic outputs not implemented yet!\n"); |
| return UNENCODABLE; |