Skip to content

SDLAML/norm-transfer

Repository files navigation

Norm Transfer Toolkit

Code and notebooks for studying how parameter norms, learning rates, and batch sizes interact across long-horizon language model training sweeps. The repository provides a full pipeline: download raw metrics from Weights & Biases, standardise the CSVs, fit per-horizon optima, and generate publication-ready plots.

This toolkit has been extensively used to analyse data for our work "Optimal Scaling Needs Optimal Norm", see the accompanying W&B report and the abstract:

Despite recent progress in optimal hyperparameter transfer under model and dataset scaling, no unifying explanatory principle has been established. Using the Scion optimizer, we discover that joint optimal scaling across model and dataset sizes is governed by a single invariant: the operator norm of the output layer. Across models with up to 1.3B parameters trained on up to 138B tokens, the optimal learning rate/batch size pair (η∗, B∗) consistently has the same operator norm value – a phenomenon we term norm transfer. This constant norm condition is necessary but not sufficient: while for each dataset size, multiple (η, B) reach the optimal norm, only a unique (η∗, B∗) achieves the best loss. As a sufficient condition, we provide the first measurement of (η∗, B∗) scaling with dataset size for Scion, and find that the scaling rules are consistent with those of the Adam optimizer. Tuning per-layer-group learning rates also improves model performance, with the output layer being the most sensitive and hidden layers benefiting from lower learning rates. We provide practical insights on norm-guided optimal scaling and release our Distributed Scion (Disco) implementation with logs from over two thousand runs to support research on LLM training dynamics at scale

Check out also our Distributed Scion repo, optimized for training at scale: https://github.com/SDLAML/disco/

Repository highlights

  • Command line scripts to pull sweep histories from W&B and to normalise the CSV layout used throughout the analysis.
  • Reusable fitting utilities that locate quadratic minima for each (horizon, batch size) pair and expose a configurable analysis Config dataclass.
  • A rich plotting module covering scatter/line plots, per-model comparisons, regression fits, parallel coordinates, interactive Plotly views, and convenience legends.
  • Example notebooks that reproduce the figures used in the accompanying paper submissions.
  • Pre-generated artefacts (data/, plots/) that illustrate the expected directory structure and naming conventions.

Project layout

Path Purpose
extract_wandb_data.py CLI helper to download selected metrics from a W&B project/group into wandb_logs/.
preprocess_data.py Merges raw CSVs, cleans column names, filters run names, computes horizons, and can aggregate statistics across seeds.
utils/fitting.py Analysis primitives: quadratic fitting around minima, horizon window averaging, and the Config dataclass that controls the workflow.
utils/plotting.py Matplotlib and Plotly visualisations for minima, regression fits, per-model overlays, interactive 3D views, and hyperparameter sweeps.
plot.ipynb, plot_lr_layout.ipynb End-to-end notebooks showing how to assemble datasets, call the helpers, and reproduce the final figures.
data/ Example processed CSVs; place your own outputs here. data/raw/ is the staging area for unprocessed logs.
plots/ Saved figures generated by the notebooks and scripts.
wandb_logs/ Raw exports captured by extract_wandb_data.py.

Environment setup

  1. Python: use Python 3.10 or newer.
  2. Virtual environment (recommended):
    python3 -m venv .venv
    source .venv/bin/activate
    pip install --upgrade pip
  3. Dependencies: install the scientific stack and plotting packages used throughout the repo.
    pip install numpy pandas tqdm wandb matplotlib statsmodels plotly
    pip install jupyterlab ipywidgets  # notebooks & interactive widgets
    Add any extra libraries you rely on inside your notebooks (for example seaborn).
  4. Weights & Biases: run wandb login once so that the API key is stored locally before using the download script.

Data workflow

1. Download runs from W&B

Update the constants near the top of extract_wandb_data.py (project, sweep group, metric names, momentum settings). Then fetch the histories:

python extract_wandb_data.py [--bs BATCH_SIZE] [--seed SEED]
  • The optional --bs or --seed flags restrict the runs included in the export.
  • Each CSV is written to wandb_logs/ and includes the run metadata (run_name, bs, seed, learning-rate components, etc.).

2. Stage raw files

Move or symlink the downloaded CSVs into data/raw/. The naming scheme is flexible, but preprocess_data.py expects run_name patterns such as lr-<value>-bs-<value>-seed-<value> (with optional -momentum- and -decayed fragments).

3. Preprocess and engineer features

Open preprocess_data.py and adjust the constants in the __main__ block:

  • NORM_NAME, LOSS_NAME: choose which metrics become output_norm and train_loss.
  • SEQ_LEN: sequence length used to compute the horizon = bs * step * seq_len column.
  • MOMENTUM, DECAYED, BASE, SCALING_PREFIX, SEEDS, POSTFIX: mirror your file naming scheme.

Run the script from the repository root:

python preprocess_data.py

process_runs collects matching CSVs, merges them, renames the metrics, filters out runs whose names do not match the expected regex, casts numeric columns, computes horizon, and saves a consolidated CSV to data/.

To collapse duplicate rows and summarise metrics across seeds, use the helper at the bottom of the module:

import pandas as pd
from preprocess_data import deduplicate_and_aggregate

raw = pd.read_csv("data/lr-bs-scan-base-momentum-1.0-preprocessed-seed-30.csv")
agg = deduplicate_and_aggregate(raw)
agg.to_csv("data/lr-bs-scan-base-momentum-1.0-preprocessed-seed-30-agg.csv", index=False)

deduplicate_and_aggregate checks for conflicting horizon values, removes exact duplicates, and reports per-configuration means and standard deviations.

4. Analyse minima and trends

utils.fitting.Config centralises the knobs for the downstream analysis:

  • Select horizons, toggle quadratic fitting (from_fit), enforce specific coefficients (c_fixed), or choose empirical minima.
  • Control the window used to average metrics around each horizon (avg_rel_from, avg_rel_to, average_h, average_bs).
  • Skip problematic fits via skip_fit, or refit using the closest k points (fit_k, fit_k_by).
  • Configure plot aesthetics (figure size, markers, legend placement, font sizes).

Example session:

import pandas as pd
from utils.fitting import Config, build_minima_df
from utils.plotting import plot_minima

cfg = Config(
    csv_path="data/lr-bs-scan-base-momentum-1.0-preprocessed-seeds.csv",
    horizons=[2**28, 2**30, 2**32, 2**33],
    avg_rel_from=-2,
    avg_rel_to=2,
    fit_k=5,
)

runs = pd.read_csv(cfg.csv_path)
minima = build_minima_df(runs, cfg)
fig, ax = plot_minima(minima, cfg)
fig.savefig("plots/minima-summary.png", dpi=200)

5. Visualise results

utils.plotting exposes specialised tools that build on the minima dataframe or aggregated CSVs:

  • plot_minima and plot_minima_at_horizon_across_models: scatter plus trend-line views keyed by horizon and batch size, with automatic legends.
  • plot_parabola_grid: a grid of quadratic fits for every horizon and batch size, including optional error bars from standard deviations.
  • plot_lr_bs_fit: regression of log2(lr) against log2(bs) and horizon using statsmodels with configurable transparency and column names.
  • plot_global_two_param_fit: fits log2(lr) = A*log2(h) + B*log2(bs) + C, summarises coefficients (and optionally exports them) while visualising the selected data band.
  • plot_interactive_horizon_scatter: Plotly-based 3D scatter of loss vs output norm vs learning rate for a chosen horizon.
  • plot_parallel_coordinates: compares per-layer learning-rate triples, highlighting the top-performing combinations.
  • plot_top1_across_horizons: focuses on the best loss per horizon for a fixed batch size while comparing against equal-learning-rate baselines.
  • plot_norm_vs_horizon_by_lr_bs: line plots of norm evolution for selected (lr, bs) combinations.

Each function returns Matplotlib (or Plotly) figure objects so you can save, display, or further customise them.

Notebooks

  • plot.ipynb walks through the full workflow: loading processed CSVs, building minima, fitting regressions, and exporting plots for the paper.
  • plot_lr_layout.ipynb explores the geometry of per-layer learning-rate schedules and the resulting norms.

The notebooks expect the processed CSVs under data/ and the helper modules to be importable from utils/. Launch them with jupyter lab (or your preferred frontend) after activating the virtual environment.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published