A widely believed explanation for the remarkable generalization capacities of overparameterized neural networks is that the optimization algorithms used for training induce an implicit bias towards benign solutions. To grasp this theoretically, recent works examine gradient descent and its variants in simplified training settings, often assuming vanishing learning rates. These studies reveal various forms of implicit regularization, such as $\ell_1$-norm minimizing parameters in regression and max-margin solutions in classification. Concurrently, empirical findings show that moderate to large learning rates exceeding standard stability thresholds lead to faster, albeit oscillatory, convergence in the so-called Edge-of-Stability regime, and induce an implicit bias towards minima of low sharpness (norm of training loss Hessian). In this work, we argue that a comprehensive understanding of the generalization performance of gradient descent requires analyzing the interaction between these various forms of implicit regularization. We empirically demonstrate that the learning rate balances between low parameter norm and low sharpness of the trained model. We furthermore prove for diagonal linear networks trained on a simple regression task that neither implicit bias alone minimizes the generalization error. These findings demonstrate that focusing on a single implicit bias is insufficient to explain good generalization, and they motivate a broader view of implicit regularization that captures the dynamic trade-off between norm and sharpness induced by non-negligible learning rates.
翻译:关于过参数化神经网络卓越泛化能力的一个广泛认可的解释是:训练过程中使用的优化算法会诱导出一种对良性解的隐式偏置。为从理论上理解这一现象,近期研究在简化的训练场景中考察梯度下降及其变体,通常假设学习率趋近于零。这些研究揭示了多种形式的隐式正则化,例如回归任务中参数的最小化$\ell_1$范数特性,以及分类任务中的最大间隔解。与此同时,实证研究表明,超过标准稳定性阈值的中等到大学习率会在所谓的“稳定性边缘”区域中导致更快(尽管存在振荡)的收敛,并诱导出对低锐度(训练损失Hessian矩阵范数)最小值的隐式偏置。本工作中,我们认为要全面理解梯度下降的泛化性能,需要分析这些不同形式隐式正则化之间的相互作用。我们通过实证证明,学习率在低参数范数与低模型锐度之间起到平衡作用。此外,我们针对在简单回归任务上训练的对角线性网络证明:单一隐式偏置均无法最小化泛化误差。这些发现表明,仅关注单一隐式偏置不足以解释良好的泛化性能,从而激励我们以更广阔的视角审视隐式正则化,以捕捉非可忽略学习率所引发的范数与锐度之间的动态权衡。