训练神经网络的最佳批大小是多少?
简要研究批大小对测试准确性的影响。
Created on March 25|Last edited on July 12
Comment
导言
“最佳的批大小是多少?”尽管我们意识到这种问题的答案几乎总是一样(“看情况”),但我们今天的目标就是研究不同的批大小是如何影响准确性、训练时间和计算资源的。然后,我们将探讨一些解释这些差异的假设。
开始调查吧!
我们首先需要确定批大小对测试准确性和训练时间的影响。
为此,我们进行一个消融实验(ablation study)。我们将使用一个图像分类任务,并测试不同批大小时准确性如何变化。在这些测试中,我们将着重注意以下几点:
- 我们不会使用过度参数化的网络——这有助于避免过度拟合。
- 我们将使用不同的批大小对我们的模型进行25个时期的训练。
尝试在Google colab上进行消融实验
Run set
9
- 首先,从验证指标来看,以小批大小训练的模型在验证集上有很好的通用性。
- 批大小为32时,结果最好。批大小为2048时,结果最差。在我们的研究中,我们使用从8到2048的批大小来训练我们的模型,每个批大小是前一个批大小的两倍。
- 我们的平行坐标图也非常明显地显示了一个关键的权衡:批大小越大,训练所需的时间越少,但准确性就越低。
Run set
9
- 进一步深入研究,我们可以清楚地看到,随着我们从较高的批大小过渡到较低的批大小,测试的错误率呈指数下降。不过,请注意,批大小为32时,我们的错误率最低。
- 随着我们从较高的批大小过渡到较低的批大小,所需的训练时间呈指数增加。这是意料之中的!由于我们没有在模型开始过度拟合时使用早期停止,而是允许它训练25个时期,所以我们必然会看到训练时间的增加。,
为什么批大小越大,泛化能力越差呢?
- 基于梯度下降的优化对成本函数进行线性逼近。但是,如果成本函数是高度非线性(高度弯曲)的,则近似将不是很好,因此小批量反而是安全的。
- 将m个示例放入一个迷你批中时,需要进行O(m)计算并使用O(m)内存,但是您只是将梯度中的不确定性量减少了O(sqrt(m))。换句话说,在迷你批放置更多示例的边际收益正在递减。
- 即使使用整个训练集,您也不会获得真正的梯度。使用整个训练集其实就是使用一个非常大的迷你批。
- 与较大的批大小相比,较小的批大小的梯度振荡更多。这种振荡可以被认为是噪声,但是对于非凸损耗的场景(通常是这种情况),这种噪声有助于摆脱局部极小值。因此,批大小越大,搜索最优解的步骤就越少、越粗糙。由此,从构造上看,就不太可能收敛于最优解。
Weights & Biases
Weights & Biases可帮助您跟踪机器学习实验。使用我们的工具记录您运行中的超参数和输出指标,然后可视化和比较结果,并与您的同事快速共享发现。
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.