In Sharpness-Aware Minimization, they compute gradients at two different positions within one update step, denoted as
To validate the geometry of loss landscape, we use the ratio of the perturbed gradient to the current gradient :
The figure below illustrates the effectiveness of this metric in depicting the local convex minima.
In the figure, checkpoint1 and checkpoint2 are notable because their perturbed gradients share the same sign as the current gradient. Experiments show that during training, over checkpoint1. Conversely, checkpoint2 starts at around
In contrast, checkpoint3 and checkpoint4 depict scenarios where the signs of the perturbed and current gradients differ. During training, the number of parameters associated with checkpoint3 and checkpoint4 is initially very low, at around
- Loss Landscape Geometry:
- The geometry is predominantly characterized by
checkpoint1. - As training progresses, there is a transition through
checkpoint2,checkpoint3, andcheckpoint4. - In the late stages of training, the loss landscape shows multiple minima regions, indicated by an increase in
checkpoint3andcheckpoint4. - Surprisingly, the number of
checkpoint1parameters does not decrease, suggesting a significant number of parameters remain on the ridge of the landscape.
- The geometry is predominantly characterized by
Hypothesis: The parameters of SAM have the same Direction as SGD but their Magnitude changes, taking responsibility to increase test accuracy
| rho | ori-SAM | ckpt12 |
|---|---|---|
| 78.84 | 79.46 (0.778) | |
| 79.55 (0.7418) | 79.78 (0.7682) | |
| 79.21 (0.7235) | 79.79 (0.7525) | |
| 77.98 (0.7656) | 79.09 (0.7719) | |
| 74.26 (0.9172) | 78.60 (0.7787) |
Hypothesis: The parameters of SAM have the opposite Direction of SGD, taking responsibility to reduce test loss
| rho | SGD | ori-SAM | ckpt3 | ckpt4 | ckpt34 |
|---|---|---|---|---|---|
| 77.36 (0.9116) | 78.84 | 78.38 (0.8373) | 78.60 (0.8784) | 77.29 (0.7946) | |
| 77.36 (0.9116) | 79.55 (0.7418) | 78.15 (0.8151) | 78.30 (0.8722) | 78.14 (0.7650) | |
| 77.36 (0.9116) | 79.21 (0.7235) | 78.49 (0.7689) | 77.84 (0.8671) | 78.51 (0.7640) | |
| 77.36 (0.9116) | 77.98 (0.7656) | 77.72 (0.8189) | 77.48 (0.8797) | 78.60 (0.8784) | |
| 77.36 (0.9116) | 74.26 (0.9172) | 76.98 (0.8440) | 77.56 (0.8732) | 78.60 (0.8784) |
Key Question: Which checkpoint among the four contributes most to finding the flat minima in SAM?
| rho | SGD | ori-SAM | checkpoint1 | checkpoint2 | checkpoint3 | checkpoint4 |
|---|---|---|---|---|---|---|
| 77.36 | 78.84 | 79.06 | 78.26 | 78.38 | 78.60 | |
| 77.36 | 79.55 | 79.31 | 78.27 | 78.15 | 78.30 | |
| 77.36 | 79.21 | 79.58 | 77.99 | 79.04 or 78.49 | 77.84 | |
| 77.36 | 77.98 | 78.58 | 78.16 | 77.72 | 77.48 | |
| 77.36 | 74.26 | 78.79 | 77.38 | 76.98 | 77.56 |
| rho | SGD | ori-SAM | ckpt12 | ckpt13 | ckpt14 | ckpt134 | ckpt123 | ckpt234 |
|---|---|---|---|---|---|---|---|---|
| 77.36 | 78.84 | 79.46 | 79.64 | 78.58 | 79.41 | 79.44 | ??? | |
| 77.36 | 79.55 | 79.78 | 79.81 | 78.68 | 78.94 | 79.85 | 78.82 | |
| 77.36 | 79.21 | 79.79 | 79.29 | 78.91 | 78.86 | 79.73 | 78.40 | |
| 77.36 | 77.98 | 79.09 | 79.23 | 77.21 | 79.72 | 79.27 | 76.79 | |
| 77.36 | 74.26 | 78.60 | 79.36 | 75.86 | 78.72 | 78.18 | ??? |
Setup variants:
- SAMECKPT1: mask = ratio > 1
- grad(SAMECKPT1) = grad(SAM) * mask * condition + grad(SAM) * not(mask) Similar to SAMECKPT2, SAMECKPT3, SAMECKPT4
| rho\SAMECKPT1 | ori-SAM | condition=0.5 | condition=0.667 | condition=1.2 | condition=1.5 | condition=2 |
|---|---|---|---|---|---|---|
| 78.75 (0.8063) | 78.74 (0.811) | 78.63 (0.8165) | 78.32 (0.8086) | 78.86 (0.8102) | 78.66 (0.8077) | |
| 79.35 (0.7526) | 78.98 (0.7813) | 79.17 (0.7713) | ?? | 79.33 (0.7637) | 78.78 (0.7778) | |
| 79.55 (0.7418) | 79.28 (0.744) | 79.36 (0.7455) | 79.29 (0.7388) | ?? | 78.50 (0.7758) | |
| 79.21 (0.7235) | 78.76 (0.7333) | ?? | 79.31 (0.7157) | 79.05 (0.7219) | 79.09 (0.7265) |
| rho\SAMECKPT2 | ori-SAM | condition=0.1 | condition=0.3 | condition=0.5 | condition=0.667 | condition=1.5 | condition=2 | condition=3 |
|---|---|---|---|---|---|---|---|---|
| 78.75 (0.8063) | 78.93 (0.7553) | 79.65 (0.7291) | 79.47 (0.7531) | 79.57 (0.7714) | 78.15 (0.8229) | ?? | ?? | |
| 79.35 (0.7526) | 79.27 (0.722) | 80.01 (0.7081) | 79.94 (0.7263) | 79.73 (0.732) | 78.35 (0.7813) | ?? | ?? | |
| 79.55 (0.7418) | 78.58 (0.7368) | 79.78 (0.6941) | 80.02 (0.7111) | 79.58 (0.7283) | 79.23 (0.7446) | ?? | ?? | |
| 79.21 (0.7235) | ?? | ?? | ?? | 78.70 (0.7337) | 79.88 (0.7134) | 79.32 (0.7260) | 79.09 (0.7265) | |
| 77.98 (0.7656) | ?? | ?? | ?? | ?? | 79.33 (0.7171) | 78.87 (0.7099) | 78.48 (0.7712) | |
| 74.26 (0.9172) | ?? | ?? | ?? | ?? | ?? | 78.99 (0.7158) | 79.43 (0.7498) |
| rho\SAMECKPT234 | ori-SAM | condition=0.1 | condition=0.3 | condition=0.5 |
|---|---|---|---|---|
| 78.75 (0.8063) | 79.66 (0.7179) | 79.92 (0.7325) | 79.59 (0.742) | |
| 79.35 (0.7526) | ?? | 79.61 (0.7091) | 79.95 (0.7227) | |
| 79.55 (0.7418) | ?? | 79.66 (0.7016) | 80.16 (0.697) |
| rho\SAMECKPT3-RN18 | ori-SAM | condition=0.1 | condition=0.5 | condition=1.5 | condition=2 |
|---|---|---|---|---|---|
| 79.35 (0.7526) | 79.39 (0.7656) | 79.1 (0.7654) | 78.99 (0.7681) | 78.98 (0.7647) | |
| 79.55 (0.7418) | 79.76 (0.7307) | 79.41 (0.738) | 79.26 (0.7339) | 79.16 (0.7378) | |
| 79.21 (0.7235) | 80.46 (0.7026) | 79.57 (0.7198) | 79.54 (0.7172) | 78.81 (0.7335) | |
| 77.98 (0.7656) | 79.64 (0.7186) | ?? | ?? | ?? |
| rho\SAMECKPT4-RN18 | ori-SAM | condition=0.1 | condition=0.5 | condition=0.667 | condition=1.5 | condition=2 |
|---|---|---|---|---|---|---|
| 79.35 (0.7526) | 79.01 (0.7682) | 79.4 (0.7605) | ?? | 79.64 (0.759) | 79.46 (0.7571) | |
| 79.55 (0.7418) | 79.17 (0.7485) | 79.54 (0.7332) | 79.27 (0.7303) | 79.82 (0.7315) | 78.9 (0.7458) | |
| 79.21 (0.7235) | ?? | ?? | 79.22 (0.7178) | 79.24 (0.7244) | 79.21 (0.7149) |
| rho\SAMECKPT3-RN34 | ori-SAM | condition=0.1 |
|---|---|---|
| 80.4 (0.6995) | 81.63 (0.6792) | |
| 80.89 (0.6576) | 81.17 (0.6934) | |
| 79.65 (0.6935) | ?? |
| rho\SAMECKPT3-WRN28-10 | ori-SAM | condition=0.1 |
|---|---|---|
| 83.91 (0.5886) | 83.16 (0.6265) | |
| 83.44 (0.5697) | 83.34 (0.6147) | |
| 83.23 (0.5791) | 83.18 (0.6016) |
- Hypothesis: Maintaining the magnitude of all parameters in
checkpoint1while replacing others with the magnitude from SGD would still retain SAM's ability to find flat minima. - Results:
- SAM's ability to find flat minima was maintained.
- Repeating this experiment with other checkpoints resulted in sharper minima.
- Conclusion: The realistic ability of SAM is due to its effective learning rate, evidenced by
checkpoint1, not direction modification.
-
Research Question: How long does the ratio of perturbed gradient to current gradient in
checkpoint1remain consistent? -
Results:
- At step 172,
checkpoint1contained$6 \times 10^6$ parameters. - There was an overlap of about
$4 \times 10^6$ parameters with later steps, maintaining around$4 \times 10^6$ (or 40%) even at the final step. - This indicates that 40% of parameters require a higher learning rate than the initial one.
- At step 172,
-
Research Question: Are the same
$4 \times 10^6$ parameters consistent over many steps? -
Answer: No. After just 5 steps, the overlap reduced to
$10^5$ and continued to diminish.
- Research Question: Does large batch training lower generalization due to low learning rates?
- Comparison:
- ResNet18 on CIFAR100 with batch size 1024 and SGD learning rates of 0.1 and 0.2.
- The gradient norm with batch size 1024 was lower than with batch size 256, hinting slower learning with a higher batch size.
- Increasing the learning rate in the batch size 1024 experiment improved test accuracy.
- The effectiveness of SAM is attributed to its adaptive learning rate.
- 60 % of parameters do not converge even after 200 epochs, potentially due to learning rate decay.
We classified the parameters from checkpoint1 into two types, as illustrated in the figure below:
-
Initial Approach: Diagonal Hessian
- We initially attempted to use the diagonal Hessian to determine parameter flatness. However, the shape of this approximation differed from the gradient, making it challenging to proceed with this metric.
-
Current Approach: Gradient Magnitude
- We decided to use the gradient magnitude as the metric to determine the flatness of each parameter. This approach provided a more straightforward and consistent method for our analysis.
-
Data Preparation:
- We stored the flattened tensors of
checkpoint1and computed the absolute values of the gradient magnitudes.
- We stored the flattened tensors of
-
Threshold Determination:
- Our hypothesis is that a higher gradient magnitude corresponds to higher flatness. To verify this, we sorted the
magnitude_gradienttensor in decreasing order. - We then determined the
thresholdas the value ofmagnitude_gradientat the position corresponding to the length ofcheckpoint1.
- Our hypothesis is that a higher gradient magnitude corresponds to higher flatness. To verify this, we sorted the
-
Calculation of Flat Minima:
- We calculated the percentage of parameters from
checkpoint1with gradient magnitudes greater than the determined threshold. This percentage represents the proportion of parameters incheckpoint1that are in flat minima.
- We calculated the percentage of parameters from
By following this methodology, we can identify the flat minima among the parameters in checkpoint1 and gain insights into their distribution.
| Epoch | mag_grad[len(checkpoint1)] |
percent % |
|---|---|---|
| 5 | 0.00013771 | 63 |
| 10 | 0.00015848 | 64 |
| 15 | 0.00016051 | 65 |
| 20 | 0.00016005 | 67 |
| 25 | 0.00015288 | 67 |
| 30 | 0.00017219 | 66 |
| 35 | 0.00014082 | 69 |
| 40 | 0.00019160 | 66 |
| 45 | 0.00017323 | 69 |
| 50 | 0.00016461 | 69 |
| 55 | 0.00018168 | 66 |
| 60 | 0.00019630 | 67 |
| 65 | 0.00017031 | 68 |
| 70 | 0.00017464 | 66 |
| 75 | 0.00016619 | 70 |
| 80 | 0.00017488 | 71 |
| 85 | 0.00018003 | 69 |
| 90 | 0.00016166 | 70 |
| 95 | 0.00011792 | 75 |
| 100 | 0.00014122 | 75 |
| 105 | 0.00015415 | 74 |
| 110 | 0.00014154 | 74 |
| 115 | 0.00019645 | 71 |
| 120 | 0.00019847 | 71 |
| 125 | 0.00017283 | 72 |
| 130 | 0.00018389 | 72 |
| 135 | 0.00013861 | 73 |
| 140 | 0.00014639 | 76 |
| 145 | 0.00012558 | 73 |
| 150 | 0.00012297 | 72 |
| 155 | 0.00009237 | 76 |
| 160 | 0.00007134 | 78 |
| 165 | 0.00012038 | 73 |
| 170 | 0.00010611 | 67 |
| 175 | 0.00011405 | 67 |
| 180 | 0.00004605 | 80 |
| 185 | 0.00007793 | 66 |
| 190 | 0.00008220 | 64 |
| 195 | 0.00004291 | 77 |
| 200 | 0.00015874 | 59 |
-
Checkpoint1 Gradient Magnitude:
- In
checkpoint1, over 60 of parameters have a gradient magnitude exceeding the threshold. - There is an observable trend of increasing gradient magnitude in the later stages.
- In
-
Sharp Region Identification:
- The experiment suggests that most parameters in
checkpoint1reside in a sharp region, which we denote ascheckpoint1.1.
- The experiment suggests that most parameters in
Statistic for checkpoint1?
Methods to validate: Counting how many parameters having ratio checkpoint1.1, checkpoint1.2, checkpoint1.3, checkpoint1.4, checkpoint1.5, checkpoint1.6, respectively. The results shown as below:
- At first, the parameters belong to
checkpoint1.1which have the ratio$< 1.2$ are highest, but then decrease fast. - In the later stage, almost the parameters belong to
checkpoint1.1has the ratio increase fast, evidenced bycheckpoint1.3,checkpoint1.4,checkpoint1.5,checkpoint1.6.
- Instead of increasing magnitude follows ratio, increase with the fixed constant.
- Why preconditioning
$H(w)$ is not effecitve for deep learning especially over-parameterized setting? - Results of this project give intuition for IRE. Because IRE double learning rate for
$95 %$ parameters.
Check the hypothesis: The number of ratio > 1, which is checkpoint1 are higher, the generalization of SAM is higher?
- Methods to validate: We see that if the high increasing perturbation radius
$\rho$ means the decreasing ofcheckpoint1because each parameter have their own threshold of not escaping attractor. - We run SAM with different perturbation radius value
$\rho = 0.01, \rho=0.05, \rho=0.1$ and combining them parameters withratio > 1and counting proportion ofcheckpoint1- The worst case: the number of
checkpoint1does not affect generalization. - The best case: The number of
checkpoint1have a correlation with generalization.
- The worst case: the number of



