This note provides a comprehensive overview of the dynamical asymptotic analysis for online SGD in the high-dimensional limit.
While the replica method analyzes static properties (post-training), dynamical analysis tracks the entire learning trajectory. This enables understanding of:
- Convergence rates and learning curves
- Transient dynamics and phase transitions
- Optimal learning rate schedules
- Early stopping strategies
The key insight is that high-dimensional stochastic dynamics concentrate to deterministic ODEs in the macroscopic variables.
At each discrete time step
where:
-
$\tau$ : learning rate (step size) -
$r(\cdot)$ : per-sample loss function
Critical distinction from batch SGD: Each sample is used exactly once, corresponding to infinite dataset streaming.
Define the macroscopic state
For neural networks, these are typically order parameters:
- Student-teacher overlaps
- Student self-overlaps
- Other sufficient statistics
The macroscopic state evolves as a stochastic process:
where
For the stochastic process
Condition 1: Mean-Incremental Concentration
The conditional mean increment matches the ODE drift vector
Condition 2: Mean-Variance Concentration
The variance of increments decays faster than
Condition 3: Initial State Concentration
The initial state concentrates to a deterministic value.
Theorem (Concentration to ODE): Under Conditions 1-3 with
where
-
Time rescaling: Discrete time
$t$ maps to continuous time$t/d$ -
Concentration rate:
$O(1/\sqrt{d})$ error -
Deterministic dynamics: Stochastic fluctuations vanish as
$d \to \infty$
-
Identify macroscopic variables
$\phi$ (order parameters) -
Compute conditional expectation
$\mathbb{E}_t[\phi^{t+1}|\phi^t]$ -
Extract the
$O(1/d)$ drift:$\mathbf{g}(\phi) = d \cdot \mathbb{E}_t[\phi^{t+1} - \phi^t]$ - Verify concentration conditions
Gaussian Equivalence: For i.i.d.
where
Stein's Lemma: For
A narrow two-layer network with
where:
-
$\mathbf{a} \in \mathbb{R}^k$ : second-layer weights (fixed or trained) - $W = (\mathbf{w}l){l \in [k]} \in \mathbb{R}^{d \times k}$: first-layer weights
-
$\mathbf{b} \in \mathbb{R}^k$ : bias terms -
$\sigma: \mathbb{R} \to \mathbb{R}$ : activation function (applied element-wise)
Teacher network with
Observed label with noise:
Student self-overlap matrix
Student-teacher overlap matrix $M \in \mathbb{R}^{k \times k^}$: $$M_{lm} = \frac{1}{d} \mathbf{w}_l^\top \mathbf{w}_m^$$
Teacher self-overlap matrix $P \in \mathbb{R}^{k^* \times k^}$: $$P_{mn} = \frac{1}{d} (\mathbf{w}_m^)^\top \mathbf{w}_n^*$$
For a data point
Student pre-activation:
Teacher pre-activation: $$\boldsymbol{\Upsilon}^* = \frac{1}{\sqrt{d}} (W^)^\top \mathbf{x} + \mathbf{b}^ \in \mathbb{R}^{k^*}$$
Joint distribution (LLMM data model): $$\begin{pmatrix} \boldsymbol{\Upsilon} \ \boldsymbol{\Upsilon}^* \end{pmatrix} \sim \mathcal{N}\left(\begin{pmatrix} \mathbf{b} \ \mathbf{b}^* \end{pmatrix}, \begin{pmatrix} Q & M \ M^\top & P \end{pmatrix}\right)$$
Prediction error: $$\Gamma = \hat{y} - y = \mathbf{a}^\top \sigma(\boldsymbol{\Upsilon}) - (\mathbf{a}^)^\top \sigma(\boldsymbol{\Upsilon}^) - \sigma_\epsilon \xi$$
Gradient with respect to first-layer weights:
For squared loss
Update for
Taking expectation and
where
Similarly:
The expectations involve Gaussian integrals over the joint distribution:
These can be computed using:
- Gauss-Hermite quadrature for low-dimensional integrals
- Monte Carlo sampling for higher dimensions
- Analytical formulas for specific activations (linear, ReLU)
For linear models (
-
$Q = q \in \mathbb{R}$ (student self-overlap) -
$M = m \in \mathbb{R}$ (student-teacher overlap) -
$P = \rho \in \mathbb{R}$ (teacher norm)
This represents the expected training loss (without regularization).
ODE for
ODE for
Steady state (
Stability condition:
Generalization error dynamics:
Steady-state generalization error:
For
-
$\sigma'(x) = \mathbf{1}_{x > 0}$ (Heaviside step function) $\sigma(x) \cdot \sigma'(x) = \sigma(x)$
For $(u, v)^\top \sim \mathcal{N}(\mathbf{0}, \begin{pmatrix} a & c \ c & b \end{pmatrix})$:
from statphys.theory.online import ODESolver, GaussianLinearMseEquations
# Define ODE equations
equations = GaussianLinearMseEquations(
rho=1.0, # Teacher norm ||w*||^2/d
eta_noise=0.1, # Noise variance
lr=0.5, # Learning rate
reg_param=0.01, # Regularization
)
# Create solver
solver = ODESolver(
equations=equations,
order_params=["m", "q"],
)
# Solve ODE
result = solver.solve(
t_span=(0, 10),
init_values=(0.0, 0.1), # Initial (m0, q0)
n_points=100,
)
# Extract learning curve
t = result.t_values
m = result.order_params["m"]
q = result.order_params["q"]
eg = 0.5 * (equations.rho - 2*m + q)
import matplotlib.pyplot as plt
plt.plot(t, eg)
plt.xlabel("Time t")
plt.ylabel("Generalization Error")
plt.title("Online Learning Dynamics")
plt.show()from statphys.theory.online.scenario.base import OnlineEquations
import numpy as np
class TwoLayerReLUEquations(OnlineEquations):
def __init__(self, teacher_overlap_P, a_student, a_teacher, lr=0.1, reg_param=0.0):
self.P = teacher_overlap_P
self.a = a_student
self.a_star = a_teacher
self.lr = lr
self.reg = reg_param
self.k = len(a_student)
self.k_star = len(a_teacher)
def __call__(self, t, state):
# Unpack state: Q (k*k), M (k*k*)
Q = state[:self.k**2].reshape(self.k, self.k)
M = state[self.k**2:].reshape(self.k, self.k_star)
# Compute expectations (using Gaussian integrals)
dQ = self._compute_dQ(Q, M)
dM = self._compute_dM(Q, M)
return np.concatenate([dQ.flatten(), dM.flatten()])
def _compute_dQ(self, Q, M):
# Implement ODE for Q using numerical integration
# dQ/dt = -tau E[Gamma (a*sigma')Y^T + ...] + tau^2 E[...] - 2*tau*lambda*Q
pass
def _compute_dM(self, Q, M):
# Implement ODE for M
# dM/dt = -tau E[Gamma (a*sigma')(Y*)^T] - tau*lambda*M
passfrom statphys.simulation import OnlineSimulation, SimulationConfig
from statphys.dataset import GaussianDataset
from statphys.model import LinearRegression
from statphys.loss import MSELoss
from statphys.vis import ComparisonPlotter
# Setup
dataset = GaussianDataset(d=1000, rho=1.0, eta=0.1)
config = SimulationConfig.for_online(
t_max=10.0,
lr=0.1,
n_seeds=5,
use_theory=True,
)
# Theory solver
from statphys.theory.online import ODESolver, GaussianLinearMseEquations
equations = GaussianLinearMseEquations(rho=1.0, eta_noise=0.1, lr=0.1)
theory_solver = ODESolver(equations=equations, order_params=["m", "q"])
# Run simulation
sim = OnlineSimulation(config)
results = sim.run(
dataset=dataset,
model_class=LinearRegression,
loss_fn=MSELoss(),
theory_solver=theory_solver,
)
# Plot comparison
plotter = ComparisonPlotter()
plotter.plot_theory_vs_experiment(results)In the limit
- Learning dynamics are described by ODEs with
$O(k^2 + kk^*)$ variables - Independent of the ambient dimension
$d$ - Enables analysis of arbitrarily high-dimensional problems
The stochastic process concentrates to the ODE at rate
- Finite-$d$ simulations converge to theory as
$d$ increases - Provides quantitative prediction for practical systems
Critical learning rate:
-
$\tau < \tau_c$ : Convergent regime -
$\tau > \tau_c$ : Divergent regime
The
- Larger learning rate amplifies noise
- Trade-off between convergence speed and steady-state error
For multi-hidden-unit networks:
- Symmetric phase: All hidden units have similar weights
- Symmetry breaking: Units specialize to different features
- Convergence phase: Approach to optimal solution
For mini-batch size
The noise term scales as
For time-varying
Optimal schedules can be derived analytically.
Additional order parameters for momentum:
- Saad, D. & Solla, S.A. (1995). "On-line learning in soft committee machines." Phys. Rev. E.
- Biehl, M. & Schwarze, H. (1995). "Learning by on-line gradient descent." J. Phys. A.
- Werfel, J., Xie, X., & Seung, H.S. (2005). "Learning curves for stochastic gradient descent." Neural Computation.
- Goldt, S., Mezard, M., Krzakala, F., & Zdeborova, L. (2020). "Modeling the influence of data structure on learning in neural networks." Phys. Rev. X.
- Engel, A. & Van den Broeck, C. (2001). Statistical Mechanics of Learning. Cambridge University Press.