Skip to main content

How to handle training divergences with our new rewind feature

Easily handle common LLM training failures with Weights & Biases rewind!
Created on July 10|Last edited on July 16
Weights & Biases is an essential tool for researchers and developers to track their machine learning experiments. It offers a comprehensive suite of features for monitoring, comparing, and managing model training runs. One of the recent additions to Weights & Biases's feature set is the "rewind a run" capability, which allows you to seamlessly correct or modify the history of a run. This tutorial will explore this new feature, its practical applications, and how it differs from the "resume" functionality.



As of this writing, the ability to rewind a run is in private preview. Contact W&B Support at support@wandb.com to request access to this feature.
💡

Why use the rewind feature?

During model training, various issues such as loss spikes, divergences, or unexpected behaviors can occur. These anomalies can stem from several factors, including bad data pockets, numerical instabilities, or problematic data batches.
Loss spikes are sudden increases in training loss that may or may not recover quickly. Divergences occur when the model's training process deviates significantly, often requiring intervention. Resumption issues arise from restarting training, such as incorrect data loader states or inconsistent random number generators (RNGs).
The "rewind a run" feature is designed to address these issues by allowing users to reset the state of a run to a specific step, thereby preserving the original data and maintaining a consistent run ID.
Presently, rewind does not support log rewinds, meaning logs are reset in the new run segment, and system metrics rewind, as only new metrics after the rewind point are logged. Additionally, artifacts remain associated with the original run that produced them. For the latest updates, refer to the official docs.

Basic usage of rewind

The "rewind a run" feature allows users to reset the state of a run to a specific step, preserving the original data and maintaining a consistent run ID. This can be particularly useful for correcting errors or modifying the trajectory of an experiment without starting from scratch.
To use this feature, ensure you have wandb Python SDK version >= 0.17.1. The resume_from parameter in wandb.init() allows you to specify the step from which you want to rewind your run. Simply pass the run id followed by the step as the resume_from flag. Here’s a basic example:
import wandb
import math


# Initialize the first run and log some metrics
run1 = wandb.init(project="rewind_article")


for i in range(300):
if i < 250:
run1.log({"metric": 100}) # Log constant metric 100 up to step 250
else:
# Introduce the spikey behavior starting from step 250
subtle_spike = i + (2 * math.sin(i / 3.0)) # Apply a subtle spikey pattern
run1.log({"metric": subtle_spike, "step": i}) # Log spikey metrics


run1.finish()


# Rewind from the first run at step 249 (before the spikes) and log the metric starting from step 249
run2 = wandb.init(project="rewind_article", resume_from=f"{run1.id}?_step=249")


# Continue logging in the new run
for i in range(249, 300):
run2.log({"metric": 300, "step": i}) # Log constant metric 300 from step 249


run2.finish()
In this example, the first run logs normal metrics up to step 250 and then introduces spiky behavior. The second run rewinds to step 249, just before the spikes start, and continues logging normal metrics from that point. This demonstrates how the rewind feature can be used to correct a run by rolling back to a step before an issue occurred and resuming from there.
Here's the chart showing just 'run1' before rewinding:

Run: wild-feather-4
1

After rewinding, we will see the following chart (with a constant metric value after the rewind point):

Run: major-silence-3
1


Rewind vs. resume

You may be familiar with the existing "resume" feature in Weights & Biases but the "rewind a run" feature is distinct from the "resume" functionality. Here’s a quick comparison to understand when you might want to use one over the other:

Rewind

The purpose of rewinding a run is to correct or modify the history of a run from a specific step in the previous history of a run, allowing for new data to be logged from that point. This is ideal for situations where you need to adjust the course of your experiment based on findings or correct mistakes without starting over.

Resume

The purpose of resuming a run is to continue a run from where it left off after a stop or crash. This is perfect for long-running experiments that may be interrupted due to system failures or planned pauses, ensuring that the run continues seamlessly from the last checkpoint.
However, if you would like to resume training from a step before the end of the run you are continuing from, it's best to use rewind instead of the resume feature. To understand this at a deeper level, I'll share a simple example of one of the core limitations of rewind (eg. incorrect usage, which will hopefully give more context as to why rewind is relevant).
The core limitation with resume is that it prevents you from logging properly at steps before the end point of the previous run. Here is a script that demonstrates this:
import wandb

# Initialize W&B for the first run
r1 = wandb.init(project="resume_example_simple", name="initial_run")

# Simulate training for 10 steps and log constant loss values
for step in range(10):
loss = 0.5 # Constant loss value
wandb.log({"loss": loss, "step": step})

r1.finish()

# Rollback to step 5 and resume training
# Initialize W&B for the resumed run
wandb.init(project="resume_example_simple", name="resumed_run", id=r1.id, resume="must")

# Simulate training from step 5 onwards and log new constant loss values
for step in range(5, 10):
loss = 0.3 # Different constant loss value after rollback
wandb.log({"loss": loss, "step": step})

wandb.finish()
The below chart shows that Weights & Biases picks up at step 10, instead of step 5, which was the specified step in the second run. Luckily, we're laser-focused on providing experiment tracking tools for every use case—we'll see later in the article how we can correct this. Correct usage of 'resume' would require picking up at step 11 instead of step 5.

Run: resumed_run
1

Based on our previous example, we'll make a new version using the "rewind" feature which solves the above limitation. Here we should see the chart pick up at step 5.

import wandb


# Initialize W&B for the first run
r1 = wandb.init(project="rewind_article")


# Simulate training for 10 steps and log constant loss values
for step in range(1,10):
loss = 0.5 # Constant loss value
wandb.log({"loss": loss, "step": step})


r1.finish()

# import time
# time.sleep(30)
# Rollback to step 5 and resume training
# Initialize W&B for the resumed run
wandb.init(project="rewind_article", resume_from=f"{r1.id}?_step=4")


# Simulate training from step 5 onwards and log new constant loss values
for step in range(5, 10):
loss = 0.3 # Different constant loss value after rollback
wandb.log({"loss": loss, "step": step})


wandb.finish()
Below, we will see our script chart rewinds properly to step 5, which is expected behavior.

Run: giddy-glade-11
1


Advanced usage with HuggingFace

To illustrate the rewind feature in a more realistic scenario, let's integrate it with Hugging Face’s Transformers library. This example involves training a language model and using the rewind functionality to handle training instabilities. Specifically, we will add noise to our model weights at a certain point during training to simulate a loss spike. This should demonstrate how the rewind feature can be used to recover from such instabilities and continue training effectively.
I’ll share the full script down below, then we will cover each component individually:
import os
import torch
from datasets import load_dataset, Dataset
from transformers import (
AutoTokenizer,
TrainingArguments,
AutoModelForCausalLM,
TrainerCallback,
)
from trl import SFTTrainer
import wandb
import json

# Login to Weights and Biases

run = wandb.init(project="rewind_article")
# Seed for reproducibility
torch.manual_seed(42)

# Configuration
model_name = "sshleifer/tiny-gpt2"
max_seq_length = 1024
output_dir = "./results"
num_train_epochs = 1
per_device_train_batch_size = 2
per_device_eval_batch_size = 2
gradient_accumulation_steps = 16
learning_rate = 5e-6
logging_steps = 10
save_steps = 10
eval_steps = 10
warmup_steps = 0
save_total_limit = 5 # will save best and latest
train_file_path = './final_ds/train_completions.jsonl'
val_file_path = './final_ds/test_completions.jsonl'
loss_exploded = False
explosion_step = 0

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

def load_jsonl_data(file_path):
with open(file_path, 'r') as f:
data = [json.loads(line) for line in f]
return data

train_data = load_jsonl_data(train_file_path)[:10000]
val_data = load_jsonl_data(val_file_path)[:200]

# Convert data to Hugging Face Dataset format
data_list_train = [dict(d) for d in train_data]
data_list_val = [dict(d) for d in val_data]

train_dataset = Dataset.from_list(data_list_train)
val_dataset = Dataset.from_list(data_list_val)

# Filter examples based on max_seq_length
def filter_examples(example):
combined_text = example['input']
tokens = tokenizer.encode(combined_text)
return len(tokens) < max_seq_length

train_dataset = train_dataset.filter(filter_examples)
val_dataset = val_dataset.filter(filter_examples)

# Format chat template
def format_chat_template(example):
return {'text': f"\n{example['input']}\n\n{example['model_name']}\n\n{example['output']}\n"}

# Format and prepare datasets
train_dataset = train_dataset.map(format_chat_template)
val_dataset = val_dataset.map(format_chat_template)

print(f"Number of examples in the train set: {len(train_dataset)}")
print(f"Number of examples in the validation set: {len(val_dataset)}")

def create_and_prepare_model():
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer

model, tokenizer = create_and_prepare_model()

training_arguments = TrainingArguments(
num_train_epochs=num_train_epochs,
output_dir=output_dir,
per_device_train_batch_size=per_device_train_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
save_total_limit=save_total_limit,
logging_steps=logging_steps,
learning_rate=learning_rate,
fp16=False,
bf16=False,
evaluation_strategy="steps",
eval_steps=eval_steps,
warmup_steps=warmup_steps,
lr_scheduler_type="linear",
report_to='none',
save_steps=save_steps,
save_strategy="steps",
metric_for_best_model="eval_loss",
greater_is_better=False,
)

class NoiseInjectionCallback(TrainerCallback):
def __init__(self, noise_start_step, noise_std):
self.noise_start_step = noise_start_step
self.noise_std = noise_std

def on_step_end(self, args, state, control, **kwargs):
global loss_exploded
print(state.global_step)
if state.global_step == self.noise_start_step and not loss_exploded:
print(f"Injecting noise to model weights at step {state.global_step}")
model = kwargs['model']
with torch.no_grad():
for param in model.parameters():
noise = torch.randn_like(param) * self.noise_std
param.add_(noise)

class WandbLoggingCallback(TrainerCallback):
def __init__(self):
self.initial_loss = None

def on_log(self, args, state, control, **kwargs):
global loss_exploded, explosion_step, run
logs = kwargs.get('logs', {})
if 'loss' in logs:
current_loss = logs['loss']
run.log({'train_loss': current_loss}, step=state.global_step)
print("step", state.global_step, "loss", current_loss)
if self.initial_loss is None:
self.initial_loss = current_loss
# else:
if current_loss > 2 * self.initial_loss: # 200% increase
print(f"Loss increased by over 200%: {current_loss}. Stopping training.")
control.should_training_stop = True
loss_exploded = True
explosion_step = int((state.global_step - 2*logging_steps)) # go back 2 steps

wandb_logging_callback = WandbLoggingCallback()
noise_callback = NoiseInjectionCallback(noise_start_step=40, noise_std=4.0)

def reload_model_and_resume():
global model, run, explosion_step
print(f"Explosion step: {explosion_step}")
rwd_step = str(int(explosion_step))
# Construct the directory name of the lowest checkpoint
lowest_checkpoint_dir = f'checkpoint-{rwd_step}'
model_dir = os.path.join(output_dir, lowest_checkpoint_dir)
print("Loading model from", model_dir)
model = AutoModelForCausalLM.from_pretrained(model_dir)
run = wandb.init(project="rewind_article", resume_from=f"{run.id}?_step={str(int(rwd_step))}")
return model, model_dir


model_pth = ""
limit = 0
while True:
limit+=1
if loss_exploded:
model, model_pth = reload_model_and_resume()
# loss_exploded = False
resume_checkpoint = True
else:
resume_checkpoint = False
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=val_dataset,
dataset_text_field="text",
max_seq_length=max_seq_length,
tokenizer=tokenizer,
args=training_arguments,
callbacks=[noise_callback, wandb_logging_callback], # Add both callbacks
packing=True
)

trainer.train(resume_from_checkpoint=False if model_pth == "" else model_pth)
run.finish()
if limit == 2:
break

Overview

The script begins by setting up the environment and configurations necessary for training the GPT-2 model. It initializes the W&B run, sets a random seed for reproducibility, and configures model parameters such as the model name, maximum sequence length, and training arguments.

Loading and preparing the dataset

Next, the script loads the training and validation datasets. It uses the HuggingFace datasets library to load the data and converts it to the required format. The dataset is filtered to ensure that the token length is within the maximum sequence length, and a specific chat format template is applied to each example.

Model initialization

The script initializes the GPT-2 model and tokenizer. The model is configured with appropriate training arguments, including the number of epochs, batch size, learning rate, and logging steps. These configurations are essential for controlling the training process.

Custom callbacks for training

To simulate and handle training instabilities, the script defines two custom callbacks:
1. Noise injection callback: This callback injects noise into the model weights at a specific training step to simulate an exploding loss scenario. This helps in testing the rewind feature by deliberately causing a training instability. In our script, we inject noise at step 30, which causes the loss to increase dramatically.
2. wandb logging callback: This callback logs training and evaluation metrics to W&B. It also monitors the training loss and rewinds the run if the loss increases beyond a specified threshold, simulating an exploding loss condition. This rewind is achieved by using the following function:
def reload_model_and_resume():
global model, run, explosion_step
print(f"Explosion step: {explosion_step}")
rwd_step = str(int(explosion_step))
lowest_checkpoint_dir = f'checkpoint-{rwd_step}'
model_dir = os.path.join(output_dir, lowest_checkpoint_dir)
print("Loading model from", model_dir)
model = AutoModelForCausalLM.from_pretrained(model_dir)
run = wandb.init(project="rewind_article", resume_from=f"{run.id}?_step={str(int(rwd_step))}")
return model, model_dir

Handling rewind and resume

The core functionality of the script involves handling the rewind and resume of training runs. When an exploding loss is detected, the script reloads the model from the last saved checkpoint and resumes training. The resume_from_checkpointparameter is used to load the optimizer state and other training parameters, ensuring a seamless continuation of training from the last checkpoint.

Training Loop

The training loop continuously checks for the exploding loss condition. If detected, the model is reloaded and training is resumed. The script uses the SFTTrainer from the Hugging Face library to manage the training process, incorporating the defined callbacks for noise injection and W&B logging. Since we added a rewind, we see our loss is smooth without any spikes, as we were able to detect them, and rewind the run back before the large increase in loss.

Run: skilled-jazz-62
1


Conclusion

The "rewind a run" feature in Weights & Biases offers a powerful tool for managing and correcting machine learning experiments. By allowing users to reset a run to a specific step, this feature helps address common issues such as loss spikes, divergences, and unexpected behaviors during training. It enables researchers to fine-tune their experiments without losing valuable data, ensuring a more robust and efficient workflow.
One of the key benefits of the rewind feature is its ability to handle training instabilities. By incorporating custom callbacks, you can easily can handle scenarios like exploding loss utilize the rewind functionality.
The "rewind a run" feature in Weights & Biases is a great tool for experiment management, especially in applications involving LLM's. As machine learning workflows continue to evolve, features like rewind will play a pivotal role in enhancing the efficiency and effectiveness of experimental research. Feel free to check out the project repo here.
As of this writing, the ability to rewind a run is in private preview. Contact W&B Support at support@wandb.com to request access to this feature.
💡
Iterate on AI agents and models faster. Try Weights & Biases today.