From 6fa708e375e3b6fc183f628c5c591c8f26ef484a Mon Sep 17 00:00:00 2001 From: Quentin Anthony Date: Mon, 14 Jul 2025 14:13:17 -0700 Subject: [PATCH 1/3] add muon and sgd_momentum optimizers --- calc/calc_transformer_mem.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/calc/calc_transformer_mem.py b/calc/calc_transformer_mem.py index 18bdceb..a0a4323 100644 --- a/calc/calc_transformer_mem.py +++ b/calc/calc_transformer_mem.py @@ -192,6 +192,11 @@ def config_parser(): type=int, default=None, help='The precision of gradient elements as bytes per value') + parser.add_argument("--optimizer", "-opt", + type=str, + choices=["adamw", "muon", "sgd_momentum"], + default=None, + help='Which optimizer to estimate memory for (default adamw)') # MoE Settings parser.add_argument("--num-experts", type=int, @@ -238,6 +243,7 @@ def config_parser(): "high_prec_bytes_per_val" : 4, "low_prec_bytes_per_val" : 2, "bytes_per_grad_ele" : 4, + "optimizer" : "adamw", # MoE Settings "num_experts" : 0, "expert_parallelism" : 1, @@ -304,11 +310,23 @@ def calc_mem(args): # --- OPTIMIZER MEMORY --- # For mixed-precision Adam/AdamW, the optimizer must store fp32 copies of the parameters, momentum, and variance (4 + 4 + 4 = 12 bytes per optimizer parameter) - # Feel free to change the multiplier for your optimizer (examples include SGD (4 + 4 = 8) and 8-bit ADAM (2 + 2 + 2 = 6) + # Feel free to change the multiplier for your optimizer (examples include SGD w/ momentum (4 + 4 = 8) and 8-bit ADAM (2 + 2 + 2 = 6) + opt_type = args.optimizer.lower() + if opt_type == "sgd_momentum": + opt_multiplier = 8 # fp32 copy + m + elif opt_type == "muon": + opt_multiplier = 8 # fp32 copy + m (https://kellerjordan.github.io/posts/muon/#runtime-analysis) + else: # Adamw default + opt_multiplier = 12 # fp32 copy + m + v (https://arxiv.org/abs/1910.02054) + + + if args.num_experts > 0: - optimizer_mem = EP_total_params * 12 + optimizer_mem = EP_total_params * opt_multiplier else: - optimizer_mem = total_params * 12 + optimizer_mem = total_params * opt_multiplier + + per_gpu_optimizer_mem = optimizer_mem # ZeRO stage 3 shards the optimizer states across GPUs if args.zero_stage >= 1: @@ -385,7 +403,7 @@ def calc_mem(args): print(f'Per-GPU KV Cache Memory: {per_gpu_kv_cache_mem_gib:.2f} GiB') else: print(f'Per-GPU Gradient Memory: {per_gpu_gradient_mem_gib:.2f} GiB') - print(f'Per-GPU Optimizer Memory: {per_gpu_optimizer_mem_gib:.2f} GiB') + print(f'Per-GPU Optimizer Memory ({opt_type}): {per_gpu_optimizer_mem_gib:.2f} GiB') print(f'Per-GPU Communication Memory: {per_gpu_communication_mem_gib:.2f} GiB') print(f'Per-GPU Miscellaneous Memory: {args.misc_mem_gib:.2f} GiB') # Aggregate Per-GPU Memory @@ -403,7 +421,7 @@ def calc_mem(args): print(f'Total KV Cache Memory: {kv_cache_mem_gib:.2f} GiB') else: print(f'Total Gradient Memory: {gradient_mem_gib:.2f} GiB') - print(f'Total Optimizer Memory: {optimizer_mem_gib:.2f} GiB') + print(f'Total Optimizer Memory ({opt_type}): {optimizer_mem_gib:.2f} GiB') print(f'Total Miscellaneous Memory: {args.num_gpus*args.misc_mem_gib:.2f} GiB') # Aggregate GPU memory if args.infer: From 61c45b96a0d52ff15f118b1956b1c4e8b222af8a Mon Sep 17 00:00:00 2001 From: Quentin Anthony Date: Tue, 29 Jul 2025 09:23:30 -0700 Subject: [PATCH 2/3] add bert and clear up flag --- calc/calc_transformer_flops.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/calc/calc_transformer_flops.py b/calc/calc_transformer_flops.py index 89a4af4..b36e67d 100644 --- a/calc/calc_transformer_flops.py +++ b/calc/calc_transformer_flops.py @@ -65,7 +65,7 @@ def config_parser(): type=float, default=300e9, help='Number of tokens you are training over') - parser.add_argument("--no-checkpoint-activations", "-ca", + parser.add_argument("--no-checkpoint-activations", "-nca", action='store_false', help='Whether Megatron-style activation checkpointing is being used', dest='checkpoint_activations') @@ -85,6 +85,13 @@ def config_parser(): parser.add_argument("--infer", "-i", action='store_true', help='Pass to calculate FLOPs for inference-only workload (no backward pass)') + parser.add_argument("--encoder-only", "-enc", action="store_true", + help="Set if the model is encoder-only (e.g. BERT/ModernBERT)") + + parser.add_argument("--mlm-ratio", "-mr", type=float, default=1.0, + help="Fraction of tokens that receive a language-model head. " + "Use 1.0 for autoregressive GPT-style training, " + "0.15 for BERT-style masked-LM pre-training") return parser # calculates the flops of a model given its hparams @@ -93,6 +100,8 @@ def calc_flops(args): assert args.num_layers % args.expert_interval == 0, "Require for simplicity that we don't have hanging dense layers" assert not args.ffn_hidden_size or (args.ffn_expansion_factor == 4), "both '--ffn-hidden-size' and non-default '-ff' values were specified, these cannot conflict" + is_encoder = args.encoder_only + # An A_(m x k) X B_(k x n) matrix multiplication requires 2m x k x n FLOPs (factor of 2 needed to account for multiplies and adds) # determine the flops factor. @@ -114,7 +123,8 @@ def calc_flops(args): ffn_flops = int(iter_factor * 2 * args.num_mlp_linears * args.ffn_expansion_factor) * args.num_layers * args.tokens * args.hidden_size * args.hidden_size # no activation checkpointing for embeddings - embedding_flops = 6 * args.tokens * args.hidden_size * args.vocab_size + embedding_flops = (6 * args.tokens * args.hidden_size * args.vocab_size + * (args.mlm_ratio if is_encoder else 1.0)) if args.moe and args.topk > 1: ffn_flops += ffn_flops * args.topk / args.expert_interval From d8a357342d1fdbd3fba0c591ba3f088d18a30e89 Mon Sep 17 00:00:00 2001 From: Quentin Anthony Date: Tue, 29 Jul 2025 09:30:42 -0700 Subject: [PATCH 3/3] add bert and clear up flag --- calc/calc_transformer_flops.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/calc/calc_transformer_flops.py b/calc/calc_transformer_flops.py index 74f99e4..79d9ff6 100644 --- a/calc/calc_transformer_flops.py +++ b/calc/calc_transformer_flops.py @@ -65,7 +65,7 @@ def config_parser(): type=float, default=300e9, help='Number of tokens you are training over') - parser.add_argument("--no-checkpoint-activations", "-ca", + parser.add_argument("--no-checkpoint-activations", "-nca", action='store_false', help='Whether Megatron-style activation checkpointing is being used', dest='checkpoint_activations') @@ -85,6 +85,13 @@ def config_parser(): parser.add_argument("--infer", "-i", action='store_true', help='Pass to calculate FLOPs for inference-only workload (no backward pass)') + parser.add_argument("--encoder-only", "-enc", action="store_true", + help="Set if the model is encoder-only (e.g. BERT/ModernBERT)") + + parser.add_argument("--mlm-ratio", "-mr", type=float, default=1.0, + help="Fraction of tokens that receive a language-model head. " + "Use 1.0 for autoregressive GPT-style training, " + "0.15 for BERT-style masked-LM pre-training") return parser # calculates the flops of a model given its hparams @@ -93,6 +100,8 @@ def calc_flops(args): assert args.num_layers % args.expert_interval == 0, "Require for simplicity that we don't have hanging dense layers" assert not args.ffn_hidden_size or (args.ffn_expansion_factor == 4), "both '--ffn-hidden-size' and non-default '-ff' values were specified, these cannot conflict" + is_encoder = args.encoder_only + # An A_(m x k) X B_(k x n) matrix multiplication requires 2m x k x n FLOPs (factor of 2 needed to account for multiplies and adds) # determine the flops factor. @@ -117,8 +126,8 @@ def calc_flops(args): ffn_flops = int(iter_factor * 2 * args.num_mlp_linears * args.ffn_expansion_factor) * args.num_layers * args.tokens * args.hidden_size * args.hidden_size # no activation checkpointing for embeddings - # embedding (4*d_model) plus unembedding (2*d_model*vocab_size) - embedding_flops = args.tokens * (4 * args.hidden_size + 2 * args.hidden_size * args.vocab_size) + embedding_flops = (args.tokens * (4 * args.hidden_size + 2 * args.hidden_size * args.vocab_size) + * (args.mlm_ratio if is_encoder else 1.0)) if args.moe and args.topk > 1: ffn_flops *= args.topk / args.expert_interval