Cassava Leaf Disease Classification: Part II
Note that all of the code referenced in this blog post can be found at https://www.kaggle.com/bcotler/cassava-transfer-v1-data2040-lobster-rolls
or https://colab.research.google.com/drive/19O85X_5AD_S5hifse_MpOyrCDjvKz_Ry#scrollTo=EBV8jCca1bI-
This blog post is intended as a follow-up discussion on the content presented in our last blog post. If you have not yet read this, we encourage you to read through it before continuing (https://alex-zimbalist.medium.com/cassava-leaf-disease-classification-first-steps-9bc6a6478ec6).
To recap, we began our attempt to classify cassava leaf disease based on labeled images by creating a baseline model that predicted every image had the most common disease: Cassava Mosaic Disease. This naive model made correct predictions about 61% of the time. Our goal was to create a neural network that would perform better than this baseline model. In our last blog post, we constructed a neural network by hand, somewhat arbitrarily choosing the number of convolutional and dense layers, when and where to use techniques such as kernel regularization, batch normalization, dropout, pooling, and padding, and what values to choose for hyperparameters such as learning rate and batch size. This yielded a convolutional neural network (CNN) that performed about the same as our baseline model that predicted Cassava Mosaic Disease every time. Clearly, we need to a bit more scientific in creating a neural network architecture. Fortunately, others have trained highly thought-out CNN architectures on massive numbers of images. We can adapt these models for our goal of identifying cassava leaf disease. This process, known as fine-tuning, will be the main focus of this blog post.
Before we delve too far into our fine-tuning efforts, let’s revisit something our last blog post largely skimmed over: image data augmentation. The images we are training on are presumably high resolution, properly oriented, centered, and not smudged or otherwise compromised. However, it is possible that our model will be asked to make predictions on less-perfect images. Even if it is not, training on distorted images will hopefully allow our network to better learn the features of images that are truly telling of the proper disease classification.
For this project, we used TFRecords to load, preprocess, and augment the image data before training CNNs. Note that we credit much of this code to the following source: credit: https://keras.io/examples/keras_recipes/tfrecord/
The image above includes a code snippet of the get_dataset function given an input list of tfrecord files. Tfrecord files store the image data in a sequence of binary strings. To use the information contained in these files, the get_dataset function has helper functions embedded in its call to load_dataset to decode and parse these files into image data and label data. For the sake of brevity, we won’t go through all of the code here, but it is well explained in the Keras documentation referenced above. After the dataset is loaded, the get_dataset function then shuffles the data, separates it into batches of a given size, prefetches, and returns the dataset. We can call this function on the training, validation, and test files to form the respective datasets.
The next step is data augmentation. Keras provides several built-in methods to preprocess and augment image data. In the code snippet above, we used tf.keras.layers.experimental.preprocessing to apply horizontal and vertical flips, a rotation, a zoom, and horizontal and vertical translations. The parameter values represent the interval within which to apply these random transformations. For example, a RandomZoom parameter of 0.3 means that the function will zoom in or zoom out at a random value between -30% and +30%.
We tried fine-tuning with 3 different pre-trained models: VGG16, ResNet101, and EfficientNetB0. For the sake of brevity, we will demonstrate how we fine-tuned EfficientNetB0 (hereafter EfficientNet), and only report some of the parameters and results for our attempts at fine-tuning with other base models. Before proceeding, we also want to credit https://keras.io/examples/vision/image_classification_efficientnet_fine_tuning/ for helping us fine-tune our own models.
The first step in fine-tuning is to initialize the model using EfficientNet and freezing all the layers except the last few (after a little bit of trial and error, we chose to unfreeze the last 10 layers of EfficientNet). This looks something like this:
Next, we want to add a few of our own dense layers on top of the base model in order to give our model the chance to make special adaptations for the specific types of images we are trying to classify. After a little trial and error, we decided to add 3 dense layers, one with 128 nodes, one with 64 nodes, and, finally, one with 5 nodes (for making final predictions). We also added in some batch normalization and dropout layers, as well as l2 regularization to the dense layers in order to combat some overfitting behavior we were seeing without these measures being taken. The full code specifying the network architecture looks like this:
We compile using the Adam optimizer and a learning rate of 0.0001 and fit the model with 100 epochs, or until early stopping causes the training process to terminate prematurely:
After 10 epochs, this model achieved a training accuracy of 82% to go along with a validation accuracy of 80%. We can visualize the model’s performance as follows (the model actually performed slightly worse in this iteration, only reaching 78% validation accuracy and appearing perhaps slightly overfit):
Using ResNet101, we achieved a training accuracy of 78.4% and a validation accuracy of 79.4% after 10 epochs. For this model, we added dense layers with 256, 128, and 5 neurons on top of the base model. We implemented 2 batch normalization layers and only one dropout layer with a dropout rate of 0.1. This model excluded kernel regularization and didn’t unfreeze any of the layers in the baseline model. Otherwise, the parameters used were very similar to the ones shown above for the model built on top of EfficientNet.
Using VGGNet, our results were a little worse. Trying various parameters, we have so far not been able to get the validation accuracy above 75%. Further, it seems that models built on top of VGGNet are slightly more prone to overfitting, which has made tweaking these models somewhat more challenging. Therefore, we will probably focus most of our future efforts on continuing to tweak models build on top of ResNet and EfficientNet.
Next Steps:
We have successfully created several models that far outshine both our baseline model and our homemade CNN, which both sat at just barely over 62% validation accuracy. However, we believe there are still tweaks to be made that will improve our model performance. Should we be using more/less dropout, fewer/more batch normalization layers, a small/larger learning rate, or unfreezing fewer/more layers from the base models (EfficientNet, ResNet, and VGG)? Perhaps we need to add more dense layers on top of the baseline model? Our image augmentation efforts also have room to grow. Before our next blog post, we will tweak our image augmentation code as well as the model architectures themselves. Stay tuned for our final blog post, in which we will present our final results!