Source code for lenet.trainer

import tensorflow as tf
from global_definitions import DROPOUT_PROBABILITY

[docs]class trainer(object): """ Trainer for networks Args: network: A network class object dataset: A tensorflow dataset object Attributes: network: This is the network we initialized with. We pass this as an argument and we add it to the current trainer class. dataset: This is also the initializer. It comes from the :class:`lenet.dataset.mnist` module. session: This is a session created with trainer. This session is used for training. tensorboard: Is a summary writer tool. This writes things into the tensorboard that is then setup on the tensorboard server. At the end of the trainer, it closes this tensorboard. """ def __init__ (self, network, dataset): """ Class constructor """ self.network = network self.dataset = dataset self.session = tf.InteractiveSession() tf.global_variables_initializer().run() self.summaries()
[docs] def bp_step(self, mini_batch_size): """ Sample a minibatch of data and run one step of BP. Args: mini_batch_size: Integer Returns: tuple of tensors: total objective and cost of that step """ x, y = self.dataset.train.next_batch(mini_batch_size) _, obj, cost = self.session.run( fetches = [self.network.back_prop, self.network.obj, self.network.cost], \ feed_dict = {self.network.images:x, self.network.labels:y, \ self.network.dropout_prob: DROPOUT_PROBABILITY}) return (obj, cost)
[docs] def accuracy (self, images, labels): """ Return accuracy Args: images: images labels: labels Returns: float: accuracy """ acc = self.session.run(self.network.accuracy,\ feed_dict = { self.network.images: images, self.network.labels: labels, self.network.dropout_prob: 1.0} ) return acc
[docs] def summaries(self, name = "tensorboard"): """ Just creates a summary merge bufer Args: name: a name for the tensorboard directory """ self.summary = tf.summary.merge_all() self.tensorboard = tf.summary.FileWriter("tensorboard") self.tensorboard.add_graph(self.session.graph)
[docs] def test (self): """ Run validation of the model Returns: float: accuracy """ x = self.dataset.test.images y = self.dataset.test.labels acc = self.accuracy (images =x, labels = y) return acc
[docs] def training_accuracy (self, mini_batch_size = 500): """ Run validation of the model on training set Args: mini_batch_size: Number of samples in a mini batch Returns: float: accuracy """ x, y = self.dataset.train.next_batch(mini_batch_size) acc = self.accuracy (images =x, labels = y) return acc
[docs] def write_summary (self, iter = 0, mini_batch_size = 500): """ This method updates the summaries Args: iter: iteration number to index values with. mini_batch_size: Mini batch to evaluate on. """ x = self.dataset.test.images y = self.dataset.test.labels s = self.session.run(self.summary, feed_dict = {self.network.images: x, self.network.labels: y, self.network.dropout_prob: 1.0}) self.tensorboard.add_summary(s, iter)
[docs] def train ( self, iter= 10000, mini_batch_size = 500, update_after_iter = 1000, training_accuracy = False, summarize = True): """ Run backprop for ``iter`` iterations Args: iter: number of iterations to run mini_batch_size: Size of the mini batch to process with training_accuracy: if ``True``, will calculate accuracy on training data also. update_after_iter: This is the iteration for validation summarize: Tensorboard operation """ for it in range(iter): obj, cost = self.bp_step(mini_batch_size) if it % update_after_iter == 0: train_acc = self.training_accuracy(mini_batch_size = 50000) acc = self.test() print( " Iter " + str(it) + " Objective " + str(obj) + " Cost " + str(cost) + " Test Accuracy " + str(acc) + " Training Accuracy " + str(train_acc) ) if summarize is True: self.write_summary(iter = it, mini_batch_size = mini_batch_size) acc = self.test() print ("Final Test Accuracy: " + str(acc))