Step 8 - Training the LSTM network

Now we will start training the LSTM network. However, before getting started, let's try to define some variables to keep track of the training's performance:

val testLosses = ArrayBuffer[Float]() 
val testAccuracies = ArrayBuffer[Float]() 
val trainLosses = ArrayBuffer[Float]() 
val trainAccuracies = ArrayBuffer[Float]()     

Then, we start performing the training steps with batch_size iterations at each loop:

var step = 1 
while (step * batchSize <= trainingIters) { 
    val (batchTrainData, batchTrainLabel) = { 
        val idx = ((step - 1) * batchSize) % trainingDataCount 
        if (idx + batchSize <= trainingDataCount) { 
          val datas = trainData.drop(idx).take(batchSize) 
          val labels = trainLabels.drop(idx).take(batchSize) (datas, ...

Get Scala Machine Learning Projects now with the O’Reilly learning platform.

O’Reilly members experience books, live events, courses curated by job role, and more from O’Reilly and nearly 200 top publishers.