What is overfitting, and how to mitigate it?

Overfitting

Overfitting: A machine learning model overfits to a dataset when it contains more parameters than can be justified by the data.

An overfitted model’s performance declines significantly when evaluated on data it has not seen before.

There are multiple ways to mitigate overfitting, depending on the model type. They include k-fold cross validation, regularization, ensembling, and early stopping of training.

Long answer

According to The Cambridge Dictionary of Statistics [page 314], an overfitted model is a model “that contain more unknown parameters than can be justified by the data.”

Overfitting is a property of a complex machine learning model that fails to generalize well on unseen data. It can happen in supervised and unsupervised learning as the model fits too well to the data to the point that it learns the noise.

Overfitting is what happens when you look at your teammates and conclude that all data scientists have to be men between the ages of 25 and 35, who never wear dress pants or exercise.

An overfitting model produces high variance error. Its performance varies a lot when presented with unseen data. However, it has low bias error. It’s predictions are not biased to a specific wrong answer.

How to detect overfitting

Test the model on unseen data. If its performance degrades significantly (10% or more) compared to how well it performed on your training data, then it’s overfitting to the noise in your training data, not the patterns you want it to learn.

How to mitigate overfitting

Do one or more of the following, if your model allows it, to alleviate overfitting.

Train on more data

If your training sample is too small and has a rich set of features, a complex model can come up with an elaborate set of rules that only holds for the few examples you have. Training on more data will expose the model to more patterns and help it generalize.

Cross-validation (CV)

If your training data set is large enough, split it into k mini data sets (say k=10). Hold out mini set #1 for testing and train your model on the combined remaining sets (2, 3, … 10). Repeat this process 9 more times while holding out mini set #2, then #3, … #10 for testing, and training on the remaining sets. Average out the model performance on the 10 held out mini test sets to get a measure of how it performs on unseen data.

Source: Wikipedia

By optimizing the model’s performance on different (although overlapping) data sets, you’re reducing its variance.

Regularization

Regularization is the process of penalizing the model performance for its complexity. It adds a regularization term to the cost or loss equation that grows in value as the model complexity increases.

For example, the regularization term could be the L1 norm, L2 norm, or L-infinity norm of the model parameters vector giving us L1, L2, and L-infinity regularization.

Early stopping

Early stopping of training machine learning models to mitigate overfitting.
Source: Wikipedia

As you iteratively train your model, you tune its parameters to reduce its prediction error on the training data (blue curve above). While doing so, the model prediction error on the test data goes down as well (red curve above), until it reaches an inflection point where overfitting kicks in.

To mitigate overfitting, you need to stop training your model at this point. You need to evaluate your model’s performance on your training and test data sets at each iteration. Stop training when the error gap starts growing.

Early stopping is usually used to counter overfitting in deep learning models.

Feature pruning (selection)

One way to reduce a model’s complexity is to drop features with low predictive power and some of the features that are highly correlated from your model. This can be done manually or with an automated procedure.

Some algorithms, like Lasso regression, have built in feature selection and regularization. During model training, some feature coefficient values may gradually drop in value to zero, effectively eliminating their influence in the model output.

If the algorithm you use doesn’t have built-in feature selection capability, you can do it separately. You can use algorithms such as the Fast Correlation-Based Filter or correlation tests such as Chi-squared to identify irrelevant features so you can consider removing them.

Alternatively, you can inspect the features manually and remove extraneous features if they don’t make sense to you.

Ensembling

Ensembling involves combining the output of multiple base models into one. The premise is that the base models will complement each other. You will end up with an ensemble that’s more powerful than its individual models. The main types of ensembling are:

Bagging — (Bootstrap aggregation) where you train multiple strong learners in parallel. Use different training data subsets drawn using bootstrap sampling from the original training data set. You can then combine the output from the base learners using voting for classification problems, and averaging for regression problems.

Boosting — where you train multiple weak learners sequentially. You give more weight to misclassified data points when you use them to train subsequent base models. You then combine the output of the base models using weighted voting for classification problems, and weighted sum for regression problems. The weights reflect the accuracy of the base models.

Stacking — where the base models are typically of different types (e.g. decision tree, naive Bayes, logistic regression). The outputs of the base models are used as features to train a higher level model.

What is underfitting?

Underfitting is a property of a model that hasn’t picked up patterns from the data. It performs poorly on the training and test data sets.

Leave a Reply

Your email address will not be published. Required fields are marked *

You may use these HTML tags and attributes:

<a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code> <del datetime=""> <em> <i> <q cite=""> <s> <strike> <strong>