Original Source Here
Understanding Regression, Dropout and Overfitting through a Traffic Jam Predictor Using Data from Waze
As a part of my focus, a process in which I spent months learning about machine learning and AI, I replicated (and tried to improve) a traffic prediction algorithm made by Shuyi Wang, where he was the First Prize Winner of HackNTX 2018 with this code. You can find the link to the original project here.
Understanding Libraries and Notebooks was My First Priority.
I had already learned about most machine learning terms, and had overcome the intimidation of big words like “unsupervised learning,” “discriminative model” and “convolutional neural network.” Having felt like I had done enough prep, I set out to find a project that would match my novice skill level.
So I did a quick object detection algorithm that could recognize faces in a picture. This would have been great to talk about, but unfortunately my computer broke and the project didn’t save by the time my computer was fixed and returned to me. So me, being stuck on a school-issued Dell Chromebook that did not permit downloads or using a non-school account, had to find a way to code.
I stumbled upon Google Colab, which would allow me to run code and use libraries through an account, rather than through Pycharm and the like. Around the same time, I found a perfect Replicate project to try out.
I worked my way through Shuyi’s code, figuring out the different cells and their function as I went along. Through this, I learned the value of libraries.
To all you experienced data scientists and programmers out there, this may sound pretty obvious. But to my coding beginners out there, this one’s for you: libraries make the code. In my project, I used Tensorflow, Keras, Numpy, Pandas, and everybody’s favorite, MatPlotlib. I also imported pathlib and pickle.
However, as I was going through the code, I felt like I was blindly copying another project instead of taking advantage of the opportunity to learn from someone else’s code. One way I overcame this feeling was by familiarizing myself with each library, specifically keras and pandas. This allowed me to learn from my experience, and helped me reach my overarching end goal of creating my own project. I would recommend that other beginning programmers do this as well; really dig in to your projects and take them apart to understand the inner workings.
Stagnation Shows Signs of Overfitting.
The coding was fairly easy, until I got to the “recurrent” part of the Recurrent Neural Network (RNN). While the accuracy in the original model improved from this:
My code stagnated, remaining like this even after adding recurrent layers:
As you can see, the training and validation points are very separated from each other, neither really fitting together.
So of course, I was confused. How was it that I had followed the original trafficJam predictor down to a T, taking the time to understand each line of code and yet, it still wasn’t working? If our code was the same, why weren’t the results?
The answer to this question is overfitting. I didn’t even realize what was going on at first, because my accuracy and loss curves were improving. And then they would just stop. I finally recognized the signs of overfitting, or the modeling error in which a function is too closely aligned to a limited set of data points. This can be identified by checking validation metrics, specifically accuracy and loss. These metrics will improve (in accuracy it will increase, in loss it will decrease), until they stagnate or start declining. This shows signs that the model fails to fit additional data or predict future observations. The analysis only corresponds to a particular set of data, showing that the model cannot generalize outside of the training data. Think of the model as a kid in math class. They learn that 2+2 = 4, and that 3+3 = 6. But as an overfitted model, the kid would not be able to figure out 2+3=5.
Overfitting Can Occur as a Result of Low Dropout.
If overfitting means that the model is technically learning “too much” from the training data, then the logical solution, even without coding knowledge, is to remove some of it. Keep the model adaptable, so to speak. Basically, the model is learning a specific scenario way too well, so if we can keep it on its toes, it remains able to generalize to a different dataset than the training data. This can be done through dropout regularization. Dropout basically refers to ignoring neurons, chosen at random, during training.
0.2 as a dropout rate was too low for my model, even though it worked in the original program. Basically, it sets random activations to 0 so that weaker biases that have been ignored have a chance to shine.
Failure is the life of a programmer.
0.5 is the maximum regularization for dropout, where 1/2 activations have been set to 0. After that, the program starts declining in accuracy for me. However, the magic number for me, with the most consistent results was 0.4. Although 0.5 could have worked after running trafficJam after a while, this was too much regularization for my model. Remember that dropout is random in the units it chooses to ignore, so it’s mostly trial and error.
Ultimately, I couldn’t get to the level of accuracy of Shuyi Wang, but I got close:
See how the curves fit the data points way better. This shows that dropout allowed our kid in math class to solve 2+3. Although I didn’t yield as good results, I understand dropout and overfitting way better now. I can now generalize this and apply it to the next model I build (stay tuned for that).
- Understanding the function of ML libraries allows you to become a stronger and more independent programmer.
- Validation metrics, such as accuracy and loss, will either: stagnate, decline, or improve then diminish when overfitting occurs
- Overfitting shows that your model is unable to generalize past a specific data set.
- Increasing dropout mitigates overfitting.
As usual, connect with me on LinkedIn and follow me on Medium!
Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot