自定义线图
wandb.plot.line()的用法与示例
Created on June 16|Last edited on January 27
Comment
方法:wandb.plot.line()
仅用几行代码即可在本地记录自定义线图——一对任意轴x和y上的一系列连接点/有序点(x, y):
data = [[x, y] for (x, y) in zip(x_values, y_values)]
table = wandb.Table(data=data, columns = ["x", "y"])
wandb.log({"my_custom_plot_id" : wandb.plot.line(table,
"x", "y", title="Custom Y vs X Line Plot")
你可以用这段代码记录所有的二维曲线。注意,如果你要绘制相互对应的两列值,两个列表中值的数量必须恰好匹配(例如,每个点都必须有一个x和一个y)。
Toy CNN variants
6
基本用法示例
我微调一个卷积神经网络来预测10类生物:植物、鸟类、昆虫等等。我要绘制二值化平均精度曲线(10个类都简化为一个二元的正/误标签)。在我的验证步骤中,我计算微平均精度并利用Sklearn回调,回调时遵循他们的多标签设置范例,从而产生两个相同长度的数组:precision_micro
和recall_micro
(可在本报告结尾查看详细信息)。我绘制的线图将用y轴表示precision_micro
值,用x轴表示recall_micro
值。
现在我就可以调用:
data = [[x, y] for (x, y) in zip(recall_micro, precision_micro)]
table = wandb.Table(data=data, columns = ["recall_micro", "precision_micro"])
wandb.log({"my_lineplot_id" : wandb.plot.line(table, "recall_micro", "precision_micro", stroke=None, title="Average Precision")})
按照以下步骤:
- 创建一个对象data:收集点并组成二维列表/数组,每一行是一个点,每一列是一个维度。这个线图呈现的是二维/两列,但你完全可以传入更多数据并进一步自定义图表。你还可以在
wandb.plot.line()
中使用stroke=自变量(可选),用来传入第三个字段,对应着线型(实线、虚线、点线……)或颜色。 - 将
data
传递给对象wandb.Table()
,在这个对象中命名列,以便于之后引用这些列。 - 把对象
table
以及相同的x列名称和y列名称按顺序传入wandb.plot.line()
,有个标题可选,这将在键my_lineplot_id
之下创建自定义图表。为了在同一个图表上可视化多个运行项,则保持该图表键不变。提示:表格本身也会被记录到工作空间的“Media”(媒体)部分,在my_lineplot_id_table
中。
自定义用法
利用Vega可视化语法,有多种方法自定义线图。
下面是几个简单的:
- 为了清晰起见,重命名轴标题:为
encoding
下面的字段x和y添加"title"
:"Your Title"
。 - 修改色谱以反映训练数据/周期的增加,曲线从紫色逐渐变为黄色:
"color": {
"type": "nominal",
"field" : "name",
"scale" : {"scheme" : "plasma"}
}
在每个图表中,都可以平移并缩放以查看细节,并且悬停鼠标可查看点的详细信息(还可以修改显示信息!)。还能在右上方悬停鼠标并点击“眼睛”图标,即可查看定义图表的完整Vega参数。接口wandb.plot.line
的定义在这里。
Toy CNN variants
6
补充:计算多类模型的平均精度
要计算该项,你的代码要有权限获取:
- 模型对一系列样本的预测分数(
val_predictions
); - 这些样本对应的真实标签(
ground_truth
)。
from sklearn.metrics import precision_recall_curve
from sklearn.preprocessing import label_binarize
# generate binary correctness labels across classes
binary_ground_truth = label_binarize(ground_truth,
classes=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
# compute a PR curve with sklearn like you normally would
precision_micro, recall_micro, _ = precision_recall_curve(binary_ground_truth.ravel(),
val_predictions.ravel())
# now you can log these values to a custom chart!
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.