Visualize Predictions over Time
Log, explore, and dynamically query model predictions across training epochs. Made by Stacey Svetlichnaya using Weights & Biases
log metrics, images, text, etc. to a wandb.Table() during model training or evaluation
view, sort, filter, group, join, interactively query, and otherwise explore these tables
compare model predictions or results over time: dynamically across different epochs or validation steps
track changes at the level of predictions on specific examples in your dataset at a specific step during training, and aggregate over these dynamically
interactively analyze the predictions to understand patterns of errors and opportunities for improvement, within the same model over time or across different models
In this guide, I train a basic convnet on MNIST using Pytorch and explain how to log relevant fields for a prediction to a wandb.Table(), organize tables across time steps, and interact with the resulting tables to analyze model performance over time.
Compare a new model to the baseline, filter out correct predictions, and group by guess to see the patterns of misclassifications with visual examples for each label.
How to log a wandb.Table
Create a table
Add each prediction
Sync the artifact
Browse logged artifacts
Open the table
Histograms in tables
Compare model predictions
Compare across steps
Compare different models
1. How to log a wandb.Table
1.1 Create a Table: for each meaningful step
To compare model predictions over the course of training, decide what a meaningful step is in the context of your project. In the example colab,
one epoch consists of one training step (one pass through the MNIST training data) and one validation or test step (one pass through the MNIST test data to see how the model is doing/validate performance on unseen data). I log a wandb.Table() of predictions at the end of every validation step. For longer training (say, 200 epochs), you may only want to run validation and log predictions every 10 or 20 epochs. For more detail, I could instead log more frequently, say including the training predictions alongside the validation.
1.2 Define columns: what to log for each prediction row
When you create a wandb.Table, specify the columns you'll want to see for each row or prediction. These columns could be of many types (numeric, text, boolean, image, video, audio, etc). Some useful ones to consider are:
the image pixels: what is the model looking at when it makes the prediction?
the model's predicted label and the ground truth label: was the model correct?
the image id in your dataset: track a single example and compare performance across multiple model versions. Note that in most cases you will need some unique identifier as a join key across different Tables. If your dataset doesn't have meaningful ids, this could simply be the index of the image in a fixed dataset ordering.
confidence scores for all classes: this lets you visualize the distribution of predictions, generate a confusion matrix, and look for patterns of errors
any other metrics about the prediction or image that can't be derived from existing columns: it's easy to remove columns if you don't end up needing them :)
You can dynamically add new columns with derived values after logging the Table. For example, you could sum the confidence scores for the digits 0, 6, 8, and 9 to evaluate performance on "closed loops" of handwriting in MNIST and use this "score_loops" column to rank, filter, and compare predictions without needing to rerun your code.
Sample code to create a wandb.Table
# create a wandb Artifact for each meaningful step# I use the unique wandb run id to organize my artifactstest_data_at = wandb.Artifact("test_samples_" + str(wandb.run.id), type="predictions")# create a wandb.Table() with columns for the id, image pixels,# guess (model's prediction), the ground truth, confidence scores for all labelscolumns=["id", "image", "guess", "truth"]for digit in range(10): columns.append("score_" + str(digit))test_table = wandb.Table(columns=columns)
1.3 Add each prediction as a row
Add each prediction as a row to the table (likely inside your validation loop). In the example colab
, I keep the number of images/rows logged minimal for simplicity. You can log up to 50,000 rows per table, and further analysis operations—say to join predictions from different models (and different tables) by the same id—efficiently reference data instead of creating new copies.
test_table.add_data(img_id, wandb.Image(img), guess, truth, *scores)
1.4 Sync the artifact
Once the table for a given step is complete, add it to the artifact and log it to wandb:
# log predictions table to wandb, giving it a nametest_data_at.add(test_table, "predictions")wandb.run.log_artifact(test_data_at)
2. Visualize tables
Once you've logged Tables, you can explore them interactively.
View the images, predicted label (guess), ground truth, and all confidence scores
Higher detail version
2.1 Browse logged artifacts
To view the wandb.Table logged for a particular experiment run:
You will see a list of the input and output artifacts for the run. In this case, I log one artifact version (v0 through v4) for each of 5 validation steps, evaluating the model's performance after a matching number of training epochs (v0 after one epoch, v1 after two epochs, etc).
2.2 Open the table
Select one of these versions to see more information about the artifact version. You can also annotate each version with an alias (short tag) or longer notes. Select a table ("predictions" below, highlighted in green) to see the visual Table.
2.3 Table operations
Each prediction is logged as a row with the columns you've specified (see next section) . You can explore a Table interactively—try taking some of the actions below in this interactive example.
filter: where did the model guess wrong? Use the Filter button to enter an expression like x["guess"] != x["truth"] to see only the incorrect predictions.
sort on a column: what are the highest confidence errors for a particular class? Click on the three dots in a column header to sort by that column. Focus on a particular case, say 4s that look like 9s, with an expression like x["guess"] = 9 and x["truth"] = 4, then sort by "score_9" to see the most confusing fours (first image below)
group on a column: are there patterns in what the model gets wrong, or most frequent types of mistakes? After filtering out any correct guesses (x["guess"] != x["truth"] ), group by "guess" to scroll through examples organized by the predicted class.
add columns to see derived information: after grouping by "guess", add a column via the three dot menu in the header and edit it to show x["image"].count. Now you can sort by this derived column to see the top confused classes (second image below). You can also add a duplicate column as a temporary workaround for reordering the columns, so you can see the most relevant info side-by-side.
remove columns if you don't need them
Filter to 4s that look like 9s, add a "score_9" column closer to the images, sort by "score_9"
9s are the most frequent wrong guess for this model. The histogram of true labels show that 4s are the most frequent true class confused for 9s, 5s are the most frequent true class confused for 8s, 3s for 7s, 7s for 2s, and so on.
2.4 Histograms in tables
When a grouping operation aggregates numerical values in a Table, you'll see a histogram of their distribution by default. This is useful for visualizing patterns at a higher level and noticing correlations in your data. For example, after filtering out correct predictions, we see that this model is still most confident along the diagonal (broader distributions, higher scores) and generally has low confidence (peaks at zero) in the off-diagonal cells. This is expected because the model's guess is the maximum confidence score, but deviations from this—small blips of higher confidence for a different label—indicate areas or pairings of confusion for the model.
This model only guessed 4 incorrectly once (last row).
Zoomed in, without pixels: the truth distribution (second column) is most useful here, and the off-diagonal cells show where the model is confusing digits, e.g. the small high-confidence bar for guess 1, score 0 (second row, third column).
3. Compare model predictions
We can use a model prediction Table to deeply explore a snapshot of a model at a particular time. How can we evaluate model performance over time and understand how to improve it? All of these visualizations and table operations can be used in a comparative context.
3.1 Compare across steps
Select two tables from two different time steps of the same model. Pick the first time step by going to the "Artifacts" section for a given run, and navigating to the artifact version. Here I'm looking at the first validation step logged for my "baseline" model. You may notice my earlier prediction artifacts have run ids in their names (like "test_samples_3h9o3dsk")—this is a good default pattern to make sure I can trace which runs produced which artifact versions. Once I know which ones are important and which ones I want to keep, I can rename these from the UI for clarity and take notes about the changes I made.
Click on a logged Table ("predictions" here) to view the model predictions for this step. Next, select a different step for comparison: e.g. the final validation epoch for this model, "v4 latest", and click "Compare":
Your initial view will look like this:
This is useful for advanced queries, but I prefer the side-by-side rendering: change "WbTableFile" in the top dropdown to "SplitPanel → WbTableFile".
View side-by-side: Blue left, yellow right
Now you can compare model predictions across time:
observe shifts in predicted labels and confidence scores for the same examples over additional training epochs
filter and group both tables in tandem to view changes in the number of mistakes by class, the confusion matrix, most confused pair of labels, hardest negatives, confidence score distribution by label, and much more
add columns of derived information—say precision and recall—without needing to rerun any of your training
3.2 Compare different models
Follow the same steps to compare results across different models. Here I compare predictions between the baseline and a new model variant, 2x_layers_2x_lr, in which I've doubled the size of my first convolutional layer from 32 to 64, the second from 128 to 256, and the learning rate from 0.001 to 0.002. From this live example
, I use the "split panel" view and filter down to the incorrect predictions after just one epoch of training. Notice that the baseline on the left gets 152 examples wrong, while the double variant gets 238 wrong.
If I then group by "guess", add a column to count the number of items in each group, and sort by that columns, I see that the baseline most frequently mistakes 8s for 3s, while the double variant most frequently mistakes 4s for 9s:
If I repeat this analysis comparing the two models' predictions after 5 epochs of training (live example→
), the double variant matches and slightly beats the baseline (90 vs 95 incorrect) and for both models 5 and 7 become the most frequent wrong guess.
This is a toy example of model comparison, but it illustrates the ease, flexibility, and depth of the exploratory analysis you can do with Tables—without rerunning any of your code, writing new one-off scripts, generating new charts, etc.
Table comparison operations and views
All of the single table view operations apply to the comparison view: filter, sort, group, and add or remove columns as described above. In the side-by-side "SplitPanel" view, any operations apply to both sides. In the default flattened "WbTableFile" view, you can refer to the individual tables using the alias 0 (blue highlight) or 1 (yellow highlight). In the flattened view, you can compare scores and other metrics in stacked histograms. Note that you need to specify a column on which to join for this view, and typically this is the example id column.
Q & A
The Tables feature
is brand new, and we'd absolutely love your questions and any feedback in the Comments section below as you try it! Check out more documentation
and a growing list of example colabs here