This repository contains the code and configuration files for the paper "Midtraining Bridges Pretraining and Posttraining Distributions" (2025)
This repo provides scripts and configuration templates to run large-scale language model training with midtraining, pretraining, and ablation experiments. All training is managed via YAML config files and SLURM batch scripts.
To try a new midtraining experiment:
- Create a new YAML config in
midtrain_configs/(see existing examples for structure). - Edit the config to specify your desired data, model, and training parameters.
The repo includes a flexible sharded mixed dataset implementation (litgpt.data.ShardedMixedDataset) that:
- Discovers shards under a base data directory by type (main numbered shards
1,2, ..., and typed shards likec1,m1,q1,w1etc.). - Supports three mixing modes:
- literal: specify exact shard weights via
literal_weights_str(e.g.main/1:0.8,q1:0.1,q2:0.1). - weighted: give per-type weights via
mix_weights_str(e.g.main:0.8,math:0.2) which are divided equally across shards of that type. - proportional: set
proportional_samplingto compute weights proportional to shard sizes.
- literal: specify exact shard weights via
- Uses streaming loaders and
CombinedStreamingDatasetto sample from multiple shards with the computed weights.
Make sure your shard names match the config (literal weights require exact shard IDs like w1 or main/1).
- Prefixes are simple labels used to group shards. Examples:
main(numbered folders1,2),c→c1(code),m→m1(math),w→w1(web). - Use
mix_weights_strfor per-type weights (divided across shards),literal_weights_strfor exact shard IDs (e.g.w1ormain/1), or enableproportional_samplingto weight by shard size. - To add a new type: pick a short prefix, add it to
DATASET_TYPE_CONFIGSinlitgpt/litgpt/data/sharded_mixed_dataset.py, name shardsp1,p2,... and use the label in your config.
Midtraining experiments are implemented simply by resuming from an intermediate checkpoint and changing the dataset blend or mixing config in the YAML. In practice you:
- Point
out_dirandresume: trueto the checkpointed run you want to continue from. - Update
mix_weights_str,literal_weights_str, ormix_config_pathin the new midtraining YAML to change the data blend. - Launch training; the code will load the checkpoint and continue training with the new data mixture.
training_scripts/small_model_pretrain.sh: Main SLURM script for launching pretraining or midtraining jobs. It supports array jobs and can be pointed to any config file.- Other scripts in
training_scripts/andutil_scripts/provide evaluation, symlinking, and utility functions.
You can run the small model pretrain script with a specific config file by passing it as an environment variable:
export model_config_file=/path/to/your_config.yaml
sbatch training_scripts/small_model_pretrain.shAlternatively, you can edit the script to directly set model_config_file to your config path.
For more details, see the example configs in midtrain_configs/ and the comments in each script.
TODO