Pham Duy Khanh, Hoang-Chau Luong, Boris S. Mordukhovich, Dat Ba Tran
This repo implements paper "Fundamental Convergence Analysis of Sharpness-Aware Minimization" which was published at NeurIPS 2024. SAM's implementation based on davda54/sam.
The paper investigates the fundamental convergence properties of Sharpness-Aware Minimization (SAM), a recently proposed gradient-based optimization method [Foret et al., 2021] that significantly improves the generalization of deep neural networks. The convergence properties, including the stationarity of accumulation points, the convergence of the sequence of gradients to the origin, the sequence of function values to the optimal value, and the sequence of iterates to the optimal solution, are established for the method. The universality of the provided convergence analysis, based on inexact gradient descent frameworks Khanh et al. [2023b], allows its extensions to efficient normalized versions of SAM such as F-SAM [Li et al., 2024], VaSSO [Li and Giannakis, 2023], RSAM [Liu et al., 2022], and to the unnormalized versions of SAM such as USAM [Andriushchenko and Flammarion, 2022]. Numerical experiments are conducted on classification tasks using deep learning models to confirm the practical aspects of our analysis.
In this paper, we establish the convergence of SAM to a stationary point. Unlike previous studies that demonstrate only near-stationary convergence due to the assumption of a constant step size (commonly used as the learning rate in Deep Learning), our work proves exact convergence by employing a diminishing step size. To support our theory, we conduct experiments in this codebase, showing that a diminishing step size leads to better performance compared to a constant step size.
You need to have Anaconda installed on your machine. To install the dependencies, you can use a package manager like conda and pip:
conda env create -f environment.yml
conda activate fundamental-sam
pip install -e .You need to install WandB for experiment tracking:
pip install wandb
wandb loginChoose any file name in ./config and replace it in the below bash.
python fundamental_sam/train_sam.py --experiment=config/file/name