[machine learning, ML, performance, Python, overfitting, K-fold Crossvalidation, performance, regularization, hyperparameter tuning]


Detecting Overfitting with K-Fold Cross-Validation

4 mins

Overview

In the realm of Machine learning, model evaluation is an crucial for assessing their reliability. It helps in detecting overfitting and underfitting after training algorithms on datasets. Cross-validation stands out as the best approach to pinpoint overfitting. In this short article, you’ll become familiar with the primary causes of overfitting and briefly introduce solutions to address it. Additionally, we’ll explore K-fold cross-validation, the main method for detecting overfitting.

Overfitting

Overfitting occurs when a model performs well on the training data but poorly on the test data. For example, in the image below, the training accuracy is close to 1, while the test accuracy is much lower, less than 0.7. In this case, the model fails to generalize to new samples.

Why overfitting happen

Overfitting generally happens when the model is too complex relative to the amount of available data. For example, consider a neural network model with a complex structure trained on a small dataset. Some of the reasons that cause overfitting are as follows:

1- Complex Model: When the model includes many features, it can become overly complex, detecting noise as well.

2- Small test dataset: If the test dataset is small and doesn’t adequately represent all possible input data, it may not provide a reliable measure of the model’s performance.

3- Training Data Errors and Outliers: Errors and outliers in the training data introduce noise. If the algorithm interprets this noise as meaningful patterns, it can negatively impact the model’s ability to generalize to new data.

Solutions for overfitting

Several solutions can mitigate overfitting:

1- Reducing Model Complexity: Simplify the model by decreasing the number of parameters or attributes. For example, in linear regression, include only the most predictive features instead of all available features. As another example, simplifying the structure of neural network models by reducing the number of nodes and layers in a complex architecture.

2- Increasing Sample Size: Gathering more real data can help, although it may be costly. Alternatively, augmentation methods can add samples.

3- Noise Reduction: Clean the data by removing outliers and erroneous data to prevent the model from fitting to noisy data.

4- Regularization: Regularization in linear regression is a technique used to prevent overfitting. When a linear model has large coefficients for its features, it tends to overfit the training data. Regularization terms like L1 and L2 works by penalizing large coefficients, which can decrease them or even set them to zero. This process makes the model smoother and helps prevent overfitting.

5- Hyperparameter tuning: By finding the best combination of hyperparameters, the model’s complexity can be managed. Examples of hyperparameters include the learning rate in gradient descent and the depth of a decision tree. Overall, hyperparameter tuning helps a balance between model complexity and generalization performance, ultimately reducing the risk of overfitting.

K-Fold Cross-Validation

One method to check how well a model works on new data is using train-test-split. Here, we set aside some data to test the model after training it. But a better method is K-fold cross-validation. In this technique, we split the dataset into k subsets or folds. The model is trained k times, each time using a different fold for testing and the rest for training. This gives us k evaluation scores. Finally, we average these scores to get the overall performance. This approach provides a more reliable evaluation compared to just using train-test-split. Using K-fold cross-validation doesn’t directly prevent overfitting, but it aids in detecting and addressing it. Let’s explore it further by means of an example.

We’ll use the diabetes dataset from Sklearn and train a DecisionTreeRegressor to predict diabetes progression.

Tip

Tree-based algorithms are particularly vulnerable to overfitting.

Running the next code block allows you to split the dataset into training and testing sets. You then train the model using the training set and assess its performance by calculating the Mean Squared Error (MSE) for both the training and testing datasets.

from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error

# Load the diabetes dataset
diabetes = load_diabetes()
X = diabetes.data
y = diabetes.target

# a simple train-test split 
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# training the model
model = DecisionTreeRegressor()
model.fit(X_train, y_train)

# Make prediction
y_train_pred = model.predict(X_train)
y_test_pred = model.predict(X_test)

# Calculate Mean Squared Error (MSE)
mse_train = mean_squared_error(y_train, y_train_pred)
mse_test = mean_squared_error(y_test, y_test_pred)

print("MSE for Training Data:", mse_train)
print("MSE for Test Data:", mse_test)

The result shows that the training MSE is zero, whereas the test MSE is 4933, indicating a severe case of overfitting. We can see this because the MSE in test set is much larger than in the training set. When we implement K-fold cross-validation in next codeblock, it becomes even evident that the model is overfits the training data.

# Perform k-fold cross-validation (k=5)
cv_scores = cross_val_score(model, X, y, cv=5, scoring='neg_mean_squared_error')

# Calculate Mean Squared Error (MSE) for each fold and take the average
mse_cross_val = -cv_scores.mean()
print("MSE with 5-fold cross-validation:", mse_cross_val)

Using cross-validation during model training improves our chances of creating a model that not only fits the training data well but also accurately predicts outcomes for new, unseen data.

Summary
  • Overfitting: Overfitting definition is provided and we discuss what are causes and how to address it.
  • K-fold cross-validation: We explore the advantages of K-fold cross-validation compared to the traditional train-test-split method. We also provide an example of its implementation in Python.
Contributed by Kheiry Sohooli