Introduction to Cross Validation Techniques

A tutorial covering Cross Validation techniques, complete with code and interactive visualizations. Made by Saurav Maheshkar using Weights & Biases
Saurav Maheshkar

Accompanying Colab \longrightarrow

Table of Contents 📝 (Click 🖱 to Expand)

Introduction 👋

Cross Validation techniques are used to evaluate the performance of machine learning models. Let's walk through a typical machine learning pipeline.
We train a model on some data using a particular cost function to improve some metric. But how do we evaluate the model performance now? If we feed in the entire training dataset again, then our results will be significantly skewed if not perfect because our model has been trained using the exact same data. It has for all intents and purposes memorized that data.

The Much Simpler (not always better): Train Test Split

Therefore, in most machine learning pipelines, we create a split in the original dataset to create separate Training and Testing Datasets. This where train_test_split from sklearn comes in. It involves randomly splitting the dataset into two splits using some percentage. For example, a test size of 0.2 means that 20% of the dataset will be separated to create a test dataset. Now, during the training stage, we can feed in the Training dataset and when it comes to evaluation we can use the Test dataset and get a better evaluation metric.
Figure 1: How the much simpler train test split works.
This approach seems nice and easy right? But if we think of it doesn't really care about the inherent distribution of the dataset. For example, consider the widely popular cats vs dogs problem. It might happen that because of the random split, the training dataset gets the majority of the cat images and therefore doesn't really learn how to distinguish dogs from cats but rather how cats look like. We don't want that do we?
As suave as this little dude is, we don't want to bias our models with too much feline energy
We need a method that allows the model to learn all the important features of the dataset while still not allowing the model to rote memorize the dataset.

Code 🧑🏼‍💻

Let's go over a simple code snippet on how you can use train_test_split from sklearn in your pipeline.
from sklearn.model_selection import train_test_split# =============== Fetch dataset ===============# X,y = ... X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)# =============== Train the model ===============# clf = ...#, y_train)# =============== Evaluate the model ===============# clf.score(X_test, y_test)

K-Fold Cross Validation

Figure 2: How K Fold Validation Looks like
In this technique we split the dataset into a number of folds (say k folds). During training, we allocate the first fold as the test dataset and then use the others for training while evaluating the model using the first fold as the test set. In the next iteration, we use the second fold as the test set and use the other folds for training, and continue this process for all k folds. This way we'll get k scores corresponding to each of the k folds being used as the test dataset.

Code 🧑🏼‍💻

Let's go over how you can use KFold from sklearn to split your dataset and then use it in your pipeline.
from sklearn.model_selection import KFold# =============== Fetch dataset ===============# X,y = ... # Define your CV Strategy# Parameters:# n_splits: Number of folds you want# shuffle: Whether to shuffle the data before splitting# random_state: Random state for splittingkfold = KFold(n_splits = 3, shuffle = True, random_state = 1)# =============== Split the Dataset ===============for train, test in kfold.split(X): X_train, X_test = X[train_index], X[test_index] y_train, y_test = y[train_index], y[test_index] # =============== Train the model =============== #, y_train) # =============== Evaluate the model =============== # clf.score(X_test, y_test)


And that wraps up our short post on why you should use Cross Validation and how you can use sklearn to correctly evaluate your models. To see the full suite of wandb features please check out this short 5 minutes guide.
Check out these other reports on Fully Connected covering other fundamental concepts like Linear Regression and Decision Trees.
Report Gallery