Skip to content

Add DeepSpeed 103B GPT pretraining benchmark and standardize containers for B200#1009

Merged
paragao merged 5 commits intomainfrom
feature/deepspeed-b200-pretraining
Mar 11, 2026
Merged

Add DeepSpeed 103B GPT pretraining benchmark and standardize containers for B200#1009
paragao merged 5 commits intomainfrom
feature/deepspeed-b200-pretraining

Conversation

@paragao
Copy link
Contributor

@paragao paragao commented Mar 6, 2026

Summary

  • Add a 103B-parameter GPT pretraining benchmark using Megatron-DeepSpeed on NVIDIA B200 clusters with 3D parallelism (TP/PP/DP) and DeepSpeed ZeRO optimization
  • Rewrite the DeepSpeed README as a use-case-focused guide covering all three test cases (GPT-103B pretraining, QLoRA fine-tuning, Llama2 fine-tuning) with best-practice configuration recommendations
  • Standardize the QLoRA Dockerfile with the same infrastructure best practices as the pretraining container

Changes

New files

File Description
pretrain_gpt_103b.sbatch Parameterized Slurm script for 103B GPT pretraining (TP, PP, ZeRO, fusions, seq length all configurable via env vars)
parse_results.py Parses Megatron-DeepSpeed training logs into benchmark JSON
configs/ds_config_103b_template.json Reference DeepSpeed config for the 103B model

Modified files

File Change
0.deepspeed.dockerfile Rewritten: pytorch:25.04-py3 base (CUDA 12.9), EFA 1.47.0, NCCL 2.29.3, GDRCopy v2.5.1, OFI-NCCL symlinks, ld.so.conf.d library discovery
1.build-image.sbatch Added mkdir -p /fsx/apps, fixed job name
Makefile Rewritten: train target uses best config (TP=8, PP=8, fusions), parse target for log parsing, help with descriptions
README.md Rewritten as use-case index with prerequisites, container setup, data prep, running instructions, best practices, and known issues
qlora/Dockerfile Standardized: same pytorch:25.04-py3 base, EFA 1.47, NCCL 2.29.3, GDRCopy 2.5.1, proper NCCL/EFA env vars, expandable_segments:True
qlora/requirements.txt Updated comment to reflect new CUDA 12.9 index URL

Key best practices documented

  • Maximize pipeline parallelism (TP=8, PP=8) with kernel fusions enabled for best throughput
  • ZeRO-0 over ZeRO-1 when data-parallel group size is small
  • ZeRO-2/3 require --no-pipeline-parallel (Megatron-DeepSpeed PipelineEngine limitation)
  • Do not set NCCL_ALGO=Tree on EFA clusters (causes hangs)
  • Use expandable_segments:True (capital T) in pytorch:25.04 containers
  • Use python3 -m torch.distributed.run instead of torchrun (shebang compatibility)

Testing

All configurations were validated on an 8-node B200 HyperPod cluster (64 GPUs total) with 50 training steps each.

paragao added 2 commits March 5, 2026 15:04
- Update Dockerfile: pytorch:25.04-py3 base, EFA 1.47, NCCL 2.29.3, GDRCopy 2.5.1
- Add 103B GPT pretraining sbatch script with parameterized parallelism, ZeRO stages, fusion ops, and correct NCCL/EFA flags
- Add sweep runners (v1: 20 configs, v2: 10 configs) covering TP/PP/ZeRO/fusion/memory variations
- Add results parser and S3 upload script with CloudWatch metric publishing
- Best result: 476.6 TFLOPS/GPU (TP=8, PP=8, ZeRO=0, fusions enabled) on 8x B200 nodes
- Rewrite README as use-case-focused guide: GPT-103B pretraining,
  QLoRA fine-tuning, and Llama2 fine-tuning with best practices
  and proper configuration docs (no benchmark numbers)
- Simplify Makefile: best-config train target, remove sweep/upload targets
- Standardize QLoRA Dockerfile: pytorch:25.04-py3 base, EFA 1.47,
  NCCL 2.29.3, GDRCopy 2.5.1, OFI-NCCL symlinks, proper NCCL/EFA env vars
- Remove sweep runners and upload script from tracked files (internal tooling)
Copy link
Collaborator

@KeitaW KeitaW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Batch 1/4 — Structure & Repository Hygiene

New pretraining files should live in their own subdirectory with platform-specific layout

The repo convention requires test cases to follow 3.test_cases/<framework>/<library>/<test_case>/ with platform-specific subdirectories (slurm/, kubernetes/, hyperpod-eks/). Right now the three new pretraining files (pretrain_gpt_103b.sbatch, parse_results.py, configs/ds_config_103b_template.json) are placed directly in the deepspeed/ directory alongside shared infrastructure (0.deepspeed.dockerfile, 1.build-image.sbatch, Makefile). This mixes the shared container build with a specific benchmark, and doesn't match how the sibling qlora/ directory is organized.

I'd suggest moving the GPT-103B content into its own subdirectory, for example:

3.test_cases/pytorch/deepspeed/
├── 0.deepspeed.dockerfile          # shared container (stays here)
├── 1.build-image.sbatch            # shared build script (stays here)
├── Makefile                         # shared targets (stays here)
├── README.md                        # index page (stays here)
├── gpt/                             # NEW: GPT-103B benchmark
│   ├── README.md                    # benchmark-specific docs
│   ├── slurm/
│   │   └── pretrain_gpt_103b.sbatch
│   ├── configs/
│   │   └── ds_config_103b_template.json
│   └── parse_results.py
├── qlora/                           # existing
└── examples_megatron_deepspeed/     # existing

This keeps the directory layout consistent with qlora/ and with the broader repo convention. The Makefile train target would just need its path updated to gpt/slurm/pretrain_gpt_103b.sbatch.

Comment on lines +146 to +150
RUN pip3 install --no-cache-dir \
awscli pynvml \
transformers==${TRANSFORMERS_VERSION} \
sentencepiece python-etcd \
deepspeed accelerate
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unpinned Python packages

deepspeed and accelerate are installed without version pins. Per repo conventions, dependencies should have at least an upper-bound pin to prevent silent breakage from major version changes. For example:

Suggested change
RUN pip3 install --no-cache-dir \
awscli pynvml \
transformers==${TRANSFORMERS_VERSION} \
sentencepiece python-etcd \
deepspeed accelerate
RUN pip3 install --no-cache-dir \
awscli pynvml \
transformers==${TRANSFORMERS_VERSION} \
sentencepiece python-etcd \
deepspeed>=0.16,<1.0 accelerate>=1.0,<2.0

I'd also double-check the QLoRA requirements.txt for the same issue — any packages there without upper bounds should get them too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Pinned deepspeed>=0.16,<1.0 and accelerate>=1.0,<2.0 in the Dockerfile. QLoRA requirements.txt already had proper upper bounds — no changes needed there.

Comment on lines +8 to +10
matching the existing benchmark-results schema at:
s3://paragao-new-nemo-squash-container/benchmark-results/b200/

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personal S3 bucket reference

This looks like a personal bucket path in the module docstring. Since this will be visible in the public repo, I'd suggest either removing it or replacing it with a generic path.

Suggested change
matching the existing benchmark-results schema at:
s3://paragao-new-nemo-squash-container/benchmark-results/b200/
matching the benchmark-results schema.
Usage:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Replaced with generic placeholder s3://<YOUR_BUCKET>/benchmark-results/<instance_type>/. Also parameterized cluster and instance_type in build_result_json() — they now accept CLI args (--cluster, --instance-type) and fall back to $CLUSTER_NAME/$INSTANCE_TYPE env vars.

Copy link
Collaborator

@KeitaW KeitaW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Batch 2/4 — Deployment Pipeline

Comment on lines +18 to +19
mkdir -p ${APPS_PATH}
mkdir -p logs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unquoted variable expansions

The repo convention calls for quoting all variable expansions to guard against word splitting. While these are unlikely to contain spaces in practice, it's good hygiene and consistent with the rest of the codebase.

Suggested change
mkdir -p ${APPS_PATH}
mkdir -p logs
mkdir -p "${APPS_PATH}"
mkdir -p logs

The same applies throughout pretrain_gpt_103b.sbatch — several ${IMAGE}, ${HOSTFILE}, and other expansions are unquoted, especially in the srun arguments and the bash -c string interpolation at the bottom of the script.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Quoted all variable expansions in 1.build-image.sbatch and pretrain_gpt_103b.sbatch ($SLURM_JOB_NODELIST, ${HOSTFILE}, ${IMAGE}, ${APPS_PATH}, ${ENROOT_IMAGE}). For the srun bash -c line, variables are intentionally expanded on the host side before passing to the container — added a comment explaining this.

#SBATCH --job-name=deepspeed-pretrain-103b
#SBATCH --output=logs/%x_%j.out
#SBATCH --error=logs/%x_%j.err
#SBATCH --partition=b200
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoded partition name

The b200 partition is specific to your cluster. Other users won't have this partition, and the #SBATCH directive takes precedence over sbatch --partition=... command-line overrides. Since the Makefile already passes --partition=$(PARTITION) at submission time, I'd suggest removing this directive entirely and letting users specify the partition when submitting. That way make train PARTITION=my-partition works without the #SBATCH line overriding it.

Alternatively, add a comment noting this must be changed for other clusters.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Removed the #SBATCH --partition=b200 directive entirely. The partition is now passed exclusively at submit time via --partition= in the Makefile. Also changed the Makefile default partition from b200 to dev.

Copy link
Collaborator

@KeitaW KeitaW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Batch 3/4 — Infrastructure & NCCL Configuration

export FI_PROVIDER=efa
export FI_EFA_USE_HUGE_PAGE=0
export NCCL_SOCKET_IFNAME=^docker,lo,veth
export NCCL_P2P_NET_CHUNKSIZE=2048576
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NCCL_P2P_NET_CHUNKSIZE value looks off

This value is 2048576 (2,048,576). If 2 MB was intended, the conventional power-of-two value would be 2097152 (2×1024×1024). If it's an intentional non-power-of-two tuning choice that works well in your benchmarks, a brief comment explaining why would be helpful.

Suggested change
export NCCL_P2P_NET_CHUNKSIZE=2048576
export NCCL_P2P_NET_CHUNKSIZE=2097152

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Changed from 2048576 to 2097152 (2 x 1024 x 1024). The original value was a typo, not an intentional tuning choice.

Copy link
Collaborator

@KeitaW KeitaW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Batch 4/4 — Documentation Consistency

README environment variable table defaults vs recommended config

The README's env vars table correctly shows that PP defaults to 2 and ZERO_STAGE to 1 (matching the sbatch script), but the recommended "best config" uses TP=8, PP=8, ZERO_STAGE=0. Users who run the sbatch directly without make train would get the baseline config rather than the documented best. I'd suggest either aligning the sbatch defaults with the recommended config, or adding a note clarifying that the defaults are a safe baseline and make train uses the optimized settings.

Missing newline at end of README

The diff shows \ No newline at end of file on the last line of README.md. Per .editorconfig conventions (insert_final_newline = true), this should have a trailing newline.


Things That Look Great

  • Excellent Dockerfile modernization: The switch from manual aws-ofi-nccl source builds to the EFA 1.47 bundled plugin with proper symlinks is a significant improvement. The ld.so.conf.d approach for library discovery is more robust than relying solely on LD_LIBRARY_PATH.
  • Thorough best practices documentation: The README's best practices section is clearly informed by real experimentation — the ZeRO-2/3 + pipeline parallelism incompatibility, the NCCL_ALGO=Tree warning, and the expandable_segments:True case sensitivity are all valuable tribal knowledge that will save users hours of debugging.
  • Well-parameterized sbatch script: The pretrain_gpt_103b.sbatch script is cleanly organized with sensible defaults and full override capability via environment variables. The automatic ZeRO-stage-aware pipeline parallel handling is a nice touch.
  • Consistent infrastructure stack: Standardizing both Dockerfiles on the same EFA/NCCL/GDRCopy versions reduces maintenance burden and eliminates "works in one container but not the other" issues.
  • Dynamic DeepSpeed config generation: Generating the DS config at runtime from environment variables (rather than requiring users to edit JSON files) is much more ergonomic for parameter sweeps.
  • Clean log parser: parse_results.py is well-structured with clear regex patterns and proper warmup-step handling for steady-state metrics.

| Variable | Default | Description |
|----------|---------|-------------|
| `TP` | 8 | Tensor parallel size |
| `PP` | 2 | Pipeline parallel size |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default PP doesn't match recommended config

The table shows PP defaults to 2, but the recommended "best config" (in the Makefile and the make train example above) uses PP=8. Users running the sbatch directly without make train will get PP=2. Consider either changing the default to 8 or adding a note here that these are safe baselines and make train uses the optimized settings.

Suggested change
| `PP` | 2 | Pipeline parallel size |
| `PP` | 2 | Pipeline parallel size (use PP=8 for best throughput, see `make train`) |

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Updated the README table row to: | PP | 2 | Pipeline parallel size (best throughput with PP=8, see make train) |

Copy link
Collaborator

@KeitaW KeitaW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good update, few comments.

paragao added 3 commits March 10, 2026 03:11
- Remove personal S3 bucket name from parse_results.py and upload_results.sh
- Parameterize cluster name and instance type in parse_results.py (via CLI
  args and env vars)
- Replace hardcoded S3 bucket/regions in upload_results.sh with required
  env vars (S3_BUCKET, S3_REGION, CW_REGION)
- Remove hardcoded --partition=b200 from sbatch script
- Make PARTITION overridable in sweep_runner.sh and sweep_runner_v2.sh
- Change default partition from 'b200' to 'dev' in Makefile and sweep scripts
- Add sweep_runner.sh, sweep_runner_v2.sh, and upload_results.sh to tracking
…e vars

- Move GPT-103B pretraining files into gpt/ subdirectory with slurm/
  and configs/ sub-dirs to match repo conventions
- Pin deepspeed>=0.16,<1.0 and accelerate>=1.0,<2.0 in Dockerfile
- Quote all variable expansions in build and training scripts
- Fix NCCL_P2P_NET_CHUNKSIZE from 2048576 to 2097152 (2MB power-of-two)
- Add PP note to README env vars table clarifying best config uses PP=8
- Add trailing newline to README.md
- Update all path references for new directory structure
@paragao
Copy link
Contributor Author

paragao commented Mar 10, 2026

Addressed all review feedback

Directory restructure (Batch 1/4): Moved GPT-103B files into gpt/ subdirectory with slurm/ and configs/ sub-dirs to match repo conventions:

gpt/
├── configs/ds_config_103b_template.json
├── slurm/pretrain_gpt_103b.sbatch
└── parse_results.py

Updated all path references in Makefile and README.

Documentation (Batch 4/4): Added trailing newline to README.md. Updated PP row with note about make train.

Also fixed across these commits:

  • Removed personal S3 bucket references, parameterized cluster/instance metadata
  • Removed hardcoded --partition=b200, changed default to dev
  • Pinned deepspeed>=0.16,<1.0 and accelerate>=1.0,<2.0 in Dockerfile
  • Quoted all variable expansions in build and training scripts
  • Corrected NCCL_P2P_NET_CHUNKSIZE from 2048576 to 2097152 (power-of-two)

See individual inline replies for details on each item.

Copy link
Collaborator

@KeitaW KeitaW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Batch 1/3 — Structure & Repository Hygiene

Re-review of PR #1009 (2026-03-11). Good progress since the initial review — the gpt/ subdirectory restructuring, version pins on deepspeed/accelerate, and NCCL chunk size fix all look great.

One cross-cutting note: Steps 1–7 of 0.deepspeed.dockerfile and qlora/Dockerfile are nearly identical (~80 lines of duplicated EFA/NCCL/GDRCopy infrastructure). This isn't blocking since self-contained Dockerfiles are the norm in this repo, but it's worth considering a shared base image in a follow-up to reduce maintenance burden when version bumps are needed.


```bash
enroot import -o ${ENROOT_IMAGE} dockerd://deepspeed:latest
git clone https://github.com/microsoft/Megatron-DeepSpeed /fsx/deepspeed/Megatron-DeepSpeed
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Megatron-DeepSpeed clone is not pinned to a version

The previous Dockerfile had ARG MEGATRON_LM_VERSION=core_r0.8.0, but this was removed. Per the repo convention, external dependencies must be pinned to a version/tag/commit — never HEAD. I'd suggest pinning to whatever commit or tag was validated during testing on the B200 cluster.

Suggested change
git clone https://github.com/microsoft/Megatron-DeepSpeed /fsx/deepspeed/Megatron-DeepSpeed
git clone --branch <validated_tag> --depth 1 https://github.com/microsoft/Megatron-DeepSpeed /fsx/deepspeed/Megatron-DeepSpeed

Copy link
Collaborator

@KeitaW KeitaW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Batch 2/3 — Deployment Pipeline

Two items in the container build pipeline.

enroot import -o ${ENROOT_IMAGE}.sqsh dockerd://${ENROOT_IMAGE}:latest
mv ${ENROOT_IMAGE}.sqsh ${IMAGE} No newline at end of file
enroot import -o "${ENROOT_IMAGE}.sqsh" "dockerd://${ENROOT_IMAGE}:latest"
mv "${ENROOT_IMAGE}.sqsh" "${IMAGE}" No newline at end of file
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing final newline

The diff shows \ No newline at end of file here. Per .editorconfig (insert_final_newline = true), this should end with a trailing newline.

Comment on lines +140 to +144
# Install PyTorch with CUDA 12.9 support
# Note: torch 2.10+ has a breaking LR scheduler change (strict zip) that is
# incompatible with some DeepSpeed/transformers versions. Pin to <2.10 until
# upstream libraries catch up.
RUN pip install --no-cache-dir 'torch>=2.7.0,<2.10.0' --index-url https://download.pytorch.org/whl/cu128
RUN pip install --no-cache-dir 'torch>=2.7.0,<2.10.0' --index-url https://download.pytorch.org/whl/cu129
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reinstalling PyTorch on top of NGC base image

The base image nvcr.io/nvidia/pytorch:25.04-py3 already ships with a CUDA 12.9-optimized PyTorch. This line overwrites it with the generic PyPI wheel, which may lack NGC-specific optimizations (cuDNN auto-tuning, NVTX annotations, etc.) and increases image size.

The main Dockerfile (0.deepspeed.dockerfile) correctly relies on the base image's PyTorch. I'd suggest either removing this line (the base image's torch satisfies >=2.7.0,<2.10.0) or adding a comment explaining why the override is intentional (e.g., a specific version requirement for QLoRA compatibility).

Copy link
Collaborator

@KeitaW KeitaW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Batch 3/3 — Documentation Consistency


Things That Look Great

  • Directory structure fixed: The gpt/ subdirectory with slurm/ and configs/ follows the repo convention nicely. Good restructuring since the last review.
  • ZeRO-2/3 pipeline parallelism guard: The automatic detection of ZERO_STAGE >= 2 to disable pipeline parallelism and add --no-pipeline-parallel prevents a confusing runtime assertion. The inline comment explaining the Megatron-DeepSpeed PipelineEngine constraint is exactly the kind of documentation that saves hours of debugging.
  • Dynamic DeepSpeed config generation: Generating the JSON config at runtime from env vars (with ZeRO-3 specific parameters) is much more maintainable than managing multiple static configs.
  • Comprehensive best practices from real experiments: The documented findings (ZeRO-0 > ZeRO-1 when DP=1, capital T in expandable_segments:True, NCCL_ALGO=Tree hangs on EFA, negligible impact of NCCL buffer tuning) are hard-won knowledge that will save users significant debugging time.
  • Proper NCCL_SOCKET_IFNAME exclusion pattern: Both Dockerfiles use ^docker,lo,veth — the correct exclusion-based approach for EFA clusters.
  • NCCL plugin symlink documentation: The comments explaining why symlinks are needed (EFA installer naming vs. what NCCL expects, causing silent TCP socket fallback) prevent a subtle and hard-to-diagnose failure mode.
  • Well-parameterized sbatch: Everything configurable via env vars with sensible defaults enables both quick benchmarking and detailed parameter sweeps.
  • Clean log parser: parse_results.py is well-structured with clear regex patterns, warmup exclusion, TFLOPS computation fallback, and both single-file and batch CSV modes.
  • Improved shell quoting: 1.build-image.sbatch now quotes all variable expansions properly.

Copy link
Collaborator

@KeitaW KeitaW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few minor comments. Generally looks good. Thanks for addressing the comments!

@paragao paragao merged commit f20676d into main Mar 11, 2026
4 checks passed
@paragao paragao deleted the feature/deepspeed-b200-pretraining branch March 11, 2026 13:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants