Skip to main content

创建混淆矩阵(Confusion Matrix)图

在Vega中从头开始构建混淆矩阵
Created on November 25|Last edited on July 13
本报告是作者Stacey Svetlichnaya所写的Confusion Matrix的翻译

将多分类混淆矩阵记录到W&B中

要在W&B中创建一个多分类混淆矩阵,请首先找到您的模型开发代码可以访问的同一组示例的预测标签和相应的真实标签(通常在验证步骤中)。 然后执行以下步骤:

  • 将它们传递给plot_confusion_matrix()Python函数(当前作为独立包装器提供,很快将添加到wandb API中)
  • 使用Confusion Matrix v0 Vega规范创建自定义图表(如下所示,很快将作为预设添加到wandb.plot)
  • 通过查询编辑器连接下拉菜单中的对应字段,并查看下面的图表,并根据需要进一步自定义!

在此简单示例中,我微调CNN以预测10种生物(植物、动物、昆虫)中的一种,同时更改训练时期epoch(Epretrain_epochs)和训练示例(NTnum_train)的数量。 在下面的运行设置中,您可以在每个运行旁边切换眼睛符号以显示/隐藏它。 您可以一目了然地看到每个模型的相对性能,并将鼠标悬停在不同的条形上可以查看确切的计数。

不出所料,示例/训练期太少的模型容易犯更多的错误(从最小的模型上的“ Aves”和“ Reptilia”的蓝色条,和第二个最小的模型上的“ Animalia”的红色条可以看出)。 随着训练期和例子数量的增加,模型倾向于做出更准确的预测(沿对角线更强)。 两栖类与爬行动物是这些(公认的噪音)模型中更常见的混淆类。




Vary num train and num epochs
9


步骤1:添加Python代码以绘制wandb.Table()

在验证步骤中,我可以访问所有验证示例的val_data和相应的val_labels,以及该模型可能的标签的完整列表:labels = [“Amphibia”,“ Animalia”,...“ Reptilia” ],这是整数类别标签0 = Amphibia, 1 = Animalia, ... 9 = Reptilia)。 引用我到目前为止在验证回调函数中训练的model,我回调:

val_predictions = model.predict(val_data)
ground_truth = val_labels.argmax(axis=1)
plot_confusion_matrix(ground_truth, val_predictions, labels)

其中plot_confusion_matrix()的定义如下。 您可以使用以下方法进一步自定义调用函数:

  • 调整true_labelspred_labels来减少要显示在矩阵中行或列的类
  • 调整normalize标志,以显示标准化计数(最大浮点数为1.0),而不是原始计数
from sklearn.metrics import confusion_matrix
def plot_confusion_matrix(y_true=None, y_pred=None, labels=None, true_labels=None,
                          pred_labels=None, normalize=False):
    """                   
    Computes the confusion matrix to evaluate the accuracy of a classification.
    """
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    cm = confusion_matrix(y_true, y_pred)
    if labels is None:
        classes = unique_labels(y_true, y_pred)
    else:
        classes = np.asarray(labels)
            
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        cm = np.around(cm, decimals=2)
        cm[np.isnan(cm)] = 0.0
            
    if true_labels is None:
        true_classes = classes
    else:
        true_label_indexes = np.in1d(classes, true_labels)
        true_classes = classes[true_label_indexes]
        cm = cm[true_label_indexes]
            
    if pred_labels is None:
        pred_classes = classes
    else:
        pred_label_indexes = np.in1d(classes, pred_labels)
        pred_classes = classes[pred_label_indexes]
        cm = cm[:, pred_label_indexes]
            
    data=[]
    count = 0
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        if labels is not None and (isinstance(pred_classes[i], int)
                                    or isinstance(pred_classes[0], np.integer)):
            pred_dict = labels[pred_classes[i]]
            true_dict = labels[true_classes[j]]
        else:
            pred_dict = pred_classes[i]
            true_dict = true_classes[j]
        data.append([pred_dict, true_dict, cm[i,j]])
        count+=1
    wandb.log({"confusion_matrix" : wandb.Table(
                columns=['Predicted', 'Actual', 'Count'],
                data=data)}


步骤2:为混淆矩阵创建一个自定义图表

W&B自定义图表使用 Vega(一种功能强大且灵活的可视化语言)编写。 您可以在线找到许多示例和演练,它可以帮助您从现有的最相似的预设开始,达到所需的自定义可视化。 您可以在我们的IDE中迭代一些细微的更改,从而在更改定义时渲染图表。这是此多分类混淆矩阵的完整Vega规范。

  • 从项目工作区或报告中,单击“添加可视化"Add a visualization",然后选择“自定义图表"Custom chart"

  • 选择任何现有预设并将其定义替换为以下Vega规范

  • 单击“另存为”为该预设命名,以方便参考(我建议使用"confusion_matrix" )

{
  "$schema": "https://vega.github.io/schema/vega-lite/v4.json",
  "description": "Multi-class confusion matrix",
  "data": {
    "name": "wandb"
  },
  "width": 40,
  "height": {"step":6},
  "spacing": 5,
  "mark" : "bar",
  "encoding": {
    "y": {"field": "name", "type": "nominal", "axis" : {"labels" : false}, 
         "title" : null, "scale": {"zero": false}},
    "x": {
      "field": "${field:count}",
      "type": "quantitative",
      "axis" : null,
      "title" : null
    },
    "tooltip": [
      {"field": "${field:count}", "type": "quantitative", "title" : "Count"},
      {"field": "name", "type": "nominal", "title" : "Run name"}
    ],
    "color": {
      "field": "name",
      "type": "nominal",
      "legend": {"orient": "top", "titleOrient": "left"},
      "title": "Run name"
    },
    "row": {"field": "${field:actual}", "title": "Actual", 
            "header": {"labelAlign" : "left", "labelAngle": 0}},
    "column": {"field": "${field:predicted}", "title": "Predicted"}
  }
}

步骤3:将记录的运行中的相关数据字段映射到图表中

在可视化IDE的右侧,修改运行查询以将运行数据输入到混淆矩阵中:

  • 输入您的自定义表ID作为tableKeys中的第一项——这是您登录wandb.Table的key,在该示例中是confusion_matrix
  • 使用查询编辑器的下拉菜单来连接匹配的字段,比如将记录到wandb.Table的“ Count”列中的值读入Vega图表的“ count”中,将“ Actual”列读入“ actual”中 ,等等。最终查询应如下所示:



Run set
0


随意定制

通过编辑Vega规范,可以调整图表的高度,宽度和颜色方案。 例如,此图表在颜色上使用“ scale”:{“ scheme”:“ rainbow”}重新对运行进行着色。

这里越深的蓝色对应于更多的训练示例/更多的时期epoch,并且他们通常沿对角线显示出更好的表现。 软体动物Mollusks被分类为“动物animals”,两栖动物Amphibians被分类为爬行动物Reptiles,这是最大模型的两个最常见的错误(训练10,000个示例,历时10个epoch)。 有趣的是,在某些单元格中,即使数据量少了10倍,绿色的“ NT 1000,E 10”模型表现也比最大的蓝色“ NT 10000,E 10”模型好。




Vary num train and num epochs
6

Iterate on AI agents and models faster. Try Weights & Biases today.