Model training

Model training is implemented in the fit(..) method. It takes the following parameters:

  • train_X: array_like, shape (n_samples, n_features), Training data
  • train_Y: array_like, shape (n_samples, n_classes), Training labels
  • val_X: array_like, shape (N, n_features) optional, (default = None), Validation data
  • val_Y: array_like, shape (N, n_classes) optional, (default = None), Validation labels
  • graph: tf.Graph, optional (default = None), TensorFlow Graph object

Next, we look at the implementation of fit(...) function where the model is trained and saved in the model path specified by model_path.

def fit(self, train_X, train_Y, val_X=None, val_Y=None, graph=None):    if len(train_Y.shape) != 1:        num_classes = train_Y.shape[1]    else

Get Neural Network Programming with TensorFlow now with the O’Reilly learning platform.

O’Reilly members experience books, live events, courses curated by job role, and more from O’Reilly and nearly 200 top publishers.