Model training is implemented in the fit(..) method. It takes the following parameters:
- train_X: array_like, shape (n_samples, n_features), Training data
- train_Y: array_like, shape (n_samples, n_classes), Training labels
- val_X: array_like, shape (N, n_features) optional, (default = None), Validation data
- val_Y: array_like, shape (N, n_classes) optional, (default = None), Validation labels
- graph: tf.Graph, optional (default = None), TensorFlow Graph object
Next, we look at the implementation of fit(...) function where the model is trained and saved in the model path specified by model_path.
def fit(self, train_X, train_Y, val_X=None, val_Y=None, graph=None): if len(train_Y.shape) != 1: num_classes = train_Y.shape[1] else