论文标题
理解具有较大批量大小的深神经网络的概括差距的新观点
A New Perspective for Understanding Generalization Gap of Deep Neural Networks Trained with Large Batch Sizes
论文作者
论文摘要
深度神经网络(DNN)通常使用各种形式的迷你批次梯度下降算法优化。迷你批次梯度下降的主要动机是,具有适当选择的批量大小,可用的计算资源可以最佳利用(包括并行化)用于快速模型训练。但是,许多作品报告了训练批量大小超出某些限制时的逐渐损失模型概括。这是通常称为概括差距的场景。尽管有几项作品提出了减轻概括差距问题的不同方法,但文献中仍然缺乏了解概括差距的一致解释。鉴于最近的作品观察到,一些用于泛化差距问题的解决方案,学习率缩放和增加培训预算确实无法解决问题,这一点尤其重要。因此,本文中我们的主要博览会是调查并为经过较大批量训练的DNN的概括损失来源提供新的观点。我们的分析表明,较大的训练批量大小会导致单位激活(即输出)张量的接近量损失增加,从而影响模型的优化和概括。使用CIFAR-10,CIFAR-100,Fashion-Mnist和MNIST数据集对流行的DNN模型(例如VGG-16,Resnet-56)和LENET-5进行验证进行了广泛的实验。
Deep neural networks (DNNs) are typically optimized using various forms of mini-batch gradient descent algorithm. A major motivation for mini-batch gradient descent is that with a suitably chosen batch size, available computing resources can be optimally utilized (including parallelization) for fast model training. However, many works report the progressive loss of model generalization when the training batch size is increased beyond some limits. This is a scenario commonly referred to as generalization gap. Although several works have proposed different methods for alleviating the generalization gap problem, a unanimous account for understanding generalization gap is still lacking in the literature. This is especially important given that recent works have observed that several proposed solutions for generalization gap problem such learning rate scaling and increased training budget do not indeed resolve it. As such, our main exposition in this paper is to investigate and provide new perspectives for the source of generalization loss for DNNs trained with a large batch size. Our analysis suggests that large training batch size results in increased near-rank loss of units' activation (i.e. output) tensors, which consequently impacts model optimization and generalization. Extensive experiments are performed for validation on popular DNN models such as VGG-16, residual network (ResNet-56) and LeNet-5 using CIFAR-10, CIFAR-100, Fashion-MNIST and MNIST datasets.