Examples of Early Stopping in HuggingFace Transformers
In this article, we'll take a look at how to fine-tune your HuggingFace Transformer with Early Stopping regularization using TensorFlow and PyTorch.
Created on January 27|Last edited on December 13
Comment
In this article, we'll see examples to use early stopping regularization to fine-tune your HuggingFace Transformer. We will cover the use of early stopping with native PyTorch and TensorFlow workflow alongside HuggingFace's Trainer API.

Native TensorFlow
If you are using TensorFlow (Keras) to fine-tune a HuggingFace Transformer, adding early stopping is very straightforward with tf.keras.callbacks.EarlyStopping callback. It takes in the name of the metric that you will monitor and the number of epochs after which training will be stopped if there is no improvement.
early_stopper = tf.keras.callbacks.EarlyStopping(monitor='val_loss',patience=5,restore_best_weights=True)
Here early_stopper is the callback that can be used with model.fit.
model.fit(trainloader,epochs=10,validation_data=validloader,callbacks=[early_stopper])
Observations
- The model quickly overfitted on the training dataset which is evident from the validation loss.
- Training model with early stopping leads to early termination of the training process. This, in turn, saves computational costs and time.
- Since the best instance of the model(lowset validation loss) was saved using the EarlyStopping callback the resulting test accuracy indicates a more generalized model.
-
Run set
4
Native PyTorch
Native PyTorch does not have an off-the-shelf early stopping method. But if you are fine-tuning your HuggingFace Transformer using native PyTorch here's a GitHub Gist that provides a working early stopping hook.
class EarlyStopping(object):def __init__(self, mode='min', min_delta=0, patience=10, percentage=False):self.mode = modeself.min_delta = min_deltaself.patience = patienceself.best = Noneself.num_bad_epochs = 0self.is_better = Noneself._init_is_better(mode, min_delta, percentage)if patience == 0:self.is_better = lambda a, b: Trueself.step = lambda a: Falsedef step(self, metrics):if self.best is None:self.best = metricsreturn Falseif np.isnan(metrics):return Trueif self.is_better(metrics, self.best):self.num_bad_epochs = 0self.best = metricselse:self.num_bad_epochs += 1if self.num_bad_epochs >= self.patience:print('terminating because of early stopping!')return Truereturn Falsedef _init_is_better(self, mode, min_delta, percentage):if mode not in {'min', 'max'}:raise ValueError('mode ' + mode + ' is unknown!')if not percentage:if mode == 'min':self.is_better = lambda a, best: a < best - min_deltaif mode == 'max':self.is_better = lambda a, best: a > best + min_deltaelse:if mode == 'min':self.is_better = lambda a, best: a < best - (best * min_delta / 100)if mode == 'max':self.is_better = lambda a, best: a > best + (best * min_delta / 100)
es = EarlyStopping(patience=5)num_epochs = 100for epoch in range(num_epochs):train_one_epoch(model, data_loader) # train the model for one epoch.metric = eval(model, data_loader_dev) # evalution on dev set.if es.step(metric):break # early stop criterion is met, we can stop now
I highly recommend reorganizing your PyTorch code using PyTorch Lightning. It provides early stopping and many other techniques off the shelf. If you are not familiar with PyTorch Lightning here are some reports that will get you started:
Run set
2
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.