Custom Bar Charts
Method: wandb.plot.bar()
Log a custom bar chart—a list of labeled values as bars—natively in a few lines:
data = [[label, val] for (label, val) in zip(labels, values)]
table = wandb.Table(data=data, columns = ["label", "value"])
wandb.log({"my_bar_chart_id" : wandb.plot.bar(table, "label", "value",
title="Custom Bar Chart")})
You can use this to log arbitrary bar charts. Note that the number of labels and values in the lists must match exactly (i.e. each data point must have both). In the example below, you can display/hide individual runs from the bar chart by clicking on the "eye" icon to the left of each run name under "Toy CNN variants".
You can also see more information about a bar on hover (and modify this through the Vega spec in your charts!). To see the full Vega spec of a chart, hover over the top right corner and click on the "eye" icon.
Basic usage
I finetune a CNN to predict 10 classes of living things: plants, birds, insects, etc. I want to plot the final precision for each label in my validation step. I compute the precision using sklearn.metrics.precision_score. This returns val_precision, a list of 10 precision values, one for each class. I then create a bar for each label:
data = [[name, prec] for (name, prec) in zip(self.class_names, val_precision)]
table = wandb.Table(data=data, columns=["class_name", "precision"])
wandb.log({"my_bar_chart_id" : wandb.plot.bar(table, "class_name",
"precision", title="Per Class Precision")})
Steps to follow:
- create a
dataobject: collect the (label, value) pairs as 2D list/array, where each row is a bar, one column is its label, and the other column is its value. The default bar chart assumes two dimensions / two columns, but you could pass in more data and customize the plot further if you wish (e.g. use a third column to give each bar a different color). - pass
datato awandb.Table()object in which you name the columns in order so you can refer to them later - pass the
tableobject and the column names in labels, values order towandb.plot.bar()with an optional title, which will create your custom plot under the keymy_bar_chart_id. To visualize multiple runs on the same plot, keep this plot key constant. Note that the table itself will also be logged in the "Media" section of your workspace, undermy_bar_chart_id_table.
Customized usage
There are many ways to customize the line plot using the Vega visualization grammar.
Here are some simple ones:
- rename the axis titles for clarity: add
"title" : "Your Title"to thexandyfields underencoding - change the orientation of the bars by swapping
xandy - change the stacking of the bars by setting
stacktocenterorzero(instead of overlapping bars, as in the default)
See the full API for wandb.plot.bar() →.
P.S. Computing per class precision for multi-class models
You can compute this whenever your code has access to:
- a model's predicted scores (
val_predictions) on a set of examples - the corresponding ground truth labels (
ground_truth) for those examples
from sklearn.metrics import precision_score
ground_truth_class_ids = ground_truth.argmax(axis=1)
guessed_class_ids = val_predictions.argmax(axis=1)
val_precision = precision_score(ground_truth_class_ids,
guessed_class_ids, average=None)
# now you can log val_precision to a custom chart!