[machine learning, ML, performance, Python, overfitting, K-fold Crossvalidation, performance, regularization, hyperparameter tuning]


Detecting Overfitting with K-Fold Cross-Validation

5 mins

Overview

In the realm of machine learning, model evaluation is crucial for assessing their reliability. It helps in detecting overfitting and underfitting after training algorithms on datasets. Cross-validation stands out as the best approach to pinpoint overfitting. In this short article, you’ll become familiar with the primary causes of overfitting and some solutions to address them. Additionally, we’ll explore K-fold cross-validation, the main method for detecting overfitting.

Overfitting

Overfitting occurs when a model performs well on the training data but poorly on the test data. For example, in the image below, the training accuracy is close to 1, while the test accuracy is much lower, less than 0.7. In this case, the model fails to generalize to new samples.

Why overfitting happens

Overfitting generally happens when the model is too complex relative to the amount of available data. For example, consider a neural network model with a complex structure trained on a small dataset. Some of the reasons that cause overfitting are as follows:

  1. Complex Model: When the model includes many features, it can become overly complex, detecting noise as well.

  2. Small test dataset: If the test dataset is small and doesn’t adequately represent all possible input data, it may not provide a reliable measure of the model’s performance.

  3. Training Data Errors and Outliers: Errors and outliers in the training data introduce noise. If the algorithm interprets this noise as meaningful patterns, it can negatively impact the model’s ability to generalize to new data.

Solutions for overfitting

Several solutions can mitigate overfitting:

  1. Reducing Model Complexity: Simplify the model by decreasing the number of parameters or attributes. For example, in linear regression, include only the most predictive features instead of all available features. Another example is simplifying the structure of neural network models by reducing the number of nodes and layers in a complex architecture.

  2. Increasing Sample Size: Gathering more real data can help, although it may be costly. Alternatively, augmentation methods can add samples.

  3. Noise Reduction: Clean the data by removing outliers and erroneous data to prevent the model from fitting to noisy data.

  4. Regularization: Regularization in linear regression is a technique used to prevent overfitting. When a linear model has large coefficients for its features, it tends to overfit the training data. Regularization terms like L1 and L2 works by penalizing large coefficients, which can decrease them or even set them to zero. This process makes the model smoother and helps prevent overfitting.

  5. Hyperparameter tuning: By finding the best combination of hyperparameters, the model’s complexity can be managed. Examples of hyperparameters include the learning rate in gradient descent and the depth of a decision tree. Overall, hyperparameter tuning helps a balance between model complexity and generalization performance, ultimately reducing the risk of overfitting.

K-Fold Cross-Validation

One method to check how well a model works on new data is using train-test-split. Here, we set aside some data to test the model after training it. But a better method is K-fold cross-validation. In this technique, we split the dataset into k subsets or folds. The model is trained k times, each time using a different fold for testing and the rest for training. This gives us k evaluation scores. Finally, we average these scores to get the overall performance. This approach provides a more reliable evaluation compared to just using train-test-split. Using K-fold cross-validation doesn’t directly prevent overfitting, but it aids in detecting and addressing it. Let’s explore it further by using an example.

We’ll use the diabetes dataset from Sklearn and train a DecisionTreeRegressor to predict diabetes progression.

Tip

Tree-based algorithms are particularly vulnerable to overfitting.

Running the next code block allows you to split the dataset into training and testing sets. You then train the model using the training set and assess its performance by calculating the Mean Squared Error (MSE) for both the training and testing datasets.

from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error

# Load the diabetes dataset
diabetes = load_diabetes()
X = diabetes.data
y = diabetes.target

# a simple train-test split 
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# training the model
model = DecisionTreeRegressor()
model.fit(X_train, y_train)

# Make prediction
y_train_pred = model.predict(X_train)
y_test_pred = model.predict(X_test)

# Calculate Mean Squared Error (MSE)
mse_train = mean_squared_error(y_train, y_train_pred)
mse_test = mean_squared_error(y_test, y_test_pred)

print("MSE for Training Data:", mse_train)
print("MSE for Test Data:", mse_test)

The result shows that the training MSE is zero, whereas the test MSE is 4933, indicating a severe case of overfitting. We can see this because the MSE in the test set is much larger than in the training set. When we implement K-fold cross-validation in the next code block, it becomes even evident that the model is overfits the training data.

# Perform k-fold cross-validation (k=5)
cv_scores = cross_val_score(model, X, y, cv=5, scoring='neg_mean_squared_error')

# Calculate Mean Squared Error (MSE) for each fold and take the average
mse_cross_val = -cv_scores.mean()
print("MSE with 5-fold cross-validation:", mse_cross_val)

By executing the above code, you will get an MSE of 6740 through cross-validation, which is higher than the previous number obtained using a simple train_test_split. This example shows that cross-validation can reveal overfitting more clearly than a simple train_test_split.

Using cross-validation during model training improves our chances of creating a model that not only fits the training data well but also accurately predicts outcomes for new, unseen data. By using multiple subsets of the data, cross-validation helps prevent overfitting and ensures consistent performance across different data segments. It averages the model’s performance across k sample sets, providing robust and reliable results by reducing the variability in training and testing datasets. This effect is especially significant with larger datasets. However, with small sample sizes, cross-validation might not be ideal, as the already limited data is further divided into smaller subsets, making the evaluation less robust and reliable.

Summary
  • Overfitting: The overfitting definition was provided and we discussed its causes and solutions.
  • K-fold cross-validation: We explored the advantages of K-fold cross-validation compared to the traditional train-test-split method. We also provided an example of its implementation in Python.
Contributed by Kheiry Sohooli