From 1219a263351c8e0c738e9ed1df3c7cfcfae44212 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Wed, 29 Oct 2025 02:46:14 -0700 Subject: [PATCH 01/74] renaming golden values Signed-off-by: Hongbin Liu --- ...ev_coreweave.json => golden_values_dev_dgxh100_coreweave.json} | 0 ...den_values_dev_eos.json => golden_values_dev_dgxh100_eos.json} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/{golden_values_dev_coreweave.json => golden_values_dev_dgxh100_coreweave.json} (100%) rename tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/{golden_values_dev_eos.json => golden_values_dev_dgxh100_eos.json} (100%) diff --git a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/golden_values_dev_coreweave.json b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/golden_values_dev_dgxh100_coreweave.json similarity index 100% rename from tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/golden_values_dev_coreweave.json rename to tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/golden_values_dev_dgxh100_coreweave.json diff --git a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/golden_values_dev_eos.json b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/golden_values_dev_dgxh100_eos.json similarity index 100% rename from tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/golden_values_dev_eos.json rename to tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/golden_values_dev_dgxh100_eos.json From ce6e6613b99465aa936ce8b80913c4eeeed337c0 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Tue, 4 Nov 2025 00:26:16 -0800 Subject: [PATCH 02/74] fix bug: accuracy issu because of recomputing and offloading same module Signed-off-by: Hongbin Liu --- megatron/core/tensor_parallel/random.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/megatron/core/tensor_parallel/random.py b/megatron/core/tensor_parallel/random.py index 5a44c38713d..69e973142df 100644 --- a/megatron/core/tensor_parallel/random.py +++ b/megatron/core/tensor_parallel/random.py @@ -579,6 +579,14 @@ def _recompute(self, _): # Store the inputs for backward pass inputs = self.ctx.saved_tensors + def detach(t): + if isinstance(t, torch.Tensor): + requires_grad = t.requires_grad + t = t.detach() + t.requires_grad_(requires_grad) + return t + + inputs = tuple(detach(t) for t in inputs) with torch.enable_grad(), fp8_ctx, recompute_ctx: outputs = self.run_function(*inputs) From 2fe4aebf90dd0f82cc612281690fd59a55bf71b4 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Tue, 4 Nov 2025 00:33:06 -0800 Subject: [PATCH 03/74] format Signed-off-by: Hongbin Liu --- megatron/core/tensor_parallel/random.py | 1 + 1 file changed, 1 insertion(+) diff --git a/megatron/core/tensor_parallel/random.py b/megatron/core/tensor_parallel/random.py index 69e973142df..396e5c54a2d 100644 --- a/megatron/core/tensor_parallel/random.py +++ b/megatron/core/tensor_parallel/random.py @@ -579,6 +579,7 @@ def _recompute(self, _): # Store the inputs for backward pass inputs = self.ctx.saved_tensors + def detach(t): if isinstance(t, torch.Tensor): requires_grad = t.requires_grad From fb3f7c38b93ffb51d5a1da2e1c54c5e812ad408f Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Tue, 4 Nov 2025 19:49:47 -0800 Subject: [PATCH 04/74] update golden values Signed-off-by: Hongbin Liu --- .../golden_values_dev_dgx_h100.json | 592 +++++++++--------- .../golden_values_dev_dgx_h100.json | 392 ++++++------ 2 files changed, 492 insertions(+), 492 deletions(-) diff --git a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/golden_values_dev_dgx_h100.json b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/golden_values_dev_dgx_h100.json index 0a6724a3e95..4b32d4256db 100644 --- a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/golden_values_dev_dgx_h100.json +++ b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/golden_values_dev_dgx_h100.json @@ -6,54 +6,54 @@ "values": { "1": 11.07559, "2": 11.03834, - "3": 9.66022, - "4": 9.91367, - "5": 9.3291, - "6": 9.13927, - "7": 9.13591, - "8": 8.65527, - "9": 8.51396, - "10": 8.84095, - "11": 8.29144, - "12": 8.34584, - "13": 8.25509, - "14": 7.73685, - "15": 7.86273, - "16": 7.93699, - "17": 7.89257, - "18": 7.63116, - "19": 7.99719, - "20": 7.7453, - "21": 7.44298, - "22": 7.42242, - "23": 7.29721, - "24": 7.27467, - "25": 7.54562, - "26": 6.96839, - "27": 7.50569, - "28": 7.22761, - "29": 7.36579, - "30": 7.52635, - "31": 7.27036, - "32": 7.45548, - "33": 7.50952, - "34": 7.55694, - "35": 7.10212, - "36": 6.96414, - "37": 7.28438, - "38": 7.08049, - "39": 7.40908, - "40": 7.4335, - "41": 7.38491, - "42": 7.15766, - "43": 7.15867, - "44": 7.28831, - "45": 7.16729, - "46": 6.78429, - "47": 7.40937, - "48": 7.00259, - "49": 7.46241, - "50": 6.92143 + "3": 9.72869, + "4": 9.61678, + "5": 10.63323, + "6": 9.1681, + "7": 9.35196, + "8": 9.05204, + "9": 8.84148, + "10": 9.00321, + "11": 8.49799, + "12": 8.5218, + "13": 8.41649, + "14": 7.9096, + "15": 8.00627, + "16": 8.05394, + "17": 8.0203, + "18": 7.73136, + "19": 8.11676, + "20": 7.83945, + "21": 7.52196, + "22": 7.5295, + "23": 7.38729, + "24": 7.3758, + "25": 7.65255, + "26": 7.04795, + "27": 7.591, + "28": 7.30023, + "29": 7.45656, + "30": 7.60935, + "31": 7.3713, + "32": 7.55298, + "33": 7.59738, + "34": 7.65764, + "35": 7.17916, + "36": 7.04913, + "37": 7.38022, + "38": 7.14883, + "39": 7.50321, + "40": 7.51595, + "41": 7.45139, + "42": 7.21197, + "43": 7.21131, + "44": 7.38058, + "45": 7.16397, + "46": 6.86108, + "47": 7.27247, + "48": 7.10862, + "49": 7.56398, + "50": 7.00523 } }, "num-zeros": { @@ -61,56 +61,56 @@ "end_step": 50, "step_interval": 1, "values": { - "1": 911219392.0, - "2": 910960384.0, - "3": 911156288.0, - "4": 913253376.0, - "5": 921845056.0, - "6": 941436672.0, - "7": 993745472.0, - "8": 974360512.0, - "9": 999146112.0, - "10": 992706944.0, - "11": 991438144.0, - "12": 979442048.0, - "13": 1029190272.0, - "14": 1008214656.0, - "15": 988472000.0, - "16": 988861120.0, - "17": 979173312.0, - "18": 996164608.0, - "19": 979453440.0, - "20": 982914688.0, - "21": 975473344.0, - "22": 955037568.0, - "23": 969208128.0, - "24": 965840832.0, - "25": 953269440.0, - "26": 949025536.0, - "27": 948458304.0, - "28": 951741184.0, - "29": 943926272.0, - "30": 935020160.0, - "31": 933230336.0, - "32": 930086848.0, - "33": 922853952.0, - "34": 927140800.0, - "35": 925348224.0, - "36": 925295168.0, - "37": 922758272.0, - "38": 922930752.0, - "39": 922322880.0, - "40": 921856640.0, - "41": 920227776.0, - "42": 918353664.0, - "43": 919655616.0, - "44": 914948224.0, - "45": 916392512.0, - "46": 914344448.0, - "47": 911769536.0, - "48": 912013248.0, - "49": 910349376.0, - "50": 914351616.0 + "1": 38802120.0, + "2": 38543052.0, + "3": 38738396.0, + "4": 113220144.0, + "5": 344100160.0, + "6": 435062816.0, + "7": 579598912.0, + "8": 819195200.0, + "9": 604910464.0, + "10": 690749824.0, + "11": 744002496.0, + "12": 520212192.0, + "13": 547932992.0, + "14": 585659584.0, + "15": 614149184.0, + "16": 664915328.0, + "17": 592272320.0, + "18": 630225856.0, + "19": 579959808.0, + "20": 800470080.0, + "21": 573941056.0, + "22": 557652032.0, + "23": 797256640.0, + "24": 826380864.0, + "25": 814860160.0, + "26": 617708032.0, + "27": 715680384.0, + "28": 548045824.0, + "29": 736312064.0, + "30": 722163456.0, + "31": 711986176.0, + "32": 674238208.0, + "33": 715239232.0, + "34": 677588288.0, + "35": 473423392.0, + "36": 451352800.0, + "37": 446739392.0, + "38": 567466304.0, + "39": 472519552.0, + "40": 434322048.0, + "41": 554276096.0, + "42": 526187424.0, + "43": 510713152.0, + "44": 522783808.0, + "45": 335511072.0, + "46": 450878784.0, + "47": 450397344.0, + "48": 321720704.0, + "49": 437443680.0, + "50": 419425088.0 } }, "mem-allocated-bytes": { @@ -118,56 +118,56 @@ "end_step": 50, "step_interval": 1, "values": { - "1": 5498353152.0, - "2": 5499147776.0, - "3": 5499940352.0, - "4": 5500732928.0, - "5": 5501525504.0, - "6": 5502318080.0, - "7": 5503110656.0, - "8": 5503903232.0, - "9": 5497958912.0, - "10": 5498751488.0, - "11": 5499544064.0, - "12": 5500336640.0, - "13": 5501129216.0, - "14": 5501921792.0, - "15": 5502714368.0, - "16": 5503506944.0, - "17": 5504299520.0, - "18": 5505092096.0, - "19": 5505884672.0, - "20": 5506677248.0, - "21": 5507469824.0, - "22": 5508262400.0, - "23": 5509054976.0, - "24": 5509847552.0, - "25": 5510640128.0, - "26": 5511432704.0, - "27": 5512225280.0, - "28": 5513017856.0, - "29": 5513810432.0, - "30": 5514603008.0, - "31": 5515395584.0, - "32": 5516188160.0, - "33": 5516980736.0, - "34": 5517773312.0, - "35": 5518565888.0, - "36": 5519358464.0, - "37": 5520151040.0, - "38": 5520943616.0, - "39": 5521736192.0, - "40": 5522528768.0, - "41": 5523321344.0, - "42": 5524113920.0, - "43": 5524906496.0, - "44": 5525699072.0, - "45": 5526491648.0, - "46": 5527284224.0, - "47": 5528076800.0, - "48": 5528869376.0, - "49": 5529661952.0, - "50": 5530454528.0 + "1": 5498340864.0, + "2": 5499135488.0, + "3": 5499928064.0, + "4": 5500720640.0, + "5": 5501513216.0, + "6": 5502305792.0, + "7": 5497946624.0, + "8": 5498739200.0, + "9": 5499531776.0, + "10": 5500324352.0, + "11": 5501116928.0, + "12": 5498342912.0, + "13": 5499135488.0, + "14": 5499928064.0, + "15": 5500720640.0, + "16": 5501513216.0, + "17": 5502305792.0, + "18": 5503098368.0, + "19": 5503890944.0, + "20": 5504683520.0, + "21": 5505476096.0, + "22": 5506268672.0, + "23": 5507061248.0, + "24": 5507853824.0, + "25": 5508646400.0, + "26": 5509438976.0, + "27": 5510231552.0, + "28": 5511024128.0, + "29": 5511816704.0, + "30": 5512609280.0, + "31": 5513401856.0, + "32": 5514194432.0, + "33": 5514987008.0, + "34": 5515779584.0, + "35": 5516572160.0, + "36": 5517364736.0, + "37": 5518157312.0, + "38": 5518949888.0, + "39": 5519742464.0, + "40": 5520535040.0, + "41": 5521327616.0, + "42": 5522120192.0, + "43": 5522912768.0, + "44": 5523705344.0, + "45": 5524497920.0, + "46": 5525290496.0, + "47": 5526083072.0, + "48": 5526875648.0, + "49": 5527668224.0, + "50": 5528460800.0 } }, "mem-max-allocated-bytes": { @@ -175,56 +175,56 @@ "end_step": 50, "step_interval": 1, "values": { - "1": 41740259328.0, - "2": 43687292928.0, - "3": 43687292928.0, - "4": 43984064512.0, - "5": 43984064512.0, - "6": 43984064512.0, - "7": 43984064512.0, - "8": 44026380288.0, - "9": 44041506816.0, - "10": 44041506816.0, - "11": 44041506816.0, - "12": 44041506816.0, - "13": 44041506816.0, - "14": 44041506816.0, - "15": 44041506816.0, - "16": 44041506816.0, - "17": 44041506816.0, - "18": 44041506816.0, - "19": 44041506816.0, - "20": 44041506816.0, - "21": 44041506816.0, - "22": 44041506816.0, - "23": 44041506816.0, - "24": 44041506816.0, - "25": 44041506816.0, - "26": 44041506816.0, - "27": 44041506816.0, - "28": 44041506816.0, - "29": 44044173312.0, - "30": 44164231168.0, - "31": 44221079552.0, - "32": 44271415296.0, - "33": 44290232320.0, - "34": 44290232320.0, - "35": 44290232320.0, - "36": 44290232320.0, - "37": 44290232320.0, - "38": 44290232320.0, - "39": 44290232320.0, - "40": 44290232320.0, - "41": 44290232320.0, - "42": 44290232320.0, - "43": 44290232320.0, - "44": 44290232320.0, - "45": 44290232320.0, - "46": 44290232320.0, - "47": 44290232320.0, - "48": 44290232320.0, - "49": 44290232320.0, - "50": 44290232320.0 + "1": 41723441152.0, + "2": 43687280640.0, + "3": 43916578816.0, + "4": 43916578816.0, + "5": 43916578816.0, + "6": 43916578816.0, + "7": 43916578816.0, + "8": 43916578816.0, + "9": 43916578816.0, + "10": 43916578816.0, + "11": 43916578816.0, + "12": 44028436480.0, + "13": 44028436480.0, + "14": 44028436480.0, + "15": 44028436480.0, + "16": 44028436480.0, + "17": 44028436480.0, + "18": 44028436480.0, + "19": 44028436480.0, + "20": 44028436480.0, + "21": 44028436480.0, + "22": 44028436480.0, + "23": 44028436480.0, + "24": 44028436480.0, + "25": 44028436480.0, + "26": 44028436480.0, + "27": 44028436480.0, + "28": 44028436480.0, + "29": 44028436480.0, + "30": 44028436480.0, + "31": 44028436480.0, + "32": 44028436480.0, + "33": 44028436480.0, + "34": 44028436480.0, + "35": 44028436480.0, + "36": 44028436480.0, + "37": 44028436480.0, + "38": 44028436480.0, + "39": 44028436480.0, + "40": 44028436480.0, + "41": 44028436480.0, + "42": 44028436480.0, + "43": 44028436480.0, + "44": 44028436480.0, + "45": 44028436480.0, + "46": 44028436480.0, + "47": 44028436480.0, + "48": 44028436480.0, + "49": 44028436480.0, + "50": 44028436480.0 } }, "mtp_1 loss": { @@ -234,54 +234,54 @@ "values": { "1": 11.08623, "2": 11.1047, - "3": 10.47999, - "4": 10.13471, - "5": 9.79045, - "6": 9.50607, - "7": 9.51139, - "8": 8.85331, - "9": 8.66688, - "10": 8.95867, - "11": 8.29318, - "12": 8.36986, - "13": 8.25545, - "14": 7.73323, - "15": 7.86639, - "16": 7.92438, - "17": 7.86276, - "18": 7.61004, - "19": 8.00261, - "20": 7.73004, - "21": 7.41636, - "22": 7.41466, - "23": 7.28656, - "24": 7.27882, - "25": 7.54458, - "26": 6.96533, - "27": 7.5053, - "28": 7.20603, - "29": 7.37687, - "30": 7.52783, - "31": 7.27097, - "32": 7.46043, - "33": 7.51419, - "34": 7.56879, - "35": 7.09276, - "36": 6.96019, - "37": 7.29843, - "38": 7.07417, - "39": 7.43338, - "40": 7.43134, - "41": 7.40946, - "42": 7.15527, - "43": 7.15684, - "44": 7.30429, - "45": 7.18917, - "46": 6.77286, - "47": 7.44985, - "48": 7.02383, - "49": 7.4572, - "50": 6.92645 + "3": 10.54469, + "4": 10.08474, + "5": 9.76549, + "6": 9.56242, + "7": 9.59473, + "8": 8.97686, + "9": 8.83293, + "10": 9.1193, + "11": 8.44318, + "12": 8.49593, + "13": 8.37985, + "14": 7.81516, + "15": 7.95146, + "16": 8.01718, + "17": 7.94503, + "18": 7.68603, + "19": 8.07501, + "20": 7.79558, + "21": 7.46867, + "22": 7.46603, + "23": 7.32734, + "24": 7.32819, + "25": 7.58465, + "26": 6.99257, + "27": 7.53486, + "28": 7.23432, + "29": 7.40501, + "30": 7.55005, + "31": 7.30085, + "32": 7.48028, + "33": 7.53593, + "34": 7.60112, + "35": 7.12344, + "36": 6.99007, + "37": 7.32578, + "38": 7.09623, + "39": 7.45759, + "40": 7.45018, + "41": 7.40101, + "42": 7.14459, + "43": 7.13995, + "44": 7.32066, + "45": 7.0966, + "46": 6.80106, + "47": 7.21219, + "48": 7.05021, + "49": 7.48165, + "50": 6.95118 } }, "iteration-time": { @@ -289,56 +289,56 @@ "end_step": 50, "step_interval": 1, "values": { - "1": 85.92313, - "2": 1.99152, - "3": 3.91366, - "4": 1.68454, - "5": 2.53883, - "6": 2.55539, - "7": 1.60104, - "8": 1.70562, - "9": 1.72325, - "10": 1.4332, - "11": 1.07958, - "12": 1.399, - "13": 1.10259, - "14": 1.43922, - "15": 1.12046, - "16": 1.33695, - "17": 1.24765, - "18": 1.11257, - "19": 1.10335, - "20": 1.12919, - "21": 1.27711, - "22": 1.09482, - "23": 1.27635, - "24": 1.112, - "25": 1.17791, - "26": 1.10426, - "27": 1.09103, - "28": 1.08338, - "29": 1.07904, - "30": 1.08709, - "31": 1.2237, - "32": 1.18059, - "33": 1.07913, - "34": 1.17232, - "35": 1.09059, - "36": 1.09648, - "37": 1.12683, - "38": 1.10153, - "39": 1.09557, - "40": 1.07747, - "41": 1.12905, - "42": 1.09275, - "43": 1.08609, - "44": 1.08042, - "45": 1.08321, - "46": 1.0732, - "47": 1.08666, - "48": 1.08865, - "49": 1.08808, - "50": 1.08086 + "1": 87.66203, + "2": 2.04189, + "3": 3.34278, + "4": 3.72414, + "5": 3.23492, + "6": 1.94546, + "7": 2.14942, + "8": 1.78075, + "9": 1.06029, + "10": 2.13554, + "11": 1.42578, + "12": 1.80986, + "13": 1.06134, + "14": 1.087, + "15": 1.16687, + "16": 1.20412, + "17": 1.06984, + "18": 1.07557, + "19": 1.04081, + "20": 1.21763, + "21": 1.06196, + "22": 1.14038, + "23": 2.25761, + "24": 1.09161, + "25": 1.04319, + "26": 1.40025, + "27": 1.04974, + "28": 1.03984, + "29": 1.05293, + "30": 1.48942, + "31": 1.04785, + "32": 1.0529, + "33": 1.04366, + "34": 1.0633, + "35": 1.0713, + "36": 1.05711, + "37": 1.08085, + "38": 1.07006, + "39": 1.06498, + "40": 1.05913, + "41": 1.0697, + "42": 1.079, + "43": 1.14122, + "44": 1.06478, + "45": 1.04692, + "46": 1.08174, + "47": 1.07595, + "48": 1.10523, + "49": 1.0839, + "50": 1.07754 } } } \ No newline at end of file diff --git a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/golden_values_dev_dgx_h100.json b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/golden_values_dev_dgx_h100.json index eca2cabacaf..f3ef4646971 100644 --- a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/golden_values_dev_dgx_h100.json +++ b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/golden_values_dev_dgx_h100.json @@ -6,54 +6,54 @@ "values": { "1": 11.04276, "2": 11.02298, - "3": 9.43542, - "4": 10.04672, - "5": 9.38572, - "6": 9.14547, - "7": 9.21155, - "8": 8.63445, - "9": 8.48944, - "10": 8.82764, - "11": 8.29479, - "12": 8.32819, - "13": 8.23003, - "14": 7.71724, - "15": 7.86963, - "16": 7.9228, - "17": 7.86049, - "18": 7.62035, - "19": 7.9851, - "20": 7.72027, - "21": 7.39754, - "22": 7.39767, - "23": 7.28334, - "24": 7.25057, - "25": 7.53131, - "26": 6.95335, - "27": 7.49421, - "28": 7.20415, - "29": 7.373, - "30": 7.50279, - "31": 7.25342, - "32": 7.43069, - "33": 7.48385, - "34": 7.53476, - "35": 7.10325, - "36": 6.94471, - "37": 7.26141, - "38": 7.07026, - "39": 7.40536, - "40": 7.42025, - "41": 7.34194, - "42": 7.11724, - "43": 7.11421, - "44": 7.27077, - "45": 7.0701, - "46": 6.77811, - "47": 7.18895, - "48": 7.00013, - "49": 7.45875, - "50": 6.90988 + "3": 9.50921, + "4": 10.86244, + "5": 9.36127, + "6": 9.05636, + "7": 9.20064, + "8": 8.98909, + "9": 8.67001, + "10": 9.00892, + "11": 8.50716, + "12": 8.45579, + "13": 8.41197, + "14": 7.92802, + "15": 7.99663, + "16": 8.04156, + "17": 8.06453, + "18": 7.73746, + "19": 8.09946, + "20": 7.85555, + "21": 7.54063, + "22": 7.51142, + "23": 7.39766, + "24": 7.36551, + "25": 7.63399, + "26": 7.04934, + "27": 7.60084, + "28": 7.30223, + "29": 7.47164, + "30": 7.61428, + "31": 7.34981, + "32": 7.53935, + "33": 7.59164, + "34": 7.64951, + "35": 7.18657, + "36": 7.03804, + "37": 7.36778, + "38": 7.14613, + "39": 7.50644, + "40": 7.51103, + "41": 7.44582, + "42": 7.20666, + "43": 7.2123, + "44": 7.37723, + "45": 7.17293, + "46": 6.86188, + "47": 7.2648, + "48": 7.1069, + "49": 7.56115, + "50": 7.00113 } }, "num-zeros": { @@ -61,56 +61,56 @@ "end_step": 50, "step_interval": 1, "values": { - "1": 844114112.0, - "2": 843855296.0, - "3": 844048640.0, - "4": 842998208.0, - "5": 855786112.0, - "6": 878524160.0, - "7": 924542976.0, - "8": 917741504.0, - "9": 932042112.0, - "10": 930847360.0, - "11": 954742400.0, - "12": 922824128.0, - "13": 968378816.0, - "14": 965228416.0, - "15": 951776640.0, - "16": 941679424.0, - "17": 929894336.0, - "18": 928011136.0, - "19": 955339264.0, - "20": 987111232.0, - "21": 924095488.0, - "22": 906805504.0, - "23": 895810432.0, - "24": 902927680.0, - "25": 927056960.0, - "26": 879821440.0, - "27": 911759744.0, - "28": 902460416.0, - "29": 872625216.0, - "30": 865815744.0, - "31": 868220352.0, - "32": 865076800.0, - "33": 864135552.0, - "34": 855839104.0, - "35": 854046784.0, - "36": 855042176.0, - "37": 850408192.0, - "38": 850580480.0, - "39": 849972608.0, - "40": 849505792.0, - "41": 845780352.0, - "42": 846003392.0, - "43": 848354688.0, - "44": 850986496.0, - "45": 848236160.0, - "46": 855625856.0, - "47": 843613312.0, - "48": 851197312.0, - "49": 851630464.0, - "50": 846195968.0 + "1": 38808176.0, + "2": 38549232.0, + "3": 38741780.0, + "4": 78604016.0, + "5": 152229680.0, + "6": 299762016.0, + "7": 557587712.0, + "8": 589584384.0, + "9": 482229120.0, + "10": 517739584.0, + "11": 526962624.0, + "12": 476182528.0, + "13": 667453056.0, + "14": 563635200.0, + "15": 592125760.0, + "16": 589362048.0, + "17": 453879424.0, + "18": 444631456.0, + "19": 532791520.0, + "20": 677797248.0, + "21": 545577920.0, + "22": 494731040.0, + "23": 551928576.0, + "24": 489800928.0, + "25": 644993344.0, + "26": 441532864.0, + "27": 467175040.0, + "28": 431687840.0, + "29": 409185824.0, + "30": 583756032.0, + "31": 592451072.0, + "32": 416290048.0, + "33": 391230880.0, + "34": 325273120.0, + "35": 350756576.0, + "36": 331801376.0, + "37": 349196160.0, + "38": 312664800.0, + "39": 419015584.0, + "40": 299035872.0, + "41": 274307296.0, + "42": 296551584.0, + "43": 381740640.0, + "44": 308872480.0, + "45": 263141648.0, + "46": 353360864.0, + "47": 271093472.0, + "48": 346833600.0, + "49": 267589936.0, + "50": 252702768.0 } }, "mem-allocated-bytes": { @@ -177,54 +177,54 @@ "values": { "1": 37959917568.0, "2": 39578673152.0, - "3": 39580192768.0, - "4": 39580192768.0, - "5": 39583301632.0, - "6": 39583301632.0, - "7": 39583301632.0, - "8": 39583301632.0, - "9": 39583301632.0, - "10": 39583301632.0, - "11": 39583301632.0, - "12": 39583301632.0, - "13": 39583301632.0, - "14": 39583301632.0, - "15": 39583301632.0, - "16": 39583301632.0, - "17": 39583301632.0, - "18": 39583301632.0, - "19": 39583301632.0, - "20": 39583301632.0, - "21": 39583301632.0, - "22": 39583301632.0, - "23": 39583301632.0, - "24": 39583301632.0, - "25": 39583301632.0, - "26": 39583301632.0, - "27": 39583301632.0, - "28": 39583301632.0, - "29": 39583301632.0, - "30": 39583301632.0, - "31": 39583301632.0, - "32": 39583301632.0, - "33": 39583301632.0, - "34": 39583301632.0, - "35": 39583301632.0, - "36": 39583301632.0, - "37": 39583301632.0, - "38": 39583301632.0, - "39": 39583301632.0, - "40": 39583301632.0, - "41": 39583301632.0, - "42": 39583301632.0, - "43": 39583301632.0, - "44": 39583301632.0, - "45": 39583301632.0, - "46": 39583301632.0, - "47": 39583301632.0, - "48": 39583301632.0, - "49": 39583301632.0, - "50": 39583301632.0 + "3": 39583825920.0, + "4": 39583825920.0, + "5": 39586181120.0, + "6": 39586181120.0, + "7": 39586181120.0, + "8": 39586181120.0, + "9": 39586181120.0, + "10": 39586181120.0, + "11": 39586181120.0, + "12": 39586181120.0, + "13": 39586181120.0, + "14": 39586181120.0, + "15": 39586181120.0, + "16": 39586181120.0, + "17": 39586181120.0, + "18": 39586181120.0, + "19": 39586181120.0, + "20": 39586181120.0, + "21": 39586181120.0, + "22": 39586181120.0, + "23": 39586181120.0, + "24": 39586181120.0, + "25": 39586181120.0, + "26": 39586181120.0, + "27": 39586181120.0, + "28": 39586181120.0, + "29": 39586181120.0, + "30": 39586181120.0, + "31": 39586181120.0, + "32": 39586181120.0, + "33": 39586181120.0, + "34": 39586181120.0, + "35": 39586181120.0, + "36": 39586181120.0, + "37": 39586181120.0, + "38": 39586181120.0, + "39": 39586181120.0, + "40": 39586181120.0, + "41": 39586181120.0, + "42": 39586181120.0, + "43": 39586181120.0, + "44": 39586181120.0, + "45": 39586181120.0, + "46": 39586181120.0, + "47": 39586181120.0, + "48": 39586181120.0, + "49": 39586181120.0, + "50": 39586181120.0 } }, "iteration-time": { @@ -232,56 +232,56 @@ "end_step": 50, "step_interval": 1, "values": { - "1": 89.14162, - "2": 2.00665, - "3": 3.2832, - "4": 2.63833, - "5": 2.43073, - "6": 1.4868, - "7": 1.81732, - "8": 2.74562, - "9": 1.18286, - "10": 1.18542, - "11": 1.27273, - "12": 1.63885, - "13": 1.31323, - "14": 2.29007, - "15": 1.52021, - "16": 1.87975, - "17": 1.3507, - "18": 1.48627, - "19": 1.17842, - "20": 1.17004, - "21": 1.30369, - "22": 1.24781, - "23": 1.13565, - "24": 1.13418, - "25": 1.21915, - "26": 1.24288, - "27": 1.15052, - "28": 1.12573, - "29": 1.15398, - "30": 1.13143, - "31": 1.17104, - "32": 1.12919, - "33": 1.1286, - "34": 1.14327, - "35": 1.1721, - "36": 1.12494, - "37": 1.2626, - "38": 1.11425, - "39": 1.14594, - "40": 1.18189, - "41": 1.09297, - "42": 1.09247, - "43": 1.18621, - "44": 1.19564, - "45": 1.08252, - "46": 1.08511, - "47": 1.23319, - "48": 1.08249, - "49": 1.0979, - "50": 1.07182 + "1": 65.48328, + "2": 1.94615, + "3": 3.94539, + "4": 2.42699, + "5": 1.80319, + "6": 1.79395, + "7": 1.50546, + "8": 2.00251, + "9": 1.2172, + "10": 1.31071, + "11": 1.3171, + "12": 1.10351, + "13": 1.26314, + "14": 1.47608, + "15": 1.19001, + "16": 1.12949, + "17": 1.15105, + "18": 1.06698, + "19": 1.10069, + "20": 1.12463, + "21": 1.35075, + "22": 1.56258, + "23": 1.2368, + "24": 1.13707, + "25": 1.11826, + "26": 1.09445, + "27": 1.08857, + "28": 1.07964, + "29": 1.08505, + "30": 1.24068, + "31": 1.10419, + "32": 1.5164, + "33": 1.10245, + "34": 1.37977, + "35": 1.1165, + "36": 1.1457, + "37": 1.10487, + "38": 1.08491, + "39": 1.08901, + "40": 1.08968, + "41": 1.13702, + "42": 1.09805, + "43": 1.06669, + "44": 1.07791, + "45": 1.08898, + "46": 1.10717, + "47": 1.13008, + "48": 1.05745, + "49": 1.08268, + "50": 1.05678 } } } \ No newline at end of file From 993789010373c1e1844fb07b80c7ff726bc1c8ad Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Wed, 5 Nov 2025 01:04:35 -0800 Subject: [PATCH 05/74] update golden values Signed-off-by: Hongbin Liu --- .../golden_values_dev_dgx_h100.json | 392 +++++++++--------- 1 file changed, 196 insertions(+), 196 deletions(-) diff --git a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/golden_values_dev_dgx_h100.json b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/golden_values_dev_dgx_h100.json index f3ef4646971..150ba70462f 100644 --- a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/golden_values_dev_dgx_h100.json +++ b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/golden_values_dev_dgx_h100.json @@ -6,54 +6,54 @@ "values": { "1": 11.04276, "2": 11.02298, - "3": 9.50921, - "4": 10.86244, - "5": 9.36127, - "6": 9.05636, - "7": 9.20064, - "8": 8.98909, - "9": 8.67001, - "10": 9.00892, - "11": 8.50716, - "12": 8.45579, - "13": 8.41197, - "14": 7.92802, - "15": 7.99663, - "16": 8.04156, - "17": 8.06453, - "18": 7.73746, - "19": 8.09946, - "20": 7.85555, - "21": 7.54063, - "22": 7.51142, - "23": 7.39766, - "24": 7.36551, - "25": 7.63399, - "26": 7.04934, - "27": 7.60084, - "28": 7.30223, - "29": 7.47164, - "30": 7.61428, - "31": 7.34981, - "32": 7.53935, - "33": 7.59164, - "34": 7.64951, - "35": 7.18657, - "36": 7.03804, - "37": 7.36778, - "38": 7.14613, - "39": 7.50644, - "40": 7.51103, - "41": 7.44582, - "42": 7.20666, - "43": 7.2123, - "44": 7.37723, - "45": 7.17293, - "46": 6.86188, - "47": 7.2648, - "48": 7.1069, - "49": 7.56115, - "50": 7.00113 + "3": 9.50907, + "4": 10.86145, + "5": 9.36104, + "6": 9.05664, + "7": 9.20646, + "8": 9.00188, + "9": 8.69791, + "10": 8.97535, + "11": 8.48206, + "12": 8.44961, + "13": 8.38916, + "14": 7.90422, + "15": 7.98559, + "16": 8.02787, + "17": 8.04894, + "18": 7.72163, + "19": 8.0935, + "20": 7.85609, + "21": 7.53372, + "22": 7.50495, + "23": 7.39733, + "24": 7.36369, + "25": 7.62993, + "26": 7.04703, + "27": 7.59839, + "28": 7.29807, + "29": 7.46826, + "30": 7.60613, + "31": 7.34795, + "32": 7.53766, + "33": 7.58939, + "34": 7.64431, + "35": 7.18358, + "36": 7.036, + "37": 7.36506, + "38": 7.14525, + "39": 7.50347, + "40": 7.50925, + "41": 7.44415, + "42": 7.20526, + "43": 7.21039, + "44": 7.37585, + "45": 7.1698, + "46": 6.8612, + "47": 7.26258, + "48": 7.1033, + "49": 7.55974, + "50": 6.99878 } }, "num-zeros": { @@ -61,56 +61,56 @@ "end_step": 50, "step_interval": 1, "values": { - "1": 38808176.0, - "2": 38549232.0, - "3": 38741780.0, - "4": 78604016.0, - "5": 152229680.0, - "6": 299762016.0, - "7": 557587712.0, - "8": 589584384.0, - "9": 482229120.0, - "10": 517739584.0, - "11": 526962624.0, - "12": 476182528.0, - "13": 667453056.0, - "14": 563635200.0, - "15": 592125760.0, - "16": 589362048.0, - "17": 453879424.0, - "18": 444631456.0, - "19": 532791520.0, - "20": 677797248.0, - "21": 545577920.0, - "22": 494731040.0, - "23": 551928576.0, - "24": 489800928.0, - "25": 644993344.0, - "26": 441532864.0, - "27": 467175040.0, - "28": 431687840.0, - "29": 409185824.0, - "30": 583756032.0, - "31": 592451072.0, - "32": 416290048.0, - "33": 391230880.0, - "34": 325273120.0, - "35": 350756576.0, - "36": 331801376.0, - "37": 349196160.0, - "38": 312664800.0, - "39": 419015584.0, - "40": 299035872.0, - "41": 274307296.0, - "42": 296551584.0, - "43": 381740640.0, - "44": 308872480.0, - "45": 263141648.0, - "46": 353360864.0, - "47": 271093472.0, - "48": 346833600.0, - "49": 267589936.0, - "50": 252702768.0 + "1": 38808152.0, + "2": 38549168.0, + "3": 38741680.0, + "4": 81738424.0, + "5": 161659808.0, + "6": 296608000.0, + "7": 557581568.0, + "8": 592711744.0, + "9": 479088640.0, + "10": 520896096.0, + "11": 555256320.0, + "12": 444724480.0, + "13": 658029440.0, + "14": 585665280.0, + "15": 588986240.0, + "16": 479280192.0, + "17": 494748608.0, + "18": 504398944.0, + "19": 601982144.0, + "20": 787884160.0, + "21": 536156160.0, + "22": 513609344.0, + "23": 577056256.0, + "24": 549563712.0, + "25": 648153280.0, + "26": 498150784.0, + "27": 501770816.0, + "28": 522921920.0, + "29": 462644416.0, + "30": 612066112.0, + "31": 605029312.0, + "32": 454036160.0, + "33": 419547936.0, + "34": 378748896.0, + "35": 385339904.0, + "36": 350676768.0, + "37": 478164480.0, + "38": 337833600.0, + "39": 450472544.0, + "40": 267556496.0, + "41": 280614912.0, + "42": 305998368.0, + "43": 372298848.0, + "44": 261697280.0, + "45": 225394720.0, + "46": 268431392.0, + "47": 217617888.0, + "48": 261904016.0, + "49": 229846288.0, + "50": 214954112.0 } }, "mem-allocated-bytes": { @@ -177,54 +177,54 @@ "values": { "1": 37959917568.0, "2": 39578673152.0, - "3": 39583825920.0, - "4": 39583825920.0, - "5": 39586181120.0, - "6": 39586181120.0, - "7": 39586181120.0, - "8": 39586181120.0, - "9": 39586181120.0, - "10": 39586181120.0, - "11": 39586181120.0, - "12": 39586181120.0, - "13": 39586181120.0, - "14": 39586181120.0, - "15": 39586181120.0, - "16": 39586181120.0, - "17": 39586181120.0, - "18": 39586181120.0, - "19": 39586181120.0, - "20": 39586181120.0, - "21": 39586181120.0, - "22": 39586181120.0, - "23": 39586181120.0, - "24": 39586181120.0, - "25": 39586181120.0, - "26": 39586181120.0, - "27": 39586181120.0, - "28": 39586181120.0, - "29": 39586181120.0, - "30": 39586181120.0, - "31": 39586181120.0, - "32": 39586181120.0, - "33": 39586181120.0, - "34": 39586181120.0, - "35": 39586181120.0, - "36": 39586181120.0, - "37": 39586181120.0, - "38": 39586181120.0, - "39": 39586181120.0, - "40": 39586181120.0, - "41": 39586181120.0, - "42": 39586181120.0, - "43": 39586181120.0, - "44": 39586181120.0, - "45": 39586181120.0, - "46": 39586181120.0, - "47": 39586181120.0, - "48": 39586181120.0, - "49": 39586181120.0, - "50": 39586181120.0 + "3": 39583842304.0, + "4": 39583842304.0, + "5": 39584591872.0, + "6": 39584591872.0, + "7": 39584591872.0, + "8": 39584591872.0, + "9": 39584591872.0, + "10": 39584591872.0, + "11": 39584591872.0, + "12": 39584591872.0, + "13": 39584591872.0, + "14": 39584591872.0, + "15": 39584591872.0, + "16": 39584591872.0, + "17": 39584591872.0, + "18": 39584591872.0, + "19": 39584591872.0, + "20": 39584591872.0, + "21": 39584591872.0, + "22": 39584591872.0, + "23": 39584591872.0, + "24": 39584591872.0, + "25": 39584591872.0, + "26": 39584591872.0, + "27": 39584591872.0, + "28": 39584591872.0, + "29": 39584591872.0, + "30": 39584591872.0, + "31": 39584591872.0, + "32": 39584591872.0, + "33": 39584591872.0, + "34": 39584591872.0, + "35": 39584591872.0, + "36": 39584591872.0, + "37": 39584591872.0, + "38": 39584591872.0, + "39": 39584591872.0, + "40": 39584591872.0, + "41": 39584591872.0, + "42": 39584591872.0, + "43": 39584591872.0, + "44": 39584591872.0, + "45": 39584591872.0, + "46": 39584591872.0, + "47": 39584591872.0, + "48": 39584591872.0, + "49": 39584591872.0, + "50": 39584591872.0 } }, "iteration-time": { @@ -232,56 +232,56 @@ "end_step": 50, "step_interval": 1, "values": { - "1": 65.48328, - "2": 1.94615, - "3": 3.94539, - "4": 2.42699, - "5": 1.80319, - "6": 1.79395, - "7": 1.50546, - "8": 2.00251, - "9": 1.2172, - "10": 1.31071, - "11": 1.3171, - "12": 1.10351, - "13": 1.26314, - "14": 1.47608, - "15": 1.19001, - "16": 1.12949, - "17": 1.15105, - "18": 1.06698, - "19": 1.10069, - "20": 1.12463, - "21": 1.35075, - "22": 1.56258, - "23": 1.2368, - "24": 1.13707, - "25": 1.11826, - "26": 1.09445, - "27": 1.08857, - "28": 1.07964, - "29": 1.08505, - "30": 1.24068, - "31": 1.10419, - "32": 1.5164, - "33": 1.10245, - "34": 1.37977, - "35": 1.1165, - "36": 1.1457, - "37": 1.10487, - "38": 1.08491, - "39": 1.08901, - "40": 1.08968, - "41": 1.13702, - "42": 1.09805, - "43": 1.06669, - "44": 1.07791, - "45": 1.08898, - "46": 1.10717, - "47": 1.13008, - "48": 1.05745, - "49": 1.08268, - "50": 1.05678 + "1": 65.95827, + "2": 1.9924, + "3": 3.92592, + "4": 2.4652, + "5": 1.84842, + "6": 1.80402, + "7": 1.67822, + "8": 1.88485, + "9": 1.32993, + "10": 1.37648, + "11": 1.18596, + "12": 1.16521, + "13": 1.14524, + "14": 1.34968, + "15": 1.22798, + "16": 1.10709, + "17": 1.2737, + "18": 1.12048, + "19": 1.44431, + "20": 1.22659, + "21": 1.23111, + "22": 1.27597, + "23": 1.25479, + "24": 1.12437, + "25": 1.28457, + "26": 1.26411, + "27": 1.16703, + "28": 1.13595, + "29": 1.24774, + "30": 1.10985, + "31": 1.3919, + "32": 1.10386, + "33": 1.20402, + "34": 1.08667, + "35": 1.10247, + "36": 1.09087, + "37": 1.16339, + "38": 1.12236, + "39": 1.10519, + "40": 1.20224, + "41": 1.11719, + "42": 1.18432, + "43": 1.11065, + "44": 1.14205, + "45": 1.12352, + "46": 1.09449, + "47": 1.10298, + "48": 1.10504, + "49": 1.09853, + "50": 1.0939 } } } \ No newline at end of file From 6c83118d55f537feabb1d934ef7437bcd8ed673f Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Thu, 6 Nov 2025 02:31:18 -0800 Subject: [PATCH 06/74] update model_config and golden values Signed-off-by: Hongbin Liu --- .../transformer/multi_token_prediction.py | 3 + megatron/training/arguments.py | 2 +- .../golden_values_dev_dgx_h100.json | 600 +++++++++--------- .../model_config.yaml | 6 +- .../golden_values_dev_dgx_h100.json | 500 +++++++-------- .../model_config.yaml | 6 +- 6 files changed, 560 insertions(+), 557 deletions(-) diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py index 945682741d4..80f72a91ff2 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -29,6 +29,9 @@ make_tp_sharded_tensor_for_checkpoint, make_viewless_tensor, ) +from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + fine_grained_offloading_set_last_layer +) if is_torch_min_version("1.13.0"): dist_all_gather_func = torch.distributed.all_gather_into_tensor diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 3413d1e1547..c91bb536fea 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -2346,7 +2346,7 @@ def _add_training_args(parser): group.add_argument('--fine-grained-activation-offloading', action='store_true', help='Enable fine-grained activation offloading.') group.add_argument('--offload-modules', nargs='*', type=str, default=[], - help='The submodules to offload its input. Choices: "attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act".') + help='The submodules to offload its input. Choices: "attn_norm", "qkv_linear", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act".') group.add_argument('--min-offloaded-tensor-size', type=int, default=1024*1024, help='The minimum size of the tensor to be offloaded.') return parser diff --git a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/golden_values_dev_dgx_h100.json b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/golden_values_dev_dgx_h100.json index 4b32d4256db..e7f62bbe4af 100644 --- a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/golden_values_dev_dgx_h100.json +++ b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/golden_values_dev_dgx_h100.json @@ -4,56 +4,56 @@ "end_step": 50, "step_interval": 1, "values": { - "1": 11.07559, - "2": 11.03834, - "3": 9.72869, - "4": 9.61678, - "5": 10.63323, - "6": 9.1681, - "7": 9.35196, - "8": 9.05204, - "9": 8.84148, - "10": 9.00321, - "11": 8.49799, - "12": 8.5218, - "13": 8.41649, - "14": 7.9096, - "15": 8.00627, - "16": 8.05394, - "17": 8.0203, - "18": 7.73136, - "19": 8.11676, - "20": 7.83945, - "21": 7.52196, - "22": 7.5295, - "23": 7.38729, - "24": 7.3758, - "25": 7.65255, - "26": 7.04795, - "27": 7.591, - "28": 7.30023, - "29": 7.45656, - "30": 7.60935, - "31": 7.3713, - "32": 7.55298, - "33": 7.59738, - "34": 7.65764, - "35": 7.17916, - "36": 7.04913, - "37": 7.38022, - "38": 7.14883, - "39": 7.50321, - "40": 7.51595, - "41": 7.45139, - "42": 7.21197, - "43": 7.21131, - "44": 7.38058, - "45": 7.16397, - "46": 6.86108, - "47": 7.27247, - "48": 7.10862, - "49": 7.56398, - "50": 7.00523 + "1": 11.06715, + "2": 11.06051, + "3": 10.21154, + "4": 9.95175, + "5": 10.12622, + "6": 8.82146, + "7": 9.52879, + "8": 8.442, + "9": 7.84738, + "10": 7.07075, + "11": 9.31042, + "12": 9.16013, + "13": 7.87292, + "14": 8.2102, + "15": 8.22483, + "16": 8.17879, + "17": 8.21121, + "18": 7.50325, + "19": 8.08274, + "20": 7.62562, + "21": 7.95058, + "22": 7.29789, + "23": 7.93775, + "24": 7.44169, + "25": 8.23817, + "26": 7.74959, + "27": 7.69344, + "28": 7.65487, + "29": 7.75173, + "30": 7.56007, + "31": 7.81567, + "32": 6.46589, + "33": 7.20401, + "34": 7.77921, + "35": 7.72944, + "36": 6.71776, + "37": 8.08311, + "38": 7.6137, + "39": 7.96476, + "40": 7.50072, + "41": 7.50304, + "42": 6.11349, + "43": 7.59404, + "44": 7.91361, + "45": 6.83615, + "46": 7.41293, + "47": 7.79226, + "48": 7.87549, + "49": 7.58763, + "50": 6.84525 } }, "num-zeros": { @@ -61,56 +61,56 @@ "end_step": 50, "step_interval": 1, "values": { - "1": 38802120.0, - "2": 38543052.0, - "3": 38738396.0, - "4": 113220144.0, - "5": 344100160.0, - "6": 435062816.0, - "7": 579598912.0, - "8": 819195200.0, - "9": 604910464.0, - "10": 690749824.0, - "11": 744002496.0, - "12": 520212192.0, - "13": 547932992.0, - "14": 585659584.0, - "15": 614149184.0, - "16": 664915328.0, - "17": 592272320.0, - "18": 630225856.0, - "19": 579959808.0, - "20": 800470080.0, - "21": 573941056.0, - "22": 557652032.0, - "23": 797256640.0, - "24": 826380864.0, - "25": 814860160.0, - "26": 617708032.0, - "27": 715680384.0, - "28": 548045824.0, - "29": 736312064.0, - "30": 722163456.0, - "31": 711986176.0, - "32": 674238208.0, - "33": 715239232.0, - "34": 677588288.0, - "35": 473423392.0, - "36": 451352800.0, - "37": 446739392.0, - "38": 567466304.0, - "39": 472519552.0, - "40": 434322048.0, - "41": 554276096.0, - "42": 526187424.0, - "43": 510713152.0, - "44": 522783808.0, - "45": 335511072.0, - "46": 450878784.0, - "47": 450397344.0, - "48": 321720704.0, - "49": 437443680.0, - "50": 419425088.0 + "1": 47165192.0, + "2": 46897912.0, + "3": 52684456.0, + "4": 297127552.0, + "5": 562950784.0, + "6": 668142144.0, + "7": 1027449536.0, + "8": 752259328.0, + "9": 830947776.0, + "10": 718307136.0, + "11": 823731840.0, + "12": 804867840.0, + "13": 639461056.0, + "14": 625408576.0, + "15": 716256960.0, + "16": 870866752.0, + "17": 673817856.0, + "18": 811900096.0, + "19": 892689024.0, + "20": 878114112.0, + "21": 666859968.0, + "22": 792718848.0, + "23": 783683200.0, + "24": 770686976.0, + "25": 651376640.0, + "26": 780070272.0, + "27": 801722496.0, + "28": 670273664.0, + "29": 647960768.0, + "30": 789867776.0, + "31": 801385856.0, + "32": 787688640.0, + "33": 783506816.0, + "34": 792837760.0, + "35": 776103936.0, + "36": 761920512.0, + "37": 775085824.0, + "38": 752868608.0, + "39": 754997184.0, + "40": 745075072.0, + "41": 713941440.0, + "42": 689968512.0, + "43": 663461824.0, + "44": 680285632.0, + "45": 644628992.0, + "46": 641672704.0, + "47": 642439616.0, + "48": 597700608.0, + "49": 603523520.0, + "50": 601014528.0 } }, "mem-allocated-bytes": { @@ -118,56 +118,56 @@ "end_step": 50, "step_interval": 1, "values": { - "1": 5498340864.0, - "2": 5499135488.0, - "3": 5499928064.0, - "4": 5500720640.0, - "5": 5501513216.0, - "6": 5502305792.0, - "7": 5497946624.0, - "8": 5498739200.0, - "9": 5499531776.0, - "10": 5500324352.0, - "11": 5501116928.0, - "12": 5498342912.0, - "13": 5499135488.0, - "14": 5499928064.0, - "15": 5500720640.0, - "16": 5501513216.0, - "17": 5502305792.0, - "18": 5503098368.0, - "19": 5503890944.0, - "20": 5504683520.0, - "21": 5505476096.0, - "22": 5506268672.0, - "23": 5507061248.0, - "24": 5507853824.0, - "25": 5508646400.0, - "26": 5509438976.0, - "27": 5510231552.0, - "28": 5511024128.0, - "29": 5511816704.0, - "30": 5512609280.0, - "31": 5513401856.0, - "32": 5514194432.0, - "33": 5514987008.0, - "34": 5515779584.0, - "35": 5516572160.0, - "36": 5517364736.0, - "37": 5518157312.0, - "38": 5518949888.0, - "39": 5519742464.0, - "40": 5520535040.0, - "41": 5521327616.0, - "42": 5522120192.0, - "43": 5522912768.0, - "44": 5523705344.0, - "45": 5524497920.0, - "46": 5525290496.0, - "47": 5526083072.0, - "48": 5526875648.0, - "49": 5527668224.0, - "50": 5528460800.0 + "1": 5290944000.0, + "2": 5291148800.0, + "3": 5291351552.0, + "4": 5290946048.0, + "5": 5291148800.0, + "6": 5291351552.0, + "7": 5291554304.0, + "8": 5291757056.0, + "9": 5291959808.0, + "10": 5292162560.0, + "11": 5292365312.0, + "12": 5292568064.0, + "13": 5292770816.0, + "14": 5292973568.0, + "15": 5293176320.0, + "16": 5293379072.0, + "17": 5293581824.0, + "18": 5293784576.0, + "19": 5293987328.0, + "20": 5294190080.0, + "21": 5294392832.0, + "22": 5294595584.0, + "23": 5294798336.0, + "24": 5295001088.0, + "25": 5295203840.0, + "26": 5295406592.0, + "27": 5295609344.0, + "28": 5295812096.0, + "29": 5296014848.0, + "30": 5296217600.0, + "31": 5296420352.0, + "32": 5296623104.0, + "33": 5296825856.0, + "34": 5297028608.0, + "35": 5297231360.0, + "36": 5297434112.0, + "37": 5297636864.0, + "38": 5297839616.0, + "39": 5298042368.0, + "40": 5298245120.0, + "41": 5298447872.0, + "42": 5298650624.0, + "43": 5298853376.0, + "44": 5299056128.0, + "45": 5299258880.0, + "46": 5299461632.0, + "47": 5299664384.0, + "48": 5299867136.0, + "49": 5300069888.0, + "50": 5300272640.0 } }, "mem-max-allocated-bytes": { @@ -175,56 +175,56 @@ "end_step": 50, "step_interval": 1, "values": { - "1": 41723441152.0, - "2": 43687280640.0, - "3": 43916578816.0, - "4": 43916578816.0, - "5": 43916578816.0, - "6": 43916578816.0, - "7": 43916578816.0, - "8": 43916578816.0, - "9": 43916578816.0, - "10": 43916578816.0, - "11": 43916578816.0, - "12": 44028436480.0, - "13": 44028436480.0, - "14": 44028436480.0, - "15": 44028436480.0, - "16": 44028436480.0, - "17": 44028436480.0, - "18": 44028436480.0, - "19": 44028436480.0, - "20": 44028436480.0, - "21": 44028436480.0, - "22": 44028436480.0, - "23": 44028436480.0, - "24": 44028436480.0, - "25": 44028436480.0, - "26": 44028436480.0, - "27": 44028436480.0, - "28": 44028436480.0, - "29": 44028436480.0, - "30": 44028436480.0, - "31": 44028436480.0, - "32": 44028436480.0, - "33": 44028436480.0, - "34": 44028436480.0, - "35": 44028436480.0, - "36": 44028436480.0, - "37": 44028436480.0, - "38": 44028436480.0, - "39": 44028436480.0, - "40": 44028436480.0, - "41": 44028436480.0, - "42": 44028436480.0, - "43": 44028436480.0, - "44": 44028436480.0, - "45": 44028436480.0, - "46": 44028436480.0, - "47": 44028436480.0, - "48": 44028436480.0, - "49": 44028436480.0, - "50": 44028436480.0 + "1": 6180783616.0, + "2": 8225679872.0, + "3": 8225679872.0, + "4": 8225679872.0, + "5": 8225679872.0, + "6": 8225679872.0, + "7": 8225679872.0, + "8": 8225679872.0, + "9": 8225679872.0, + "10": 8225679872.0, + "11": 8239991296.0, + "12": 8239991296.0, + "13": 8239991296.0, + "14": 8239991296.0, + "15": 8239991296.0, + "16": 8239991296.0, + "17": 8244914688.0, + "18": 8244914688.0, + "19": 8244914688.0, + "20": 8265598464.0, + "21": 8265598464.0, + "22": 8265598464.0, + "23": 8265598464.0, + "24": 8265598464.0, + "25": 8265598464.0, + "26": 8265598464.0, + "27": 8265598464.0, + "28": 8265598464.0, + "29": 8271664640.0, + "30": 8316803584.0, + "31": 8316803584.0, + "32": 8316803584.0, + "33": 8316803584.0, + "34": 8316803584.0, + "35": 8316803584.0, + "36": 8316803584.0, + "37": 8316803584.0, + "38": 8316803584.0, + "39": 8318923264.0, + "40": 8318923264.0, + "41": 8318923264.0, + "42": 8318923264.0, + "43": 8318923264.0, + "44": 8318923264.0, + "45": 8318923264.0, + "46": 8318923264.0, + "47": 8318923264.0, + "48": 8318923264.0, + "49": 8318923264.0, + "50": 8318923264.0 } }, "mtp_1 loss": { @@ -232,56 +232,56 @@ "end_step": 50, "step_interval": 1, "values": { - "1": 11.08623, - "2": 11.1047, - "3": 10.54469, - "4": 10.08474, - "5": 9.76549, - "6": 9.56242, - "7": 9.59473, - "8": 8.97686, - "9": 8.83293, - "10": 9.1193, - "11": 8.44318, - "12": 8.49593, - "13": 8.37985, - "14": 7.81516, - "15": 7.95146, - "16": 8.01718, - "17": 7.94503, - "18": 7.68603, - "19": 8.07501, - "20": 7.79558, - "21": 7.46867, - "22": 7.46603, - "23": 7.32734, - "24": 7.32819, - "25": 7.58465, - "26": 6.99257, - "27": 7.53486, - "28": 7.23432, - "29": 7.40501, - "30": 7.55005, - "31": 7.30085, - "32": 7.48028, - "33": 7.53593, - "34": 7.60112, - "35": 7.12344, - "36": 6.99007, - "37": 7.32578, - "38": 7.09623, - "39": 7.45759, - "40": 7.45018, - "41": 7.40101, - "42": 7.14459, - "43": 7.13995, - "44": 7.32066, - "45": 7.0966, - "46": 6.80106, - "47": 7.21219, - "48": 7.05021, - "49": 7.48165, - "50": 6.95118 + "1": 11.07395, + "2": 11.0927, + "3": 10.82648, + "4": 10.27524, + "5": 10.45343, + "6": 8.32789, + "7": 9.82687, + "8": 8.01561, + "9": 7.47686, + "10": 6.75778, + "11": 8.92977, + "12": 8.98867, + "13": 7.80263, + "14": 8.02637, + "15": 8.11184, + "16": 8.13967, + "17": 8.13444, + "18": 7.44744, + "19": 8.03657, + "20": 7.53993, + "21": 7.90129, + "22": 7.27518, + "23": 7.88304, + "24": 7.37567, + "25": 8.16836, + "26": 7.69935, + "27": 7.6262, + "28": 7.61271, + "29": 7.69819, + "30": 7.4848, + "31": 7.73967, + "32": 6.36884, + "33": 7.14295, + "34": 7.71844, + "35": 7.63485, + "36": 6.61195, + "37": 8.02821, + "38": 7.57841, + "39": 7.89473, + "40": 7.41461, + "41": 7.42116, + "42": 6.01344, + "43": 7.4906, + "44": 7.86418, + "45": 6.74814, + "46": 7.30484, + "47": 7.72617, + "48": 7.79074, + "49": 7.49049, + "50": 6.75504 } }, "iteration-time": { @@ -289,56 +289,56 @@ "end_step": 50, "step_interval": 1, "values": { - "1": 87.66203, - "2": 2.04189, - "3": 3.34278, - "4": 3.72414, - "5": 3.23492, - "6": 1.94546, - "7": 2.14942, - "8": 1.78075, - "9": 1.06029, - "10": 2.13554, - "11": 1.42578, - "12": 1.80986, - "13": 1.06134, - "14": 1.087, - "15": 1.16687, - "16": 1.20412, - "17": 1.06984, - "18": 1.07557, - "19": 1.04081, - "20": 1.21763, - "21": 1.06196, - "22": 1.14038, - "23": 2.25761, - "24": 1.09161, - "25": 1.04319, - "26": 1.40025, - "27": 1.04974, - "28": 1.03984, - "29": 1.05293, - "30": 1.48942, - "31": 1.04785, - "32": 1.0529, - "33": 1.04366, - "34": 1.0633, - "35": 1.0713, - "36": 1.05711, - "37": 1.08085, - "38": 1.07006, - "39": 1.06498, - "40": 1.05913, - "41": 1.0697, - "42": 1.079, - "43": 1.14122, - "44": 1.06478, - "45": 1.04692, - "46": 1.08174, - "47": 1.07595, - "48": 1.10523, - "49": 1.0839, - "50": 1.07754 + "1": 90.97535, + "2": 4.15413, + "3": 4.25282, + "4": 5.50314, + "5": 4.36528, + "6": 4.16016, + "7": 4.60989, + "8": 3.68392, + "9": 3.70951, + "10": 3.66417, + "11": 3.64904, + "12": 3.66094, + "13": 3.68824, + "14": 3.64996, + "15": 3.64159, + "16": 3.68269, + "17": 3.66905, + "18": 4.10783, + "19": 3.63362, + "20": 3.65129, + "21": 3.6431, + "22": 3.64946, + "23": 3.6411, + "24": 3.59707, + "25": 3.55364, + "26": 3.61478, + "27": 3.59779, + "28": 3.58741, + "29": 3.62545, + "30": 3.63538, + "31": 3.58264, + "32": 3.65914, + "33": 3.62764, + "34": 3.61962, + "35": 3.57076, + "36": 3.59244, + "37": 3.68499, + "38": 3.6803, + "39": 3.5849, + "40": 3.59019, + "41": 3.62068, + "42": 3.69144, + "43": 3.71863, + "44": 3.67193, + "45": 3.65673, + "46": 3.66919, + "47": 3.58334, + "48": 3.57229, + "49": 3.66195, + "50": 3.64157 } } } \ No newline at end of file diff --git a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/model_config.yaml b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/model_config.yaml index 487382042b7..c657b9087e7 100644 --- a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/model_config.yaml +++ b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/model_config.yaml @@ -23,8 +23,8 @@ MODEL_ARGS: --use-mcore-models: true --sequence-parallel: true --disable-bias-linear: true - --micro-batch-size: 4 - --global-batch-size: 32 + --micro-batch-size: 1 + --global-batch-size: 8 --train-iters: 50 --exit-duration-in-mins: 230 --no-check-for-nan-in-loss-and-grad: true @@ -36,7 +36,7 @@ MODEL_ARGS: --recompute-granularity: selective --recompute-modules: "[layernorm mla_up_proj mlp moe_act]" --fine-grained-activation-offloading: true - --offload-modules: "[expert_fc1 moe_act attn_norm mlp_norm]" + --offload-modules: "[expert_fc1 moe_act attn_norm mlp_norm qkv_linear core_attn attn_proj]" # Transformer Engine args --transformer-impl: transformer_engine # Data args diff --git a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/golden_values_dev_dgx_h100.json b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/golden_values_dev_dgx_h100.json index 150ba70462f..1483224813a 100644 --- a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/golden_values_dev_dgx_h100.json +++ b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/golden_values_dev_dgx_h100.json @@ -4,56 +4,56 @@ "end_step": 50, "step_interval": 1, "values": { - "1": 11.04276, - "2": 11.02298, - "3": 9.50907, - "4": 10.86145, - "5": 9.36104, - "6": 9.05664, - "7": 9.20646, - "8": 9.00188, - "9": 8.69791, - "10": 8.97535, - "11": 8.48206, - "12": 8.44961, - "13": 8.38916, - "14": 7.90422, - "15": 7.98559, - "16": 8.02787, - "17": 8.04894, - "18": 7.72163, - "19": 8.0935, - "20": 7.85609, - "21": 7.53372, - "22": 7.50495, - "23": 7.39733, - "24": 7.36369, - "25": 7.62993, - "26": 7.04703, - "27": 7.59839, - "28": 7.29807, - "29": 7.46826, - "30": 7.60613, - "31": 7.34795, - "32": 7.53766, - "33": 7.58939, - "34": 7.64431, - "35": 7.18358, - "36": 7.036, - "37": 7.36506, - "38": 7.14525, - "39": 7.50347, - "40": 7.50925, - "41": 7.44415, - "42": 7.20526, - "43": 7.21039, - "44": 7.37585, - "45": 7.1698, - "46": 6.8612, - "47": 7.26258, - "48": 7.1033, - "49": 7.55974, - "50": 6.99878 + "1": 11.01686, + "2": 11.06264, + "3": 10.17771, + "4": 10.86294, + "5": 9.81711, + "6": 9.10377, + "7": 9.61048, + "8": 8.39441, + "9": 7.79453, + "10": 7.15206, + "11": 9.06579, + "12": 12.40166, + "13": 8.04847, + "14": 8.24594, + "15": 8.24907, + "16": 8.32751, + "17": 8.35488, + "18": 7.58028, + "19": 8.18771, + "20": 7.71954, + "21": 8.00698, + "22": 7.35089, + "23": 7.95479, + "24": 7.51289, + "25": 8.32529, + "26": 7.78885, + "27": 7.72725, + "28": 7.71319, + "29": 7.77361, + "30": 7.56799, + "31": 7.85271, + "32": 6.52658, + "33": 7.24362, + "34": 7.80331, + "35": 7.74511, + "36": 6.73702, + "37": 8.15605, + "38": 7.62885, + "39": 7.97707, + "40": 7.52037, + "41": 7.52443, + "42": 6.12689, + "43": 7.60467, + "44": 7.96883, + "45": 6.84543, + "46": 7.42548, + "47": 7.82723, + "48": 7.87988, + "49": 7.59963, + "50": 6.85112 } }, "num-zeros": { @@ -61,56 +61,56 @@ "end_step": 50, "step_interval": 1, "values": { - "1": 38808152.0, - "2": 38549168.0, - "3": 38741680.0, - "4": 81738424.0, - "5": 161659808.0, - "6": 296608000.0, - "7": 557581568.0, - "8": 592711744.0, - "9": 479088640.0, - "10": 520896096.0, - "11": 555256320.0, - "12": 444724480.0, - "13": 658029440.0, - "14": 585665280.0, - "15": 588986240.0, - "16": 479280192.0, - "17": 494748608.0, - "18": 504398944.0, - "19": 601982144.0, - "20": 787884160.0, - "21": 536156160.0, - "22": 513609344.0, - "23": 577056256.0, - "24": 549563712.0, - "25": 648153280.0, - "26": 498150784.0, - "27": 501770816.0, - "28": 522921920.0, - "29": 462644416.0, - "30": 612066112.0, - "31": 605029312.0, - "32": 454036160.0, - "33": 419547936.0, - "34": 378748896.0, - "35": 385339904.0, - "36": 350676768.0, - "37": 478164480.0, - "38": 337833600.0, - "39": 450472544.0, - "40": 267556496.0, - "41": 280614912.0, - "42": 305998368.0, - "43": 372298848.0, - "44": 261697280.0, - "45": 225394720.0, - "46": 268431392.0, - "47": 217617888.0, - "48": 261904016.0, - "49": 229846288.0, - "50": 214954112.0 + "1": 47167840.0, + "2": 46900628.0, + "3": 81003512.0, + "4": 243621808.0, + "5": 468555040.0, + "6": 561181184.0, + "7": 958267392.0, + "8": 720794112.0, + "9": 771164224.0, + "10": 718302016.0, + "11": 669618304.0, + "12": 559500096.0, + "13": 642601344.0, + "14": 754397952.0, + "15": 766531584.0, + "16": 697850240.0, + "17": 654906240.0, + "18": 745861824.0, + "19": 738620928.0, + "20": 887555328.0, + "21": 729800064.0, + "22": 666937216.0, + "23": 777389312.0, + "24": 607175552.0, + "25": 855782784.0, + "26": 846129152.0, + "27": 666477056.0, + "28": 830677504.0, + "29": 811523712.0, + "30": 657771072.0, + "31": 609501440.0, + "32": 784538816.0, + "33": 755198720.0, + "34": 729929280.0, + "35": 719482368.0, + "36": 699006208.0, + "37": 727900096.0, + "38": 711973824.0, + "39": 701515264.0, + "40": 682162752.0, + "41": 534678112.0, + "42": 655361792.0, + "43": 663463424.0, + "44": 642541952.0, + "45": 455907168.0, + "46": 613359936.0, + "47": 592108160.0, + "48": 585115008.0, + "49": 559483008.0, + "50": 544390208.0 } }, "mem-allocated-bytes": { @@ -118,56 +118,56 @@ "end_step": 50, "step_interval": 1, "values": { - "1": 4419107328.0, - "2": 4419108864.0, - "3": 4419108864.0, - "4": 4419108864.0, - "5": 4419108864.0, - "6": 4419108864.0, - "7": 4419108864.0, - "8": 4419108864.0, - "9": 4419108864.0, - "10": 4419108864.0, - "11": 4419108864.0, - "12": 4419108864.0, - "13": 4419108864.0, - "14": 4419108864.0, - "15": 4419108864.0, - "16": 4419108864.0, - "17": 4419108864.0, - "18": 4419108864.0, - "19": 4419108864.0, - "20": 4419108864.0, - "21": 4419108864.0, - "22": 4419108864.0, - "23": 4419108864.0, - "24": 4419108864.0, - "25": 4419108864.0, - "26": 4419108864.0, - "27": 4419108864.0, - "28": 4419108864.0, - "29": 4419108864.0, - "30": 4419108864.0, - "31": 4419108864.0, - "32": 4419108864.0, - "33": 4419108864.0, - "34": 4419108864.0, - "35": 4419108864.0, - "36": 4419108864.0, - "37": 4419108864.0, - "38": 4419108864.0, - "39": 4419108864.0, - "40": 4419108864.0, - "41": 4419108864.0, - "42": 4419108864.0, - "43": 4419108864.0, - "44": 4419108864.0, - "45": 4419108864.0, - "46": 4419108864.0, - "47": 4419108864.0, - "48": 4419108864.0, - "49": 4419108864.0, - "50": 4419108864.0 + "1": 4315544064.0, + "2": 4315545600.0, + "3": 4315545600.0, + "4": 4315545600.0, + "5": 4315545600.0, + "6": 4315545600.0, + "7": 4315545600.0, + "8": 4315545600.0, + "9": 4315545600.0, + "10": 4315545600.0, + "11": 4315545600.0, + "12": 4315545600.0, + "13": 4315545600.0, + "14": 4315545600.0, + "15": 4315545600.0, + "16": 4315545600.0, + "17": 4315545600.0, + "18": 4315545600.0, + "19": 4315545600.0, + "20": 4315545600.0, + "21": 4315545600.0, + "22": 4315545600.0, + "23": 4315545600.0, + "24": 4315545600.0, + "25": 4315545600.0, + "26": 4315545600.0, + "27": 4315545600.0, + "28": 4315545600.0, + "29": 4315545600.0, + "30": 4315545600.0, + "31": 4315545600.0, + "32": 4315545600.0, + "33": 4315545600.0, + "34": 4315545600.0, + "35": 4315545600.0, + "36": 4315545600.0, + "37": 4315545600.0, + "38": 4315545600.0, + "39": 4315545600.0, + "40": 4315545600.0, + "41": 4315545600.0, + "42": 4315545600.0, + "43": 4315545600.0, + "44": 4315545600.0, + "45": 4315545600.0, + "46": 4315545600.0, + "47": 4315545600.0, + "48": 4315545600.0, + "49": 4315545600.0, + "50": 4315545600.0 } }, "mem-max-allocated-bytes": { @@ -175,56 +175,56 @@ "end_step": 50, "step_interval": 1, "values": { - "1": 37959917568.0, - "2": 39578673152.0, - "3": 39583842304.0, - "4": 39583842304.0, - "5": 39584591872.0, - "6": 39584591872.0, - "7": 39584591872.0, - "8": 39584591872.0, - "9": 39584591872.0, - "10": 39584591872.0, - "11": 39584591872.0, - "12": 39584591872.0, - "13": 39584591872.0, - "14": 39584591872.0, - "15": 39584591872.0, - "16": 39584591872.0, - "17": 39584591872.0, - "18": 39584591872.0, - "19": 39584591872.0, - "20": 39584591872.0, - "21": 39584591872.0, - "22": 39584591872.0, - "23": 39584591872.0, - "24": 39584591872.0, - "25": 39584591872.0, - "26": 39584591872.0, - "27": 39584591872.0, - "28": 39584591872.0, - "29": 39584591872.0, - "30": 39584591872.0, - "31": 39584591872.0, - "32": 39584591872.0, - "33": 39584591872.0, - "34": 39584591872.0, - "35": 39584591872.0, - "36": 39584591872.0, - "37": 39584591872.0, - "38": 39584591872.0, - "39": 39584591872.0, - "40": 39584591872.0, - "41": 39584591872.0, - "42": 39584591872.0, - "43": 39584591872.0, - "44": 39584591872.0, - "45": 39584591872.0, - "46": 39584591872.0, - "47": 39584591872.0, - "48": 39584591872.0, - "49": 39584591872.0, - "50": 39584591872.0 + "1": 4919527424.0, + "2": 5861408768.0, + "3": 5861408768.0, + "4": 5863651328.0, + "5": 5863651328.0, + "6": 5863651328.0, + "7": 5863651328.0, + "8": 5863651328.0, + "9": 5863651328.0, + "10": 5863651328.0, + "11": 5863651328.0, + "12": 5863651328.0, + "13": 5863651328.0, + "14": 5863651328.0, + "15": 5863986176.0, + "16": 5865795072.0, + "17": 5865795072.0, + "18": 5865795072.0, + "19": 5865795072.0, + "20": 5865795072.0, + "21": 5866987520.0, + "22": 5866987520.0, + "23": 5866987520.0, + "24": 5866987520.0, + "25": 5866987520.0, + "26": 5866987520.0, + "27": 5866987520.0, + "28": 5866987520.0, + "29": 5866987520.0, + "30": 5866987520.0, + "31": 5866987520.0, + "32": 5866987520.0, + "33": 5866987520.0, + "34": 5866987520.0, + "35": 5866987520.0, + "36": 5866987520.0, + "37": 5866987520.0, + "38": 5866987520.0, + "39": 5866987520.0, + "40": 5866987520.0, + "41": 5866987520.0, + "42": 5866987520.0, + "43": 5866987520.0, + "44": 5866987520.0, + "45": 5866987520.0, + "46": 5866987520.0, + "47": 5866987520.0, + "48": 5866987520.0, + "49": 5866987520.0, + "50": 5866987520.0 } }, "iteration-time": { @@ -232,56 +232,56 @@ "end_step": 50, "step_interval": 1, "values": { - "1": 65.95827, - "2": 1.9924, - "3": 3.92592, - "4": 2.4652, - "5": 1.84842, - "6": 1.80402, - "7": 1.67822, - "8": 1.88485, - "9": 1.32993, - "10": 1.37648, - "11": 1.18596, - "12": 1.16521, - "13": 1.14524, - "14": 1.34968, - "15": 1.22798, - "16": 1.10709, - "17": 1.2737, - "18": 1.12048, - "19": 1.44431, - "20": 1.22659, - "21": 1.23111, - "22": 1.27597, - "23": 1.25479, - "24": 1.12437, - "25": 1.28457, - "26": 1.26411, - "27": 1.16703, - "28": 1.13595, - "29": 1.24774, - "30": 1.10985, - "31": 1.3919, - "32": 1.10386, - "33": 1.20402, - "34": 1.08667, - "35": 1.10247, - "36": 1.09087, - "37": 1.16339, - "38": 1.12236, - "39": 1.10519, - "40": 1.20224, - "41": 1.11719, - "42": 1.18432, - "43": 1.11065, - "44": 1.14205, - "45": 1.12352, - "46": 1.09449, - "47": 1.10298, - "48": 1.10504, - "49": 1.09853, - "50": 1.0939 + "1": 72.699, + "2": 4.27015, + "3": 3.87365, + "4": 3.67041, + "5": 3.65964, + "6": 3.48532, + "7": 3.47679, + "8": 3.47349, + "9": 3.43879, + "10": 3.47441, + "11": 3.45737, + "12": 3.48691, + "13": 3.54474, + "14": 3.44102, + "15": 3.42127, + "16": 3.45795, + "17": 3.49717, + "18": 3.51293, + "19": 3.5617, + "20": 3.49733, + "21": 3.50336, + "22": 3.62308, + "23": 3.50166, + "24": 3.49075, + "25": 3.50996, + "26": 3.44423, + "27": 3.47323, + "28": 3.53784, + "29": 3.51989, + "30": 3.49211, + "31": 3.49945, + "32": 3.4419, + "33": 3.50458, + "34": 3.47663, + "35": 3.45702, + "36": 3.50281, + "37": 3.44136, + "38": 3.45165, + "39": 3.50095, + "40": 3.50126, + "41": 3.50863, + "42": 3.46684, + "43": 3.55122, + "44": 3.48372, + "45": 3.46903, + "46": 3.47654, + "47": 3.51574, + "48": 3.4895, + "49": 3.49404, + "50": 3.45824 } } } \ No newline at end of file diff --git a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/model_config.yaml b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/model_config.yaml index 28ad106f522..5b177ed116d 100644 --- a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/model_config.yaml +++ b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/model_config.yaml @@ -23,8 +23,8 @@ MODEL_ARGS: --use-mcore-models: true --sequence-parallel: true --disable-bias-linear: true - --micro-batch-size: 4 - --global-batch-size: 32 + --micro-batch-size: 1 + --global-batch-size: 8 --train-iters: 50 --exit-duration-in-mins: 230 --no-check-for-nan-in-loss-and-grad: true @@ -36,7 +36,7 @@ MODEL_ARGS: --recompute-granularity: selective --recompute-modules: "[layernorm mla_up_proj mlp moe_act]" --fine-grained-activation-offloading: true - --offload-modules: "[expert_fc1 moe_act attn_norm mlp_norm]" + --offload-modules: "[expert_fc1 moe_act attn_norm mlp_norm qkv_linear core_attn attn_proj]" # Transformer Engine args --transformer-impl: transformer_engine # Data args From 33a38f51c735048c1df92a0ea39e289aba6a85de Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Thu, 6 Nov 2025 02:33:31 -0800 Subject: [PATCH 07/74] format Signed-off-by: Hongbin Liu --- megatron/core/transformer/multi_token_prediction.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py index 80f72a91ff2..d8d20039e45 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -13,6 +13,9 @@ from megatron.core.fp8_utils import get_fp8_context from megatron.core.models.backends import BackendSpecProvider, LocalSpecProvider from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + fine_grained_offloading_set_last_layer, +) from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel import ( gather_from_tensor_model_parallel_region, @@ -29,9 +32,6 @@ make_tp_sharded_tensor_for_checkpoint, make_viewless_tensor, ) -from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - fine_grained_offloading_set_last_layer -) if is_torch_min_version("1.13.0"): dist_all_gather_func = torch.distributed.all_gather_into_tensor From 6c76b07a07d86e961c896834e29fdd4b02b135b2 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Thu, 6 Nov 2025 06:15:51 -0800 Subject: [PATCH 08/74] update golden values Signed-off-by: Hongbin Liu --- .../golden_values_dev_dgx_h100.json | 390 +++++++++--------- 1 file changed, 195 insertions(+), 195 deletions(-) diff --git a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/golden_values_dev_dgx_h100.json b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/golden_values_dev_dgx_h100.json index 1483224813a..f31e8584055 100644 --- a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/golden_values_dev_dgx_h100.json +++ b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/golden_values_dev_dgx_h100.json @@ -6,54 +6,54 @@ "values": { "1": 11.01686, "2": 11.06264, - "3": 10.17771, - "4": 10.86294, - "5": 9.81711, - "6": 9.10377, - "7": 9.61048, - "8": 8.39441, - "9": 7.79453, - "10": 7.15206, - "11": 9.06579, - "12": 12.40166, - "13": 8.04847, - "14": 8.24594, - "15": 8.24907, - "16": 8.32751, - "17": 8.35488, - "18": 7.58028, - "19": 8.18771, - "20": 7.71954, - "21": 8.00698, - "22": 7.35089, - "23": 7.95479, - "24": 7.51289, - "25": 8.32529, - "26": 7.78885, - "27": 7.72725, - "28": 7.71319, - "29": 7.77361, - "30": 7.56799, - "31": 7.85271, - "32": 6.52658, - "33": 7.24362, - "34": 7.80331, - "35": 7.74511, - "36": 6.73702, - "37": 8.15605, - "38": 7.62885, - "39": 7.97707, - "40": 7.52037, - "41": 7.52443, - "42": 6.12689, - "43": 7.60467, - "44": 7.96883, - "45": 6.84543, - "46": 7.42548, - "47": 7.82723, - "48": 7.87988, - "49": 7.59963, - "50": 6.85112 + "3": 10.17793, + "4": 10.86283, + "5": 9.81719, + "6": 9.10416, + "7": 9.61067, + "8": 8.39543, + "9": 7.79835, + "10": 7.15295, + "11": 9.06686, + "12": 12.40969, + "13": 8.05055, + "14": 8.2476, + "15": 8.25138, + "16": 8.32761, + "17": 8.33769, + "18": 7.57521, + "19": 8.18843, + "20": 7.70464, + "21": 8.00008, + "22": 7.35567, + "23": 7.9428, + "24": 7.49828, + "25": 8.31989, + "26": 7.79139, + "27": 7.72813, + "28": 7.70354, + "29": 7.77157, + "30": 7.56925, + "31": 7.85097, + "32": 6.53309, + "33": 7.24762, + "34": 7.79993, + "35": 7.74601, + "36": 6.74083, + "37": 8.15463, + "38": 7.62637, + "39": 7.97973, + "40": 7.52426, + "41": 7.52118, + "42": 6.11695, + "43": 7.60509, + "44": 7.96979, + "45": 6.84567, + "46": 7.4309, + "47": 7.82486, + "48": 7.87887, + "49": 7.59924, + "50": 6.85064 } }, "num-zeros": { @@ -61,56 +61,56 @@ "end_step": 50, "step_interval": 1, "values": { - "1": 47167840.0, - "2": 46900628.0, - "3": 81003512.0, - "4": 243621808.0, - "5": 468555040.0, - "6": 561181184.0, - "7": 958267392.0, - "8": 720794112.0, - "9": 771164224.0, - "10": 718302016.0, - "11": 669618304.0, - "12": 559500096.0, - "13": 642601344.0, - "14": 754397952.0, - "15": 766531584.0, - "16": 697850240.0, - "17": 654906240.0, - "18": 745861824.0, - "19": 738620928.0, - "20": 887555328.0, - "21": 729800064.0, - "22": 666937216.0, - "23": 777389312.0, - "24": 607175552.0, - "25": 855782784.0, - "26": 846129152.0, - "27": 666477056.0, - "28": 830677504.0, - "29": 811523712.0, - "30": 657771072.0, - "31": 609501440.0, - "32": 784538816.0, - "33": 755198720.0, - "34": 729929280.0, - "35": 719482368.0, - "36": 699006208.0, - "37": 727900096.0, - "38": 711973824.0, - "39": 701515264.0, - "40": 682162752.0, - "41": 534678112.0, - "42": 655361792.0, - "43": 663463424.0, - "44": 642541952.0, - "45": 455907168.0, - "46": 613359936.0, - "47": 592108160.0, - "48": 585115008.0, - "49": 559483008.0, - "50": 544390208.0 + "1": 47167816.0, + "2": 46900776.0, + "3": 77860808.0, + "4": 237329376.0, + "5": 471709792.0, + "6": 558041536.0, + "7": 948826176.0, + "8": 723939584.0, + "9": 786891776.0, + "10": 734021888.0, + "11": 688478400.0, + "12": 553228736.0, + "13": 608009792.0, + "14": 741806976.0, + "15": 766532736.0, + "16": 685280512.0, + "17": 654899648.0, + "18": 730146112.0, + "19": 751163904.0, + "20": 884406592.0, + "21": 723541120.0, + "22": 805299648.0, + "23": 789975808.0, + "24": 610294016.0, + "25": 830610048.0, + "26": 824111232.0, + "27": 757678144.0, + "28": 774057088.0, + "29": 805232640.0, + "30": 770995712.0, + "31": 801384640.0, + "32": 790830656.0, + "33": 758341184.0, + "34": 726777280.0, + "35": 750934144.0, + "36": 717880064.0, + "37": 740480704.0, + "38": 724556544.0, + "39": 710957376.0, + "40": 716765760.0, + "41": 531516928.0, + "42": 658507328.0, + "43": 676045888.0, + "44": 680286208.0, + "45": 606880576.0, + "46": 641672384.0, + "47": 633002368.0, + "48": 607136576.0, + "49": 430551968.0, + "50": 563263808.0 } }, "mem-allocated-bytes": { @@ -178,53 +178,53 @@ "1": 4919527424.0, "2": 5861408768.0, "3": 5861408768.0, - "4": 5863651328.0, - "5": 5863651328.0, - "6": 5863651328.0, - "7": 5863651328.0, - "8": 5863651328.0, - "9": 5863651328.0, - "10": 5863651328.0, - "11": 5863651328.0, - "12": 5863651328.0, - "13": 5863651328.0, - "14": 5863651328.0, - "15": 5863986176.0, - "16": 5865795072.0, - "17": 5865795072.0, - "18": 5865795072.0, - "19": 5865795072.0, - "20": 5865795072.0, - "21": 5866987520.0, - "22": 5866987520.0, - "23": 5866987520.0, - "24": 5866987520.0, - "25": 5866987520.0, - "26": 5866987520.0, - "27": 5866987520.0, - "28": 5866987520.0, - "29": 5866987520.0, - "30": 5866987520.0, - "31": 5866987520.0, - "32": 5866987520.0, - "33": 5866987520.0, - "34": 5866987520.0, - "35": 5866987520.0, - "36": 5866987520.0, - "37": 5866987520.0, - "38": 5866987520.0, - "39": 5866987520.0, - "40": 5866987520.0, - "41": 5866987520.0, - "42": 5866987520.0, - "43": 5866987520.0, - "44": 5866987520.0, - "45": 5866987520.0, - "46": 5866987520.0, - "47": 5866987520.0, - "48": 5866987520.0, - "49": 5866987520.0, - "50": 5866987520.0 + "4": 5865549824.0, + "5": 5865549824.0, + "6": 5865549824.0, + "7": 5865549824.0, + "8": 5865549824.0, + "9": 5865549824.0, + "10": 5865549824.0, + "11": 5865549824.0, + "12": 5865549824.0, + "13": 5865549824.0, + "14": 5865549824.0, + "15": 5865549824.0, + "16": 5865549824.0, + "17": 5865549824.0, + "18": 5865549824.0, + "19": 5866154496.0, + "20": 5866154496.0, + "21": 5866154496.0, + "22": 5866154496.0, + "23": 5866154496.0, + "24": 5866154496.0, + "25": 5866154496.0, + "26": 5866154496.0, + "27": 5866154496.0, + "28": 5866154496.0, + "29": 5866154496.0, + "30": 5866154496.0, + "31": 5866154496.0, + "32": 5866154496.0, + "33": 5866154496.0, + "34": 5866154496.0, + "35": 5866154496.0, + "36": 5866154496.0, + "37": 5866154496.0, + "38": 5866154496.0, + "39": 5866154496.0, + "40": 5866154496.0, + "41": 5866154496.0, + "42": 5866154496.0, + "43": 5866154496.0, + "44": 5866154496.0, + "45": 5866154496.0, + "46": 5866154496.0, + "47": 5866154496.0, + "48": 5866154496.0, + "49": 5866154496.0, + "50": 5866154496.0 } }, "iteration-time": { @@ -232,56 +232,56 @@ "end_step": 50, "step_interval": 1, "values": { - "1": 72.699, - "2": 4.27015, - "3": 3.87365, - "4": 3.67041, - "5": 3.65964, - "6": 3.48532, - "7": 3.47679, - "8": 3.47349, - "9": 3.43879, - "10": 3.47441, - "11": 3.45737, - "12": 3.48691, - "13": 3.54474, - "14": 3.44102, - "15": 3.42127, - "16": 3.45795, - "17": 3.49717, - "18": 3.51293, - "19": 3.5617, - "20": 3.49733, - "21": 3.50336, - "22": 3.62308, - "23": 3.50166, - "24": 3.49075, - "25": 3.50996, - "26": 3.44423, - "27": 3.47323, - "28": 3.53784, - "29": 3.51989, - "30": 3.49211, - "31": 3.49945, - "32": 3.4419, - "33": 3.50458, - "34": 3.47663, - "35": 3.45702, - "36": 3.50281, - "37": 3.44136, - "38": 3.45165, - "39": 3.50095, - "40": 3.50126, - "41": 3.50863, - "42": 3.46684, - "43": 3.55122, - "44": 3.48372, - "45": 3.46903, - "46": 3.47654, - "47": 3.51574, - "48": 3.4895, - "49": 3.49404, - "50": 3.45824 + "1": 86.37903, + "2": 4.30499, + "3": 5.51749, + "4": 4.16842, + "5": 5.35652, + "6": 3.7018, + "7": 3.68633, + "8": 3.75304, + "9": 3.67596, + "10": 3.70408, + "11": 3.70621, + "12": 3.71713, + "13": 3.73785, + "14": 3.64923, + "15": 3.63825, + "16": 3.64129, + "17": 3.71791, + "18": 3.69956, + "19": 4.27786, + "20": 4.04035, + "21": 3.67423, + "22": 3.66455, + "23": 3.67758, + "24": 4.16675, + "25": 3.71546, + "26": 3.71205, + "27": 3.71193, + "28": 3.60188, + "29": 3.69233, + "30": 3.68235, + "31": 3.69734, + "32": 3.69173, + "33": 3.64974, + "34": 3.73647, + "35": 3.68627, + "36": 3.70357, + "37": 3.71094, + "38": 3.72508, + "39": 3.70553, + "40": 3.6995, + "41": 3.61312, + "42": 3.63624, + "43": 3.68714, + "44": 3.70371, + "45": 3.67257, + "46": 3.73701, + "47": 3.69639, + "48": 3.65815, + "49": 3.63754, + "50": 3.71569 } } } \ No newline at end of file From 8e72b44a66dd7e36e58758239bdf8d3dd7335e24 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 17 Nov 2025 17:50:27 -0800 Subject: [PATCH 09/74] temp save Signed-off-by: Hongbin Liu --- .../fine_grained_activation_offload.py | 380 +++++++++++++++++- .../core/transformer/transformer_layer.py | 20 +- 2 files changed, 378 insertions(+), 22 deletions(-) diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index 1e280a09d35..bd18275e309 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -3,7 +3,7 @@ import warnings from collections import deque from contextlib import nullcontext -from typing import Any +from typing import Any, Dict, List, Tuple import torch @@ -22,6 +22,250 @@ def debug_rank(message): print(message) +class GPUTensorPool: + """ + GPU memory pool for efficient allocation and deallocation of tensors. + + Features: + - Supports multiple tensor shapes and dtypes, each with its own pool + - Dynamic allocation: tensors are created on-demand during allocation + - Efficient reuse: freed tensors are returned to the pool for reuse + - Uses queue-based management for O(1) allocation and deallocation + + Example: + pool = GPUTensorPool(device='cuda:0') + tensor = pool.allocate((128, 512), dtype=torch.float32) + # ... use tensor ... + pool.free(tensor, (128, 512), dtype=torch.float32) + """ + + def __init__( + self, + device: str = 'cuda', + pin_memory: bool = False + ): + """ + Initialize GPU tensor pool. + + Args: + device: GPU device, default 'cuda' + pin_memory: Whether to use pinned memory (mainly for CPU tensors) + """ + self.device = torch.device(device) + self.pin_memory = pin_memory + + # Maintain a separate pool for each (shape, dtype) combination + # Structure: {(shape, dtype): {'free': deque, 'all': list, 'allocated_count': int}} + self._pools: Dict[Tuple, Dict[str, Any]] = {} + + # Statistics + self._stats = { + 'total_allocated': 0, # Total number of tensors ever allocated + 'current_in_use': 0, # Number of tensors currently in use + 'allocation_requests': 0, # Number of allocation requests + 'free_requests': 0, # Number of free requests + 'pool_hits': 0, # Number of times a tensor was reused from pool + 'pool_misses': 0, # Number of times a new tensor was created + } + + debug_rank("GPUTensorPool: Initialized with dynamic allocation") + + def _get_pool_key(self, shape: Tuple, dtype: torch.dtype) -> Tuple: + """Generate a unique key for the pool based on shape and dtype.""" + return (shape, dtype) + + @staticmethod + def _calculate_memory_size(shape: Tuple, dtype: torch.dtype) -> int: + """Calculate memory size in bytes.""" + element_size = torch.tensor([], dtype=dtype).element_size() + numel = 1 + for dim in shape: + numel *= dim + return numel * element_size + + def allocate(self, shape: Tuple, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """ + Allocate a tensor with the specified shape and dtype. + + Args: + shape: Shape of the tensor + dtype: Data type of the tensor, default torch.float32 + + Returns: + Allocated tensor + """ + self._stats['allocation_requests'] += 1 + + pool_key = self._get_pool_key(shape, dtype) + + # Create pool for this (shape, dtype) if it doesn't exist + if pool_key not in self._pools: + self._pools[pool_key] = { + 'free': deque(), # Queue of available tensors + 'all': [], # List of all tensors (for tracking) + 'allocated_count': 0, # Number of allocated tensors + } + + pool = self._pools[pool_key] + + # Try to reuse a tensor from the pool + if len(pool['free']) > 0: + tensor = pool['free'].popleft() + self._stats['pool_hits'] += 1 + debug_rank( + f"GPUTensorPool.allocate: Reused tensor from pool, " + f"shape={shape}, dtype={dtype}, " + f"remaining in pool={len(pool['free'])}" + ) + else: + # Allocate a new tensor + tensor = torch.empty( + shape, + dtype=dtype, + device=self.device, + pin_memory=self.pin_memory + ) + pool['all'].append(tensor) + self._stats['total_allocated'] += 1 + self._stats['pool_misses'] += 1 + + memory_mb = self._calculate_memory_size(shape, dtype) / (1024 ** 2) + debug_rank( + f"GPUTensorPool.allocate: Created new tensor, " + f"shape={shape}, dtype={dtype}, " + f"memory={memory_mb:.2f} MB, " + f"total_created={len(pool['all'])}" + ) + + pool['allocated_count'] += 1 + self._stats['current_in_use'] += 1 + + return tensor + + def free(self, tensor: torch.Tensor): + """ + Return a tensor to the pool for reuse. + + Args: + tensor: Tensor to free + + Raises: + ValueError: If tensor doesn't belong to this pool + """ + self._stats['free_requests'] += 1 + + shape = tensor.shape + dtype = tensor.dtype + + pool_key = self._get_pool_key(shape, dtype) + + if pool_key not in self._pools: + raise ValueError( + f"No pool exists for shape={shape}, dtype={dtype}. " + f"Available pools: {list(self._pools.keys())}" + ) + + pool = self._pools[pool_key] + + # Verify tensor belongs to this pool (use identity check, not value comparison) + tensor_found = any(tensor is t for t in pool['all']) + if not tensor_found: + raise ValueError( + f"Attempting to free a tensor that doesn't belong to this pool " + f"(shape={shape}, dtype={dtype})" + ) + + # Return tensor to the free queue + pool['free'].append(tensor) + pool['allocated_count'] -= 1 + self._stats['current_in_use'] -= 1 + + debug_rank( + f"GPUTensorPool.free: shape={shape}, dtype={dtype}, " + f"available in pool={len(pool['free'])}" + ) + + def get_pool_status(self, shape: Tuple = None, dtype: torch.dtype = None) -> Dict[str, Any]: + """ + Get the status of the memory pool. + + Args: + shape: If specified along with dtype, return status for that specific pool + dtype: Data type (required if shape is specified) + + Returns: + Dictionary containing status information + """ + if shape is not None: + if dtype is None: + raise ValueError("dtype must be specified when shape is provided") + + pool_key = self._get_pool_key(shape, dtype) + + if pool_key not in self._pools: + raise ValueError(f"No pool exists for shape={shape}, dtype={dtype}") + + pool = self._pools[pool_key] + total_count = len(pool['all']) + + return { + 'shape': shape, + 'dtype': dtype, + 'total_count': total_count, + 'allocated_count': pool['allocated_count'], + 'free_count': len(pool['free']), + 'utilization': pool['allocated_count'] / total_count * 100 if total_count > 0 else 0, + } + else: + # Return status for all pools + status = { + 'global_stats': self._stats.copy(), + 'pools': {} + } + + for pool_key in self._pools: + shape, dtype = pool_key + status['pools'][pool_key] = self.get_pool_status(shape, dtype) + + return status + + def reset(self): + """Reset the pool, marking all tensors as available.""" + debug_rank("GPUTensorPool: Resetting pool...") + + for pool_key, pool in self._pools.items(): + # Clear and refill the free queue + pool['free'].clear() + for tensor in pool['all']: + pool['free'].append(tensor) + pool['allocated_count'] = 0 + + self._stats['current_in_use'] = 0 + debug_rank("GPUTensorPool: Reset complete") + + def clear(self): + """Clear the pool and release all GPU memory.""" + debug_rank("GPUTensorPool: Clearing pool...") + + for pool_key, pool in self._pools.items(): + # Clear all references, allowing PyTorch GC to reclaim memory + pool['free'].clear() + pool['all'].clear() + + self._pools.clear() + self._stats['current_in_use'] = 0 + + # Trigger GPU cache cleanup + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + debug_rank("GPUTensorPool: Clear complete") + + def __del__(self): + """Destructor to ensure resources are released.""" + self.clear() + + def set_ideal_affinity_for_current_gpu(): """Set CPU affinity for the current GPU to optimize host-device transfers.""" import uuid @@ -80,6 +324,8 @@ def __init__(self): # allocate streams and events for synchronization self._d2h_stream = torch.cuda.Stream() self._h2d_stream = torch.cuda.Stream() + # Shared CPU tensor pool for all chunks to improve reuse efficiency + self._cpu_tensor_pool = GPUTensorPool(device="cpu", pin_memory=True) self.reset() @property @@ -92,6 +338,11 @@ def h2d_stream(self): """Get the host-to-device (CPU to GPU) transfer stream.""" return self._h2d_stream + @property + def cpu_tensor_pool(self): + """Get the shared CPU tensor pool.""" + return self._cpu_tensor_pool + def reset(self): """Reset manager state for a new training iteration.""" set_ideal_affinity_for_current_gpu() @@ -100,6 +351,9 @@ def reset(self): self._cur_backward_chunk = None # Track the first microbatch of the last virtual pipeline stage self._is_first_last_vpp_chunk = True + # Reset CPU tensor pool to reuse all CPU tensors for next iteration + if hasattr(self, '_cpu_tensor_pool'): + self._cpu_tensor_pool.reset() def flush(self): """Flush all staged chunks to the backward queue in reverse order.""" @@ -171,7 +425,10 @@ def init_model_chunk_offload_handler( # Determine if this is the first microbatch of the last virtual pipeline stage is_first_last_vpp_chunk = is_first_last_vpp_chunk and (cur_vpp_rank == self._vpp - 1) - cur_chunk = ChunkOffloadHandler(is_first_last_vpp_chunk, min_offloaded_tensor_size) + # Use shared CPU tensor pool for better reuse across chunks + cur_chunk = ChunkOffloadHandler( + is_first_last_vpp_chunk, min_offloaded_tensor_size, self._cpu_tensor_pool + ) self._stages[cur_vpp_rank].append(cur_chunk) # For the last stage, push immediately and flush if cur_vpp_rank == self._vpp - 1: @@ -240,42 +497,50 @@ class ChunkOffloadHandler: Manages tensor groups, coordinates asynchronous GPU-CPU transfers, and handles synchronization. """ - @staticmethod - def offload(src_tensor, pin_memory=True): + def offload(self, src_tensor, pin_memory=True): """Offload.""" debug_rank("--------offload") from megatron.core.extensions.transformer_engine import Float8Tensor - fp8_offload = isinstance(src_tensor, Float8Tensor) if Float8Tensor is not None else False + # fp8_offload = isinstance(src_tensor, Float8Tensor) if Float8Tensor is not None else False if not src_tensor.is_contiguous(): src_tensor = src_tensor.contiguous() - cpu_backup = torch.empty( - src_tensor.size(), - dtype=torch.uint8 if fp8_offload else src_tensor.dtype, - layout=src_tensor.layout, - device="cpu", - pin_memory=pin_memory, - ) + # cpu_backup = torch.empty( + # src_tensor.size(), + # dtype=torch.uint8 if fp8_offload else src_tensor.dtype, + # layout=src_tensor.layout, + # device="cpu", + # pin_memory=pin_memory, + # ) + + cpu_backup = self.cpu_tensor_pool.allocate(src_tensor.shape, dtype=src_tensor.dtype) - if fp8_offload: - cpu_backup = Float8Tensor.make_like(src_tensor, data=cpu_backup) + # if fp8_offload: + # cpu_backup = Float8Tensor.make_like(src_tensor, data=cpu_backup) cpu_backup.copy_(src_tensor, non_blocking=pin_memory) state = (src_tensor.device, cpu_backup) return state - @staticmethod - def reload(state, non_blocking=None): + def reload(self, state, non_blocking=None): """Reload.""" debug_rank("------reload") dev, cpu_backup = state if non_blocking is None: non_blocking = cpu_backup.is_pinned() - return cpu_backup.to(dev, non_blocking=non_blocking) + gpu_tensor = torch.empty( + cpu_backup.size(), + dtype=cpu_backup.dtype, + layout=cpu_backup.layout, + device=torch.cuda.current_device(), + ) + gpu_tensor.copy_(cpu_backup, non_blocking=non_blocking) + self.cpu_tensor_pool.free(cpu_backup) + return gpu_tensor - def __init__(self, is_first_last_vpp_chunk, min_offloaded_tensor_size): + def __init__(self, is_first_last_vpp_chunk, min_offloaded_tensor_size, cpu_tensor_pool): # Data Structure to maintain reference to activation tensors self._tensor_tag_to_state = {} # Mark the first microbatch of the last virtual pipeline stage @@ -295,6 +560,11 @@ def __init__(self, is_first_last_vpp_chunk, min_offloaded_tensor_size): self._reload_events = {} self.min_offloaded_tensor_size = min_offloaded_tensor_size self.is_last_layer = False + self.cpu_tensor_pool = cpu_tensor_pool + + self.delay_offload_and_reload = True + self.delay_offload_group = [] + self.delay_reload_group = 0 def is_empty_chunk(self): """Check if this chunk has no tensors to manage.""" @@ -456,7 +726,17 @@ def on_group_commit_forward(self, forced_released_tensors): debug_rank("--on_group_commit_forward") # Wait for compute to finish before starting offload self.d2h_stream.wait_stream(torch.cuda.current_stream()) + if self.delay_offload_and_reload: + self.delay_offload_group.append(forced_released_tensors) self.bulk_offload(forced_released_tensors) + torch.cuda.current_stream().wait_stream(self.d2h_stream) + + def flush_delay_offload_groups(self): + """Flush the delay offload groups.""" + debug_rank("--flush_delay_offload_groups") + # for group in self.delay_offload_group: + # self.bulk_offload(group) + # self.delay_offload_group = [] def bulk_reload(self): """Reload the next group of tensors from CPU to GPU.""" @@ -508,7 +788,18 @@ def on_group_start_backward(self): debug_rank("--on_group_start_backward") # Wait for compute to finish before starting reload self.h2d_stream.wait_stream(torch.cuda.current_stream()) - self.bulk_reload() + if self.delay_offload_and_reload: + self.delay_reload_group += 1 + else: + self.bulk_reload() + torch.cuda.current_stream().wait_stream(self.h2d_stream) + + def flush_delay_reload_groups(self): + """Flush the delay reload groups.""" + debug_rank("--flush_delay_reload_groups") + for i in range(self.delay_reload_group): + self.bulk_reload() + self.delay_reload_group = 0 class FineGrainedOffloadingGroupCommitFunction(torch.autograd.Function): @@ -586,6 +877,53 @@ def fine_grained_offloading_group_start(tensor, name=None): cur_forward_chunk = PipelineOffloadManager.get_instance().cur_forward_chunk() return FineGrainedOffloadingGroupStartFunction.apply(tensor, cur_forward_chunk, name) +class FineGrainedOffloadingGroupDelayOffloadFunction(torch.autograd.Function): + """ + Identity operation that marks the end of a layer group for offload synchronization. + Triggers offload during forward and synchronizes reload during backward. + """ + + @staticmethod + def forward(ctx, tensor, cur_forward_chunk): + if DEBUG and torch.distributed.get_rank() == 0: + print("flush_delay_offload_groups") + print(cur_forward_chunk.delay_offload_group) + breakpoint() + if DEBUG: + torch.cuda.synchronize() + torch.distributed.barrier() + cur_forward_chunk.flush_delay_offload_groups() + return tensor + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + +def fine_grained_offloading_flush_delay_offload_groups(tensor): + """Flush the delay offload groups.""" + cur_forward_chunk = PipelineOffloadManager.get_instance().cur_forward_chunk() + return FineGrainedOffloadingGroupDelayOffloadFunction.apply(tensor, cur_forward_chunk) +class FineGrainedOffloadingGroupDelayReloadFunction(torch.autograd.Function): + """ + Identity operation that marks the end of a layer group for offload synchronization. + Triggers offload during forward and synchronizes reload during backward. + """ + + @staticmethod + def forward(ctx, tensor, cur_forward_chunk): + ctx.cur_forward_chunk = cur_forward_chunk + return tensor + + @staticmethod + def backward(ctx, grad_output): + cur_forward_chunk = ctx.cur_forward_chunk + cur_forward_chunk.flush_delay_reload_groups() + return grad_output, None + +def fine_grained_offloading_flush_delay_reload_groups(tensor): + """Flush the delay reload groups.""" + cur_forward_chunk = PipelineOffloadManager.get_instance().cur_forward_chunk() + return FineGrainedOffloadingGroupDelayReloadFunction.apply(tensor, cur_forward_chunk) def get_fine_grained_offloading_context(flag): """Get the fine-grained offload context""" @@ -594,7 +932,8 @@ def get_fine_grained_offloading_context(flag): def fine_grained_offloading_set_last_layer(is_last_layer): """Set the last layer flag.""" - PipelineOffloadManager.get_instance().set_last_layer(is_last_layer) + pass + # PipelineOffloadManager.get_instance().set_last_layer(is_last_layer) def fine_grained_offloading_init_chunk_handler(vp_size, vp_stage, min_offloaded_tensor_size): @@ -603,7 +942,6 @@ def fine_grained_offloading_init_chunk_handler(vp_size, vp_stage, min_offloaded_ vp_size, vp_stage, min_offloaded_tensor_size ) - def fine_grained_offloading_reset(): """Reset the chunk handler, called at the start of a training iteration.""" PipelineOffloadManager.get_instance().reset() diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index cacfb9d01b8..28ac576c913 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -558,7 +558,7 @@ def _forward_attention( if self.offload_attn_norm: (hidden_states,) = fine_grained_offloading_group_commit( - hidden_states, name="attn_norm", forced_released_tensors=[residual] + hidden_states, name="attn_norm", forced_released_tensors=[] ) # Residual connection. @@ -601,7 +601,10 @@ def _forward_mlp(self, hidden_states, inference_context=None): from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( fine_grained_offloading_group_start, get_fine_grained_offloading_context, + fine_grained_offloading_flush_delay_offload_groups, ) + if self.config.fine_grained_activation_offloading: + hidden_states = fine_grained_offloading_flush_delay_offload_groups(hidden_states) # Residual connection. residual = hidden_states @@ -803,6 +806,10 @@ def _te_cuda_graph_capture(self, *args, **kwargs): attribute can be set to control the scope of the CUDA graph. 2. If context is None, it cannot be returned as output. """ + from megatron.core.pipeline_parallel import ( + fine_grained_activation_offload, + ) + fine_grained_activation_offload.DEBUG = True context = None if not self.config.cuda_graph_scope or 'attn' in self.config.cuda_graph_scope: hidden_states, context = self._forward_attention(*args, **kwargs) @@ -839,6 +846,15 @@ def _te_cuda_graph_replay(self, *args, **kwargs): However, CUDA graph accepts only Tensor inputs. Hence, `inference_context` and `packed_seq_params` are excluded from input list. """ + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + fine_grained_offloading_flush_delay_offload_groups, + fine_grained_offloading_flush_delay_reload_groups, + ) + # if torch.distributed.get_rank() == 0: + # print("te_cuda_graph_replay") + # breakpoint() + # torch.cuda.synchronize() + # torch.distributed.barrier() context = None if self.config.cuda_graph_scope and 'attn' not in self.config.cuda_graph_scope: hidden_states, context = self._forward_attention(*args, **kwargs) @@ -908,6 +924,8 @@ def _te_cuda_graph_replay(self, *args, **kwargs): residual=residual, shared_expert_output=shared_expert_output, ) + if self.config.fine_grained_activation_offloading: + hidden_states = fine_grained_offloading_flush_delay_offload_groups(hidden_states) mlp_output_with_bias = self.mlp(hidden_states) self.mlp.cudagraph_tensor_store.clear() nvtx_range_pop(suffix="mlp") From 1646f040466d5a4c4db1ea6be948062c3cb729c1 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 24 Nov 2025 20:53:33 -0800 Subject: [PATCH 10/74] support offloading+cuda graph Signed-off-by: Hongbin Liu --- .../fine_grained_activation_offload.py | 124 +++++++----------- megatron/core/pipeline_parallel/schedules.py | 12 +- megatron/core/transformer/module.py | 3 + megatron/core/transformer/moe/moe_layer.py | 2 + .../core/transformer/transformer_layer.py | 41 +++--- 5 files changed, 80 insertions(+), 102 deletions(-) diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index bd18275e309..a40b23d1b82 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -562,9 +562,6 @@ def __init__(self, is_first_last_vpp_chunk, min_offloaded_tensor_size, cpu_tenso self.is_last_layer = False self.cpu_tensor_pool = cpu_tensor_pool - self.delay_offload_and_reload = True - self.delay_offload_group = [] - self.delay_reload_group = 0 def is_empty_chunk(self): """Check if this chunk has no tensors to manage.""" @@ -671,7 +668,7 @@ def bulk_reload_group(self, group_to_reload): # Only reload if tensor was offloaded (stored as tuple) if isinstance(state, tuple): # Wait for offload to complete before reloading - torch.cuda.current_stream().wait_event(event) + # torch.cuda.current_stream().wait_event(event) recovered_tensor = self.reload(state) event.record(self.h2d_stream) self._reload_events[name] = event @@ -710,6 +707,9 @@ def bulk_offload(self, forced_released_tensors): debug_rank("----bulk_offload") if self.should_bulk_offload(): group_to_offload = self._groups_to_offload.pop() + if group_to_offload[0] == 8: + print("rank", torch.distributed.get_rank(), "group_to_offload", group_to_offload) + return self._groups_to_reload.append(group_to_offload) self.bulk_offload_group(group_to_offload) # Manually release tensors not auto-freed by torch GC @@ -726,17 +726,8 @@ def on_group_commit_forward(self, forced_released_tensors): debug_rank("--on_group_commit_forward") # Wait for compute to finish before starting offload self.d2h_stream.wait_stream(torch.cuda.current_stream()) - if self.delay_offload_and_reload: - self.delay_offload_group.append(forced_released_tensors) self.bulk_offload(forced_released_tensors) - torch.cuda.current_stream().wait_stream(self.d2h_stream) - - def flush_delay_offload_groups(self): - """Flush the delay offload groups.""" - debug_rank("--flush_delay_offload_groups") - # for group in self.delay_offload_group: - # self.bulk_offload(group) - # self.delay_offload_group = [] + # torch.cuda.current_stream().wait_stream(self.d2h_stream) def bulk_reload(self): """Reload the next group of tensors from CPU to GPU.""" @@ -766,8 +757,8 @@ def on_group_commit_backward(self, name): assert cur_backward_chunk is self, "Chunk mismatch" # Wait for reload to complete before using tensors event = self.get_reload_event(name) - if event is not None: - torch.cuda.current_stream().wait_event(event) + # if event is not None: + # torch.cuda.current_stream().wait_event(event) self._offloaded_group_index = self._offloaded_group_index - 1 def on_group_start_forward(self, name): @@ -788,19 +779,9 @@ def on_group_start_backward(self): debug_rank("--on_group_start_backward") # Wait for compute to finish before starting reload self.h2d_stream.wait_stream(torch.cuda.current_stream()) - if self.delay_offload_and_reload: - self.delay_reload_group += 1 - else: - self.bulk_reload() - torch.cuda.current_stream().wait_stream(self.h2d_stream) + self.bulk_reload() + # torch.cuda.current_stream().wait_stream(self.h2d_stream) - def flush_delay_reload_groups(self): - """Flush the delay reload groups.""" - debug_rank("--flush_delay_reload_groups") - for i in range(self.delay_reload_group): - self.bulk_reload() - self.delay_reload_group = 0 - class FineGrainedOffloadingGroupCommitFunction(torch.autograd.Function): """ @@ -877,53 +858,6 @@ def fine_grained_offloading_group_start(tensor, name=None): cur_forward_chunk = PipelineOffloadManager.get_instance().cur_forward_chunk() return FineGrainedOffloadingGroupStartFunction.apply(tensor, cur_forward_chunk, name) -class FineGrainedOffloadingGroupDelayOffloadFunction(torch.autograd.Function): - """ - Identity operation that marks the end of a layer group for offload synchronization. - Triggers offload during forward and synchronizes reload during backward. - """ - - @staticmethod - def forward(ctx, tensor, cur_forward_chunk): - if DEBUG and torch.distributed.get_rank() == 0: - print("flush_delay_offload_groups") - print(cur_forward_chunk.delay_offload_group) - breakpoint() - if DEBUG: - torch.cuda.synchronize() - torch.distributed.barrier() - cur_forward_chunk.flush_delay_offload_groups() - return tensor - - @staticmethod - def backward(ctx, grad_output): - return grad_output, None - -def fine_grained_offloading_flush_delay_offload_groups(tensor): - """Flush the delay offload groups.""" - cur_forward_chunk = PipelineOffloadManager.get_instance().cur_forward_chunk() - return FineGrainedOffloadingGroupDelayOffloadFunction.apply(tensor, cur_forward_chunk) -class FineGrainedOffloadingGroupDelayReloadFunction(torch.autograd.Function): - """ - Identity operation that marks the end of a layer group for offload synchronization. - Triggers offload during forward and synchronizes reload during backward. - """ - - @staticmethod - def forward(ctx, tensor, cur_forward_chunk): - ctx.cur_forward_chunk = cur_forward_chunk - return tensor - - @staticmethod - def backward(ctx, grad_output): - cur_forward_chunk = ctx.cur_forward_chunk - cur_forward_chunk.flush_delay_reload_groups() - return grad_output, None - -def fine_grained_offloading_flush_delay_reload_groups(tensor): - """Flush the delay reload groups.""" - cur_forward_chunk = PipelineOffloadManager.get_instance().cur_forward_chunk() - return FineGrainedOffloadingGroupDelayReloadFunction.apply(tensor, cur_forward_chunk) def get_fine_grained_offloading_context(flag): """Get the fine-grained offload context""" @@ -932,8 +866,9 @@ def get_fine_grained_offloading_context(flag): def fine_grained_offloading_set_last_layer(is_last_layer): """Set the last layer flag.""" - pass - # PipelineOffloadManager.get_instance().set_last_layer(is_last_layer) + # pass + # print("set_last_layer", is_last_layer) + PipelineOffloadManager.get_instance().set_last_layer(is_last_layer) def fine_grained_offloading_init_chunk_handler(vp_size, vp_stage, min_offloaded_tensor_size): @@ -945,3 +880,38 @@ def fine_grained_offloading_init_chunk_handler(vp_size, vp_stage, min_offloaded_ def fine_grained_offloading_reset(): """Reset the chunk handler, called at the start of a training iteration.""" PipelineOffloadManager.get_instance().reset() + +class FineGrainedOffloadingBackwardRecordFunction(torch.autograd.Function): + """ + Identity operation that marks the end of a layer group for offload synchronization. + Triggers offload during forward and synchronizes reload during backward. + """ + + @staticmethod + def forward(ctx, tensor, event): + ctx.event = event + return tensor + + @staticmethod + def backward(ctx, grad_output): + h2d_stream = PipelineOffloadManager.get_instance().h2d_stream + torch.cuda.current_stream().record_event(ctx.event) + torch.cuda.current_stream().wait_stream(h2d_stream) + return grad_output, None + +def fine_grained_offloading_backward_record(tensor, event): + return FineGrainedOffloadingBackwardRecordFunction.apply(tensor, event) + +class FineGrainedOffloadingBackwardSyncFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, tensor, stream): + ctx.stream = stream + return tensor + + @staticmethod + def backward(ctx, grad_output): + torch.cuda.current_stream().wait_stream(ctx.stream) + return grad_output, None + +def fine_grained_offloading_backward_sync(tensor, stream): + return FineGrainedOffloadingBackwardSyncFunction.apply(tensor, stream) \ No newline at end of file diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 11e54e0fa53..6ee29b38bd1 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -425,7 +425,7 @@ def forward_step( return [output_tensor], num_tokens -def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config): +def backward_step(model, input_tensor, output_tensor, output_tensor_grad, model_type, config): """Backward step through passed-in output tensor. If last stage, output_tensor_grad is None, otherwise gradient of loss @@ -614,7 +614,7 @@ def forward_backward_no_pipelining( total_num_tokens += num_tokens if not forward_only: backward_step( - input_tensor, output_tensor, output_tensor_grad, model_type, config + model, input_tensor, output_tensor, output_tensor_grad, model_type, config ) # Run computation for last microbatch out of context handler (want to # synchronize gradients). @@ -637,7 +637,7 @@ def forward_backward_no_pipelining( total_num_tokens += num_tokens if not forward_only: - backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) + backward_step(model, input_tensor, output_tensor, output_tensor_grad, model_type, config) if config.finalize_model_grads_func is not None and not forward_only: # Finalize model grads (perform full grad all-reduce / reduce-scatter for @@ -1289,7 +1289,7 @@ def backward_step_helper(virtual_microbatch_id): ) input_tensor_grad = backward_step( - input_tensor, output_tensor, output_tensor_grad, model_type, config + None, input_tensor, output_tensor, output_tensor_grad, model_type, config ) backward_step_helper_postprocess(virtual_microbatch_id) @@ -2236,7 +2236,7 @@ def enable_grad_sync(): enable_grad_sync() input_tensor_grad = backward_step( - input_tensor, output_tensor, output_tensor_grad, model_type, config + None, input_tensor, output_tensor, output_tensor_grad, model_type, config ) if last_iteration: @@ -2272,7 +2272,7 @@ def enable_grad_sync(): ) input_tensor_grad = backward_step( - input_tensor, output_tensor, output_tensor_grad, model_type, config + None, input_tensor, output_tensor, output_tensor_grad, model_type, config ) p2p_communicator.send_backward( diff --git a/megatron/core/transformer/module.py b/megatron/core/transformer/module.py index 4fdcacb791b..f6c8cad9c96 100644 --- a/megatron/core/transformer/module.py +++ b/megatron/core/transformer/module.py @@ -267,6 +267,9 @@ def _get_te_cuda_graph_replay_args(self, *args, **kwargs): cudagraph_kwargs = kwargs.copy() cudagraph_kwargs['is_first_microbatch'] = getattr(self, 'current_microbatch', 0) == 0 + from megatron.core.transformer.transformer_layer import TransformerLayer + cudagraph_kwargs['cuda_graph_stream'] = TransformerLayer.cuda_graph_stream + cudagraph_kwargs['cuda_graph_event'] = TransformerLayer.cuda_graph_event return cudagraph_args, cudagraph_kwargs def _should_call_local_cudagraph(self, *args, **kwargs): diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index 893b2e7b99a..94bb2bf20b2 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -293,6 +293,8 @@ def forward(self, hidden_states: torch.Tensor): # MoE forward: route -> dispatch -> compute -> combine def custom_forward(hidden_states): + from megatron.core.pipeline_parallel.fine_grained_activation_offload import PipelineOffloadManager + d2h_stream = PipelineOffloadManager.get_instance().d2h_stream try: shared_expert_output = self.shared_experts_compute(hidden_states) probs, routing_map = self.route(hidden_states) diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 28ac576c913..5cea830dc1c 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -258,7 +258,8 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): Transformer layer takes input with size [s, b, h] and returns an output of the same size. """ - + cuda_graph_stream = None + cuda_graph_event = None def __init__( self, config: TransformerConfig, @@ -427,6 +428,11 @@ def __init__( # self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad self.bias_dropout_add_exec_handler = torch.enable_grad + if TransformerLayer.cuda_graph_stream is None: + TransformerLayer.cuda_graph_stream = torch.cuda.Stream() + if TransformerLayer.cuda_graph_event is None: + TransformerLayer.cuda_graph_event = torch.cuda.Event(external=True) + @staticmethod def _get_layer_offset(config: TransformerConfig): """ @@ -504,8 +510,11 @@ def _forward_attention( fine_grained_offloading_group_commit, fine_grained_offloading_group_start, get_fine_grained_offloading_context, + fine_grained_offloading_backward_record, ) + hidden_states = fine_grained_offloading_backward_record(hidden_states, TransformerLayer.cuda_graph_event) + inference_context = deprecate_inference_params(inference_context, inference_params) # Residual connection. @@ -601,10 +610,9 @@ def _forward_mlp(self, hidden_states, inference_context=None): from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( fine_grained_offloading_group_start, get_fine_grained_offloading_context, - fine_grained_offloading_flush_delay_offload_groups, + PipelineOffloadManager, ) - if self.config.fine_grained_activation_offloading: - hidden_states = fine_grained_offloading_flush_delay_offload_groups(hidden_states) + d2h_stream = PipelineOffloadManager.get_instance().d2h_stream # Residual connection. residual = hidden_states @@ -642,6 +650,8 @@ def _forward_mlp(self, hidden_states, inference_context=None): not self.recompute_pre_mlp_layernorm ), "Recomputation is not supported for CUDA graph." cudagraph_outputs = self.mlp(pre_mlp_layernorm_output) + TransformerLayer.cuda_graph_event.record(torch.cuda.current_stream()) + torch.cuda.current_stream().wait_stream(d2h_stream) return cudagraph_outputs + [residual] elif self.recompute_mlp: if self.config.fp8: @@ -806,10 +816,6 @@ def _te_cuda_graph_capture(self, *args, **kwargs): attribute can be set to control the scope of the CUDA graph. 2. If context is None, it cannot be returned as output. """ - from megatron.core.pipeline_parallel import ( - fine_grained_activation_offload, - ) - fine_grained_activation_offload.DEBUG = True context = None if not self.config.cuda_graph_scope or 'attn' in self.config.cuda_graph_scope: hidden_states, context = self._forward_attention(*args, **kwargs) @@ -846,15 +852,6 @@ def _te_cuda_graph_replay(self, *args, **kwargs): However, CUDA graph accepts only Tensor inputs. Hence, `inference_context` and `packed_seq_params` are excluded from input list. """ - from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - fine_grained_offloading_flush_delay_offload_groups, - fine_grained_offloading_flush_delay_reload_groups, - ) - # if torch.distributed.get_rank() == 0: - # print("te_cuda_graph_replay") - # breakpoint() - # torch.cuda.synchronize() - # torch.distributed.barrier() context = None if self.config.cuda_graph_scope and 'attn' not in self.config.cuda_graph_scope: hidden_states, context = self._forward_attention(*args, **kwargs) @@ -871,6 +868,7 @@ def _te_cuda_graph_replay(self, *args, **kwargs): cuda_graph_output = list(super()._te_cuda_graph_replay(*args, **kwargs)) + torch.cuda.current_stream().wait_event(TransformerLayer.cuda_graph_event) if kwargs.get('context') is not None: context = cuda_graph_output.pop() @@ -924,8 +922,13 @@ def _te_cuda_graph_replay(self, *args, **kwargs): residual=residual, shared_expert_output=shared_expert_output, ) - if self.config.fine_grained_activation_offloading: - hidden_states = fine_grained_offloading_flush_delay_offload_groups(hidden_states) + # if torch.distributed.get_rank() == 0 and not is_graph_capturing(): + # print(f"hidden_states before mlp: {hidden_states}") + # print(f"shared_expert_output: {shared_expert_output}") + # print(f"probs: {probs}") + # print(f"routing_map: {routing_map}") + # print(f"residual: {residual}") + # breakpoint() mlp_output_with_bias = self.mlp(hidden_states) self.mlp.cudagraph_tensor_store.clear() nvtx_range_pop(suffix="mlp") From a177cf5db103a9d543de67251b5cf4e362fe4eab Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Thu, 27 Nov 2025 00:18:22 -0800 Subject: [PATCH 11/74] support PP=1 Signed-off-by: Hongbin Liu --- .../fine_grained_activation_offload.py | 81 ++++++++----------- megatron/core/transformer/cuda_graphs.py | 9 ++- megatron/core/transformer/moe/moe_layer.py | 2 - .../core/transformer/transformer_block.py | 6 +- .../core/transformer/transformer_config.py | 19 +++++ .../core/transformer/transformer_layer.py | 39 +++++---- 6 files changed, 86 insertions(+), 70 deletions(-) diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index a40b23d1b82..0e1b8a87242 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -11,6 +11,8 @@ DEBUG = False DEBUG_RANK = 0 +from megatron.core.transformer.cuda_graphs import is_graph_capturing + def debug_rank(message): """Print debug message for a specific rank when DEBUG is enabled.""" @@ -408,8 +410,8 @@ def init_model_chunk_offload_handler( vp_stage: Virtual pipeline stage index (None means stage 0) min_offloaded_tensor_size: Minimum tensor size (in elements) to offload """ + vp_size = 1 if vp_size is None else vp_size if self._stages is None: - vp_size = 1 if vp_size is None else vp_size self._vpp = vp_size self._stages = [[] for _ in range(vp_size)] @@ -432,7 +434,8 @@ def init_model_chunk_offload_handler( self._stages[cur_vpp_rank].append(cur_chunk) # For the last stage, push immediately and flush if cur_vpp_rank == self._vpp - 1: - self._is_first_last_vpp_chunk = False + if vp_size > 1: + self._is_first_last_vpp_chunk = False self.push(cur_chunk) self.flush() self._cur_forward_chunk = cur_chunk @@ -500,26 +503,12 @@ class ChunkOffloadHandler: def offload(self, src_tensor, pin_memory=True): """Offload.""" debug_rank("--------offload") - from megatron.core.extensions.transformer_engine import Float8Tensor - - # fp8_offload = isinstance(src_tensor, Float8Tensor) if Float8Tensor is not None else False if not src_tensor.is_contiguous(): src_tensor = src_tensor.contiguous() - # cpu_backup = torch.empty( - # src_tensor.size(), - # dtype=torch.uint8 if fp8_offload else src_tensor.dtype, - # layout=src_tensor.layout, - # device="cpu", - # pin_memory=pin_memory, - # ) - cpu_backup = self.cpu_tensor_pool.allocate(src_tensor.shape, dtype=src_tensor.dtype) - # if fp8_offload: - # cpu_backup = Float8Tensor.make_like(src_tensor, data=cpu_backup) - cpu_backup.copy_(src_tensor, non_blocking=pin_memory) state = (src_tensor.device, cpu_backup) return state @@ -534,13 +523,14 @@ def reload(self, state, non_blocking=None): cpu_backup.size(), dtype=cpu_backup.dtype, layout=cpu_backup.layout, - device=torch.cuda.current_device(), + device=dev, ) gpu_tensor.copy_(cpu_backup, non_blocking=non_blocking) self.cpu_tensor_pool.free(cpu_backup) return gpu_tensor def __init__(self, is_first_last_vpp_chunk, min_offloaded_tensor_size, cpu_tensor_pool): + self.do_offload = True # Data Structure to maintain reference to activation tensors self._tensor_tag_to_state = {} # Mark the first microbatch of the last virtual pipeline stage @@ -668,7 +658,8 @@ def bulk_reload_group(self, group_to_reload): # Only reload if tensor was offloaded (stored as tuple) if isinstance(state, tuple): # Wait for offload to complete before reloading - # torch.cuda.current_stream().wait_event(event) + if not is_graph_capturing(): + torch.cuda.current_stream().wait_event(event) recovered_tensor = self.reload(state) event.record(self.h2d_stream) self._reload_events[name] = event @@ -689,6 +680,8 @@ def pre_reload_last_layer(self): def should_bulk_offload(self): """Determine if the current group should be offloaded.""" + if not self.do_offload: + return False # Don't offload the first backward chunk's last layer if self.is_first_last_layer(): return False @@ -707,9 +700,9 @@ def bulk_offload(self, forced_released_tensors): debug_rank("----bulk_offload") if self.should_bulk_offload(): group_to_offload = self._groups_to_offload.pop() - if group_to_offload[0] == 8: - print("rank", torch.distributed.get_rank(), "group_to_offload", group_to_offload) - return + # if group_to_offload[0] == 8: + # # print("rank", torch.distributed.get_rank(), "group_to_offload", group_to_offload) + # return self._groups_to_reload.append(group_to_offload) self.bulk_offload_group(group_to_offload) # Manually release tensors not auto-freed by torch GC @@ -727,7 +720,6 @@ def on_group_commit_forward(self, forced_released_tensors): # Wait for compute to finish before starting offload self.d2h_stream.wait_stream(torch.cuda.current_stream()) self.bulk_offload(forced_released_tensors) - # torch.cuda.current_stream().wait_stream(self.d2h_stream) def bulk_reload(self): """Reload the next group of tensors from CPU to GPU.""" @@ -757,8 +749,8 @@ def on_group_commit_backward(self, name): assert cur_backward_chunk is self, "Chunk mismatch" # Wait for reload to complete before using tensors event = self.get_reload_event(name) - # if event is not None: - # torch.cuda.current_stream().wait_event(event) + if event is not None and not is_graph_capturing(): + torch.cuda.current_stream().wait_event(event) self._offloaded_group_index = self._offloaded_group_index - 1 def on_group_start_forward(self, name): @@ -780,8 +772,16 @@ def on_group_start_backward(self): # Wait for compute to finish before starting reload self.h2d_stream.wait_stream(torch.cuda.current_stream()) self.bulk_reload() - # torch.cuda.current_stream().wait_stream(self.h2d_stream) - + +def fine_grained_offloading_disable_offload(): + """Disable the offload.""" + debug_rank("fine_grained_offloading_disable_offload") + PipelineOffloadManager.get_instance().cur_forward_chunk().do_offload = False + +def fine_grained_offloading_enable_offload(): + """Enable the offload.""" + debug_rank("fine_grained_offloading_enable_offload") + PipelineOffloadManager.get_instance().cur_forward_chunk().do_offload = True class FineGrainedOffloadingGroupCommitFunction(torch.autograd.Function): """ @@ -866,8 +866,6 @@ def get_fine_grained_offloading_context(flag): def fine_grained_offloading_set_last_layer(is_last_layer): """Set the last layer flag.""" - # pass - # print("set_last_layer", is_last_layer) PipelineOffloadManager.get_instance().set_last_layer(is_last_layer) @@ -881,6 +879,11 @@ def fine_grained_offloading_reset(): """Reset the chunk handler, called at the start of a training iteration.""" PipelineOffloadManager.get_instance().reset() +def fine_grained_offloading_forward_record(event: torch.cuda.Event) -> None: + d2h_stream = PipelineOffloadManager.get_instance().d2h_stream + torch.cuda.current_stream().record_event(event) + torch.cuda.current_stream().wait_stream(d2h_stream) + class FineGrainedOffloadingBackwardRecordFunction(torch.autograd.Function): """ Identity operation that marks the end of a layer group for offload synchronization. @@ -888,10 +891,10 @@ class FineGrainedOffloadingBackwardRecordFunction(torch.autograd.Function): """ @staticmethod - def forward(ctx, tensor, event): + def forward(ctx, tensor, event: torch.cuda.Event) -> torch.Tensor: ctx.event = event return tensor - + @staticmethod def backward(ctx, grad_output): h2d_stream = PipelineOffloadManager.get_instance().h2d_stream @@ -899,19 +902,5 @@ def backward(ctx, grad_output): torch.cuda.current_stream().wait_stream(h2d_stream) return grad_output, None -def fine_grained_offloading_backward_record(tensor, event): - return FineGrainedOffloadingBackwardRecordFunction.apply(tensor, event) - -class FineGrainedOffloadingBackwardSyncFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, tensor, stream): - ctx.stream = stream - return tensor - - @staticmethod - def backward(ctx, grad_output): - torch.cuda.current_stream().wait_stream(ctx.stream) - return grad_output, None - -def fine_grained_offloading_backward_sync(tensor, stream): - return FineGrainedOffloadingBackwardSyncFunction.apply(tensor, stream) \ No newline at end of file +def fine_grained_offloading_backward_record(tensor, event: torch.cuda.Event) -> torch.Tensor: + return FineGrainedOffloadingBackwardRecordFunction.apply(tensor, event) \ No newline at end of file diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index 12f15ee980a..4c9005e5b28 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -1581,7 +1581,7 @@ def get_rotary_pos_emb(transformer_module, transformer_input): ) def get_make_graphed_callables_kwargs(): - kwargs = {'num_warmup_iters': 11, 'allow_unused_input': True, '_order': order} + kwargs = {'num_warmup_iters': 2, 'allow_unused_input': True, '_order': order} if is_te_min_version("2.6.0"): # Starting from TE 2.6.0, make_graphed_callables() accepts different number @@ -1630,6 +1630,13 @@ def _get_fp8_enabled(): ) else: kwargs['fp8_enabled'] = False + + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + fine_grained_offloading_disable_offload, + fine_grained_offloading_enable_offload, + ) + kwargs['pre_warmup_hook'] = fine_grained_offloading_disable_offload + kwargs['post_warmup_hook'] = fine_grained_offloading_enable_offload return kwargs kwargs = get_make_graphed_callables_kwargs() diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index a0780a59471..095e6526934 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -293,8 +293,6 @@ def forward(self, hidden_states: torch.Tensor): # MoE forward: route -> dispatch -> compute -> combine def custom_forward(hidden_states): - from megatron.core.pipeline_parallel.fine_grained_activation_offload import PipelineOffloadManager - d2h_stream = PipelineOffloadManager.get_instance().d2h_stream try: shared_expert_output = self.shared_experts_compute(hidden_states) probs, routing_map = self.route(hidden_states) diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index 6f69927e9e8..0a05ca31182 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -327,6 +327,7 @@ def __init__( self._build_layers() self.num_layers_per_pipeline_rank = len(self.layers) + self.layers[-1].is_last_layer = True def _build_layers(self): # Transformer layers. @@ -726,11 +727,6 @@ def forward( else: inner_quantization_context = nullcontext() - if self.config.fine_grained_activation_offloading: - fine_grained_offloading_set_last_layer( - l_no == self.num_layers_per_pipeline_rank - 1 - ) - with self.offload_context, inner_quantization_context: hidden_states, context = layer( hidden_states=hidden_states, diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 895aef978e2..0b5b3d80c09 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -801,6 +801,9 @@ class TransformerConfig(ModelParallelConfig): """ min_offloaded_tensor_size: int = 1024 * 1024 """The minimum size of the tensor to be offloaded.""" + offload_module_in_cuda_graph: bool = False + """The flag is derived from the fine_grained_activation_offloading flag. + If True, mark the module in the cuda graph will be offloaded.""" def __post_init__(self): """Python dataclass method that is used to modify attributes after initialization. @@ -1188,6 +1191,22 @@ def __post_init__(self): "because the input of attn_proj is the output of core_attn, " "which is needed in core_attn.backward()." ) + if self.enable_cuda_graph or self.cuda_graph_impl == "local": + raise ValueError("Fine-grained activation offloading does not support local implementation of CUDA graph.") + if self.external_cuda_graph or self.cuda_graph_impl == "transformer_engine": + assert "attn" in self.cuda_graph_scope and "moe_router" in self.cuda_graph_scope, "attn and moe_router must be in cuda_graph_scope when enabling offloading." + assert "attn_norm" not in self.offload_modules, "input of attn_norm is exactly the entry point of cuda graph, which cannot be offloaded." + self.offload_module_in_cuda_graph = \ + "attn_proj" in self.offload_modules \ + or "core_attn" in self.offload_modules \ + or "mlp_norm" in self.offload_modules \ + or "qkv_linear" in self.offload_modules + if self.offload_module_in_cuda_graph: + assert is_torch_min_version("2.9.0a0"), \ + "Fine-grained activation offloading needs torch>=2.9.0 to support cuda graph. " \ + f"Current torch version is {torch.__version__}." + assert self.cuda_graph_warmup_steps > 0, \ + "Fine-grained activation offloading needs cuda_graph_warmup_steps > 0." if ( self.num_layers_in_first_pipeline_stage is not None diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index d92d3b0756a..8dde5271f4a 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -276,6 +276,8 @@ def __init__( self.pg_collection = pg_collection self.tp_group = pg_collection.tp + self.is_last_layer = False + self.submodules_config = submodules self.layer_number = layer_number + get_transformer_layer_offset( self.config, vp_stage, get_pg_rank(pg_collection.pp) @@ -430,9 +432,15 @@ def __init__( self.bias_dropout_add_exec_handler = torch.enable_grad if TransformerLayer.cuda_graph_stream is None: - TransformerLayer.cuda_graph_stream = torch.cuda.Stream() + if self.config.offload_module_in_cuda_graph: + TransformerLayer.cuda_graph_stream = torch.cuda.Stream() + else: + TransformerLayer.cuda_graph_stream = torch.cuda.current_stream() if TransformerLayer.cuda_graph_event is None: - TransformerLayer.cuda_graph_event = torch.cuda.Event(external=True) + if self.config.offload_module_in_cuda_graph: + TransformerLayer.cuda_graph_event = torch.cuda.Event(external=True) + else: + TransformerLayer.cuda_graph_event = torch.cuda.Event() @staticmethod def _get_layer_offset(config: TransformerConfig): @@ -511,10 +519,14 @@ def _forward_attention( fine_grained_offloading_group_commit, fine_grained_offloading_group_start, get_fine_grained_offloading_context, - fine_grained_offloading_backward_record, ) + if self.config.fine_grained_activation_offloading: + from megatron.core.pipeline_parallel.fine_grained_activation_offload import fine_grained_offloading_set_last_layer + fine_grained_offloading_set_last_layer(self.is_last_layer) - hidden_states = fine_grained_offloading_backward_record(hidden_states, TransformerLayer.cuda_graph_event) + if self.config.offload_module_in_cuda_graph: + from megatron.core.pipeline_parallel.fine_grained_activation_offload import fine_grained_offloading_backward_record + hidden_states = fine_grained_offloading_backward_record(hidden_states, TransformerLayer.cuda_graph_event) inference_context = deprecate_inference_params(inference_context, inference_params) @@ -568,7 +580,7 @@ def _forward_attention( if self.offload_attn_norm: (hidden_states,) = fine_grained_offloading_group_commit( - hidden_states, name="attn_norm", forced_released_tensors=[] + hidden_states, name="attn_norm", forced_released_tensors=[residual] ) # Residual connection. @@ -612,8 +624,8 @@ def _forward_mlp(self, hidden_states, inference_context=None): fine_grained_offloading_group_start, get_fine_grained_offloading_context, PipelineOffloadManager, + fine_grained_offloading_forward_record, ) - d2h_stream = PipelineOffloadManager.get_instance().d2h_stream # Residual connection. residual = hidden_states @@ -651,8 +663,6 @@ def _forward_mlp(self, hidden_states, inference_context=None): not self.recompute_pre_mlp_layernorm ), "Recomputation is not supported for CUDA graph." cudagraph_outputs = self.mlp(pre_mlp_layernorm_output) - TransformerLayer.cuda_graph_event.record(torch.cuda.current_stream()) - torch.cuda.current_stream().wait_stream(d2h_stream) return cudagraph_outputs + [residual] elif self.recompute_mlp: if self.config.fp8: @@ -844,6 +854,9 @@ def _te_cuda_graph_capture(self, *args, **kwargs): cuda_graph_outputs = list(hidden_states) if context is not None: cuda_graph_outputs.append(context) + if self.config.offload_module_in_cuda_graph: + from megatron.core.pipeline_parallel.fine_grained_activation_offload import fine_grained_offloading_forward_record + fine_grained_offloading_forward_record(TransformerLayer.cuda_graph_event) return tuple(cuda_graph_outputs) def _te_cuda_graph_replay(self, *args, **kwargs): @@ -869,7 +882,8 @@ def _te_cuda_graph_replay(self, *args, **kwargs): cuda_graph_output = list(super()._te_cuda_graph_replay(*args, **kwargs)) - torch.cuda.current_stream().wait_event(TransformerLayer.cuda_graph_event) + # if self.config.offload_module_in_cuda_graph: + # torch.cuda.current_stream().wait_event(TransformerLayer.cuda_graph_event) if kwargs.get('context') is not None: context = cuda_graph_output.pop() @@ -923,13 +937,6 @@ def _te_cuda_graph_replay(self, *args, **kwargs): residual=residual, shared_expert_output=shared_expert_output, ) - # if torch.distributed.get_rank() == 0 and not is_graph_capturing(): - # print(f"hidden_states before mlp: {hidden_states}") - # print(f"shared_expert_output: {shared_expert_output}") - # print(f"probs: {probs}") - # print(f"routing_map: {routing_map}") - # print(f"residual: {residual}") - # breakpoint() mlp_output_with_bias = self.mlp(hidden_states) self.mlp.cudagraph_tensor_store.clear() nvtx_range_pop(suffix="mlp") From f7cfbba15edf0483ce3cf3dbcfbdd7b158e09d17 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Sun, 30 Nov 2025 21:18:11 -0800 Subject: [PATCH 12/74] support VPP Signed-off-by: Hongbin Liu --- .../fine_grained_activation_offload.py | 71 ++++++++--- megatron/core/transformer/cuda_graphs.py | 13 +- .../transformer/multi_latent_attention.py | 21 ++-- .../transformer/multi_token_prediction.py | 3 +- .../core/transformer/transformer_config.py | 17 +-- .../core/transformer/transformer_layer.py | 115 ++++++++++++------ 6 files changed, 160 insertions(+), 80 deletions(-) diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index 0e1b8a87242..c58d7582f65 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -431,6 +431,7 @@ def init_model_chunk_offload_handler( cur_chunk = ChunkOffloadHandler( is_first_last_vpp_chunk, min_offloaded_tensor_size, self._cpu_tensor_pool ) + debug_rank(f"init_model_chunk_offload_handler {cur_chunk}") self._stages[cur_vpp_rank].append(cur_chunk) # For the last stage, push immediately and flush if cur_vpp_rank == self._vpp - 1: @@ -443,6 +444,7 @@ def init_model_chunk_offload_handler( def set_last_layer(self, is_last_layer): """Mark whether the current forward chunk is processing the last layer.""" + debug_rank(f"set_last_layer {is_last_layer}") self._cur_forward_chunk.is_last_layer = is_last_layer def cur_forward_chunk(self): @@ -456,6 +458,8 @@ def cur_backward_chunk(self): def __enter__(self): """Enter context manager to enable activation offloading hooks.""" debug_rank("----__enter__") + if not self.cur_forward_chunk().do_offload: + return from megatron.core.extensions.transformer_engine import cpu_offload if cpu_offload is not None: @@ -469,6 +473,8 @@ def __enter__(self): def __exit__(self, *args: Any): """Exit context manager and restore original tensor saving behavior.""" debug_rank("----__exit__") + if not self.cur_forward_chunk().do_offload: + return from megatron.core.extensions.transformer_engine import cpu_offload if cpu_offload is not None: @@ -500,23 +506,26 @@ class ChunkOffloadHandler: Manages tensor groups, coordinates asynchronous GPU-CPU transfers, and handles synchronization. """ - def offload(self, src_tensor, pin_memory=True): + def offload(self, src_tensor, pin_memory=True, use_cpu_pool=True): """Offload.""" debug_rank("--------offload") if not src_tensor.is_contiguous(): src_tensor = src_tensor.contiguous() - cpu_backup = self.cpu_tensor_pool.allocate(src_tensor.shape, dtype=src_tensor.dtype) + if use_cpu_pool: + cpu_backup = self.cpu_tensor_pool.allocate(src_tensor.shape, dtype=src_tensor.dtype) + else: + cpu_backup = torch.empty(src_tensor.shape, dtype=src_tensor.dtype, device="cpu", pin_memory=pin_memory) cpu_backup.copy_(src_tensor, non_blocking=pin_memory) - state = (src_tensor.device, cpu_backup) + state = (src_tensor.device, cpu_backup, use_cpu_pool) return state def reload(self, state, non_blocking=None): """Reload.""" debug_rank("------reload") - dev, cpu_backup = state + dev, cpu_backup, use_cpu_pool = state if non_blocking is None: non_blocking = cpu_backup.is_pinned() gpu_tensor = torch.empty( @@ -526,7 +535,8 @@ def reload(self, state, non_blocking=None): device=dev, ) gpu_tensor.copy_(cpu_backup, non_blocking=non_blocking) - self.cpu_tensor_pool.free(cpu_backup) + if use_cpu_pool: + self.cpu_tensor_pool.free(cpu_backup) return gpu_tensor def __init__(self, is_first_last_vpp_chunk, min_offloaded_tensor_size, cpu_tensor_pool): @@ -615,8 +625,13 @@ def bulk_offload_group(self, group_to_offload): """offload a group of tensors recorded in tensor_push().""" debug_rank("------bulk_offload_group") assert not self.is_first_last_layer(), "Should not offload first-last layer" - group_id_to_offload, name = group_to_offload + group_id_to_offload, name, fake_offload = group_to_offload + if fake_offload: + return False torch.cuda.nvtx.range_push("activation offloading " + name) + use_cpu_pool = True + if name == "expert_fc1" or name == "moe_act": + use_cpu_pool = False with torch.cuda.stream(self.d2h_stream): for tensor_tag, state in self._tensor_tag_to_state.items(): group_id, _ = tensor_tag @@ -626,13 +641,14 @@ def bulk_offload_group(self, group_to_offload): assert not isinstance(state, tuple), "Tensor already offloaded" tensor_on_device = state if self.tensor_need_offloading_checker(tensor_on_device): - state = self.offload(tensor_on_device) + state = self.offload(tensor_on_device, use_cpu_pool=use_cpu_pool) event = torch.cuda.Event() event.record(self.d2h_stream) self._offload_events[name] = event tensor_on_device.record_stream(self.d2h_stream) self._tensor_tag_to_state[tensor_tag] = state torch.cuda.nvtx.range_pop() + return True def get_offload_event(self, name): """Get the CUDA event for a named offload operation.""" @@ -645,8 +661,10 @@ def get_reload_event(self, name): def bulk_reload_group(self, group_to_reload): """Bulk reload group.""" debug_rank("----bulk_reload_group") + group_id_to_reload, name, fake_reload = group_to_reload + if fake_reload: + return True found_reload_group = False - group_id_to_reload, name = group_to_reload torch.cuda.nvtx.range_push("activation reloading " + name) with torch.cuda.stream(self.h2d_stream): for tensor_label, state in self._tensor_tag_to_state.items(): @@ -680,8 +698,6 @@ def pre_reload_last_layer(self): def should_bulk_offload(self): """Determine if the current group should be offloaded.""" - if not self.do_offload: - return False # Don't offload the first backward chunk's last layer if self.is_first_last_layer(): return False @@ -704,7 +720,9 @@ def bulk_offload(self, forced_released_tensors): # # print("rank", torch.distributed.get_rank(), "group_to_offload", group_to_offload) # return self._groups_to_reload.append(group_to_offload) - self.bulk_offload_group(group_to_offload) + ret = self.bulk_offload_group(group_to_offload) + if not ret: + return # Manually release tensors not auto-freed by torch GC if len(forced_released_tensors) > 0: cur_stream = torch.cuda.current_stream() @@ -716,6 +734,8 @@ def bulk_offload(self, forced_released_tensors): def on_group_commit_forward(self, forced_released_tensors): """Called at the end of a layer group's forward pass to trigger offloading.""" + if not self.do_offload: + return debug_rank("--on_group_commit_forward") # Wait for compute to finish before starting offload self.d2h_stream.wait_stream(torch.cuda.current_stream()) @@ -740,6 +760,8 @@ def on_group_commit_backward(self, name): Called at the end of a layer group's backward pass. Ensures correct chunk is active and synchronizes reloads. """ + if not self.do_offload: + return debug_rank("--on_group_commit_backward") cur_backward_chunk = PipelineOffloadManager.get_instance().cur_backward_chunk() # Switch to this chunk if it's not already current @@ -753,21 +775,25 @@ def on_group_commit_backward(self, name): torch.cuda.current_stream().wait_event(event) self._offloaded_group_index = self._offloaded_group_index - 1 - def on_group_start_forward(self, name): + def on_group_start_forward(self, name, fake_offload=False): """ Called at the start of a layer group's forward pass. Increments group index and prepares for offloading. """ + if not self.do_offload: + return debug_rank(f"--on_group_start_forward") self._offloaded_group_index = self._offloaded_group_index + 1 self._tensor_count_current_group = 0 - self._groups_to_offload.append((self._offloaded_group_index, name)) + self._groups_to_offload.append((self._offloaded_group_index, name, fake_offload)) def on_group_start_backward(self): """ Called at the start of a layer group's backward pass. Triggers reloading of tensors from CPU. """ + if not self.do_offload: + return debug_rank("--on_group_start_backward") # Wait for compute to finish before starting reload self.h2d_stream.wait_stream(torch.cuda.current_stream()) @@ -835,12 +861,12 @@ class FineGrainedOffloadingGroupStartFunction(torch.autograd.Function): """ @staticmethod - def forward(ctx, tensor, cpu_offload_handler, name): + def forward(ctx, tensor, cpu_offload_handler, name, fake_offload): # pylint: disable=missing-function-docstring ctx.cpu_offload_handler = cpu_offload_handler debug_rank("FineGrainedOffloadingGroupStartFunction forward") - cpu_offload_handler.on_group_start_forward(name) + cpu_offload_handler.on_group_start_forward(name, fake_offload) # return the identical tensor return tensor @@ -850,13 +876,13 @@ def backward(ctx, grad_output): debug_rank("FineGrainedOffloadingGroupStartFunction backward") cpu_offload_handler = ctx.cpu_offload_handler cpu_offload_handler.on_group_start_backward() - return grad_output, None, None + return grad_output, None, None, None -def fine_grained_offloading_group_start(tensor, name=None): +def fine_grained_offloading_group_start(tensor, name=None, fake_offload=False): """Mark the start of a layer group and prepare for offload/reload.""" cur_forward_chunk = PipelineOffloadManager.get_instance().cur_forward_chunk() - return FineGrainedOffloadingGroupStartFunction.apply(tensor, cur_forward_chunk, name) + return FineGrainedOffloadingGroupStartFunction.apply(tensor, cur_forward_chunk, name, fake_offload) def get_fine_grained_offloading_context(flag): @@ -903,4 +929,11 @@ def backward(ctx, grad_output): return grad_output, None def fine_grained_offloading_backward_record(tensor, event: torch.cuda.Event) -> torch.Tensor: - return FineGrainedOffloadingBackwardRecordFunction.apply(tensor, event) \ No newline at end of file + return FineGrainedOffloadingBackwardRecordFunction.apply(tensor, event) + +def fine_grained_offloading_fake_offload(tensor): + """Fake offload.""" + tensor = fine_grained_offloading_group_start(tensor, name="fake_offload", fake_offload=True) + (tensor,) = fine_grained_offloading_group_commit(tensor, name="fake_offload", forced_released_tensors=[]) + return tensor + \ No newline at end of file diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index 4c9005e5b28..efbc7a768f6 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -1634,9 +1634,18 @@ def _get_fp8_enabled(): from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( fine_grained_offloading_disable_offload, fine_grained_offloading_enable_offload, + fine_grained_offloading_init_chunk_handler, ) - kwargs['pre_warmup_hook'] = fine_grained_offloading_disable_offload - kwargs['post_warmup_hook'] = fine_grained_offloading_enable_offload + from functools import partial + # if self.config.offload_module_in_cuda_graph: + if self.config.fine_grained_activation_offloading: + kwargs['pre_warmup_hook'] = fine_grained_offloading_disable_offload + kwargs['post_warmup_hook'] = fine_grained_offloading_enable_offload + kwargs['init_chunk_handler'] = partial( + fine_grained_offloading_init_chunk_handler, + vp_size=self.config.virtual_pipeline_model_parallel_size, + min_offloaded_tensor_size=self.config.min_offloaded_tensor_size + ) return kwargs kwargs = get_make_graphed_callables_kwargs() diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index 5d3f16c1041..7bd88919052 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -238,13 +238,20 @@ def forward( # Get the query, key and value tensors based on the type of attention - # self or cross attn. # query: [96, 1, 16, 128], key:[96, 1, 16, 128], value:[96, 1, 16, 128] - query, key, value = self.get_query_key_value_tensors( - hidden_states, - key_value_states, - position_ids, - packed_seq_params, - inference_context=inference_context, - ) + if self.offload_qkv_linear: + hidden_states = fine_grained_offloading_group_start(hidden_states, name="qkv_linear") + with get_fine_grained_offloading_context(self.offload_qkv_linear): + query, key, value = self.get_query_key_value_tensors( + hidden_states, + key_value_states, + position_ids, + packed_seq_params, + inference_context=inference_context, + ) + if self.offload_qkv_linear: + (query, key, value) = fine_grained_offloading_group_commit( + query, key, value, name="qkv_linear", forced_released_tensors=[hidden_states] + ) # =================================================== # Adjust key, value for inference diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py index 94fbfb23677..69f4b3d1264 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -949,6 +949,7 @@ def __init__( self._build_layers(pg_collection) assert len(self.layers) > 0, "MultiTokenPredictionBlock must have at least one layer." + self.layers[-1].is_last_layer = True self.cp_group = pg_collection.cp def _build_layers(self, pg_collection): @@ -1006,8 +1007,6 @@ def forward( hidden_states_list = list(torch.chunk(hidden_states, 1 + offset, dim=0)) hidden_states = hidden_states_list[offset] for layer_number in range(len(self.layers)): - if self.config.fine_grained_activation_offloading: - fine_grained_offloading_set_last_layer(layer_number == len(self.layers) - 1) (hidden_states, input_ids, position_ids) = self.layers[layer_number]( input_ids=input_ids, position_ids=position_ids, diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 0b5b3d80c09..9f7afe7ccd4 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -801,9 +801,6 @@ class TransformerConfig(ModelParallelConfig): """ min_offloaded_tensor_size: int = 1024 * 1024 """The minimum size of the tensor to be offloaded.""" - offload_module_in_cuda_graph: bool = False - """The flag is derived from the fine_grained_activation_offloading flag. - If True, mark the module in the cuda graph will be offloaded.""" def __post_init__(self): """Python dataclass method that is used to modify attributes after initialization. @@ -1179,6 +1176,7 @@ def __post_init__(self): "attn_norm", "mlp_norm", "qkv_linear", + "dense_mlp", } invalid_modules = set(self.offload_modules) - allowed_modules assert not invalid_modules, ( @@ -1194,19 +1192,10 @@ def __post_init__(self): if self.enable_cuda_graph or self.cuda_graph_impl == "local": raise ValueError("Fine-grained activation offloading does not support local implementation of CUDA graph.") if self.external_cuda_graph or self.cuda_graph_impl == "transformer_engine": + assert self.cuda_graph_scope is not None, "cuda_graph_scope must be set when enabling offloading." assert "attn" in self.cuda_graph_scope and "moe_router" in self.cuda_graph_scope, "attn and moe_router must be in cuda_graph_scope when enabling offloading." assert "attn_norm" not in self.offload_modules, "input of attn_norm is exactly the entry point of cuda graph, which cannot be offloaded." - self.offload_module_in_cuda_graph = \ - "attn_proj" in self.offload_modules \ - or "core_attn" in self.offload_modules \ - or "mlp_norm" in self.offload_modules \ - or "qkv_linear" in self.offload_modules - if self.offload_module_in_cuda_graph: - assert is_torch_min_version("2.9.0a0"), \ - "Fine-grained activation offloading needs torch>=2.9.0 to support cuda graph. " \ - f"Current torch version is {torch.__version__}." - assert self.cuda_graph_warmup_steps > 0, \ - "Fine-grained activation offloading needs cuda_graph_warmup_steps > 0." + assert "mlp_norm" not in self.offload_modules, "offloading mlp_norm goes through the boundary of the cuda graph, which cannot be offloaded." if ( self.num_layers_in_first_pipeline_stage is not None diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 8dde5271f4a..f422b07dff9 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -26,6 +26,7 @@ deprecate_inference_params, get_pg_rank, is_te_min_version, + is_torch_min_version, log_single_rank, make_viewless_tensor, nvtx_range_pop, @@ -412,17 +413,8 @@ def __init__( if "mlp" in self.config.recompute_modules: if not self.is_moe_layer: self.recompute_mlp = True - self.offload_attn_norm = ( - self.config.fine_grained_activation_offloading - and "attn_norm" in self.config.offload_modules - and not isinstance(self.input_layernorm, IdentityOp) - ) - self.offload_mlp_norm = ( - self.config.fine_grained_activation_offloading - and "mlp_norm" in self.config.offload_modules - and not isinstance(self.pre_mlp_layernorm, IdentityOp) - ) + self._set_offload_modules() # @jcasper how should we handle nvfuser? # Set bias+dropout+add fusion grad_enable execution handler. # TORCH_MAJOR = int(torch.__version__.split('.')[0]) @@ -431,17 +423,6 @@ def __init__( # self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad self.bias_dropout_add_exec_handler = torch.enable_grad - if TransformerLayer.cuda_graph_stream is None: - if self.config.offload_module_in_cuda_graph: - TransformerLayer.cuda_graph_stream = torch.cuda.Stream() - else: - TransformerLayer.cuda_graph_stream = torch.cuda.current_stream() - if TransformerLayer.cuda_graph_event is None: - if self.config.offload_module_in_cuda_graph: - TransformerLayer.cuda_graph_event = torch.cuda.Event(external=True) - else: - TransformerLayer.cuda_graph_event = torch.cuda.Event() - @staticmethod def _get_layer_offset(config: TransformerConfig): """ @@ -524,7 +505,7 @@ def _forward_attention( from megatron.core.pipeline_parallel.fine_grained_activation_offload import fine_grained_offloading_set_last_layer fine_grained_offloading_set_last_layer(self.is_last_layer) - if self.config.offload_module_in_cuda_graph: + if self.offload_module_in_cuda_graph: from megatron.core.pipeline_parallel.fine_grained_activation_offload import fine_grained_offloading_backward_record hidden_states = fine_grained_offloading_backward_record(hidden_states, TransformerLayer.cuda_graph_event) @@ -533,17 +514,17 @@ def _forward_attention( # Residual connection. residual = hidden_states - if self.offload_attn_norm: + if self.offload_modules["attn_norm"]: hidden_states = fine_grained_offloading_group_start(hidden_states, name="attn_norm") # Optional Input Layer norm if self.recompute_input_layernorm: self.input_layernorm_checkpoint = tensor_parallel.CheckpointWithoutOutput() - with get_fine_grained_offloading_context(self.offload_attn_norm): + with get_fine_grained_offloading_context(self.offload_modules["attn_norm"]): input_layernorm_output = self.input_layernorm_checkpoint.checkpoint( self.input_layernorm, hidden_states ) else: - with get_fine_grained_offloading_context(self.offload_attn_norm): + with get_fine_grained_offloading_context(self.offload_modules["attn_norm"]): input_layernorm_output = self.input_layernorm(hidden_states) # Self attention. @@ -578,7 +559,7 @@ def _forward_attention( ) nvtx_range_pop(suffix="self_attn_bda") - if self.offload_attn_norm: + if self.offload_modules["attn_norm"]: (hidden_states,) = fine_grained_offloading_group_commit( hidden_states, name="attn_norm", forced_released_tensors=[residual] ) @@ -623,24 +604,23 @@ def _forward_mlp(self, hidden_states, inference_context=None): from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( fine_grained_offloading_group_start, get_fine_grained_offloading_context, - PipelineOffloadManager, - fine_grained_offloading_forward_record, + fine_grained_offloading_group_commit, ) # Residual connection. residual = hidden_states - if self.offload_mlp_norm: + if self.offload_modules["mlp_norm"]: hidden_states = fine_grained_offloading_group_start(hidden_states, name="mlp_norm") # Optional Layer norm post the cross-attention. if self.recompute_pre_mlp_layernorm: self.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput() - with get_fine_grained_offloading_context(self.offload_mlp_norm): + with get_fine_grained_offloading_context(self.offload_modules["mlp_norm"]): pre_mlp_layernorm_output = self.pre_mlp_norm_checkpoint.checkpoint( self.pre_mlp_layernorm, hidden_states ) else: - with get_fine_grained_offloading_context(self.offload_mlp_norm): + with get_fine_grained_offloading_context(self.offload_modules["mlp_norm"]): pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states) nvtx_range_push(suffix="mlp") @@ -694,7 +674,17 @@ def _forward_mlp(self, hidden_states, inference_context=None): bias_output = torch.stack(bias_chunks, dim=0).sum(dim=0) if bias_chunks else None mlp_output_with_bias = (mlp_output, bias_output) else: - mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) + if self.offload_modules["dense_mlp"]: + pre_mlp_layernorm_output = fine_grained_offloading_group_start(pre_mlp_layernorm_output, name="dense_mlp") + with get_fine_grained_offloading_context(self.offload_modules["dense_mlp"]): + mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) + if self.offload_modules["dense_mlp"]: + (mlp_output,) = fine_grained_offloading_group_commit( + mlp_output_with_bias[0], name="dense_mlp", forced_released_tensors=[] + ) + mlp_output_with_bias = (mlp_output, mlp_output_with_bias[1]) + else: + mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) if self.recompute_pre_mlp_layernorm: # discard the output of the pre-mlp layernorm and register the recompute @@ -729,7 +719,7 @@ def _forward_post_mlp(self, mlp_output_with_bias, residual): mlp_output_with_bias, residual, self.hidden_dropout ) nvtx_range_pop(suffix="mlp_bda") - if self.offload_mlp_norm: + if self.offload_modules["mlp_norm"]: (hidden_states,) = fine_grained_offloading_group_commit( hidden_states, name="mlp_norm", forced_released_tensors=[residual] ) @@ -854,7 +844,7 @@ def _te_cuda_graph_capture(self, *args, **kwargs): cuda_graph_outputs = list(hidden_states) if context is not None: cuda_graph_outputs.append(context) - if self.config.offload_module_in_cuda_graph: + if self.offload_module_in_cuda_graph: from megatron.core.pipeline_parallel.fine_grained_activation_offload import fine_grained_offloading_forward_record fine_grained_offloading_forward_record(TransformerLayer.cuda_graph_event) return tuple(cuda_graph_outputs) @@ -866,6 +856,9 @@ def _te_cuda_graph_replay(self, *args, **kwargs): However, CUDA graph accepts only Tensor inputs. Hence, `inference_context` and `packed_seq_params` are excluded from input list. """ + if self.config.fine_grained_activation_offloading: + from megatron.core.pipeline_parallel.fine_grained_activation_offload import fine_grained_offloading_set_last_layer + fine_grained_offloading_set_last_layer(self.is_last_layer) context = None if self.config.cuda_graph_scope and 'attn' not in self.config.cuda_graph_scope: hidden_states, context = self._forward_attention(*args, **kwargs) @@ -882,8 +875,6 @@ def _te_cuda_graph_replay(self, *args, **kwargs): cuda_graph_output = list(super()._te_cuda_graph_replay(*args, **kwargs)) - # if self.config.offload_module_in_cuda_graph: - # torch.cuda.current_stream().wait_event(TransformerLayer.cuda_graph_event) if kwargs.get('context') is not None: context = cuda_graph_output.pop() @@ -1041,3 +1032,55 @@ def __call__(self, *args, **kwargs): 'inference_context' ].is_decode_only() return super().__call__(*args, **kwargs) + + def _set_offload_modules(self): + """Set the offload modules for the transformer layer.""" + self.offload_modules = { + "attn_norm": False, + "qkv_linear": False, + "core_attn": False, + "attn_proj": False, + "mlp_norm": False, + "expert_fc1": False, + "moe_act": False, + "dense_mlp": False, + } + if self.config.fine_grained_activation_offloading: + if "attn_norm" in self.config.offload_modules and not isinstance(self.input_layernorm, IdentityOp): + self.offload_modules["attn_norm"] = True + if "qkv_linear" in self.config.offload_modules: + self.offload_modules["qkv_linear"] = True + if "core_attn" in self.config.offload_modules: + self.offload_modules["core_attn"] = True + if "attn_proj" in self.config.offload_modules: + self.offload_modules["attn_proj"] = True + if "mlp_norm" in self.config.offload_modules and not isinstance(self.pre_mlp_layernorm, IdentityOp): + self.offload_modules["mlp_norm"] = True + if "expert_fc1" in self.config.offload_modules: + self.offload_modules["expert_fc1"] = True + if "moe_act" in self.config.offload_modules: + self.offload_modules["moe_act"] = True + if "dense_mlp" in self.config.offload_modules and not self.is_moe_layer: + self.offload_modules["dense_mlp"] = True + # Set the offload module in cuda graph flag. + self.offload_module_in_cuda_graph = False + if "attn" in self.config.cuda_graph_scope: + if self.offload_modules["core_attn"] or self.offload_modules["attn_proj"] or self.offload_modules["qkv_linear"]: + self.offload_module_in_cuda_graph = True + if (not self.is_moe_layer and 'mlp' in self.config.cuda_graph_scope): + if self.offload_modules["mlp_norm"] or self.offload_modules["dense_mlp"]: + self.offload_module_in_cuda_graph = True + if self.offload_module_in_cuda_graph: + assert is_torch_min_version("2.9.0a0"), "Fine-grained activation offloading needs torch>=2.9.0 to support cuda graph." + assert self.config.cuda_graph_warmup_steps > 0, "Fine-grained activation offloading needs cuda_graph_warmup_steps > 0." + # Set the cuda graph stream and event for the transformer layer. + if TransformerLayer.cuda_graph_stream is None: + if self.offload_module_in_cuda_graph: + TransformerLayer.cuda_graph_stream = torch.cuda.Stream() + else: + TransformerLayer.cuda_graph_stream = torch.cuda.current_stream() + if TransformerLayer.cuda_graph_event is None: + if self.offload_module_in_cuda_graph: + TransformerLayer.cuda_graph_event = torch.cuda.Event(external=True) + else: + TransformerLayer.cuda_graph_event = torch.cuda.Event() From 6d475ad2d4bb7de490015c1158017f8db8177b44 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 1 Dec 2025 21:20:39 -0800 Subject: [PATCH 13/74] bug fix Signed-off-by: Hongbin Liu --- .../pipeline_parallel/fine_grained_activation_offload.py | 2 +- megatron/core/transformer/multi_token_prediction.py | 2 +- megatron/core/transformer/transformer_block.py | 3 ++- megatron/training/training.py | 6 ++++++ 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index c58d7582f65..2a627f1bce3 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -347,7 +347,6 @@ def cpu_tensor_pool(self): def reset(self): """Reset manager state for a new training iteration.""" - set_ideal_affinity_for_current_gpu() self._inside_context = False self._cur_forward_chunk = None self._cur_backward_chunk = None @@ -607,6 +606,7 @@ def tensor_pop(self, tensor_tag): assert tensor_tag in self._tensor_tag_to_state, f"Tag {tensor_tag} not found" tensor = self._tensor_tag_to_state.pop(tensor_tag) # If tensor is offloaded (stored as tuple), reload it + # assert isinstance(tensor, torch.Tensor), "Tensor is not a tensor" if isinstance(tensor, tuple): tensor = self.reload(tensor) debug_rank(f"--------tensor_pop {tensor.shape}") diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py index 69f4b3d1264..2b03b44505b 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -949,7 +949,7 @@ def __init__( self._build_layers(pg_collection) assert len(self.layers) > 0, "MultiTokenPredictionBlock must have at least one layer." - self.layers[-1].is_last_layer = True + self.layers[-1].transformer_layer.is_last_layer = True self.cp_group = pg_collection.cp def _build_layers(self, pg_collection): diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index 0a05ca31182..b2b01ce8ede 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -327,7 +327,8 @@ def __init__( self._build_layers() self.num_layers_per_pipeline_rank = len(self.layers) - self.layers[-1].is_last_layer = True + if len(self.layers) > 0: + self.layers[-1].is_last_layer = True def _build_layers(self): # Transformer layers. diff --git a/megatron/training/training.py b/megatron/training/training.py index 967397bec10..916a3c18d1c 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -684,6 +684,12 @@ def pretrain( args = get_args() timers = get_timers() + if args.fine_grained_activation_offloading: + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + set_ideal_affinity_for_current_gpu + ) + set_ideal_affinity_for_current_gpu() + if args.log_progress: append_to_progress_log("Starting job") From 089da6cbe1ca3b0db39579e494c1b762b881c5b2 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Sun, 7 Dec 2025 18:00:33 -0800 Subject: [PATCH 14/74] support VPP Signed-off-by: Hongbin Liu --- .../fine_grained_activation_offload.py | 315 +++++++++++++----- megatron/core/transformer/cuda_graphs.py | 2 + megatron/training/arguments.py | 2 +- 3 files changed, 242 insertions(+), 77 deletions(-) diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index 2a627f1bce3..bf648ca8958 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -267,6 +267,22 @@ def __del__(self): """Destructor to ensure resources are released.""" self.clear() +class OffloadTensorGroup: + """ + A group of tensors to be offloaded together. + """ + def __init__(self, name): + self._name = name + self._tensors = {} + self._events = [] + self._aux = {} + self.offload = True + + def push_tensor(self, tag, tensor): + self._tensors[tag] = tensor + + def pop_tensor(self, tag): + return self._tensors.pop(tag) def set_ideal_affinity_for_current_gpu(): """Set CPU affinity for the current GPU to optimize host-device transfers.""" @@ -275,19 +291,13 @@ def set_ideal_affinity_for_current_gpu(): try: import cuda.bindings.driver as cuda_driver import cuda.bindings.runtime as cuda_runtime - except ImportError: + except: try: import cuda.cuda as cuda_driver import cuda.cudart as cuda_runtime - except ImportError: - # print("cuda-python may not be installed, skipping GPU affinity setting") - warnings.warn("cuda-python may not be installed, skipping GPU affinity setting") - return - try: - import pynvml - except ImportError: - warnings.warn("pynvml is not installed, skipping GPU affinity setting") - return + except: + raise RuntimeError("Please install cuda-python to enable GPU affinity setting") + import pynvml # Get current CUDA device ID err, device_id = cuda_runtime.cudaGetDevice() @@ -328,6 +338,17 @@ def __init__(self): self._h2d_stream = torch.cuda.Stream() # Shared CPU tensor pool for all chunks to improve reuse efficiency self._cpu_tensor_pool = GPUTensorPool(device="cpu", pin_memory=True) + + self._is_warmup = True + self._cached_chunks_forward = [] + self._cached_chunks_backward = [] + self._cached_chunks_index_backward = 0 + self._cached_chunks_index_forward = 0 + + self.do_offload = True + + # Margin to avoid offloading too many groups so that + self._offload_margin = 0 self.reset() @property @@ -355,6 +376,14 @@ def reset(self): # Reset CPU tensor pool to reuse all CPU tensors for next iteration if hasattr(self, '_cpu_tensor_pool'): self._cpu_tensor_pool.reset() + + if self._is_warmup and len(self._cached_chunks_forward) > 0: + self.post_warmup_callback() + self._cached_chunks_index_backward = 0 + self._cached_chunks_index_forward = 0 + + for chunk in self._cached_chunks_forward: + chunk.reset() def flush(self): """Flush all staged chunks to the backward queue in reverse order.""" @@ -371,27 +400,84 @@ def flush(self): for i in range(self._vpp): self._stages[i] = [] + def disable_offload(self): + """Disable the offload.""" + debug_rank("disable_offload") + self.do_offload = False + for chunk in self._cached_chunks_forward: + chunk.do_offload = False + + def enable_offload(self): + """Enable the offload.""" + debug_rank("enable_offload") + self.do_offload = True + for chunk in self._cached_chunks_forward: + chunk.do_offload = True + + def post_warmup_callback(self): + """Callback after warmup.""" + debug_rank("post_warmup_callback") + self._is_warmup = False + assert len(self._cached_chunks_forward) == len(self._cached_chunks_backward), \ + "Cached chunks forward and backward must have the same length" + for chunk in self._cached_chunks_forward: + chunk.is_warmup = False + assert chunk in self._cached_chunks_backward, "Chunk not found in cached chunks backward" + # Update the offload margin to the maximum number of deduplicated groups + self._offload_margin = max(self._offload_margin, chunk.get_max_deduplicated_groups()) + debug_rank(f"offload margin {self._offload_margin}") + last_group_with_same_name = {} + for chunk_idx, chunk in enumerate(reversed(self._cached_chunks_backward)): + for group in chunk.offload_groups: + last_group_with_same_name[group._name] = group + for name, group in last_group_with_same_name.items(): + if self._offload_margin > 0: + group.offload = False + self._offload_margin -= 1 + debug_rank(f"setting offload to false for group {name} at chunk index {chunk_idx}") + else: + break + + def push(self, handler): """Add a chunk handler to the backward queue.""" debug_rank(f"pushing handler {handler}") self._queue.append(handler) + if self._is_warmup: + self._cached_chunks_backward.append(handler) - def pop(self): + def pop(self, name=None): """Remove and set the next non-empty chunk as the current backward chunk.""" - assert self.size(), "Cannot pop from empty queue" - while self._queue: - self._cur_backward_chunk = self._queue.popleft() - if not self._cur_backward_chunk.is_empty_chunk(): - break - debug_rank(f"popping handler {self._cur_backward_chunk}") - - def front(self): + # assert self.size(), "Cannot pop from empty queue" + # while self._queue: + # self._cur_backward_chunk = self._queue.popleft() + # if not self._cur_backward_chunk.is_empty_chunk(): + # break + # debug_rank(f"popping handler {self._cur_backward_chunk}") + self._cur_backward_chunk = None + debug_rank(f"popping backward chunk {self._cached_chunks_index_backward}") + debug_rank(f"cached chunks backward {self._cached_chunks_backward}") + for idx, handler in enumerate(self._cached_chunks_backward[self._cached_chunks_index_backward:]): + self._cached_chunks_index_backward += 1 + if handler.is_empty_chunk(name): + debug_rank(f"handler {handler} at index {idx} is empty") + continue + self._cur_backward_chunk = handler # set the first non-empty chunk as the current backward chunk + break + assert self._cur_backward_chunk is not None, "No non-empty chunk found" + debug_rank(f"popped backward chunk {self._cur_backward_chunk} cached chunks index backward {self._cached_chunks_index_backward}") + + def front(self, name=None): """Get the first non-empty chunk handler without removing it from the queue.""" - if not self.size(): - return None - for chunk_handler in self._queue: - if not chunk_handler.is_empty_chunk(): - return chunk_handler + # if not self.size(): + # return None + # for chunk_handler in self._queue: + # if not chunk_handler.is_empty_chunk(): + # return chunk_handler + for idx, handler in enumerate(self._cached_chunks_backward[self._cached_chunks_index_backward:]): + if not handler.is_empty_chunk(name): + debug_rank(f"front handler {handler} at index {idx}") + return handler return None def size(self): @@ -409,6 +495,9 @@ def init_model_chunk_offload_handler( vp_stage: Virtual pipeline stage index (None means stage 0) min_offloaded_tensor_size: Minimum tensor size (in elements) to offload """ + if not self._is_warmup: + return + vp_size = 1 if vp_size is None else vp_size if self._stages is None: self._vpp = vp_size @@ -440,11 +529,26 @@ def init_model_chunk_offload_handler( self.flush() self._cur_forward_chunk = cur_chunk cur_chunk.vpp_rank = cur_vpp_rank + self._cached_chunks_forward.append(cur_chunk) def set_last_layer(self, is_last_layer): """Mark whether the current forward chunk is processing the last layer.""" debug_rank(f"set_last_layer {is_last_layer}") - self._cur_forward_chunk.is_last_layer = is_last_layer + # self._cur_forward_chunk.is_last_layer = is_last_layer + + def pop_forward_chunk(self, name=None): + """Get the current forward pass chunk handler.""" + debug_rank(f"pop_forward_chunk {self._cur_forward_chunk}") + if not self.do_offload: + return self._cur_forward_chunk + while (not self._is_warmup + and (self._cur_forward_chunk is None or self._cur_forward_chunk.finish_all_groups(name))): + self._cur_forward_chunk = self._cached_chunks_forward[self._cached_chunks_index_forward] + self._cached_chunks_index_forward += 1 + # if self._cached_chunks_index_forward == len(self._cached_chunks_forward): + # self._cached_chunks_index_forward = 0 + debug_rank(f"new cur_forward_chunk {self._cur_forward_chunk}") + return self._cur_forward_chunk def cur_forward_chunk(self): """Get the current forward pass chunk handler.""" @@ -550,6 +654,7 @@ def __init__(self, is_first_last_vpp_chunk, min_offloaded_tensor_size, cpu_tenso self._groups_to_offload = [] self._groups_to_reload = [] self._tensor_count_current_group = 0 + self._max_group_size = 0 # Counter for special torch tensor types (FakeTensor, FunctionalTensor) self.torch_tensor_count = 0 @@ -560,11 +665,49 @@ def __init__(self, is_first_last_vpp_chunk, min_offloaded_tensor_size, cpu_tenso self.min_offloaded_tensor_size = min_offloaded_tensor_size self.is_last_layer = False self.cpu_tensor_pool = cpu_tensor_pool + self.offload_groups = [] + self.is_warmup = True + def reset(self): + """Reset the chunk offload handler.""" + self._offloaded_group_index = 0 + self._groups_to_offload = [] + self._groups_to_reload = [] + self._tensor_count_current_group = 0 + self._offload_events = {} + self._reload_events = {} - def is_empty_chunk(self): + def is_empty_chunk(self, name=None): """Check if this chunk has no tensors to manage.""" - return len(self._tensor_tag_to_state) == 0 + debug_rank(f"------is_empty_chunk {self} {self._max_group_size}") + # return len(self._tensor_tag_to_state) == 0 + if name is not None: + for group in self.offload_groups: + debug_rank(f"group name {group._name} need name {name}") + if group._name == name: + return False + return True + return self._max_group_size == 0 + + def finish_all_groups(self, name=None) -> bool: + """Finish all groups.""" + debug_rank(f"------finish_all_groups {self} {self._max_group_size} {self._offloaded_group_index}") + #TODO: check if this is correct + if len(self._groups_to_reload) == 0 and len(self._offload_events) > 0: + return True + assert name is not None, "Name is required" + for group in self.offload_groups[self._offloaded_group_index:]: + if group._name == name: + return False + return True + + def find_next_group(self, name=None): + """Find the next group with the given name.""" + assert name is not None, "Name is required" + for group in self.offload_groups[self._offloaded_group_index:]: + if group._name == name: + return group + return None def is_first_last_layer(self): """ @@ -585,26 +728,30 @@ def tensor_push(self, tensor): torch._subclasses.functional_tensor.FunctionalTensor, ), ) + assert not torch_stray_tensor, "Stray tensor should not be offloaded" if not torch_stray_tensor: # Assign unique tag based on group index and position within group tensor_tag = (self._offloaded_group_index, self._tensor_count_current_group) self._tensor_count_current_group += 1 - assert tensor_tag not in self._tensor_tag_to_state, "Duplicate tensor tag" - self._tensor_tag_to_state[tensor_tag] = tensor + # assert tensor_tag not in self._tensor_tag_to_state, "Duplicate tensor tag" + # self._tensor_tag_to_state[tensor_tag] = tensor + self.offload_groups[self._offloaded_group_index-1].push_tensor(tensor_tag, tensor) else: # Use negative group ID for special tensor types tensor_tag = (-1, self.torch_tensor_count) self.torch_tensor_count += 1 - self._tensor_tag_to_state[tensor_tag] = tensor + # self._tensor_tag_to_state[tensor_tag] = tensor debug_rank(f"--------tensor_push {tensor_tag}") return tensor_tag def tensor_pop(self, tensor_tag): """Pop tensor from the offload handler.""" debug_rank(f"--------tensor_pop {tensor_tag}") - assert tensor_tag in self._tensor_tag_to_state, f"Tag {tensor_tag} not found" - tensor = self._tensor_tag_to_state.pop(tensor_tag) + # assert tensor_tag in self._tensor_tag_to_state, f"Tag {tensor_tag} not found" + # tensor = self._tensor_tag_to_state.pop(tensor_tag) + group_id, idx = tensor_tag + tensor = self.offload_groups[group_id-1].pop_tensor(tensor_tag) # If tensor is offloaded (stored as tuple), reload it # assert isinstance(tensor, torch.Tensor), "Tensor is not a tensor" if isinstance(tensor, tuple): @@ -624,8 +771,9 @@ def tensor_need_offloading_checker(self, tensor): def bulk_offload_group(self, group_to_offload): """offload a group of tensors recorded in tensor_push().""" debug_rank("------bulk_offload_group") - assert not self.is_first_last_layer(), "Should not offload first-last layer" + # assert not self.is_first_last_layer(), "Should not offload first-last layer" group_id_to_offload, name, fake_offload = group_to_offload + offload_group = self.offload_groups[group_id_to_offload-1] if fake_offload: return False torch.cuda.nvtx.range_push("activation offloading " + name) @@ -633,20 +781,16 @@ def bulk_offload_group(self, group_to_offload): if name == "expert_fc1" or name == "moe_act": use_cpu_pool = False with torch.cuda.stream(self.d2h_stream): - for tensor_tag, state in self._tensor_tag_to_state.items(): - group_id, _ = tensor_tag - if group_id == group_id_to_offload: - debug_rank(f"------tensor_tag {tensor_tag}") - debug_rank(f"------group_to_offload {group_to_offload}") - assert not isinstance(state, tuple), "Tensor already offloaded" - tensor_on_device = state - if self.tensor_need_offloading_checker(tensor_on_device): - state = self.offload(tensor_on_device, use_cpu_pool=use_cpu_pool) - event = torch.cuda.Event() - event.record(self.d2h_stream) - self._offload_events[name] = event - tensor_on_device.record_stream(self.d2h_stream) - self._tensor_tag_to_state[tensor_tag] = state + # for tensor_tag, state in self._tensor_tag_to_state.items(): + for tensor_tag, tensor_on_device in offload_group._tensors.items(): + if self.tensor_need_offloading_checker(tensor_on_device): + state = self.offload(tensor_on_device, use_cpu_pool=use_cpu_pool) + event = torch.cuda.Event() + event.record(self.d2h_stream) + self._offload_events[name] = event + tensor_on_device.record_stream(self.d2h_stream) + # self._tensor_tag_to_state[tensor_tag] = state + offload_group.push_tensor(tensor_tag, state) torch.cuda.nvtx.range_pop() return True @@ -657,56 +801,62 @@ def get_offload_event(self, name): def get_reload_event(self, name): """Get the CUDA event for a named reload operation.""" return self._reload_events.get(name, None) + + def get_max_deduplicated_groups(self): + """Get the maximum number of deduplicated groups.""" + return len(self._offload_events) def bulk_reload_group(self, group_to_reload): """Bulk reload group.""" debug_rank("----bulk_reload_group") group_id_to_reload, name, fake_reload = group_to_reload + offload_group = self.offload_groups[group_id_to_reload-1] if fake_reload: return True found_reload_group = False torch.cuda.nvtx.range_push("activation reloading " + name) with torch.cuda.stream(self.h2d_stream): - for tensor_label, state in self._tensor_tag_to_state.items(): - group_id, _ = tensor_label - if group_id == group_id_to_reload: - debug_rank(f"----tensor_label {tensor_label}") - found_reload_group = True - event = self.get_offload_event(name) - # Only reload if tensor was offloaded (stored as tuple) - if isinstance(state, tuple): - # Wait for offload to complete before reloading - if not is_graph_capturing(): - torch.cuda.current_stream().wait_event(event) - recovered_tensor = self.reload(state) - event.record(self.h2d_stream) - self._reload_events[name] = event - debug_rank(f"----recovered_tensor {recovered_tensor.shape}") - self._tensor_tag_to_state[tensor_label] = recovered_tensor + event = self.get_offload_event(name) + for tensor_tag, state in offload_group._tensors.items(): + found_reload_group = True + # Only reload if tensor was offloaded (stored as tuple) + if isinstance(state, tuple): + # Wait for offload to complete before reloading + if not is_graph_capturing(): + torch.cuda.current_stream().wait_event(event) + recovered_tensor = self.reload(state) + event.record(self.h2d_stream) + self._reload_events[name] = event + debug_rank(f"----recovered_tensor {recovered_tensor.shape}") + # self._tensor_tag_to_state[tensor_tag] = recovered_tensor + offload_group.push_tensor(tensor_tag, recovered_tensor) torch.cuda.nvtx.range_pop() return found_reload_group def pre_reload_last_layer(self): """Pre-reload the last layer of this chunk to hide reload latency.""" debug_rank("pre_reload_last_layer") - assert not self._is_first_last_vpp_chunk, "Should not pre-reload first chunk" + # assert not self._is_first_last_vpp_chunk, "Should not pre-reload first chunk" debug_rank(f"len(self._groups_to_reload) {len(self._groups_to_reload)}") if len(self._groups_to_reload) > 0: # Reload the last group (last layer) early if self.bulk_reload_group(self._groups_to_reload[-1]): self._groups_to_reload.pop() - def should_bulk_offload(self): + def should_bulk_offload(self, group_to_offload): """Determine if the current group should be offloaded.""" # Don't offload the first backward chunk's last layer - if self.is_first_last_layer(): + # if self.is_first_last_layer(): + group_id, name, fake_offload = group_to_offload + if not PipelineOffloadManager.get_instance()._is_warmup and not self.offload_groups[group_id-1].offload: return False # Check if next backward chunk is this chunk (for last pipeline stage) - next_backward_chunk = PipelineOffloadManager.get_instance().front() + next_backward_chunk = PipelineOffloadManager.get_instance().front(name=name) if next_backward_chunk is not None and next_backward_chunk is self: # Don't offload last layer if it's about to be used immediately - if self.is_last_layer: + if self.find_next_group(name) is None: + debug_rank(f"next group {name} is not found") return False return True @@ -714,7 +864,8 @@ def should_bulk_offload(self): def bulk_offload(self, forced_released_tensors): """Offload a group of tensors and optionally release their GPU memory.""" debug_rank("----bulk_offload") - if self.should_bulk_offload(): + group_to_offload = self._groups_to_offload[-1] + if self.should_bulk_offload(group_to_offload): group_to_offload = self._groups_to_offload.pop() # if group_to_offload[0] == 8: # # print("rank", torch.distributed.get_rank(), "group_to_offload", group_to_offload) @@ -766,9 +917,9 @@ def on_group_commit_backward(self, name): cur_backward_chunk = PipelineOffloadManager.get_instance().cur_backward_chunk() # Switch to this chunk if it's not already current if cur_backward_chunk is not self: - PipelineOffloadManager.get_instance().pop() + PipelineOffloadManager.get_instance().pop(name) cur_backward_chunk = PipelineOffloadManager.get_instance().cur_backward_chunk() - assert cur_backward_chunk is self, "Chunk mismatch" + assert cur_backward_chunk is self, f"Chunk mismatch {cur_backward_chunk} {self}" # Wait for reload to complete before using tensors event = self.get_reload_event(name) if event is not None and not is_graph_capturing(): @@ -782,10 +933,22 @@ def on_group_start_forward(self, name, fake_offload=False): """ if not self.do_offload: return - debug_rank(f"--on_group_start_forward") - self._offloaded_group_index = self._offloaded_group_index + 1 + debug_rank(f"--on_group_start_forward {name}") + if self.is_warmup: + self._offloaded_group_index = self._offloaded_group_index + 1 + self.offload_groups.append(OffloadTensorGroup(name)) + self._max_group_size = max(self._max_group_size, self._offloaded_group_index) + debug_rank(f"max group size {self._max_group_size}") + else: + self._offloaded_group_index = self._offloaded_group_index + 1 + for group in self.offload_groups[self._offloaded_group_index-1:]: + debug_rank(f"offloaded group index {self._offloaded_group_index} for group {group._name}") + if group._name == name: + break + self._offloaded_group_index = self._offloaded_group_index + 1 self._tensor_count_current_group = 0 self._groups_to_offload.append((self._offloaded_group_index, name, fake_offload)) + debug_rank(f"groups to offload {self._groups_to_offload}") def on_group_start_backward(self): """ @@ -802,12 +965,12 @@ def on_group_start_backward(self): def fine_grained_offloading_disable_offload(): """Disable the offload.""" debug_rank("fine_grained_offloading_disable_offload") - PipelineOffloadManager.get_instance().cur_forward_chunk().do_offload = False + PipelineOffloadManager.get_instance().disable_offload() def fine_grained_offloading_enable_offload(): """Enable the offload.""" debug_rank("fine_grained_offloading_enable_offload") - PipelineOffloadManager.get_instance().cur_forward_chunk().do_offload = True + PipelineOffloadManager.get_instance().enable_offload() class FineGrainedOffloadingGroupCommitFunction(torch.autograd.Function): """ @@ -881,7 +1044,7 @@ def backward(ctx, grad_output): def fine_grained_offloading_group_start(tensor, name=None, fake_offload=False): """Mark the start of a layer group and prepare for offload/reload.""" - cur_forward_chunk = PipelineOffloadManager.get_instance().cur_forward_chunk() + cur_forward_chunk = PipelineOffloadManager.get_instance().pop_forward_chunk(name=name) return FineGrainedOffloadingGroupStartFunction.apply(tensor, cur_forward_chunk, name, fake_offload) diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index efbc7a768f6..ca01658fa51 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -1635,6 +1635,7 @@ def _get_fp8_enabled(): fine_grained_offloading_disable_offload, fine_grained_offloading_enable_offload, fine_grained_offloading_init_chunk_handler, + fine_grained_offloading_reset, ) from functools import partial # if self.config.offload_module_in_cuda_graph: @@ -1646,6 +1647,7 @@ def _get_fp8_enabled(): vp_size=self.config.virtual_pipeline_model_parallel_size, min_offloaded_tensor_size=self.config.min_offloaded_tensor_size ) + kwargs['reset_hook'] = fine_grained_offloading_reset return kwargs kwargs = get_make_graphed_callables_kwargs() diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 494c82f7873..9cc1aad75f7 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -2367,7 +2367,7 @@ def _add_training_args(parser): help='Enable fine-grained activation offloading.') group.add_argument('--offload-modules', nargs='*', type=str, default=[], help='The submodules to offload its input. Choices: "attn_norm", "qkv_linear", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act".') - group.add_argument('--min-offloaded-tensor-size', type=int, default=1024*1024, + group.add_argument('--min-offloaded-tensor-size', type=int, default=10*1024*1024, help='The minimum size of the tensor to be offloaded.') group.add_argument('--disable-jit-fuser', action='store_true', help='Disable the JIT fuser.') From 35b0f970404ff139a5de5066bbaf74dac2352956 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Sun, 7 Dec 2025 23:35:31 -0800 Subject: [PATCH 15/74] code refactor Signed-off-by: Hongbin Liu --- .../common/model_chunk_schedule_plan.py | 7 - .../fine_grained_activation_offload.py | 133 ++++++++---------- .../transformer/multi_token_prediction.py | 4 - .../core/transformer/transformer_block.py | 5 - .../core/transformer/transformer_layer.py | 7 - 5 files changed, 62 insertions(+), 94 deletions(-) diff --git a/megatron/core/models/common/model_chunk_schedule_plan.py b/megatron/core/models/common/model_chunk_schedule_plan.py index 401d9a81a97..db40d93e63e 100644 --- a/megatron/core/models/common/model_chunk_schedule_plan.py +++ b/megatron/core/models/common/model_chunk_schedule_plan.py @@ -8,9 +8,6 @@ from megatron.core.enums import Fp8Recipe from megatron.core.fp8_utils import get_fp8_context -from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - fine_grained_offloading_set_last_layer, -) from megatron.core.pipeline_parallel.utils import ( AbstractSchedulePlan, NoopScheduleNode, @@ -455,8 +452,6 @@ def run( f_layer = f_schedule_plan.get_layer(i) b_layer = b_schedule_plan.get_layer(b_num_layers - 1 - i) torch.cuda.nvtx.range_push(f"layer_{i}f-layer_{b_num_layers - 1 - i}b") - if f_layer.layer.config.fine_grained_activation_offloading: - fine_grained_offloading_set_last_layer(i == f_num_layers - 1) f_input, b_grad = TransformerLayerSchedulePlan.run( f_layer, b_layer, @@ -479,8 +474,6 @@ def run( for i in range(overlapped_layers, f_num_layers): f_layer = f_schedule_plan.get_layer(i) torch.cuda.nvtx.range_push(f"layer_{i}f") - if f_layer.layer.config.fine_grained_activation_offloading: - fine_grained_offloading_set_last_layer(i == f_num_layers - 1) f_input, _ = TransformerLayerSchedulePlan.run(f_layer, None, f_input=f_input) torch.cuda.nvtx.range_pop() diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index bf648ca8958..ca87d3db0d9 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -274,7 +274,8 @@ class OffloadTensorGroup: def __init__(self, name): self._name = name self._tensors = {} - self._events = [] + self._offload_event = torch.cuda.Event() + self._reload_event = torch.cuda.Event() self._aux = {} self.offload = True @@ -284,6 +285,19 @@ def push_tensor(self, tag, tensor): def pop_tensor(self, tag): return self._tensors.pop(tag) + def record_offload_event(self, stream): + self._offload_event.record(stream) + + def wait_offload_event(self, stream): + stream.wait_event(self._offload_event) + + def record_reload_event(self, stream): + self._reload_event.record(stream) + + def wait_reload_event(self, stream): + stream.wait_event(self._reload_event) + + def set_ideal_affinity_for_current_gpu(): """Set CPU affinity for the current GPU to optimize host-device transfers.""" import uuid @@ -531,10 +545,6 @@ def init_model_chunk_offload_handler( cur_chunk.vpp_rank = cur_vpp_rank self._cached_chunks_forward.append(cur_chunk) - def set_last_layer(self, is_last_layer): - """Mark whether the current forward chunk is processing the last layer.""" - debug_rank(f"set_last_layer {is_last_layer}") - # self._cur_forward_chunk.is_last_layer = is_last_layer def pop_forward_chunk(self, name=None): """Get the current forward pass chunk handler.""" @@ -660,10 +670,9 @@ def __init__(self, is_first_last_vpp_chunk, min_offloaded_tensor_size, cpu_tenso self.torch_tensor_count = 0 self.d2h_stream = PipelineOffloadManager.get_instance().d2h_stream self.h2d_stream = PipelineOffloadManager.get_instance().h2d_stream - self._offload_events = {} - self._reload_events = {} + # self._offload_events = {} + # self._reload_events = {} self.min_offloaded_tensor_size = min_offloaded_tensor_size - self.is_last_layer = False self.cpu_tensor_pool = cpu_tensor_pool self.offload_groups = [] self.is_warmup = True @@ -674,8 +683,8 @@ def reset(self): self._groups_to_offload = [] self._groups_to_reload = [] self._tensor_count_current_group = 0 - self._offload_events = {} - self._reload_events = {} + # self._offload_events = {} + # self._reload_events = {} def is_empty_chunk(self, name=None): """Check if this chunk has no tensors to manage.""" @@ -693,7 +702,7 @@ def finish_all_groups(self, name=None) -> bool: """Finish all groups.""" debug_rank(f"------finish_all_groups {self} {self._max_group_size} {self._offloaded_group_index}") #TODO: check if this is correct - if len(self._groups_to_reload) == 0 and len(self._offload_events) > 0: + if len(self._groups_to_reload) == 0 and self._offloaded_group_index > 0: return True assert name is not None, "Name is required" for group in self.offload_groups[self._offloaded_group_index:]: @@ -709,15 +718,6 @@ def find_next_group(self, name=None): return group return None - def is_first_last_layer(self): - """ - Check if this is the last layer of the first microbatch of the last vp stage. - These tensors should not be offloaded to avoid unnecessary overhead. - """ - debug_rank( - f"------is_first_last_layer {self._is_first_last_vpp_chunk} {self.is_last_layer}" - ) - return self._is_first_last_vpp_chunk and self.is_last_layer def tensor_push(self, tensor): """Push tensor to the offload handler.""" @@ -771,11 +771,8 @@ def tensor_need_offloading_checker(self, tensor): def bulk_offload_group(self, group_to_offload): """offload a group of tensors recorded in tensor_push().""" debug_rank("------bulk_offload_group") - # assert not self.is_first_last_layer(), "Should not offload first-last layer" - group_id_to_offload, name, fake_offload = group_to_offload + group_id_to_offload, name = group_to_offload offload_group = self.offload_groups[group_id_to_offload-1] - if fake_offload: - return False torch.cuda.nvtx.range_push("activation offloading " + name) use_cpu_pool = True if name == "expert_fc1" or name == "moe_act": @@ -785,51 +782,56 @@ def bulk_offload_group(self, group_to_offload): for tensor_tag, tensor_on_device in offload_group._tensors.items(): if self.tensor_need_offloading_checker(tensor_on_device): state = self.offload(tensor_on_device, use_cpu_pool=use_cpu_pool) - event = torch.cuda.Event() - event.record(self.d2h_stream) - self._offload_events[name] = event + # event = torch.cuda.Event() + # event.record(self.d2h_stream) + # self._offload_events[name] = event tensor_on_device.record_stream(self.d2h_stream) # self._tensor_tag_to_state[tensor_tag] = state offload_group.push_tensor(tensor_tag, state) + offload_group.record_offload_event(self.d2h_stream) torch.cuda.nvtx.range_pop() - return True - def get_offload_event(self, name): - """Get the CUDA event for a named offload operation.""" - return self._offload_events.get(name, None) + # def get_offload_event(self, name): + # """Get the CUDA event for a named offload operation.""" + # return self._offload_events.get(name, None) - def get_reload_event(self, name): - """Get the CUDA event for a named reload operation.""" - return self._reload_events.get(name, None) + # def get_reload_event(self, name): + # """Get the CUDA event for a named reload operation.""" + # return self._reload_events.get(name, None) def get_max_deduplicated_groups(self): """Get the maximum number of deduplicated groups.""" - return len(self._offload_events) + count_modules = [] + for group in self.offload_groups: + if group._name not in count_modules: + count_modules.append(group._name) + return len(count_modules) def bulk_reload_group(self, group_to_reload): """Bulk reload group.""" debug_rank("----bulk_reload_group") - group_id_to_reload, name, fake_reload = group_to_reload + group_id_to_reload, name = group_to_reload offload_group = self.offload_groups[group_id_to_reload-1] - if fake_reload: - return True found_reload_group = False torch.cuda.nvtx.range_push("activation reloading " + name) with torch.cuda.stream(self.h2d_stream): - event = self.get_offload_event(name) + # event = self.get_offload_event(name) + if not is_graph_capturing(): + offload_group.wait_offload_event(self.h2d_stream) for tensor_tag, state in offload_group._tensors.items(): found_reload_group = True # Only reload if tensor was offloaded (stored as tuple) if isinstance(state, tuple): # Wait for offload to complete before reloading - if not is_graph_capturing(): - torch.cuda.current_stream().wait_event(event) + # if not is_graph_capturing(): + # torch.cuda.current_stream().wait_event(event) recovered_tensor = self.reload(state) - event.record(self.h2d_stream) - self._reload_events[name] = event + # event.record(self.h2d_stream) + # self._reload_events[name] = event debug_rank(f"----recovered_tensor {recovered_tensor.shape}") # self._tensor_tag_to_state[tensor_tag] = recovered_tensor offload_group.push_tensor(tensor_tag, recovered_tensor) + offload_group.record_reload_event(self.h2d_stream) torch.cuda.nvtx.range_pop() return found_reload_group @@ -846,8 +848,7 @@ def pre_reload_last_layer(self): def should_bulk_offload(self, group_to_offload): """Determine if the current group should be offloaded.""" # Don't offload the first backward chunk's last layer - # if self.is_first_last_layer(): - group_id, name, fake_offload = group_to_offload + group_id, name = group_to_offload if not PipelineOffloadManager.get_instance()._is_warmup and not self.offload_groups[group_id-1].offload: return False @@ -867,13 +868,8 @@ def bulk_offload(self, forced_released_tensors): group_to_offload = self._groups_to_offload[-1] if self.should_bulk_offload(group_to_offload): group_to_offload = self._groups_to_offload.pop() - # if group_to_offload[0] == 8: - # # print("rank", torch.distributed.get_rank(), "group_to_offload", group_to_offload) - # return self._groups_to_reload.append(group_to_offload) - ret = self.bulk_offload_group(group_to_offload) - if not ret: - return + self.bulk_offload_group(group_to_offload) # Manually release tensors not auto-freed by torch GC if len(forced_released_tensors) > 0: cur_stream = torch.cuda.current_stream() @@ -921,12 +917,17 @@ def on_group_commit_backward(self, name): cur_backward_chunk = PipelineOffloadManager.get_instance().cur_backward_chunk() assert cur_backward_chunk is self, f"Chunk mismatch {cur_backward_chunk} {self}" # Wait for reload to complete before using tensors - event = self.get_reload_event(name) - if event is not None and not is_graph_capturing(): - torch.cuda.current_stream().wait_event(event) - self._offloaded_group_index = self._offloaded_group_index - 1 + # event = self.get_reload_event(name) + # if event is not None and not is_graph_capturing(): + # torch.cuda.current_stream().wait_event(event) + if len(self._groups_to_reload) > 0: + group_to_reload = self._groups_to_reload[-1] + offload_group = self.offload_groups[group_to_reload[0]-1] + if not is_graph_capturing(): + offload_group.wait_reload_event(torch.cuda.current_stream()) + # self._offloaded_group_index = self._offloaded_group_index - 1 - def on_group_start_forward(self, name, fake_offload=False): + def on_group_start_forward(self, name): """ Called at the start of a layer group's forward pass. Increments group index and prepares for offloading. @@ -947,7 +948,7 @@ def on_group_start_forward(self, name, fake_offload=False): break self._offloaded_group_index = self._offloaded_group_index + 1 self._tensor_count_current_group = 0 - self._groups_to_offload.append((self._offloaded_group_index, name, fake_offload)) + self._groups_to_offload.append((self._offloaded_group_index, name)) debug_rank(f"groups to offload {self._groups_to_offload}") def on_group_start_backward(self): @@ -1024,12 +1025,12 @@ class FineGrainedOffloadingGroupStartFunction(torch.autograd.Function): """ @staticmethod - def forward(ctx, tensor, cpu_offload_handler, name, fake_offload): + def forward(ctx, tensor, cpu_offload_handler, name): # pylint: disable=missing-function-docstring ctx.cpu_offload_handler = cpu_offload_handler debug_rank("FineGrainedOffloadingGroupStartFunction forward") - cpu_offload_handler.on_group_start_forward(name, fake_offload) + cpu_offload_handler.on_group_start_forward(name) # return the identical tensor return tensor @@ -1042,10 +1043,10 @@ def backward(ctx, grad_output): return grad_output, None, None, None -def fine_grained_offloading_group_start(tensor, name=None, fake_offload=False): +def fine_grained_offloading_group_start(tensor, name=None): """Mark the start of a layer group and prepare for offload/reload.""" cur_forward_chunk = PipelineOffloadManager.get_instance().pop_forward_chunk(name=name) - return FineGrainedOffloadingGroupStartFunction.apply(tensor, cur_forward_chunk, name, fake_offload) + return FineGrainedOffloadingGroupStartFunction.apply(tensor, cur_forward_chunk, name) def get_fine_grained_offloading_context(flag): @@ -1053,11 +1054,6 @@ def get_fine_grained_offloading_context(flag): return PipelineOffloadManager.get_instance() if flag else nullcontext() -def fine_grained_offloading_set_last_layer(is_last_layer): - """Set the last layer flag.""" - PipelineOffloadManager.get_instance().set_last_layer(is_last_layer) - - def fine_grained_offloading_init_chunk_handler(vp_size, vp_stage, min_offloaded_tensor_size): """Initialize the chunk handler, called at the start of a microbatch forward pass.""" PipelineOffloadManager.get_instance().init_model_chunk_offload_handler( @@ -1094,9 +1090,4 @@ def backward(ctx, grad_output): def fine_grained_offloading_backward_record(tensor, event: torch.cuda.Event) -> torch.Tensor: return FineGrainedOffloadingBackwardRecordFunction.apply(tensor, event) -def fine_grained_offloading_fake_offload(tensor): - """Fake offload.""" - tensor = fine_grained_offloading_group_start(tensor, name="fake_offload", fake_offload=True) - (tensor,) = fine_grained_offloading_group_commit(tensor, name="fake_offload", forced_released_tensors=[]) - return tensor \ No newline at end of file diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py index 2b03b44505b..dbff1834619 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -13,9 +13,6 @@ from megatron.core.fp8_utils import get_fp8_context from megatron.core.models.backends import BackendSpecProvider, LocalSpecProvider from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - fine_grained_offloading_set_last_layer, -) from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel import ( gather_from_tensor_model_parallel_region, @@ -949,7 +946,6 @@ def __init__( self._build_layers(pg_collection) assert len(self.layers) > 0, "MultiTokenPredictionBlock must have at least one layer." - self.layers[-1].transformer_layer.is_last_layer = True self.cp_group = pg_collection.cp def _build_layers(self, pg_collection): diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index b2b01ce8ede..d85258903e2 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -16,9 +16,6 @@ from megatron.core.fusions.fused_layer_norm import FusedLayerNorm from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - fine_grained_offloading_set_last_layer, -) from megatron.core.pipeline_parallel.utils import is_vp_first_stage, is_vp_last_stage from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.enums import LayerType @@ -327,8 +324,6 @@ def __init__( self._build_layers() self.num_layers_per_pipeline_rank = len(self.layers) - if len(self.layers) > 0: - self.layers[-1].is_last_layer = True def _build_layers(self): # Transformer layers. diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index f422b07dff9..be19a728566 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -277,7 +277,6 @@ def __init__( self.pg_collection = pg_collection self.tp_group = pg_collection.tp - self.is_last_layer = False self.submodules_config = submodules self.layer_number = layer_number + get_transformer_layer_offset( @@ -501,9 +500,6 @@ def _forward_attention( fine_grained_offloading_group_start, get_fine_grained_offloading_context, ) - if self.config.fine_grained_activation_offloading: - from megatron.core.pipeline_parallel.fine_grained_activation_offload import fine_grained_offloading_set_last_layer - fine_grained_offloading_set_last_layer(self.is_last_layer) if self.offload_module_in_cuda_graph: from megatron.core.pipeline_parallel.fine_grained_activation_offload import fine_grained_offloading_backward_record @@ -856,9 +852,6 @@ def _te_cuda_graph_replay(self, *args, **kwargs): However, CUDA graph accepts only Tensor inputs. Hence, `inference_context` and `packed_seq_params` are excluded from input list. """ - if self.config.fine_grained_activation_offloading: - from megatron.core.pipeline_parallel.fine_grained_activation_offload import fine_grained_offloading_set_last_layer - fine_grained_offloading_set_last_layer(self.is_last_layer) context = None if self.config.cuda_graph_scope and 'attn' not in self.config.cuda_graph_scope: hidden_states, context = self._forward_attention(*args, **kwargs) From df09b85ace1980af25ff8532f2d64c1568b61d75 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 8 Dec 2025 01:35:32 -0800 Subject: [PATCH 16/74] big code refactor and format Signed-off-by: Hongbin Liu --- .../fine_grained_activation_offload.py | 406 ++++++++---------- megatron/core/pipeline_parallel/schedules.py | 29 +- megatron/core/transformer/cuda_graphs.py | 17 +- megatron/core/transformer/module.py | 1 + .../core/transformer/transformer_config.py | 18 +- .../core/transformer/transformer_layer.py | 49 ++- 6 files changed, 257 insertions(+), 263 deletions(-) diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index ca87d3db0d9..e093fe5412c 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -1,9 +1,8 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -import warnings from collections import deque from contextlib import nullcontext -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, Tuple import torch @@ -27,55 +26,51 @@ def debug_rank(message): class GPUTensorPool: """ GPU memory pool for efficient allocation and deallocation of tensors. - + Features: - Supports multiple tensor shapes and dtypes, each with its own pool - Dynamic allocation: tensors are created on-demand during allocation - Efficient reuse: freed tensors are returned to the pool for reuse - Uses queue-based management for O(1) allocation and deallocation - + Example: pool = GPUTensorPool(device='cuda:0') tensor = pool.allocate((128, 512), dtype=torch.float32) # ... use tensor ... pool.free(tensor, (128, 512), dtype=torch.float32) """ - - def __init__( - self, - device: str = 'cuda', - pin_memory: bool = False - ): + + def __init__(self, device: str = 'cuda', pin_memory: bool = False): """ Initialize GPU tensor pool. - + Args: device: GPU device, default 'cuda' pin_memory: Whether to use pinned memory (mainly for CPU tensors) """ self.device = torch.device(device) self.pin_memory = pin_memory - + # Maintain a separate pool for each (shape, dtype) combination # Structure: {(shape, dtype): {'free': deque, 'all': list, 'allocated_count': int}} self._pools: Dict[Tuple, Dict[str, Any]] = {} - + # Statistics self._stats = { - 'total_allocated': 0, # Total number of tensors ever allocated - 'current_in_use': 0, # Number of tensors currently in use - 'allocation_requests': 0, # Number of allocation requests - 'free_requests': 0, # Number of free requests - 'pool_hits': 0, # Number of times a tensor was reused from pool - 'pool_misses': 0, # Number of times a new tensor was created + 'total_allocated': 0, # Total number of tensors ever allocated + 'current_in_use': 0, # Number of tensors currently in use + 'allocation_requests': 0, # Number of allocation requests + 'free_requests': 0, # Number of free requests + 'pool_hits': 0, # Number of times a tensor was reused from pool + 'pool_misses': 0, # Number of times a new tensor was created } - + debug_rank("GPUTensorPool: Initialized with dynamic allocation") - + def _get_pool_key(self, shape: Tuple, dtype: torch.dtype) -> Tuple: """Generate a unique key for the pool based on shape and dtype.""" return (shape, dtype) - + @staticmethod def _calculate_memory_size(shape: Tuple, dtype: torch.dtype) -> int: """Calculate memory size in bytes.""" @@ -84,32 +79,32 @@ def _calculate_memory_size(shape: Tuple, dtype: torch.dtype) -> int: for dim in shape: numel *= dim return numel * element_size - + def allocate(self, shape: Tuple, dtype: torch.dtype = torch.float32) -> torch.Tensor: """ Allocate a tensor with the specified shape and dtype. - + Args: shape: Shape of the tensor dtype: Data type of the tensor, default torch.float32 - + Returns: Allocated tensor """ self._stats['allocation_requests'] += 1 - + pool_key = self._get_pool_key(shape, dtype) - + # Create pool for this (shape, dtype) if it doesn't exist if pool_key not in self._pools: self._pools[pool_key] = { - 'free': deque(), # Queue of available tensors - 'all': [], # List of all tensors (for tracking) - 'allocated_count': 0, # Number of allocated tensors + 'free': deque(), # Queue of available tensors + 'all': [], # List of all tensors (for tracking) + 'allocated_count': 0, # Number of allocated tensors } - + pool = self._pools[pool_key] - + # Try to reuse a tensor from the pool if len(pool['free']) > 0: tensor = pool['free'].popleft() @@ -121,36 +116,31 @@ def allocate(self, shape: Tuple, dtype: torch.dtype = torch.float32) -> torch.Te ) else: # Allocate a new tensor - tensor = torch.empty( - shape, - dtype=dtype, - device=self.device, - pin_memory=self.pin_memory - ) + tensor = torch.empty(shape, dtype=dtype, device=self.device, pin_memory=self.pin_memory) pool['all'].append(tensor) self._stats['total_allocated'] += 1 self._stats['pool_misses'] += 1 - - memory_mb = self._calculate_memory_size(shape, dtype) / (1024 ** 2) + + memory_mb = self._calculate_memory_size(shape, dtype) / (1024**2) debug_rank( f"GPUTensorPool.allocate: Created new tensor, " f"shape={shape}, dtype={dtype}, " f"memory={memory_mb:.2f} MB, " f"total_created={len(pool['all'])}" ) - + pool['allocated_count'] += 1 self._stats['current_in_use'] += 1 - + return tensor - + def free(self, tensor: torch.Tensor): """ Return a tensor to the pool for reuse. - + Args: tensor: Tensor to free - + Raises: ValueError: If tensor doesn't belong to this pool """ @@ -158,17 +148,17 @@ def free(self, tensor: torch.Tensor): shape = tensor.shape dtype = tensor.dtype - + pool_key = self._get_pool_key(shape, dtype) - + if pool_key not in self._pools: raise ValueError( f"No pool exists for shape={shape}, dtype={dtype}. " f"Available pools: {list(self._pools.keys())}" ) - + pool = self._pools[pool_key] - + # Verify tensor belongs to this pool (use identity check, not value comparison) tensor_found = any(tensor is t for t in pool['all']) if not tensor_found: @@ -176,127 +166,138 @@ def free(self, tensor: torch.Tensor): f"Attempting to free a tensor that doesn't belong to this pool " f"(shape={shape}, dtype={dtype})" ) - + # Return tensor to the free queue pool['free'].append(tensor) pool['allocated_count'] -= 1 self._stats['current_in_use'] -= 1 - + debug_rank( f"GPUTensorPool.free: shape={shape}, dtype={dtype}, " f"available in pool={len(pool['free'])}" ) - + def get_pool_status(self, shape: Tuple = None, dtype: torch.dtype = None) -> Dict[str, Any]: """ Get the status of the memory pool. - + Args: shape: If specified along with dtype, return status for that specific pool dtype: Data type (required if shape is specified) - + Returns: Dictionary containing status information """ if shape is not None: if dtype is None: raise ValueError("dtype must be specified when shape is provided") - + pool_key = self._get_pool_key(shape, dtype) - + if pool_key not in self._pools: raise ValueError(f"No pool exists for shape={shape}, dtype={dtype}") - + pool = self._pools[pool_key] total_count = len(pool['all']) - + return { 'shape': shape, 'dtype': dtype, 'total_count': total_count, 'allocated_count': pool['allocated_count'], 'free_count': len(pool['free']), - 'utilization': pool['allocated_count'] / total_count * 100 if total_count > 0 else 0, + 'utilization': ( + pool['allocated_count'] / total_count * 100 if total_count > 0 else 0 + ), } else: # Return status for all pools - status = { - 'global_stats': self._stats.copy(), - 'pools': {} - } - + status = {'global_stats': self._stats.copy(), 'pools': {}} + for pool_key in self._pools: shape, dtype = pool_key status['pools'][pool_key] = self.get_pool_status(shape, dtype) - + return status - + def reset(self): """Reset the pool, marking all tensors as available.""" debug_rank("GPUTensorPool: Resetting pool...") - + for pool_key, pool in self._pools.items(): # Clear and refill the free queue pool['free'].clear() for tensor in pool['all']: pool['free'].append(tensor) pool['allocated_count'] = 0 - + self._stats['current_in_use'] = 0 debug_rank("GPUTensorPool: Reset complete") - + def clear(self): """Clear the pool and release all GPU memory.""" debug_rank("GPUTensorPool: Clearing pool...") - + for pool_key, pool in self._pools.items(): # Clear all references, allowing PyTorch GC to reclaim memory pool['free'].clear() pool['all'].clear() - + self._pools.clear() self._stats['current_in_use'] = 0 - + # Trigger GPU cache cleanup if torch.cuda.is_available(): torch.cuda.empty_cache() - + debug_rank("GPUTensorPool: Clear complete") - + def __del__(self): """Destructor to ensure resources are released.""" self.clear() + class OffloadTensorGroup: """ A group of tensors to be offloaded together. """ + def __init__(self, name): self._name = name self._tensors = {} self._offload_event = torch.cuda.Event() self._reload_event = torch.cuda.Event() - self._aux = {} self.offload = True - + + if name == "expert_fc1" or name == "moe_act": + self.use_cpu_pool = False + else: + self.use_cpu_pool = True + def push_tensor(self, tag, tensor): + """Push a tensor to the group.""" self._tensors[tag] = tensor - + def pop_tensor(self, tag): + """Pop a tensor from the group.""" return self._tensors.pop(tag) def record_offload_event(self, stream): + """Record the offload event.""" self._offload_event.record(stream) - + def wait_offload_event(self, stream): + """Wait for the offload event.""" stream.wait_event(self._offload_event) - + def record_reload_event(self, stream): + """Record the reload event.""" self._reload_event.record(stream) - + def wait_reload_event(self, stream): + """Wait for the reload event.""" stream.wait_event(self._reload_event) - + def set_ideal_affinity_for_current_gpu(): """Set CPU affinity for the current GPU to optimize host-device transfers.""" @@ -385,12 +386,10 @@ def reset(self): self._inside_context = False self._cur_forward_chunk = None self._cur_backward_chunk = None - # Track the first microbatch of the last virtual pipeline stage - self._is_first_last_vpp_chunk = True # Reset CPU tensor pool to reuse all CPU tensors for next iteration if hasattr(self, '_cpu_tensor_pool'): self._cpu_tensor_pool.reset() - + if self._is_warmup and len(self._cached_chunks_forward) > 0: self.post_warmup_callback() self._cached_chunks_index_backward = 0 @@ -432,18 +431,24 @@ def post_warmup_callback(self): """Callback after warmup.""" debug_rank("post_warmup_callback") self._is_warmup = False - assert len(self._cached_chunks_forward) == len(self._cached_chunks_backward), \ - "Cached chunks forward and backward must have the same length" + assert len(self._cached_chunks_forward) == len( + self._cached_chunks_backward + ), "Cached chunks forward and backward must have the same length" for chunk in self._cached_chunks_forward: chunk.is_warmup = False - assert chunk in self._cached_chunks_backward, "Chunk not found in cached chunks backward" + assert ( + chunk in self._cached_chunks_backward + ), "Chunk not found in cached chunks backward" # Update the offload margin to the maximum number of deduplicated groups self._offload_margin = max(self._offload_margin, chunk.get_max_deduplicated_groups()) debug_rank(f"offload margin {self._offload_margin}") + # Fine the last group with the same name in the cached chunks backward last_group_with_same_name = {} for chunk_idx, chunk in enumerate(reversed(self._cached_chunks_backward)): for group in chunk.offload_groups: last_group_with_same_name[group._name] = group + # Mark the last group with the same name as not offloadable to make sure + # the reloading won't block the main stream. for name, group in last_group_with_same_name.items(): if self._offload_margin > 0: group.offload = False @@ -451,7 +456,8 @@ def post_warmup_callback(self): debug_rank(f"setting offload to false for group {name} at chunk index {chunk_idx}") else: break - + debug_rank(f"offload margin {self._offload_margin}") + assert self._offload_margin == 0, "Offload margin is not 0" def push(self, handler): """Add a chunk handler to the backward queue.""" @@ -462,42 +468,31 @@ def push(self, handler): def pop(self, name=None): """Remove and set the next non-empty chunk as the current backward chunk.""" - # assert self.size(), "Cannot pop from empty queue" - # while self._queue: - # self._cur_backward_chunk = self._queue.popleft() - # if not self._cur_backward_chunk.is_empty_chunk(): - # break - # debug_rank(f"popping handler {self._cur_backward_chunk}") self._cur_backward_chunk = None debug_rank(f"popping backward chunk {self._cached_chunks_index_backward}") debug_rank(f"cached chunks backward {self._cached_chunks_backward}") - for idx, handler in enumerate(self._cached_chunks_backward[self._cached_chunks_index_backward:]): + for idx, handler in enumerate( + self._cached_chunks_backward[self._cached_chunks_index_backward :] + ): self._cached_chunks_index_backward += 1 - if handler.is_empty_chunk(name): - debug_rank(f"handler {handler} at index {idx} is empty") - continue - self._cur_backward_chunk = handler # set the first non-empty chunk as the current backward chunk - break + if not handler.is_empty_chunk(name): + self._cur_backward_chunk = ( + handler # set the first non-empty chunk as the current backward chunk + ) + debug_rank(f"handler {handler} at index {idx} is not empty") + break assert self._cur_backward_chunk is not None, "No non-empty chunk found" - debug_rank(f"popped backward chunk {self._cur_backward_chunk} cached chunks index backward {self._cached_chunks_index_backward}") def front(self, name=None): """Get the first non-empty chunk handler without removing it from the queue.""" - # if not self.size(): - # return None - # for chunk_handler in self._queue: - # if not chunk_handler.is_empty_chunk(): - # return chunk_handler - for idx, handler in enumerate(self._cached_chunks_backward[self._cached_chunks_index_backward:]): + for idx, handler in enumerate( + self._cached_chunks_backward[self._cached_chunks_index_backward :] + ): if not handler.is_empty_chunk(name): debug_rank(f"front handler {handler} at index {idx}") return handler return None - def size(self): - """Return the number of chunk handlers in the queue.""" - return len(self._queue) - def init_model_chunk_offload_handler( self, vp_size, vp_stage, min_offloaded_tensor_size=1024 * 1024 ): @@ -511,7 +506,7 @@ def init_model_chunk_offload_handler( """ if not self._is_warmup: return - + vp_size = 1 if vp_size is None else vp_size if self._stages is None: self._vpp = vp_size @@ -522,41 +517,32 @@ def init_model_chunk_offload_handler( else: cur_vpp_rank = vp_stage - is_first_last_vpp_chunk = self._is_first_last_vpp_chunk # Flush staged chunks when reaching the last virtual pipeline stage if cur_vpp_rank == self._vpp - 1: self.flush() - # Determine if this is the first microbatch of the last virtual pipeline stage - is_first_last_vpp_chunk = is_first_last_vpp_chunk and (cur_vpp_rank == self._vpp - 1) # Use shared CPU tensor pool for better reuse across chunks - cur_chunk = ChunkOffloadHandler( - is_first_last_vpp_chunk, min_offloaded_tensor_size, self._cpu_tensor_pool - ) + cur_chunk = ChunkOffloadHandler(min_offloaded_tensor_size, self._cpu_tensor_pool) debug_rank(f"init_model_chunk_offload_handler {cur_chunk}") self._stages[cur_vpp_rank].append(cur_chunk) # For the last stage, push immediately and flush if cur_vpp_rank == self._vpp - 1: - if vp_size > 1: - self._is_first_last_vpp_chunk = False self.push(cur_chunk) self.flush() self._cur_forward_chunk = cur_chunk cur_chunk.vpp_rank = cur_vpp_rank self._cached_chunks_forward.append(cur_chunk) - def pop_forward_chunk(self, name=None): - """Get the current forward pass chunk handler.""" + """Get the next forward pass chunk handler.""" debug_rank(f"pop_forward_chunk {self._cur_forward_chunk}") if not self.do_offload: return self._cur_forward_chunk - while (not self._is_warmup - and (self._cur_forward_chunk is None or self._cur_forward_chunk.finish_all_groups(name))): + while not self._is_warmup and ( + self._cur_forward_chunk is None or self._cur_forward_chunk.finish_all_groups(name) + ): self._cur_forward_chunk = self._cached_chunks_forward[self._cached_chunks_index_forward] self._cached_chunks_index_forward += 1 - # if self._cached_chunks_index_forward == len(self._cached_chunks_forward): - # self._cached_chunks_index_forward = 0 debug_rank(f"new cur_forward_chunk {self._cur_forward_chunk}") return self._cur_forward_chunk @@ -571,7 +557,7 @@ def cur_backward_chunk(self): def __enter__(self): """Enter context manager to enable activation offloading hooks.""" debug_rank("----__enter__") - if not self.cur_forward_chunk().do_offload: + if self._cur_forward_chunk is None or not self.cur_forward_chunk().do_offload: return from megatron.core.extensions.transformer_engine import cpu_offload @@ -586,7 +572,7 @@ def __enter__(self): def __exit__(self, *args: Any): """Exit context manager and restore original tensor saving behavior.""" debug_rank("----__exit__") - if not self.cur_forward_chunk().do_offload: + if self._cur_forward_chunk is None or not self.cur_forward_chunk().do_offload: return from megatron.core.extensions.transformer_engine import cpu_offload @@ -629,7 +615,9 @@ def offload(self, src_tensor, pin_memory=True, use_cpu_pool=True): if use_cpu_pool: cpu_backup = self.cpu_tensor_pool.allocate(src_tensor.shape, dtype=src_tensor.dtype) else: - cpu_backup = torch.empty(src_tensor.shape, dtype=src_tensor.dtype, device="cpu", pin_memory=pin_memory) + cpu_backup = torch.empty( + src_tensor.shape, dtype=src_tensor.dtype, device="cpu", pin_memory=pin_memory + ) cpu_backup.copy_(src_tensor, non_blocking=pin_memory) state = (src_tensor.device, cpu_backup, use_cpu_pool) @@ -642,22 +630,19 @@ def reload(self, state, non_blocking=None): if non_blocking is None: non_blocking = cpu_backup.is_pinned() gpu_tensor = torch.empty( - cpu_backup.size(), - dtype=cpu_backup.dtype, - layout=cpu_backup.layout, - device=dev, + cpu_backup.size(), dtype=cpu_backup.dtype, layout=cpu_backup.layout, device=dev ) gpu_tensor.copy_(cpu_backup, non_blocking=non_blocking) if use_cpu_pool: self.cpu_tensor_pool.free(cpu_backup) return gpu_tensor - def __init__(self, is_first_last_vpp_chunk, min_offloaded_tensor_size, cpu_tensor_pool): + def __init__(self, min_offloaded_tensor_size, cpu_tensor_pool): self.do_offload = True # Data Structure to maintain reference to activation tensors self._tensor_tag_to_state = {} # Mark the first microbatch of the last virtual pipeline stage - self._is_first_last_vpp_chunk = is_first_last_vpp_chunk + # self._is_first_last_vpp_chunk = is_first_last_vpp_chunk # Group management for batching offload/reload operations self._offloaded_group_index = 0 @@ -688,7 +673,7 @@ def reset(self): def is_empty_chunk(self, name=None): """Check if this chunk has no tensors to manage.""" - debug_rank(f"------is_empty_chunk {self} {self._max_group_size}") + debug_rank(f"------is_empty_chunk {self._max_group_size}") # return len(self._tensor_tag_to_state) == 0 if name is not None: for group in self.offload_groups: @@ -697,28 +682,29 @@ def is_empty_chunk(self, name=None): return False return True return self._max_group_size == 0 - + def finish_all_groups(self, name=None) -> bool: """Finish all groups.""" - debug_rank(f"------finish_all_groups {self} {self._max_group_size} {self._offloaded_group_index}") - #TODO: check if this is correct + debug_rank( + f"------finish_all_groups {self} {self._max_group_size} {self._offloaded_group_index}" + ) + # TODO: check if this is correct if len(self._groups_to_reload) == 0 and self._offloaded_group_index > 0: return True assert name is not None, "Name is required" - for group in self.offload_groups[self._offloaded_group_index:]: + for group in self.offload_groups[self._offloaded_group_index :]: if group._name == name: return False return True - + def find_next_group(self, name=None): """Find the next group with the given name.""" assert name is not None, "Name is required" - for group in self.offload_groups[self._offloaded_group_index:]: + for group in self.offload_groups[self._offloaded_group_index :]: if group._name == name: return group return None - def tensor_push(self, tensor): """Push tensor to the offload handler.""" torch_stray_tensor = isinstance( @@ -736,7 +722,7 @@ def tensor_push(self, tensor): self._tensor_count_current_group += 1 # assert tensor_tag not in self._tensor_tag_to_state, "Duplicate tensor tag" # self._tensor_tag_to_state[tensor_tag] = tensor - self.offload_groups[self._offloaded_group_index-1].push_tensor(tensor_tag, tensor) + self.offload_groups[self._offloaded_group_index - 1].push_tensor(tensor_tag, tensor) else: # Use negative group ID for special tensor types tensor_tag = (-1, self.torch_tensor_count) @@ -751,7 +737,7 @@ def tensor_pop(self, tensor_tag): # assert tensor_tag in self._tensor_tag_to_state, f"Tag {tensor_tag} not found" # tensor = self._tensor_tag_to_state.pop(tensor_tag) group_id, idx = tensor_tag - tensor = self.offload_groups[group_id-1].pop_tensor(tensor_tag) + tensor = self.offload_groups[group_id - 1].pop_tensor(tensor_tag) # If tensor is offloaded (stored as tuple), reload it # assert isinstance(tensor, torch.Tensor), "Tensor is not a tensor" if isinstance(tensor, tuple): @@ -768,37 +754,24 @@ def tensor_need_offloading_checker(self, tensor): return False return True - def bulk_offload_group(self, group_to_offload): + def bulk_offload_group(self): """offload a group of tensors recorded in tensor_push().""" debug_rank("------bulk_offload_group") - group_id_to_offload, name = group_to_offload - offload_group = self.offload_groups[group_id_to_offload-1] - torch.cuda.nvtx.range_push("activation offloading " + name) - use_cpu_pool = True - if name == "expert_fc1" or name == "moe_act": - use_cpu_pool = False + group_to_offload = self._groups_to_offload[-1] + torch.cuda.nvtx.range_push("activation offloading " + group_to_offload._name) with torch.cuda.stream(self.d2h_stream): # for tensor_tag, state in self._tensor_tag_to_state.items(): - for tensor_tag, tensor_on_device in offload_group._tensors.items(): + for tensor_tag, tensor_on_device in group_to_offload._tensors.items(): if self.tensor_need_offloading_checker(tensor_on_device): - state = self.offload(tensor_on_device, use_cpu_pool=use_cpu_pool) - # event = torch.cuda.Event() - # event.record(self.d2h_stream) - # self._offload_events[name] = event + state = self.offload( + tensor_on_device, use_cpu_pool=group_to_offload.use_cpu_pool + ) tensor_on_device.record_stream(self.d2h_stream) - # self._tensor_tag_to_state[tensor_tag] = state - offload_group.push_tensor(tensor_tag, state) - offload_group.record_offload_event(self.d2h_stream) + group_to_offload.push_tensor(tensor_tag, state) + group_to_offload.record_offload_event(self.d2h_stream) + self._groups_to_offload.pop() torch.cuda.nvtx.range_pop() - # def get_offload_event(self, name): - # """Get the CUDA event for a named offload operation.""" - # return self._offload_events.get(name, None) - - # def get_reload_event(self, name): - # """Get the CUDA event for a named reload operation.""" - # return self._reload_events.get(name, None) - def get_max_deduplicated_groups(self): """Get the maximum number of deduplicated groups.""" count_modules = [] @@ -807,57 +780,47 @@ def get_max_deduplicated_groups(self): count_modules.append(group._name) return len(count_modules) - def bulk_reload_group(self, group_to_reload): + def bulk_reload_group(self): """Bulk reload group.""" debug_rank("----bulk_reload_group") - group_id_to_reload, name = group_to_reload - offload_group = self.offload_groups[group_id_to_reload-1] - found_reload_group = False - torch.cuda.nvtx.range_push("activation reloading " + name) + group_to_reload = self._groups_to_reload[-1] + torch.cuda.nvtx.range_push("activation reloading " + group_to_reload._name) with torch.cuda.stream(self.h2d_stream): - # event = self.get_offload_event(name) + # Wait for offload to complete before reloading if not is_graph_capturing(): - offload_group.wait_offload_event(self.h2d_stream) - for tensor_tag, state in offload_group._tensors.items(): - found_reload_group = True + group_to_reload.wait_offload_event(self.h2d_stream) + for tensor_tag, state in group_to_reload._tensors.items(): # Only reload if tensor was offloaded (stored as tuple) if isinstance(state, tuple): - # Wait for offload to complete before reloading - # if not is_graph_capturing(): - # torch.cuda.current_stream().wait_event(event) recovered_tensor = self.reload(state) - # event.record(self.h2d_stream) - # self._reload_events[name] = event debug_rank(f"----recovered_tensor {recovered_tensor.shape}") - # self._tensor_tag_to_state[tensor_tag] = recovered_tensor - offload_group.push_tensor(tensor_tag, recovered_tensor) - offload_group.record_reload_event(self.h2d_stream) + group_to_reload.push_tensor(tensor_tag, recovered_tensor) + group_to_reload.record_reload_event(self.h2d_stream) + self._groups_to_reload.pop() torch.cuda.nvtx.range_pop() - return found_reload_group def pre_reload_last_layer(self): """Pre-reload the last layer of this chunk to hide reload latency.""" debug_rank("pre_reload_last_layer") - # assert not self._is_first_last_vpp_chunk, "Should not pre-reload first chunk" debug_rank(f"len(self._groups_to_reload) {len(self._groups_to_reload)}") if len(self._groups_to_reload) > 0: # Reload the last group (last layer) early - if self.bulk_reload_group(self._groups_to_reload[-1]): - self._groups_to_reload.pop() + self.bulk_reload_group() - def should_bulk_offload(self, group_to_offload): + def should_bulk_offload(self): """Determine if the current group should be offloaded.""" - # Don't offload the first backward chunk's last layer - group_id, name = group_to_offload - if not PipelineOffloadManager.get_instance()._is_warmup and not self.offload_groups[group_id-1].offload: + group = self._groups_to_offload[-1] + # Don't offload if the chunk is not in warmup stage + # and the group is marked as not offloadable + if not PipelineOffloadManager.get_instance()._is_warmup and not group.offload: return False # Check if next backward chunk is this chunk (for last pipeline stage) - next_backward_chunk = PipelineOffloadManager.get_instance().front(name=name) + next_backward_chunk = PipelineOffloadManager.get_instance().front(name=group._name) if next_backward_chunk is not None and next_backward_chunk is self: - # Don't offload last layer if it's about to be used immediately - if self.find_next_group(name) is None: - debug_rank(f"next group {name} is not found") + # Don't offload the last group with the same name if it's about to be used immediately + if self.find_next_group(group._name) is None: + debug_rank(f"next group {group._name} is not found") return False return True @@ -865,11 +828,9 @@ def should_bulk_offload(self, group_to_offload): def bulk_offload(self, forced_released_tensors): """Offload a group of tensors and optionally release their GPU memory.""" debug_rank("----bulk_offload") - group_to_offload = self._groups_to_offload[-1] - if self.should_bulk_offload(group_to_offload): - group_to_offload = self._groups_to_offload.pop() - self._groups_to_reload.append(group_to_offload) - self.bulk_offload_group(group_to_offload) + if self.should_bulk_offload(): + self._groups_to_reload.append(self._groups_to_offload[-1]) + self.bulk_offload_group() # Manually release tensors not auto-freed by torch GC if len(forced_released_tensors) > 0: cur_stream = torch.cuda.current_stream() @@ -893,9 +854,7 @@ def bulk_reload(self): debug_rank("--bulk_reload") if len(self._groups_to_reload) > 0: # Reload the next layer group - if self.bulk_reload_group(self._groups_to_reload[-1]): - debug_rank(f"--bulk_reload_group {self._groups_to_reload}") - self._groups_to_reload.pop() + self.bulk_reload_group() else: # Pre-load the last layer of the next backward chunk to hide latency next_backward_chunk = PipelineOffloadManager.get_instance().front() @@ -917,15 +876,9 @@ def on_group_commit_backward(self, name): cur_backward_chunk = PipelineOffloadManager.get_instance().cur_backward_chunk() assert cur_backward_chunk is self, f"Chunk mismatch {cur_backward_chunk} {self}" # Wait for reload to complete before using tensors - # event = self.get_reload_event(name) - # if event is not None and not is_graph_capturing(): - # torch.cuda.current_stream().wait_event(event) - if len(self._groups_to_reload) > 0: + if not is_graph_capturing() and len(self._groups_to_reload) > 0: group_to_reload = self._groups_to_reload[-1] - offload_group = self.offload_groups[group_to_reload[0]-1] - if not is_graph_capturing(): - offload_group.wait_reload_event(torch.cuda.current_stream()) - # self._offloaded_group_index = self._offloaded_group_index - 1 + group_to_reload.wait_reload_event(torch.cuda.current_stream()) def on_group_start_forward(self, name): """ @@ -942,13 +895,15 @@ def on_group_start_forward(self, name): debug_rank(f"max group size {self._max_group_size}") else: self._offloaded_group_index = self._offloaded_group_index + 1 - for group in self.offload_groups[self._offloaded_group_index-1:]: - debug_rank(f"offloaded group index {self._offloaded_group_index} for group {group._name}") + for group in self.offload_groups[self._offloaded_group_index - 1 :]: + debug_rank( + f"offloaded group index {self._offloaded_group_index} for group {group._name}" + ) if group._name == name: break self._offloaded_group_index = self._offloaded_group_index + 1 self._tensor_count_current_group = 0 - self._groups_to_offload.append((self._offloaded_group_index, name)) + self._groups_to_offload.append(self.offload_groups[self._offloaded_group_index - 1]) debug_rank(f"groups to offload {self._groups_to_offload}") def on_group_start_backward(self): @@ -963,16 +918,19 @@ def on_group_start_backward(self): self.h2d_stream.wait_stream(torch.cuda.current_stream()) self.bulk_reload() + def fine_grained_offloading_disable_offload(): """Disable the offload.""" debug_rank("fine_grained_offloading_disable_offload") PipelineOffloadManager.get_instance().disable_offload() + def fine_grained_offloading_enable_offload(): """Enable the offload.""" debug_rank("fine_grained_offloading_enable_offload") PipelineOffloadManager.get_instance().enable_offload() + class FineGrainedOffloadingGroupCommitFunction(torch.autograd.Function): """ Identity operation that marks the end of a layer group for offload synchronization. @@ -1013,6 +971,8 @@ def fine_grained_offloading_group_commit(*tensor, name, forced_released_tensors= Note: specify the tensors only when they are not automatically released by torch gc. """ cur_forward_chunk = PipelineOffloadManager.get_instance().cur_forward_chunk() + if cur_forward_chunk is None: + return tensor return FineGrainedOffloadingGroupCommitFunction.apply( *tensor, cur_forward_chunk, name, forced_released_tensors ) @@ -1046,6 +1006,8 @@ def backward(ctx, grad_output): def fine_grained_offloading_group_start(tensor, name=None): """Mark the start of a layer group and prepare for offload/reload.""" cur_forward_chunk = PipelineOffloadManager.get_instance().pop_forward_chunk(name=name) + if cur_forward_chunk is None: + return tensor return FineGrainedOffloadingGroupStartFunction.apply(tensor, cur_forward_chunk, name) @@ -1060,15 +1022,19 @@ def fine_grained_offloading_init_chunk_handler(vp_size, vp_stage, min_offloaded_ vp_size, vp_stage, min_offloaded_tensor_size ) + def fine_grained_offloading_reset(): """Reset the chunk handler, called at the start of a training iteration.""" PipelineOffloadManager.get_instance().reset() + def fine_grained_offloading_forward_record(event: torch.cuda.Event) -> None: + """Record the forward event for cuda graph capture.""" d2h_stream = PipelineOffloadManager.get_instance().d2h_stream torch.cuda.current_stream().record_event(event) torch.cuda.current_stream().wait_stream(d2h_stream) + class FineGrainedOffloadingBackwardRecordFunction(torch.autograd.Function): """ Identity operation that marks the end of a layer group for offload synchronization. @@ -1077,17 +1043,19 @@ class FineGrainedOffloadingBackwardRecordFunction(torch.autograd.Function): @staticmethod def forward(ctx, tensor, event: torch.cuda.Event) -> torch.Tensor: + """Forward pass for cuda graph capture.""" ctx.event = event return tensor - + @staticmethod def backward(ctx, grad_output): + """Record the backward event and wait for the h2d stream on cuda graph stream.""" h2d_stream = PipelineOffloadManager.get_instance().h2d_stream torch.cuda.current_stream().record_event(ctx.event) torch.cuda.current_stream().wait_stream(h2d_stream) return grad_output, None + def fine_grained_offloading_backward_record(tensor, event: torch.cuda.Event) -> torch.Tensor: + """Record the backward event for cuda graph capture.""" return FineGrainedOffloadingBackwardRecordFunction.apply(tensor, event) - - \ No newline at end of file diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 6ee29b38bd1..a85ef572014 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -425,7 +425,7 @@ def forward_step( return [output_tensor], num_tokens -def backward_step(model, input_tensor, output_tensor, output_tensor_grad, model_type, config): +def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config): """Backward step through passed-in output tensor. If last stage, output_tensor_grad is None, otherwise gradient of loss @@ -565,9 +565,6 @@ def forward_backward_no_pipelining( if config.timers is not None: config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time) - if not forward_only and config.fine_grained_activation_offloading: - fine_grained_offloading_reset() - no_sync_func = config.no_sync_func if no_sync_func is None: no_sync_func = contextlib.nullcontext @@ -614,7 +611,7 @@ def forward_backward_no_pipelining( total_num_tokens += num_tokens if not forward_only: backward_step( - model, input_tensor, output_tensor, output_tensor_grad, model_type, config + input_tensor, output_tensor, output_tensor_grad, model_type, config ) # Run computation for last microbatch out of context handler (want to # synchronize gradients). @@ -637,7 +634,7 @@ def forward_backward_no_pipelining( total_num_tokens += num_tokens if not forward_only: - backward_step(model, input_tensor, output_tensor, output_tensor_grad, model_type, config) + backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) if config.finalize_model_grads_func is not None and not forward_only: # Finalize model grads (perform full grad all-reduce / reduce-scatter for @@ -648,6 +645,9 @@ def forward_backward_no_pipelining( pg_collection=pg_collection, ) + if not forward_only and config.fine_grained_activation_offloading: + fine_grained_offloading_reset() + if config.timers is not None: config.timers('forward-backward').stop() @@ -904,9 +904,6 @@ def forward_backward_pipelining_with_interleaving( adjust_tensor_shapes_fn is None ), "adjust_tensor_shapes_fn is not supported for interleaved pipeline parallelism" - if not forward_only and config.fine_grained_activation_offloading: - fine_grained_offloading_reset() - if config.overlap_p2p_comm and config.batch_p2p_comm: raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm") @@ -1289,7 +1286,7 @@ def backward_step_helper(virtual_microbatch_id): ) input_tensor_grad = backward_step( - None, input_tensor, output_tensor, output_tensor_grad, model_type, config + input_tensor, output_tensor, output_tensor_grad, model_type, config ) backward_step_helper_postprocess(virtual_microbatch_id) @@ -1911,6 +1908,8 @@ def pp_post_backward(input_tensor_grad, vp_stage=None): pg_collection=pg_collection, ) + if not forward_only and config.fine_grained_activation_offloading: + fine_grained_offloading_reset() # Restore config.grad_sync_func and config.param_sync_func. if forward_only: config.grad_sync_func, config.param_sync_func = grad_sync_func, param_sync_func @@ -2052,9 +2051,6 @@ def forward_backward_pipelining_without_interleaving( if config.timers is not None: config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time) - if not forward_only and config.fine_grained_activation_offloading: - fine_grained_offloading_reset() - # Disable async grad reductions no_sync_func = config.no_sync_func if no_sync_func is None: @@ -2236,7 +2232,7 @@ def enable_grad_sync(): enable_grad_sync() input_tensor_grad = backward_step( - None, input_tensor, output_tensor, output_tensor_grad, model_type, config + input_tensor, output_tensor, output_tensor_grad, model_type, config ) if last_iteration: @@ -2272,7 +2268,7 @@ def enable_grad_sync(): ) input_tensor_grad = backward_step( - None, input_tensor, output_tensor, output_tensor_grad, model_type, config + input_tensor, output_tensor, output_tensor_grad, model_type, config ) p2p_communicator.send_backward( @@ -2302,6 +2298,9 @@ def enable_grad_sync(): pg_collection=pg_collection, ) + if not forward_only and config.fine_grained_activation_offloading: + fine_grained_offloading_reset() + if config.timers is not None: config.timers('forward-backward').stop() diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index ca01658fa51..6a8090854ec 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -1630,24 +1630,16 @@ def _get_fp8_enabled(): ) else: kwargs['fp8_enabled'] = False - + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( fine_grained_offloading_disable_offload, fine_grained_offloading_enable_offload, - fine_grained_offloading_init_chunk_handler, - fine_grained_offloading_reset, ) - from functools import partial + # if self.config.offload_module_in_cuda_graph: if self.config.fine_grained_activation_offloading: kwargs['pre_warmup_hook'] = fine_grained_offloading_disable_offload kwargs['post_warmup_hook'] = fine_grained_offloading_enable_offload - kwargs['init_chunk_handler'] = partial( - fine_grained_offloading_init_chunk_handler, - vp_size=self.config.virtual_pipeline_model_parallel_size, - min_offloaded_tensor_size=self.config.min_offloaded_tensor_size - ) - kwargs['reset_hook'] = fine_grained_offloading_reset return kwargs kwargs = get_make_graphed_callables_kwargs() @@ -1680,8 +1672,13 @@ def _finish_capturing(self, start_time): _set_capture_end() from megatron.core.distributed.finalize_model_grads import reset_model_temporary_tensors + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + fine_grained_offloading_reset, + ) from megatron.core.transformer.moe.moe_utils import clear_aux_losses_tracker + fine_grained_offloading_reset() + torch.distributed.barrier() for model_chunk in self.model: model_chunk.zero_grad_buffer() diff --git a/megatron/core/transformer/module.py b/megatron/core/transformer/module.py index 701a2bd4c8c..96b5ded5991 100644 --- a/megatron/core/transformer/module.py +++ b/megatron/core/transformer/module.py @@ -282,6 +282,7 @@ def _get_te_cuda_graph_replay_args(self, *args, **kwargs): cudagraph_kwargs = kwargs.copy() cudagraph_kwargs['is_first_microbatch'] = getattr(self, 'current_microbatch', 0) == 0 from megatron.core.transformer.transformer_layer import TransformerLayer + cudagraph_kwargs['cuda_graph_stream'] = TransformerLayer.cuda_graph_stream cudagraph_kwargs['cuda_graph_event'] = TransformerLayer.cuda_graph_event return cudagraph_args, cudagraph_kwargs diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 9f7afe7ccd4..badaa9032a5 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1189,13 +1189,19 @@ def __post_init__(self): "because the input of attn_proj is the output of core_attn, " "which is needed in core_attn.backward()." ) - if self.enable_cuda_graph or self.cuda_graph_impl == "local": - raise ValueError("Fine-grained activation offloading does not support local implementation of CUDA graph.") if self.external_cuda_graph or self.cuda_graph_impl == "transformer_engine": - assert self.cuda_graph_scope is not None, "cuda_graph_scope must be set when enabling offloading." - assert "attn" in self.cuda_graph_scope and "moe_router" in self.cuda_graph_scope, "attn and moe_router must be in cuda_graph_scope when enabling offloading." - assert "attn_norm" not in self.offload_modules, "input of attn_norm is exactly the entry point of cuda graph, which cannot be offloaded." - assert "mlp_norm" not in self.offload_modules, "offloading mlp_norm goes through the boundary of the cuda graph, which cannot be offloaded." + assert ( + self.cuda_graph_scope is not None + ), "cuda_graph_scope must be set when enabling offloading." + assert ( + "attn" in self.cuda_graph_scope and "moe_router" in self.cuda_graph_scope + ), "attn and moe_router must be in cuda_graph_scope when enabling offloading." + assert ( + "attn_norm" not in self.offload_modules + ), "input of attn_norm is the start point of cuda graph, which can't be offloaded." + assert ( + "mlp_norm" not in self.offload_modules + ), "mlp_norm goes through the boundary of cuda graph, which can't be offloaded." if ( self.num_layers_in_first_pipeline_stage is not None diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index be19a728566..7dd4550b0db 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -259,8 +259,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): Transformer layer takes input with size [s, b, h] and returns an output of the same size. """ + cuda_graph_stream = None cuda_graph_event = None + def __init__( self, config: TransformerConfig, @@ -277,7 +279,6 @@ def __init__( self.pg_collection = pg_collection self.tp_group = pg_collection.tp - self.submodules_config = submodules self.layer_number = layer_number + get_transformer_layer_offset( self.config, vp_stage, get_pg_rank(pg_collection.pp) @@ -502,8 +503,13 @@ def _forward_attention( ) if self.offload_module_in_cuda_graph: - from megatron.core.pipeline_parallel.fine_grained_activation_offload import fine_grained_offloading_backward_record - hidden_states = fine_grained_offloading_backward_record(hidden_states, TransformerLayer.cuda_graph_event) + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + fine_grained_offloading_backward_record, + ) + + hidden_states = fine_grained_offloading_backward_record( + hidden_states, TransformerLayer.cuda_graph_event + ) inference_context = deprecate_inference_params(inference_context, inference_params) @@ -598,9 +604,9 @@ def _forward_mlp(self, hidden_states, inference_context=None): """ from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + fine_grained_offloading_group_commit, fine_grained_offloading_group_start, get_fine_grained_offloading_context, - fine_grained_offloading_group_commit, ) # Residual connection. @@ -671,7 +677,9 @@ def _forward_mlp(self, hidden_states, inference_context=None): mlp_output_with_bias = (mlp_output, bias_output) else: if self.offload_modules["dense_mlp"]: - pre_mlp_layernorm_output = fine_grained_offloading_group_start(pre_mlp_layernorm_output, name="dense_mlp") + pre_mlp_layernorm_output = fine_grained_offloading_group_start( + pre_mlp_layernorm_output, name="dense_mlp" + ) with get_fine_grained_offloading_context(self.offload_modules["dense_mlp"]): mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) if self.offload_modules["dense_mlp"]: @@ -841,7 +849,10 @@ def _te_cuda_graph_capture(self, *args, **kwargs): if context is not None: cuda_graph_outputs.append(context) if self.offload_module_in_cuda_graph: - from megatron.core.pipeline_parallel.fine_grained_activation_offload import fine_grained_offloading_forward_record + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + fine_grained_offloading_forward_record, + ) + fine_grained_offloading_forward_record(TransformerLayer.cuda_graph_event) return tuple(cuda_graph_outputs) @@ -1025,7 +1036,7 @@ def __call__(self, *args, **kwargs): 'inference_context' ].is_decode_only() return super().__call__(*args, **kwargs) - + def _set_offload_modules(self): """Set the offload modules for the transformer layer.""" self.offload_modules = { @@ -1039,7 +1050,9 @@ def _set_offload_modules(self): "dense_mlp": False, } if self.config.fine_grained_activation_offloading: - if "attn_norm" in self.config.offload_modules and not isinstance(self.input_layernorm, IdentityOp): + if "attn_norm" in self.config.offload_modules and not isinstance( + self.input_layernorm, IdentityOp + ): self.offload_modules["attn_norm"] = True if "qkv_linear" in self.config.offload_modules: self.offload_modules["qkv_linear"] = True @@ -1047,7 +1060,9 @@ def _set_offload_modules(self): self.offload_modules["core_attn"] = True if "attn_proj" in self.config.offload_modules: self.offload_modules["attn_proj"] = True - if "mlp_norm" in self.config.offload_modules and not isinstance(self.pre_mlp_layernorm, IdentityOp): + if "mlp_norm" in self.config.offload_modules and not isinstance( + self.pre_mlp_layernorm, IdentityOp + ): self.offload_modules["mlp_norm"] = True if "expert_fc1" in self.config.offload_modules: self.offload_modules["expert_fc1"] = True @@ -1058,14 +1073,22 @@ def _set_offload_modules(self): # Set the offload module in cuda graph flag. self.offload_module_in_cuda_graph = False if "attn" in self.config.cuda_graph_scope: - if self.offload_modules["core_attn"] or self.offload_modules["attn_proj"] or self.offload_modules["qkv_linear"]: + if ( + self.offload_modules["core_attn"] + or self.offload_modules["attn_proj"] + or self.offload_modules["qkv_linear"] + ): self.offload_module_in_cuda_graph = True - if (not self.is_moe_layer and 'mlp' in self.config.cuda_graph_scope): + if not self.is_moe_layer and 'mlp' in self.config.cuda_graph_scope: if self.offload_modules["mlp_norm"] or self.offload_modules["dense_mlp"]: self.offload_module_in_cuda_graph = True if self.offload_module_in_cuda_graph: - assert is_torch_min_version("2.9.0a0"), "Fine-grained activation offloading needs torch>=2.9.0 to support cuda graph." - assert self.config.cuda_graph_warmup_steps > 0, "Fine-grained activation offloading needs cuda_graph_warmup_steps > 0." + assert is_torch_min_version( + "2.9.0a0" + ), "Fine-grained activation offloading needs torch>=2.9.0 to support cuda graph." + assert ( + self.config.cuda_graph_warmup_steps > 0 + ), "Fine-grained activation offloading needs cuda_graph_warmup_steps > 0." # Set the cuda graph stream and event for the transformer layer. if TransformerLayer.cuda_graph_stream is None: if self.offload_module_in_cuda_graph: From 12cb8de829156c10dc4efff8010606973b75ae14 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 8 Dec 2025 05:57:00 -0800 Subject: [PATCH 17/74] minor fix Signed-off-by: Hongbin Liu --- megatron/core/transformer/transformer_config.py | 4 +++- megatron/core/transformer/transformer_layer.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index bfa642acba0..9fbaa3a6940 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1248,7 +1248,9 @@ def __post_init__(self): self.cuda_graph_scope is not None ), "cuda_graph_scope must be set when enabling offloading." assert ( - "attn" in self.cuda_graph_scope and "moe_router" in self.cuda_graph_scope + ("attn" in self.cuda_graph_scope and "moe_router" in self.cuda_graph_scope) + or (CudaGraphScope.attn in self.cuda_graph_scope + and CudaGraphScope.moe_router in self.cuda_graph_scope) ), "attn and moe_router must be in cuda_graph_scope when enabling offloading." assert ( "attn_norm" not in self.offload_modules diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 6c7a40586d3..d2b7c585a38 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -1077,14 +1077,14 @@ def _set_offload_modules(self): self.offload_modules["dense_mlp"] = True # Set the offload module in cuda graph flag. self.offload_module_in_cuda_graph = False - if "attn" in self.config.cuda_graph_scope: + if CudaGraphScope.attn in self.config.cuda_graph_scope: if ( self.offload_modules["core_attn"] or self.offload_modules["attn_proj"] or self.offload_modules["qkv_linear"] ): self.offload_module_in_cuda_graph = True - if not self.is_moe_layer and 'mlp' in self.config.cuda_graph_scope: + if not self.is_moe_layer and CudaGraphScope.mlp in self.config.cuda_graph_scope: if self.offload_modules["mlp_norm"] or self.offload_modules["dense_mlp"]: self.offload_module_in_cuda_graph = True if self.offload_module_in_cuda_graph: From 3cf19b7c517023ac07bea50a9c3971b3f430b9d6 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 8 Dec 2025 06:03:40 -0800 Subject: [PATCH 18/74] minor fix Signed-off-by: Hongbin Liu --- .../fine_grained_activation_offload.py | 29 ------------------- megatron/core/pipeline_parallel/utils.py | 14 +++------ .../transformer/multi_latent_attention.py | 2 +- .../core/transformer/transformer_config.py | 7 +++-- megatron/training/training.py | 2 +- 5 files changed, 10 insertions(+), 44 deletions(-) diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index 32d0d2f0c1d..69047637322 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -6,8 +6,6 @@ import torch -from megatron.core.pipeline_parallel.utils import set_ideal_affinity_for_current_gpu - # CPU offload implementation for pipeline parallelism DEBUG = False DEBUG_RANK = 0 @@ -301,33 +299,6 @@ def wait_reload_event(self, stream): stream.wait_event(self._reload_event) -def set_ideal_affinity_for_current_gpu(): - """Set CPU affinity for the current GPU to optimize host-device transfers.""" - import uuid - - try: - import cuda.bindings.driver as cuda_driver - import cuda.bindings.runtime as cuda_runtime - except: - try: - import cuda.cuda as cuda_driver - import cuda.cudart as cuda_runtime - except: - raise RuntimeError("Please install cuda-python to enable GPU affinity setting") - import pynvml - - # Get current CUDA device ID - err, device_id = cuda_runtime.cudaGetDevice() - assert err == cuda_runtime.cudaError_t.cudaSuccess - # Get device UUID - err, device_uuid = cuda_driver.cuDeviceGetUuid(device_id) - assert err == cuda_driver.CUresult.CUDA_SUCCESS - # Set CPU affinity based on GPU's NUMA node - pynvml.nvmlInit() - handle = pynvml.nvmlDeviceGetHandleByUUID("GPU-" + str(uuid.UUID(bytes=device_uuid.bytes))) - pynvml.nvmlDeviceSetCpuAffinity(handle) - - class PipelineOffloadManager: """ Singleton manager for coordinating activation offloading across pipeline stages. diff --git a/megatron/core/pipeline_parallel/utils.py b/megatron/core/pipeline_parallel/utils.py index c50c6ac7964..814de85ae70 100644 --- a/megatron/core/pipeline_parallel/utils.py +++ b/megatron/core/pipeline_parallel/utils.py @@ -87,19 +87,13 @@ def set_ideal_affinity_for_current_gpu(): try: import cuda.bindings.driver as cuda_driver import cuda.bindings.runtime as cuda_runtime - except ImportError: + except: try: import cuda.cuda as cuda_driver import cuda.cudart as cuda_runtime - except ImportError: - # print("cuda-python may not be installed, skipping GPU affinity setting") - warnings.warn("cuda-python may not be installed, skipping GPU affinity setting") - return - try: - import pynvml - except ImportError: - warnings.warn("pynvml is not installed, skipping GPU affinity setting") - return + except: + raise RuntimeError("Please install cuda-python to enable GPU affinity setting") + import pynvml # Get current CUDA device ID err, device_id = cuda_runtime.cudaGetDevice() diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index 57438603e3f..d0b588bc94a 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -268,7 +268,7 @@ def forward( raise ValueError( f"Unsupported experimental attention variant: " f"{self.config.experimental_attention_variant}" - ) + ) if self.offload_qkv_linear: (query, key, value) = fine_grained_offloading_group_commit( query, key, value, name="qkv_linear", forced_released_tensors=[hidden_states] diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 9fbaa3a6940..2846b08ae12 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1248,9 +1248,10 @@ def __post_init__(self): self.cuda_graph_scope is not None ), "cuda_graph_scope must be set when enabling offloading." assert ( - ("attn" in self.cuda_graph_scope and "moe_router" in self.cuda_graph_scope) - or (CudaGraphScope.attn in self.cuda_graph_scope - and CudaGraphScope.moe_router in self.cuda_graph_scope) + "attn" in self.cuda_graph_scope and "moe_router" in self.cuda_graph_scope + ) or ( + CudaGraphScope.attn in self.cuda_graph_scope + and CudaGraphScope.moe_router in self.cuda_graph_scope ), "attn and moe_router must be in cuda_graph_scope when enabling offloading." assert ( "attn_norm" not in self.offload_modules diff --git a/megatron/training/training.py b/megatron/training/training.py index de8dfd498ba..009305bbf59 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -696,7 +696,7 @@ def pretrain( timers = get_timers() if args.fine_grained_activation_offloading: - from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + from megatron.core.pipeline_parallel.utils import ( set_ideal_affinity_for_current_gpu ) set_ideal_affinity_for_current_gpu() From d0fc888f0c2ab835cd37724a32fee4cbbb132b97 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 8 Dec 2025 06:26:47 -0800 Subject: [PATCH 19/74] dump offloading information Signed-off-by: Hongbin Liu --- .../fine_grained_activation_offload.py | 34 +++++++++++++++---- 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index 69047637322..01d9128e14d 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -268,7 +268,8 @@ def __init__(self, name): self._offload_event = torch.cuda.Event() self._reload_event = torch.cuda.Event() self.offload = True - + self.total_offload_bytes = 0 + self.total_tensor_count = 0 if name == "expert_fc1" or name == "moe_act": self.use_cpu_pool = False else: @@ -298,6 +299,11 @@ def wait_reload_event(self, stream): """Wait for the reload event.""" stream.wait_event(self._reload_event) + def update_offload_info(self, tensor): + """Update the offload information.""" + self.total_offload_bytes += tensor.numel() * tensor.element_size() + self.total_tensor_count += 1 + class PipelineOffloadManager: """ @@ -402,6 +408,7 @@ def enable_offload(self): def post_warmup_callback(self): """Callback after warmup.""" + # pylint: disable=bad-builtin debug_rank("post_warmup_callback") self._is_warmup = False assert len(self._cached_chunks_forward) == len( @@ -431,6 +438,24 @@ def post_warmup_callback(self): break debug_rank(f"offload margin {self._offload_margin}") assert self._offload_margin == 0, "Offload margin is not 0" + # Dump the offload information + total_tensor_count = {} + total_offload_bytes = {} + for chunk in self._cached_chunks_backward: + for group in chunk.offload_groups: + if group.offload: + if group._name not in total_tensor_count: + total_tensor_count[group._name] = 0 + total_tensor_count[group._name] += group.total_tensor_count + if group._name not in total_offload_bytes: + total_offload_bytes[group._name] = 0 + total_offload_bytes[group._name] += group.total_offload_bytes + assert torch.distributed.is_initialized() + rank = torch.distributed.get_rank() + for name, tensor_count in total_tensor_count.items(): + print(f"rank {rank} total offloaded tensor count for group {name} {tensor_count}") + for name, offload_bytes in total_offload_bytes.items(): + print(f"rank {rank} total offloaded bytes for group {name} {offload_bytes}") def push(self, handler): """Add a chunk handler to the backward queue.""" @@ -632,8 +657,6 @@ def __init__(self, min_offloaded_tensor_size, cpu_tensor_pool): self.torch_tensor_count = 0 self.d2h_stream = PipelineOffloadManager.get_instance().d2h_stream self.h2d_stream = PipelineOffloadManager.get_instance().h2d_stream - # self._offload_events = {} - # self._reload_events = {} self.min_offloaded_tensor_size = min_offloaded_tensor_size self.cpu_tensor_pool = cpu_tensor_pool self.offload_groups = [] @@ -645,13 +668,10 @@ def reset(self): self._groups_to_offload = [] self._groups_to_reload = [] self._tensor_count_current_group = 0 - # self._offload_events = {} - # self._reload_events = {} def is_empty_chunk(self, name=None): """Check if this chunk has no tensors to manage.""" debug_rank(f"------is_empty_chunk {self._max_group_size}") - # return len(self._tensor_tag_to_state) == 0 if name is not None: for group in self.offload_groups: debug_rank(f"group name {group._name} need name {name}") @@ -743,6 +763,8 @@ def bulk_offload_group(self): state = self.offload( tensor_on_device, use_cpu_pool=group_to_offload.use_cpu_pool ) + if self.is_warmup: + group_to_offload.update_offload_info(tensor_on_device) tensor_on_device.record_stream(self.d2h_stream) group_to_offload.push_tensor(tensor_tag, state) group_to_offload.record_offload_event(self.d2h_stream) From b7c0fbacc5279204a002c24bec7f96b84ec5ccf5 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 8 Dec 2025 07:07:41 -0800 Subject: [PATCH 20/74] fix ut Signed-off-by: Hongbin Liu --- .../core/transformer/transformer_layer.py | 76 ++++++++----------- ...test_fine_grained_activation_offloading.py | 1 - 2 files changed, 31 insertions(+), 46 deletions(-) diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index d2b7c585a38..603a8a4c89f 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -519,17 +519,17 @@ def _forward_attention( # Residual connection. residual = hidden_states - if self.offload_modules["attn_norm"]: + if self.offload_attn_norm: hidden_states = fine_grained_offloading_group_start(hidden_states, name="attn_norm") # Optional Input Layer norm if self.recompute_input_layernorm: self.input_layernorm_checkpoint = tensor_parallel.CheckpointWithoutOutput() - with get_fine_grained_offloading_context(self.offload_modules["attn_norm"]): + with get_fine_grained_offloading_context(self.offload_attn_norm): input_layernorm_output = self.input_layernorm_checkpoint.checkpoint( self.input_layernorm, hidden_states ) else: - with get_fine_grained_offloading_context(self.offload_modules["attn_norm"]): + with get_fine_grained_offloading_context(self.offload_attn_norm): input_layernorm_output = self.input_layernorm(hidden_states) # Self attention. @@ -564,7 +564,7 @@ def _forward_attention( ) nvtx_range_pop(suffix="self_attn_bda") - if self.offload_modules["attn_norm"]: + if self.offload_attn_norm: (hidden_states,) = fine_grained_offloading_group_commit( hidden_states, name="attn_norm", forced_released_tensors=[residual] ) @@ -615,17 +615,17 @@ def _forward_mlp(self, hidden_states, inference_context=None): # Residual connection. residual = hidden_states - if self.offload_modules["mlp_norm"]: + if self.offload_mlp_norm: hidden_states = fine_grained_offloading_group_start(hidden_states, name="mlp_norm") # Optional Layer norm post the cross-attention. if self.recompute_pre_mlp_layernorm: self.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput() - with get_fine_grained_offloading_context(self.offload_modules["mlp_norm"]): + with get_fine_grained_offloading_context(self.offload_mlp_norm): pre_mlp_layernorm_output = self.pre_mlp_norm_checkpoint.checkpoint( self.pre_mlp_layernorm, hidden_states ) else: - with get_fine_grained_offloading_context(self.offload_modules["mlp_norm"]): + with get_fine_grained_offloading_context(self.offload_mlp_norm): pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states) nvtx_range_push(suffix="mlp") @@ -680,13 +680,13 @@ def _forward_mlp(self, hidden_states, inference_context=None): bias_output = torch.stack(bias_chunks, dim=0).sum(dim=0) if bias_chunks else None mlp_output_with_bias = (mlp_output, bias_output) else: - if self.offload_modules["dense_mlp"]: + if self.offload_dense_mlp: pre_mlp_layernorm_output = fine_grained_offloading_group_start( pre_mlp_layernorm_output, name="dense_mlp" ) - with get_fine_grained_offloading_context(self.offload_modules["dense_mlp"]): + with get_fine_grained_offloading_context(self.offload_dense_mlp): mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) - if self.offload_modules["dense_mlp"]: + if self.offload_dense_mlp: (mlp_output,) = fine_grained_offloading_group_commit( mlp_output_with_bias[0], name="dense_mlp", forced_released_tensors=[] ) @@ -728,7 +728,7 @@ def _forward_post_mlp(self, mlp_output_with_bias, residual): mlp_output_with_bias, residual, self.hidden_dropout ) nvtx_range_pop(suffix="mlp_bda") - if self.offload_modules["mlp_norm"]: + if self.offload_mlp_norm: (hidden_states,) = fine_grained_offloading_group_commit( hidden_states, name="mlp_norm", forced_released_tensors=[residual] ) @@ -1044,48 +1044,34 @@ def __call__(self, *args, **kwargs): def _set_offload_modules(self): """Set the offload modules for the transformer layer.""" - self.offload_modules = { - "attn_norm": False, - "qkv_linear": False, - "core_attn": False, - "attn_proj": False, - "mlp_norm": False, - "expert_fc1": False, - "moe_act": False, - "dense_mlp": False, - } if self.config.fine_grained_activation_offloading: - if "attn_norm" in self.config.offload_modules and not isinstance( - self.input_layernorm, IdentityOp - ): - self.offload_modules["attn_norm"] = True - if "qkv_linear" in self.config.offload_modules: - self.offload_modules["qkv_linear"] = True - if "core_attn" in self.config.offload_modules: - self.offload_modules["core_attn"] = True - if "attn_proj" in self.config.offload_modules: - self.offload_modules["attn_proj"] = True - if "mlp_norm" in self.config.offload_modules and not isinstance( - self.pre_mlp_layernorm, IdentityOp - ): - self.offload_modules["mlp_norm"] = True - if "expert_fc1" in self.config.offload_modules: - self.offload_modules["expert_fc1"] = True - if "moe_act" in self.config.offload_modules: - self.offload_modules["moe_act"] = True - if "dense_mlp" in self.config.offload_modules and not self.is_moe_layer: - self.offload_modules["dense_mlp"] = True + self.offload_attn_norm = ( + "attn_norm" in self.config.offload_modules + and not isinstance(self.input_layernorm, IdentityOp) + ) + self.offload_qkv_linear = "qkv_linear" in self.config.offload_modules + self.offload_core_attn = "core_attn" in self.config.offload_modules + self.offload_attn_proj = "attn_proj" in self.config.offload_modules + self.offload_mlp_norm = ( + "mlp_norm" in self.config.offload_modules + and not isinstance(self.pre_mlp_layernorm, IdentityOp) + ) + self.offload_expert_fc1 = "expert_fc1" in self.config.offload_modules + self.offload_moe_act = "moe_act" in self.config.offload_modules + self.offload_dense_mlp = ( + "dense_mlp" in self.config.offload_modules and not self.is_moe_layer + ) # Set the offload module in cuda graph flag. self.offload_module_in_cuda_graph = False if CudaGraphScope.attn in self.config.cuda_graph_scope: if ( - self.offload_modules["core_attn"] - or self.offload_modules["attn_proj"] - or self.offload_modules["qkv_linear"] + self.offload_core_attn + or self.offload_attn_proj + or self.offload_qkv_linear ): self.offload_module_in_cuda_graph = True if not self.is_moe_layer and CudaGraphScope.mlp in self.config.cuda_graph_scope: - if self.offload_modules["mlp_norm"] or self.offload_modules["dense_mlp"]: + if self.offload_mlp_norm or self.offload_dense_mlp: self.offload_module_in_cuda_graph = True if self.offload_module_in_cuda_graph: assert is_torch_min_version( diff --git a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py index 7c1b7f1fe4b..d43839074f4 100644 --- a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py +++ b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py @@ -49,7 +49,6 @@ def forward(self, x, use_offload: bool = False): ) for i, layer in enumerate(self.net): # Group by module; with this linear-only model, each group corresponds to a layer. - off.fine_grained_offloading_set_last_layer(i == len(self.net) - 1) x = off.fine_grained_offloading_group_start(x, name=f"layer_{i}") x = layer(x) # Commit the group; returns a tuple of tensors From ae4e2b52c851ed3b70e9c5863064dd1fae65860a Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 8 Dec 2025 07:12:36 -0800 Subject: [PATCH 21/74] format Signed-off-by: Hongbin Liu --- megatron/core/transformer/transformer_layer.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 603a8a4c89f..0aad9edc01e 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -1045,16 +1045,14 @@ def __call__(self, *args, **kwargs): def _set_offload_modules(self): """Set the offload modules for the transformer layer.""" if self.config.fine_grained_activation_offloading: - self.offload_attn_norm = ( - "attn_norm" in self.config.offload_modules - and not isinstance(self.input_layernorm, IdentityOp) + self.offload_attn_norm = "attn_norm" in self.config.offload_modules and not isinstance( + self.input_layernorm, IdentityOp ) self.offload_qkv_linear = "qkv_linear" in self.config.offload_modules self.offload_core_attn = "core_attn" in self.config.offload_modules self.offload_attn_proj = "attn_proj" in self.config.offload_modules - self.offload_mlp_norm = ( - "mlp_norm" in self.config.offload_modules - and not isinstance(self.pre_mlp_layernorm, IdentityOp) + self.offload_mlp_norm = "mlp_norm" in self.config.offload_modules and not isinstance( + self.pre_mlp_layernorm, IdentityOp ) self.offload_expert_fc1 = "expert_fc1" in self.config.offload_modules self.offload_moe_act = "moe_act" in self.config.offload_modules @@ -1064,11 +1062,7 @@ def _set_offload_modules(self): # Set the offload module in cuda graph flag. self.offload_module_in_cuda_graph = False if CudaGraphScope.attn in self.config.cuda_graph_scope: - if ( - self.offload_core_attn - or self.offload_attn_proj - or self.offload_qkv_linear - ): + if self.offload_core_attn or self.offload_attn_proj or self.offload_qkv_linear: self.offload_module_in_cuda_graph = True if not self.is_moe_layer and CudaGraphScope.mlp in self.config.cuda_graph_scope: if self.offload_mlp_norm or self.offload_dense_mlp: From b797438deb666aaf7d6b745c26c5ea8239cbe744 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 8 Dec 2025 07:33:46 -0800 Subject: [PATCH 22/74] fit ut Signed-off-by: Hongbin Liu --- megatron/core/transformer/transformer_layer.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 0aad9edc01e..da552fa37f4 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -1059,6 +1059,15 @@ def _set_offload_modules(self): self.offload_dense_mlp = ( "dense_mlp" in self.config.offload_modules and not self.is_moe_layer ) + else: + self.offload_attn_norm = False + self.offload_qkv_linear = False + self.offload_core_attn = False + self.offload_attn_proj = False + self.offload_mlp_norm = False + self.offload_expert_fc1 = False + self.offload_moe_act = False + self.offload_dense_mlp = False # Set the offload module in cuda graph flag. self.offload_module_in_cuda_graph = False if CudaGraphScope.attn in self.config.cuda_graph_scope: From 6cec22f6bfeb67a448f3e45c158a8f36efb33bd5 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Wed, 17 Dec 2025 00:00:46 -0800 Subject: [PATCH 23/74] delay d2h copies until finishing cuda graph Signed-off-by: Hongbin Liu --- .../fine_grained_activation_offload.py | 43 ++++++++++++++----- megatron/core/pipeline_parallel/utils.py | 1 + megatron/core/transformer/moe/experts.py | 4 +- .../core/transformer/transformer_config.py | 3 ++ .../core/transformer/transformer_layer.py | 11 +++++ megatron/training/arguments.py | 2 + 6 files changed, 53 insertions(+), 11 deletions(-) diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index 01d9128e14d..453066b33df 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -343,6 +343,7 @@ def __init__(self): # Margin to avoid offloading too many groups so that self._offload_margin = 0 + self._delayed_offload_groups = [] self.reset() @property @@ -360,6 +361,18 @@ def cpu_tensor_pool(self): """Get the shared CPU tensor pool.""" return self._cpu_tensor_pool + def push_offload_groups(self, group_hook, forced_released_tensors): + """Push the offload groups to the delayed queue.""" + debug_rank(f"pushing offload groups to the delayed queue") + self._delayed_offload_groups.append((group_hook, forced_released_tensors)) + + def flush_delayed_groups(self): + """Flush the delayed groups.""" + debug_rank("flushing delayed groups") + for group_hook, forced_released_tensors in self._delayed_offload_groups: + group_hook(forced_released_tensors) + self._delayed_offload_groups = [] + def reset(self): """Reset manager state for a new training iteration.""" self._inside_context = False @@ -376,6 +389,7 @@ def reset(self): for chunk in self._cached_chunks_forward: chunk.reset() + self._delayed_offload_groups = [] def flush(self): """Flush all staged chunks to the backward queue in reverse order.""" @@ -941,15 +955,20 @@ def forward(ctx, *args): # pylint: disable=missing-function-docstring debug_rank("FineGrainedOffloadingGroupCommitFunction forward") - forced_released_tensors = args[-1] - name = args[-2] - cpu_offload_handler = args[-3] - tensor = args[:-3] - cpu_offload_handler.on_group_commit_forward(forced_released_tensors) + delay_offload = args[-1] + forced_released_tensors = args[-2] + name = args[-3] + cpu_offload_handler = args[-4] + tensor = args[:-4] + if delay_offload: + PipelineOffloadManager.get_instance().push_offload_groups( + cpu_offload_handler.on_group_commit_forward, + forced_released_tensors + ) + else: + cpu_offload_handler.on_group_commit_forward(forced_released_tensors) ctx.cpu_offload_handler = cpu_offload_handler ctx.name = name - - # return the identical tensor return tensor @staticmethod @@ -959,10 +978,10 @@ def backward(ctx, *grad_output): cpu_offload_handler = ctx.cpu_offload_handler cpu_offload_handler.on_group_commit_backward(ctx.name) - return grad_output + (None, None, None) + return grad_output + (None, None, None, None) -def fine_grained_offloading_group_commit(*tensor, name, forced_released_tensors=[]): +def fine_grained_offloading_group_commit(*tensor, name, forced_released_tensors=[], delay_offload=False): """ Specify the tensors to be released after offloading. forced_released_tensors is a list of tensors to be released after offloading. @@ -973,9 +992,13 @@ def fine_grained_offloading_group_commit(*tensor, name, forced_released_tensors= if cur_forward_chunk is None: return tensor return FineGrainedOffloadingGroupCommitFunction.apply( - *tensor, cur_forward_chunk, name, forced_released_tensors + *tensor, cur_forward_chunk, name, forced_released_tensors, delay_offload ) +def fine_grained_offloading_group_flush_delayed_groups(): + """Flush the delayed groups.""" + debug_rank("fine_grained_offloading_group_flush_delayed_groups") + PipelineOffloadManager.get_instance().flush_delayed_groups() class FineGrainedOffloadingGroupStartFunction(torch.autograd.Function): """ diff --git a/megatron/core/pipeline_parallel/utils.py b/megatron/core/pipeline_parallel/utils.py index 814de85ae70..d0c88e629d2 100644 --- a/megatron/core/pipeline_parallel/utils.py +++ b/megatron/core/pipeline_parallel/utils.py @@ -5,6 +5,7 @@ from typing import Callable, Optional import torch +import warnings from torch.autograd import Variable from megatron.core.utils import get_pg_rank, get_pg_size, make_viewless_tensor diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 5eeafdd8d1d..cc0c8e13e66 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -896,6 +896,7 @@ def forward( bias_parallel, name="expert_fc1", forced_released_tensors=[permuted_local_hidden_states], + delay_offload=self.config.delay_offload_until_cuda_graph, ) def bias_act_func(intermediate_parallel, bias_parallel, permuted_probs): @@ -974,7 +975,8 @@ def glu(x): self.activation_checkpoint.discard_output_and_register_recompute(output) if self.offload_moe_act: (output,) = fine_grained_offloading_group_commit( - output, name="moe_act", forced_released_tensors=[fc1_output] + output, name="moe_act", forced_released_tensors=[fc1_output], + delay_offload=self.config.delay_offload_until_cuda_graph ) # upad and concat the output diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 2846b08ae12..dfe422e3a5e 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -847,6 +847,9 @@ class TransformerConfig(ModelParallelConfig): min_offloaded_tensor_size: int = 1024 * 1024 """The minimum size of the tensor to be offloaded.""" + delay_offload_until_cuda_graph: bool = False + """If True, delay the offload until the CUDA graph is executed for minimal CPU overhead.""" + def __post_init__(self): """Python dataclass method that is used to modify attributes after initialization. See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index da552fa37f4..1329789bdd0 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -693,6 +693,11 @@ def _forward_mlp(self, hidden_states, inference_context=None): mlp_output_with_bias = (mlp_output, mlp_output_with_bias[1]) else: mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) + if self.config.delay_offload_until_cuda_graph: + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + fine_grained_offloading_group_flush_delayed_groups, + ) + fine_grained_offloading_group_flush_delayed_groups() if self.recompute_pre_mlp_layernorm: # discard the output of the pre-mlp layernorm and register the recompute @@ -884,6 +889,12 @@ def _te_cuda_graph_replay(self, *args, **kwargs): cuda_graph_output = list(super()._te_cuda_graph_replay(*args, **kwargs)) + if self.config.delay_offload_until_cuda_graph: + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + fine_grained_offloading_group_flush_delayed_groups, + ) + fine_grained_offloading_group_flush_delayed_groups() + if kwargs.get('context') is not None: context = cuda_graph_output.pop() diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index be147150cf3..f04bd51c94b 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -2426,6 +2426,8 @@ def _add_training_args(parser): help='The submodules to offload its input. Choices: "attn_norm", "qkv_linear", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act".') group.add_argument('--min-offloaded-tensor-size', type=int, default=10*1024*1024, help='The minimum size of the tensor to be offloaded.') + group.add_argument('--delay-offload-until-cuda-graph', action='store_true', + help='Delay the offload until the CUDA graph is executed for minimal CPU overhead.') group.add_argument('--disable-jit-fuser', action='store_true', help='Disable the JIT fuser.') return parser From 5a150c78f03c062ecea7a0c22e302a290497f4bb Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Wed, 17 Dec 2025 02:13:42 -0800 Subject: [PATCH 24/74] minor fix Signed-off-by: Hongbin Liu --- .../fine_grained_activation_offload.py | 79 +++++++++++++++++-- .../core/transformer/transformer_layer.py | 24 +++--- 2 files changed, 85 insertions(+), 18 deletions(-) diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index 453066b33df..75f52bdb98b 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -23,6 +23,76 @@ def debug_rank(message): print(message) +def print_offload_summary_table(total_offload_bytes: Dict[str, int]): + """ + Print an ASCII table summarizing offload bytes across all ranks. + + Gathers offload data from all ranks and prints a formatted table on rank 0, + with rows representing ranks and columns representing groups. + + Args: + total_offload_bytes: Dict mapping group names to offload bytes for this rank. + """ + assert torch.distributed.is_initialized() + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + + # Gather all group names across ranks + local_names = list(total_offload_bytes.keys()) + all_names_list = [None] * world_size + torch.distributed.all_gather_object(all_names_list, local_names) + all_group_names = sorted(set(name for names in all_names_list for name in names)) + + # Gather offload bytes from all ranks: each rank sends a list of bytes per group + local_bytes = [total_offload_bytes.get(name, 0) for name in all_group_names] + all_bytes_list = [None] * world_size + torch.distributed.all_gather_object(all_bytes_list, local_bytes) + + # Print ASCII table on rank 0 + if rank == 0: + # Calculate column widths + col_width = max(12, max((len(name) for name in all_group_names), default=8) + 2) + rank_col_width = max(6, len(f"Rank {world_size - 1}") + 2) + + # Build header + header = "Rank".ljust(rank_col_width) + header += "".join(name.rjust(col_width) for name in all_group_names) + header += "Total".rjust(col_width) + separator = "-" * len(header) + + print("\n" + "=" * len(header)) + print("Activation Offload Summary (MB)".center(len(header))) + print("=" * len(header)) + print(header) + print(separator) + + # Build rows for each rank + grand_total = 0 + col_totals = [0] * len(all_group_names) + for r in range(world_size): + row_bytes = all_bytes_list[r] + row_total = sum(row_bytes) + grand_total += row_total + for i, b in enumerate(row_bytes): + col_totals[i] += b + row_str = f"Rank {r}".ljust(rank_col_width) + for b in row_bytes: + row_str += f"{b / (1024 * 1024):.2f}".rjust(col_width) + row_str += f"{row_total / (1024 * 1024):.2f}".rjust(col_width) + print(row_str) + + # Print totals row + print(separator) + totals_row = "Total".ljust(rank_col_width) + for ct in col_totals: + totals_row += f"{ct / (1024 * 1024):.2f}".rjust(col_width) + totals_row += f"{grand_total / (1024 * 1024):.2f}".rjust(col_width) + print(totals_row) + print("=" * len(header) + "\n") + + torch.distributed.barrier() + + class GPUTensorPool: """ GPU memory pool for efficient allocation and deallocation of tensors. @@ -365,7 +435,7 @@ def push_offload_groups(self, group_hook, forced_released_tensors): """Push the offload groups to the delayed queue.""" debug_rank(f"pushing offload groups to the delayed queue") self._delayed_offload_groups.append((group_hook, forced_released_tensors)) - + def flush_delayed_groups(self): """Flush the delayed groups.""" debug_rank("flushing delayed groups") @@ -464,12 +534,7 @@ def post_warmup_callback(self): if group._name not in total_offload_bytes: total_offload_bytes[group._name] = 0 total_offload_bytes[group._name] += group.total_offload_bytes - assert torch.distributed.is_initialized() - rank = torch.distributed.get_rank() - for name, tensor_count in total_tensor_count.items(): - print(f"rank {rank} total offloaded tensor count for group {name} {tensor_count}") - for name, offload_bytes in total_offload_bytes.items(): - print(f"rank {rank} total offloaded bytes for group {name} {offload_bytes}") + print_offload_summary_table(total_offload_bytes) def push(self, handler): """Add a chunk handler to the backward queue.""" diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 1329789bdd0..bce6b8ebcee 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -595,7 +595,7 @@ def _forward_attention( return hidden_states, context - def _forward_mlp(self, hidden_states, inference_context=None): + def _forward_mlp(self, hidden_states, inference_context=None, flush_delayed_groups=True): """ Perform a forward pass through the feed-forward layer. @@ -688,16 +688,13 @@ def _forward_mlp(self, hidden_states, inference_context=None): mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) if self.offload_dense_mlp: (mlp_output,) = fine_grained_offloading_group_commit( - mlp_output_with_bias[0], name="dense_mlp", forced_released_tensors=[] + mlp_output_with_bias[0], name="dense_mlp", + forced_released_tensors=[], + delay_offload=self.config.delay_offload_until_cuda_graph ) mlp_output_with_bias = (mlp_output, mlp_output_with_bias[1]) else: mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) - if self.config.delay_offload_until_cuda_graph: - from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - fine_grained_offloading_group_flush_delayed_groups, - ) - fine_grained_offloading_group_flush_delayed_groups() if self.recompute_pre_mlp_layernorm: # discard the output of the pre-mlp layernorm and register the recompute @@ -707,9 +704,9 @@ def _forward_mlp(self, hidden_states, inference_context=None): ) nvtx_range_pop(suffix="mlp") - return self._forward_post_mlp(mlp_output_with_bias, residual) + return self._forward_post_mlp(mlp_output_with_bias, residual, flush_delayed_groups) - def _forward_post_mlp(self, mlp_output_with_bias, residual): + def _forward_post_mlp(self, mlp_output_with_bias, residual, flush_delayed_groups=True): """ Perform operations after the MLP computation. @@ -748,6 +745,11 @@ def _forward_post_mlp(self, mlp_output_with_bias, residual): inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True ) + if self.config.delay_offload_until_cuda_graph and flush_delayed_groups: + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + fine_grained_offloading_group_flush_delayed_groups, + ) + fine_grained_offloading_group_flush_delayed_groups() return output def sharded_state_dict( @@ -952,10 +954,10 @@ def _te_cuda_graph_replay(self, *args, **kwargs): self.mlp.cudagraph_tensor_store.clear() nvtx_range_pop(suffix="mlp") - output = self._forward_post_mlp(mlp_output_with_bias, mlp_residual) + output = self._forward_post_mlp(mlp_output_with_bias, mlp_residual, flush_delayed_groups=False) else: # CUDA Graph does not capture the MLP/MoE part at all. - output = self._forward_mlp(*cuda_graph_output) + output = self._forward_mlp(*cuda_graph_output, flush_delayed_groups=False) return output, context def _get_te_cuda_graph_replay_args(self, *args, **kwargs): From 60e30823329d8fa84576591c2a0eca258e003ff0 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Wed, 17 Dec 2025 02:21:34 -0800 Subject: [PATCH 25/74] format Signed-off-by: Hongbin Liu --- .../fine_grained_activation_offload.py | 10 +++++++--- megatron/core/pipeline_parallel/utils.py | 1 - megatron/core/transformer/moe/experts.py | 6 ++++-- megatron/core/transformer/transformer_layer.py | 11 ++++++++--- 4 files changed, 19 insertions(+), 9 deletions(-) diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index 75f52bdb98b..4d34ef37f3b 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -33,6 +33,7 @@ def print_offload_summary_table(total_offload_bytes: Dict[str, int]): Args: total_offload_bytes: Dict mapping group names to offload bytes for this rank. """ + # pylint: disable=bad-builtin assert torch.distributed.is_initialized() rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() @@ -1027,8 +1028,7 @@ def forward(ctx, *args): tensor = args[:-4] if delay_offload: PipelineOffloadManager.get_instance().push_offload_groups( - cpu_offload_handler.on_group_commit_forward, - forced_released_tensors + cpu_offload_handler.on_group_commit_forward, forced_released_tensors ) else: cpu_offload_handler.on_group_commit_forward(forced_released_tensors) @@ -1046,7 +1046,9 @@ def backward(ctx, *grad_output): return grad_output + (None, None, None, None) -def fine_grained_offloading_group_commit(*tensor, name, forced_released_tensors=[], delay_offload=False): +def fine_grained_offloading_group_commit( + *tensor, name, forced_released_tensors=[], delay_offload=False +): """ Specify the tensors to be released after offloading. forced_released_tensors is a list of tensors to be released after offloading. @@ -1060,11 +1062,13 @@ def fine_grained_offloading_group_commit(*tensor, name, forced_released_tensors= *tensor, cur_forward_chunk, name, forced_released_tensors, delay_offload ) + def fine_grained_offloading_group_flush_delayed_groups(): """Flush the delayed groups.""" debug_rank("fine_grained_offloading_group_flush_delayed_groups") PipelineOffloadManager.get_instance().flush_delayed_groups() + class FineGrainedOffloadingGroupStartFunction(torch.autograd.Function): """ Identity operation that marks the start of a layer group for offload/reload. diff --git a/megatron/core/pipeline_parallel/utils.py b/megatron/core/pipeline_parallel/utils.py index b8dc64ff9f4..acdc3bb27ad 100644 --- a/megatron/core/pipeline_parallel/utils.py +++ b/megatron/core/pipeline_parallel/utils.py @@ -5,7 +5,6 @@ from typing import Callable, Optional import torch -import warnings from torch.autograd import Variable from megatron.core.utils import get_pg_rank, get_pg_size, make_viewless_tensor diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index cc0c8e13e66..544e12fa4a4 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -975,8 +975,10 @@ def glu(x): self.activation_checkpoint.discard_output_and_register_recompute(output) if self.offload_moe_act: (output,) = fine_grained_offloading_group_commit( - output, name="moe_act", forced_released_tensors=[fc1_output], - delay_offload=self.config.delay_offload_until_cuda_graph + output, + name="moe_act", + forced_released_tensors=[fc1_output], + delay_offload=self.config.delay_offload_until_cuda_graph, ) # upad and concat the output diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index bce6b8ebcee..310d9d2a927 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -688,9 +688,10 @@ def _forward_mlp(self, hidden_states, inference_context=None, flush_delayed_grou mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) if self.offload_dense_mlp: (mlp_output,) = fine_grained_offloading_group_commit( - mlp_output_with_bias[0], name="dense_mlp", + mlp_output_with_bias[0], + name="dense_mlp", forced_released_tensors=[], - delay_offload=self.config.delay_offload_until_cuda_graph + delay_offload=self.config.delay_offload_until_cuda_graph, ) mlp_output_with_bias = (mlp_output, mlp_output_with_bias[1]) else: @@ -749,6 +750,7 @@ def _forward_post_mlp(self, mlp_output_with_bias, residual, flush_delayed_groups from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( fine_grained_offloading_group_flush_delayed_groups, ) + fine_grained_offloading_group_flush_delayed_groups() return output @@ -895,6 +897,7 @@ def _te_cuda_graph_replay(self, *args, **kwargs): from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( fine_grained_offloading_group_flush_delayed_groups, ) + fine_grained_offloading_group_flush_delayed_groups() if kwargs.get('context') is not None: @@ -954,7 +957,9 @@ def _te_cuda_graph_replay(self, *args, **kwargs): self.mlp.cudagraph_tensor_store.clear() nvtx_range_pop(suffix="mlp") - output = self._forward_post_mlp(mlp_output_with_bias, mlp_residual, flush_delayed_groups=False) + output = self._forward_post_mlp( + mlp_output_with_bias, mlp_residual, flush_delayed_groups=False + ) else: # CUDA Graph does not capture the MLP/MoE part at all. output = self._forward_mlp(*cuda_graph_output, flush_delayed_groups=False) From 256d79d67b5ae04b8749b475e2d1885ad938758c Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Wed, 17 Dec 2025 06:43:49 -0800 Subject: [PATCH 26/74] fix ut Signed-off-by: Hongbin Liu --- .../fine_grained_activation_offload.py | 8 ++- ...test_fine_grained_activation_offloading.py | 67 +++++++++---------- 2 files changed, 36 insertions(+), 39 deletions(-) diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index 4d34ef37f3b..96a5629250f 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -824,6 +824,7 @@ def tensor_pop(self, tensor_tag): def tensor_need_offloading_checker(self, tensor): """Check if the tensor needs to be offloaded.""" + debug_rank(f"tensor_need_offloading_checker {tensor.numel()} {getattr(tensor, 'offloading_activation', None)}") if tensor.numel() < self.min_offloaded_tensor_size: return False # Respect tensor's offload preference if specified @@ -889,9 +890,12 @@ def pre_reload_last_layer(self): def should_bulk_offload(self): """Determine if the current group should be offloaded.""" group = self._groups_to_offload[-1] + debug_rank(f"should_bulk_offload {self.is_warmup} {group.offload}") # Don't offload if the chunk is not in warmup stage - # and the group is marked as not offloadable - if not PipelineOffloadManager.get_instance()._is_warmup and not group.offload: + if self.is_warmup: + return True + # Don't offload if the group is marked as not offloadable + if not group.offload: return False # Check if next backward chunk is this chunk (for last pipeline stage) diff --git a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py index d43839074f4..8807670af33 100644 --- a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py +++ b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py @@ -1,10 +1,12 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import gc +import os import pytest import torch + EPSILON = 0.1 # Skip all tests if CUDA is not available @@ -16,46 +18,40 @@ def _reset_cuda_memory(): if cuda_available: torch.cuda.empty_cache() - class ToyModel(torch.nn.Module): def __init__(self, hidden_size: int = 2048, num_layers: int = 4, dtype=torch.bfloat16): + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend="nccl") + if torch.cuda.is_available(): + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) super().__init__() layers = [] for _ in range(num_layers): - layers.append( - torch.nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device="cuda") - ) + linear = torch.nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device="cuda") + layers.append(linear) self.net = torch.nn.Sequential(*layers).to(device="cuda", dtype=dtype) self.hidden_size = hidden_size self.num_layers = num_layers self.dtype = dtype - # Prevent weights/bias from being considered activation tensors for offload; - # ensure we only count activation tensors (inputs x) in memory accounting. - for p in self.parameters(): - try: - setattr(p, "offloading_activation", False) - except Exception: - pass - def forward(self, x, use_offload: bool = False): from megatron.core.pipeline_parallel import fine_grained_activation_offload as off if use_offload: # Initialize a new chunk (microbatch) and enable offload context. - with off.get_fine_grained_offloading_context(True): - off.fine_grained_offloading_init_chunk_handler( - vp_size=1, vp_stage=None, min_offloaded_tensor_size=1 - ) - for i, layer in enumerate(self.net): - # Group by module; with this linear-only model, each group corresponds to a layer. - x = off.fine_grained_offloading_group_start(x, name=f"layer_{i}") + off.fine_grained_offloading_init_chunk_handler( + vp_size=1, vp_stage=None, min_offloaded_tensor_size=1 + ) + for layer in self.net: + # Group by module; with this linear-only model, each group corresponds to a layer. + x = off.fine_grained_offloading_group_start(x, name=f"linear_layer") + with off.get_fine_grained_offloading_context(True): x = layer(x) - # Commit the group; returns a tuple of tensors - (x,) = off.fine_grained_offloading_group_commit( - x, name=f"layer_{i}", forced_released_tensors=[] - ) - return x + # Commit the group; returns a tuple of tensors + (x,) = off.fine_grained_offloading_group_commit( + x, name=f"linear_layer", forced_released_tensors=[] + ) + return x # Baseline path (no offload hooks) with ( torch.autocast(device_type="cuda", dtype=self.dtype) @@ -67,19 +63,6 @@ def forward(self, x, use_offload: bool = False): return x -@pytest.fixture(autouse=True) -def _monkeypatch_offload_deps(monkeypatch): - # Avoid requiring torch.distributed initialization and NVML in tests - import megatron.core.pipeline_parallel.fine_grained_activation_offload as off - - monkeypatch.setattr(off, "debug_rank", lambda *args, **kwargs: None, raising=False) - monkeypatch.setattr(off, "set_ideal_affinity_for_current_gpu", lambda: None, raising=False) - # Ensure a clean state each test - off.fine_grained_offloading_reset() - yield - off.fine_grained_offloading_reset() - - def test_fine_grained_activation_offload_memory_reduction(): torch.manual_seed(1234) # Use a linear-only stack so theoretical saved memory equals sum of per-layer input x bytes. @@ -111,7 +94,17 @@ def test_fine_grained_activation_offload_memory_reduction(): from megatron.core.pipeline_parallel import fine_grained_activation_offload as off off.fine_grained_offloading_reset() + # warmup + inp_off = inp.detach().clone().requires_grad_(True) + out_off = model(inp_off, use_offload=True) + (out_off.sum()).backward() + torch.cuda.synchronize() + off.fine_grained_offloading_reset() + del inp_off + del out_off _reset_cuda_memory() + torch.cuda.synchronize() + inp_off = inp.detach().clone().requires_grad_(True) offload_mem_before = torch.cuda.memory_allocated() / (1024**2) out_off = model(inp_off, use_offload=True) From 9475e3da720271d17945c0aff2642c981eab2b4a Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Wed, 17 Dec 2025 06:45:57 -0800 Subject: [PATCH 27/74] format Signed-off-by: Hongbin Liu --- .../pipeline_parallel/fine_grained_activation_offload.py | 4 +++- .../test_fine_grained_activation_offloading.py | 6 ++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index 96a5629250f..0deeaf63ac9 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -824,7 +824,9 @@ def tensor_pop(self, tensor_tag): def tensor_need_offloading_checker(self, tensor): """Check if the tensor needs to be offloaded.""" - debug_rank(f"tensor_need_offloading_checker {tensor.numel()} {getattr(tensor, 'offloading_activation', None)}") + debug_rank( + f"tensor_need_offloading_checker {getattr(tensor, 'offloading_activation', None)}" + ) if tensor.numel() < self.min_offloaded_tensor_size: return False # Respect tensor's offload preference if specified diff --git a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py index 8807670af33..e984b3d38ad 100644 --- a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py +++ b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py @@ -6,7 +6,6 @@ import pytest import torch - EPSILON = 0.1 # Skip all tests if CUDA is not available @@ -18,6 +17,7 @@ def _reset_cuda_memory(): if cuda_available: torch.cuda.empty_cache() + class ToyModel(torch.nn.Module): def __init__(self, hidden_size: int = 2048, num_layers: int = 4, dtype=torch.bfloat16): if not torch.distributed.is_initialized(): @@ -27,7 +27,9 @@ def __init__(self, hidden_size: int = 2048, num_layers: int = 4, dtype=torch.bfl super().__init__() layers = [] for _ in range(num_layers): - linear = torch.nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device="cuda") + linear = torch.nn.Linear( + hidden_size, hidden_size, bias=True, dtype=dtype, device="cuda" + ) layers.append(linear) self.net = torch.nn.Sequential(*layers).to(device="cuda", dtype=dtype) self.hidden_size = hidden_size From 93c0827737c1003c0a67da02297e6f03f59b8e64 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 5 Jan 2026 18:26:05 -0800 Subject: [PATCH 28/74] minor fix Signed-off-by: Hongbin Liu --- .../pipeline_parallel/fine_grained_activation_offload.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index 0deeaf63ac9..164ed3b14a0 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -440,7 +440,7 @@ def push_offload_groups(self, group_hook, forced_released_tensors): def flush_delayed_groups(self): """Flush the delayed groups.""" debug_rank("flushing delayed groups") - for group_hook, forced_released_tensors in self._delayed_offload_groups: + for group_hook, forced_released_tensors in reversed(self._delayed_offload_groups): group_hook(forced_released_tensors) self._delayed_offload_groups = [] @@ -766,7 +766,8 @@ def finish_all_groups(self, name=None) -> bool: f"------finish_all_groups {self} {self._max_group_size} {self._offloaded_group_index}" ) # TODO: check if this is correct - if len(self._groups_to_reload) == 0 and self._offloaded_group_index > 0: + # Mark it as finished when all groups are finished and there are no groups to offload or reload + if len(self._groups_to_reload) == 0 and len(self._groups_to_offload) == 0 and self._offloaded_group_index > 0: return True assert name is not None, "Name is required" for group in self.offload_groups[self._offloaded_group_index :]: From f22a1947399ab91d01f605976088e9e1d81b7d03 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 5 Jan 2026 18:37:51 -0800 Subject: [PATCH 29/74] remove changes for cuda graph Signed-off-by: Hongbin Liu --- megatron/core/transformer/cuda_graphs.py | 15 --- megatron/core/transformer/module.py | 4 - megatron/core/transformer/moe/experts.py | 2 - .../core/transformer/transformer_config.py | 20 --- .../core/transformer/transformer_layer.py | 126 +++--------------- megatron/training/arguments.py | 2 - 6 files changed, 16 insertions(+), 153 deletions(-) diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index 3da122d90f1..27e6c65c738 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -1848,16 +1848,6 @@ def _get_fp8_enabled(): ) else: kwargs['fp8_enabled'] = False - - from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - fine_grained_offloading_disable_offload, - fine_grained_offloading_enable_offload, - ) - - # if self.config.offload_module_in_cuda_graph: - if self.config.fine_grained_activation_offloading: - kwargs['pre_warmup_hook'] = fine_grained_offloading_disable_offload - kwargs['post_warmup_hook'] = fine_grained_offloading_enable_offload return kwargs kwargs = get_make_graphed_callables_kwargs() @@ -1892,13 +1882,8 @@ def _finish_capturing(self, start_time): _set_capture_end() from megatron.core.distributed.finalize_model_grads import reset_model_temporary_tensors - from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - fine_grained_offloading_reset, - ) from megatron.core.transformer.moe.moe_utils import clear_aux_losses_tracker - fine_grained_offloading_reset() - torch.distributed.barrier() for model_chunk in self.model: model_chunk.zero_grad_buffer() diff --git a/megatron/core/transformer/module.py b/megatron/core/transformer/module.py index 20da74bf67c..2330df91b52 100644 --- a/megatron/core/transformer/module.py +++ b/megatron/core/transformer/module.py @@ -281,10 +281,6 @@ def _get_te_cuda_graph_replay_args(self, *args, **kwargs): cudagraph_kwargs = kwargs.copy() cudagraph_kwargs['is_first_microbatch'] = getattr(self, 'current_microbatch', 0) == 0 - from megatron.core.transformer.transformer_layer import TransformerLayer - - cudagraph_kwargs['cuda_graph_stream'] = TransformerLayer.cuda_graph_stream - cudagraph_kwargs['cuda_graph_event'] = TransformerLayer.cuda_graph_event return cudagraph_args, cudagraph_kwargs def _should_call_local_cudagraph(self, *args, **kwargs): diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 544e12fa4a4..2c30fbbf1ab 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -896,7 +896,6 @@ def forward( bias_parallel, name="expert_fc1", forced_released_tensors=[permuted_local_hidden_states], - delay_offload=self.config.delay_offload_until_cuda_graph, ) def bias_act_func(intermediate_parallel, bias_parallel, permuted_probs): @@ -978,7 +977,6 @@ def glu(x): output, name="moe_act", forced_released_tensors=[fc1_output], - delay_offload=self.config.delay_offload_until_cuda_graph, ) # upad and concat the output diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 11611d8ce15..e2705bd9f51 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -850,9 +850,6 @@ class TransformerConfig(ModelParallelConfig): min_offloaded_tensor_size: int = 1024 * 1024 """The minimum size of the tensor to be offloaded.""" - delay_offload_until_cuda_graph: bool = False - """If True, delay the offload until the CUDA graph is executed for minimal CPU overhead.""" - def __post_init__(self): """Python dataclass method that is used to modify attributes after initialization. See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more @@ -1236,7 +1233,6 @@ def __post_init__(self): "attn_norm", "mlp_norm", "qkv_linear", - "dense_mlp", } invalid_modules = set(self.offload_modules) - allowed_modules assert not invalid_modules, ( @@ -1249,22 +1245,6 @@ def __post_init__(self): "because the input of attn_proj is the output of core_attn, " "which is needed in core_attn.backward()." ) - if self.external_cuda_graph or self.cuda_graph_impl == "transformer_engine": - assert ( - self.cuda_graph_scope is not None - ), "cuda_graph_scope must be set when enabling offloading." - assert ( - "attn" in self.cuda_graph_scope and "moe_router" in self.cuda_graph_scope - ) or ( - CudaGraphScope.attn in self.cuda_graph_scope - and CudaGraphScope.moe_router in self.cuda_graph_scope - ), "attn and moe_router must be in cuda_graph_scope when enabling offloading." - assert ( - "attn_norm" not in self.offload_modules - ), "input of attn_norm is the start point of cuda graph, which can't be offloaded." - assert ( - "mlp_norm" not in self.offload_modules - ), "mlp_norm goes through the boundary of cuda graph, which can't be offloaded." if ( self.num_layers_in_first_pipeline_stage is not None diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 310d9d2a927..3ea40577009 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -26,7 +26,6 @@ deprecate_inference_params, get_pg_rank, is_te_min_version, - is_torch_min_version, log_single_rank, make_viewless_tensor, nvtx_range_pop, @@ -260,9 +259,6 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): output of the same size. """ - cuda_graph_stream = None - cuda_graph_event = None - def __init__( self, config: TransformerConfig, @@ -416,8 +412,17 @@ def __init__( if "mlp" in self.config.recompute_modules: if not self.is_moe_layer: self.recompute_mlp = True + self.offload_attn_norm = ( + self.config.fine_grained_activation_offloading + and "attn_norm" in self.config.offload_modules + and not isinstance(self.input_layernorm, IdentityOp) + ) + self.offload_mlp_norm = ( + self.config.fine_grained_activation_offloading + and "mlp_norm" in self.config.offload_modules + and not isinstance(self.pre_mlp_layernorm, IdentityOp) + ) - self._set_offload_modules() # @jcasper how should we handle nvfuser? # Set bias+dropout+add fusion grad_enable execution handler. # TORCH_MAJOR = int(torch.__version__.split('.')[0]) @@ -505,15 +510,6 @@ def _forward_attention( get_fine_grained_offloading_context, ) - if self.offload_module_in_cuda_graph: - from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - fine_grained_offloading_backward_record, - ) - - hidden_states = fine_grained_offloading_backward_record( - hidden_states, TransformerLayer.cuda_graph_event - ) - inference_context = deprecate_inference_params(inference_context, inference_params) # Residual connection. @@ -595,7 +591,7 @@ def _forward_attention( return hidden_states, context - def _forward_mlp(self, hidden_states, inference_context=None, flush_delayed_groups=True): + def _forward_mlp(self, hidden_states, inference_context=None): """ Perform a forward pass through the feed-forward layer. @@ -607,7 +603,6 @@ def _forward_mlp(self, hidden_states, inference_context=None, flush_delayed_grou """ from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - fine_grained_offloading_group_commit, fine_grained_offloading_group_start, get_fine_grained_offloading_context, ) @@ -680,22 +675,7 @@ def _forward_mlp(self, hidden_states, inference_context=None, flush_delayed_grou bias_output = torch.stack(bias_chunks, dim=0).sum(dim=0) if bias_chunks else None mlp_output_with_bias = (mlp_output, bias_output) else: - if self.offload_dense_mlp: - pre_mlp_layernorm_output = fine_grained_offloading_group_start( - pre_mlp_layernorm_output, name="dense_mlp" - ) - with get_fine_grained_offloading_context(self.offload_dense_mlp): - mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) - if self.offload_dense_mlp: - (mlp_output,) = fine_grained_offloading_group_commit( - mlp_output_with_bias[0], - name="dense_mlp", - forced_released_tensors=[], - delay_offload=self.config.delay_offload_until_cuda_graph, - ) - mlp_output_with_bias = (mlp_output, mlp_output_with_bias[1]) - else: - mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) + mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) if self.recompute_pre_mlp_layernorm: # discard the output of the pre-mlp layernorm and register the recompute @@ -705,9 +685,9 @@ def _forward_mlp(self, hidden_states, inference_context=None, flush_delayed_grou ) nvtx_range_pop(suffix="mlp") - return self._forward_post_mlp(mlp_output_with_bias, residual, flush_delayed_groups) + return self._forward_post_mlp(mlp_output_with_bias, residual) - def _forward_post_mlp(self, mlp_output_with_bias, residual, flush_delayed_groups=True): + def _forward_post_mlp(self, mlp_output_with_bias, residual): """ Perform operations after the MLP computation. @@ -746,12 +726,6 @@ def _forward_post_mlp(self, mlp_output_with_bias, residual, flush_delayed_groups inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True ) - if self.config.delay_offload_until_cuda_graph and flush_delayed_groups: - from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - fine_grained_offloading_group_flush_delayed_groups, - ) - - fine_grained_offloading_group_flush_delayed_groups() return output def sharded_state_dict( @@ -862,12 +836,6 @@ def _te_cuda_graph_capture(self, *args, **kwargs): cuda_graph_outputs = list(hidden_states) if context is not None: cuda_graph_outputs.append(context) - if self.offload_module_in_cuda_graph: - from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - fine_grained_offloading_forward_record, - ) - - fine_grained_offloading_forward_record(TransformerLayer.cuda_graph_event) return tuple(cuda_graph_outputs) def _te_cuda_graph_replay(self, *args, **kwargs): @@ -893,13 +861,6 @@ def _te_cuda_graph_replay(self, *args, **kwargs): cuda_graph_output = list(super()._te_cuda_graph_replay(*args, **kwargs)) - if self.config.delay_offload_until_cuda_graph: - from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - fine_grained_offloading_group_flush_delayed_groups, - ) - - fine_grained_offloading_group_flush_delayed_groups() - if kwargs.get('context') is not None: context = cuda_graph_output.pop() @@ -957,12 +918,10 @@ def _te_cuda_graph_replay(self, *args, **kwargs): self.mlp.cudagraph_tensor_store.clear() nvtx_range_pop(suffix="mlp") - output = self._forward_post_mlp( - mlp_output_with_bias, mlp_residual, flush_delayed_groups=False - ) + output = self._forward_post_mlp(mlp_output_with_bias, mlp_residual) else: # CUDA Graph does not capture the MLP/MoE part at all. - output = self._forward_mlp(*cuda_graph_output, flush_delayed_groups=False) + output = self._forward_mlp(*cuda_graph_output) return output, context def _get_te_cuda_graph_replay_args(self, *args, **kwargs): @@ -1059,56 +1018,3 @@ def __call__(self, *args, **kwargs): 'inference_context' ].is_decode_only() return super().__call__(*args, **kwargs) - - def _set_offload_modules(self): - """Set the offload modules for the transformer layer.""" - if self.config.fine_grained_activation_offloading: - self.offload_attn_norm = "attn_norm" in self.config.offload_modules and not isinstance( - self.input_layernorm, IdentityOp - ) - self.offload_qkv_linear = "qkv_linear" in self.config.offload_modules - self.offload_core_attn = "core_attn" in self.config.offload_modules - self.offload_attn_proj = "attn_proj" in self.config.offload_modules - self.offload_mlp_norm = "mlp_norm" in self.config.offload_modules and not isinstance( - self.pre_mlp_layernorm, IdentityOp - ) - self.offload_expert_fc1 = "expert_fc1" in self.config.offload_modules - self.offload_moe_act = "moe_act" in self.config.offload_modules - self.offload_dense_mlp = ( - "dense_mlp" in self.config.offload_modules and not self.is_moe_layer - ) - else: - self.offload_attn_norm = False - self.offload_qkv_linear = False - self.offload_core_attn = False - self.offload_attn_proj = False - self.offload_mlp_norm = False - self.offload_expert_fc1 = False - self.offload_moe_act = False - self.offload_dense_mlp = False - # Set the offload module in cuda graph flag. - self.offload_module_in_cuda_graph = False - if CudaGraphScope.attn in self.config.cuda_graph_scope: - if self.offload_core_attn or self.offload_attn_proj or self.offload_qkv_linear: - self.offload_module_in_cuda_graph = True - if not self.is_moe_layer and CudaGraphScope.mlp in self.config.cuda_graph_scope: - if self.offload_mlp_norm or self.offload_dense_mlp: - self.offload_module_in_cuda_graph = True - if self.offload_module_in_cuda_graph: - assert is_torch_min_version( - "2.9.0a0" - ), "Fine-grained activation offloading needs torch>=2.9.0 to support cuda graph." - assert ( - self.config.cuda_graph_warmup_steps > 0 - ), "Fine-grained activation offloading needs cuda_graph_warmup_steps > 0." - # Set the cuda graph stream and event for the transformer layer. - if TransformerLayer.cuda_graph_stream is None: - if self.offload_module_in_cuda_graph: - TransformerLayer.cuda_graph_stream = torch.cuda.Stream() - else: - TransformerLayer.cuda_graph_stream = torch.cuda.current_stream() - if TransformerLayer.cuda_graph_event is None: - if self.offload_module_in_cuda_graph: - TransformerLayer.cuda_graph_event = torch.cuda.Event(external=True) - else: - TransformerLayer.cuda_graph_event = torch.cuda.Event() diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 4cbd7b12eb5..91292f1ed6e 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -2366,8 +2366,6 @@ def _add_training_args(parser): help='The submodules to offload its input. Choices: "attn_norm", "qkv_linear", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act".') group.add_argument('--min-offloaded-tensor-size', type=int, default=10*1024*1024, help='The minimum size of the tensor to be offloaded.') - group.add_argument('--delay-offload-until-cuda-graph', action='store_true', - help='Delay the offload until the CUDA graph is executed for minimal CPU overhead.') group.add_argument('--disable-jit-fuser', action='store_true', help='Disable the JIT fuser.') return parser From a9d6633f7e54430133e1215a90a37df38e1a5302 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 6 Jan 2026 19:12:14 -0800 Subject: [PATCH 30/74] bug fix when cuda graph is disabled and fix for dumping offloading info Signed-off-by: root --- .../fine_grained_activation_offload.py | 19 ++++++++++++++----- megatron/core/transformer/moe/experts.py | 2 +- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index 164ed3b14a0..50fbdfdcb2a 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -526,7 +526,7 @@ def post_warmup_callback(self): # Dump the offload information total_tensor_count = {} total_offload_bytes = {} - for chunk in self._cached_chunks_backward: + for chunk in self._cached_chunks_forward: for group in chunk.offload_groups: if group.offload: if group._name not in total_tensor_count: @@ -535,6 +535,10 @@ def post_warmup_callback(self): if group._name not in total_offload_bytes: total_offload_bytes[group._name] = 0 total_offload_bytes[group._name] += group.total_offload_bytes + # Stop statistics at the first backward chunk after which 1F1B is running, + # where the memory cost will not increase anymore. + if chunk is self._cached_chunks_backward[0]: + break print_offload_summary_table(total_offload_bytes) def push(self, handler): @@ -732,7 +736,7 @@ def __init__(self, min_offloaded_tensor_size, cpu_tensor_pool): self._groups_to_reload = [] self._tensor_count_current_group = 0 self._max_group_size = 0 - + self._reloading_group = [] # Counter for special torch tensor types (FakeTensor, FunctionalTensor) self.torch_tensor_count = 0 self.d2h_stream = PipelineOffloadManager.get_instance().d2h_stream @@ -748,6 +752,7 @@ def reset(self): self._groups_to_offload = [] self._groups_to_reload = [] self._tensor_count_current_group = 0 + self._reloading_group = [] def is_empty_chunk(self, name=None): """Check if this chunk has no tensors to manage.""" @@ -880,6 +885,7 @@ def bulk_reload_group(self): group_to_reload.push_tensor(tensor_tag, recovered_tensor) group_to_reload.record_reload_event(self.h2d_stream) self._groups_to_reload.pop() + self._reloading_group.append(group_to_reload) torch.cuda.nvtx.range_pop() def pre_reload_last_layer(self): @@ -962,9 +968,12 @@ def on_group_commit_backward(self, name): cur_backward_chunk = PipelineOffloadManager.get_instance().cur_backward_chunk() assert cur_backward_chunk is self, f"Chunk mismatch {cur_backward_chunk} {self}" # Wait for reload to complete before using tensors - if not is_graph_capturing() and len(self._groups_to_reload) > 0: - group_to_reload = self._groups_to_reload[-1] - group_to_reload.wait_reload_event(torch.cuda.current_stream()) + if not is_graph_capturing() and len(self._reloading_group) > 0: + for reloading_group in self._reloading_group: + if reloading_group._name == name: + reloading_group.wait_reload_event(torch.cuda.current_stream()) + self._reloading_group.remove(reloading_group) + break def on_group_start_forward(self, name): """ diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 2c30fbbf1ab..643563e9cb7 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -813,7 +813,7 @@ def __init__( set_save_original_input(self.linear_fc2) # This is to avoid the CPU overhead of multiple d2h copies - if self.offload_expert_fc1 and not (self.config.fp8 or self.config.fp4): + if self.offload_expert_fc1: from megatron.core.extensions.transformer_engine import set_save_original_input set_save_original_input(self.linear_fc1) From 9d766e9f466eab872b7ea205c5441b2ee37b2270 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Wed, 7 Jan 2026 22:25:10 -0800 Subject: [PATCH 31/74] refactor and update ut Signed-off-by: Hongbin Liu --- .../core/models/gpt/fine_grained_callables.py | 2 +- .../fine_grained_activation_offload.py | 66 ++- megatron/core/transformer/attention.py | 9 +- megatron/core/transformer/moe/experts.py | 13 +- .../transformer/multi_latent_attention.py | 10 +- .../core/transformer/transformer_layer.py | 4 +- ...test_fine_grained_activation_offloading.py | 445 ++++++++++++------ 7 files changed, 360 insertions(+), 189 deletions(-) diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index 60094976a9a..f6ced7ce2b4 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -464,7 +464,7 @@ def submodule_combine_forward( mlp_output_with_bias, residual, layer.hidden_dropout ) if layer.offload_mlp_norm: - (hidden_states,) = fine_grained_offloading_group_commit( + hidden_states = fine_grained_offloading_group_commit( hidden_states, name="mlp_norm", forced_released_tensors=[residual] ) output = make_viewless_tensor( diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index 50fbdfdcb2a..a84ba6de327 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -392,6 +392,12 @@ def get_instance(cls): cls.OFFLOAD_MGR = PipelineOffloadManager() return cls.OFFLOAD_MGR + @classmethod + def reset_instance(cls): + """Reset the singleton instance of PipelineOffloadManager.""" + cls.OFFLOAD_MGR = None + cls.OFFLOAD_MGR = PipelineOffloadManager() + def __init__(self): """Initialize the manager with queues and dedicated CUDA streams.""" # Queue to store chunk handlers for backward pass @@ -462,6 +468,16 @@ def reset(self): chunk.reset() self._delayed_offload_groups = [] + @property + def offload_summary_bytes(self) -> Dict[str, int]: + """Offload summary bytes per group collected after warmup.""" + return self._offload_summary_bytes + + @property + def offload_summary_total_bytes(self) -> int: + """Total offloaded bytes collected after warmup.""" + return self._offload_summary_total_bytes + def flush(self): """Flush all staged chunks to the backward queue in reverse order.""" # Ensure all virtual pipeline stages have the same number of chunks @@ -539,6 +555,9 @@ def post_warmup_callback(self): # where the memory cost will not increase anymore. if chunk is self._cached_chunks_backward[0]: break + # Cache summary for downstream consumers (e.g., unit tests). + self._offload_summary_bytes = dict(total_offload_bytes) + self._offload_summary_total_bytes = int(sum(total_offload_bytes.values())) print_offload_summary_table(total_offload_bytes) def push(self, handler): @@ -1033,22 +1052,17 @@ class FineGrainedOffloadingGroupCommitFunction(torch.autograd.Function): """ @staticmethod - def forward(ctx, *args): + def forward(ctx, tensor, cur_forward_chunk, name, forced_released_tensors, delay_offload): # pylint: disable=missing-function-docstring debug_rank("FineGrainedOffloadingGroupCommitFunction forward") - delay_offload = args[-1] - forced_released_tensors = args[-2] - name = args[-3] - cpu_offload_handler = args[-4] - tensor = args[:-4] if delay_offload: PipelineOffloadManager.get_instance().push_offload_groups( - cpu_offload_handler.on_group_commit_forward, forced_released_tensors + cur_forward_chunk.on_group_commit_forward, forced_released_tensors ) else: - cpu_offload_handler.on_group_commit_forward(forced_released_tensors) - ctx.cpu_offload_handler = cpu_offload_handler + cur_forward_chunk.on_group_commit_forward(forced_released_tensors) + ctx.cpu_offload_handler = cur_forward_chunk ctx.name = name return tensor @@ -1063,7 +1077,7 @@ def backward(ctx, *grad_output): def fine_grained_offloading_group_commit( - *tensor, name, forced_released_tensors=[], delay_offload=False + tensor, name, forced_released_tensors=None, delay_offload=False ): """ Specify the tensors to be released after offloading. @@ -1071,11 +1085,37 @@ def fine_grained_offloading_group_commit( The tensors will be untyped_storage().resize_(0) after offloading. Note: specify the tensors only when they are not automatically released by torch gc. """ + # Be permissive: callers may pass a tuple/list of outputs (e.g., (q, k, v)). + # We only need to insert a single identity op into the autograd graph; applying + # it to the first tensor output is sufficient and keeps callers' code minimal. + if forced_released_tensors is None: + forced_released_tensors = [] + if isinstance(tensor, tuple): + if len(tensor) == 0: + return tensor + committed0 = fine_grained_offloading_group_commit( + tensor[0], + name=name, + forced_released_tensors=forced_released_tensors, + delay_offload=delay_offload, + ) + return (committed0,) + tensor[1:] + if isinstance(tensor, list): + if len(tensor) == 0: + return tensor + committed0 = fine_grained_offloading_group_commit( + tensor[0], + name=name, + forced_released_tensors=forced_released_tensors, + delay_offload=delay_offload, + ) + return [committed0] + tensor[1:] + cur_forward_chunk = PipelineOffloadManager.get_instance().cur_forward_chunk() if cur_forward_chunk is None: return tensor return FineGrainedOffloadingGroupCommitFunction.apply( - *tensor, cur_forward_chunk, name, forced_released_tensors, delay_offload + tensor, cur_forward_chunk, name, forced_released_tensors, delay_offload ) @@ -1166,3 +1206,7 @@ def backward(ctx, grad_output): def fine_grained_offloading_backward_record(tensor, event: torch.cuda.Event) -> torch.Tensor: """Record the backward event for cuda graph capture.""" return FineGrainedOffloadingBackwardRecordFunction.apply(tensor, event) + +def fine_grained_offloading_reset_instance(): + """Reset the singleton instance of PipelineOffloadManager.""" + PipelineOffloadManager.reset_instance() diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 80e9ec6fc92..7a6a4859d5f 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -805,7 +805,8 @@ def forward( hidden_states, key_value_states, output_gate=output_gate, split_qkv=split_qkv ) if self.offload_qkv_linear: - (qkv_output,) = fine_grained_offloading_group_commit( + # `qkv_output` may be a tuple; commit supports tuple/list and will keep structure. + qkv_output = fine_grained_offloading_group_commit( qkv_output, name="qkv_linear", forced_released_tensors=[] ) @@ -991,7 +992,7 @@ def forward( ) core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)') if self.offload_core_attention and self.training: - (core_attn_out,) = fine_grained_offloading_group_commit( + core_attn_out = fine_grained_offloading_group_commit( core_attn_out, name="core_attn", forced_released_tensors=[query, key, value] ) @@ -1019,8 +1020,8 @@ def forward( with get_fine_grained_offloading_context(self.offload_attn_proj): output, bias = self.linear_proj(core_attn_out) if self.offload_attn_proj: - output, bias = fine_grained_offloading_group_commit( - output, bias, name="attn_proj", forced_released_tensors=[core_attn_out] + output = fine_grained_offloading_group_commit( + output, name="attn_proj", forced_released_tensors=[core_attn_out] ) nvtx_range_pop(suffix="linear_proj") diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 643563e9cb7..d0cccb287aa 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -891,11 +891,8 @@ def forward( permuted_local_hidden_states, tokens_per_expert ) if self.offload_expert_fc1: - fc1_output, bias_parallel = fine_grained_offloading_group_commit( - fc1_output, - bias_parallel, - name="expert_fc1", - forced_released_tensors=[permuted_local_hidden_states], + fc1_output = fine_grained_offloading_group_commit( + fc1_output, name="expert_fc1", forced_released_tensors=[permuted_local_hidden_states] ) def bias_act_func(intermediate_parallel, bias_parallel, permuted_probs): @@ -973,10 +970,8 @@ def glu(x): if self.activation_recompute: self.activation_checkpoint.discard_output_and_register_recompute(output) if self.offload_moe_act: - (output,) = fine_grained_offloading_group_commit( - output, - name="moe_act", - forced_released_tensors=[fc1_output], + output = fine_grained_offloading_group_commit( + output, name="moe_act", forced_released_tensors=[fc1_output] ) # upad and concat the output diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index d0b588bc94a..03c31d70686 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -270,8 +270,8 @@ def forward( f"{self.config.experimental_attention_variant}" ) if self.offload_qkv_linear: - (query, key, value) = fine_grained_offloading_group_commit( - query, key, value, name="qkv_linear", forced_released_tensors=[hidden_states] + query = fine_grained_offloading_group_commit( + query, name="qkv_linear", forced_released_tensors=[hidden_states] ) # =================================================== @@ -353,7 +353,7 @@ def forward( if not inference_context.is_decode_only(): core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)') if self.offload_core_attention and self.training: - (core_attn_out,) = fine_grained_offloading_group_commit( + core_attn_out = fine_grained_offloading_group_commit( core_attn_out, name="core_attn", forced_released_tensors=[query, key, value] ) @@ -386,8 +386,8 @@ def forward( with get_fine_grained_offloading_context(self.offload_attn_proj): output, bias = self.linear_proj(core_attn_out) if self.offload_attn_proj: - output, bias = fine_grained_offloading_group_commit( - output, bias, name="attn_proj", forced_released_tensors=[core_attn_out] + output = fine_grained_offloading_group_commit( + output, name="attn_proj", forced_released_tensors=[core_attn_out] ) return output, bias diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 3ea40577009..e00c767739f 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -561,7 +561,7 @@ def _forward_attention( nvtx_range_pop(suffix="self_attn_bda") if self.offload_attn_norm: - (hidden_states,) = fine_grained_offloading_group_commit( + hidden_states = fine_grained_offloading_group_commit( hidden_states, name="attn_norm", forced_released_tensors=[residual] ) @@ -712,7 +712,7 @@ def _forward_post_mlp(self, mlp_output_with_bias, residual): ) nvtx_range_pop(suffix="mlp_bda") if self.offload_mlp_norm: - (hidden_states,) = fine_grained_offloading_group_commit( + hidden_states = fine_grained_offloading_group_commit( hidden_states, name="mlp_norm", forced_released_tensors=[residual] ) diff --git a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py index e984b3d38ad..7dbb0d77d9e 100644 --- a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py +++ b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py @@ -2,180 +2,311 @@ import gc import os +from typing import Dict, List, Optional, Tuple import pytest import torch +from contextlib import nullcontext -EPSILON = 0.1 +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer.transformer_config import TransformerConfig, MLATransformerConfig +from megatron.core.utils import is_te_min_version +from tests.unit_tests.test_utilities import Utils +from megatron.core.transformer.enums import AttnBackend -# Skip all tests if CUDA is not available -cuda_available = torch.cuda.is_available() +# Tolerance for memory expectation check (GPU allocator jitter etc). +EPSILON = 0.30 -def _reset_cuda_memory(): + +def _reset_cuda_memory() -> None: gc.collect() - if cuda_available: + if torch.cuda.is_available(): torch.cuda.empty_cache() + torch.cuda.synchronize() -class ToyModel(torch.nn.Module): - def __init__(self, hidden_size: int = 2048, num_layers: int = 4, dtype=torch.bfloat16): - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl") - if torch.cuda.is_available(): - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) - super().__init__() - layers = [] - for _ in range(num_layers): - linear = torch.nn.Linear( - hidden_size, hidden_size, bias=True, dtype=dtype, device="cuda" - ) - layers.append(linear) - self.net = torch.nn.Sequential(*layers).to(device="cuda", dtype=dtype) - self.hidden_size = hidden_size - self.num_layers = num_layers - self.dtype = dtype - - def forward(self, x, use_offload: bool = False): - from megatron.core.pipeline_parallel import fine_grained_activation_offload as off - - if use_offload: - # Initialize a new chunk (microbatch) and enable offload context. - off.fine_grained_offloading_init_chunk_handler( - vp_size=1, vp_stage=None, min_offloaded_tensor_size=1 - ) - for layer in self.net: - # Group by module; with this linear-only model, each group corresponds to a layer. - x = off.fine_grained_offloading_group_start(x, name=f"linear_layer") - with off.get_fine_grained_offloading_context(True): - x = layer(x) - # Commit the group; returns a tuple of tensors - (x,) = off.fine_grained_offloading_group_commit( - x, name=f"linear_layer", forced_released_tensors=[] - ) - return x - # Baseline path (no offload hooks) - with ( - torch.autocast(device_type="cuda", dtype=self.dtype) - if self.dtype in (torch.float16, torch.bfloat16) - else torch.cuda.amp.autocast(enabled=False) - ): - for layer in self.net: - x = layer(x) - return x - - -def test_fine_grained_activation_offload_memory_reduction(): - torch.manual_seed(1234) - # Use a linear-only stack so theoretical saved memory equals sum of per-layer input x bytes. - model = ToyModel(hidden_size=2048, num_layers=8, dtype=torch.bfloat16).eval() - - # Create input - inp = torch.randn( - (2048, model.hidden_size), device="cuda", dtype=torch.bfloat16, requires_grad=True +def _build_gpt_model( + *, + seed: int, + num_layers: int, + hidden_size: int, + num_attention_heads: int, + vocab_size: int, + seq_length: int, + num_experts: Optional[int], + fine_grained_activation_offloading: bool, + offload_modules: Optional[List[str]], + min_offloaded_tensor_size: int, + is_mla: bool, +) -> GPTModel: + """Build a GPTModel that uses TE-based transformer layer spec.""" + model_parallel_cuda_manual_seed(seed) + torch.manual_seed(seed) + ConfigClass = MLATransformerConfig if is_mla else TransformerConfig + transformer_config = ConfigClass( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + use_cpu_initialization=True, + attention_backend=AttnBackend.unfused, + # Make sure model weights / activations are BF16 so TE fused attention isn't disabled. + bf16=True, + # params_dtype=torch.bfloat16, + # enable_autocast=True, + # autocast_dtype=torch.bfloat16, + # MoE + num_moe_experts=num_experts, + moe_grouped_gemm=(num_experts is not None), + # Fine-grained activation offloading + fine_grained_activation_offloading=fine_grained_activation_offloading, + offload_modules=offload_modules, + min_offloaded_tensor_size=min_offloaded_tensor_size, ) + gpt_model = GPTModel( + config=transformer_config, + transformer_layer_spec=get_gpt_layer_with_transformer_engine_spec( + num_experts=num_experts, + moe_grouped_gemm=num_experts is not None, + moe_use_legacy_grouped_gemm=False, + multi_latent_attention=is_mla, + ), + vocab_size=vocab_size, + max_sequence_length=seq_length, + ).bfloat16() + return gpt_model + + +def _make_gpt_inputs( + *, + seq_length: int, + micro_batch_size: int, + device: torch.device, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + data = list(range(seq_length)) + input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).to(device) + position_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).to(device) + attention_mask = torch.ones((micro_batch_size, 1, seq_length, seq_length), dtype=bool).to( + device + ) + return input_ids, position_ids, attention_mask - # Warmup to stabilize allocator behavior - _reset_cuda_memory() - out = model(inp, use_offload=False) - (out.sum()).backward() - torch.cuda.synchronize() - _reset_cuda_memory() - - # Baseline memory measurement (no offload) - _reset_cuda_memory() - inp_baseline = inp.detach().clone().requires_grad_(True) - baseline_mem_before = torch.cuda.memory_allocated() / (1024**2) - out_base = model(inp_baseline, use_offload=False) - baseline_mem_after = (torch.cuda.memory_allocated() - out_base.nbytes) / (1024**2) - (out_base.sum()).backward() - torch.cuda.synchronize() - baseline_delta = baseline_mem_after - baseline_mem_before - # Offload memory measurement +def _run_one_iter_and_capture( + model: GPTModel, + *, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor, + enable_offload_reset: bool, +) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], int]: + """ + Run a single forward+backward iteration. + + Returns: + - logits (CPU float32) + - selected grads (CPU float32) + - peak_memory_allocated (bytes) during the iteration + """ from megatron.core.pipeline_parallel import fine_grained_activation_offload as off - off.fine_grained_offloading_reset() - # warmup - inp_off = inp.detach().clone().requires_grad_(True) - out_off = model(inp_off, use_offload=True) - (out_off.sum()).backward() - torch.cuda.synchronize() - off.fine_grained_offloading_reset() - del inp_off - del out_off - _reset_cuda_memory() - torch.cuda.synchronize() + if enable_offload_reset: + off.fine_grained_offloading_reset() + + # for p in model.parameters(): + # if p.grad is not None: + # p.grad = None - inp_off = inp.detach().clone().requires_grad_(True) - offload_mem_before = torch.cuda.memory_allocated() / (1024**2) - out_off = model(inp_off, use_offload=True) - offload_mem_after = (torch.cuda.memory_allocated() - out_off.nbytes) / (1024**2) - (out_off.sum()).backward() + torch.cuda.reset_peak_memory_stats() + logits = model( + input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask + ) + loss = logits.float().sum() + loss.backward() torch.cuda.synchronize() - offload_delta = offload_mem_after - offload_mem_before - - # Offload should reduce peak cached memory usage after forward - assert ( - offload_delta < baseline_delta - ), f"offload did not reduce memory: off={offload_delta:.2f}MiB base={baseline_delta:.2f}MiB" - - # Theoretical savings: storing per-layer input x (same shape each layer). - bytes_per_elem = inp.element_size() # 2 for bfloat16 - input_bytes = inp.numel() * bytes_per_elem - # -2 because the first and last activations are not offloaded - expected_saved_mib = (model.num_layers - 2) * (input_bytes / (1024**2)) - - # Actual savings ≈ baseline_delta - offload_delta (both exclude output tensor memory). - actual_saved_mib = baseline_delta - offload_delta - - # Allow slack for allocator jitter and extra intermediates; magnitudes should match. - rel_err = abs(actual_saved_mib - expected_saved_mib) / max(expected_saved_mib, 1e-6) - assert ( - rel_err <= EPSILON - ), f"saved mismatch: actual={actual_saved_mib:.2f}MiB expected~={expected_saved_mib:.2f}MiB (rel_err={rel_err:.2f})" - - -def test_fine_grained_activation_offload_output_and_grad_consistency(): - torch.manual_seed(2025) - hidden = 1024 - layers = 3 - - # Create identical models by resetting seed - torch.manual_seed(2025) - model_base = ToyModel(hidden_size=hidden, num_layers=layers, dtype=torch.bfloat16).train() - torch.manual_seed(2025) - model_off = ToyModel(hidden_size=hidden, num_layers=layers, dtype=torch.bfloat16).train() - - # Same input and target - inp = torch.randn((32, hidden), device="cuda", dtype=torch.bfloat16, requires_grad=True) - target = torch.randn_like(inp) - - # Baseline forward/backward - out_base = model_base(inp, use_offload=False) - loss_base = torch.nn.functional.mse_loss(out_base, target) - loss_base.backward() - grads_base = [ - p.grad.detach().clone() if p.grad is not None else None for p in model_base.parameters() - ] - - # Offload forward/backward + peak_bytes = int(torch.cuda.max_memory_allocated()) + + # capture all gradients for correctness + grads: Dict[str, torch.Tensor] = {} + for name, p in model.named_parameters(): + grads[name] = (p.grad.detach().float().cpu() if p.grad is not None else None) + + return logits.detach().float().cpu(), grads, peak_bytes + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for offloading tests.") +@pytest.mark.skipif( + not is_te_min_version("1.13.0"), + reason="Fine-grained activation offloading requires TE-based GPT layer spec (TE 1.13+ in this repo's tests).", +) +@pytest.mark.parametrize( + "is_moe, is_mla, offload_modules", + [ + # Dense GPT modules + (False, True, ["attn_norm"]), + (True, False, ["qkv_linear"]), + (True, False, ["core_attn"]), + # # attn_proj depends on core_attn (validated in TransformerConfig.__post_init__) + (True, True, ["core_attn", "attn_proj"]), + (True, False, ["mlp_norm"]), + (True, False, ["expert_fc1"]), + (True, False, ["moe_act"]), + ], +) +def test_gpt_fine_grained_activation_offloading_correctness_and_memory( + is_moe: bool, is_mla: bool, offload_modules: List[str] +): + """ + Initialize a GPTModel and verify: + - forward output correctness under each offload_modules setting + - backward gradient correctness (subset) + - peak GPU memory is reduced roughly as expected (based on recorded offload bytes) + """ + # setup distributed/model-parallel (same pattern as other UTs) + os.environ.pop("NVTE_FUSED_ATTN", None) + os.environ.pop("NVTE_FLASH_ATTN", None) + os.environ.pop("NVTE_UNFUSED_ATTN", None) + # os.environ["NVTE_FLASH_ATTN"] = "1" + Utils.initialize_model_parallel(1, 1) + torch.cuda.memory._record_memory_history(max_entries=100000) + + seed = 123 + # Choose shapes large enough to make memory deltas stable but still fast. + num_experts = 4 if is_moe else None + num_layers = 8 + hidden_size = 2048 if num_experts is None else 1024 + num_attention_heads = 16 if hidden_size >= 2048 else 8 + vocab_size = 512 + seq_length = 512 + micro_batch_size = 2 + device = torch.device("cuda") + + input_ids, position_ids, attention_mask = _make_gpt_inputs( + seq_length=seq_length, micro_batch_size=micro_batch_size, device=device + ) + from megatron.core.pipeline_parallel import fine_grained_activation_offload as off + off.fine_grained_offloading_reset_instance() + + try: + # 1) Baseline run (no offloading) + _reset_cuda_memory() + base_model = _build_gpt_model( + seed=seed, + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + vocab_size=vocab_size, + seq_length=seq_length, + num_experts=num_experts, + fine_grained_activation_offloading=False, + offload_modules=None, + min_offloaded_tensor_size=1024 * 1024, + is_mla=is_mla, + ).cuda() + base_model.train() + + # Warmup baseline once for allocator stability + _run_one_iter_and_capture( + base_model, + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + enable_offload_reset=False, + ) + _reset_cuda_memory() + base_logits, base_grads, base_peak = _run_one_iter_and_capture( + base_model, + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + enable_offload_reset=False, + ) + # Free baseline model GPU memory before offload path + del base_model + _reset_cuda_memory() + + # 2) Offload run (warmup to record bytes + steady-state measurement) + off_model = _build_gpt_model( + seed=seed, + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + vocab_size=vocab_size, + seq_length=seq_length, + num_experts=num_experts, + fine_grained_activation_offloading=True, + offload_modules=offload_modules, + min_offloaded_tensor_size=1024, # force offloading for UT determinism + is_mla=is_mla, + ).cuda() + off_model.train() - off.fine_grained_offloading_reset() - out_off = model_off(inp.detach().clone().requires_grad_(True), use_offload=True) - loss_off = torch.nn.functional.mse_loss(out_off, target) - loss_off.backward() - grads_off = [ - p.grad.detach().clone() if p.grad is not None else None for p in model_off.parameters() - ] - - # Compare outputs - assert torch.allclose(out_off.float(), out_base.float(), rtol=1e-3, atol=1e-3) - - # Compare gradients parameter-wise - for gb, go in zip(grads_base, grads_off): - if gb is None and go is None: - continue - assert gb is not None and go is not None - assert torch.allclose(go.float(), gb.float(), rtol=1e-3, atol=1e-3) + # Warmup 1 iter to populate cached chunks, then reset to finish warmup bookkeeping. + _run_one_iter_and_capture( + off_model, + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + enable_offload_reset=True, + ) + # Reset once more to trigger post_warmup_callback and apply steady-state offload decisions. + off.fine_grained_offloading_reset() + + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + PipelineOffloadManager, + ) + + mgr = PipelineOffloadManager.get_instance() + expected_offload_bytes = int( + sum(mgr.offload_summary_bytes.get(k, 0) for k in offload_modules) + ) + expected_offload_mib = expected_offload_bytes / (1024**2) + + _reset_cuda_memory() + off_logits, off_grads, off_peak = _run_one_iter_and_capture( + off_model, + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + enable_offload_reset=True, + ) + del off_model + _reset_cuda_memory() + + torch.cuda.memory._dump_snapshot(f"/workspace/pyt_profile/memory_snapshot.pickle") + print(f"Captured memory snapshot at /workspace/pyt_profile/memory_snapshot.pickle") + torch.cuda.memory._record_memory_history(enabled=False) + + # 3) Correctness checks (forward + selected grads) + assert torch.allclose(off_logits, base_logits, rtol=1e-3, atol=1e-3) + assert set(off_grads.keys()) == set(base_grads.keys()) + for name, gb in base_grads.items(): + go = off_grads[name] + if gb is None or go is None: + assert gb is None and go is None, f"Grad None mismatch for {name}" + continue + assert torch.allclose(go, gb, rtol=1e-3, atol=1e-3), f"Grad mismatch for {name}" + + # 4) Memory checks (peak allocated over forward+backward) + saved_mib = (base_peak - off_peak) / (1024**2) + assert saved_mib > 0.0, ( + f"Expected GPU peak memory reduction for offload_modules={offload_modules}, " + f"but got saved={saved_mib:.2f}MiB (base={base_peak/(1024**2):.2f}MiB, " + f"off={off_peak/(1024**2):.2f}MiB)" + ) + + # If expectation is large enough, enforce approximate match. + # For tiny expectations, allocator noise may dominate; we only require a positive reduction. + if expected_offload_mib >= 2.0: + rel_err = abs(saved_mib - expected_offload_mib) / max(expected_offload_mib, 1e-6) + assert rel_err <= EPSILON, ( + f"Memory saving mismatch for offload_modules={offload_modules}: " + f"saved={saved_mib:.2f}MiB expected~={expected_offload_mib:.2f}MiB " + f"(rel_err={rel_err:.2f})" + ) + print(f"Rank {torch.distributed.get_rank()}: Saved {saved_mib:.2f}MiB, expected {expected_offload_mib:.2f}MiB") + finally: + Utils.destroy_model_parallel() From 9d1fe341c4019f8d8ccea7395705ee97fca0ee0e Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Wed, 7 Jan 2026 23:59:19 -0800 Subject: [PATCH 32/74] format Signed-off-by: Hongbin Liu --- .../fine_grained_activation_offload.py | 9 +++++-- megatron/core/transformer/moe/experts.py | 4 ++- ...test_fine_grained_activation_offloading.py | 25 ++++++++----------- 3 files changed, 21 insertions(+), 17 deletions(-) diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index a84ba6de327..baabe29838a 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -790,8 +790,12 @@ def finish_all_groups(self, name=None) -> bool: f"------finish_all_groups {self} {self._max_group_size} {self._offloaded_group_index}" ) # TODO: check if this is correct - # Mark it as finished when all groups are finished and there are no groups to offload or reload - if len(self._groups_to_reload) == 0 and len(self._groups_to_offload) == 0 and self._offloaded_group_index > 0: + # Mark it as finished when there are no groups to offload or reload + if ( + len(self._groups_to_reload) == 0 + and len(self._groups_to_offload) == 0 + and self._offloaded_group_index > 0 + ): return True assert name is not None, "Name is required" for group in self.offload_groups[self._offloaded_group_index :]: @@ -1207,6 +1211,7 @@ def fine_grained_offloading_backward_record(tensor, event: torch.cuda.Event) -> """Record the backward event for cuda graph capture.""" return FineGrainedOffloadingBackwardRecordFunction.apply(tensor, event) + def fine_grained_offloading_reset_instance(): """Reset the singleton instance of PipelineOffloadManager.""" PipelineOffloadManager.reset_instance() diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index d0cccb287aa..471008489c9 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -892,7 +892,9 @@ def forward( ) if self.offload_expert_fc1: fc1_output = fine_grained_offloading_group_commit( - fc1_output, name="expert_fc1", forced_released_tensors=[permuted_local_hidden_states] + fc1_output, + name="expert_fc1", + forced_released_tensors=[permuted_local_hidden_states], ) def bias_act_func(intermediate_parallel, bias_parallel, permuted_probs): diff --git a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py index 7dbb0d77d9e..302af7a2c27 100644 --- a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py +++ b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py @@ -2,20 +2,19 @@ import gc import os +from contextlib import nullcontext from typing import Dict, List, Optional, Tuple import pytest import torch -from contextlib import nullcontext from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.transformer_config import TransformerConfig, MLATransformerConfig +from megatron.core.transformer.enums import AttnBackend +from megatron.core.transformer.transformer_config import MLATransformerConfig, TransformerConfig from megatron.core.utils import is_te_min_version from tests.unit_tests.test_utilities import Utils -from megatron.core.transformer.enums import AttnBackend - # Tolerance for memory expectation check (GPU allocator jitter etc). EPSILON = 0.30 @@ -72,7 +71,7 @@ def _build_gpt_model( moe_grouped_gemm=num_experts is not None, moe_use_legacy_grouped_gemm=False, multi_latent_attention=is_mla, - ), + ), vocab_size=vocab_size, max_sequence_length=seq_length, ).bfloat16() @@ -80,10 +79,7 @@ def _build_gpt_model( def _make_gpt_inputs( - *, - seq_length: int, - micro_batch_size: int, - device: torch.device, + *, seq_length: int, micro_batch_size: int, device: torch.device ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: data = list(range(seq_length)) input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).to(device) @@ -120,9 +116,7 @@ def _run_one_iter_and_capture( # p.grad = None torch.cuda.reset_peak_memory_stats() - logits = model( - input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask - ) + logits = model(input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask) loss = logits.float().sum() loss.backward() torch.cuda.synchronize() @@ -131,7 +125,7 @@ def _run_one_iter_and_capture( # capture all gradients for correctness grads: Dict[str, torch.Tensor] = {} for name, p in model.named_parameters(): - grads[name] = (p.grad.detach().float().cpu() if p.grad is not None else None) + grads[name] = p.grad.detach().float().cpu() if p.grad is not None else None return logits.detach().float().cpu(), grads, peak_bytes @@ -188,6 +182,7 @@ def test_gpt_fine_grained_activation_offloading_correctness_and_memory( ) from megatron.core.pipeline_parallel import fine_grained_activation_offload as off + off.fine_grained_offloading_reset_instance() try: @@ -307,6 +302,8 @@ def test_gpt_fine_grained_activation_offloading_correctness_and_memory( f"saved={saved_mib:.2f}MiB expected~={expected_offload_mib:.2f}MiB " f"(rel_err={rel_err:.2f})" ) - print(f"Rank {torch.distributed.get_rank()}: Saved {saved_mib:.2f}MiB, expected {expected_offload_mib:.2f}MiB") + print( + f"Rank {torch.distributed.get_rank()}: Saved {saved_mib:.2f}MiB, expected {expected_offload_mib:.2f}MiB" + ) finally: Utils.destroy_model_parallel() From 08b46aaf62238066bd573303aac5278b865db72d Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Thu, 8 Jan 2026 00:22:16 -0800 Subject: [PATCH 33/74] fix ut Signed-off-by: Hongbin Liu --- .../test_fine_grained_activation_offloading.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py index 302af7a2c27..34c1e523dcb 100644 --- a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py +++ b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py @@ -164,7 +164,6 @@ def test_gpt_fine_grained_activation_offloading_correctness_and_memory( os.environ.pop("NVTE_UNFUSED_ATTN", None) # os.environ["NVTE_FLASH_ATTN"] = "1" Utils.initialize_model_parallel(1, 1) - torch.cuda.memory._record_memory_history(max_entries=100000) seed = 123 # Choose shapes large enough to make memory deltas stable but still fast. @@ -271,10 +270,6 @@ def test_gpt_fine_grained_activation_offloading_correctness_and_memory( del off_model _reset_cuda_memory() - torch.cuda.memory._dump_snapshot(f"/workspace/pyt_profile/memory_snapshot.pickle") - print(f"Captured memory snapshot at /workspace/pyt_profile/memory_snapshot.pickle") - torch.cuda.memory._record_memory_history(enabled=False) - # 3) Correctness checks (forward + selected grads) assert torch.allclose(off_logits, base_logits, rtol=1e-3, atol=1e-3) assert set(off_grads.keys()) == set(base_grads.keys()) From d33b3c461b34b690c74597fbc28cd0e03e20c684 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 12 Jan 2026 01:25:17 -0800 Subject: [PATCH 34/74] update ut Signed-off-by: Hongbin Liu --- ...test_fine_grained_activation_offloading.py | 293 +++++++++++++++++- 1 file changed, 280 insertions(+), 13 deletions(-) diff --git a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py index 34c1e523dcb..88d76bf5d80 100644 --- a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py +++ b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py @@ -18,6 +18,8 @@ # Tolerance for memory expectation check (GPU allocator jitter etc). EPSILON = 0.30 +EPSILON_A2A = 0.30 +DELTA = 20 # MiB def _reset_cuda_memory() -> None: @@ -51,11 +53,10 @@ def _build_gpt_model( num_attention_heads=num_attention_heads, use_cpu_initialization=True, attention_backend=AttnBackend.unfused, - # Make sure model weights / activations are BF16 so TE fused attention isn't disabled. bf16=True, - # params_dtype=torch.bfloat16, - # enable_autocast=True, - # autocast_dtype=torch.bfloat16, + # Recompute + recompute_modules=["layernorm", "moe_act"] if num_experts is not None else ["layernorm"], + recompute_granularity="selective", # MoE num_moe_experts=num_experts, moe_grouped_gemm=(num_experts is not None), @@ -131,10 +132,6 @@ def _run_one_iter_and_capture( @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for offloading tests.") -@pytest.mark.skipif( - not is_te_min_version("1.13.0"), - reason="Fine-grained activation offloading requires TE-based GPT layer spec (TE 1.13+ in this repo's tests).", -) @pytest.mark.parametrize( "is_moe, is_mla, offload_modules", [ @@ -163,7 +160,7 @@ def test_gpt_fine_grained_activation_offloading_correctness_and_memory( os.environ.pop("NVTE_FLASH_ATTN", None) os.environ.pop("NVTE_UNFUSED_ATTN", None) # os.environ["NVTE_FLASH_ATTN"] = "1" - Utils.initialize_model_parallel(1, 1) + Utils.initialize_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=1) seed = 123 # Choose shapes large enough to make memory deltas stable but still fast. @@ -171,8 +168,8 @@ def test_gpt_fine_grained_activation_offloading_correctness_and_memory( num_layers = 8 hidden_size = 2048 if num_experts is None else 1024 num_attention_heads = 16 if hidden_size >= 2048 else 8 - vocab_size = 512 - seq_length = 512 + vocab_size = 1024 + seq_length = 1024 micro_batch_size = 2 device = torch.device("cuda") @@ -292,13 +289,283 @@ def test_gpt_fine_grained_activation_offloading_correctness_and_memory( # For tiny expectations, allocator noise may dominate; we only require a positive reduction. if expected_offload_mib >= 2.0: rel_err = abs(saved_mib - expected_offload_mib) / max(expected_offload_mib, 1e-6) - assert rel_err <= EPSILON, ( + abs_err = abs(saved_mib - expected_offload_mib) + assert rel_err <= EPSILON and abs_err <= DELTA, ( f"Memory saving mismatch for offload_modules={offload_modules}: " f"saved={saved_mib:.2f}MiB expected~={expected_offload_mib:.2f}MiB " - f"(rel_err={rel_err:.2f})" + f"(rel_err={rel_err:.2f}, abs_err={abs_err:.2f})" ) print( f"Rank {torch.distributed.get_rank()}: Saved {saved_mib:.2f}MiB, expected {expected_offload_mib:.2f}MiB" ) finally: Utils.destroy_model_parallel() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for offloading tests.") +@pytest.mark.skipif( + not is_te_min_version("1.9.0.dev0"), + reason="EP A2A overlap requires TE 1.9.0.dev0+ in this repo's tests.", +) +@pytest.mark.parametrize( + "dispatcher_backend, is_mla, offload_modules", + [ + ("alltoall", True, ["attn_norm"]), + ("alltoall", True, ["core_attn"]), + ("alltoall", True, ["attn_norm", "core_attn", "attn_proj"]), + ("alltoall", True, ["mlp_norm"]), + ("alltoall", False, ["expert_fc1"]), + ("alltoall", False, ["moe_act"]), + ("alltoall", False, ["mlp_norm", "expert_fc1", "moe_act"]), + ( + "alltoall", + True, + ["attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act"], + ), + ( + "alltoall", + False, + ["attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act"], + ), + ], +) +def test_fine_grained_activation_offload_with_ep_a2a_overlap_compatibility( + dispatcher_backend: str, is_mla: bool, offload_modules: List[str] +): + """ + Compatibility test for: + - fine-grained activation offloading + - EP all-to-all overlap (overlap_moe_expert_parallel_comm) + - memory saving roughly matches expected offload bytes (when expectation is large enough) + + The EP A2A overlap initialization pattern is aligned with + `tests/unit_tests/a2a_overlap/test_schedule_chunk_1f1b.py`. + """ + from megatron.core.models.common.model_chunk_schedule_plan import ( + TransformerModelChunkSchedulePlan, + ) + from megatron.core.pipeline_parallel.utils import set_streams + from tests.unit_tests.a2a_overlap.utils import deterministic_mode + + # EP overlap requires distributed initialization with EP groups. + ep_size = 4 + if Utils.world_size % ep_size != 0: + pytest.skip( + f"Skipping: WORLD_SIZE={Utils.world_size} must be divisible by ep_size={ep_size}." + ) + + seed = 123 + num_experts = 8 # must be divisible by ep_size + if num_experts % ep_size != 0: + pytest.skip( + f"Skipping: num_moe_experts={num_experts} must be divisible by ep_size={ep_size}." + ) + + # Small shapes to keep this compatibility test fast. + num_layers = 8 + hidden_size = 1024 + num_attention_heads = 16 + vocab_size = 1024 + seq_length = 1024 + micro_batch_size = 2 + device = torch.device("cuda") + + from megatron.core.pipeline_parallel import fine_grained_activation_offload as off + + def _make_schedule_inputs() -> Dict[str, torch.Tensor]: + data = list(range(seq_length)) + input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).to(device) + position_ids = ( + torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).to(device) + ) + attention_mask = torch.ones((micro_batch_size, 1, seq_length, seq_length), dtype=bool).to( + device + ) + labels = input_ids.clone() + return { + "input_ids": input_ids, + "labels": labels, + "position_ids": position_ids, + "attention_mask": attention_mask, + } + + def _capture_params(model: torch.nn.Module) -> Dict[str, torch.Tensor]: + params: Dict[str, torch.Tensor] = {} + for name, p in model.named_parameters(): + params[name] = p.detach().clone() + return params + + def _restore_params(model: torch.nn.Module, params: Dict[str, torch.Tensor]) -> None: + for name, p in model.named_parameters(): + p.data.copy_(params[name]) + + def _build_overlap_moe_gpt( + *, enable_offload: bool, is_mla: bool, dispatcher_backend: str + ) -> GPTModel: + model_parallel_cuda_manual_seed(seed) + torch.manual_seed(seed) + ConfigClass = MLATransformerConfig if is_mla else TransformerConfig + transformer_config = ConfigClass( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + use_cpu_initialization=True, + attention_backend=AttnBackend.unfused, + # Recompute + recompute_modules=["layernorm", "moe_act"], + recompute_granularity="selective", + bf16=True, + # MoE + EP overlap + num_moe_experts=num_experts, + moe_grouped_gemm=True, + expert_model_parallel_size=ep_size, + moe_token_dispatcher_type="alltoall" if dispatcher_backend == "alltoall" else "flex", + moe_flex_dispatcher_backend=dispatcher_backend, + moe_router_dtype="fp32" if dispatcher_backend == "hybridep" else "fp64", + overlap_moe_expert_parallel_comm=True, + delay_wgrad_compute=True, + # Fine-grained activation offloading + fine_grained_activation_offloading=enable_offload, + offload_modules=offload_modules if enable_offload else None, + min_offloaded_tensor_size=1024, # force offloading to exercise the code path + ) + return ( + GPTModel( + config=transformer_config, + transformer_layer_spec=get_gpt_layer_with_transformer_engine_spec( + num_experts=num_experts, + moe_grouped_gemm=True, + moe_use_legacy_grouped_gemm=False, + multi_latent_attention=is_mla, + ), + vocab_size=vocab_size, + max_sequence_length=seq_length, + ) + .bfloat16() + .cuda() + ) + + def _run_schedule_1f1b_two_microbatches( + model: GPTModel, *, enable_offload_reset: bool + ) -> Tuple[List[torch.Tensor], Dict[str, torch.Tensor], int]: + """ + Run a minimal 1F1B schedule (2 microbatches) using ModelChunkSchedulePlan.run(). + This is the execution path that exercises EP A2A overlap scheduling. + """ + if enable_offload_reset: + off.fine_grained_offloading_reset() + + data0 = _make_schedule_inputs() + data1 = _make_schedule_inputs() + plan0 = model.build_schedule_plan(**data0) + + torch.cuda.reset_peak_memory_stats() + out0 = TransformerModelChunkSchedulePlan.run(plan0, None) + plan1 = model.build_schedule_plan(**data1) + out1 = TransformerModelChunkSchedulePlan.run(plan1, plan0, b_grad=torch.ones_like(out0)) + TransformerModelChunkSchedulePlan.run(None, plan1, b_grad=torch.ones_like(out1)) + torch.cuda.synchronize() + peak_bytes = int(torch.cuda.max_memory_allocated()) + + # capture outputs and grads + outputs = [out0.detach().float().cpu(), out1.detach().float().cpu()] + grads: Dict[str, torch.Tensor] = {} + for name, p in model.named_parameters(): + grads[name] = p.grad.detach().float().cpu() if p.grad is not None else None + return outputs, grads, peak_bytes + + # setup distributed/model-parallel + os.environ.pop("NVTE_FUSED_ATTN", None) + os.environ.pop("NVTE_FLASH_ATTN", None) + os.environ.pop("NVTE_UNFUSED_ATTN", None) + + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + expert_model_parallel_size=ep_size, + ) + set_streams() + + off.fine_grained_offloading_reset_instance() + + try: + with deterministic_mode(): + # Baseline: EP overlap on, offload off. + _reset_cuda_memory() + base_model = _build_overlap_moe_gpt( + enable_offload=False, is_mla=is_mla, dispatcher_backend=dispatcher_backend + ) + base_model.train() + base_params = _capture_params(base_model) + # Warmup once for allocator stability / graph caching + _run_schedule_1f1b_two_microbatches(base_model, enable_offload_reset=False) + _reset_cuda_memory() + base_outs, base_grads, base_peak = _run_schedule_1f1b_two_microbatches( + base_model, enable_offload_reset=False + ) + del base_model + _reset_cuda_memory() + + # Offload: EP overlap on, fine-grained offload on. + off_model = _build_overlap_moe_gpt( + enable_offload=True, is_mla=is_mla, dispatcher_backend=dispatcher_backend + ) + _restore_params(off_model, base_params) + off_model.train() + # Warmup once to populate cached chunks, then reset to apply steady-state offload decisions. + off.fine_grained_offloading_reset() + _run_schedule_1f1b_two_microbatches(off_model, enable_offload_reset=False) + off.fine_grained_offloading_reset() + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + PipelineOffloadManager, + ) + + mgr = PipelineOffloadManager.get_instance() + expected_offload_bytes = int( + sum(mgr.offload_summary_bytes.get(k, 0) for k in offload_modules) + ) + expected_offload_mib = expected_offload_bytes / (1024**2) + + _reset_cuda_memory() + off_outs, off_grads, off_peak = _run_schedule_1f1b_two_microbatches( + off_model, enable_offload_reset=True + ) + del off_model + _reset_cuda_memory() + + # Correctness (forward outputs + all grads) + assert len(off_outs) == len(base_outs) == 2 + for i in range(2): + assert torch.allclose(off_outs[i], base_outs[i], rtol=1e-3, atol=1e-3) + assert set(off_grads.keys()) == set(base_grads.keys()) + for name, gb in base_grads.items(): + go = off_grads[name] + if gb is None or go is None: + assert gb is None and go is None, f"Grad None mismatch for {name}" + continue + assert torch.allclose( + go, gb, rtol=1e-3, atol=1e-3 + ), f"Rank {torch.distributed.get_rank()}: Grad mismatch for {name}" + + # Memory checks (peak allocated during the scheduled 1F1B run) + saved_mib = (base_peak - off_peak) / (1024**2) + assert saved_mib > 0.0, ( + f"Expected GPU peak memory reduction for offload_modules={offload_modules}, " + f"but got saved={saved_mib:.2f}MiB (base={base_peak/(1024**2):.2f}MiB, " + f"off={off_peak/(1024**2):.2f}MiB)" + ) + # If expectation is large enough, enforce approximate match. + if expected_offload_mib >= 2.0: + rel_err = abs(saved_mib - expected_offload_mib) / max(expected_offload_mib, 1e-6) + abs_err = abs(saved_mib - expected_offload_mib) + print( + f"Rank {torch.distributed.get_rank()}: Saved {saved_mib:.2f}MiB, expected {expected_offload_mib:.2f}MiB" + ) + if abs_err > DELTA: + assert rel_err <= EPSILON_A2A, ( + f"Memory saving mismatch for offload_modules={offload_modules}: " + f"saved={saved_mib:.2f}MiB expected~={expected_offload_mib:.2f}MiB " + f"(rel_err={rel_err:.2f}, abs_err={abs_err:.2f})" + ) + finally: + Utils.destroy_model_parallel() From 884c335de2697eae4a2f9dbc6dd198ba145cd599 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Fri, 9 Jan 2026 02:32:11 -0800 Subject: [PATCH 35/74] update ut Signed-off-by: Hongbin Liu --- .../pipeline_parallel/fine_grained_activation_offload.py | 5 +++-- .../test_fine_grained_activation_offloading.py | 4 +++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index baabe29838a..b0f8c0e4cf0 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -973,7 +973,8 @@ def bulk_reload(self): else: # Pre-load the last layer of the next backward chunk to hide latency next_backward_chunk = PipelineOffloadManager.get_instance().front() - if next_backward_chunk is not None: + if next_backward_chunk is not None \ + and next_backward_chunk._offloaded_group_index == next_backward_chunk._max_group_size: next_backward_chunk.pre_reload_last_layer() def on_group_commit_backward(self, name): @@ -1031,7 +1032,7 @@ def on_group_start_backward(self): """ if not self.do_offload: return - debug_rank("--on_group_start_backward") + debug_rank(f"--on_group_start_backward {self}") # Wait for compute to finish before starting reload self.h2d_stream.wait_stream(torch.cuda.current_stream()) self.bulk_reload() diff --git a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py index 88d76bf5d80..4284ae834e0 100644 --- a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py +++ b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py @@ -19,7 +19,7 @@ # Tolerance for memory expectation check (GPU allocator jitter etc). EPSILON = 0.30 EPSILON_A2A = 0.30 -DELTA = 20 # MiB +DELTA = 20 # MiB def _reset_cuda_memory() -> None: @@ -144,6 +144,8 @@ def _run_one_iter_and_capture( (True, False, ["mlp_norm"]), (True, False, ["expert_fc1"]), (True, False, ["moe_act"]), + (True, True, ["attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act"]), + (True, False, ["core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act"]), ], ) def test_gpt_fine_grained_activation_offloading_correctness_and_memory( From df2e8397fa2bb40715366afc0c4992b2fab23ef4 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Fri, 9 Jan 2026 04:07:04 -0800 Subject: [PATCH 36/74] fix ut Signed-off-by: Hongbin Liu --- .../test_fine_grained_activation_offloading.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py index 4284ae834e0..692bf8773a7 100644 --- a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py +++ b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py @@ -144,8 +144,6 @@ def _run_one_iter_and_capture( (True, False, ["mlp_norm"]), (True, False, ["expert_fc1"]), (True, False, ["moe_act"]), - (True, True, ["attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act"]), - (True, False, ["core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act"]), ], ) def test_gpt_fine_grained_activation_offloading_correctness_and_memory( From 569f347bb3a3e83a25bf8c01edb7a86aff8aa649 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Sun, 11 Jan 2026 19:18:25 -0800 Subject: [PATCH 37/74] add version check Signed-off-by: Hongbin Liu --- megatron/training/arguments.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 91292f1ed6e..233d86c774d 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1264,6 +1264,9 @@ def validate_args(args, defaults={}): if args.fine_grained_activation_offloading: assert args.transformer_impl == 'transformer_engine', \ "Fine-grained activation offloading is only supported with transformer_engine implementation" + if is_te_min_version("2.10.0"): + assert os.getenv("NVTE_CPU_OFFLOAD_V1", "0") == "1", \ + "For fine-grained activation offloading with TE >= 2.10.0, NVTE_CPU_OFFLOAD_V1 should be set to 1 to avoid offloading weights." if args.mtp_num_layers: assert not args.use_legacy_models, "The legacy Megatron models does not support Multi-Token Prediction (MTP)." From 7f8109a1979534c159c23348d2a6bc30e137423c Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 12 Jan 2026 01:15:16 -0800 Subject: [PATCH 38/74] minor refactor for fine_grained_activation_offload.py Signed-off-by: Hongbin Liu --- .../common/model_chunk_schedule_plan.py | 2 +- .../fine_grained_activation_offload.py | 101 +++++++++--------- 2 files changed, 52 insertions(+), 51 deletions(-) diff --git a/megatron/core/models/common/model_chunk_schedule_plan.py b/megatron/core/models/common/model_chunk_schedule_plan.py index ed23e6ad391..02312c1c59b 100644 --- a/megatron/core/models/common/model_chunk_schedule_plan.py +++ b/megatron/core/models/common/model_chunk_schedule_plan.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. from contextlib import nullcontext from typing import Optional diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index b0f8c0e4cf0..a8cf2207e89 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -341,6 +341,9 @@ def __init__(self, name): self.offload = True self.total_offload_bytes = 0 self.total_tensor_count = 0 + # Using memory pool is for the compatibility with cuda graph. + # Shapes of tensors for expert_fc1 and moe_act are not known in advance, + # so we do not use CPU pool for them. if name == "expert_fc1" or name == "moe_act": self.use_cpu_pool = False else: @@ -410,16 +413,23 @@ def __init__(self): # Shared CPU tensor pool for all chunks to improve reuse efficiency self._cpu_tensor_pool = GPUTensorPool(device="cpu", pin_memory=True) + # Whether the manager is in warmup phase. self._is_warmup = True + # Cache OffloadChunkHandler objects for each virtual pipeline stage and each forward pass. self._cached_chunks_forward = [] + # Cache OffloadChunkHandler objects for each virtual pipeline stage and each backward pass. self._cached_chunks_backward = [] + # Index of the current backward chunk in the cached chunks backward. self._cached_chunks_index_backward = 0 + # Index of the current forward chunk in the cached chunks forward. self._cached_chunks_index_forward = 0 self.do_offload = True - # Margin to avoid offloading too many groups so that + # Do not offload the last X groups so that the reloading won't block the computing stream. self._offload_margin = 0 + # Sometimes we need to delay the offloading and launch it later. + # The delayed offload groups are stored in a queue. self._delayed_offload_groups = [] self.reset() @@ -446,6 +456,7 @@ def push_offload_groups(self, group_hook, forced_released_tensors): def flush_delayed_groups(self): """Flush the delayed groups.""" debug_rank("flushing delayed groups") + # Flush the delayed groups in reverse order to maintain the order of the groups. for group_hook, forced_released_tensors in reversed(self._delayed_offload_groups): group_hook(forced_released_tensors) self._delayed_offload_groups = [] @@ -459,6 +470,7 @@ def reset(self): if hasattr(self, '_cpu_tensor_pool'): self._cpu_tensor_pool.reset() + # Call post_warmup_callback after warmup to collect the offload information. if self._is_warmup and len(self._cached_chunks_forward) > 0: self.post_warmup_callback() self._cached_chunks_index_backward = 0 @@ -523,7 +535,7 @@ def post_warmup_callback(self): # Update the offload margin to the maximum number of deduplicated groups self._offload_margin = max(self._offload_margin, chunk.get_max_deduplicated_groups()) debug_rank(f"offload margin {self._offload_margin}") - # Fine the last group with the same name in the cached chunks backward + # Find the last group with the same name in the cached chunks backward last_group_with_same_name = {} for chunk_idx, chunk in enumerate(reversed(self._cached_chunks_backward)): for group in chunk.offload_groups: @@ -567,8 +579,8 @@ def push(self, handler): if self._is_warmup: self._cached_chunks_backward.append(handler) - def pop(self, name=None): - """Remove and set the next non-empty chunk as the current backward chunk.""" + def pop_backward_chunk(self, name=None): + """Get the next non-empty backward chunk containing the group with the given name.""" self._cur_backward_chunk = None debug_rank(f"popping backward chunk {self._cached_chunks_index_backward}") debug_rank(f"cached chunks backward {self._cached_chunks_backward}") @@ -584,8 +596,8 @@ def pop(self, name=None): break assert self._cur_backward_chunk is not None, "No non-empty chunk found" - def front(self, name=None): - """Get the first non-empty chunk handler without removing it from the queue.""" + def front_backward_chunk(self, name=None): + """Get the first non-empty backward chunk containing the group with the given name.""" for idx, handler in enumerate( self._cached_chunks_backward[self._cached_chunks_index_backward :] ): @@ -744,17 +756,19 @@ def reload(self, state, non_blocking=None): def __init__(self, min_offloaded_tensor_size, cpu_tensor_pool): self.do_offload = True - # Data Structure to maintain reference to activation tensors - self._tensor_tag_to_state = {} - # Mark the first microbatch of the last virtual pipeline stage - # self._is_first_last_vpp_chunk = is_first_last_vpp_chunk # Group management for batching offload/reload operations + self.offload_groups = [] self._offloaded_group_index = 0 + # Groups to be offloaded. self._groups_to_offload = [] + # Groups to be reloaded. self._groups_to_reload = [] + # Tensor count for the current group. self._tensor_count_current_group = 0 + # Maximum number of groups to offload or reload. self._max_group_size = 0 + # Groups being reloaded. self._reloading_group = [] # Counter for special torch tensor types (FakeTensor, FunctionalTensor) self.torch_tensor_count = 0 @@ -762,7 +776,6 @@ def __init__(self, min_offloaded_tensor_size, cpu_tensor_pool): self.h2d_stream = PipelineOffloadManager.get_instance().h2d_stream self.min_offloaded_tensor_size = min_offloaded_tensor_size self.cpu_tensor_pool = cpu_tensor_pool - self.offload_groups = [] self.is_warmup = True def reset(self): @@ -773,15 +786,18 @@ def reset(self): self._tensor_count_current_group = 0 self._reloading_group = [] + def find_group_with_name(self, name: str, start_index: int = 0): + """Find the group with the given name starting from the given index.""" + return next( + (group for group in self.offload_groups[start_index:] if group._name == name), + None, + ) + def is_empty_chunk(self, name=None): """Check if this chunk has no tensors to manage.""" debug_rank(f"------is_empty_chunk {self._max_group_size}") if name is not None: - for group in self.offload_groups: - debug_rank(f"group name {group._name} need name {name}") - if group._name == name: - return False - return True + return self.find_group_with_name(name) is None return self._max_group_size == 0 def finish_all_groups(self, name=None) -> bool: @@ -798,18 +814,12 @@ def finish_all_groups(self, name=None) -> bool: ): return True assert name is not None, "Name is required" - for group in self.offload_groups[self._offloaded_group_index :]: - if group._name == name: - return False - return True + return self.find_group_with_name(name, self._offloaded_group_index) is None def find_next_group(self, name=None): """Find the next group with the given name.""" assert name is not None, "Name is required" - for group in self.offload_groups[self._offloaded_group_index :]: - if group._name == name: - return group - return None + return self.find_group_with_name(name, self._offloaded_group_index) def tensor_push(self, tensor): """Push tensor to the offload handler.""" @@ -822,30 +832,19 @@ def tensor_push(self, tensor): ) assert not torch_stray_tensor, "Stray tensor should not be offloaded" - if not torch_stray_tensor: - # Assign unique tag based on group index and position within group - tensor_tag = (self._offloaded_group_index, self._tensor_count_current_group) - self._tensor_count_current_group += 1 - # assert tensor_tag not in self._tensor_tag_to_state, "Duplicate tensor tag" - # self._tensor_tag_to_state[tensor_tag] = tensor - self.offload_groups[self._offloaded_group_index - 1].push_tensor(tensor_tag, tensor) - else: - # Use negative group ID for special tensor types - tensor_tag = (-1, self.torch_tensor_count) - self.torch_tensor_count += 1 - # self._tensor_tag_to_state[tensor_tag] = tensor + # Assign unique tag based on group index and position within group + tensor_tag = (self._offloaded_group_index, self._tensor_count_current_group) + self._tensor_count_current_group += 1 + self.offload_groups[self._offloaded_group_index - 1].push_tensor(tensor_tag, tensor) debug_rank(f"--------tensor_push {tensor_tag}") return tensor_tag def tensor_pop(self, tensor_tag): """Pop tensor from the offload handler.""" debug_rank(f"--------tensor_pop {tensor_tag}") - # assert tensor_tag in self._tensor_tag_to_state, f"Tag {tensor_tag} not found" - # tensor = self._tensor_tag_to_state.pop(tensor_tag) group_id, idx = tensor_tag tensor = self.offload_groups[group_id - 1].pop_tensor(tensor_tag) # If tensor is offloaded (stored as tuple), reload it - # assert isinstance(tensor, torch.Tensor), "Tensor is not a tensor" if isinstance(tensor, tuple): tensor = self.reload(tensor) debug_rank(f"--------tensor_pop {tensor.shape}") @@ -869,7 +868,6 @@ def bulk_offload_group(self): group_to_offload = self._groups_to_offload[-1] torch.cuda.nvtx.range_push("activation offloading " + group_to_offload._name) with torch.cuda.stream(self.d2h_stream): - # for tensor_tag, state in self._tensor_tag_to_state.items(): for tensor_tag, tensor_on_device in group_to_offload._tensors.items(): if self.tensor_need_offloading_checker(tensor_on_device): state = self.offload( @@ -908,6 +906,7 @@ def bulk_reload_group(self): group_to_reload.push_tensor(tensor_tag, recovered_tensor) group_to_reload.record_reload_event(self.h2d_stream) self._groups_to_reload.pop() + # Add the group to the reloading group to wait for the reload event. self._reloading_group.append(group_to_reload) torch.cuda.nvtx.range_pop() @@ -921,6 +920,7 @@ def pre_reload_last_layer(self): def should_bulk_offload(self): """Determine if the current group should be offloaded.""" + assert len(self._groups_to_offload) > 0, "No groups to offload" group = self._groups_to_offload[-1] debug_rank(f"should_bulk_offload {self.is_warmup} {group.offload}") # Don't offload if the chunk is not in warmup stage @@ -931,7 +931,8 @@ def should_bulk_offload(self): return False # Check if next backward chunk is this chunk (for last pipeline stage) - next_backward_chunk = PipelineOffloadManager.get_instance().front(name=group._name) + next_backward_chunk = \ + PipelineOffloadManager.get_instance().front_backward_chunk(group._name) if next_backward_chunk is not None and next_backward_chunk is self: # Don't offload the last group with the same name if it's about to be used immediately if self.find_next_group(group._name) is None: @@ -972,9 +973,13 @@ def bulk_reload(self): self.bulk_reload_group() else: # Pre-load the last layer of the next backward chunk to hide latency - next_backward_chunk = PipelineOffloadManager.get_instance().front() - if next_backward_chunk is not None \ - and next_backward_chunk._offloaded_group_index == next_backward_chunk._max_group_size: + next_backward_chunk = PipelineOffloadManager.get_instance().front_backward_chunk() + # Don't pre-reload the last layer if the next backward chunk hasn't finished fprop yet. + if ( + next_backward_chunk is not None + and next_backward_chunk._offloaded_group_index + == next_backward_chunk._max_group_size + ): next_backward_chunk.pre_reload_last_layer() def on_group_commit_backward(self, name): @@ -988,7 +993,7 @@ def on_group_commit_backward(self, name): cur_backward_chunk = PipelineOffloadManager.get_instance().cur_backward_chunk() # Switch to this chunk if it's not already current if cur_backward_chunk is not self: - PipelineOffloadManager.get_instance().pop(name) + PipelineOffloadManager.get_instance().pop_backward_chunk(name) cur_backward_chunk = PipelineOffloadManager.get_instance().cur_backward_chunk() assert cur_backward_chunk is self, f"Chunk mismatch {cur_backward_chunk} {self}" # Wait for reload to complete before using tensors @@ -1007,17 +1012,13 @@ def on_group_start_forward(self, name): if not self.do_offload: return debug_rank(f"--on_group_start_forward {name}") + self._offloaded_group_index = self._offloaded_group_index + 1 if self.is_warmup: - self._offloaded_group_index = self._offloaded_group_index + 1 self.offload_groups.append(OffloadTensorGroup(name)) self._max_group_size = max(self._max_group_size, self._offloaded_group_index) debug_rank(f"max group size {self._max_group_size}") else: - self._offloaded_group_index = self._offloaded_group_index + 1 for group in self.offload_groups[self._offloaded_group_index - 1 :]: - debug_rank( - f"offloaded group index {self._offloaded_group_index} for group {group._name}" - ) if group._name == name: break self._offloaded_group_index = self._offloaded_group_index + 1 From 19b35c317ee4be2d35a6bd40ed52161412296552 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 12 Jan 2026 01:32:28 -0800 Subject: [PATCH 39/74] format Signed-off-by: Hongbin Liu --- .../pipeline_parallel/fine_grained_activation_offload.py | 8 ++++---- .../test_fine_grained_activation_offloading.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index a8cf2207e89..fa23dd2417f 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -789,8 +789,7 @@ def reset(self): def find_group_with_name(self, name: str, start_index: int = 0): """Find the group with the given name starting from the given index.""" return next( - (group for group in self.offload_groups[start_index:] if group._name == name), - None, + (group for group in self.offload_groups[start_index:] if group._name == name), None ) def is_empty_chunk(self, name=None): @@ -931,8 +930,9 @@ def should_bulk_offload(self): return False # Check if next backward chunk is this chunk (for last pipeline stage) - next_backward_chunk = \ - PipelineOffloadManager.get_instance().front_backward_chunk(group._name) + next_backward_chunk = PipelineOffloadManager.get_instance().front_backward_chunk( + group._name + ) if next_backward_chunk is not None and next_backward_chunk is self: # Don't offload the last group with the same name if it's about to be used immediately if self.find_next_group(group._name) is None: diff --git a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py index 692bf8773a7..88d76bf5d80 100644 --- a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py +++ b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py @@ -19,7 +19,7 @@ # Tolerance for memory expectation check (GPU allocator jitter etc). EPSILON = 0.30 EPSILON_A2A = 0.30 -DELTA = 20 # MiB +DELTA = 20 # MiB def _reset_cuda_memory() -> None: From dd874a92c70dcdb6be5088e89cf6eb6b7d303e8b Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 12 Jan 2026 18:07:51 -0800 Subject: [PATCH 40/74] support partial cuda graph Signed-off-by: Hongbin Liu --- megatron/core/transformer/cuda_graphs.py | 15 +++ megatron/core/transformer/module.py | 4 + megatron/core/transformer/moe/experts.py | 6 +- .../core/transformer/transformer_config.py | 20 ++++ .../core/transformer/transformer_layer.py | 111 +++++++++++++++--- megatron/training/arguments.py | 2 + 6 files changed, 142 insertions(+), 16 deletions(-) diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index b566c1830dc..691b129d8bf 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -1898,6 +1898,16 @@ def _get_fp8_enabled(): ) else: kwargs['fp8_enabled'] = False + + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + fine_grained_offloading_disable_offload, + fine_grained_offloading_enable_offload, + ) + + # if self.config.offload_module_in_cuda_graph: + if self.config.fine_grained_activation_offloading: + kwargs['pre_warmup_hook'] = fine_grained_offloading_disable_offload + kwargs['post_warmup_hook'] = fine_grained_offloading_enable_offload return kwargs kwargs = get_make_graphed_callables_kwargs() @@ -1932,8 +1942,13 @@ def _finish_capturing(self, start_time): _set_capture_end() from megatron.core.distributed.finalize_model_grads import reset_model_temporary_tensors + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + fine_grained_offloading_reset, + ) from megatron.core.transformer.moe.moe_utils import clear_aux_losses_tracker + fine_grained_offloading_reset() + torch.distributed.barrier() for model_chunk in self.model: model_chunk.zero_grad_buffer() diff --git a/megatron/core/transformer/module.py b/megatron/core/transformer/module.py index d68f34ffd0b..4fed8177a91 100644 --- a/megatron/core/transformer/module.py +++ b/megatron/core/transformer/module.py @@ -285,6 +285,10 @@ def _get_te_cuda_graph_replay_args(self, *args, **kwargs): cudagraph_kwargs = kwargs.copy() cudagraph_kwargs['is_first_microbatch'] = getattr(self, 'current_microbatch', 0) == 0 + from megatron.core.transformer.transformer_layer import TransformerLayer + + cudagraph_kwargs['cuda_graph_stream'] = TransformerLayer.cuda_graph_stream + cudagraph_kwargs['cuda_graph_event'] = TransformerLayer.cuda_graph_event return cudagraph_args, cudagraph_kwargs def _should_call_local_cudagraph(self, *args, **kwargs): diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index fa2653a030a..9b8d111b153 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -744,6 +744,7 @@ def forward( fc1_output, name="expert_fc1", forced_released_tensors=[permuted_local_hidden_states], + delay_offload=self.config.delay_offload_until_cuda_graph, ) def bias_act_func(intermediate_parallel, bias_parallel, permuted_probs): @@ -822,7 +823,10 @@ def glu(x): self.activation_checkpoint.discard_output_and_register_recompute(output) if self.offload_moe_act: output = fine_grained_offloading_group_commit( - output, name="moe_act", forced_released_tensors=[fc1_output] + output, + name="moe_act", + forced_released_tensors=[fc1_output], + delay_offload=self.config.delay_offload_until_cuda_graph, ) # upad and concat the output diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 3a57f09f6cf..e98c9a69993 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -867,6 +867,9 @@ class TransformerConfig(ModelParallelConfig): min_offloaded_tensor_size: int = 1024 * 1024 """The minimum size of the tensor to be offloaded.""" + delay_offload_until_cuda_graph: bool = False + """If True, delay the offload until the CUDA graph is executed for minimal CPU overhead.""" + def __post_init__(self): """Python dataclass method that is used to modify attributes after initialization. See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more @@ -1250,6 +1253,7 @@ def __post_init__(self): "attn_norm", "mlp_norm", "qkv_linear", + "dense_mlp", } invalid_modules = set(self.offload_modules) - allowed_modules assert not invalid_modules, ( @@ -1262,6 +1266,22 @@ def __post_init__(self): "because the input of attn_proj is the output of core_attn, " "which is needed in core_attn.backward()." ) + if self.external_cuda_graph or self.cuda_graph_impl == "transformer_engine": + assert ( + self.cuda_graph_scope is not None + ), "cuda_graph_scope must be set when enabling offloading." + assert ( + "attn" in self.cuda_graph_scope and "moe_router" in self.cuda_graph_scope + ) or ( + CudaGraphScope.attn in self.cuda_graph_scope + and CudaGraphScope.moe_router in self.cuda_graph_scope + ), "attn and moe_router must be in cuda_graph_scope when enabling offloading." + assert ( + "attn_norm" not in self.offload_modules + ), "input of attn_norm is the start point of cuda graph, which can't be offloaded." + assert ( + "mlp_norm" not in self.offload_modules + ), "mlp_norm goes through the boundary of cuda graph, which can't be offloaded." if ( self.num_layers_in_first_pipeline_stage is not None diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 6941963f6b0..41f82f9d80c 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -27,6 +27,7 @@ deprecate_inference_params, get_pg_rank, is_te_min_version, + is_torch_min_version, log_single_rank, make_viewless_tensor, nvtx_range_pop, @@ -260,6 +261,9 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): output of the same size. """ + cuda_graph_stream = None + cuda_graph_event = None + def __init__( self, config: TransformerConfig, @@ -413,17 +417,8 @@ def __init__( if "mlp" in self.config.recompute_modules: if not self.is_moe_layer: self.recompute_mlp = True - self.offload_attn_norm = ( - self.config.fine_grained_activation_offloading - and "attn_norm" in self.config.offload_modules - and not isinstance(self.input_layernorm, IdentityOp) - ) - self.offload_mlp_norm = ( - self.config.fine_grained_activation_offloading - and "mlp_norm" in self.config.offload_modules - and not isinstance(self.pre_mlp_layernorm, IdentityOp) - ) + self._set_offload_modules() # @jcasper how should we handle nvfuser? # Set bias+dropout+add fusion grad_enable execution handler. # TORCH_MAJOR = int(torch.__version__.split('.')[0]) @@ -517,6 +512,15 @@ def _forward_attention( get_fine_grained_offloading_context, ) + if self.offload_module_in_cuda_graph: + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + fine_grained_offloading_backward_record, + ) + + hidden_states = fine_grained_offloading_backward_record( + hidden_states, TransformerLayer.cuda_graph_event + ) + inference_context = deprecate_inference_params(inference_context, inference_params) # Residual connection. @@ -598,7 +602,9 @@ def _forward_attention( return hidden_states, context - def _forward_mlp(self, hidden_states, inference_context=None, padding_mask=None): + def _forward_mlp( + self, hidden_states, inference_context=None, padding_mask=None, flush_delayed_groups=True + ): """ Perform a forward pass through the feed-forward layer. @@ -616,6 +622,7 @@ def _forward_mlp(self, hidden_states, inference_context=None, padding_mask=None) """ from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + fine_grained_offloading_group_commit, fine_grained_offloading_group_start, get_fine_grained_offloading_context, ) @@ -701,9 +708,9 @@ def _forward_mlp(self, hidden_states, inference_context=None, padding_mask=None) ) nvtx_range_pop(suffix="mlp") - return self._forward_post_mlp(mlp_output_with_bias, residual) + return self._forward_post_mlp(mlp_output_with_bias, residual, flush_delayed_groups) - def _forward_post_mlp(self, mlp_output_with_bias, residual): + def _forward_post_mlp(self, mlp_output_with_bias, residual, flush_delayed_groups=True): """ Perform operations after the MLP computation. @@ -742,6 +749,12 @@ def _forward_post_mlp(self, mlp_output_with_bias, residual): inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True ) + if self.config.delay_offload_until_cuda_graph and flush_delayed_groups: + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + fine_grained_offloading_group_flush_delayed_groups, + ) + + fine_grained_offloading_group_flush_delayed_groups() return output def sharded_state_dict( @@ -852,6 +865,12 @@ def _te_cuda_graph_capture(self, *args, **kwargs): cuda_graph_outputs = list(hidden_states) if context is not None: cuda_graph_outputs.append(context) + if self.offload_module_in_cuda_graph: + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + fine_grained_offloading_forward_record, + ) + + fine_grained_offloading_forward_record(TransformerLayer.cuda_graph_event) return tuple(cuda_graph_outputs) def _te_cuda_graph_replay(self, *args, **kwargs): @@ -877,6 +896,13 @@ def _te_cuda_graph_replay(self, *args, **kwargs): cuda_graph_output = list(super()._te_cuda_graph_replay(*args, **kwargs)) + if self.config.delay_offload_until_cuda_graph: + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + fine_grained_offloading_group_flush_delayed_groups, + ) + + fine_grained_offloading_group_flush_delayed_groups() + if kwargs.get('context') is not None: context = cuda_graph_output.pop() @@ -947,7 +973,9 @@ def _te_cuda_graph_replay(self, *args, **kwargs): self.mlp.cudagraph_tensor_store.clear() nvtx_range_pop(suffix="mlp") - output = self._forward_post_mlp(mlp_output_with_bias, mlp_residual) + output = self._forward_post_mlp( + mlp_output_with_bias, mlp_residual, flush_delayed_groups=False + ) else: # If EP overlap is enabled, needs to return same outputs as submodule.attn if self.config.overlap_moe_expert_parallel_comm: @@ -964,7 +992,7 @@ def _te_cuda_graph_replay(self, *args, **kwargs): return mlp_residual, hidden_states, probs, shared_expert_output # CUDA Graph does not capture the MLP/MoE part at all. - output = self._forward_mlp(*cuda_graph_output) + output = self._forward_mlp(*cuda_graph_output, flush_delayed_groups=False) return output, context def _get_te_cuda_graph_replay_args(self, *args, **kwargs): @@ -1070,3 +1098,56 @@ def __call__(self, *args, **kwargs): 'inference_context' ].is_decode_only() return super().__call__(*args, **kwargs) + + def _set_offload_modules(self): + """Set the offload modules for the transformer layer.""" + if self.config.fine_grained_activation_offloading: + self.offload_attn_norm = "attn_norm" in self.config.offload_modules and not isinstance( + self.input_layernorm, IdentityOp + ) + self.offload_qkv_linear = "qkv_linear" in self.config.offload_modules + self.offload_core_attn = "core_attn" in self.config.offload_modules + self.offload_attn_proj = "attn_proj" in self.config.offload_modules + self.offload_mlp_norm = "mlp_norm" in self.config.offload_modules and not isinstance( + self.pre_mlp_layernorm, IdentityOp + ) + self.offload_expert_fc1 = "expert_fc1" in self.config.offload_modules + self.offload_moe_act = "moe_act" in self.config.offload_modules + self.offload_dense_mlp = ( + "dense_mlp" in self.config.offload_modules and not self.is_moe_layer + ) + else: + self.offload_attn_norm = False + self.offload_qkv_linear = False + self.offload_core_attn = False + self.offload_attn_proj = False + self.offload_mlp_norm = False + self.offload_expert_fc1 = False + self.offload_moe_act = False + self.offload_dense_mlp = False + # Set the offload module in cuda graph flag. + self.offload_module_in_cuda_graph = False + if CudaGraphScope.attn in self.config.cuda_graph_scope: + if self.offload_core_attn or self.offload_attn_proj or self.offload_qkv_linear: + self.offload_module_in_cuda_graph = True + if not self.is_moe_layer and CudaGraphScope.mlp in self.config.cuda_graph_scope: + if self.offload_mlp_norm or self.offload_dense_mlp: + self.offload_module_in_cuda_graph = True + if self.offload_module_in_cuda_graph: + assert is_torch_min_version( + "2.9.0a0" + ), "Fine-grained activation offloading needs torch>=2.9.0 to support cuda graph." + assert ( + self.config.cuda_graph_warmup_steps > 0 + ), "Fine-grained activation offloading needs cuda_graph_warmup_steps > 0." + # Set the cuda graph stream and event for the transformer layer. + if TransformerLayer.cuda_graph_stream is None: + if self.offload_module_in_cuda_graph: + TransformerLayer.cuda_graph_stream = torch.cuda.Stream() + else: + TransformerLayer.cuda_graph_stream = torch.cuda.current_stream() + if TransformerLayer.cuda_graph_event is None: + if self.offload_module_in_cuda_graph: + TransformerLayer.cuda_graph_event = torch.cuda.Event(external=True) + else: + TransformerLayer.cuda_graph_event = torch.cuda.Event() diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 13f6e041a36..fc7b80f7b8d 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -2438,6 +2438,8 @@ def _add_training_args(parser): help='The submodules to offload its input. Choices: "attn_norm", "qkv_linear", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act".') group.add_argument('--min-offloaded-tensor-size', type=int, default=10*1024*1024, help='The minimum size of the tensor to be offloaded.') + group.add_argument('--delay-offload-until-cuda-graph', action='store_true', + help='Delay the offload until the CUDA graph is executed for minimal CPU overhead.') group.add_argument('--disable-jit-fuser', action='store_true', help='Disable the JIT fuser.') group.add_argument('--batch-invariant-mode', action='store_true', From b8c0b790aaef9e3c8c3650fdbf446cc3daac74dd Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 12 Jan 2026 23:58:48 -0800 Subject: [PATCH 41/74] fix bug when working with a2a overlap and cuda graph Signed-off-by: Hongbin Liu --- megatron/core/models/gpt/fine_grained_callables.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index dc3c15ab269..6ba9974c033 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -428,7 +428,9 @@ def submodule_attn_forward(node: ScheduleNode, hidden_states: torch.Tensor): attn_backward_dw_wrapper.set_graphed_backward_dw_callable( partial(layer.backward_dw_cudagraph, layer.current_microbatch) ) + node.chunk_state.flush_delayed_groups = False else: + node.chunk_state.flush_delayed_groups = True # wrapper function that keeps consistent api with cuda graph replay def forward_func( hidden_states: Tensor, @@ -560,6 +562,12 @@ def submodule_combine_forward(node: ScheduleNode, output: torch.Tensor): inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True ) + if node.chunk_state.flush_delayed_groups: + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + fine_grained_offloading_group_flush_delayed_groups, + ) + fine_grained_offloading_group_flush_delayed_groups() + # Need to record residual to comm stream, since it's created on comp stream node.layer_state.residual.record_stream(torch.cuda.current_stream()) From 994fc5a1db09948a0cf6f6c8adf44da0d150cd85 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Tue, 13 Jan 2026 00:57:31 -0800 Subject: [PATCH 42/74] support offloading less for large pp rank Signed-off-by: Hongbin Liu --- megatron/core/models/gpt/gpt_model.py | 2 ++ .../fine_grained_activation_offload.py | 23 +++++++++++++++---- .../core/transformer/transformer_config.py | 3 +++ megatron/training/arguments.py | 2 ++ 4 files changed, 26 insertions(+), 4 deletions(-) diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 9e70c677226..ea2f00ccf5e 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -432,9 +432,11 @@ def _preprocess( def preprocess_for_fine_grained_offloading(self): """Preprocess for fine-grained activation offloading.""" fine_grained_offloading_init_chunk_handler( + pp_rank=self.pg_collection.pp.rank(), vp_size=self.config.virtual_pipeline_model_parallel_size, vp_stage=self.vp_stage, min_offloaded_tensor_size=self.config.min_offloaded_tensor_size, + delta_offload_bytes_across_pp_ranks=self.config.delta_offload_bytes_across_pp_ranks, ) if self.disable_param_offloading: for param in self.decoder.parameters(): diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index fa23dd2417f..6224adf1a2e 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -549,8 +549,15 @@ def post_warmup_callback(self): debug_rank(f"setting offload to false for group {name} at chunk index {chunk_idx}") else: break - debug_rank(f"offload margin {self._offload_margin}") assert self._offload_margin == 0, "Offload margin is not 0" + keep_on_gpu_bytes = self._pp_rank * self._delta_offload_bytes_across_pp_ranks + for chunk in self._cached_chunks_backward: + for group in chunk.offload_groups: + if group.offload and keep_on_gpu_bytes > 0: + debug_rank(f"group {group._name} offload {group.offload} \ + keep_on_gpu_bytes {keep_on_gpu_bytes}") + keep_on_gpu_bytes -= group.total_offload_bytes + group.offload = False # Dump the offload information total_tensor_count = {} total_offload_bytes = {} @@ -607,15 +614,19 @@ def front_backward_chunk(self, name=None): return None def init_model_chunk_offload_handler( - self, vp_size, vp_stage, min_offloaded_tensor_size=1024 * 1024 + self, pp_rank, vp_size, vp_stage, min_offloaded_tensor_size=1024 * 1024, + delta_offload_bytes_across_pp_ranks=0 ): """ Initialize a chunk offload handler for a model chunk (microbatch). Args: + pp_rank: Pipeline parallel rank vp_size: Virtual pipeline size vp_stage: Virtual pipeline stage index (None means stage 0) min_offloaded_tensor_size: Minimum tensor size (in elements) to offload + delta_offload_bytes_across_pp_ranks: + Difference of offload bytes across PP ranks to balance the offload load. """ if not self._is_warmup: return @@ -625,6 +636,9 @@ def init_model_chunk_offload_handler( self._vpp = vp_size self._stages = [[] for _ in range(vp_size)] + self._delta_offload_bytes_across_pp_ranks = delta_offload_bytes_across_pp_ranks + self._pp_rank = pp_rank + if vp_stage is None: cur_vpp_rank = 0 else: @@ -1169,10 +1183,11 @@ def get_fine_grained_offloading_context(flag): return PipelineOffloadManager.get_instance() if flag else nullcontext() -def fine_grained_offloading_init_chunk_handler(vp_size, vp_stage, min_offloaded_tensor_size): +def fine_grained_offloading_init_chunk_handler(pp_rank, vp_size, vp_stage, \ + min_offloaded_tensor_size, delta_offload_bytes_across_pp_ranks): """Initialize the chunk handler, called at the start of a microbatch forward pass.""" PipelineOffloadManager.get_instance().init_model_chunk_offload_handler( - vp_size, vp_stage, min_offloaded_tensor_size + pp_rank, vp_size, vp_stage, min_offloaded_tensor_size, delta_offload_bytes_across_pp_ranks ) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index e98c9a69993..9b26af1b75f 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -870,6 +870,9 @@ class TransformerConfig(ModelParallelConfig): delay_offload_until_cuda_graph: bool = False """If True, delay the offload until the CUDA graph is executed for minimal CPU overhead.""" + delta_offload_bytes_across_pp_ranks: int = 0 + """Difference of offload bytes across PP ranks to balance the offload load.""" + def __post_init__(self): """Python dataclass method that is used to modify attributes after initialization. See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index fc7b80f7b8d..936f7af4089 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -2440,6 +2440,8 @@ def _add_training_args(parser): help='The minimum size of the tensor to be offloaded.') group.add_argument('--delay-offload-until-cuda-graph', action='store_true', help='Delay the offload until the CUDA graph is executed for minimal CPU overhead.') + group.add_argument('--delta-offload-bytes-across-pp-ranks', type=int, default=0, + help='Difference of offload bytes across PP ranks to balance the offload load.') group.add_argument('--disable-jit-fuser', action='store_true', help='Disable the JIT fuser.') group.add_argument('--batch-invariant-mode', action='store_true', From 2688c7e898c41f48ecd95cfe01c135ada38b3eb5 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Tue, 13 Jan 2026 17:58:39 -0800 Subject: [PATCH 43/74] fix doc Signed-off-by: Hongbin Liu --- .../api-guide/fine_grained_activation_offloading.md | 2 +- .../offloading_and_recomputing.png | Bin megatron/core/transformer/moe/README.md | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename docs/{source => }/images/fine_grained_activation_offloading/offloading_and_recomputing.png (100%) diff --git a/docs/api-guide/fine_grained_activation_offloading.md b/docs/api-guide/fine_grained_activation_offloading.md index 969098263fc..53211d1d06c 100644 --- a/docs/api-guide/fine_grained_activation_offloading.md +++ b/docs/api-guide/fine_grained_activation_offloading.md @@ -28,4 +28,4 @@ Currently, the supported offloading modules are `"attn_norm", "core_attn", "attn - For other modules, use offloading to reduce memory footprint; - Make sure the offloading/reloading could be overlapped with computing; -![Fine-grained Activation Offloading and Fine-grained Recomputation](../images/fine_grained_activation_offloading/offloading_and_recomputing.png) +![Fine-grained Activation Offloading and Fine-grained Recomputation](../../images/fine_grained_activation_offloading/offloading_and_recomputing.png) diff --git a/docs/source/images/fine_grained_activation_offloading/offloading_and_recomputing.png b/docs/images/fine_grained_activation_offloading/offloading_and_recomputing.png similarity index 100% rename from docs/source/images/fine_grained_activation_offloading/offloading_and_recomputing.png rename to docs/images/fine_grained_activation_offloading/offloading_and_recomputing.png diff --git a/megatron/core/transformer/moe/README.md b/megatron/core/transformer/moe/README.md index a44daea38e2..1e5b49fdcc4 100644 --- a/megatron/core/transformer/moe/README.md +++ b/megatron/core/transformer/moe/README.md @@ -222,7 +222,7 @@ Offload the input activation at the granularity of modules # Choices: "attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act". --offload-modules expert_fc1 ``` -For more details, please refer to the ```docs/source/api-guide/fine_grained_activation_offloading.md``` +For more details, please refer to the ```docs/user-guide/features/fine_grained_activation_offloading.md``` ### MoE Related Arguments | Item | Description | From aab845559696675be41815a6421cd3e0612eff15 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Tue, 13 Jan 2026 18:47:12 -0800 Subject: [PATCH 44/74] code refactor Signed-off-by: Hongbin Liu --- .../core/models/gpt/fine_grained_callables.py | 27 +++++++- megatron/core/models/gpt/gpt_model.py | 10 +-- .../fine_grained_activation_offload.py | 64 +++++++++++++------ megatron/core/pipeline_parallel/schedules.py | 8 +-- megatron/core/transformer/attention.py | 10 +-- megatron/core/transformer/moe/experts.py | 10 +-- .../transformer/multi_latent_attention.py | 16 ++--- .../core/transformer/transformer_layer.py | 16 +++-- ...test_fine_grained_activation_offloading.py | 18 +++--- 9 files changed, 115 insertions(+), 64 deletions(-) diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index dc3c15ab269..4ddbab70e8b 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -10,10 +10,12 @@ from megatron.core import tensor_parallel from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, +) from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( fine_grained_offloading_group_commit, fine_grained_offloading_group_start, - get_fine_grained_offloading_context, ) from megatron.core.pipeline_parallel.utils import ScheduleNode, make_viewless from megatron.core.transformer.enums import CudaGraphScope @@ -480,8 +482,27 @@ def forward_func( packed_seq_params=node.chunk_state.packed_seq_params, sequence_len_offset=node.chunk_state.sequence_len_offset, ) - if not isinstance(layer.mlp, MoELayer): - return hidden_states + return hidden_states + + def submodule_post_attn_forward(node: ScheduleNode, hidden_states: torch.Tensor): + """ + Run forward pass for computations between attention and dispatch: + pre mlp layernorm->router->dispatch preprocess + """ + if layer.offload_mlp_norm: + hidden_states = fine_grained_offloading_group_start(hidden_states, name="mlp_norm") + if layer.recompute_pre_mlp_layernorm: + layer.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput() + with off_interface.get_context(layer.offload_mlp_norm): + pre_mlp_layernorm_output = layer.pre_mlp_norm_checkpoint.checkpoint( + layer.pre_mlp_layernorm, hidden_states + ) + else: + with off_interface.get_context(layer.offload_mlp_norm): + pre_mlp_layernorm_output = layer.pre_mlp_layernorm(hidden_states) + + probs, routing_map = layer.mlp.route(pre_mlp_layernorm_output) + local_tokens, probs = layer.mlp.preprocess(pre_mlp_layernorm_output, probs, routing_map) # Detach here for mlp_bda residual connection node.layer_state.residual = node.detach(hidden_states) diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 9e70c677226..16462d6e426 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -19,7 +19,7 @@ from megatron.core.models.common.language_module.language_module import LanguageModule from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - fine_grained_offloading_init_chunk_handler, + FineGrainedActivationOffloadingInterface as off_interface, ) from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.quantization.utils import get_quant_config_or_none @@ -431,20 +431,20 @@ def _preprocess( def preprocess_for_fine_grained_offloading(self): """Preprocess for fine-grained activation offloading.""" - fine_grained_offloading_init_chunk_handler( + off_interface.init_chunk_handler( vp_size=self.config.virtual_pipeline_model_parallel_size, vp_stage=self.vp_stage, min_offloaded_tensor_size=self.config.min_offloaded_tensor_size, ) if self.disable_param_offloading: for param in self.decoder.parameters(): - param.offloading_activation = False + off_interface.mark_not_offloadable(param) if self.mtp_process: for param in self.mtp.parameters(): - param.offloading_activation = False + off_interface.mark_not_offloadable(param) if self.post_process: for param in self.output_layer.parameters(): - param.offloading_activation = False + off_interface.mark_not_offloadable(param) self.disable_param_offloading = False def forward( diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index fa23dd2417f..4e6c8219ee2 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -667,6 +667,11 @@ def cur_backward_chunk(self): """Get the current backward pass chunk handler.""" return self._cur_backward_chunk + def mark_not_offloadable(self, tensor: torch.Tensor): + """Mark the current forward chunk as not offloadable.""" + if tensor is not None: + tensor.offloading_activation = False + def __enter__(self): """Enter context manager to enable activation offloading hooks.""" debug_rank("----__enter__") @@ -1164,23 +1169,6 @@ def fine_grained_offloading_group_start(tensor, name=None): return FineGrainedOffloadingGroupStartFunction.apply(tensor, cur_forward_chunk, name) -def get_fine_grained_offloading_context(flag): - """Get the fine-grained offload context""" - return PipelineOffloadManager.get_instance() if flag else nullcontext() - - -def fine_grained_offloading_init_chunk_handler(vp_size, vp_stage, min_offloaded_tensor_size): - """Initialize the chunk handler, called at the start of a microbatch forward pass.""" - PipelineOffloadManager.get_instance().init_model_chunk_offload_handler( - vp_size, vp_stage, min_offloaded_tensor_size - ) - - -def fine_grained_offloading_reset(): - """Reset the chunk handler, called at the start of a training iteration.""" - PipelineOffloadManager.get_instance().reset() - - def fine_grained_offloading_forward_record(event: torch.cuda.Event) -> None: """Record the forward event for cuda graph capture.""" d2h_stream = PipelineOffloadManager.get_instance().d2h_stream @@ -1214,6 +1202,42 @@ def fine_grained_offloading_backward_record(tensor, event: torch.cuda.Event) -> return FineGrainedOffloadingBackwardRecordFunction.apply(tensor, event) -def fine_grained_offloading_reset_instance(): - """Reset the singleton instance of PipelineOffloadManager.""" - PipelineOffloadManager.reset_instance() +class FineGrainedActivationOffloadingInterface: + """Interface for fine-grained activation offloading.""" + + def __init__(self): + pass + + @staticmethod + def init_chunk_handler(vp_size, vp_stage, min_offloaded_tensor_size): + """Initialize the chunk handler, called at the start of a microbatch forward pass.""" + PipelineOffloadManager.get_instance().init_model_chunk_offload_handler( + vp_size, vp_stage, min_offloaded_tensor_size + ) + + @staticmethod + def get_context(flag): + """Get the fine-grained offload context""" + return PipelineOffloadManager.get_instance() if flag else nullcontext() + + @staticmethod + def mark_not_offloadable(tensor: torch.Tensor): + """Mark the tensor as not offloadable.""" + PipelineOffloadManager.get_instance().mark_not_offloadable(tensor) + + @staticmethod + def forward_record(event: torch.cuda.Event) -> None: + """Record the forward event for cuda graph capture.""" + d2h_stream = PipelineOffloadManager.get_instance().d2h_stream + torch.cuda.current_stream().record_event(event) + torch.cuda.current_stream().wait_stream(d2h_stream) + + @staticmethod + def reset(): + """Reset the chunk handler.""" + PipelineOffloadManager.get_instance().reset() + + @staticmethod + def reset_instance(): + """Reset the singleton instance.""" + PipelineOffloadManager.reset_instance() diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 4fa48b05e11..dadbd199ab7 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -11,7 +11,7 @@ from megatron.core import parallel_state from megatron.core.enums import ModelType from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - fine_grained_offloading_reset, + FineGrainedActivationOffloadingInterface as off_interface, ) from megatron.core.pipeline_parallel.p2p_communication import P2PCommunicator from megatron.core.pipeline_parallel.utils import ( @@ -680,7 +680,7 @@ def forward_backward_no_pipelining( ) if not forward_only and config.fine_grained_activation_offloading: - fine_grained_offloading_reset() + off_interface.reset() if config.timers is not None: config.timers('forward-backward').stop() @@ -2047,7 +2047,7 @@ def pp_post_backward(input_tensor_grad, vp_stage=None): ) if not forward_only and config.fine_grained_activation_offloading: - fine_grained_offloading_reset() + off_interface.reset() # Restore config.grad_sync_func and config.param_sync_func. if forward_only: config.grad_sync_func, config.param_sync_func = grad_sync_func, param_sync_func @@ -2437,7 +2437,7 @@ def enable_grad_sync(): ) if not forward_only and config.fine_grained_activation_offloading: - fine_grained_offloading_reset() + off_interface.reset() if config.timers is not None: config.timers('forward-backward').stop() diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 3ca9c4b3531..613b583b716 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -24,10 +24,12 @@ get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) +from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, +) from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( fine_grained_offloading_group_commit, fine_grained_offloading_group_start, - get_fine_grained_offloading_context, ) from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel.mappings import all_gather_last_dim_from_tensor_parallel_region @@ -832,7 +834,7 @@ def forward( if self.offload_qkv_linear: hidden_states = fine_grained_offloading_group_start(hidden_states, name="qkv_linear") - with get_fine_grained_offloading_context(self.offload_qkv_linear): + with off_interface.get_context(self.offload_qkv_linear): qkv_output = self.get_query_key_value_tensors( hidden_states, key_value_states, output_gate=output_gate, split_qkv=split_qkv ) @@ -994,7 +996,7 @@ def forward( query = fine_grained_offloading_group_start(query, name="core_attn") if inference_context is None or inference_context.is_static_batching(): # Static batching attention kernel. - with get_fine_grained_offloading_context(self.offload_core_attention): + with off_interface.get_context(self.offload_core_attention): core_attn_out = self.core_attention( query, key, @@ -1049,7 +1051,7 @@ def forward( nvtx_range_push(suffix="linear_proj") if self.offload_attn_proj: core_attn_out = fine_grained_offloading_group_start(core_attn_out, name="attn_proj") - with get_fine_grained_offloading_context(self.offload_attn_proj): + with off_interface.get_context(self.offload_attn_proj): output, bias = self.linear_proj(core_attn_out) if self.offload_attn_proj: output = fine_grained_offloading_group_commit( diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index fa2653a030a..abbe459cb7d 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -25,10 +25,12 @@ from megatron.core.fusions.fused_bias_swiglu import weighted_bias_swiglu_impl from megatron.core.fusions.fused_weighted_squared_relu import weighted_squared_relu_impl from megatron.core.jit import jit_fuser +from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, +) from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( fine_grained_offloading_group_commit, fine_grained_offloading_group_start, - get_fine_grained_offloading_context, ) from megatron.core.tensor_parallel.layers import ( _initialize_affine_weight_cpu, @@ -735,7 +737,7 @@ def forward( permuted_local_hidden_states = fine_grained_offloading_group_start( permuted_local_hidden_states, name="expert_fc1" ) - with get_fine_grained_offloading_context(self.offload_expert_fc1): + with off_interface.get_context(self.offload_expert_fc1): fc1_output, bias_parallel = self.linear_fc1( permuted_local_hidden_states, tokens_per_expert ) @@ -809,12 +811,12 @@ def glu(x): if self.activation_recompute: self.activation_checkpoint = tensor_parallel.CheckpointWithoutOutput() - with get_fine_grained_offloading_context(self.offload_moe_act): + with off_interface.get_context(self.offload_moe_act): bias_act_output = self.activation_checkpoint.checkpoint( bias_act_func, fc1_output, bias_parallel, permuted_probs ) else: - with get_fine_grained_offloading_context(self.offload_moe_act): + with off_interface.get_context(self.offload_moe_act): bias_act_output = bias_act_func(fc1_output, bias_parallel, permuted_probs) output, output_bias = self.linear_fc2(bias_act_output, tokens_per_expert) diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index 03c31d70686..349269640e8 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -22,10 +22,11 @@ _yarn_get_mscale, apply_rotary_pos_emb, ) +from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, +) from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( fine_grained_offloading_group_commit, - fine_grained_offloading_group_start, - get_fine_grained_offloading_context, ) from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel.layers import ColumnParallelLinear @@ -244,9 +245,7 @@ def forward( # Get the query, key and value tensors based on the type of attention - # self or cross attn. # query: [96, 1, 16, 128], key:[96, 1, 16, 128], value:[96, 1, 16, 128] - if self.offload_qkv_linear: - hidden_states = fine_grained_offloading_group_start(hidden_states, name="qkv_linear") - with get_fine_grained_offloading_context(self.offload_qkv_linear): + with off_interface.get_context(self.offload_qkv_linear): if self.config.experimental_attention_variant is None: query, key, value = self.get_query_key_value_tensors( hidden_states, @@ -299,11 +298,8 @@ def forward( query, key, value, attention_mask, packed_seq_params=packed_seq_params ) else: - if self.offload_core_attention and self.training: - query = fine_grained_offloading_group_start(query, name="core_attn") - if inference_context is None or inference_context.is_static_batching(): - with get_fine_grained_offloading_context(self.offload_core_attention): + with off_interface.get_context(self.offload_core_attention): if self.config.experimental_attention_variant is None: core_attn_out = self.core_attention( query, @@ -383,7 +379,7 @@ def forward( # ================= if self.offload_attn_proj: core_attn_out = fine_grained_offloading_group_start(core_attn_out, name="attn_proj") - with get_fine_grained_offloading_context(self.offload_attn_proj): + with off_interface.get_context(self.offload_attn_proj): output, bias = self.linear_proj(core_attn_out) if self.offload_attn_proj: output = fine_grained_offloading_group_commit( diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 6941963f6b0..ff92d3e0753 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -511,10 +511,12 @@ def _forward_attention( context (Tensor): Updated context tensor if cross-attention is used, otherwise None. """ + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, + ) from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( fine_grained_offloading_group_commit, fine_grained_offloading_group_start, - get_fine_grained_offloading_context, ) inference_context = deprecate_inference_params(inference_context, inference_params) @@ -527,12 +529,12 @@ def _forward_attention( # Optional Input Layer norm if self.recompute_input_layernorm: self.input_layernorm_checkpoint = tensor_parallel.CheckpointWithoutOutput() - with get_fine_grained_offloading_context(self.offload_attn_norm): + with off_interface.get_context(self.offload_attn_norm): input_layernorm_output = self.input_layernorm_checkpoint.checkpoint( self.input_layernorm, hidden_states ) else: - with get_fine_grained_offloading_context(self.offload_attn_norm): + with off_interface.get_context(self.offload_attn_norm): input_layernorm_output = self.input_layernorm(hidden_states) # Self attention. @@ -615,9 +617,11 @@ def _forward_mlp(self, hidden_states, inference_context=None, padding_mask=None) output (Tensor): Transformed hidden states of shape [s, b, h]. """ + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, + ) from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( fine_grained_offloading_group_start, - get_fine_grained_offloading_context, ) # Residual connection. @@ -628,12 +632,12 @@ def _forward_mlp(self, hidden_states, inference_context=None, padding_mask=None) # Optional Layer norm post the cross-attention. if self.recompute_pre_mlp_layernorm: self.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput() - with get_fine_grained_offloading_context(self.offload_mlp_norm): + with off_interface.get_context(self.offload_mlp_norm): pre_mlp_layernorm_output = self.pre_mlp_norm_checkpoint.checkpoint( self.pre_mlp_layernorm, hidden_states ) else: - with get_fine_grained_offloading_context(self.offload_mlp_norm): + with off_interface.get_context(self.offload_mlp_norm): pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states) nvtx_range_push(suffix="mlp") diff --git a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py index 88d76bf5d80..558c6934a0c 100644 --- a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py +++ b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py @@ -10,6 +10,9 @@ from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, +) from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.enums import AttnBackend from megatron.core.transformer.transformer_config import MLATransformerConfig, TransformerConfig @@ -107,10 +110,9 @@ def _run_one_iter_and_capture( - selected grads (CPU float32) - peak_memory_allocated (bytes) during the iteration """ - from megatron.core.pipeline_parallel import fine_grained_activation_offload as off if enable_offload_reset: - off.fine_grained_offloading_reset() + off_interface.reset() # for p in model.parameters(): # if p.grad is not None: @@ -179,7 +181,7 @@ def test_gpt_fine_grained_activation_offloading_correctness_and_memory( from megatron.core.pipeline_parallel import fine_grained_activation_offload as off - off.fine_grained_offloading_reset_instance() + off_interface.reset_instance() try: # 1) Baseline run (no offloading) @@ -244,7 +246,7 @@ def test_gpt_fine_grained_activation_offloading_correctness_and_memory( enable_offload_reset=True, ) # Reset once more to trigger post_warmup_callback and apply steady-state offload decisions. - off.fine_grained_offloading_reset() + off_interface.reset() from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( PipelineOffloadManager, @@ -453,7 +455,7 @@ def _run_schedule_1f1b_two_microbatches( This is the execution path that exercises EP A2A overlap scheduling. """ if enable_offload_reset: - off.fine_grained_offloading_reset() + off_interface.reset() data0 = _make_schedule_inputs() data1 = _make_schedule_inputs() @@ -486,7 +488,7 @@ def _run_schedule_1f1b_two_microbatches( ) set_streams() - off.fine_grained_offloading_reset_instance() + off_interface.reset_instance() try: with deterministic_mode(): @@ -513,9 +515,9 @@ def _run_schedule_1f1b_two_microbatches( _restore_params(off_model, base_params) off_model.train() # Warmup once to populate cached chunks, then reset to apply steady-state offload decisions. - off.fine_grained_offloading_reset() + off_interface.reset() _run_schedule_1f1b_two_microbatches(off_model, enable_offload_reset=False) - off.fine_grained_offloading_reset() + off_interface.reset() from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( PipelineOffloadManager, ) From 0a92566a8a4ee0ef1e60a70eea471420d81d0fdb Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Wed, 14 Jan 2026 01:48:42 -0800 Subject: [PATCH 45/74] remove group_start() calls Signed-off-by: Hongbin Liu --- .../fine_grained_activation_offload.py | 18 ++++++++++++++++-- megatron/core/transformer/attention.py | 14 +++++--------- megatron/core/transformer/moe/experts.py | 14 ++++---------- .../transformer/multi_latent_attention.py | 10 +++++----- .../core/transformer/transformer_layer.py | 19 ++++++++----------- 5 files changed, 38 insertions(+), 37 deletions(-) diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index 4e6c8219ee2..47bfc0518c4 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -1205,8 +1205,22 @@ def fine_grained_offloading_backward_record(tensor, event: torch.cuda.Event) -> class FineGrainedActivationOffloadingInterface: """Interface for fine-grained activation offloading.""" - def __init__(self): - pass + def __init__(self, offload: bool, tensor: torch.Tensor, name: str): + self.offload = offload + self.tensor = tensor + self.name = name + + def __enter__(self): + """Enter context manager to enable activation offloading hooks.""" + if self.offload: + self.tensor = fine_grained_offloading_group_start(self.tensor, self.name) + PipelineOffloadManager.get_instance().__enter__() + return self.tensor + + def __exit__(self, *args: Any): + """Exit context manager to disable activation offloading hooks.""" + if self.offload: + PipelineOffloadManager.get_instance().__exit__() @staticmethod def init_chunk_handler(vp_size, vp_stage, min_offloaded_tensor_size): diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 613b583b716..361c5d64590 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -832,9 +832,8 @@ def forward( if output_gate: assert split_qkv, "output_gate is not supported for unsplit mixed_qkv tensor." - if self.offload_qkv_linear: - hidden_states = fine_grained_offloading_group_start(hidden_states, name="qkv_linear") - with off_interface.get_context(self.offload_qkv_linear): + with off_interface(self.offload_qkv_linear, hidden_states, "qkv_linear") \ + as hidden_states: qkv_output = self.get_query_key_value_tensors( hidden_states, key_value_states, output_gate=output_gate, split_qkv=split_qkv ) @@ -992,11 +991,10 @@ def forward( packed_seq_params=packed_seq_params, ) else: - if self.offload_core_attention and self.training: - query = fine_grained_offloading_group_start(query, name="core_attn") if inference_context is None or inference_context.is_static_batching(): # Static batching attention kernel. - with off_interface.get_context(self.offload_core_attention): + with off_interface(self.offload_core_attention and self.training, query, \ + "core_attn") as query: core_attn_out = self.core_attention( query, key, @@ -1049,9 +1047,7 @@ def forward( # ================= nvtx_range_push(suffix="linear_proj") - if self.offload_attn_proj: - core_attn_out = fine_grained_offloading_group_start(core_attn_out, name="attn_proj") - with off_interface.get_context(self.offload_attn_proj): + with off_interface(self.offload_attn_proj, core_attn_out, "attn_proj") as core_attn_out: output, bias = self.linear_proj(core_attn_out) if self.offload_attn_proj: output = fine_grained_offloading_group_commit( diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index abbe459cb7d..bf9ebdb0ab3 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -733,11 +733,8 @@ def forward( # Probs already applied, so reset to 1. permuted_probs = torch.ones_like(permuted_probs) - if self.offload_expert_fc1: - permuted_local_hidden_states = fine_grained_offloading_group_start( - permuted_local_hidden_states, name="expert_fc1" - ) - with off_interface.get_context(self.offload_expert_fc1): + with off_interface(self.offload_expert_fc1, permuted_local_hidden_states, "expert_fc1") \ + as permuted_local_hidden_states: fc1_output, bias_parallel = self.linear_fc1( permuted_local_hidden_states, tokens_per_expert ) @@ -806,17 +803,14 @@ def glu(x): intermediate_parallel = intermediate_parallel.to(original_dtype) return intermediate_parallel - if self.offload_moe_act: - fc1_output = fine_grained_offloading_group_start(fc1_output, name="moe_act") - if self.activation_recompute: self.activation_checkpoint = tensor_parallel.CheckpointWithoutOutput() - with off_interface.get_context(self.offload_moe_act): + with off_interface(self.offload_moe_act, fc1_output, "moe_act") as fc1_output: bias_act_output = self.activation_checkpoint.checkpoint( bias_act_func, fc1_output, bias_parallel, permuted_probs ) else: - with off_interface.get_context(self.offload_moe_act): + with off_interface(self.offload_moe_act, fc1_output, "moe_act") as fc1_output: bias_act_output = bias_act_func(fc1_output, bias_parallel, permuted_probs) output, output_bias = self.linear_fc2(bias_act_output, tokens_per_expert) diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index 349269640e8..cc82c66dbb1 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -245,7 +245,8 @@ def forward( # Get the query, key and value tensors based on the type of attention - # self or cross attn. # query: [96, 1, 16, 128], key:[96, 1, 16, 128], value:[96, 1, 16, 128] - with off_interface.get_context(self.offload_qkv_linear): + with off_interface(self.offload_qkv_linear, hidden_states, "qkv_linear") \ + as hidden_states: if self.config.experimental_attention_variant is None: query, key, value = self.get_query_key_value_tensors( hidden_states, @@ -299,7 +300,8 @@ def forward( ) else: if inference_context is None or inference_context.is_static_batching(): - with off_interface.get_context(self.offload_core_attention): + with off_interface(self.offload_core_attention and self.training, query, \ + "core_attn") as query: if self.config.experimental_attention_variant is None: core_attn_out = self.core_attention( query, @@ -377,9 +379,7 @@ def forward( # ================= # Output. [sq, b, h] # ================= - if self.offload_attn_proj: - core_attn_out = fine_grained_offloading_group_start(core_attn_out, name="attn_proj") - with off_interface.get_context(self.offload_attn_proj): + with off_interface(self.offload_attn_proj, core_attn_out, "attn_proj") as core_attn_out: output, bias = self.linear_proj(core_attn_out) if self.offload_attn_proj: output = fine_grained_offloading_group_commit( diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index ff92d3e0753..c7c26de922b 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -524,17 +524,17 @@ def _forward_attention( # Residual connection. residual = hidden_states - if self.offload_attn_norm: - hidden_states = fine_grained_offloading_group_start(hidden_states, name="attn_norm") # Optional Input Layer norm if self.recompute_input_layernorm: self.input_layernorm_checkpoint = tensor_parallel.CheckpointWithoutOutput() - with off_interface.get_context(self.offload_attn_norm): + with off_interface(self.offload_attn_norm, hidden_states, "attn_norm") \ + as hidden_states: input_layernorm_output = self.input_layernorm_checkpoint.checkpoint( self.input_layernorm, hidden_states ) else: - with off_interface.get_context(self.offload_attn_norm): + with off_interface(self.offload_attn_norm, hidden_states, "attn_norm") \ + as hidden_states: input_layernorm_output = self.input_layernorm(hidden_states) # Self attention. @@ -620,24 +620,21 @@ def _forward_mlp(self, hidden_states, inference_context=None, padding_mask=None) from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( FineGrainedActivationOffloadingInterface as off_interface, ) - from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - fine_grained_offloading_group_start, - ) # Residual connection. residual = hidden_states - if self.offload_mlp_norm: - hidden_states = fine_grained_offloading_group_start(hidden_states, name="mlp_norm") # Optional Layer norm post the cross-attention. if self.recompute_pre_mlp_layernorm: self.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput() - with off_interface.get_context(self.offload_mlp_norm): + with off_interface(self.offload_mlp_norm, hidden_states, "mlp_norm") \ + as hidden_states: pre_mlp_layernorm_output = self.pre_mlp_norm_checkpoint.checkpoint( self.pre_mlp_layernorm, hidden_states ) else: - with off_interface.get_context(self.offload_mlp_norm): + with off_interface(self.offload_mlp_norm, hidden_states, "mlp_norm") \ + as hidden_states: pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states) nvtx_range_push(suffix="mlp") From 0df5134ae72e7ddafddb7b67bcf66c677aed0b35 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Wed, 14 Jan 2026 02:05:47 -0800 Subject: [PATCH 46/74] add comments Signed-off-by: Hongbin Liu --- megatron/core/models/gpt/fine_grained_callables.py | 9 ++++----- megatron/core/pipeline_parallel/utils.py | 11 ++++++++++- megatron/core/transformer/moe/experts.py | 3 +++ megatron/core/transformer/transformer_layer.py | 5 ++++- 4 files changed, 21 insertions(+), 7 deletions(-) diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index 4ddbab70e8b..b395fef6df6 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -15,7 +15,6 @@ ) from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( fine_grained_offloading_group_commit, - fine_grained_offloading_group_start, ) from megatron.core.pipeline_parallel.utils import ScheduleNode, make_viewless from megatron.core.transformer.enums import CudaGraphScope @@ -489,16 +488,14 @@ def submodule_post_attn_forward(node: ScheduleNode, hidden_states: torch.Tensor) Run forward pass for computations between attention and dispatch: pre mlp layernorm->router->dispatch preprocess """ - if layer.offload_mlp_norm: - hidden_states = fine_grained_offloading_group_start(hidden_states, name="mlp_norm") if layer.recompute_pre_mlp_layernorm: layer.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput() - with off_interface.get_context(layer.offload_mlp_norm): + with off_interface(layer.offload_mlp_norm, hidden_states, "mlp_norm") as hidden_states: pre_mlp_layernorm_output = layer.pre_mlp_norm_checkpoint.checkpoint( layer.pre_mlp_layernorm, hidden_states ) else: - with off_interface.get_context(layer.offload_mlp_norm): + with off_interface(layer.offload_mlp_norm, hidden_states, "mlp_norm") as hidden_states: pre_mlp_layernorm_output = layer.pre_mlp_layernorm(hidden_states) probs, routing_map = layer.mlp.route(pre_mlp_layernorm_output) @@ -573,6 +570,8 @@ def submodule_combine_forward(node: ScheduleNode, output: torch.Tensor): hidden_states = layer.mlp_bda(layer.training, layer.config.bias_dropout_fusion)( mlp_output_with_bias, residual, layer.hidden_dropout ) + # Delay the offload of the mlp norm until after the mlp_bda has been computed + # because the residual is needed in the mlp_bda. if layer.offload_mlp_norm: hidden_states = fine_grained_offloading_group_commit( hidden_states, name="mlp_norm", forced_released_tensors=[residual] diff --git a/megatron/core/pipeline_parallel/utils.py b/megatron/core/pipeline_parallel/utils.py index 7c399ef951b..bda6334fc4b 100644 --- a/megatron/core/pipeline_parallel/utils.py +++ b/megatron/core/pipeline_parallel/utils.py @@ -1,5 +1,6 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +import logging from abc import ABC, abstractmethod from contextlib import contextmanager from typing import Callable, Optional @@ -7,7 +8,9 @@ import torch from torch.autograd import Variable -from megatron.core.utils import get_pg_rank, get_pg_size, make_viewless_tensor +from megatron.core.utils import get_pg_rank, get_pg_size, log_single_rank, make_viewless_tensor + +logger = logging.getLogger(__name__) def is_pp_first_stage(pp_group: torch.distributed.ProcessGroup): @@ -106,6 +109,12 @@ def set_ideal_affinity_for_current_gpu(): handle = pynvml.nvmlDeviceGetHandleByUUID("GPU-" + str(uuid.UUID(bytes=device_uuid.bytes))) pynvml.nvmlDeviceSetCpuAffinity(handle) + log_single_rank( + logger, + logging.WARNING, + f"Set CPU affinity for all GPUs for optimal host-device transfer performance", + ) + @contextmanager def stream_acquire_context(stream, event): diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index bf9ebdb0ab3..c4fb3a96bef 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -816,6 +816,9 @@ def glu(x): output, output_bias = self.linear_fc2(bias_act_output, tokens_per_expert) if self.activation_recompute: self.activation_checkpoint.discard_output_and_register_recompute(output) + + # Delay the offload of the moe act until after the linear_fc2 has been computed + # to make sure the fc1_output is reloaded to GPU before recomputing moe_act. if self.offload_moe_act: output = fine_grained_offloading_group_commit( output, name="moe_act", forced_released_tensors=[fc1_output] diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index c7c26de922b..1b8ce94682a 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -516,7 +516,6 @@ def _forward_attention( ) from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( fine_grained_offloading_group_commit, - fine_grained_offloading_group_start, ) inference_context = deprecate_inference_params(inference_context, inference_params) @@ -569,6 +568,8 @@ def _forward_attention( ) nvtx_range_pop(suffix="self_attn_bda") + # Delay the offload of the attention norm until after the self_attn_bda has been computed + # because the residual is needed in the self_attn_bda. if self.offload_attn_norm: hidden_states = fine_grained_offloading_group_commit( hidden_states, name="attn_norm", forced_released_tensors=[residual] @@ -728,6 +729,8 @@ def _forward_post_mlp(self, mlp_output_with_bias, residual): mlp_output_with_bias, residual, self.hidden_dropout ) nvtx_range_pop(suffix="mlp_bda") + # Delay the offload of the mlp norm until after the mlp_bda has been computed + # because the residual is needed in the mlp_bda. if self.offload_mlp_norm: hidden_states = fine_grained_offloading_group_commit( hidden_states, name="mlp_norm", forced_released_tensors=[residual] From 62d36f2b19dc23b0307ada9326344e675b20f17b Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Wed, 14 Jan 2026 02:51:24 -0800 Subject: [PATCH 47/74] fix min_offload_size and update golden values Signed-off-by: Hongbin Liu --- .../core/models/gpt/fine_grained_callables.py | 34 +- megatron/core/transformer/moe/experts.py | 1 - megatron/training/arguments.py | 2 +- .../golden_values_dev_dgx_h100.json | 594 +++++++++--------- .../model_config.yaml | 2 +- .../golden_values_dev_dgx_h100.json | 498 +++++++-------- .../model_config.yaml | 7 +- .../unit_tests/models/test_mamba_moe_model.py | 552 ++++++++++++++++ 8 files changed, 1106 insertions(+), 584 deletions(-) create mode 100644 tests/unit_tests/models/test_mamba_moe_model.py diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index b395fef6df6..6dbd65f7d99 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -451,18 +451,14 @@ def forward_func( ) if not isinstance(layer.mlp, MoELayer): return hidden_states, None, None, None - if layer.offload_mlp_norm: - hidden_states = fine_grained_offloading_group_start( - hidden_states, name="mlp_norm" - ) if layer.recompute_pre_mlp_layernorm: layer.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput() - with get_fine_grained_offloading_context(layer.offload_mlp_norm): + with off_interface(layer.offload_mlp_norm, hidden_states, "mlp_norm") as hidden_states: pre_mlp_layernorm_output = layer.pre_mlp_norm_checkpoint.checkpoint( layer.pre_mlp_layernorm, hidden_states ) else: - with get_fine_grained_offloading_context(layer.offload_mlp_norm): + with off_interface(layer.offload_mlp_norm, hidden_states, "mlp_norm") as hidden_states: pre_mlp_layernorm_output = layer.pre_mlp_layernorm(hidden_states) shared_expert_output = layer.mlp.shared_experts_compute(pre_mlp_layernorm_output) @@ -483,32 +479,6 @@ def forward_func( ) return hidden_states - def submodule_post_attn_forward(node: ScheduleNode, hidden_states: torch.Tensor): - """ - Run forward pass for computations between attention and dispatch: - pre mlp layernorm->router->dispatch preprocess - """ - if layer.recompute_pre_mlp_layernorm: - layer.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput() - with off_interface(layer.offload_mlp_norm, hidden_states, "mlp_norm") as hidden_states: - pre_mlp_layernorm_output = layer.pre_mlp_norm_checkpoint.checkpoint( - layer.pre_mlp_layernorm, hidden_states - ) - else: - with off_interface(layer.offload_mlp_norm, hidden_states, "mlp_norm") as hidden_states: - pre_mlp_layernorm_output = layer.pre_mlp_layernorm(hidden_states) - - probs, routing_map = layer.mlp.route(pre_mlp_layernorm_output) - local_tokens, probs = layer.mlp.preprocess(pre_mlp_layernorm_output, probs, routing_map) - - # Detach here for mlp_bda residual connection - node.layer_state.residual = node.detach(hidden_states) - if layer.mlp.use_shared_expert and not layer.mlp.shared_expert_overlap: - # Detach here for shared expert connection in moe_combine - node.layer_state.shared_expert_output = node.detach(shared_expert_output) - - return local_tokens, probs - def submodule_dispatch_forward( node: ScheduleNode, local_tokens: torch.Tensor, probs: torch.Tensor ): diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index c4fb3a96bef..eca82b04570 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -30,7 +30,6 @@ ) from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( fine_grained_offloading_group_commit, - fine_grained_offloading_group_start, ) from megatron.core.tensor_parallel.layers import ( _initialize_affine_weight_cpu, diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 13f6e041a36..ccf1fcd8534 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -2436,7 +2436,7 @@ def _add_training_args(parser): help='Enable fine-grained activation offloading.') group.add_argument('--offload-modules', nargs='*', type=str, default=[], help='The submodules to offload its input. Choices: "attn_norm", "qkv_linear", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act".') - group.add_argument('--min-offloaded-tensor-size', type=int, default=10*1024*1024, + group.add_argument('--min-offloaded-tensor-size', type=int, default=1024*1024, help='The minimum size of the tensor to be offloaded.') group.add_argument('--disable-jit-fuser', action='store_true', help='Disable the JIT fuser.') diff --git a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/golden_values_dev_dgx_h100.json b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/golden_values_dev_dgx_h100.json index 51e9d7154c9..8fbe219530d 100644 --- a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/golden_values_dev_dgx_h100.json +++ b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/golden_values_dev_dgx_h100.json @@ -6,54 +6,54 @@ "values": { "1": 11.06693, "2": 11.0602, - "3": 10.21167, - "4": 9.95277, - "5": 10.12388, - "6": 8.82369, - "7": 9.52785, - "8": 8.44289, - "9": 7.85041, - "10": 7.07093, - "11": 9.28562, - "12": 9.13324, - "13": 7.86224, - "14": 8.19705, - "15": 8.22932, - "16": 8.17783, - "17": 8.2161, - "18": 7.50358, - "19": 8.08893, - "20": 7.64905, - "21": 7.95183, - "22": 7.29849, - "23": 7.93348, - "24": 7.43565, - "25": 8.2385, - "26": 7.75634, - "27": 7.70075, - "28": 7.66089, - "29": 7.75606, - "30": 7.56072, - "31": 7.81859, - "32": 6.46861, - "33": 7.20532, - "34": 7.77706, - "35": 7.73113, - "36": 6.72448, - "37": 8.09344, - "38": 7.62008, - "39": 7.96872, - "40": 7.4992, - "41": 7.49916, - "42": 6.11993, - "43": 7.59389, - "44": 7.91482, - "45": 6.83633, - "46": 7.41335, - "47": 7.78887, - "48": 7.87666, - "49": 7.58746, - "50": 6.84352 + "3": 10.16141, + "4": 10.11145, + "5": 10.47957, + "6": 10.21751, + "7": 10.56153, + "8": 12.79501, + "9": 12.96949, + "10": 13.32223, + "11": 11.63359, + "12": 11.4938, + "13": 12.46292, + "14": 12.13415, + "15": 11.90295, + "16": 12.01307, + "17": 12.17443, + "18": 12.64978, + "19": 11.81295, + "20": 12.18673, + "21": 11.24306, + "22": 11.54156, + "23": 10.98412, + "24": 11.01925, + "25": 10.73001, + "26": 10.72806, + "27": 10.79039, + "28": 10.714, + "29": 10.73974, + "30": 10.75246, + "31": 10.68874, + "32": 10.65791, + "33": 10.81137, + "34": 10.79058, + "35": 10.75368, + "36": 10.64393, + "37": 10.87492, + "38": 10.90591, + "39": 10.78825, + "40": 10.75548, + "41": 10.8955, + "42": 10.70411, + "43": 10.66907, + "44": 10.72512, + "45": 10.54927, + "46": 10.46973, + "47": 10.66311, + "48": 10.62453, + "49": 10.61656, + "50": 10.21176 } }, "num-zeros": { @@ -61,56 +61,56 @@ "end_step": 50, "step_interval": 1, "values": { - "1": 47165160.0, - "2": 46897928.0, - "3": 52684380.0, - "4": 297108064.0, - "5": 556667648.0, - "6": 661861120.0, - "7": 1027446592.0, - "8": 742822528.0, - "9": 846651648.0, - "10": 693167680.0, - "11": 826875520.0, - "12": 814304768.0, - "13": 642608768.0, - "14": 606554752.0, - "15": 728814528.0, - "16": 845696384.0, - "17": 667529728.0, - "18": 673504384.0, - "19": 889544960.0, - "20": 890696768.0, - "21": 676302464.0, - "22": 688965120.0, - "23": 789972480.0, - "24": 761249536.0, - "25": 648185280.0, - "26": 789507392.0, - "27": 641355648.0, - "28": 805511168.0, - "29": 773780224.0, - "30": 811888960.0, - "31": 688167744.0, - "32": 834871424.0, - "33": 792944256.0, - "34": 777109568.0, - "35": 763515136.0, - "36": 733607744.0, - "37": 743626240.0, - "38": 746577024.0, - "39": 732972864.0, - "40": 735645696.0, - "41": 556711680.0, - "42": 680528384.0, - "43": 669752960.0, - "44": 667702912.0, - "45": 635197248.0, - "46": 629093120.0, - "47": 626713344.0, - "48": 600843456.0, - "49": 581506752.0, - "50": 572705728.0 + "1": 47165216.0, + "2": 46897552.0, + "3": 52682736.0, + "4": 70585808.0, + "5": 1850183680.0, + "6": 171098656.0, + "7": 436105120.0, + "8": 1850183680.0, + "9": 1850183680.0, + "10": 1850183680.0, + "11": 1850183680.0, + "12": 1850183680.0, + "13": 1850183680.0, + "14": 1850183680.0, + "15": 555857088.0, + "16": 1850183680.0, + "17": 1850183680.0, + "18": 1850183680.0, + "19": 886404992.0, + "20": 654826944.0, + "21": 603993664.0, + "22": 726709632.0, + "23": 566656896.0, + "24": 1850183680.0, + "25": 799245696.0, + "26": 978252032.0, + "27": 1850183680.0, + "28": 906183104.0, + "29": 1850183680.0, + "30": 1850183680.0, + "31": 810874112.0, + "32": 1850183680.0, + "33": 1850183680.0, + "34": 553779584.0, + "35": 565382400.0, + "36": 585787712.0, + "37": 627284160.0, + "38": 331368192.0, + "39": 638619264.0, + "40": 1850183680.0, + "41": 1850183680.0, + "42": 1850183680.0, + "43": 1850183680.0, + "44": 1850183680.0, + "45": 1850183680.0, + "46": 1850183680.0, + "47": 434842944.0, + "48": 1850183680.0, + "49": 575219328.0, + "50": 1850183680.0 } }, "mem-allocated-bytes": { @@ -118,56 +118,56 @@ "end_step": 50, "step_interval": 1, "values": { - "1": 5275215360.0, - "2": 5275420160.0, - "3": 5275622912.0, - "4": 5275217408.0, - "5": 5275420160.0, - "6": 5275622912.0, - "7": 5275825664.0, - "8": 5276028416.0, - "9": 5276231168.0, - "10": 5276433920.0, - "11": 5276636672.0, - "12": 5276839424.0, - "13": 5277042176.0, - "14": 5277244928.0, - "15": 5277447680.0, - "16": 5277650432.0, - "17": 5277853184.0, - "18": 5278055936.0, - "19": 5278258688.0, - "20": 5278461440.0, - "21": 5278664192.0, - "22": 5278866944.0, - "23": 5279069696.0, - "24": 5279272448.0, - "25": 5279475200.0, - "26": 5279677952.0, - "27": 5279880704.0, - "28": 5280083456.0, - "29": 5280286208.0, - "30": 5280488960.0, - "31": 5280691712.0, - "32": 5280894464.0, - "33": 5281097216.0, - "34": 5281299968.0, - "35": 5281502720.0, - "36": 5281705472.0, - "37": 5281908224.0, - "38": 5282110976.0, - "39": 5282313728.0, - "40": 5282516480.0, - "41": 5282719232.0, - "42": 5282921984.0, - "43": 5283124736.0, - "44": 5283327488.0, - "45": 5283530240.0, - "46": 5283732992.0, - "47": 5283935744.0, - "48": 5284138496.0, - "49": 5284341248.0, - "50": 5284544000.0 + "1": 5283616256.0, + "2": 5288015360.0, + "3": 5288218112.0, + "4": 5288420864.0, + "5": 5288623616.0, + "6": 5287812608.0, + "7": 5288015360.0, + "8": 5288218112.0, + "9": 5287711232.0, + "10": 5287913984.0, + "11": 5288116736.0, + "12": 5288319488.0, + "13": 5288522240.0, + "14": 5288724992.0, + "15": 5288927744.0, + "16": 5289130496.0, + "17": 5289333248.0, + "18": 5289536000.0, + "19": 5289738752.0, + "20": 5289941504.0, + "21": 5290144256.0, + "22": 5290347008.0, + "23": 5290549760.0, + "24": 5290752512.0, + "25": 5290955264.0, + "26": 5291158016.0, + "27": 5291360768.0, + "28": 5291563520.0, + "29": 5291766272.0, + "30": 5291969024.0, + "31": 5292171776.0, + "32": 5292374528.0, + "33": 5292577280.0, + "34": 5292780032.0, + "35": 5292982784.0, + "36": 5293185536.0, + "37": 5293388288.0, + "38": 5293591040.0, + "39": 5293793792.0, + "40": 5293996544.0, + "41": 5294199296.0, + "42": 5294402048.0, + "43": 5294604800.0, + "44": 5294807552.0, + "45": 5295010304.0, + "46": 5295213056.0, + "47": 5295415808.0, + "48": 5295618560.0, + "49": 5295821312.0, + "50": 5296024064.0 } }, "mem-max-allocated-bytes": { @@ -175,56 +175,56 @@ "end_step": 50, "step_interval": 1, "values": { - "1": 6208857600.0, - "2": 8233667072.0, - "3": 8233667072.0, - "4": 8233667072.0, - "5": 8233667072.0, - "6": 8233667072.0, - "7": 8233667072.0, - "8": 8233667072.0, - "9": 8233667072.0, - "10": 8233667072.0, - "11": 8262715904.0, - "12": 8262715904.0, - "13": 8262715904.0, - "14": 8262715904.0, - "15": 8262715904.0, - "16": 8268117504.0, - "17": 8288236032.0, - "18": 8288236032.0, - "19": 8288236032.0, - "20": 8288236032.0, - "21": 8288236032.0, - "22": 8299924992.0, - "23": 8302176768.0, - "24": 8302176768.0, - "25": 8302176768.0, - "26": 8302176768.0, - "27": 8302176768.0, - "28": 8302176768.0, - "29": 8302176768.0, - "30": 8302176768.0, - "31": 8302176768.0, - "32": 8302176768.0, - "33": 8302176768.0, - "34": 8302176768.0, - "35": 8302176768.0, - "36": 8302176768.0, - "37": 8302176768.0, - "38": 8313753088.0, - "39": 8313753088.0, - "40": 8313753088.0, - "41": 8313753088.0, - "42": 8313753088.0, - "43": 8313753088.0, - "44": 8313753088.0, - "45": 8313753088.0, - "46": 8313753088.0, - "47": 8313753088.0, - "48": 8313753088.0, - "49": 8313753088.0, - "50": 8313753088.0 + "1": 5283618816.0, + "2": 8185453056.0, + "3": 8185453056.0, + "4": 8185453056.0, + "5": 8195318272.0, + "6": 8195318272.0, + "7": 8195318272.0, + "8": 8195318272.0, + "9": 8195318272.0, + "10": 8195318272.0, + "11": 8195318272.0, + "12": 8195318272.0, + "13": 8195318272.0, + "14": 8195318272.0, + "15": 8195318272.0, + "16": 8199233024.0, + "17": 8199233024.0, + "18": 8199233024.0, + "19": 8199233024.0, + "20": 8199233024.0, + "21": 8238446080.0, + "22": 8238446080.0, + "23": 8238446080.0, + "24": 8238446080.0, + "25": 8247293440.0, + "26": 8247293440.0, + "27": 8247293440.0, + "28": 8250185216.0, + "29": 8255527424.0, + "30": 8255527424.0, + "31": 8255527424.0, + "32": 8255527424.0, + "33": 8255527424.0, + "34": 8255527424.0, + "35": 8255527424.0, + "36": 8255527424.0, + "37": 8255527424.0, + "38": 8255527424.0, + "39": 8255527424.0, + "40": 8255527424.0, + "41": 8255527424.0, + "42": 8255527424.0, + "43": 8255527424.0, + "44": 8255527424.0, + "45": 8255527424.0, + "46": 8255527424.0, + "47": 8255527424.0, + "48": 8255527424.0, + "49": 8255527424.0, + "50": 8255527424.0 } }, "mtp_1 loss": { @@ -234,54 +234,54 @@ "values": { "1": 11.07401, "2": 11.0927, - "3": 10.8262, - "4": 10.27574, - "5": 10.45324, - "6": 8.32758, - "7": 9.82629, - "8": 8.01538, - "9": 7.47611, - "10": 6.75851, - "11": 8.92961, - "12": 8.98772, - "13": 7.80203, - "14": 8.02221, - "15": 8.11372, - "16": 8.14498, - "17": 8.13435, - "18": 7.45035, - "19": 8.03784, - "20": 7.54246, - "21": 7.90269, - "22": 7.28093, - "23": 7.88727, - "24": 7.37587, - "25": 8.17289, - "26": 7.70083, - "27": 7.62668, - "28": 7.61747, - "29": 7.69888, - "30": 7.48586, - "31": 7.74301, - "32": 6.37542, - "33": 7.13919, - "34": 7.7198, - "35": 7.63387, - "36": 6.6127, - "37": 8.03449, - "38": 7.58334, - "39": 7.89887, - "40": 7.41168, - "41": 7.42316, - "42": 6.01689, - "43": 7.48867, - "44": 7.86976, - "45": 6.75113, - "46": 7.3054, - "47": 7.73281, - "48": 7.79017, - "49": 7.48985, - "50": 6.75753 + "3": 10.83159, + "4": 10.61397, + "5": 10.85768, + "6": 9.79263, + "7": 10.90607, + "8": 10.19798, + "9": 9.82717, + "10": 9.23805, + "11": 11.0712, + "12": 11.11709, + "13": 10.03407, + "14": 10.27606, + "15": 10.73067, + "16": 10.91485, + "17": 10.76886, + "18": 10.49659, + "19": 10.96955, + "20": 10.45905, + "21": 10.91629, + "22": 10.05081, + "23": 10.44411, + "24": 9.74826, + "25": 10.81497, + "26": 10.38519, + "27": 10.31999, + "28": 10.27887, + "29": 10.40945, + "30": 10.20684, + "31": 10.54594, + "32": 8.85942, + "33": 9.75619, + "34": 10.56214, + "35": 10.59167, + "36": 9.30537, + "37": 10.59407, + "38": 10.2994, + "39": 10.69954, + "40": 10.37003, + "41": 10.248, + "42": 8.56376, + "43": 10.49224, + "44": 10.57211, + "45": 9.36238, + "46": 10.2179, + "47": 10.63449, + "48": 10.56697, + "49": 10.44093, + "50": 9.49252 } }, "iteration-time": { @@ -289,56 +289,56 @@ "end_step": 50, "step_interval": 1, "values": { - "1": 64.76466, - "2": 2.42359, - "3": 2.56054, - "4": 2.61199, - "5": 2.3272, - "6": 2.19806, - "7": 2.16133, - "8": 1.97339, - "9": 2.14238, - "10": 2.05512, - "11": 2.00856, - "12": 1.96198, - "13": 2.08656, - "14": 1.96948, - "15": 1.96059, - "16": 1.97248, - "17": 1.97639, - "18": 2.01386, - "19": 1.9606, - "20": 1.94716, - "21": 2.00286, - "22": 1.965, - "23": 2.03401, - "24": 2.00528, - "25": 2.03321, - "26": 1.95999, - "27": 1.96395, - "28": 1.98191, - "29": 1.99346, - "30": 1.97579, - "31": 1.95097, - "32": 1.95726, - "33": 1.9399, - "34": 1.99177, - "35": 1.91153, - "36": 1.97534, - "37": 1.95691, - "38": 1.96206, - "39": 1.9414, - "40": 1.96027, - "41": 1.97807, - "42": 1.98861, - "43": 1.94856, - "44": 1.96339, - "45": 1.96835, - "46": 1.99733, - "47": 1.9716, - "48": 1.96591, - "49": 1.93865, - "50": 1.95198 + "1": 71.30157, + "2": 2.34464, + "3": 2.38747, + "4": 2.10322, + "5": 2.12945, + "6": 2.0424, + "7": 2.12036, + "8": 2.0147, + "9": 2.04925, + "10": 2.02797, + "11": 1.95087, + "12": 2.04985, + "13": 1.94106, + "14": 1.90425, + "15": 1.89051, + "16": 1.89398, + "17": 1.94082, + "18": 1.93176, + "19": 1.94027, + "20": 1.90271, + "21": 1.91097, + "22": 1.90382, + "23": 1.93889, + "24": 1.90551, + "25": 1.90947, + "26": 1.92126, + "27": 1.89917, + "28": 1.89866, + "29": 1.93981, + "30": 1.90782, + "31": 1.91244, + "32": 1.93864, + "33": 1.93947, + "34": 1.96882, + "35": 1.89751, + "36": 1.94038, + "37": 1.90603, + "38": 1.94988, + "39": 1.89874, + "40": 1.90233, + "41": 1.92861, + "42": 1.93931, + "43": 1.91212, + "44": 1.92615, + "45": 1.89555, + "46": 1.94522, + "47": 1.9103, + "48": 1.94689, + "49": 1.9355, + "50": 1.89832 } } -} \ No newline at end of file +} diff --git a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/model_config.yaml b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/model_config.yaml index be34eb9aec5..38528836659 100644 --- a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/model_config.yaml +++ b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/model_config.yaml @@ -5,6 +5,7 @@ ENV_VARS: NCCL_NVLS_ENABLE: 0 PYTHONWARNINGS: ignore NCCL_DEBUG: VERSION + NVTE_CPU_OFFLOAD_V1: 1 NVTE_FUSED_ATTN: 0 NCCL_ALGO: ^NVLS CUBLAS_WORKSPACE_CONFIG: ':4096:8' @@ -134,7 +135,6 @@ TEST_TYPE: regular # Usually ckpt-resume, but as a WAR to #513 set to regular METRICS: # - "iteration-time" - "lm loss" - - "num-zeros" - "mem-allocated-bytes" - "mem-max-allocated-bytes" - "mtp_1 loss" diff --git a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/golden_values_dev_dgx_h100.json b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/golden_values_dev_dgx_h100.json index 162edd4f113..d80f6baf9e8 100644 --- a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/golden_values_dev_dgx_h100.json +++ b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/golden_values_dev_dgx_h100.json @@ -6,54 +6,54 @@ "values": { "1": 11.01693, "2": 11.06263, - "3": 10.1782, - "4": 10.86126, - "5": 9.81699, - "6": 9.10047, - "7": 9.6123, - "8": 8.39574, - "9": 7.79397, - "10": 7.15194, - "11": 9.06709, - "12": 12.4321, - "13": 8.58689, - "14": 8.37208, - "15": 8.32207, - "16": 8.28873, - "17": 8.33948, - "18": 7.62098, - "19": 8.20737, - "20": 7.71874, - "21": 8.02566, - "22": 7.37552, - "23": 7.97218, - "24": 7.52837, - "25": 8.3433, - "26": 7.79595, - "27": 7.73606, - "28": 7.71545, - "29": 7.78466, - "30": 7.57814, - "31": 7.86251, - "32": 6.53514, - "33": 7.24722, - "34": 7.81689, - "35": 7.75181, - "36": 6.74644, - "37": 8.15937, - "38": 7.62962, - "39": 7.9886, - "40": 7.53058, - "41": 7.54209, - "42": 6.14029, - "43": 7.61626, - "44": 7.97638, - "45": 6.85528, - "46": 7.44245, - "47": 7.84386, - "48": 7.89235, - "49": 7.61461, - "50": 6.86695 + "3": 10.0893, + "4": 9.64622, + "5": 10.351, + "6": 8.80033, + "7": 10.38861, + "8": 9.08827, + "9": 9.2025, + "10": 8.69816, + "11": 10.71757, + "12": 10.72938, + "13": 9.8103, + "14": 10.16776, + "15": 10.34088, + "16": 10.39001, + "17": 10.33746, + "18": 10.03205, + "19": 10.40886, + "20": 10.17433, + "21": 10.41968, + "22": 9.93423, + "23": 10.27377, + "24": 10.01972, + "25": 10.6325, + "26": 10.27763, + "27": 10.2853, + "28": 10.30143, + "29": 10.34056, + "30": 10.23697, + "31": 10.43065, + "32": 9.49143, + "33": 9.97924, + "34": 10.48027, + "35": 10.40919, + "36": 9.84774, + "37": 10.5738, + "38": 10.35817, + "39": 10.53096, + "40": 10.45813, + "41": 11.01492, + "42": 11.30727, + "43": 10.54763, + "44": 10.72116, + "45": 11.32983, + "46": 10.88386, + "47": 10.6974, + "48": 10.6521, + "49": 10.74413, + "50": 11.16561 } }, "num-zeros": { @@ -61,56 +61,56 @@ "end_step": 50, "step_interval": 1, "values": { - "1": 47167904.0, - "2": 46900672.0, - "3": 81004512.0, - "4": 231040016.0, - "5": 477984896.0, - "6": 558059904.0, - "7": 958271680.0, - "8": 723959296.0, - "9": 802607040.0, - "10": 715176064.0, - "11": 657024320.0, - "12": 565795136.0, - "13": 541943680.0, - "14": 773290880.0, - "15": 810566400.0, - "16": 748195712.0, - "17": 730395008.0, - "18": 733261760.0, - "19": 729119744.0, - "20": 859242112.0, - "21": 846155136.0, - "22": 648056832.0, - "23": 774244288.0, - "24": 629192960.0, - "25": 843192448.0, - "26": 846129280.0, - "27": 804864512.0, - "28": 789783424.0, - "29": 817814656.0, - "30": 808743168.0, - "31": 662987648.0, - "32": 841163840.0, - "33": 676597440.0, - "34": 808569792.0, - "35": 804410048.0, - "36": 749336000.0, - "37": 759355904.0, - "38": 768597888.0, - "39": 758146688.0, - "40": 767096448.0, - "41": 735961920.0, - "42": 705693632.0, - "43": 694921152.0, - "44": 692872768.0, - "45": 638337792.0, - "46": 654254336.0, - "47": 655022208.0, - "48": 648030848.0, - "49": 622397184.0, - "50": 582138304.0 + "1": 47167796.0, + "2": 46899648.0, + "3": 58981488.0, + "4": 1722086400.0, + "5": 172862128.0, + "6": 167948304.0, + "7": 461255808.0, + "8": 1722086400.0, + "9": 1722086400.0, + "10": 155206176.0, + "11": 276392000.0, + "12": 449419776.0, + "13": 1722086400.0, + "14": 704085952.0, + "15": 684765632.0, + "16": 568911360.0, + "17": 1722086400.0, + "18": 944021952.0, + "19": 663072256.0, + "20": 1722086400.0, + "21": 802140416.0, + "22": 1722086400.0, + "23": 1722086400.0, + "24": 1722086400.0, + "25": 802352640.0, + "26": 1006560192.0, + "27": 710512832.0, + "28": 959659776.0, + "29": 1722086400.0, + "30": 962886400.0, + "31": 1722086400.0, + "32": 630419136.0, + "33": 1722086400.0, + "34": 585257664.0, + "35": 857957504.0, + "36": 1722086400.0, + "37": 674456896.0, + "38": 1722086400.0, + "39": 635463424.0, + "40": 559496704.0, + "41": 1722086400.0, + "42": 1722086400.0, + "43": 581694144.0, + "44": 592225088.0, + "45": 484204448.0, + "46": 493839392.0, + "47": 1722086400.0, + "48": 500215200.0, + "49": 499718016.0, + "50": 1722086400.0 } }, "mem-allocated-bytes": { @@ -118,56 +118,56 @@ "end_step": 50, "step_interval": 1, "values": { - "1": 4305058304.0, - "2": 4305059840.0, - "3": 4305059840.0, - "4": 4305059840.0, - "5": 4305059840.0, - "6": 4305059840.0, - "7": 4305059840.0, - "8": 4305059840.0, - "9": 4305059840.0, - "10": 4305059840.0, - "11": 4305059840.0, - "12": 4305059840.0, - "13": 4305059840.0, - "14": 4305059840.0, - "15": 4305059840.0, - "16": 4305059840.0, - "17": 4305059840.0, - "18": 4305059840.0, - "19": 4305059840.0, - "20": 4305059840.0, - "21": 4305059840.0, - "22": 4305059840.0, - "23": 4305059840.0, - "24": 4305059840.0, - "25": 4305059840.0, - "26": 4305059840.0, - "27": 4305059840.0, - "28": 4305059840.0, - "29": 4305059840.0, - "30": 4305059840.0, - "31": 4305059840.0, - "32": 4305059840.0, - "33": 4305059840.0, - "34": 4305059840.0, - "35": 4305059840.0, - "36": 4305059840.0, - "37": 4305059840.0, - "38": 4305059840.0, - "39": 4305059840.0, - "40": 4305059840.0, - "41": 4305059840.0, - "42": 4305059840.0, - "43": 4305059840.0, - "44": 4305059840.0, - "45": 4305059840.0, - "46": 4305059840.0, - "47": 4305059840.0, - "48": 4305059840.0, - "49": 4305059840.0, - "50": 4305059840.0 + "1": 4313446912.0, + "2": 4313448448.0, + "3": 4313448448.0, + "4": 4313448448.0, + "5": 4313448448.0, + "6": 4313448448.0, + "7": 4313448448.0, + "8": 4313448448.0, + "9": 4313448448.0, + "10": 4313448448.0, + "11": 4313448448.0, + "12": 4313448448.0, + "13": 4313448448.0, + "14": 4313448448.0, + "15": 4313448448.0, + "16": 4313448448.0, + "17": 4313448448.0, + "18": 4313448448.0, + "19": 4313448448.0, + "20": 4313448448.0, + "21": 4313448448.0, + "22": 4313448448.0, + "23": 4313448448.0, + "24": 4313448448.0, + "25": 4313448448.0, + "26": 4313448448.0, + "27": 4313448448.0, + "28": 4313448448.0, + "29": 4313448448.0, + "30": 4313448448.0, + "31": 4313448448.0, + "32": 4313448448.0, + "33": 4313448448.0, + "34": 4313448448.0, + "35": 4313448448.0, + "36": 4313448448.0, + "37": 4313448448.0, + "38": 4313448448.0, + "39": 4313448448.0, + "40": 4313448448.0, + "41": 4313448448.0, + "42": 4313448448.0, + "43": 4313448448.0, + "44": 4313448448.0, + "45": 4313448448.0, + "46": 4313448448.0, + "47": 4313448448.0, + "48": 4313448448.0, + "49": 4313448448.0, + "50": 4313448448.0 } }, "mem-max-allocated-bytes": { @@ -175,56 +175,56 @@ "end_step": 50, "step_interval": 1, "values": { - "1": 4305060864.0, - "2": 5850929152.0, - "3": 5850929152.0, - "4": 5857025536.0, - "5": 5857025536.0, - "6": 5857025536.0, - "7": 5857025536.0, - "8": 5857025536.0, - "9": 5857025536.0, - "10": 5857025536.0, - "11": 5857025536.0, - "12": 5857025536.0, - "13": 5857025536.0, - "14": 5857025536.0, - "15": 5857025536.0, - "16": 5857025536.0, - "17": 5857025536.0, - "18": 5857025536.0, - "19": 5857025536.0, - "20": 5857025536.0, - "21": 5857025536.0, - "22": 5857025536.0, - "23": 5857025536.0, - "24": 5857025536.0, - "25": 5857025536.0, - "26": 5857025536.0, - "27": 5857025536.0, - "28": 5857025536.0, - "29": 5857025536.0, - "30": 5857025536.0, - "31": 5857025536.0, - "32": 5857025536.0, - "33": 5857025536.0, - "34": 5857025536.0, - "35": 5857025536.0, - "36": 5857025536.0, - "37": 5857025536.0, - "38": 5857025536.0, - "39": 5857025536.0, - "40": 5857025536.0, - "41": 5857025536.0, - "42": 5857025536.0, - "43": 5857025536.0, - "44": 5857025536.0, - "45": 5857025536.0, - "46": 5857025536.0, - "47": 5857025536.0, - "48": 5857025536.0, - "49": 5857025536.0, - "50": 5860186112.0 + "1": 4313449472.0, + "2": 7108272640.0, + "3": 7108272640.0, + "4": 7108272640.0, + "5": 7128937984.0, + "6": 7128937984.0, + "7": 7128937984.0, + "8": 7128937984.0, + "9": 7128937984.0, + "10": 7128937984.0, + "11": 7128937984.0, + "12": 7128937984.0, + "13": 7128937984.0, + "14": 7128937984.0, + "15": 7128937984.0, + "16": 7128937984.0, + "17": 7128937984.0, + "18": 7128937984.0, + "19": 7128937984.0, + "20": 7128937984.0, + "21": 7128937984.0, + "22": 7128937984.0, + "23": 7128937984.0, + "24": 7128937984.0, + "25": 7128937984.0, + "26": 7128937984.0, + "27": 7128937984.0, + "28": 7128937984.0, + "29": 7128937984.0, + "30": 7128937984.0, + "31": 7128937984.0, + "32": 7128937984.0, + "33": 7128937984.0, + "34": 7128937984.0, + "35": 7128937984.0, + "36": 7128937984.0, + "37": 7128937984.0, + "38": 7128937984.0, + "39": 7128937984.0, + "40": 7128937984.0, + "41": 7128937984.0, + "42": 7128937984.0, + "43": 7129990656.0, + "44": 7129990656.0, + "45": 7129990656.0, + "46": 7129990656.0, + "47": 7129990656.0, + "48": 7129990656.0, + "49": 7129990656.0, + "50": 7129990656.0 } }, "iteration-time": { @@ -232,56 +232,56 @@ "end_step": 50, "step_interval": 1, "values": { - "1": 89.57975, - "2": 3.08398, - "3": 3.39072, - "4": 2.95563, - "5": 3.89951, - "6": 1.99592, - "7": 2.70541, - "8": 1.95431, - "9": 1.95178, - "10": 1.95311, - "11": 2.53128, - "12": 2.03561, - "13": 2.63986, - "14": 1.9956, - "15": 1.94751, - "16": 1.94319, - "17": 1.96972, - "18": 2.07225, - "19": 1.94281, - "20": 1.9489, - "21": 1.94199, - "22": 1.95565, - "23": 1.94632, - "24": 1.94485, - "25": 1.94325, - "26": 1.96685, - "27": 2.00745, - "28": 1.94741, - "29": 1.95606, - "30": 1.95414, - "31": 2.57092, - "32": 1.95172, - "33": 1.94952, - "34": 1.95519, - "35": 1.95735, - "36": 1.94985, - "37": 1.95117, - "38": 1.96384, - "39": 1.98373, - "40": 1.98071, - "41": 1.96168, - "42": 1.97892, - "43": 1.97654, - "44": 1.95705, - "45": 1.95269, - "46": 2.02666, - "47": 1.96138, - "48": 1.9657, - "49": 1.96155, - "50": 1.96872 + "1": 71.71892, + "2": 2.43398, + "3": 2.53134, + "4": 2.58837, + "5": 2.45297, + "6": 2.23546, + "7": 2.09269, + "8": 2.17361, + "9": 2.21504, + "10": 2.09975, + "11": 2.0732, + "12": 2.12061, + "13": 2.19105, + "14": 2.07416, + "15": 2.05962, + "16": 2.06449, + "17": 2.06534, + "18": 2.06832, + "19": 2.09788, + "20": 2.06188, + "21": 2.06072, + "22": 2.0677, + "23": 2.05993, + "24": 2.06692, + "25": 2.10922, + "26": 2.06561, + "27": 2.06369, + "28": 2.08584, + "29": 2.0623, + "30": 2.06367, + "31": 2.06523, + "32": 2.06471, + "33": 2.06243, + "34": 2.05839, + "35": 2.0663, + "36": 2.07558, + "37": 2.08622, + "38": 2.07519, + "39": 2.07009, + "40": 2.07146, + "41": 2.09338, + "42": 2.08324, + "43": 2.08632, + "44": 2.07644, + "45": 2.0922, + "46": 2.07436, + "47": 2.07246, + "48": 2.07957, + "49": 2.08348, + "50": 2.09287 } } -} \ No newline at end of file +} diff --git a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/model_config.yaml b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/model_config.yaml index 5b177ed116d..d1fcd8fd4b7 100644 --- a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/model_config.yaml +++ b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/model_config.yaml @@ -5,6 +5,10 @@ ENV_VARS: NCCL_NVLS_ENABLE: 0 PYTHONWARNINGS: ignore NCCL_DEBUG: VERSION + NVTE_CPU_OFFLOAD_V1: 1 + NVTE_FUSED_ATTN: 0 + NCCL_ALGO: ^NVLS + CUBLAS_WORKSPACE_CONFIG: ':4096:8' MODEL_ARGS: # Distributed args --distributed-timeout-minutes: 60 @@ -29,8 +33,6 @@ MODEL_ARGS: --exit-duration-in-mins: 230 --no-check-for-nan-in-loss-and-grad: true --no-rope-fusion: true - --cross-entropy-loss-fusion: true - --cross-entropy-fusion-impl: native --manual-gc: true --manual-gc-interval: 100 --recompute-granularity: selective @@ -129,6 +131,5 @@ TEST_TYPE: regular # Usually ckpt-resume, but as a WAR to #513 set to regular METRICS: # - "iteration-time" - "lm loss" - - "num-zeros" - "mem-allocated-bytes" - "mem-max-allocated-bytes" diff --git a/tests/unit_tests/models/test_mamba_moe_model.py b/tests/unit_tests/models/test_mamba_moe_model.py new file mode 100644 index 00000000000..770bc312aeb --- /dev/null +++ b/tests/unit_tests/models/test_mamba_moe_model.py @@ -0,0 +1,552 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +import hashlib +import inspect +import json +import os +import sys +from typing import Any, Dict, Mapping, Tuple + +import pytest # type: ignore[import] +import torch + +from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec +from megatron.core.models.mamba.mamba_model import MambaModel +from megatron.core.num_microbatches_calculator import destroy_num_microbatches_calculator +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer import TransformerConfig +from megatron.core.transformer.enums import AttnBackend +from megatron.training.arguments import core_transformer_config_from_args, parse_args, validate_args +from megatron.training.global_vars import ( + destroy_global_vars, + get_args, + set_args, + set_global_variables, +) +from tests.unit_tests.test_utilities import Utils + +GOLDEN_CONFIG: Dict[str, Any] = { + "_cpu_offloading_context": None, + "account_for_embedding_in_pipeline_split": False, + "account_for_loss_in_pipeline_split": False, + "activation_func": "megatron.core.activations.squared_relu", + "activation_func_clamp_value": None, + "activation_func_fp8_input_store": False, + "add_bias_linear": False, + "add_qkv_bias": False, + "apply_query_key_layer_scaling": False, + "apply_residual_connection_post_layernorm": False, + "apply_rope_fusion": False, + "async_tensor_model_parallel_allreduce": True, + "attention_backend": { + "__objclass__": "megatron.core.transformer.enums.AttnBackend", + "_name_": "flash", + "_sort_order_": 0, + "_value_": 1, + }, + "attention_dropout": 0.0, + "attention_output_gate": False, + "attention_softmax_in_fp32": False, + "autocast_dtype": "torch.bfloat16", + "barrier_with_L1_time": True, + "batch_invariant_mode": False, + "batch_p2p_comm": True, + "batch_p2p_sync": True, + "bf16": True, + "bias_activation_fusion": False, + "bias_dropout_fusion": True, + "calculate_per_token_loss": False, + "clone_scatter_output_in_embedding": True, + "config_logger_dir": "", + "context_parallel_size": 1, + "cp_comm_type": "p2p", + "cpu_offloading": False, + "cpu_offloading_activations": True, + "cpu_offloading_double_buffering": False, + "cpu_offloading_num_layers": 0, + "cpu_offloading_weights": False, + "cross_entropy_fusion_impl": "native", + "cross_entropy_loss_fusion": True, + "cuda_graph_impl": "none", + "cuda_graph_retain_backward_graph": False, + "cuda_graph_scope": [], + "cuda_graph_use_single_mempool": False, + "cuda_graph_warmup_steps": 3, + "deallocate_pipeline_outputs": True, + "defer_embedding_wgrad_compute": False, + "delay_wgrad_compute": False, + "deterministic_mode": False, + "disable_bf16_reduced_precision_matmul": False, + "disable_parameter_transpose_cache": False, + "distribute_saved_activations": False, + "embedding_init_method": {}, + "embedding_init_method_std": 0.014, + "enable_autocast": False, + "enable_cuda_graph": False, + "ep_overlap_early_attn_memory_release": False, + "expert_model_parallel_size": 4, + "expert_tensor_parallel_size": 1, + "external_cuda_graph": False, + "ffn_hidden_size": 1856, + "finalize_model_grads_func": None, + "first_last_layers_bf16": False, + "flash_decode": False, + "fp16": False, + "fp32_residual_connection": False, + "fp4": None, + "fp4_param": False, + "fp4_quantizer_factory": None, + "fp4_recipe": "nvfp4", + "fp8": None, + "fp8_amax_compute_algo": "most_recent", + "fp8_amax_history_len": 1, + "fp8_dot_product_attention": False, + "fp8_interval": 1, + "fp8_margin": 0, + "fp8_multi_head_attention": False, + "fp8_param": False, + "fp8_quantizer_factory": None, + "fp8_recipe": "delayed", + "fp8_wgrad": True, + "fused_single_qkv_rope": False, + "gated_linear_unit": False, + "glu_linear_offset": 0.0, + "grad_scale_func": None, + "grad_sync_func": None, + "gradient_accumulation_fusion": True, + "hetereogenous_dist_checkpoint": False, + "heterogeneous_block_specs": False, + "hidden_dropout": 0.0, + "hidden_size": 2688, + "hierarchical_context_parallel_sizes": None, + "inference_fuse_tp_communication": False, + "inference_rng_tracker": False, + "inference_sampling_seed": 42, + "init_method": {}, + "init_method_std": 0.014, + "init_model_with_meta_device": False, + "is_hybrid_model": True, + "kitchen_attention_backend": "sdpa", + "kv_channels": 128, + "layernorm_epsilon": 1e-05, + "layernorm_zero_centered_gamma": False, + "log_max_attention_logit": False, + "mamba_head_dim": 64, + "mamba_num_groups": 8, + "mamba_num_heads": 64, + "mamba_state_dim": 128, + "masked_softmax_fusion": True, + "memory_efficient_layer_norm": False, + "microbatch_group_size_per_vp_stage": 1, + "mlp_chunks_for_prefill": 1, + "moe_apply_probs_on_input": False, + "moe_aux_loss_coeff": 0.0, + "moe_deepep_num_sms": 20, + "moe_enable_deepep": False, + "moe_expert_capacity_factor": None, + "moe_extended_tp": False, + "moe_ffn_hidden_size": 1856, + "moe_flex_dispatcher_backend": "deepep", + "moe_grouped_gemm": True, + "moe_hybridep_num_sms": 16, + "moe_input_jitter_eps": None, + "moe_latent_size": None, + "moe_layer_freq": 1, + "moe_layer_recompute": False, + "moe_pad_expert_input_to_capacity": False, + "moe_per_layer_logging": False, + "moe_permute_fusion": False, + "moe_router_bias_update_rate": 0.001, + "moe_router_dtype": "fp64", + "moe_router_enable_expert_bias": True, + "moe_router_force_load_balancing": False, + "moe_router_fusion": False, + "moe_router_group_topk": None, + "moe_router_load_balancing_type": "aux_loss", + "moe_router_num_groups": None, + "moe_router_padding_for_fp8": False, + "moe_router_padding_for_quantization": False, + "moe_router_pre_softmax": False, + "moe_router_score_function": "sigmoid", + "moe_router_topk": 6, + "moe_router_topk_limited_devices": None, + "moe_router_topk_scaling_factor": 2.5, + "moe_shared_expert_gate": False, + "moe_shared_expert_intermediate_size": 3712, + "moe_shared_expert_overlap": False, + "moe_token_dispatcher_type": "alltoall", + "moe_token_drop_policy": "probs", + "moe_token_dropping": False, + "moe_use_legacy_grouped_gemm": False, + "moe_z_loss_coeff": None, + "mrope_section": None, + "mtp_loss_scaling_factor": 0.1, + "mtp_num_layers": None, + "multi_latent_attention": False, + "no_rope_freq": None, + "no_sync_func": None, + "normalization": "RMSNorm", + "num_attention_heads": 32, + "num_layers": 52, + "num_layers_at_end_in_bf16": 1, + "num_layers_at_start_in_bf16": 1, + "num_layers_in_first_pipeline_stage": None, + "num_layers_in_last_pipeline_stage": None, + "num_microbatches_with_partial_activation_checkpoints": None, + "num_moe_experts": 128, + "num_query_groups": 2, + "output_layer_init_method": {}, + "overlap_moe_expert_parallel_comm": False, + "overlap_p2p_comm": False, + "overlap_p2p_comm_warmup_flush": False, + "param_sync_func": None, + "params_dtype": "torch.bfloat16", + "perform_initialization": True, + "persist_layer_norm": True, + "pipeline_dtype": "torch.bfloat16", + "pipeline_model_parallel_comm_backend": None, + "pipeline_model_parallel_layout": None, + "pipeline_model_parallel_size": 1, + "qk_clip": False, + "qk_clip_alpha": 0.5, + "qk_clip_threshold": 100, + "qk_layernorm": False, + "quant_recipe": None, + "recompute_granularity": None, + "recompute_method": None, + "recompute_modules": ["core_attn"], + "recompute_num_layers": None, + "rotary_interleaved": False, + "sequence_parallel": True, + "softmax_scale": None, + "softmax_type": "vanilla", + "symmetric_ar_type": None, + "tensor_model_parallel_size": 2, + "test_mode": False, + "timers": None, + "tp_comm_atomic_ag": False, + "tp_comm_atomic_rs": False, + "tp_comm_bootstrap_backend": "nccl", + "tp_comm_bulk_dgrad": True, + "tp_comm_bulk_wgrad": True, + "tp_comm_overlap": False, + "tp_comm_overlap_ag": True, + "tp_comm_overlap_disable_fc1": False, + "tp_comm_overlap_disable_qkv": False, + "tp_comm_overlap_rs": True, + "tp_comm_overlap_rs_dgrad": False, + "tp_comm_split_ag": True, + "tp_comm_split_rs": True, + "tp_only_amax_red": False, + "transformer_impl": "transformer_engine", + "use_cpu_initialization": None, + "use_fused_weighted_squared_relu": False, + "use_inference_optimized_layers": False, + "use_kitchen": False, + "use_kitchen_attention": False, + "use_mamba_mem_eff_path": True, + "use_ring_exchange_p2p": False, + "use_te_activation_func": False, + "use_te_rng_tracker": False, + "variable_seq_lengths": False, + "virtual_pipeline_model_parallel_size": None, + "wgrad_deferral_limit": 0, + "window_attn_skip_freq": None, + "window_size": None, + "fine_grained_activation_offloading": False, + "min_offloaded_tensor_size": 1024 * 1024, + "offload_modules": [], +} +# Fields to ignore entirely (ephemeral, environment-specific, very large). +SKIP_FIELDS = set() +# Fields that are allowed to appear in the live config even if not yet in the golden. +ALLOW_ADDED_FIELDS = set() + + +def serialize_config(cfg: Any) -> Dict[str, Any]: + """Normalize a config object into a JSON-serializable dict.""" + data = {k: v for k, v in vars(cfg).items() if k not in SKIP_FIELDS} + return _ser(data) + + +def assert_config_matches_golden(cfg: Any) -> None: + """Compare live config to golden snapshot with readable diffs.""" + current = serialize_config(cfg) + golden = GOLDEN_CONFIG + + added, removed, changed = _diff_configs(golden, current) + + # Ignore added fields that are explicitly allowed. + added = [k for k in added if k not in ALLOW_ADDED_FIELDS] + + if added or removed or changed: + # Build actionable guidance for each type of drift + guidance_parts = [] + + if added: + guidance_parts.append( + f"\n\n[ADDED ARGS]: {sorted(added)}\n" + " → Update GOLDEN_CONFIG in this test file to include the new arg(s) with " + "their default value(s).\n" + " ⚠️ CAUTION: Review any logic associated with new args to ensure it doesn't " + "silently affect downstream model configs or behavior.\n" + ) + + if changed: + guidance_parts.append( + f"\n\n[CHANGED DEFAULTS]: {sorted(changed)}\n" + " → Please don't change the default values of existing args unless " + "it is absolutely necessary for a bug fix.\n" + " → If you must change the default value, please update the GOLDEN_CONFIG " + "in this test file to reflect the new default value.\n" + ) + + if removed: + guidance_parts.append( + f"\n\n[REMOVED ARGS]: {sorted(removed)}\n" + " → Do NOT remove args directly. Instead, deprecate them with a warning message " + "to maintain backwards compatibility.\n" + ) + + guidance_parts.append( + "Please contact NV-username @jbarker if you are unsure how to proceed.\n" + ) + + header = "Mamba MoE config drift detected!\n" "═" * 60 + "".join(guidance_parts) + parts = [header] + if changed: + formatted = {k: {"expected": golden[k], "actual": current[k]} for k in sorted(changed)} + parts.append( + f"Changed field details:\n{json.dumps(formatted, indent=2, sort_keys=True)}" + ) + pytest.fail("\n".join(parts)) + + +def regenerate_mamba_moe_golden(cfg: Any) -> Dict[str, Any]: + """Helper to regenerate the golden config; copy/paste into GOLDEN_CONFIG.""" + serialized = serialize_config(cfg) + return serialized + + +def _ser(obj: Any) -> Any: + """Recursively convert objects to JSON-friendly structures.""" + if obj is None or isinstance(obj, (bool, int, float, str)): + return obj + if isinstance(obj, dict): + return {k: _ser(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [_ser(v) for v in obj] + if inspect.isfunction(obj) or inspect.ismethod(obj): + return f"{obj.__module__}.{obj.__name__}" + if inspect.isclass(obj): + return f"{obj.__module__}.{obj.__name__}" + if hasattr(obj, "__dict__"): + return {k: _ser(v) for k, v in vars(obj).items()} + try: + return str(obj) + except Exception: + return f"" + + +def _diff_configs(expected: Mapping[str, Any], actual: Mapping[str, Any]) -> Tuple[set, set, set]: + """Return added, removed, and changed top-level keys between dicts.""" + expected_keys = set(expected) + actual_keys = set(actual) + added = actual_keys - expected_keys + removed = expected_keys - actual_keys + changed = {k for k in expected_keys & actual_keys if expected[k] != actual[k]} + return added, removed, changed + + +class TestMambaMoEModel: + """Test the initialization and use of an MoE Mamba model.""" + + def create_test_args(self): + destroy_global_vars() + destroy_num_microbatches_calculator() + + sys.argv = ['test_mamba_moe_model.py'] + args = parse_args() + + # The following args would be set from the nano v3 checkpoint. + args.num_layers = 52 + args.hidden_size = 2688 + args.ffn_hidden_size = 1856 + args.num_attention_heads = 32 + args.num_query_groups = 2 + args.group_query_attention = True + args.kv_channels = 128 + args.position_embedding_type = 'none' + args.add_position_embedding = True + args.use_rotary_position_embeddings = False + args.rotary_base = 10000 + args.rotary_percent = 1.0 + args.rotary_interleaved = False + args.add_bias_linear = False + args.add_qkv_bias = False + args.squared_relu = True + args.swiglu = False + args.untie_embeddings_and_output_weights = True + args.apply_layernorm_1p = False + args.normalization = "RMSNorm" + args.apply_query_key_layer_scaling = False + args.attention_dropout = 0.0 + args.hidden_dropout = 0.0 + args.hybrid_override_pattern = "MEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEMEM*EMEMEMEME" + args.spec = ["megatron.core.models.mamba.mamba_layer_specs", "mamba_stack_spec"] + args.hybrid_attention_ratio = 0.0 + args.hybrid_mlp_ratio = 0.0 + args.num_experts = 128 + args.moe_layer_freq = 1 + args.moe_ffn_hidden_size = 1856 + args.moe_router_topk = 6 + args.moe_router_pre_softmax = False + args.moe_grouped_gemm = True + args.moe_shared_expert_intermediate_size = 3712 + args.moe_router_score_function = "sigmoid" + args.moe_router_enable_expert_bias = True + args.moe_router_topk_scaling_factor = 2.5 + args.mamba_state_dim = 128 + args.mamba_head_dim = 64 + args.mamba_num_groups = 8 + args.mamba_num_heads = 64 + args.is_hybrid_model = True + args.tokenizer_type = "TikTokenizer" + args.tiktoken_pattern = "v2" + args.tokenizer_model = "/mnt/artifacts/model/nemotron6/tokenizers/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json" + args.padded_vocab_size = 131072 + + # The following args would be set in the user's nano v3 config. + args.async_tensor_model_parallel_allreduce = True + args.attention_backend = AttnBackend.flash + args.bf16 = True + args.ckpt_format = 'torch_dist' + args.cross_entropy_loss_fusion = True + args.cuda_graph_impl = "none" + args.embedding_init_method_std = 0.014 + args.expert_model_parallel_size = 4 + args.expert_tensor_parallel_size = 1 + args.init_method_std = 0.014 + args.lr = 3e-5 + args.max_position_embeddings = 1024 + args.micro_batch_size = 2 + args.moe_aux_loss_coeff = 0.0 + args.moe_grouped_gemm = True + args.moe_route_load_balancing_type = "aux_loss" + args.moe_router_dtype = "fp64" + args.moe_router_pre_softmax = False + args.moe_token_dispatcher_type = "alltoall" + args.no_load_optim = True + args.no_load_rng = True + args.no_save_optim = True + args.pipeline_model_parallel_size = 1 + args.position_embedding_type = None + args.recompute_granularity = None + args.seed = 42 + args.seq_length = 1024 + args.sequence_parallel = True + args.te_rng_tracker = True + args.tensor_model_parallel_size = 2 + args.vocab_size = 131072 + + validate_args(args) + set_global_variables(args, False) + return args + + def setup_method(self, method): + + os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' + args = self.create_test_args() + set_args(args) + + Utils.initialize_model_parallel( + tensor_model_parallel_size=args.tensor_model_parallel_size, + pipeline_model_parallel_size=args.pipeline_model_parallel_size, + expert_model_parallel_size=args.expert_model_parallel_size, + expert_tensor_parallel_size=args.expert_tensor_parallel_size, + ) + model_parallel_cuda_manual_seed(123) + + model_config = core_transformer_config_from_args(args, TransformerConfig) + + self.model = MambaModel( + config=model_config, + mamba_stack_spec=mamba_stack_spec, + vocab_size=args.vocab_size, + max_sequence_length=args.seq_length, + hybrid_attention_ratio=args.hybrid_attention_ratio, + hybrid_mlp_ratio=args.hybrid_mlp_ratio, + hybrid_override_pattern=args.hybrid_override_pattern, + position_embedding_type=args.position_embedding_type, + rotary_base=args.rotary_base, + rotary_percent=args.rotary_percent, + ) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + def test_constructor(self): + """Sanity check for the constructor of the Mamba MoE model.""" + + args = get_args() + + assert_config_matches_golden(self.model.config) + + assert self.model.pre_process is True, "pre_process should be True" + assert self.model.post_process is True, "post_process should be True" + assert self.model.hybrid_attention_ratio == 0.0, "hybrid_attention_ratio should be 0.0" + assert self.model.hybrid_mlp_ratio == 0.0, "hybrid_mlp_ratio should be 0.0" + assert ( + self.model.hybrid_override_pattern == args.hybrid_override_pattern + ), f"hybrid_override_pattern should be {args.hybrid_override_pattern}" + num_weights = sum([p.numel() for p in self.model.parameters()]) + assert num_weights == 8449294624, f"Expected 8449294624 parameters, got {num_weights}" + + def test_set_input_tensor(self): + + args = get_args() + + config: TransformerConfig = self.model.config + sequence_length = self.model.max_sequence_length + micro_batch_size = args.micro_batch_size + + # [sequence length, batch size, hidden size] + input_tensor = torch.ones((sequence_length, micro_batch_size, config.hidden_size)) + + self.model.set_input_tensor(input_tensor) + + assert self.model.decoder.input_tensor.shape[0] == sequence_length + assert self.model.decoder.input_tensor.shape[1] == micro_batch_size + assert self.model.decoder.input_tensor.shape[2] == config.hidden_size + + def test_forward(self): + """Basic smoke test for the forward pass of the Mamba MoE model.""" + + args = get_args() + + # we must override this to avoid the need to initialize the optimizer + for param in self.model.parameters(): + param.requires_grad = False + + sequence_length = self.model.max_sequence_length + micro_batch_size = args.micro_batch_size + + self.model.cuda() + + data = list(range(sequence_length)) + input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() + position_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() + attention_mask = torch.ones( + (micro_batch_size, 1, sequence_length, sequence_length), dtype=bool + ).cuda() + + logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + runtime_gather_output=True, + ) + + assert logits.shape[0] == micro_batch_size + assert logits.shape[1] == sequence_length + assert logits.shape[2] == self.model.vocab_size From 17c0eb92c05ee2daa8cc0910fc305c761fbad73b Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Thu, 15 Jan 2026 00:49:00 -0800 Subject: [PATCH 48/74] minor fix and format Signed-off-by: Hongbin Liu --- .../core/models/gpt/fine_grained_callables.py | 19 +- megatron/core/transformer/attention.py | 9 +- megatron/core/transformer/moe/experts.py | 5 +- .../transformer/multi_latent_attention.py | 8 +- .../core/transformer/transformer_layer.py | 12 +- .../unit_tests/models/test_mamba_moe_model.py | 552 ------------------ 6 files changed, 31 insertions(+), 574 deletions(-) delete mode 100644 tests/unit_tests/models/test_mamba_moe_model.py diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index 6dbd65f7d99..0a081a767c2 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -453,12 +453,16 @@ def forward_func( return hidden_states, None, None, None if layer.recompute_pre_mlp_layernorm: layer.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput() - with off_interface(layer.offload_mlp_norm, hidden_states, "mlp_norm") as hidden_states: + with off_interface( + layer.offload_mlp_norm, hidden_states, "mlp_norm" + ) as hidden_states: pre_mlp_layernorm_output = layer.pre_mlp_norm_checkpoint.checkpoint( layer.pre_mlp_layernorm, hidden_states ) else: - with off_interface(layer.offload_mlp_norm, hidden_states, "mlp_norm") as hidden_states: + with off_interface( + layer.offload_mlp_norm, hidden_states, "mlp_norm" + ) as hidden_states: pre_mlp_layernorm_output = layer.pre_mlp_layernorm(hidden_states) shared_expert_output = layer.mlp.shared_experts_compute(pre_mlp_layernorm_output) @@ -477,7 +481,16 @@ def forward_func( packed_seq_params=node.chunk_state.packed_seq_params, sequence_len_offset=node.chunk_state.sequence_len_offset, ) - return hidden_states + if not isinstance(layer.mlp, MoELayer): + return hidden_states + + # Detach here for mlp_bda residual connection + node.layer_state.residual = node.detach(hidden_states) + if layer.mlp.use_shared_expert and not layer.mlp.shared_expert_overlap: + # Detach here for shared expert connection in moe_combine + node.layer_state.shared_expert_output = node.detach(shared_expert_output) + + return local_tokens, probs def submodule_dispatch_forward( node: ScheduleNode, local_tokens: torch.Tensor, probs: torch.Tensor diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 361c5d64590..3cb1a5ee4a4 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -29,7 +29,6 @@ ) from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( fine_grained_offloading_group_commit, - fine_grained_offloading_group_start, ) from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel.mappings import all_gather_last_dim_from_tensor_parallel_region @@ -832,8 +831,7 @@ def forward( if output_gate: assert split_qkv, "output_gate is not supported for unsplit mixed_qkv tensor." - with off_interface(self.offload_qkv_linear, hidden_states, "qkv_linear") \ - as hidden_states: + with off_interface(self.offload_qkv_linear, hidden_states, "qkv_linear") as hidden_states: qkv_output = self.get_query_key_value_tensors( hidden_states, key_value_states, output_gate=output_gate, split_qkv=split_qkv ) @@ -993,8 +991,9 @@ def forward( else: if inference_context is None or inference_context.is_static_batching(): # Static batching attention kernel. - with off_interface(self.offload_core_attention and self.training, query, \ - "core_attn") as query: + with off_interface( + self.offload_core_attention and self.training, query, "core_attn" + ) as query: core_attn_out = self.core_attention( query, key, diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index eca82b04570..b43ddbb893a 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -732,8 +732,9 @@ def forward( # Probs already applied, so reset to 1. permuted_probs = torch.ones_like(permuted_probs) - with off_interface(self.offload_expert_fc1, permuted_local_hidden_states, "expert_fc1") \ - as permuted_local_hidden_states: + with off_interface( + self.offload_expert_fc1, permuted_local_hidden_states, "expert_fc1" + ) as permuted_local_hidden_states: fc1_output, bias_parallel = self.linear_fc1( permuted_local_hidden_states, tokens_per_expert ) diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index cc82c66dbb1..2e50d9f2169 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -245,8 +245,7 @@ def forward( # Get the query, key and value tensors based on the type of attention - # self or cross attn. # query: [96, 1, 16, 128], key:[96, 1, 16, 128], value:[96, 1, 16, 128] - with off_interface(self.offload_qkv_linear, hidden_states, "qkv_linear") \ - as hidden_states: + with off_interface(self.offload_qkv_linear, hidden_states, "qkv_linear") as hidden_states: if self.config.experimental_attention_variant is None: query, key, value = self.get_query_key_value_tensors( hidden_states, @@ -300,8 +299,9 @@ def forward( ) else: if inference_context is None or inference_context.is_static_batching(): - with off_interface(self.offload_core_attention and self.training, query, \ - "core_attn") as query: + with off_interface( + self.offload_core_attention and self.training, query, "core_attn" + ) as query: if self.config.experimental_attention_variant is None: core_attn_out = self.core_attention( query, diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 1b8ce94682a..6a5d1e7576f 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -526,14 +526,12 @@ def _forward_attention( # Optional Input Layer norm if self.recompute_input_layernorm: self.input_layernorm_checkpoint = tensor_parallel.CheckpointWithoutOutput() - with off_interface(self.offload_attn_norm, hidden_states, "attn_norm") \ - as hidden_states: + with off_interface(self.offload_attn_norm, hidden_states, "attn_norm") as hidden_states: input_layernorm_output = self.input_layernorm_checkpoint.checkpoint( self.input_layernorm, hidden_states ) else: - with off_interface(self.offload_attn_norm, hidden_states, "attn_norm") \ - as hidden_states: + with off_interface(self.offload_attn_norm, hidden_states, "attn_norm") as hidden_states: input_layernorm_output = self.input_layernorm(hidden_states) # Self attention. @@ -628,14 +626,12 @@ def _forward_mlp(self, hidden_states, inference_context=None, padding_mask=None) # Optional Layer norm post the cross-attention. if self.recompute_pre_mlp_layernorm: self.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput() - with off_interface(self.offload_mlp_norm, hidden_states, "mlp_norm") \ - as hidden_states: + with off_interface(self.offload_mlp_norm, hidden_states, "mlp_norm") as hidden_states: pre_mlp_layernorm_output = self.pre_mlp_norm_checkpoint.checkpoint( self.pre_mlp_layernorm, hidden_states ) else: - with off_interface(self.offload_mlp_norm, hidden_states, "mlp_norm") \ - as hidden_states: + with off_interface(self.offload_mlp_norm, hidden_states, "mlp_norm") as hidden_states: pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states) nvtx_range_push(suffix="mlp") diff --git a/tests/unit_tests/models/test_mamba_moe_model.py b/tests/unit_tests/models/test_mamba_moe_model.py deleted file mode 100644 index 770bc312aeb..00000000000 --- a/tests/unit_tests/models/test_mamba_moe_model.py +++ /dev/null @@ -1,552 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - -import hashlib -import inspect -import json -import os -import sys -from typing import Any, Dict, Mapping, Tuple - -import pytest # type: ignore[import] -import torch - -from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec -from megatron.core.models.mamba.mamba_model import MambaModel -from megatron.core.num_microbatches_calculator import destroy_num_microbatches_calculator -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer import TransformerConfig -from megatron.core.transformer.enums import AttnBackend -from megatron.training.arguments import core_transformer_config_from_args, parse_args, validate_args -from megatron.training.global_vars import ( - destroy_global_vars, - get_args, - set_args, - set_global_variables, -) -from tests.unit_tests.test_utilities import Utils - -GOLDEN_CONFIG: Dict[str, Any] = { - "_cpu_offloading_context": None, - "account_for_embedding_in_pipeline_split": False, - "account_for_loss_in_pipeline_split": False, - "activation_func": "megatron.core.activations.squared_relu", - "activation_func_clamp_value": None, - "activation_func_fp8_input_store": False, - "add_bias_linear": False, - "add_qkv_bias": False, - "apply_query_key_layer_scaling": False, - "apply_residual_connection_post_layernorm": False, - "apply_rope_fusion": False, - "async_tensor_model_parallel_allreduce": True, - "attention_backend": { - "__objclass__": "megatron.core.transformer.enums.AttnBackend", - "_name_": "flash", - "_sort_order_": 0, - "_value_": 1, - }, - "attention_dropout": 0.0, - "attention_output_gate": False, - "attention_softmax_in_fp32": False, - "autocast_dtype": "torch.bfloat16", - "barrier_with_L1_time": True, - "batch_invariant_mode": False, - "batch_p2p_comm": True, - "batch_p2p_sync": True, - "bf16": True, - "bias_activation_fusion": False, - "bias_dropout_fusion": True, - "calculate_per_token_loss": False, - "clone_scatter_output_in_embedding": True, - "config_logger_dir": "", - "context_parallel_size": 1, - "cp_comm_type": "p2p", - "cpu_offloading": False, - "cpu_offloading_activations": True, - "cpu_offloading_double_buffering": False, - "cpu_offloading_num_layers": 0, - "cpu_offloading_weights": False, - "cross_entropy_fusion_impl": "native", - "cross_entropy_loss_fusion": True, - "cuda_graph_impl": "none", - "cuda_graph_retain_backward_graph": False, - "cuda_graph_scope": [], - "cuda_graph_use_single_mempool": False, - "cuda_graph_warmup_steps": 3, - "deallocate_pipeline_outputs": True, - "defer_embedding_wgrad_compute": False, - "delay_wgrad_compute": False, - "deterministic_mode": False, - "disable_bf16_reduced_precision_matmul": False, - "disable_parameter_transpose_cache": False, - "distribute_saved_activations": False, - "embedding_init_method": {}, - "embedding_init_method_std": 0.014, - "enable_autocast": False, - "enable_cuda_graph": False, - "ep_overlap_early_attn_memory_release": False, - "expert_model_parallel_size": 4, - "expert_tensor_parallel_size": 1, - "external_cuda_graph": False, - "ffn_hidden_size": 1856, - "finalize_model_grads_func": None, - "first_last_layers_bf16": False, - "flash_decode": False, - "fp16": False, - "fp32_residual_connection": False, - "fp4": None, - "fp4_param": False, - "fp4_quantizer_factory": None, - "fp4_recipe": "nvfp4", - "fp8": None, - "fp8_amax_compute_algo": "most_recent", - "fp8_amax_history_len": 1, - "fp8_dot_product_attention": False, - "fp8_interval": 1, - "fp8_margin": 0, - "fp8_multi_head_attention": False, - "fp8_param": False, - "fp8_quantizer_factory": None, - "fp8_recipe": "delayed", - "fp8_wgrad": True, - "fused_single_qkv_rope": False, - "gated_linear_unit": False, - "glu_linear_offset": 0.0, - "grad_scale_func": None, - "grad_sync_func": None, - "gradient_accumulation_fusion": True, - "hetereogenous_dist_checkpoint": False, - "heterogeneous_block_specs": False, - "hidden_dropout": 0.0, - "hidden_size": 2688, - "hierarchical_context_parallel_sizes": None, - "inference_fuse_tp_communication": False, - "inference_rng_tracker": False, - "inference_sampling_seed": 42, - "init_method": {}, - "init_method_std": 0.014, - "init_model_with_meta_device": False, - "is_hybrid_model": True, - "kitchen_attention_backend": "sdpa", - "kv_channels": 128, - "layernorm_epsilon": 1e-05, - "layernorm_zero_centered_gamma": False, - "log_max_attention_logit": False, - "mamba_head_dim": 64, - "mamba_num_groups": 8, - "mamba_num_heads": 64, - "mamba_state_dim": 128, - "masked_softmax_fusion": True, - "memory_efficient_layer_norm": False, - "microbatch_group_size_per_vp_stage": 1, - "mlp_chunks_for_prefill": 1, - "moe_apply_probs_on_input": False, - "moe_aux_loss_coeff": 0.0, - "moe_deepep_num_sms": 20, - "moe_enable_deepep": False, - "moe_expert_capacity_factor": None, - "moe_extended_tp": False, - "moe_ffn_hidden_size": 1856, - "moe_flex_dispatcher_backend": "deepep", - "moe_grouped_gemm": True, - "moe_hybridep_num_sms": 16, - "moe_input_jitter_eps": None, - "moe_latent_size": None, - "moe_layer_freq": 1, - "moe_layer_recompute": False, - "moe_pad_expert_input_to_capacity": False, - "moe_per_layer_logging": False, - "moe_permute_fusion": False, - "moe_router_bias_update_rate": 0.001, - "moe_router_dtype": "fp64", - "moe_router_enable_expert_bias": True, - "moe_router_force_load_balancing": False, - "moe_router_fusion": False, - "moe_router_group_topk": None, - "moe_router_load_balancing_type": "aux_loss", - "moe_router_num_groups": None, - "moe_router_padding_for_fp8": False, - "moe_router_padding_for_quantization": False, - "moe_router_pre_softmax": False, - "moe_router_score_function": "sigmoid", - "moe_router_topk": 6, - "moe_router_topk_limited_devices": None, - "moe_router_topk_scaling_factor": 2.5, - "moe_shared_expert_gate": False, - "moe_shared_expert_intermediate_size": 3712, - "moe_shared_expert_overlap": False, - "moe_token_dispatcher_type": "alltoall", - "moe_token_drop_policy": "probs", - "moe_token_dropping": False, - "moe_use_legacy_grouped_gemm": False, - "moe_z_loss_coeff": None, - "mrope_section": None, - "mtp_loss_scaling_factor": 0.1, - "mtp_num_layers": None, - "multi_latent_attention": False, - "no_rope_freq": None, - "no_sync_func": None, - "normalization": "RMSNorm", - "num_attention_heads": 32, - "num_layers": 52, - "num_layers_at_end_in_bf16": 1, - "num_layers_at_start_in_bf16": 1, - "num_layers_in_first_pipeline_stage": None, - "num_layers_in_last_pipeline_stage": None, - "num_microbatches_with_partial_activation_checkpoints": None, - "num_moe_experts": 128, - "num_query_groups": 2, - "output_layer_init_method": {}, - "overlap_moe_expert_parallel_comm": False, - "overlap_p2p_comm": False, - "overlap_p2p_comm_warmup_flush": False, - "param_sync_func": None, - "params_dtype": "torch.bfloat16", - "perform_initialization": True, - "persist_layer_norm": True, - "pipeline_dtype": "torch.bfloat16", - "pipeline_model_parallel_comm_backend": None, - "pipeline_model_parallel_layout": None, - "pipeline_model_parallel_size": 1, - "qk_clip": False, - "qk_clip_alpha": 0.5, - "qk_clip_threshold": 100, - "qk_layernorm": False, - "quant_recipe": None, - "recompute_granularity": None, - "recompute_method": None, - "recompute_modules": ["core_attn"], - "recompute_num_layers": None, - "rotary_interleaved": False, - "sequence_parallel": True, - "softmax_scale": None, - "softmax_type": "vanilla", - "symmetric_ar_type": None, - "tensor_model_parallel_size": 2, - "test_mode": False, - "timers": None, - "tp_comm_atomic_ag": False, - "tp_comm_atomic_rs": False, - "tp_comm_bootstrap_backend": "nccl", - "tp_comm_bulk_dgrad": True, - "tp_comm_bulk_wgrad": True, - "tp_comm_overlap": False, - "tp_comm_overlap_ag": True, - "tp_comm_overlap_disable_fc1": False, - "tp_comm_overlap_disable_qkv": False, - "tp_comm_overlap_rs": True, - "tp_comm_overlap_rs_dgrad": False, - "tp_comm_split_ag": True, - "tp_comm_split_rs": True, - "tp_only_amax_red": False, - "transformer_impl": "transformer_engine", - "use_cpu_initialization": None, - "use_fused_weighted_squared_relu": False, - "use_inference_optimized_layers": False, - "use_kitchen": False, - "use_kitchen_attention": False, - "use_mamba_mem_eff_path": True, - "use_ring_exchange_p2p": False, - "use_te_activation_func": False, - "use_te_rng_tracker": False, - "variable_seq_lengths": False, - "virtual_pipeline_model_parallel_size": None, - "wgrad_deferral_limit": 0, - "window_attn_skip_freq": None, - "window_size": None, - "fine_grained_activation_offloading": False, - "min_offloaded_tensor_size": 1024 * 1024, - "offload_modules": [], -} -# Fields to ignore entirely (ephemeral, environment-specific, very large). -SKIP_FIELDS = set() -# Fields that are allowed to appear in the live config even if not yet in the golden. -ALLOW_ADDED_FIELDS = set() - - -def serialize_config(cfg: Any) -> Dict[str, Any]: - """Normalize a config object into a JSON-serializable dict.""" - data = {k: v for k, v in vars(cfg).items() if k not in SKIP_FIELDS} - return _ser(data) - - -def assert_config_matches_golden(cfg: Any) -> None: - """Compare live config to golden snapshot with readable diffs.""" - current = serialize_config(cfg) - golden = GOLDEN_CONFIG - - added, removed, changed = _diff_configs(golden, current) - - # Ignore added fields that are explicitly allowed. - added = [k for k in added if k not in ALLOW_ADDED_FIELDS] - - if added or removed or changed: - # Build actionable guidance for each type of drift - guidance_parts = [] - - if added: - guidance_parts.append( - f"\n\n[ADDED ARGS]: {sorted(added)}\n" - " → Update GOLDEN_CONFIG in this test file to include the new arg(s) with " - "their default value(s).\n" - " ⚠️ CAUTION: Review any logic associated with new args to ensure it doesn't " - "silently affect downstream model configs or behavior.\n" - ) - - if changed: - guidance_parts.append( - f"\n\n[CHANGED DEFAULTS]: {sorted(changed)}\n" - " → Please don't change the default values of existing args unless " - "it is absolutely necessary for a bug fix.\n" - " → If you must change the default value, please update the GOLDEN_CONFIG " - "in this test file to reflect the new default value.\n" - ) - - if removed: - guidance_parts.append( - f"\n\n[REMOVED ARGS]: {sorted(removed)}\n" - " → Do NOT remove args directly. Instead, deprecate them with a warning message " - "to maintain backwards compatibility.\n" - ) - - guidance_parts.append( - "Please contact NV-username @jbarker if you are unsure how to proceed.\n" - ) - - header = "Mamba MoE config drift detected!\n" "═" * 60 + "".join(guidance_parts) - parts = [header] - if changed: - formatted = {k: {"expected": golden[k], "actual": current[k]} for k in sorted(changed)} - parts.append( - f"Changed field details:\n{json.dumps(formatted, indent=2, sort_keys=True)}" - ) - pytest.fail("\n".join(parts)) - - -def regenerate_mamba_moe_golden(cfg: Any) -> Dict[str, Any]: - """Helper to regenerate the golden config; copy/paste into GOLDEN_CONFIG.""" - serialized = serialize_config(cfg) - return serialized - - -def _ser(obj: Any) -> Any: - """Recursively convert objects to JSON-friendly structures.""" - if obj is None or isinstance(obj, (bool, int, float, str)): - return obj - if isinstance(obj, dict): - return {k: _ser(v) for k, v in obj.items()} - if isinstance(obj, (list, tuple)): - return [_ser(v) for v in obj] - if inspect.isfunction(obj) or inspect.ismethod(obj): - return f"{obj.__module__}.{obj.__name__}" - if inspect.isclass(obj): - return f"{obj.__module__}.{obj.__name__}" - if hasattr(obj, "__dict__"): - return {k: _ser(v) for k, v in vars(obj).items()} - try: - return str(obj) - except Exception: - return f"" - - -def _diff_configs(expected: Mapping[str, Any], actual: Mapping[str, Any]) -> Tuple[set, set, set]: - """Return added, removed, and changed top-level keys between dicts.""" - expected_keys = set(expected) - actual_keys = set(actual) - added = actual_keys - expected_keys - removed = expected_keys - actual_keys - changed = {k for k in expected_keys & actual_keys if expected[k] != actual[k]} - return added, removed, changed - - -class TestMambaMoEModel: - """Test the initialization and use of an MoE Mamba model.""" - - def create_test_args(self): - destroy_global_vars() - destroy_num_microbatches_calculator() - - sys.argv = ['test_mamba_moe_model.py'] - args = parse_args() - - # The following args would be set from the nano v3 checkpoint. - args.num_layers = 52 - args.hidden_size = 2688 - args.ffn_hidden_size = 1856 - args.num_attention_heads = 32 - args.num_query_groups = 2 - args.group_query_attention = True - args.kv_channels = 128 - args.position_embedding_type = 'none' - args.add_position_embedding = True - args.use_rotary_position_embeddings = False - args.rotary_base = 10000 - args.rotary_percent = 1.0 - args.rotary_interleaved = False - args.add_bias_linear = False - args.add_qkv_bias = False - args.squared_relu = True - args.swiglu = False - args.untie_embeddings_and_output_weights = True - args.apply_layernorm_1p = False - args.normalization = "RMSNorm" - args.apply_query_key_layer_scaling = False - args.attention_dropout = 0.0 - args.hidden_dropout = 0.0 - args.hybrid_override_pattern = "MEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEMEM*EMEMEMEME" - args.spec = ["megatron.core.models.mamba.mamba_layer_specs", "mamba_stack_spec"] - args.hybrid_attention_ratio = 0.0 - args.hybrid_mlp_ratio = 0.0 - args.num_experts = 128 - args.moe_layer_freq = 1 - args.moe_ffn_hidden_size = 1856 - args.moe_router_topk = 6 - args.moe_router_pre_softmax = False - args.moe_grouped_gemm = True - args.moe_shared_expert_intermediate_size = 3712 - args.moe_router_score_function = "sigmoid" - args.moe_router_enable_expert_bias = True - args.moe_router_topk_scaling_factor = 2.5 - args.mamba_state_dim = 128 - args.mamba_head_dim = 64 - args.mamba_num_groups = 8 - args.mamba_num_heads = 64 - args.is_hybrid_model = True - args.tokenizer_type = "TikTokenizer" - args.tiktoken_pattern = "v2" - args.tokenizer_model = "/mnt/artifacts/model/nemotron6/tokenizers/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json" - args.padded_vocab_size = 131072 - - # The following args would be set in the user's nano v3 config. - args.async_tensor_model_parallel_allreduce = True - args.attention_backend = AttnBackend.flash - args.bf16 = True - args.ckpt_format = 'torch_dist' - args.cross_entropy_loss_fusion = True - args.cuda_graph_impl = "none" - args.embedding_init_method_std = 0.014 - args.expert_model_parallel_size = 4 - args.expert_tensor_parallel_size = 1 - args.init_method_std = 0.014 - args.lr = 3e-5 - args.max_position_embeddings = 1024 - args.micro_batch_size = 2 - args.moe_aux_loss_coeff = 0.0 - args.moe_grouped_gemm = True - args.moe_route_load_balancing_type = "aux_loss" - args.moe_router_dtype = "fp64" - args.moe_router_pre_softmax = False - args.moe_token_dispatcher_type = "alltoall" - args.no_load_optim = True - args.no_load_rng = True - args.no_save_optim = True - args.pipeline_model_parallel_size = 1 - args.position_embedding_type = None - args.recompute_granularity = None - args.seed = 42 - args.seq_length = 1024 - args.sequence_parallel = True - args.te_rng_tracker = True - args.tensor_model_parallel_size = 2 - args.vocab_size = 131072 - - validate_args(args) - set_global_variables(args, False) - return args - - def setup_method(self, method): - - os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' - args = self.create_test_args() - set_args(args) - - Utils.initialize_model_parallel( - tensor_model_parallel_size=args.tensor_model_parallel_size, - pipeline_model_parallel_size=args.pipeline_model_parallel_size, - expert_model_parallel_size=args.expert_model_parallel_size, - expert_tensor_parallel_size=args.expert_tensor_parallel_size, - ) - model_parallel_cuda_manual_seed(123) - - model_config = core_transformer_config_from_args(args, TransformerConfig) - - self.model = MambaModel( - config=model_config, - mamba_stack_spec=mamba_stack_spec, - vocab_size=args.vocab_size, - max_sequence_length=args.seq_length, - hybrid_attention_ratio=args.hybrid_attention_ratio, - hybrid_mlp_ratio=args.hybrid_mlp_ratio, - hybrid_override_pattern=args.hybrid_override_pattern, - position_embedding_type=args.position_embedding_type, - rotary_base=args.rotary_base, - rotary_percent=args.rotary_percent, - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - def test_constructor(self): - """Sanity check for the constructor of the Mamba MoE model.""" - - args = get_args() - - assert_config_matches_golden(self.model.config) - - assert self.model.pre_process is True, "pre_process should be True" - assert self.model.post_process is True, "post_process should be True" - assert self.model.hybrid_attention_ratio == 0.0, "hybrid_attention_ratio should be 0.0" - assert self.model.hybrid_mlp_ratio == 0.0, "hybrid_mlp_ratio should be 0.0" - assert ( - self.model.hybrid_override_pattern == args.hybrid_override_pattern - ), f"hybrid_override_pattern should be {args.hybrid_override_pattern}" - num_weights = sum([p.numel() for p in self.model.parameters()]) - assert num_weights == 8449294624, f"Expected 8449294624 parameters, got {num_weights}" - - def test_set_input_tensor(self): - - args = get_args() - - config: TransformerConfig = self.model.config - sequence_length = self.model.max_sequence_length - micro_batch_size = args.micro_batch_size - - # [sequence length, batch size, hidden size] - input_tensor = torch.ones((sequence_length, micro_batch_size, config.hidden_size)) - - self.model.set_input_tensor(input_tensor) - - assert self.model.decoder.input_tensor.shape[0] == sequence_length - assert self.model.decoder.input_tensor.shape[1] == micro_batch_size - assert self.model.decoder.input_tensor.shape[2] == config.hidden_size - - def test_forward(self): - """Basic smoke test for the forward pass of the Mamba MoE model.""" - - args = get_args() - - # we must override this to avoid the need to initialize the optimizer - for param in self.model.parameters(): - param.requires_grad = False - - sequence_length = self.model.max_sequence_length - micro_batch_size = args.micro_batch_size - - self.model.cuda() - - data = list(range(sequence_length)) - input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - position_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - attention_mask = torch.ones( - (micro_batch_size, 1, sequence_length, sequence_length), dtype=bool - ).cuda() - - logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - runtime_gather_output=True, - ) - - assert logits.shape[0] == micro_batch_size - assert logits.shape[1] == sequence_length - assert logits.shape[2] == self.model.vocab_size From 53a21e12af35663d1add2ed11e2043c4fb120431 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Thu, 15 Jan 2026 01:08:06 -0800 Subject: [PATCH 49/74] rename group_commit Signed-off-by: Hongbin Liu --- .../core/models/gpt/fine_grained_callables.py | 11 ++--- megatron/core/models/gpt/gpt_model.py | 5 -- .../fine_grained_activation_offload.py | 49 ++++++++++--------- megatron/core/transformer/attention.py | 9 ++-- megatron/core/transformer/moe/experts.py | 7 +-- .../transformer/multi_latent_attention.py | 9 ++-- .../core/transformer/transformer_layer.py | 27 ++++------ 7 files changed, 48 insertions(+), 69 deletions(-) diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index 980a5d424d2..5ae5192e5aa 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -13,9 +13,6 @@ from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( FineGrainedActivationOffloadingInterface as off_interface, ) -from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - fine_grained_offloading_group_commit, -) from megatron.core.pipeline_parallel.utils import ScheduleNode, make_viewless from megatron.core.transformer.enums import CudaGraphScope from megatron.core.transformer.module import float16_to_fp32 @@ -432,6 +429,7 @@ def submodule_attn_forward(node: ScheduleNode, hidden_states: torch.Tensor): node.chunk_state.flush_delayed_groups = False else: node.chunk_state.flush_delayed_groups = True + # wrapper function that keeps consistent api with cuda graph replay def forward_func( hidden_states: Tensor, @@ -558,7 +556,7 @@ def submodule_combine_forward(node: ScheduleNode, output: torch.Tensor): # Delay the offload of the mlp norm until after the mlp_bda has been computed # because the residual is needed in the mlp_bda. if layer.offload_mlp_norm: - hidden_states = fine_grained_offloading_group_commit( + hidden_states = off_interface.group_commit( hidden_states, name="mlp_norm", forced_released_tensors=[residual] ) output = make_viewless_tensor( @@ -566,10 +564,7 @@ def submodule_combine_forward(node: ScheduleNode, output: torch.Tensor): ) if node.chunk_state.flush_delayed_groups: - from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - fine_grained_offloading_group_flush_delayed_groups, - ) - fine_grained_offloading_group_flush_delayed_groups() + off_interface.flush_delayed_groups() # Need to record residual to comm stream, since it's created on comp stream node.layer_state.residual.record_stream(torch.cuda.current_stream()) diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index ae35eeb85a5..3c3219a5d3e 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -431,12 +431,7 @@ def _preprocess( def preprocess_for_fine_grained_offloading(self): """Preprocess for fine-grained activation offloading.""" -<<<<<<< HEAD - fine_grained_offloading_init_chunk_handler( - pp_rank=self.pg_collection.pp.rank(), -======= off_interface.init_chunk_handler( ->>>>>>> hongbinl/activation_offloading_refactor vp_size=self.config.virtual_pipeline_model_parallel_size, vp_stage=self.vp_stage, min_offloaded_tensor_size=self.config.min_offloaded_tensor_size, diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index 5402c0de69d..7ccf66d01ea 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -554,8 +554,10 @@ def post_warmup_callback(self): for chunk in self._cached_chunks_backward: for group in chunk.offload_groups: if group.offload and keep_on_gpu_bytes > 0: - debug_rank(f"group {group._name} offload {group.offload} \ - keep_on_gpu_bytes {keep_on_gpu_bytes}") + debug_rank( + f"group {group._name} offload {group.offload} \ + keep_on_gpu_bytes {keep_on_gpu_bytes}" + ) keep_on_gpu_bytes -= group.total_offload_bytes group.offload = False # Dump the offload information @@ -614,8 +616,12 @@ def front_backward_chunk(self, name=None): return None def init_model_chunk_offload_handler( - self, pp_rank, vp_size, vp_stage, min_offloaded_tensor_size=1024 * 1024, - delta_offload_bytes_across_pp_ranks=0 + self, + pp_rank, + vp_size, + vp_stage, + min_offloaded_tensor_size=1024 * 1024, + delta_offload_bytes_across_pp_ranks=0, ): """ Initialize a chunk offload handler for a model chunk (microbatch). @@ -1144,12 +1150,6 @@ def fine_grained_offloading_group_commit( ) -def fine_grained_offloading_group_flush_delayed_groups(): - """Flush the delayed groups.""" - debug_rank("fine_grained_offloading_group_flush_delayed_groups") - PipelineOffloadManager.get_instance().flush_delayed_groups() - - class FineGrainedOffloadingGroupStartFunction(torch.autograd.Function): """ Identity operation that marks the start of a layer group for offload/reload. @@ -1183,13 +1183,6 @@ def fine_grained_offloading_group_start(tensor, name=None): return FineGrainedOffloadingGroupStartFunction.apply(tensor, cur_forward_chunk, name) -def fine_grained_offloading_forward_record(event: torch.cuda.Event) -> None: - """Record the forward event for cuda graph capture.""" - d2h_stream = PipelineOffloadManager.get_instance().d2h_stream - torch.cuda.current_stream().record_event(event) - torch.cuda.current_stream().wait_stream(d2h_stream) - - class FineGrainedOffloadingBackwardRecordFunction(torch.autograd.Function): """ Identity operation that marks the end of a layer group for offload synchronization. @@ -1211,11 +1204,6 @@ def backward(ctx, grad_output): return grad_output, None -def fine_grained_offloading_backward_record(tensor, event: torch.cuda.Event) -> torch.Tensor: - """Record the backward event for cuda graph capture.""" - return FineGrainedOffloadingBackwardRecordFunction.apply(tensor, event) - - class FineGrainedActivationOffloadingInterface: """Interface for fine-grained activation offloading.""" @@ -1248,6 +1236,13 @@ def get_context(flag): """Get the fine-grained offload context""" return PipelineOffloadManager.get_instance() if flag else nullcontext() + @staticmethod + def group_commit(tensor, name, forced_released_tensors=None, delay_offload=False): + """Group commit the tensors.""" + return fine_grained_offloading_group_commit( + tensor, name, forced_released_tensors, delay_offload + ) + @staticmethod def mark_not_offloadable(tensor: torch.Tensor): """Mark the tensor as not offloadable.""" @@ -1260,6 +1255,11 @@ def forward_record(event: torch.cuda.Event) -> None: torch.cuda.current_stream().record_event(event) torch.cuda.current_stream().wait_stream(d2h_stream) + @staticmethod + def backward_record(tensor, event: torch.cuda.Event) -> torch.Tensor: + """Record the backward event for cuda graph capture.""" + return FineGrainedOffloadingBackwardRecordFunction.apply(tensor, event) + @staticmethod def reset(): """Reset the chunk handler.""" @@ -1269,3 +1269,8 @@ def reset(): def reset_instance(): """Reset the singleton instance.""" PipelineOffloadManager.reset_instance() + + @staticmethod + def flush_delayed_groups(): + """Flush the delayed groups.""" + PipelineOffloadManager.get_instance().flush_delayed_groups() diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 3cb1a5ee4a4..c3c7dad250a 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -27,9 +27,6 @@ from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( FineGrainedActivationOffloadingInterface as off_interface, ) -from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - fine_grained_offloading_group_commit, -) from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel.mappings import all_gather_last_dim_from_tensor_parallel_region from megatron.core.transformer.identity_op import IdentityOp @@ -837,7 +834,7 @@ def forward( ) if self.offload_qkv_linear: # `qkv_output` may be a tuple; commit supports tuple/list and will keep structure. - qkv_output = fine_grained_offloading_group_commit( + qkv_output = off_interface.group_commit( qkv_output, name="qkv_linear", forced_released_tensors=[] ) @@ -1023,7 +1020,7 @@ def forward( ) core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)') if self.offload_core_attention and self.training: - core_attn_out = fine_grained_offloading_group_commit( + core_attn_out = off_interface.group_commit( core_attn_out, name="core_attn", forced_released_tensors=[query, key, value] ) @@ -1049,7 +1046,7 @@ def forward( with off_interface(self.offload_attn_proj, core_attn_out, "attn_proj") as core_attn_out: output, bias = self.linear_proj(core_attn_out) if self.offload_attn_proj: - output = fine_grained_offloading_group_commit( + output = off_interface.group_commit( output, name="attn_proj", forced_released_tensors=[core_attn_out] ) nvtx_range_pop(suffix="linear_proj") diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index ab32f3a448d..7027cd28d23 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -28,9 +28,6 @@ from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( FineGrainedActivationOffloadingInterface as off_interface, ) -from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - fine_grained_offloading_group_commit, -) from megatron.core.tensor_parallel.layers import ( _initialize_affine_weight_cpu, _initialize_affine_weight_gpu, @@ -739,7 +736,7 @@ def forward( permuted_local_hidden_states, tokens_per_expert ) if self.offload_expert_fc1: - fc1_output = fine_grained_offloading_group_commit( + fc1_output = off_interface.group_commit( fc1_output, name="expert_fc1", forced_released_tensors=[permuted_local_hidden_states], @@ -821,7 +818,7 @@ def glu(x): # Delay the offload of the moe act until after the linear_fc2 has been computed # to make sure the fc1_output is reloaded to GPU before recomputing moe_act. if self.offload_moe_act: - output = fine_grained_offloading_group_commit( + output = off_interface.group_commit( output, name="moe_act", forced_released_tensors=[fc1_output], diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index 2e50d9f2169..9689056e325 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -25,9 +25,6 @@ from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( FineGrainedActivationOffloadingInterface as off_interface, ) -from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - fine_grained_offloading_group_commit, -) from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel.layers import ColumnParallelLinear from megatron.core.tensor_parallel.mappings import ( @@ -269,7 +266,7 @@ def forward( f"{self.config.experimental_attention_variant}" ) if self.offload_qkv_linear: - query = fine_grained_offloading_group_commit( + query = off_interface.group_commit( query, name="qkv_linear", forced_released_tensors=[hidden_states] ) @@ -351,7 +348,7 @@ def forward( if not inference_context.is_decode_only(): core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)') if self.offload_core_attention and self.training: - core_attn_out = fine_grained_offloading_group_commit( + core_attn_out = off_interface.group_commit( core_attn_out, name="core_attn", forced_released_tensors=[query, key, value] ) @@ -382,7 +379,7 @@ def forward( with off_interface(self.offload_attn_proj, core_attn_out, "attn_proj") as core_attn_out: output, bias = self.linear_proj(core_attn_out) if self.offload_attn_proj: - output = fine_grained_offloading_group_commit( + output = off_interface.group_commit( output, name="attn_proj", forced_released_tensors=[core_attn_out] ) diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index ecd1e8335a4..bbcfe8027c6 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -509,16 +509,9 @@ def _forward_attention( from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( FineGrainedActivationOffloadingInterface as off_interface, ) - from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - fine_grained_offloading_group_commit, - ) if self.offload_module_in_cuda_graph: - from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - fine_grained_offloading_backward_record, - ) - - hidden_states = fine_grained_offloading_backward_record( + hidden_states = off_interface.backward_record( hidden_states, TransformerLayer.cuda_graph_event ) @@ -573,7 +566,7 @@ def _forward_attention( # Delay the offload of the attention norm until after the self_attn_bda has been computed # because the residual is needed in the self_attn_bda. if self.offload_attn_norm: - hidden_states = fine_grained_offloading_group_commit( + hidden_states = off_interface.group_commit( hidden_states, name="attn_norm", forced_released_tensors=[residual] ) @@ -720,7 +713,7 @@ def _forward_post_mlp(self, mlp_output_with_bias, residual, flush_delayed_groups """ from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - fine_grained_offloading_group_commit, + FineGrainedActivationOffloadingInterface as off_interface, ) # TODO: could we move `bias_dropout_add_exec_handler` itself @@ -734,7 +727,7 @@ def _forward_post_mlp(self, mlp_output_with_bias, residual, flush_delayed_groups # Delay the offload of the mlp norm until after the mlp_bda has been computed # because the residual is needed in the mlp_bda. if self.offload_mlp_norm: - hidden_states = fine_grained_offloading_group_commit( + hidden_states = off_interface.group_commit( hidden_states, name="mlp_norm", forced_released_tensors=[residual] ) @@ -750,10 +743,10 @@ def _forward_post_mlp(self, mlp_output_with_bias, residual, flush_delayed_groups if self.config.delay_offload_until_cuda_graph and flush_delayed_groups: from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - fine_grained_offloading_group_flush_delayed_groups, + FineGrainedActivationOffloadingInterface as off_interface, ) - fine_grained_offloading_group_flush_delayed_groups() + off_interface.flush_delayed_groups() return output def sharded_state_dict( @@ -866,10 +859,10 @@ def _te_cuda_graph_capture(self, *args, **kwargs): cuda_graph_outputs.append(context) if self.offload_module_in_cuda_graph: from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - fine_grained_offloading_forward_record, + FineGrainedActivationOffloadingInterface as off_interface, ) - fine_grained_offloading_forward_record(TransformerLayer.cuda_graph_event) + off_interface.forward_record(TransformerLayer.cuda_graph_event) return tuple(cuda_graph_outputs) def _te_cuda_graph_replay(self, *args, **kwargs): @@ -897,10 +890,10 @@ def _te_cuda_graph_replay(self, *args, **kwargs): if self.config.delay_offload_until_cuda_graph: from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - fine_grained_offloading_group_flush_delayed_groups, + FineGrainedActivationOffloadingInterface as off_interface, ) - fine_grained_offloading_group_flush_delayed_groups() + off_interface.flush_delayed_groups() if kwargs.get('context') is not None: context = cuda_graph_output.pop() From 600cfe745397b8d5028cd67c60d60cfb47cd8688 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Sun, 18 Jan 2026 17:57:20 -0800 Subject: [PATCH 50/74] fix for graph support Signed-off-by: Hongbin Liu --- megatron/core/models/gpt/gpt_model.py | 1 + .../fine_grained_activation_offload.py | 80 +++++++++---------- megatron/core/transformer/cuda_graphs.py | 11 ++- megatron/core/transformer/moe/experts.py | 2 +- 4 files changed, 47 insertions(+), 47 deletions(-) diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 3c3219a5d3e..d8400cb3d16 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -432,6 +432,7 @@ def _preprocess( def preprocess_for_fine_grained_offloading(self): """Preprocess for fine-grained activation offloading.""" off_interface.init_chunk_handler( + pp_rank=self.pg_collection.pp.rank(), vp_size=self.config.virtual_pipeline_model_parallel_size, vp_stage=self.vp_stage, min_offloaded_tensor_size=self.config.min_offloaded_tensor_size, diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index 7ccf66d01ea..3a18193d542 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -448,17 +448,17 @@ def cpu_tensor_pool(self): """Get the shared CPU tensor pool.""" return self._cpu_tensor_pool - def push_offload_groups(self, group_hook, forced_released_tensors): + def push_offload_groups(self, group_hook, name, forced_released_tensors): """Push the offload groups to the delayed queue.""" debug_rank(f"pushing offload groups to the delayed queue") - self._delayed_offload_groups.append((group_hook, forced_released_tensors)) + self._delayed_offload_groups.append((group_hook, name, forced_released_tensors)) def flush_delayed_groups(self): """Flush the delayed groups.""" debug_rank("flushing delayed groups") # Flush the delayed groups in reverse order to maintain the order of the groups. - for group_hook, forced_released_tensors in reversed(self._delayed_offload_groups): - group_hook(forced_released_tensors) + for group_hook, name, forced_released_tensors in self._delayed_offload_groups: + group_hook(name, forced_released_tensors) self._delayed_offload_groups = [] def reset(self): @@ -811,17 +811,17 @@ def reset(self): self._tensor_count_current_group = 0 self._reloading_group = [] - def find_group_with_name(self, name: str, start_index: int = 0): + def find_group_with_name(self, groups: list[OffloadTensorGroup], name: str, start_index: int = 0): """Find the group with the given name starting from the given index.""" return next( - (group for group in self.offload_groups[start_index:] if group._name == name), None + (group for group in groups[start_index:] if group._name == name), None ) def is_empty_chunk(self, name=None): """Check if this chunk has no tensors to manage.""" debug_rank(f"------is_empty_chunk {self._max_group_size}") if name is not None: - return self.find_group_with_name(name) is None + return self.find_group_with_name(self.offload_groups, name) is None return self._max_group_size == 0 def finish_all_groups(self, name=None) -> bool: @@ -838,12 +838,12 @@ def finish_all_groups(self, name=None) -> bool: ): return True assert name is not None, "Name is required" - return self.find_group_with_name(name, self._offloaded_group_index) is None + return self.find_group_with_name(self.offload_groups, name, self._offloaded_group_index) is None def find_next_group(self, name=None): """Find the next group with the given name.""" assert name is not None, "Name is required" - return self.find_group_with_name(name, self._offloaded_group_index) + return self.find_group_with_name(self.offload_groups, name, self._offloaded_group_index) def tensor_push(self, tensor): """Push tensor to the offload handler.""" @@ -876,9 +876,7 @@ def tensor_pop(self, tensor_tag): def tensor_need_offloading_checker(self, tensor): """Check if the tensor needs to be offloaded.""" - debug_rank( - f"tensor_need_offloading_checker {getattr(tensor, 'offloading_activation', None)}" - ) + debug_rank("tensor_need_offloading_checker") if tensor.numel() < self.min_offloaded_tensor_size: return False # Respect tensor's offload preference if specified @@ -886,10 +884,9 @@ def tensor_need_offloading_checker(self, tensor): return False return True - def bulk_offload_group(self): + def bulk_offload_group(self, group_to_offload): """offload a group of tensors recorded in tensor_push().""" debug_rank("------bulk_offload_group") - group_to_offload = self._groups_to_offload[-1] torch.cuda.nvtx.range_push("activation offloading " + group_to_offload._name) with torch.cuda.stream(self.d2h_stream): for tensor_tag, tensor_on_device in group_to_offload._tensors.items(): @@ -902,7 +899,6 @@ def bulk_offload_group(self): tensor_on_device.record_stream(self.d2h_stream) group_to_offload.push_tensor(tensor_tag, state) group_to_offload.record_offload_event(self.d2h_stream) - self._groups_to_offload.pop() torch.cuda.nvtx.range_pop() def get_max_deduplicated_groups(self): @@ -931,6 +927,7 @@ def bulk_reload_group(self): group_to_reload.record_reload_event(self.h2d_stream) self._groups_to_reload.pop() # Add the group to the reloading group to wait for the reload event. + debug_rank(f"add group to reloading group {group_to_reload}") self._reloading_group.append(group_to_reload) torch.cuda.nvtx.range_pop() @@ -942,10 +939,11 @@ def pre_reload_last_layer(self): # Reload the last group (last layer) early self.bulk_reload_group() - def should_bulk_offload(self): + def should_bulk_offload(self, name): """Determine if the current group should be offloaded.""" assert len(self._groups_to_offload) > 0, "No groups to offload" - group = self._groups_to_offload[-1] + group = self.find_group_with_name(self._groups_to_offload, name) + assert group is not None, f"Group {name} not found in {self._groups_to_offload}" debug_rank(f"should_bulk_offload {self.is_warmup} {group.offload}") # Don't offload if the chunk is not in warmup stage if self.is_warmup: @@ -966,12 +964,16 @@ def should_bulk_offload(self): return True - def bulk_offload(self, forced_released_tensors): + def bulk_offload(self, name, forced_released_tensors): """Offload a group of tensors and optionally release their GPU memory.""" debug_rank("----bulk_offload") - if self.should_bulk_offload(): - self._groups_to_reload.append(self._groups_to_offload[-1]) - self.bulk_offload_group() + if self.should_bulk_offload(name): + group_to_offload = self.find_group_with_name(self._groups_to_offload, name) + assert group_to_offload is not None, \ + f"Group {name} not found in {self._groups_to_offload}" + self._groups_to_reload.append(group_to_offload) + self.bulk_offload_group(group_to_offload) + self._groups_to_offload.remove(group_to_offload) # Manually release tensors not auto-freed by torch GC if len(forced_released_tensors) > 0: cur_stream = torch.cuda.current_stream() @@ -981,14 +983,14 @@ def bulk_offload(self, forced_released_tensors): release_tensor.record_stream(cur_stream) release_tensor.untyped_storage().resize_(0) - def on_group_commit_forward(self, forced_released_tensors): + def on_group_commit_forward(self, name, forced_released_tensors): """Called at the end of a layer group's forward pass to trigger offloading.""" if not self.do_offload: return - debug_rank("--on_group_commit_forward") + debug_rank(f"--on_group_commit_forward {name}") # Wait for compute to finish before starting offload self.d2h_stream.wait_stream(torch.cuda.current_stream()) - self.bulk_offload(forced_released_tensors) + self.bulk_offload(name, forced_released_tensors) def bulk_reload(self): """Reload the next group of tensors from CPU to GPU.""" @@ -1064,18 +1066,6 @@ def on_group_start_backward(self): self.bulk_reload() -def fine_grained_offloading_disable_offload(): - """Disable the offload.""" - debug_rank("fine_grained_offloading_disable_offload") - PipelineOffloadManager.get_instance().disable_offload() - - -def fine_grained_offloading_enable_offload(): - """Enable the offload.""" - debug_rank("fine_grained_offloading_enable_offload") - PipelineOffloadManager.get_instance().enable_offload() - - class FineGrainedOffloadingGroupCommitFunction(torch.autograd.Function): """ Identity operation that marks the end of a layer group for offload synchronization. @@ -1089,10 +1079,10 @@ def forward(ctx, tensor, cur_forward_chunk, name, forced_released_tensors, delay if delay_offload: PipelineOffloadManager.get_instance().push_offload_groups( - cur_forward_chunk.on_group_commit_forward, forced_released_tensors + cur_forward_chunk.on_group_commit_forward, name, forced_released_tensors ) else: - cur_forward_chunk.on_group_commit_forward(forced_released_tensors) + cur_forward_chunk.on_group_commit_forward(name, forced_released_tensors) ctx.cpu_offload_handler = cur_forward_chunk ctx.name = name return tensor @@ -1225,10 +1215,10 @@ def __exit__(self, *args: Any): PipelineOffloadManager.get_instance().__exit__() @staticmethod - def init_chunk_handler(vp_size, vp_stage, min_offloaded_tensor_size): + def init_chunk_handler(pp_rank, vp_size, vp_stage, min_offloaded_tensor_size, delta_offload_bytes_across_pp_ranks): """Initialize the chunk handler, called at the start of a microbatch forward pass.""" PipelineOffloadManager.get_instance().init_model_chunk_offload_handler( - vp_size, vp_stage, min_offloaded_tensor_size + pp_rank, vp_size, vp_stage, min_offloaded_tensor_size, delta_offload_bytes_across_pp_ranks ) @staticmethod @@ -1274,3 +1264,13 @@ def reset_instance(): def flush_delayed_groups(): """Flush the delayed groups.""" PipelineOffloadManager.get_instance().flush_delayed_groups() + + @staticmethod + def disable_offload(): + """Disable the offload.""" + PipelineOffloadManager.get_instance().disable_offload() + + @staticmethod + def enable_offload(): + """Enable the offload.""" + PipelineOffloadManager.get_instance().enable_offload() diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index 691b129d8bf..b3f1dbfdaa2 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -1900,14 +1900,13 @@ def _get_fp8_enabled(): kwargs['fp8_enabled'] = False from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - fine_grained_offloading_disable_offload, - fine_grained_offloading_enable_offload, + FineGrainedActivationOffloadingInterface as off_interface ) # if self.config.offload_module_in_cuda_graph: if self.config.fine_grained_activation_offloading: - kwargs['pre_warmup_hook'] = fine_grained_offloading_disable_offload - kwargs['post_warmup_hook'] = fine_grained_offloading_enable_offload + kwargs['pre_warmup_hook'] = off_interface.disable_offload + kwargs['post_warmup_hook'] = off_interface.enable_offload return kwargs kwargs = get_make_graphed_callables_kwargs() @@ -1943,11 +1942,11 @@ def _finish_capturing(self, start_time): from megatron.core.distributed.finalize_model_grads import reset_model_temporary_tensors from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - fine_grained_offloading_reset, + FineGrainedActivationOffloadingInterface as off_interface ) from megatron.core.transformer.moe.moe_utils import clear_aux_losses_tracker - fine_grained_offloading_reset() + off_interface.reset() torch.distributed.barrier() for model_chunk in self.model: diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 7027cd28d23..1aa6b95c48d 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -660,7 +660,7 @@ def __init__( set_save_original_input(self.linear_fc2) # This is to avoid the CPU overhead of multiple d2h copies - if self.offload_expert_fc1: + if self.offload_expert_fc1 and not self.config.fp8: from megatron.core.extensions.transformer_engine import set_save_original_input set_save_original_input(self.linear_fc1) From 09956df633903b372439e886bc93a4633f4d67b2 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Tue, 20 Jan 2026 05:01:56 -0800 Subject: [PATCH 51/74] refine offloading strategy Signed-off-by: Hongbin Liu --- megatron/core/models/gpt/gpt_model.py | 6 +++--- .../fine_grained_activation_offload.py | 11 ++++++----- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index d8400cb3d16..d4cbf5cdc05 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -440,13 +440,13 @@ def preprocess_for_fine_grained_offloading(self): ) if self.disable_param_offloading: for param in self.decoder.parameters(): - off_interface.mark_not_offloadable(param) + off_interface.mark_not_offload(param) if self.mtp_process: for param in self.mtp.parameters(): - off_interface.mark_not_offloadable(param) + off_interface.mark_not_offload(param) if self.post_process: for param in self.output_layer.parameters(): - off_interface.mark_not_offloadable(param) + off_interface.mark_not_offload(param) self.disable_param_offloading = False def forward( diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index 3a18193d542..5cc6d7042f8 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -687,10 +687,10 @@ def cur_backward_chunk(self): """Get the current backward pass chunk handler.""" return self._cur_backward_chunk - def mark_not_offloadable(self, tensor: torch.Tensor): + def mark_not_offload(self, tensor: torch.Tensor): """Mark the current forward chunk as not offloadable.""" if tensor is not None: - tensor.offloading_activation = False + tensor._do_not_offload = True def __enter__(self): """Enter context manager to enable activation offloading hooks.""" @@ -880,7 +880,8 @@ def tensor_need_offloading_checker(self, tensor): if tensor.numel() < self.min_offloaded_tensor_size: return False # Respect tensor's offload preference if specified - if hasattr(tensor, "offloading_activation") and not tensor.offloading_activation: + if getattr(tensor, "_TE_do_not_offload", False) \ + or getattr(tensor, "_do_not_offload", False): return False return True @@ -1234,9 +1235,9 @@ def group_commit(tensor, name, forced_released_tensors=None, delay_offload=False ) @staticmethod - def mark_not_offloadable(tensor: torch.Tensor): + def mark_not_offload(tensor: torch.Tensor): """Mark the tensor as not offloadable.""" - PipelineOffloadManager.get_instance().mark_not_offloadable(tensor) + PipelineOffloadManager.get_instance().mark_not_offload(tensor) @staticmethod def forward_record(event: torch.cuda.Event) -> None: From 18a2d9e6ee6fba898fce8c3dda9a794f7efde2c5 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Thu, 22 Jan 2026 06:49:55 -0800 Subject: [PATCH 52/74] temp fix for mxfp8 Signed-off-by: Hongbin Liu --- .../fine_grained_activation_offload.py | 52 +++++++++++++++++-- .../core/transformer/transformer_layer.py | 19 ++++--- 2 files changed, 57 insertions(+), 14 deletions(-) diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index 5cc6d7042f8..c2a2f63a0a1 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -410,6 +410,9 @@ def __init__(self): # allocate streams and events for synchronization self._d2h_stream = torch.cuda.Stream() self._h2d_stream = torch.cuda.Stream() + # CUDA graph stream and event for offloading modules in cuda graph + self._cuda_graph_stream = torch.cuda.Stream() + self._cuda_graph_event = torch.cuda.Event(external=True) # Shared CPU tensor pool for all chunks to improve reuse efficiency self._cpu_tensor_pool = GPUTensorPool(device="cpu", pin_memory=True) @@ -442,6 +445,16 @@ def d2h_stream(self): def h2d_stream(self): """Get the host-to-device (CPU to GPU) transfer stream.""" return self._h2d_stream + + @property + def cuda_graph_stream(self): + """Get the CUDA graph stream.""" + return self._cuda_graph_stream + + @property + def cuda_graph_event(self): + """Get the CUDA graph event.""" + return self._cuda_graph_event @property def cpu_tensor_pool(self): @@ -751,6 +764,13 @@ def offload(self, src_tensor, pin_memory=True, use_cpu_pool=True): """Offload.""" debug_rank("--------offload") + from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor + is_mxfp8_tensor = isinstance(src_tensor, MXFP8Tensor) + if is_mxfp8_tensor: + mxfp8_tensor = src_tensor + src_tensor = src_tensor._columnwise_data + else: + mxfp8_tensor = None if not src_tensor.is_contiguous(): src_tensor = src_tensor.contiguous() @@ -762,13 +782,16 @@ def offload(self, src_tensor, pin_memory=True, use_cpu_pool=True): ) cpu_backup.copy_(src_tensor, non_blocking=pin_memory) - state = (src_tensor.device, cpu_backup, use_cpu_pool) + state = (src_tensor.device, cpu_backup, use_cpu_pool, mxfp8_tensor) + if is_mxfp8_tensor: + src_tensor.record_stream(torch.cuda.current_stream()) + src_tensor.untyped_storage().resize_(0) return state def reload(self, state, non_blocking=None): """Reload.""" debug_rank("------reload") - dev, cpu_backup, use_cpu_pool = state + dev, cpu_backup, use_cpu_pool, mxfp8_tensor = state if non_blocking is None: non_blocking = cpu_backup.is_pinned() gpu_tensor = torch.empty( @@ -777,6 +800,9 @@ def reload(self, state, non_blocking=None): gpu_tensor.copy_(cpu_backup, non_blocking=non_blocking) if use_cpu_pool: self.cpu_tensor_pool.free(cpu_backup) + if mxfp8_tensor is not None: + mxfp8_tensor._columnwise_data = gpu_tensor + return mxfp8_tensor return gpu_tensor def __init__(self, min_offloaded_tensor_size, cpu_tensor_pool): @@ -803,6 +829,8 @@ def __init__(self, min_offloaded_tensor_size, cpu_tensor_pool): self.cpu_tensor_pool = cpu_tensor_pool self.is_warmup = True + self.mxfp8_tensors = [] + def reset(self): """Reset the chunk offload handler.""" self._offloaded_group_index = 0 @@ -810,7 +838,7 @@ def reset(self): self._groups_to_reload = [] self._tensor_count_current_group = 0 self._reloading_group = [] - + self.mxfp8_tensors = [] def find_group_with_name(self, groups: list[OffloadTensorGroup], name: str, start_index: int = 0): """Find the group with the given name starting from the given index.""" return next( @@ -889,6 +917,7 @@ def bulk_offload_group(self, group_to_offload): """offload a group of tensors recorded in tensor_push().""" debug_rank("------bulk_offload_group") torch.cuda.nvtx.range_push("activation offloading " + group_to_offload._name) + released_tensors = [] with torch.cuda.stream(self.d2h_stream): for tensor_tag, tensor_on_device in group_to_offload._tensors.items(): if self.tensor_need_offloading_checker(tensor_on_device): @@ -897,9 +926,14 @@ def bulk_offload_group(self, group_to_offload): ) if self.is_warmup: group_to_offload.update_offload_info(tensor_on_device) - tensor_on_device.record_stream(self.d2h_stream) + if state[3] is None: + tensor_on_device.record_stream(self.d2h_stream) group_to_offload.push_tensor(tensor_tag, state) + released_tensors.append(tensor_on_device) group_to_offload.record_offload_event(self.d2h_stream) + # for tensor in released_tensors: + # tensor.record_stream(torch.cuda.current_stream()) + # tensor.untyped_storage().resize_(0) torch.cuda.nvtx.range_pop() def get_max_deduplicated_groups(self): @@ -1215,6 +1249,16 @@ def __exit__(self, *args: Any): if self.offload: PipelineOffloadManager.get_instance().__exit__() + @staticmethod + def cuda_graph_stream(): + """Get the CUDA graph stream.""" + return PipelineOffloadManager.get_instance().cuda_graph_stream + + @staticmethod + def cuda_graph_event(): + """Get the CUDA graph event.""" + return PipelineOffloadManager.get_instance().cuda_graph_event + @staticmethod def init_chunk_handler(pp_rank, vp_size, vp_stage, min_offloaded_tensor_size, delta_offload_bytes_across_pp_ranks): """Initialize the chunk handler, called at the start of a microbatch forward pass.""" diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index bbcfe8027c6..694ebd304bd 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -1133,13 +1133,12 @@ def _set_offload_modules(self): self.config.cuda_graph_warmup_steps > 0 ), "Fine-grained activation offloading needs cuda_graph_warmup_steps > 0." # Set the cuda graph stream and event for the transformer layer. - if TransformerLayer.cuda_graph_stream is None: - if self.offload_module_in_cuda_graph: - TransformerLayer.cuda_graph_stream = torch.cuda.Stream() - else: - TransformerLayer.cuda_graph_stream = torch.cuda.current_stream() - if TransformerLayer.cuda_graph_event is None: - if self.offload_module_in_cuda_graph: - TransformerLayer.cuda_graph_event = torch.cuda.Event(external=True) - else: - TransformerLayer.cuda_graph_event = torch.cuda.Event() + if self.offload_module_in_cuda_graph: + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, + ) + TransformerLayer.cuda_graph_stream = off_interface.cuda_graph_stream() + TransformerLayer.cuda_graph_event = off_interface.cuda_graph_event() + else: + TransformerLayer.cuda_graph_stream = torch.cuda.current_stream() + TransformerLayer.cuda_graph_event = torch.cuda.Event() From 885164bedf46f5b323bef03fa493970c0d436e47 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Fri, 23 Jan 2026 21:52:24 -0800 Subject: [PATCH 53/74] minor fix Signed-off-by: Hongbin Liu --- .../fine_grained_activation_offload.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index c2a2f63a0a1..31cd453a3e5 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Tuple import torch +from torch.autograd.graph import saved_tensors_hooks # CPU offload implementation for pipeline parallelism DEBUG = False @@ -436,6 +437,9 @@ def __init__(self): self._delayed_offload_groups = [] self.reset() + self._saved_tensors_hooks = saved_tensors_hooks( + self.on_save_for_backward, self.on_get_saved_tensor) + @property def d2h_stream(self): """Get the device-to-host (GPU to CPU) transfer stream.""" @@ -717,10 +721,8 @@ def __enter__(self): else: raise RuntimeError("TE CPU offload is not available") self.inside_context = True + self._saved_tensors_hooks.__enter__() - torch._C._autograd._push_saved_tensors_default_hooks( - self.on_save_for_backward, self.on_get_saved_tensor - ) def __exit__(self, *args: Any): """Exit context manager and restore original tensor saving behavior.""" @@ -734,7 +736,8 @@ def __exit__(self, *args: Any): else: raise RuntimeError("TE CPU offload is not available") self.inside_context = False - torch._C._autograd._pop_saved_tensors_default_hooks() + # torch._C._autograd._pop_saved_tensors_default_hooks() + self._saved_tensors_hooks.__exit__() def on_save_for_backward(self, tensor: torch.Tensor) -> Any: """ @@ -917,7 +920,6 @@ def bulk_offload_group(self, group_to_offload): """offload a group of tensors recorded in tensor_push().""" debug_rank("------bulk_offload_group") torch.cuda.nvtx.range_push("activation offloading " + group_to_offload._name) - released_tensors = [] with torch.cuda.stream(self.d2h_stream): for tensor_tag, tensor_on_device in group_to_offload._tensors.items(): if self.tensor_need_offloading_checker(tensor_on_device): @@ -929,11 +931,7 @@ def bulk_offload_group(self, group_to_offload): if state[3] is None: tensor_on_device.record_stream(self.d2h_stream) group_to_offload.push_tensor(tensor_tag, state) - released_tensors.append(tensor_on_device) group_to_offload.record_offload_event(self.d2h_stream) - # for tensor in released_tensors: - # tensor.record_stream(torch.cuda.current_stream()) - # tensor.untyped_storage().resize_(0) torch.cuda.nvtx.range_pop() def get_max_deduplicated_groups(self): From 688c2abf64dcd2437225ec2e41a28bf96ba65f56 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Sat, 24 Jan 2026 22:50:45 -0800 Subject: [PATCH 54/74] support offloading fraction Signed-off-by: Hongbin Liu --- megatron/core/models/gpt/gpt_model.py | 1 + .../fine_grained_activation_offload.py | 74 ++++++++++++++----- megatron/core/transformer/cuda_graphs.py | 4 +- .../core/transformer/transformer_config.py | 3 + .../core/transformer/transformer_layer.py | 1 + megatron/training/arguments.py | 2 + 6 files changed, 66 insertions(+), 19 deletions(-) diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index d4cbf5cdc05..dbb75f1b4ed 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -437,6 +437,7 @@ def preprocess_for_fine_grained_offloading(self): vp_stage=self.vp_stage, min_offloaded_tensor_size=self.config.min_offloaded_tensor_size, delta_offload_bytes_across_pp_ranks=self.config.delta_offload_bytes_across_pp_ranks, + activation_offload_fraction=self.config.activation_offload_fraction, ) if self.disable_param_offloading: for param in self.decoder.parameters(): diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index 31cd453a3e5..669e3de89ea 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -438,7 +438,8 @@ def __init__(self): self.reset() self._saved_tensors_hooks = saved_tensors_hooks( - self.on_save_for_backward, self.on_get_saved_tensor) + self.on_save_for_backward, self.on_get_saved_tensor + ) @property def d2h_stream(self): @@ -449,12 +450,12 @@ def d2h_stream(self): def h2d_stream(self): """Get the host-to-device (CPU to GPU) transfer stream.""" return self._h2d_stream - + @property def cuda_graph_stream(self): """Get the CUDA graph stream.""" return self._cuda_graph_stream - + @property def cuda_graph_event(self): """Get the CUDA graph event.""" @@ -567,6 +568,7 @@ def post_warmup_callback(self): else: break assert self._offload_margin == 0, "Offload margin is not 0" + # Disable the groups to meet the delta offload bytes across PP ranks. keep_on_gpu_bytes = self._pp_rank * self._delta_offload_bytes_across_pp_ranks for chunk in self._cached_chunks_backward: for group in chunk.offload_groups: @@ -577,11 +579,27 @@ def post_warmup_callback(self): ) keep_on_gpu_bytes -= group.total_offload_bytes group.offload = False + # Disable the groups to meet the activation offload fraction. + for chunk in self._cached_chunks_backward: + offloaded_groups_count = 0 + for group in chunk.offload_groups: + if group.offload: + offloaded_groups_count += 1 + disabled_groups_count = offloaded_groups_count * (1 - self._activation_offload_fraction) + debug_rank(f"Disabled {disabled_groups_count}/{offloaded_groups_count} groups") + for group in reversed(chunk.offload_groups): + if group.offload: + if disabled_groups_count > 0: + disabled_groups_count -= 1 + group.offload = False + else: + break # Dump the offload information total_tensor_count = {} total_offload_bytes = {} for chunk in self._cached_chunks_forward: for group in chunk.offload_groups: + debug_rank(f"chunk {chunk} group {group} offload {group.offload}") if group.offload: if group._name not in total_tensor_count: total_tensor_count[group._name] = 0 @@ -593,6 +611,8 @@ def post_warmup_callback(self): # where the memory cost will not increase anymore. if chunk is self._cached_chunks_backward[0]: break + debug_rank(f"total_tensor_count {total_tensor_count}") + debug_rank(f"total_offload_bytes {total_offload_bytes}") # Cache summary for downstream consumers (e.g., unit tests). self._offload_summary_bytes = dict(total_offload_bytes) self._offload_summary_total_bytes = int(sum(total_offload_bytes.values())) @@ -639,6 +659,7 @@ def init_model_chunk_offload_handler( vp_stage, min_offloaded_tensor_size=1024 * 1024, delta_offload_bytes_across_pp_ranks=0, + activation_offload_fraction: float = 1.0, ): """ Initialize a chunk offload handler for a model chunk (microbatch). @@ -661,6 +682,7 @@ def init_model_chunk_offload_handler( self._delta_offload_bytes_across_pp_ranks = delta_offload_bytes_across_pp_ranks self._pp_rank = pp_rank + self._activation_offload_fraction = activation_offload_fraction if vp_stage is None: cur_vpp_rank = 0 @@ -723,7 +745,6 @@ def __enter__(self): self.inside_context = True self._saved_tensors_hooks.__enter__() - def __exit__(self, *args: Any): """Exit context manager and restore original tensor saving behavior.""" debug_rank("----__exit__") @@ -768,6 +789,7 @@ def offload(self, src_tensor, pin_memory=True, use_cpu_pool=True): debug_rank("--------offload") from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor + is_mxfp8_tensor = isinstance(src_tensor, MXFP8Tensor) if is_mxfp8_tensor: mxfp8_tensor = src_tensor @@ -842,11 +864,12 @@ def reset(self): self._tensor_count_current_group = 0 self._reloading_group = [] self.mxfp8_tensors = [] - def find_group_with_name(self, groups: list[OffloadTensorGroup], name: str, start_index: int = 0): + + def find_group_with_name( + self, groups: list[OffloadTensorGroup], name: str, start_index: int = 0 + ): """Find the group with the given name starting from the given index.""" - return next( - (group for group in groups[start_index:] if group._name == name), None - ) + return next((group for group in groups[start_index:] if group._name == name), None) def is_empty_chunk(self, name=None): """Check if this chunk has no tensors to manage.""" @@ -869,7 +892,10 @@ def finish_all_groups(self, name=None) -> bool: ): return True assert name is not None, "Name is required" - return self.find_group_with_name(self.offload_groups, name, self._offloaded_group_index) is None + return ( + self.find_group_with_name(self.offload_groups, name, self._offloaded_group_index) + is None + ) def find_next_group(self, name=None): """Find the next group with the given name.""" @@ -911,8 +937,9 @@ def tensor_need_offloading_checker(self, tensor): if tensor.numel() < self.min_offloaded_tensor_size: return False # Respect tensor's offload preference if specified - if getattr(tensor, "_TE_do_not_offload", False) \ - or getattr(tensor, "_do_not_offload", False): + if getattr(tensor, "_TE_do_not_offload", False) or getattr( + tensor, "_do_not_offload", False + ): return False return True @@ -1002,8 +1029,9 @@ def bulk_offload(self, name, forced_released_tensors): debug_rank("----bulk_offload") if self.should_bulk_offload(name): group_to_offload = self.find_group_with_name(self._groups_to_offload, name) - assert group_to_offload is not None, \ - f"Group {name} not found in {self._groups_to_offload}" + assert ( + group_to_offload is not None + ), f"Group {name} not found in {self._groups_to_offload}" self._groups_to_reload.append(group_to_offload) self.bulk_offload_group(group_to_offload) self._groups_to_offload.remove(group_to_offload) @@ -1251,17 +1279,29 @@ def __exit__(self, *args: Any): def cuda_graph_stream(): """Get the CUDA graph stream.""" return PipelineOffloadManager.get_instance().cuda_graph_stream - + @staticmethod def cuda_graph_event(): """Get the CUDA graph event.""" return PipelineOffloadManager.get_instance().cuda_graph_event @staticmethod - def init_chunk_handler(pp_rank, vp_size, vp_stage, min_offloaded_tensor_size, delta_offload_bytes_across_pp_ranks): + def init_chunk_handler( + pp_rank, + vp_size, + vp_stage, + min_offloaded_tensor_size, + delta_offload_bytes_across_pp_ranks, + activation_offload_fraction, + ): """Initialize the chunk handler, called at the start of a microbatch forward pass.""" PipelineOffloadManager.get_instance().init_model_chunk_offload_handler( - pp_rank, vp_size, vp_stage, min_offloaded_tensor_size, delta_offload_bytes_across_pp_ranks + pp_rank, + vp_size, + vp_stage, + min_offloaded_tensor_size, + delta_offload_bytes_across_pp_ranks, + activation_offload_fraction, ) @staticmethod @@ -1307,7 +1347,7 @@ def reset_instance(): def flush_delayed_groups(): """Flush the delayed groups.""" PipelineOffloadManager.get_instance().flush_delayed_groups() - + @staticmethod def disable_offload(): """Disable the offload.""" diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index b3f1dbfdaa2..e3692dad921 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -1900,7 +1900,7 @@ def _get_fp8_enabled(): kwargs['fp8_enabled'] = False from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - FineGrainedActivationOffloadingInterface as off_interface + FineGrainedActivationOffloadingInterface as off_interface, ) # if self.config.offload_module_in_cuda_graph: @@ -1942,7 +1942,7 @@ def _finish_capturing(self, start_time): from megatron.core.distributed.finalize_model_grads import reset_model_temporary_tensors from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - FineGrainedActivationOffloadingInterface as off_interface + FineGrainedActivationOffloadingInterface as off_interface, ) from megatron.core.transformer.moe.moe_utils import clear_aux_losses_tracker diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 9b26af1b75f..3eb4350a0cc 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -873,6 +873,9 @@ class TransformerConfig(ModelParallelConfig): delta_offload_bytes_across_pp_ranks: int = 0 """Difference of offload bytes across PP ranks to balance the offload load.""" + activation_offload_fraction: float = 1.0 + """The fraction of the activation to be offloaded, which should be in range [0, 1].""" + def __post_init__(self): """Python dataclass method that is used to modify attributes after initialization. See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 694ebd304bd..25d334abf6f 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -1137,6 +1137,7 @@ def _set_offload_modules(self): from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( FineGrainedActivationOffloadingInterface as off_interface, ) + TransformerLayer.cuda_graph_stream = off_interface.cuda_graph_stream() TransformerLayer.cuda_graph_event = off_interface.cuda_graph_event() else: diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index cfa1a7f03f6..90af8e949d0 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -2442,6 +2442,8 @@ def _add_training_args(parser): help='Delay the offload until the CUDA graph is executed for minimal CPU overhead.') group.add_argument('--delta-offload-bytes-across-pp-ranks', type=int, default=0, help='Difference of offload bytes across PP ranks to balance the offload load.') + group.add_argument('--activation-offload-fraction', type=float, default=1.0, + help='The fraction of the activation to be offloaded for each module, which should be in range [0, 1].') group.add_argument('--disable-jit-fuser', action='store_true', help='Disable the JIT fuser.') group.add_argument('--batch-invariant-mode', action='store_true', From 2b0276cdc4263935e85cf251bb756388631650d7 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 2 Feb 2026 19:25:32 -0800 Subject: [PATCH 55/74] free input of mlp when fp8 Signed-off-by: Hongbin Liu --- megatron/core/models/gpt/fine_grained_callables.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index 5ae5192e5aa..95e2e04aad2 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -71,7 +71,7 @@ def should_free_input(name, is_moe, config): # The input and output of A2A are not needed anymore after the forward pass, # so we can free the input memory after the forward pass. free_input_nodes = { - "mlp": not enable_hybridep, + "mlp": not (enable_hybridep and config.fp8 is None), "moe_combine": True, # For non-DeepEP and non-HybridEP dispatcher mode, the input is the un-dispatched tokens # and probs before dispatch A2A and it's not needed anymore after the forward pass From e606424384590d62b15beb1d0c70f4be4d16a1b0 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 2 Feb 2026 20:23:46 -0800 Subject: [PATCH 56/74] minor fix and format Signed-off-by: Hongbin Liu --- .../core/pipeline_parallel/fine_grained_activation_offload.py | 3 +-- megatron/core/transformer/transformer_layer.py | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index 06080f46d28..efcbd3f8995 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -937,7 +937,7 @@ def tensor_need_offloading_checker(self, tensor): return False return True - def bulk_offload_group(self): + def bulk_offload_group(self, group_to_offload): """offload a group of tensors recorded in tensor_push().""" debug_rank("------bulk_offload_group") torch.cuda.nvtx.range_push("activation offloading " + group_to_offload._name) @@ -952,7 +952,6 @@ def bulk_offload_group(self): tensor_on_device.record_stream(self.d2h_stream) group_to_offload.push_tensor(tensor_tag, state) group_to_offload.record_offload_event(self.d2h_stream) - self._groups_to_offload.pop() torch.cuda.nvtx.range_pop() def get_max_deduplicated_groups(self): diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 04f26bf4fdf..7835787d8ae 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -1094,7 +1094,8 @@ def _te_cuda_graph_replay(self, *args, **kwargs): nvtx_range_pop(suffix="mlp") output = self._forward_post_mlp( - mlp_output_with_bias, residual, flush_delayed_groups=False) + mlp_output_with_bias, residual, flush_delayed_groups=False + ) else: # If EP overlap is enabled, needs to return same outputs as submodule.attn if self.config.overlap_moe_expert_parallel_comm: From 0f0d1ed44ad27f0ff61b0e85ecd9f936c365f050 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 2 Feb 2026 21:21:52 -0800 Subject: [PATCH 57/74] fix ut Signed-off-by: Hongbin Liu --- tests/unit_tests/models/test_mamba_moe_model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/unit_tests/models/test_mamba_moe_model.py b/tests/unit_tests/models/test_mamba_moe_model.py index f21cfffe4ba..98bab2e1867 100644 --- a/tests/unit_tests/models/test_mamba_moe_model.py +++ b/tests/unit_tests/models/test_mamba_moe_model.py @@ -270,6 +270,9 @@ "fine_grained_activation_offloading": False, "min_offloaded_tensor_size": 1024 * 1024, "offload_modules": [], + "delay_offload_until_cuda_graph": False, + "delta_offload_bytes_across_pp_ranks": 0, + "activation_offload_fraction": 1.0, "hybrid_context_parallel": False, "max_seqlen_per_dp_cp_rank": None, "enable_routing_replay": False, From e585f6036a7e325310aefdff57db4f4e48bb590e Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Wed, 4 Feb 2026 20:09:12 +0800 Subject: [PATCH 58/74] Update arguments.py --- megatron/training/arguments.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 1b949edacaa..1af066a8207 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -2127,18 +2127,6 @@ def _add_training_args(parser): help='Use the legacy Megatron models, not Megatron-Core models.') group.add_argument('--high-priority-stream-groups', nargs='*', type=str, default=[], help='The communicator group names to use high priority streams.') - group.add_argument('--fine-grained-activation-offloading', action='store_true', - help='Enable fine-grained activation offloading.') - group.add_argument('--offload-modules', nargs='*', type=str, default=[], - help='The submodules to offload its input. Choices: "attn_norm", "qkv_linear", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act".') - group.add_argument('--min-offloaded-tensor-size', type=int, default=1024*1024, - help='The minimum size of the tensor to be offloaded.') - group.add_argument('--delay-offload-until-cuda-graph', action='store_true', - help='Delay the offload until the CUDA graph is executed for minimal CPU overhead.') - group.add_argument('--delta-offload-bytes-across-pp-ranks', type=int, default=0, - help='Difference of offload bytes across PP ranks to balance the offload load.') - group.add_argument('--activation-offload-fraction', type=float, default=1.0, - help='The fraction of the activation to be offloaded for each module, which should be in range [0, 1].') group.add_argument('--disable-jit-fuser', action='store_true', help='Disable the JIT fuser.') From 1b8050fbf71068c57a0c85dbfa2f62f821c1fc0e Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Wed, 4 Feb 2026 06:28:34 -0800 Subject: [PATCH 59/74] fix ut Signed-off-by: Hongbin Liu --- .../test_fine_grained_activation_offloading.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py index 558c6934a0c..3439f4a3dd4 100644 --- a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py +++ b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py @@ -318,7 +318,7 @@ def test_gpt_fine_grained_activation_offloading_correctness_and_memory( ("alltoall", True, ["mlp_norm"]), ("alltoall", False, ["expert_fc1"]), ("alltoall", False, ["moe_act"]), - ("alltoall", False, ["mlp_norm", "expert_fc1", "moe_act"]), + # ("alltoall", False, ["mlp_norm", "expert_fc1", "moe_act"]), ( "alltoall", True, From 504b3a406fbe5ccc759084546bb3b990823575c5 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Wed, 4 Feb 2026 18:54:41 -0800 Subject: [PATCH 60/74] update ut and minor refactor Signed-off-by: Hongbin Liu --- .../fine_grained_activation_offloading.md | 15 + .../fine_grained_activation_offload.py | 9 - .../core/transformer/transformer_config.py | 5 +- .../core/transformer/transformer_layer.py | 5 +- ...test_fine_grained_activation_offloading.py | 335 ++++++++++++++++++ 5 files changed, 358 insertions(+), 11 deletions(-) diff --git a/docs/api-guide/fine_grained_activation_offloading.md b/docs/api-guide/fine_grained_activation_offloading.md index 53211d1d06c..eee1eb8445e 100644 --- a/docs/api-guide/fine_grained_activation_offloading.md +++ b/docs/api-guide/fine_grained_activation_offloading.md @@ -22,6 +22,21 @@ Currently, the supported offloading modules are `"attn_norm", "core_attn", "attn # Specify which modules are going to offload its input # Choices: "attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act". --offload-modules expert_fc1 + +# Specify the minimum tensor shape to be offloaded +# This is to avoid scattered offloading of small tensors +--min-offloaded-tensor-size 1048576 # 1M elements + +# When enabling cuda graph, delay the offloading outside graph until the graph launch. +# This is to utilize the leading advantages of CPU by cuda graph +--delay-offload-until-cuda-graph + +# Difference of offload bytes across PP ranks to balance the offload load. +# Larger PP ranks offload less bytes to reduce the overhead. +delta_offload_bytes_across_pp_ranks 1073741824 # 1GB + +# The fraction of the activation to be offloaded, which should be in range [0, 1]. +--activation-offload-fraction 0.8 ``` **Compatible with Fine-grained Recomputation** - For modules with minor perf overhead like layernorm or moe_act, use recomputing to reduce memory footprint; diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index efcbd3f8995..03508877e1a 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -760,7 +760,6 @@ def __exit__(self, *args: Any): else: raise RuntimeError("TE CPU offload is not available") self.inside_context = False - # torch._C._autograd._pop_saved_tensors_default_hooks() self._saved_tensors_hooks.__exit__() def on_save_for_backward(self, tensor: torch.Tensor) -> Any: @@ -791,14 +790,6 @@ def offload(self, src_tensor, pin_memory=True, use_cpu_pool=True): """Offload.""" debug_rank("--------offload") - from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor - - is_mxfp8_tensor = isinstance(src_tensor, MXFP8Tensor) - if is_mxfp8_tensor: - mxfp8_tensor = src_tensor - src_tensor = src_tensor._columnwise_data - else: - mxfp8_tensor = None if not src_tensor.is_contiguous(): src_tensor = src_tensor.contiguous() diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 5ee96aabeca..5f6c16ee6e8 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1326,7 +1326,10 @@ def __post_init__(self): "because the input of attn_proj is the output of core_attn, " "which is needed in core_attn.backward()." ) - if self.external_cuda_graph or self.cuda_graph_impl == "transformer_engine": + if self.external_cuda_graph or self.enable_cuda_graph: + assert ( + self.cuda_graph_impl == "transformer_engine" + ), "cuda_graph_impl must be transformer_engine when enabling offloading." assert ( self.cuda_graph_scope is not None ), "cuda_graph_scope must be set when enabling offloading." diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index b0bdf5962dc..7f72d91b450 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -1285,7 +1285,10 @@ def _set_offload_modules(self): if self.offload_module_in_cuda_graph: assert is_torch_min_version( "2.9.0a0" - ), "Fine-grained activation offloading needs torch>=2.9.0 to support cuda graph." + ), "Offloading modules captured in cuda graph requires torch>=2.9.0." + assert is_te_min_version( + "2.13.0" + ), "Offloading modules captured in cuda graph requires TE>=2.13.0." assert ( self.config.cuda_graph_warmup_steps > 0 ), "Fine-grained activation offloading needs cuda_graph_warmup_steps > 0." diff --git a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py index 3439f4a3dd4..d26ff8e128a 100644 --- a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py +++ b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py @@ -571,3 +571,338 @@ def _run_schedule_1f1b_two_microbatches( ) finally: Utils.destroy_model_parallel() + + +# ============================================================================= +# CUDA Graph + Fine-grained Activation Offloading Tests +# ============================================================================= + + +def _build_gpt_model_with_cuda_graph( + *, + seed: int, + num_layers: int, + hidden_size: int, + num_attention_heads: int, + vocab_size: int, + seq_length: int, + num_experts: Optional[int], + fine_grained_activation_offloading: bool, + offload_modules: Optional[List[str]], + min_offloaded_tensor_size: int, + is_mla: bool, + cuda_graph_impl: str, + cuda_graph_scope: Optional[List[str]], + cuda_graph_warmup_steps: int, + delay_offload_until_cuda_graph: bool = False, + activation_offload_fraction: float = 1.0, +) -> GPTModel: + """Build a GPTModel with CUDA Graph support and fine-grained activation offloading.""" + model_parallel_cuda_manual_seed(seed) + torch.manual_seed(seed) + ConfigClass = MLATransformerConfig if is_mla else TransformerConfig + transformer_config = ConfigClass( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + use_cpu_initialization=True, + attention_backend=AttnBackend.unfused, + bf16=True, + # Recompute + recompute_modules=["layernorm", "moe_act"] if num_experts is not None else ["layernorm"], + recompute_granularity="selective", + # MoE + num_moe_experts=num_experts, + moe_grouped_gemm=(num_experts is not None), + # Fine-grained activation offloading + fine_grained_activation_offloading=fine_grained_activation_offloading, + offload_modules=offload_modules, + min_offloaded_tensor_size=min_offloaded_tensor_size, + delay_offload_until_cuda_graph=delay_offload_until_cuda_graph, + activation_offload_fraction=activation_offload_fraction, + # CUDA Graph settings + cuda_graph_impl=cuda_graph_impl, + cuda_graph_scope=cuda_graph_scope, + cuda_graph_warmup_steps=cuda_graph_warmup_steps, + use_te_rng_tracker=True, + ) + gpt_model = GPTModel( + config=transformer_config, + transformer_layer_spec=get_gpt_layer_with_transformer_engine_spec( + num_experts=num_experts, + moe_grouped_gemm=num_experts is not None, + moe_use_legacy_grouped_gemm=False, + multi_latent_attention=is_mla, + ), + vocab_size=vocab_size, + max_sequence_length=seq_length, + ).bfloat16() + return gpt_model + + +def _run_iters_with_cuda_graph( + model: GPTModel, + *, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor, + num_warmup_iters: int, + num_measure_iters: int, + enable_offload_reset: bool, +) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], int]: + """ + Run multiple forward+backward iterations with CUDA graph capture. + + Returns: + - logits from last iteration (CPU float32) + - selected grads from last iteration (CPU float32) + - peak_memory_allocated (bytes) during measurement iterations + """ + from megatron.core.transformer.cuda_graphs import _CudagraphGlobalRecord, delete_cuda_graphs + + if enable_offload_reset: + off_interface.reset() + + # Warmup iterations (before CUDA graph capture) + for _ in range(num_warmup_iters): + if enable_offload_reset: + off_interface.reset() + logits = model( + input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask + ) + loss = logits.float().sum() + loss.backward() + # Zero grads for next iteration + for p in model.parameters(): + if p.grad is not None: + p.grad.zero_() + + # Trigger post-warmup offload decisions + if enable_offload_reset: + off_interface.reset() + + # Create CUDA graphs after warmup + _CudagraphGlobalRecord.create_cudagraphs() + + # Measurement iterations (with CUDA graph replay) + torch.cuda.reset_peak_memory_stats() + for i in range(num_measure_iters): + if enable_offload_reset: + off_interface.reset() + logits = model( + input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask + ) + loss = logits.float().sum() + loss.backward() + if i < num_measure_iters - 1: + for p in model.parameters(): + if p.grad is not None: + p.grad.zero_() + + torch.cuda.synchronize() + peak_bytes = int(torch.cuda.max_memory_allocated()) + + # Capture grads from last iteration + grads: Dict[str, torch.Tensor] = {} + for name, p in model.named_parameters(): + grads[name] = p.grad.detach().float().cpu() if p.grad is not None else None + + # Cleanup CUDA graphs + delete_cuda_graphs() + + return logits.detach().float().cpu(), grads, peak_bytes + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for offloading tests.") +@pytest.mark.skipif( + not is_te_min_version("2.13.0"), reason="CUDA Graph with TE RNG tracker requires TE >= 2.13.0" +) +@pytest.mark.parametrize( + "is_mla, offload_modules, cuda_graph_scope, activation_offload_fraction, delay_offload", + [ + # MoE model with attention CUDA graph + attn offloading + (False, ["core_attn", "attn_proj"], ["attn", "moe_router"], 1.0, True), + (False, ["expert_fc1", "moe_act"], ["attn", "moe_router", "moe_preprocess"], 1.0, True), + (False, ["core_attn", "attn_proj", "expert_fc1"], ["attn", "moe_router"], 1.0, True), + ( + False, + ["core_attn", "attn_proj", "expert_fc1", "moe_act"], + ["attn", "moe_router"], + 1.0, + True, + ), + ( + False, + ["core_attn", "expert_fc1", "moe_act"], + ["attn", "moe_router", "moe_preprocess"], + 1.0, + True, + ), + ( + True, + ["core_attn", "attn_proj", "expert_fc1", "moe_act"], + ["attn", "moe_router", "moe_preprocess"], + 1.0, + True, + ), + # Test activation_offload_fraction parameter + (False, ["core_attn", "attn_proj", "expert_fc1"], ["attn", "moe_router"], 0.0, True), + (False, ["core_attn", "attn_proj", "expert_fc1"], ["attn", "moe_router"], 0.5, True), + # Test delay_offload_until_cuda_graph parameter + (False, ["core_attn", "attn_proj", "expert_fc1"], ["attn", "moe_router"], 1.0, False), + ], +) +def test_fine_grained_activation_offloading_with_cuda_graph( + is_mla: bool, + offload_modules: List[str], + cuda_graph_scope: List[str], + activation_offload_fraction: float, + delay_offload: bool, +): + """ + Test fine-grained activation offloading combined with CUDA graph capture. + + Verifies: + - Forward output correctness with CUDA graph + offloading + - Backward gradient correctness + - Memory savings from offloading are preserved with CUDA graphs + - Different activation_offload_fraction values work correctly + - Both delay_offload_until_cuda_graph=True/False produce correct results + """ + from megatron.core.tensor_parallel.random import initialize_rng_tracker + + os.environ.pop("NVTE_FUSED_ATTN", None) + os.environ.pop("NVTE_FLASH_ATTN", None) + os.environ.pop("NVTE_UNFUSED_ATTN", None) + + initialize_rng_tracker(use_te_rng_tracker=True, force_reset=True) + Utils.initialize_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=1) + + seed = 123 + num_experts = 4 # Always MoE model + num_layers = 4 # Smaller for faster test with CUDA graphs + hidden_size = 1024 + num_attention_heads = 8 + vocab_size = 512 + seq_length = 512 + micro_batch_size = 2 + device = torch.device("cuda") + cuda_graph_warmup_steps = 3 + + input_ids, position_ids, attention_mask = _make_gpt_inputs( + seq_length=seq_length, micro_batch_size=micro_batch_size, device=device + ) + + off_interface.reset_instance() + + try: + # 1) Baseline: CUDA graph enabled, offloading disabled + _reset_cuda_memory() + base_model = _build_gpt_model_with_cuda_graph( + seed=seed, + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + vocab_size=vocab_size, + seq_length=seq_length, + num_experts=num_experts, + fine_grained_activation_offloading=False, + offload_modules=None, + min_offloaded_tensor_size=1024 * 1024, + is_mla=is_mla, + cuda_graph_impl="transformer_engine", + cuda_graph_scope=cuda_graph_scope, + cuda_graph_warmup_steps=cuda_graph_warmup_steps, + ).cuda() + base_model.train() + + base_logits, base_grads, base_peak = _run_iters_with_cuda_graph( + base_model, + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + num_warmup_iters=cuda_graph_warmup_steps, + num_measure_iters=2, + enable_offload_reset=False, + ) + del base_model + _reset_cuda_memory() + + # 2) Test: CUDA graph enabled + offloading enabled + off_interface.reset_instance() + + off_model = _build_gpt_model_with_cuda_graph( + seed=seed, + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + vocab_size=vocab_size, + seq_length=seq_length, + num_experts=num_experts, + fine_grained_activation_offloading=True, + offload_modules=offload_modules, + min_offloaded_tensor_size=1024, # Force offloading for determinism + is_mla=is_mla, + cuda_graph_impl="transformer_engine", + cuda_graph_scope=cuda_graph_scope, + cuda_graph_warmup_steps=cuda_graph_warmup_steps, + delay_offload_until_cuda_graph=delay_offload, + activation_offload_fraction=activation_offload_fraction, + ).cuda() + off_model.train() + + off_logits, off_grads, off_peak = _run_iters_with_cuda_graph( + off_model, + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + num_warmup_iters=cuda_graph_warmup_steps, + num_measure_iters=2, + enable_offload_reset=True, + ) + del off_model + _reset_cuda_memory() + + # 3) Correctness checks + assert torch.allclose( + off_logits, base_logits, rtol=1e-2, atol=1e-2 + ), f"Logits mismatch: max_diff={torch.max(torch.abs(off_logits - base_logits))}" + assert set(off_grads.keys()) == set(base_grads.keys()) + for name, gb in base_grads.items(): + go = off_grads[name] + if gb is None or go is None: + assert gb is None and go is None, f"Grad None mismatch for {name}" + continue + assert torch.allclose( + go, gb, rtol=1e-2, atol=1e-2 + ), f"Grad mismatch for {name}: max_diff={torch.max(torch.abs(go - gb))}" + + # 4) Memory checks - offloading should still reduce memory with CUDA graphs + saved_mib = (base_peak - off_peak) / (1024**2) + print( + f"CUDA Graph + Offload test (fraction={activation_offload_fraction}, delay={delay_offload}): " + f"base_peak={base_peak/(1024**2):.2f}MiB, " + f"off_peak={off_peak/(1024**2):.2f}MiB, " + f"saved={saved_mib:.2f}MiB" + ) + + # Basic sanity checks + assert not torch.isnan(off_logits).any(), "NaN detected in logits" + assert not torch.isinf(off_logits).any(), "Inf detected in logits" + + # Check gradients are valid + for name, g in off_grads.items(): + if g is not None: + assert not torch.isnan(g).any(), f"NaN detected in grad for {name}" + assert not torch.isinf(g).any(), f"Inf detected in grad for {name}" + + # Note: With CUDA graphs, memory behavior may differ from eager mode. + # We check that offloading doesn't significantly increase memory. + # In some cases, graph capture overhead may offset offload savings. + assert saved_mib >= -DELTA, ( + f"Offloading with CUDA graph significantly increased memory: " + f"saved={saved_mib:.2f}MiB (negative means increase)" + ) + + finally: + Utils.destroy_model_parallel() From 7640bf386980fccb8b26ea59ba3994967444921e Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Wed, 25 Feb 2026 01:24:07 -0800 Subject: [PATCH 61/74] minor refactor Signed-off-by: Hongbin Liu --- .../core/models/gpt/fine_grained_callables.py | 17 ++++----- .../fine_grained_activation_offload.py | 15 ++++---- megatron/core/transformer/attention.py | 36 +++++++++--------- megatron/core/transformer/moe/experts.py | 38 +++++++++---------- .../transformer/multi_latent_attention.py | 28 ++++++-------- .../core/transformer/transformer_layer.py | 25 +++++------- 6 files changed, 74 insertions(+), 85 deletions(-) diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index b9d96f39e9a..d60b05ba478 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -478,18 +478,16 @@ def forward_func( ) if not isinstance(layer.mlp, MoELayer): return hidden_states, None, None, None + mlp_norm_manager = off_interface(layer.offload_mlp_norm, hidden_states, "mlp_norm") + node.layer_state.mlp_norm_manager = mlp_norm_manager if layer.recompute_pre_mlp_layernorm: layer.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput() - with off_interface( - layer.offload_mlp_norm, hidden_states, "mlp_norm" - ) as hidden_states: + with mlp_norm_manager as hidden_states: pre_mlp_layernorm_output = layer.pre_mlp_norm_checkpoint.checkpoint( layer.pre_mlp_layernorm, hidden_states ) else: - with off_interface( - layer.offload_mlp_norm, hidden_states, "mlp_norm" - ) as hidden_states: + with mlp_norm_manager as hidden_states: pre_mlp_layernorm_output = layer.pre_mlp_layernorm(hidden_states) shared_expert_output = layer.mlp.shared_experts_compute(pre_mlp_layernorm_output) @@ -589,10 +587,11 @@ def submodule_combine_forward(node: ScheduleNode, output: torch.Tensor): ) # Delay the offload of the mlp norm until after the mlp_bda has been computed # because the residual is needed in the mlp_bda. - if layer.offload_mlp_norm: - hidden_states = off_interface.group_commit( - hidden_states, name="mlp_norm", forced_released_tensors=[residual] + if hasattr(node.layer_state, 'mlp_norm_manager'): + hidden_states = node.layer_state.mlp_norm_manager.group_offload( + hidden_states, forced_released_tensors=[residual] ) + delattr(node.layer_state, 'mlp_norm_manager') output = make_viewless_tensor( inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True ) diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index 03508877e1a..2778b05ebb0 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -1309,13 +1309,14 @@ def init_chunk_handler( def get_context(flag): """Get the fine-grained offload context""" return PipelineOffloadManager.get_instance() if flag else nullcontext() - - @staticmethod - def group_commit(tensor, name, forced_released_tensors=None, delay_offload=False): - """Group commit the tensors.""" - return fine_grained_offloading_group_commit( - tensor, name, forced_released_tensors, delay_offload - ) + + def group_offload(self, tensor, forced_released_tensors=None, delay_offload=False): + """Group offload the tensors.""" + if self.offload: + return fine_grained_offloading_group_commit( + tensor, self.name, forced_released_tensors, delay_offload + ) + return tensor @staticmethod def mark_not_offload(tensor: torch.Tensor): diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index bc5e4e2ee0d..57ee0837522 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -981,18 +981,18 @@ def forward( if output_gate: assert split_qkv, "output_gate is not supported for unsplit mixed_qkv tensor." - with off_interface(self.offload_qkv_linear, hidden_states, "qkv_linear") as hidden_states: + qkv_linear_manager = off_interface(self.offload_qkv_linear, hidden_states, "qkv_linear") + with qkv_linear_manager as hidden_states: qkv_output = self.get_query_key_value_tensors( hidden_states, key_value_states, split_qkv=split_qkv, output_gate=self.config.attention_output_gate, ) - if self.offload_qkv_linear: - # `qkv_output` may be a tuple; commit supports tuple/list and will keep structure. - qkv_output = off_interface.group_commit( - qkv_output, name="qkv_linear", forced_released_tensors=[] - ) + # `qkv_output` may be a tuple; commit supports tuple/list and will keep structure. + qkv_output = qkv_linear_manager.group_offload( + qkv_output, forced_released_tensors=[hidden_states] + ) attn_mask_type = self.attn_mask_type block_table = None gate = None @@ -1135,6 +1135,11 @@ def forward( # ================================== nvtx_range_push(suffix="core_attention") + core_attn_manager = off_interface( + self.offload_core_attention and self.training, + query, + "core_attn", + ) if self.checkpoint_core_attention and self.training: core_attn_out = self._checkpointed_attention_forward( query, @@ -1148,9 +1153,7 @@ def forward( else: if inference_context is None or inference_context.is_static_batching(): # Static batching attention kernel. - with off_interface( - self.offload_core_attention and self.training, query, "core_attn" - ) as query: + with core_attn_manager as query: core_attn_out = apply_module(self.core_attention)( query, key, @@ -1186,10 +1189,9 @@ def forward( if is_using_quantization_scales(self.config): core_attn_out[inference_context.padding_slice] = 0.0 - if self.offload_core_attention and self.training: - core_attn_out = off_interface.group_commit( - core_attn_out, name="core_attn", forced_released_tensors=[query, key, value] - ) + core_attn_out = core_attn_manager.group_offload( + core_attn_out, forced_released_tensors=[query, key, value] + ) if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': # reshape to same output shape as unpacked case @@ -1209,12 +1211,10 @@ def forward( # Output. [sq, b, h] # ================= nvtx_range_push(suffix="linear_proj") - with off_interface(self.offload_attn_proj, core_attn_out, "attn_proj") as core_attn_out: + attn_proj_manager = off_interface(self.offload_attn_proj, core_attn_out, "attn_proj") + with attn_proj_manager as core_attn_out: output, bias = self.linear_proj(core_attn_out) - if self.offload_attn_proj: - output = off_interface.group_commit( - output, name="attn_proj", forced_released_tensors=[core_attn_out] - ) + output = attn_proj_manager.group_offload(output, forced_released_tensors=[core_attn_out]) nvtx_range_pop(suffix="linear_proj") return output, bias diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index caf820f9741..28b75299675 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -681,19 +681,20 @@ def forward( # Probs already applied, so reset to 1. permuted_probs = torch.ones_like(permuted_probs) - with off_interface( - self.offload_expert_fc1, permuted_local_hidden_states, "expert_fc1" - ) as permuted_local_hidden_states: + expert_fc1_manager = off_interface( + self.offload_expert_fc1, + permuted_local_hidden_states, + "expert_fc1", + ) + with expert_fc1_manager as permuted_local_hidden_states: fc1_output, bias_parallel = self.linear_fc1( permuted_local_hidden_states, tokens_per_expert ) - if self.offload_expert_fc1: - fc1_output = off_interface.group_commit( - fc1_output, - name="expert_fc1", - forced_released_tensors=[permuted_local_hidden_states], - delay_offload=self.config.delay_offload_until_cuda_graph, - ) + fc1_output = expert_fc1_manager.group_offload( + fc1_output, + forced_released_tensors=[permuted_local_hidden_states], + delay_offload=self.config.delay_offload_until_cuda_graph, + ) def bias_act_func(intermediate_parallel, bias_parallel, permuted_probs): if self.config.use_te_activation_func: @@ -753,14 +754,15 @@ def glu(x): intermediate_parallel = intermediate_parallel.to(original_dtype) return intermediate_parallel + moe_act_manager = off_interface(self.offload_moe_act, fc1_output, "moe_act") if self.activation_recompute: self.activation_checkpoint = tensor_parallel.CheckpointWithoutOutput() - with off_interface(self.offload_moe_act, fc1_output, "moe_act") as fc1_output: + with moe_act_manager as fc1_output: bias_act_output = self.activation_checkpoint.checkpoint( bias_act_func, fc1_output, bias_parallel, permuted_probs ) else: - with off_interface(self.offload_moe_act, fc1_output, "moe_act") as fc1_output: + with moe_act_manager as fc1_output: bias_act_output = bias_act_func(fc1_output, bias_parallel, permuted_probs) output, output_bias = self.linear_fc2(bias_act_output, tokens_per_expert) @@ -769,13 +771,11 @@ def glu(x): # Delay the offload of the moe act until after the linear_fc2 has been computed # to make sure the fc1_output is reloaded to GPU before recomputing moe_act. - if self.offload_moe_act: - output = off_interface.group_commit( - output, - name="moe_act", - forced_released_tensors=[fc1_output], - delay_offload=self.config.delay_offload_until_cuda_graph, - ) + output = moe_act_manager.group_offload( + output, + forced_released_tensors=[fc1_output], + delay_offload=self.config.delay_offload_until_cuda_graph, + ) output = self._apply_bias(output, output_bias, tokens_per_expert, permuted_probs) # upad and concat the output diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index cd3db50a35b..b41117f4adc 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -243,7 +243,8 @@ def forward( # Get the query, key and value tensors based on the type of attention - # self or cross attn. # query: [96, 1, 16, 128], key:[96, 1, 16, 128], value:[96, 1, 16, 128] - with off_interface(self.offload_qkv_linear, hidden_states, "qkv_linear") as hidden_states: + qkv_linear_manager = off_interface(self.offload_qkv_linear, hidden_states, "qkv_linear") + with qkv_linear_manager as hidden_states: query, key, value, q_compressed, kv_compressed = self.get_query_key_value_tensors( hidden_states, key_value_states, @@ -251,10 +252,7 @@ def forward( packed_seq_params, inference_context=inference_context, ) - if self.offload_qkv_linear: - query = off_interface.group_commit( - query, name="qkv_linear", forced_released_tensors=[hidden_states] - ) + query = qkv_linear_manager.group_offload(query, forced_released_tensors=[hidden_states]) # =================================================== # Adjust key, value for inference @@ -276,6 +274,7 @@ def forward( # core attention computation # ================================== # Need corresponding TE change + core_attn_manager = off_interface(self.offload_core_attention and self.training, query, "core_attn") if self.checkpoint_core_attention and self.training: core_attn_out = self._checkpointed_attention_forward( query, key, value, attention_mask, packed_seq_params=packed_seq_params @@ -288,9 +287,7 @@ def forward( # query representation. extra_kwargs["x"] = hidden_states extra_kwargs["qr"] = q_compressed - with off_interface( - self.offload_core_attention and self.training, query, "core_attn" - ) as query: + with core_attn_manager as query: core_attn_out = self.core_attention( query, key, @@ -320,10 +317,9 @@ def forward( # Only rearrange if not in absorption mode (Flash MLA handles format correctly) if not inference_context.is_decode_only(): core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)') - if self.offload_core_attention and self.training: - core_attn_out = off_interface.group_commit( - core_attn_out, name="core_attn", forced_released_tensors=[query, key, value] - ) + core_attn_out = core_attn_manager.group_offload( + core_attn_out, forced_released_tensors=[query, key, value] + ) # We are doing absorption with cache mla latents and decode mode. if self.cache_mla_latents and inference_context.is_decode_only(): @@ -349,12 +345,10 @@ def forward( # ================= # Output. [sq, b, h] # ================= - with off_interface(self.offload_attn_proj, core_attn_out, "attn_proj") as core_attn_out: + attn_proj_manager = off_interface(self.offload_attn_proj, core_attn_out, "attn_proj") + with attn_proj_manager as core_attn_out: output, bias = self.linear_proj(core_attn_out) - if self.offload_attn_proj: - output = off_interface.group_commit( - output, name="attn_proj", forced_released_tensors=[core_attn_out] - ) + output = attn_proj_manager.group_offload(output, forced_released_tensors=[core_attn_out]) return output, bias diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 7f72d91b450..ebed4d87f80 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -573,14 +573,15 @@ def _forward_attention( residual = hidden_states # Optional Input Layer norm + attn_norm_manager = off_interface(self.offload_attn_norm, hidden_states, "attn_norm") if self.recompute_input_layernorm: self.input_layernorm_checkpoint = tensor_parallel.CheckpointWithoutOutput() - with off_interface(self.offload_attn_norm, hidden_states, "attn_norm") as hidden_states: + with attn_norm_manager as hidden_states: input_layernorm_output = self.input_layernorm_checkpoint.checkpoint( self.input_layernorm, hidden_states ) else: - with off_interface(self.offload_attn_norm, hidden_states, "attn_norm") as hidden_states: + with attn_norm_manager as hidden_states: input_layernorm_output = self.input_layernorm(hidden_states) using_fused_tp_inference_kernel = (not self.training) and ( @@ -632,10 +633,7 @@ def _forward_attention( # Delay the offload of the attention norm until after the self_attn_bda has been computed # because the residual is needed in the self_attn_bda. - if self.offload_attn_norm: - hidden_states = off_interface.group_commit( - hidden_states, name="attn_norm", forced_released_tensors=[residual] - ) + hidden_states = attn_norm_manager.group_offload(hidden_states, forced_released_tensors=[residual]) # Residual connection. residual = hidden_states @@ -668,14 +666,15 @@ def _forward_pre_mlp_layernorm(self, hidden_states): FineGrainedActivationOffloadingInterface as off_interface, ) + self.mlp_norm_manager = off_interface(self.offload_mlp_norm, hidden_states, "mlp_norm") if self.recompute_pre_mlp_layernorm: self.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput() - with off_interface(self.offload_mlp_norm, hidden_states, "mlp_norm") as hidden_states: + with self.mlp_norm_manager as hidden_states: pre_mlp_layernorm_output = self.pre_mlp_norm_checkpoint.checkpoint( self.pre_mlp_layernorm, hidden_states ) else: - with off_interface(self.offload_mlp_norm, hidden_states, "mlp_norm") as hidden_states: + with self.mlp_norm_manager as hidden_states: pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states) return pre_mlp_layernorm_output @@ -788,9 +787,6 @@ def _forward_post_mlp(self, mlp_output_with_bias, residual, flush_delayed_groups Returns: output (Tensor): Transformed hidden states of shape [s, b, h]. """ - from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - FineGrainedActivationOffloadingInterface as off_interface, - ) using_fused_tp_inference_kernel = (not self.training) and ( self.config.inference_fuse_tp_communication @@ -819,10 +815,9 @@ def _forward_post_mlp(self, mlp_output_with_bias, residual, flush_delayed_groups nvtx_range_pop(suffix="mlp_bda") # Delay the offload of the mlp norm until after the mlp_bda has been computed # because the residual is needed in the mlp_bda. - if self.offload_mlp_norm: - hidden_states = off_interface.group_commit( - hidden_states, name="mlp_norm", forced_released_tensors=[residual] - ) + if hasattr(self, 'mlp_norm_manager'): + hidden_states = self.mlp_norm_manager.group_offload(hidden_states, forced_released_tensors=[residual]) + delattr(self, 'mlp_norm_manager') # Jit compiled function creates 'view' tensor. This tensor # potentially gets saved in the MPU checkpoint function context, From dfc8161cad03380e1731ba0d40d8db0cfdba7844 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Wed, 25 Feb 2026 02:28:08 -0800 Subject: [PATCH 62/74] format Signed-off-by: Hongbin Liu --- .../pipeline_parallel/fine_grained_activation_offload.py | 2 +- megatron/core/transformer/attention.py | 4 +--- megatron/core/transformer/moe/experts.py | 4 +--- megatron/core/transformer/multi_latent_attention.py | 4 +++- megatron/core/transformer/transformer_layer.py | 8 ++++++-- 5 files changed, 12 insertions(+), 10 deletions(-) diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index 2778b05ebb0..cbdfe713712 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -1309,7 +1309,7 @@ def init_chunk_handler( def get_context(flag): """Get the fine-grained offload context""" return PipelineOffloadManager.get_instance() if flag else nullcontext() - + def group_offload(self, tensor, forced_released_tensors=None, delay_offload=False): """Group offload the tensors.""" if self.offload: diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 57ee0837522..15379d2f678 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -1136,9 +1136,7 @@ def forward( nvtx_range_push(suffix="core_attention") core_attn_manager = off_interface( - self.offload_core_attention and self.training, - query, - "core_attn", + self.offload_core_attention and self.training, query, "core_attn" ) if self.checkpoint_core_attention and self.training: core_attn_out = self._checkpointed_attention_forward( diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 28b75299675..33b3b20818b 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -682,9 +682,7 @@ def forward( permuted_probs = torch.ones_like(permuted_probs) expert_fc1_manager = off_interface( - self.offload_expert_fc1, - permuted_local_hidden_states, - "expert_fc1", + self.offload_expert_fc1, permuted_local_hidden_states, "expert_fc1" ) with expert_fc1_manager as permuted_local_hidden_states: fc1_output, bias_parallel = self.linear_fc1( diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index b41117f4adc..6ef70c8b0f3 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -274,7 +274,9 @@ def forward( # core attention computation # ================================== # Need corresponding TE change - core_attn_manager = off_interface(self.offload_core_attention and self.training, query, "core_attn") + core_attn_manager = off_interface( + self.offload_core_attention and self.training, query, "core_attn" + ) if self.checkpoint_core_attention and self.training: core_attn_out = self._checkpointed_attention_forward( query, key, value, attention_mask, packed_seq_params=packed_seq_params diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index ebed4d87f80..ce548af6533 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -633,7 +633,9 @@ def _forward_attention( # Delay the offload of the attention norm until after the self_attn_bda has been computed # because the residual is needed in the self_attn_bda. - hidden_states = attn_norm_manager.group_offload(hidden_states, forced_released_tensors=[residual]) + hidden_states = attn_norm_manager.group_offload( + hidden_states, forced_released_tensors=[residual] + ) # Residual connection. residual = hidden_states @@ -816,7 +818,9 @@ def _forward_post_mlp(self, mlp_output_with_bias, residual, flush_delayed_groups # Delay the offload of the mlp norm until after the mlp_bda has been computed # because the residual is needed in the mlp_bda. if hasattr(self, 'mlp_norm_manager'): - hidden_states = self.mlp_norm_manager.group_offload(hidden_states, forced_released_tensors=[residual]) + hidden_states = self.mlp_norm_manager.group_offload( + hidden_states, forced_released_tensors=[residual] + ) delattr(self, 'mlp_norm_manager') # Jit compiled function creates 'view' tensor. This tensor From e362f79af7268e02d601bb670b8e7b697698c180 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Wed, 25 Feb 2026 18:09:10 -0800 Subject: [PATCH 63/74] fix ut Signed-off-by: Hongbin Liu --- megatron/core/transformer/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 15379d2f678..00d2503dd00 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -991,7 +991,7 @@ def forward( ) # `qkv_output` may be a tuple; commit supports tuple/list and will keep structure. qkv_output = qkv_linear_manager.group_offload( - qkv_output, forced_released_tensors=[hidden_states] + qkv_output, forced_released_tensors=[] ) attn_mask_type = self.attn_mask_type block_table = None From 2da15b2590ea22401bc3ead27b59ed30d42de91b Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Wed, 25 Feb 2026 18:18:23 -0800 Subject: [PATCH 64/74] format Signed-off-by: Hongbin Liu --- megatron/core/transformer/attention.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 00d2503dd00..bc65caba568 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -990,9 +990,7 @@ def forward( output_gate=self.config.attention_output_gate, ) # `qkv_output` may be a tuple; commit supports tuple/list and will keep structure. - qkv_output = qkv_linear_manager.group_offload( - qkv_output, forced_released_tensors=[] - ) + qkv_output = qkv_linear_manager.group_offload(qkv_output, forced_released_tensors=[]) attn_mask_type = self.attn_mask_type block_table = None gate = None From 221cc16bd2ab47424a3d17f48e0d684060ede71c Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Thu, 26 Feb 2026 01:24:46 -0800 Subject: [PATCH 65/74] minor fix Signed-off-by: Hongbin Liu --- .../core/models/gpt/fine_grained_callables.py | 2 ++ megatron/core/transformer/cuda_graphs.py | 2 +- .../transformer/multi_latent_attention.py | 2 +- .../core/transformer/transformer_config.py | 11 +++++++++- .../core/transformer/transformer_layer.py | 21 ++++++++++++------- 5 files changed, 27 insertions(+), 11 deletions(-) diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index d60b05ba478..4e82e721936 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -596,6 +596,8 @@ def submodule_combine_forward(node: ScheduleNode, output: torch.Tensor): inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True ) + # Flush the delayed groups. + # This process happens only during the warmup steps of cuda graph. if node.chunk_state.flush_delayed_groups: off_interface.flush_delayed_groups() diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index d391c7fe039..3304d7462ef 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -2179,7 +2179,7 @@ def _get_fp8_enabled(): FineGrainedActivationOffloadingInterface as off_interface, ) - # if self.config.offload_module_in_cuda_graph: + # Disable and enable offloading before and after the warmup stage of cuda graph. if self.config.fine_grained_activation_offloading: kwargs['pre_warmup_hook'] = off_interface.disable_offload kwargs['post_warmup_hook'] = off_interface.enable_offload diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index 6ef70c8b0f3..126be90804e 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -252,7 +252,7 @@ def forward( packed_seq_params, inference_context=inference_context, ) - query = qkv_linear_manager.group_offload(query, forced_released_tensors=[hidden_states]) + query = qkv_linear_manager.group_offload(query, forced_released_tensors=[]) # =================================================== # Adjust key, value for inference diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 5f6c16ee6e8..1e013b37c63 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1313,7 +1313,6 @@ def __post_init__(self): "attn_norm", "mlp_norm", "qkv_linear", - "dense_mlp", } invalid_modules = set(self.offload_modules) - allowed_modules assert not invalid_modules, ( @@ -1326,6 +1325,16 @@ def __post_init__(self): "because the input of attn_proj is the output of core_attn, " "which is needed in core_attn.backward()." ) + if self.delay_offload_until_cuda_graph: + assert (self.external_cuda_graph or self.enable_cuda_graph, + "delay_offload_until_cuda_graph must be used with cuda graph." + ) + assert self.min_offloaded_tensor_size >= 0, \ + "min_offloaded_tensor_size must be non-negative." + assert self.activation_offload_fraction >= 0 and self.activation_offload_fraction <= 1, \ + "activation_offload_fraction must be in range [0, 1]." + assert self.delta_offload_bytes_across_pp_ranks >= 0, \ + "delta_offload_bytes_across_pp_ranks must be non-negative." if self.external_cuda_graph or self.enable_cuda_graph: assert ( self.cuda_graph_impl == "transformer_engine" diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index ce548af6533..805cd5f0eaa 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -562,6 +562,9 @@ def _forward_attention( FineGrainedActivationOffloadingInterface as off_interface, ) + # Record the backward event on cuda graph stream in backward pass. + # This is to ensure the main stream waits for computing on cuda graph stream to complete, + # and overlaps with the H2D transfer on reload stream. if self.offload_module_in_cuda_graph: hidden_states = off_interface.backward_record( hidden_states, TransformerLayer.cuda_graph_event @@ -833,6 +836,8 @@ def _forward_post_mlp(self, mlp_output_with_bias, residual, flush_delayed_groups inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True ) + # Flush the delayed groups. + # This process happens only during the warmup steps of cuda graph. if self.config.delay_offload_until_cuda_graph and flush_delayed_groups: from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( FineGrainedActivationOffloadingInterface as off_interface, @@ -1009,6 +1014,9 @@ def _te_cuda_graph_capture(self, *args, **kwargs): cuda_graph_outputs = list(hidden_states) if context is not None: cuda_graph_outputs.append(context) + # Record the forward event on cuda graph stream for cuda graph capture. + # This is to ensure the main stream waits for computing on cuda graph stream to complete, + # and overlaps with the D2H transfer on offloading stream. if self.offload_module_in_cuda_graph: from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( FineGrainedActivationOffloadingInterface as off_interface, @@ -1040,6 +1048,10 @@ def _te_cuda_graph_replay(self, *args, **kwargs): cuda_graph_output = list(super()._te_cuda_graph_replay(*args, **kwargs)) + # Flush the delayed groups after the cuda graph replay. + # This is to reduce or eliminate the cpu overhead of offloading because + # there exists a synchronization between the cuda graph replay and the a2a comm + # in moe layer, at that point the host thread is blocked and idle. if self.config.delay_offload_until_cuda_graph: from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( FineGrainedActivationOffloadingInterface as off_interface, @@ -1261,9 +1273,6 @@ def _set_offload_modules(self): ) self.offload_expert_fc1 = "expert_fc1" in self.config.offload_modules self.offload_moe_act = "moe_act" in self.config.offload_modules - self.offload_dense_mlp = ( - "dense_mlp" in self.config.offload_modules and not self.is_moe_layer - ) else: self.offload_attn_norm = False self.offload_qkv_linear = False @@ -1272,19 +1281,15 @@ def _set_offload_modules(self): self.offload_mlp_norm = False self.offload_expert_fc1 = False self.offload_moe_act = False - self.offload_dense_mlp = False # Set the offload module in cuda graph flag. self.offload_module_in_cuda_graph = False if CudaGraphScope.attn in self.config.cuda_graph_scope: if self.offload_core_attn or self.offload_attn_proj or self.offload_qkv_linear: self.offload_module_in_cuda_graph = True if not self.is_moe_layer and CudaGraphScope.mlp in self.config.cuda_graph_scope: - if self.offload_mlp_norm or self.offload_dense_mlp: + if self.offload_mlp_norm: self.offload_module_in_cuda_graph = True if self.offload_module_in_cuda_graph: - assert is_torch_min_version( - "2.9.0a0" - ), "Offloading modules captured in cuda graph requires torch>=2.9.0." assert is_te_min_version( "2.13.0" ), "Offloading modules captured in cuda graph requires TE>=2.13.0." From 726c52666337ab37ad45fd53e12d179e8cca235d Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Thu, 26 Feb 2026 01:29:40 -0800 Subject: [PATCH 66/74] format and minor fix Signed-off-by: Hongbin Liu --- .../core/transformer/transformer_config.py | 20 +++++++++++-------- .../core/transformer/transformer_layer.py | 3 +++ 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 1e013b37c63..a2f0c67d215 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1326,15 +1326,19 @@ def __post_init__(self): "which is needed in core_attn.backward()." ) if self.delay_offload_until_cuda_graph: - assert (self.external_cuda_graph or self.enable_cuda_graph, - "delay_offload_until_cuda_graph must be used with cuda graph." + assert ( + self.external_cuda_graph or self.enable_cuda_graph, + "delay_offload_until_cuda_graph must be used with cuda graph.", ) - assert self.min_offloaded_tensor_size >= 0, \ - "min_offloaded_tensor_size must be non-negative." - assert self.activation_offload_fraction >= 0 and self.activation_offload_fraction <= 1, \ - "activation_offload_fraction must be in range [0, 1]." - assert self.delta_offload_bytes_across_pp_ranks >= 0, \ - "delta_offload_bytes_across_pp_ranks must be non-negative." + assert ( + self.min_offloaded_tensor_size >= 0 + ), "min_offloaded_tensor_size must be non-negative." + assert ( + self.activation_offload_fraction >= 0 and self.activation_offload_fraction <= 1 + ), "activation_offload_fraction must be in range [0, 1]." + assert ( + self.delta_offload_bytes_across_pp_ranks >= 0 + ), "delta_offload_bytes_across_pp_ranks must be non-negative." if self.external_cuda_graph or self.enable_cuda_graph: assert ( self.cuda_graph_impl == "transformer_engine" diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 805cd5f0eaa..96f266bd60d 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -1290,6 +1290,9 @@ def _set_offload_modules(self): if self.offload_mlp_norm: self.offload_module_in_cuda_graph = True if self.offload_module_in_cuda_graph: + assert is_torch_min_version( + "2.9.0a0" + ), "Offloading modules captured in cuda graph requires torch>=2.9.0." assert is_te_min_version( "2.13.0" ), "Offloading modules captured in cuda graph requires TE>=2.13.0." From 998d1b06f9dd3be9f245965634f914c264a71890 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Fri, 27 Feb 2026 06:17:34 -0800 Subject: [PATCH 67/74] 1. replace hasattr+delattr with None; 2. refine offloading docs; 3. remove TransformerLayer.cuda_graph_stream and cuda_graph_event Signed-off-by: Hongbin Liu --- .../fine_grained_activation_offloading.md | 155 ++++++++++++++---- .../core/models/gpt/fine_grained_callables.py | 7 +- .../fine_grained_activation_offload.py | 23 ++- megatron/core/transformer/module.py | 9 +- .../core/transformer/transformer_config.py | 5 +- .../core/transformer/transformer_layer.py | 25 +-- ...test_fine_grained_activation_offloading.py | 1 - 7 files changed, 153 insertions(+), 72 deletions(-) diff --git a/docs/api-guide/fine_grained_activation_offloading.md b/docs/api-guide/fine_grained_activation_offloading.md index eee1eb8445e..91edec48d68 100644 --- a/docs/api-guide/fine_grained_activation_offloading.md +++ b/docs/api-guide/fine_grained_activation_offloading.md @@ -1,46 +1,141 @@ -# Fine-grained Activation Offloading (collaborated with rednote) +# Fine-Grained Activation Offloading -Memory capacity is more and more important with the rising of extreme sparse MoE models like DeepSeek-V3 and Qwen3-235B. Fine-grained recomputing reduces the memory footprint at the cost of extra recomputation, while offloading could utilize the host-device bandwidth to achieve nearly zero-overhead. Fine-grained Activation Offloading targets at offloading the activation at the granularity of specific modules, so that we can calibrate the amount of offloading activation to maximize the training throughput. +Fine-grained activation offloading reduces GPU memory by asynchronously transferring activations to CPU at the granularity of individual submodules within a transformer layer. Unlike layer-level offloading, it allows precise control over which activations to offload, enabling a tradeoff between memory savings and PCIe bandwidth overhead. -Currently, the supported offloading modules are `"attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act"`, which could work with fine-grained recomputation to release almost all activations of a transformer layer. +## User Guide -**Features** -* Support PP=1/PP/Interleaved PP -* Compatible with fine-grained recomputation -* Support FP8 -* Support MTP -* Support mixed dense & moe layer -* Support A2A Overlap -* Support CUDA Graph - * (Temporary) cuda graph scope cannot contains the offloading modules +### Basic Usage -**Usage** ```bash # Enable fine-grained activation offloading --fine-grained-activation-offloading -# Specify which modules are going to offload its input -# Choices: "attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act". ---offload-modules expert_fc1 +# Specify which modules to offload (can combine multiple) +# Choices: attn_norm, qkv_linear, core_attn, attn_proj, mlp_norm, expert_fc1, moe_act +--offload-modules core_attn attn_proj expert_fc1 +``` -# Specify the minimum tensor shape to be offloaded -# This is to avoid scattered offloading of small tensors ---min-offloaded-tensor-size 1048576 # 1M elements +### Offloadable Modules -# When enabling cuda graph, delay the offloading outside graph until the graph launch. -# This is to utilize the leading advantages of CPU by cuda graph ---delay-offload-until-cuda-graph +Each module offloads its **input** activation to CPU during forward and reloads it before backward: + +| Module | Description | Notes | +|---|---|---| +| `attn_norm` | Input layernorm of attention | Skipped if using `IdentityOp` | +| `qkv_linear` | QKV linear projection | | +| `core_attn` | Core attention (softmax + matmul) | | +| `attn_proj` | Output projection of attention | Must be used together with `core_attn` | +| `mlp_norm` | Pre-MLP layernorm | Skipped if using `IdentityOp` | +| `expert_fc1` | First FC layer in MoE experts | MoE models only | +| `moe_act` | Activation function in MoE experts | MoE models only | -# Difference of offload bytes across PP ranks to balance the offload load. -# Larger PP ranks offload less bytes to reduce the overhead. -delta_offload_bytes_across_pp_ranks 1073741824 # 1GB +### Tuning Parameters + +```bash +# Minimum tensor size (in elements) to offload. Smaller tensors are skipped. +# Default: 1048576 (1M elements) +--min-offloaded-tensor-size 1048576 -# The fraction of the activation to be offloaded, which should be in range [0, 1]. +# Fraction of activations to offload, range [0, 1]. Default: 1.0 +# Useful for partial offloading when PCIe bandwidth is a bottleneck. --activation-offload-fraction 0.8 + +# Reduce offload amount on higher PP ranks (in bytes). Default: 0 +# Higher PP ranks have fewer microbatches in flight, so offloading less +# reduces overhead without increasing peak memory. +--delta-offload-bytes-across-pp-ranks 1073741824 +``` + +### CUDA Graph Integration + +Fine-grained offloading is compatible with CUDA graphs. When CUDA graph is enabled, the following constraints apply: + +- `attn_norm` and `mlp_norm` **cannot** be offloaded (they cross CUDA graph boundaries). +- `cuda_graph_scope` must include `attn` and `moe_router`. +- `cuda_graph_impl` must be `transformer_engine`. +- Requires `torch >= 2.9.0` and `transformer_engine >= 2.13.0`. + +```bash +# Delay offloading until CUDA graph launch to hide CPU overhead +--delay-offload-until-cuda-graph +``` + +### Combining with Fine-Grained Recomputation + +Offloading and recomputation are complementary: +- Use **recomputation** for lightweight modules (e.g., layernorm, activation functions) with negligible compute overhead. +- Use **offloading** for heavy modules (e.g., core_attn, expert_fc1) where recomputation would be too costly. + +```bash +--recompute-granularity selective +--recompute-modules layernorm moe_act +--fine-grained-activation-offloading +--offload-modules core_attn attn_proj expert_fc1 ``` -**Compatible with Fine-grained Recomputation** -- For modules with minor perf overhead like layernorm or moe_act, use recomputing to reduce memory footprint; -- For other modules, use offloading to reduce memory footprint; -- Make sure the offloading/reloading could be overlapped with computing; ![Fine-grained Activation Offloading and Fine-grained Recomputation](../../images/fine_grained_activation_offloading/offloading_and_recomputing.png) + + +### Compatibility + +| Feature | Supported | +|---|---| +| PP / Interleaved PP / PP=1 | Yes | +| Fine-grained recomputation | Yes | +| FP8 training | Yes | +| MTP (Multi-Token Prediction) | Yes | +| Mixed dense & MoE layers | Yes | +| A2A overlap (EP) | Yes | +| CUDA Graph (TE impl) | Yes | + +--- + +## How It Works + +### Architecture Overview + +The implementation consists of three layers: + +1. **`PipelineOffloadManager`** (singleton): Global coordinator that manages CUDA streams, CPU tensor pools, and chunk lifecycle across pipeline stages. +2. **`ChunkOffloadHandler`**: Per-microbatch handler that tracks tensor groups, executes D2H/H2D transfers, and decides which groups to actually offload. +3. **`FineGrainedActivationOffloadingInterface`**: Lightweight interface used by transformer modules (attention, MoE, etc.) to mark offload boundaries. + +### Offload/Reload Flow + +``` +Forward pass (Layer N): Backward pass (Layer N): +┌─────────────────────┐ ┌───────────────────────┐ +│ group_start(input) │─── register ──► │ │ +│ │ tensor group │ group_commit_backward │ +│ module.forward() │ │ wait H2D complete │ +│ │ │ pop tensors from │ +│ group_offload(out) │─── D2H async ──► │ CPU → GPU │ +│ on d2h_stream │ to pinned CPU │ on h2d_stream │ +└─────────────────────┘ └───────────────────────┘ +``` + +1. **`group_start`**: Registers a new tensor group and hooks into `saved_tensors_hooks` to intercept `save_for_backward`. +2. **Forward execution**: All tensors saved by autograd within the group are captured. +3. **`group_offload`**: Triggers asynchronous D2H copy on a dedicated CUDA stream (`d2h_stream`), optionally releases GPU storage of input tensors. +4. **Backward**: Before the group's backward, tensors are reloaded from CPU to GPU on `h2d_stream`, and the compute stream waits for the transfer to complete. + +### Warmup and Adaptive Offloading + +The first training iteration serves as a **warmup phase** where the manager records tensor groups, their sizes, and the execution order. After warmup, a `post_warmup_callback` runs to: + +1. **Reserve margin**: The last N groups (by deduplication count) are kept on GPU to avoid reload blocking the compute stream. +2. **Apply PP rank delta**: Higher PP ranks offload fewer bytes (controlled by `delta_offload_bytes_across_pp_ranks`). +3. **Apply fraction**: Only a fraction of eligible groups are actually offloaded (controlled by `activation_offload_fraction`). +4. **Print summary table**: An ASCII table of per-rank offload bytes is printed for debugging. + +### CPU Tensor Pool + +A `GPUTensorPool` (on CPU with pinned memory) caches allocated tensors by `(shape, dtype)`. This avoids repeated `cudaMallocHost` / `cudaFreeHost` calls and reduces D2H latency after the first iteration. + +### CUDA Graph Support + +When offloading modules captured inside a CUDA graph: + +- A dedicated `cuda_graph_stream` runs the captured computation, while `d2h_stream` overlaps D2H transfers. +- During CUDA graph **warmup**, offloading is disabled (`pre_warmup_hook` / `post_warmup_hook`). +- The `delay_offload_until_cuda_graph` option defers D2H launches until graph replay, utilizing the CPU idle time during `cudaGraphLaunch` to issue offload commands with near-zero CPU overhead. diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index 966b088de39..c13fbeab229 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -590,11 +590,12 @@ def submodule_combine_forward(node: ScheduleNode, output: torch.Tensor): ) # Delay the offload of the mlp norm until after the mlp_bda has been computed # because the residual is needed in the mlp_bda. - if hasattr(node.layer_state, 'mlp_norm_manager'): - hidden_states = node.layer_state.mlp_norm_manager.group_offload( + mlp_norm_manager = getattr(node.layer_state, 'mlp_norm_manager', None) + if mlp_norm_manager is not None: + hidden_states = mlp_norm_manager.group_offload( hidden_states, forced_released_tensors=[residual] ) - delattr(node.layer_state, 'mlp_norm_manager') + node.layer_state.mlp_norm_manager = None output = make_viewless_tensor( inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True ) diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index cbdfe713712..58771d18609 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -1242,18 +1242,17 @@ class FineGrainedOffloadingBackwardRecordFunction(torch.autograd.Function): """ @staticmethod - def forward(ctx, tensor, event: torch.cuda.Event) -> torch.Tensor: + def forward(ctx, tensor) -> torch.Tensor: """Forward pass for cuda graph capture.""" - ctx.event = event return tensor @staticmethod def backward(ctx, grad_output): """Record the backward event and wait for the h2d stream on cuda graph stream.""" - h2d_stream = PipelineOffloadManager.get_instance().h2d_stream - torch.cuda.current_stream().record_event(ctx.event) - torch.cuda.current_stream().wait_stream(h2d_stream) - return grad_output, None + mgr = PipelineOffloadManager.get_instance() + torch.cuda.current_stream().record_event(mgr.cuda_graph_event) + torch.cuda.current_stream().wait_stream(mgr.h2d_stream) + return (grad_output,) class FineGrainedActivationOffloadingInterface: @@ -1324,16 +1323,16 @@ def mark_not_offload(tensor: torch.Tensor): PipelineOffloadManager.get_instance().mark_not_offload(tensor) @staticmethod - def forward_record(event: torch.cuda.Event) -> None: + def forward_record() -> None: """Record the forward event for cuda graph capture.""" - d2h_stream = PipelineOffloadManager.get_instance().d2h_stream - torch.cuda.current_stream().record_event(event) - torch.cuda.current_stream().wait_stream(d2h_stream) + mgr = PipelineOffloadManager.get_instance() + torch.cuda.current_stream().record_event(mgr.cuda_graph_event) + torch.cuda.current_stream().wait_stream(mgr.d2h_stream) @staticmethod - def backward_record(tensor, event: torch.cuda.Event) -> torch.Tensor: + def backward_record(tensor) -> torch.Tensor: """Record the backward event for cuda graph capture.""" - return FineGrainedOffloadingBackwardRecordFunction.apply(tensor, event) + return FineGrainedOffloadingBackwardRecordFunction.apply(tensor) @staticmethod def reset(): diff --git a/megatron/core/transformer/module.py b/megatron/core/transformer/module.py index 855976df844..c837d7b4c9e 100644 --- a/megatron/core/transformer/module.py +++ b/megatron/core/transformer/module.py @@ -322,10 +322,13 @@ def _get_te_cuda_graph_replay_args(self, *args, **kwargs): cudagraph_kwargs = kwargs.copy() cudagraph_kwargs['is_first_microbatch'] = getattr(self, 'current_microbatch', 0) == 0 - from megatron.core.transformer.transformer_layer import TransformerLayer + if self.config.fine_grained_activation_offloading: + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, + ) - cudagraph_kwargs['cuda_graph_stream'] = TransformerLayer.cuda_graph_stream - cudagraph_kwargs['cuda_graph_event'] = TransformerLayer.cuda_graph_event + cudagraph_kwargs['cuda_graph_stream'] = off_interface.cuda_graph_stream() + cudagraph_kwargs['cuda_graph_event'] = off_interface.cuda_graph_event() return cudagraph_args, cudagraph_kwargs def _should_call_local_cudagraph(self, *args, **kwargs): diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 51e9c622a1d..d6654cd117a 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1362,9 +1362,8 @@ def __post_init__(self): "which is needed in core_attn.backward()." ) if self.delay_offload_until_cuda_graph: - assert ( - self.external_cuda_graph or self.enable_cuda_graph, - "delay_offload_until_cuda_graph must be used with cuda graph.", + assert self.external_cuda_graph or self.enable_cuda_graph, ( + "delay_offload_until_cuda_graph must be used with cuda graph." ) assert ( self.min_offloaded_tensor_size >= 0 diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 1a751aaf78b..36ce27499ec 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -267,9 +267,6 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): output of the same size. """ - cuda_graph_stream = None - cuda_graph_event = None - def __init__( self, config: TransformerConfig, @@ -467,6 +464,7 @@ def can_recompute_pre_mlp_layernorm_for_cudagraph(): self.recompute_mlp = True self._set_offload_modules() + self.mlp_norm_manager = None # @jcasper how should we handle nvfuser? # Set bias+dropout+add fusion grad_enable execution handler. # TORCH_MAJOR = int(torch.__version__.split('.')[0]) @@ -563,9 +561,7 @@ def _forward_attention( # This is to ensure the main stream waits for computing on cuda graph stream to complete, # and overlaps with the H2D transfer on reload stream. if self.offload_module_in_cuda_graph: - hidden_states = off_interface.backward_record( - hidden_states, TransformerLayer.cuda_graph_event - ) + hidden_states = off_interface.backward_record(hidden_states) inference_context = deprecate_inference_params(inference_context, inference_params) @@ -851,11 +847,11 @@ def _forward_post_mlp( nvtx_range_pop(suffix="mlp_bda") # Delay the offload of the mlp norm until after the mlp_bda has been computed # because the residual is needed in the mlp_bda. - if hasattr(self, 'mlp_norm_manager'): + if self.mlp_norm_manager is not None: hidden_states = self.mlp_norm_manager.group_offload( hidden_states, forced_released_tensors=[residual] ) - delattr(self, 'mlp_norm_manager') + self.mlp_norm_manager = None # Jit compiled function creates 'view' tensor. This tensor # potentially gets saved in the MPU checkpoint function context, @@ -1053,7 +1049,7 @@ def _te_cuda_graph_capture(self, *args, **kwargs): FineGrainedActivationOffloadingInterface as off_interface, ) - off_interface.forward_record(TransformerLayer.cuda_graph_event) + off_interface.forward_record() return tuple(cuda_graph_outputs) def _te_cuda_graph_replay(self, *args, **kwargs): @@ -1330,17 +1326,6 @@ def _set_offload_modules(self): assert ( self.config.cuda_graph_warmup_steps > 0 ), "Fine-grained activation offloading needs cuda_graph_warmup_steps > 0." - # Set the cuda graph stream and event for the transformer layer. - if self.offload_module_in_cuda_graph: - from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - FineGrainedActivationOffloadingInterface as off_interface, - ) - - TransformerLayer.cuda_graph_stream = off_interface.cuda_graph_stream() - TransformerLayer.cuda_graph_event = off_interface.cuda_graph_event() - else: - TransformerLayer.cuda_graph_stream = torch.cuda.current_stream() - TransformerLayer.cuda_graph_event = torch.cuda.Event() def get_layer_norm_weights(self): """ diff --git a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py index d26ff8e128a..41b9391e171 100644 --- a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py +++ b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py @@ -318,7 +318,6 @@ def test_gpt_fine_grained_activation_offloading_correctness_and_memory( ("alltoall", True, ["mlp_norm"]), ("alltoall", False, ["expert_fc1"]), ("alltoall", False, ["moe_act"]), - # ("alltoall", False, ["mlp_norm", "expert_fc1", "moe_act"]), ( "alltoall", True, From 716e12a9311d6f840de43c7aa30a7c9a000c42b4 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Sun, 1 Mar 2026 20:30:56 -0800 Subject: [PATCH 68/74] format Signed-off-by: Hongbin Liu --- megatron/core/transformer/transformer_config.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index d6654cd117a..099cf9028a2 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1362,9 +1362,9 @@ def __post_init__(self): "which is needed in core_attn.backward()." ) if self.delay_offload_until_cuda_graph: - assert self.external_cuda_graph or self.enable_cuda_graph, ( - "delay_offload_until_cuda_graph must be used with cuda graph." - ) + assert ( + self.external_cuda_graph or self.enable_cuda_graph + ), "delay_offload_until_cuda_graph must be used with cuda graph." assert ( self.min_offloaded_tensor_size >= 0 ), "min_offloaded_tensor_size must be non-negative." From 61b589a33582ed56a11470ff82813ace930a82b0 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Sun, 1 Mar 2026 22:33:27 -0800 Subject: [PATCH 69/74] bug fix Signed-off-by: Hongbin Liu --- megatron/core/transformer/module.py | 4 +++- megatron/core/transformer/transformer_config.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/megatron/core/transformer/module.py b/megatron/core/transformer/module.py index c837d7b4c9e..2d588262676 100644 --- a/megatron/core/transformer/module.py +++ b/megatron/core/transformer/module.py @@ -322,7 +322,9 @@ def _get_te_cuda_graph_replay_args(self, *args, **kwargs): cudagraph_kwargs = kwargs.copy() cudagraph_kwargs['is_first_microbatch'] = getattr(self, 'current_microbatch', 0) == 0 - if self.config.fine_grained_activation_offloading: + if self.config.fine_grained_activation_offloading and getattr( + self, 'offload_module_in_cuda_graph', False + ): from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( FineGrainedActivationOffloadingInterface as off_interface, ) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 099cf9028a2..7581616ea9c 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1363,7 +1363,7 @@ def __post_init__(self): ) if self.delay_offload_until_cuda_graph: assert ( - self.external_cuda_graph or self.enable_cuda_graph + self.transformer_impl == "transformer_engine" ), "delay_offload_until_cuda_graph must be used with cuda graph." assert ( self.min_offloaded_tensor_size >= 0 @@ -1378,6 +1378,7 @@ def __post_init__(self): assert ( self.cuda_graph_impl == "transformer_engine" ), "cuda_graph_impl must be transformer_engine when enabling offloading." + if self.cuda_graph_impl == "transformer_engine": assert ( self.cuda_graph_scope is not None ), "cuda_graph_scope must be set when enabling offloading." From 0200121cd271b2d08eb183f4944aafea823ce924 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 2 Mar 2026 01:25:37 -0800 Subject: [PATCH 70/74] add flag to control flush_delayed_groups in fine_grained_callables.py Signed-off-by: Hongbin Liu --- megatron/core/models/gpt/fine_grained_callables.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index c13fbeab229..95fc29e25cc 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -602,7 +602,7 @@ def submodule_combine_forward(node: ScheduleNode, output: torch.Tensor): # Flush the delayed groups. # This process happens only during the warmup steps of cuda graph. - if node.chunk_state.flush_delayed_groups: + if layer.config.delay_offload_until_cuda_graph and node.chunk_state.flush_delayed_groups: off_interface.flush_delayed_groups() # Need to record tensors created on comp stream to comm stream From c8bd90d6ccf3a230f47b5b04d0d5a1c410886dc3 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Wed, 4 Mar 2026 23:16:30 -0800 Subject: [PATCH 71/74] 1. move backward_record() to te_cuda_graph_capture() 2. remove flush_delayed_groups() when the training is not in replay mode Signed-off-by: Hongbin Liu --- .../core/models/gpt/fine_grained_callables.py | 8 --- .../fine_grained_activation_offload.py | 16 ++++- .../core/transformer/transformer_layer.py | 63 +++++++++++-------- 3 files changed, 52 insertions(+), 35 deletions(-) diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index 95fc29e25cc..4615b62d456 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -454,10 +454,7 @@ def submodule_attn_forward(node: ScheduleNode, hidden_states: torch.Tensor): ): layer.set_te_cuda_graph_backward_dw_wrapper() forward_func = layer._te_cuda_graph_replay - node.chunk_state.flush_delayed_groups = False else: - node.chunk_state.flush_delayed_groups = True - # wrapper function that keeps consistent api with cuda graph replay def forward_func( hidden_states: Tensor, @@ -600,11 +597,6 @@ def submodule_combine_forward(node: ScheduleNode, output: torch.Tensor): inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True ) - # Flush the delayed groups. - # This process happens only during the warmup steps of cuda graph. - if layer.config.delay_offload_until_cuda_graph and node.chunk_state.flush_delayed_groups: - off_interface.flush_delayed_groups() - # Need to record tensors created on comp stream to comm stream node.layer_state.residual.record_stream(torch.cuda.current_stream()) if shared_expert_output is not None: diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index 58771d18609..6e66d5ddffb 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -419,6 +419,8 @@ def __init__(self): # Whether the manager is in warmup phase. self._is_warmup = True + # Whether the manager is in CUDA graph replay phase. + self._in_replay = False # Cache OffloadChunkHandler objects for each virtual pipeline stage and each forward pass. self._cached_chunks_forward = [] # Cache OffloadChunkHandler objects for each virtual pipeline stage and each backward pass. @@ -1133,7 +1135,7 @@ def forward(ctx, tensor, cur_forward_chunk, name, forced_released_tensors, delay # pylint: disable=missing-function-docstring debug_rank("FineGrainedOffloadingGroupCommitFunction forward") - if delay_offload: + if delay_offload and PipelineOffloadManager.get_instance()._in_replay: PipelineOffloadManager.get_instance().push_offload_groups( cur_forward_chunk.on_group_commit_forward, name, forced_released_tensors ) @@ -1244,11 +1246,13 @@ class FineGrainedOffloadingBackwardRecordFunction(torch.autograd.Function): @staticmethod def forward(ctx, tensor) -> torch.Tensor: """Forward pass for cuda graph capture.""" + debug_rank("FineGrainedOffloadingBackwardRecordFunction forward") return tensor @staticmethod def backward(ctx, grad_output): """Record the backward event and wait for the h2d stream on cuda graph stream.""" + debug_rank("FineGrainedOffloadingBackwardRecordFunction backward") mgr = PipelineOffloadManager.get_instance() torch.cuda.current_stream().record_event(mgr.cuda_graph_event) torch.cuda.current_stream().wait_stream(mgr.h2d_stream) @@ -1358,3 +1362,13 @@ def disable_offload(): def enable_offload(): """Enable the offload.""" PipelineOffloadManager.get_instance().enable_offload() + + @staticmethod + def enter_replay(): + """Enter CUDA graph replay mode to enable delayed offloading.""" + PipelineOffloadManager.get_instance()._in_replay = True + + @staticmethod + def exit_replay(): + """Exit CUDA graph replay mode.""" + PipelineOffloadManager.get_instance()._in_replay = False diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 36ce27499ec..43b00a7480b 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -557,12 +557,6 @@ def _forward_attention( FineGrainedActivationOffloadingInterface as off_interface, ) - # Record the backward event on cuda graph stream in backward pass. - # This is to ensure the main stream waits for computing on cuda graph stream to complete, - # and overlaps with the H2D transfer on reload stream. - if self.offload_module_in_cuda_graph: - hidden_states = off_interface.backward_record(hidden_states) - inference_context = deprecate_inference_params(inference_context, inference_params) # Residual connection. @@ -704,7 +698,6 @@ def _forward_mlp( hidden_states: Tensor, inference_context: BaseInferenceContext | None = None, padding_mask: Tensor | None = None, - flush_delayed_groups: bool = True, ) -> Tensor | list[Tensor | None]: """ Perform a forward pass through the feed-forward layer. @@ -800,13 +793,12 @@ def _forward_mlp( self.pre_mlp_norm_checkpoint.discard_output_and_register_recompute(tensor) return list(mlp_output_with_bias) + [residual] else: - return self._forward_post_mlp(mlp_output_with_bias, residual, flush_delayed_groups) + return self._forward_post_mlp(mlp_output_with_bias, residual) def _forward_post_mlp( self, mlp_output_with_bias: tuple[Tensor, Tensor | None], residual: Tensor, - flush_delayed_groups: bool = True, ) -> Tensor: """ Perform operations after the MLP computation. @@ -814,7 +806,6 @@ def _forward_post_mlp( Args: mlp_output_with_bias (Tensor): Output tensor of the MLP layer with bias. residual (Tensor): Residual tensor. - flush_delayed_groups (bool): Whether to flush the delayed groups. Returns: output (Tensor): Transformed hidden states of shape [s, b, h]. @@ -863,14 +854,6 @@ def _forward_post_mlp( inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True ) - # Flush the delayed groups. - # This process happens only during the warmup steps of cuda graph. - if self.config.delay_offload_until_cuda_graph and flush_delayed_groups: - from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - FineGrainedActivationOffloadingInterface as off_interface, - ) - - off_interface.flush_delayed_groups() return output def sharded_state_dict( @@ -1014,6 +997,22 @@ def _te_cuda_graph_capture(self, *args, **kwargs): attribute can be set to control the scope of the CUDA graph. 2. If context is None, it cannot be returned as output. """ + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, + ) + + # Record the backward event on cuda graph stream in backward pass. + # This is to ensure the main stream waits for computing on cuda graph stream to complete, + # and overlaps with the H2D transfer on reload stream. + if self.offload_module_in_cuda_graph: + if len(args) > 0: + hidden_states = args[0] + hidden_states = off_interface.backward_record(hidden_states) + args = (hidden_states,) + args[1:] + else: + hidden_states = kwargs.pop("hidden_states") + hidden_states = off_interface.backward_record(hidden_states) + kwargs["hidden_states"] = hidden_states context = None if not self.config.cuda_graph_scope or CudaGraphScope.attn in self.config.cuda_graph_scope: hidden_states, context = self._forward_attention(*args, **kwargs) @@ -1073,12 +1072,26 @@ def _te_cuda_graph_replay(self, *args, **kwargs): "For inference cuda graph, please use cuda_graph_impl=local instead." ) + if self.config.delay_offload_until_cuda_graph: + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, + ) + + off_interface.enter_replay() + + try: + return self._te_cuda_graph_replay_impl(args, kwargs, context) + finally: + if self.config.delay_offload_until_cuda_graph: + off_interface.exit_replay() + + def _te_cuda_graph_replay_impl(self, args, kwargs, context): + """Implementation of _te_cuda_graph_replay, separated for replay mode cleanup.""" cuda_graph_output = list(super()._te_cuda_graph_replay(*args, **kwargs)) - # Flush the delayed groups after the cuda graph replay. - # This is to reduce or eliminate the cpu overhead of offloading because - # there exists a synchronization between the cuda graph replay and the a2a comm - # in moe layer, at that point the host thread is blocked and idle. + # Flush delayed offload groups from previous layers after graph replay. + # The CPU is idle during the sync between graph replay and a2a comm, + # so we use that time to execute the delayed offload operations. if self.config.delay_offload_until_cuda_graph: from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( FineGrainedActivationOffloadingInterface as off_interface, @@ -1160,9 +1173,7 @@ def _te_cuda_graph_replay(self, *args, **kwargs): # of the cudagraph, so disable the recompute hooks inside _forward_post_mlp recompute_pre_mlp_layernorm = self.recompute_pre_mlp_layernorm self.recompute_pre_mlp_layernorm = False - output = self._forward_post_mlp( - mlp_output_with_bias, residual, flush_delayed_groups=False - ) + output = self._forward_post_mlp(mlp_output_with_bias, residual) self.recompute_pre_mlp_layernorm = recompute_pre_mlp_layernorm else: # If EP overlap is enabled, needs to return same outputs as submodule.attn @@ -1178,7 +1189,7 @@ def _te_cuda_graph_replay(self, *args, **kwargs): return residual, hidden_states, probs, shared_expert_output # CUDA Graph does not capture the MLP/MoE part at all. - output = self._forward_mlp(*cuda_graph_output, flush_delayed_groups=False) + output = self._forward_mlp(*cuda_graph_output) return output, context def _get_te_cuda_graph_replay_args(self, *args, **kwargs): From ddd67d29419e1a6982481b5152cce9865a9a8864 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Wed, 4 Mar 2026 23:17:19 -0800 Subject: [PATCH 72/74] format Signed-off-by: Hongbin Liu --- megatron/core/transformer/transformer_layer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 43b00a7480b..f9609165f24 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -796,9 +796,7 @@ def _forward_mlp( return self._forward_post_mlp(mlp_output_with_bias, residual) def _forward_post_mlp( - self, - mlp_output_with_bias: tuple[Tensor, Tensor | None], - residual: Tensor, + self, mlp_output_with_bias: tuple[Tensor, Tensor | None], residual: Tensor ) -> Tensor: """ Perform operations after the MLP computation. From e989b95455e02ed088849ce8e345fd3332cfe44c Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Wed, 4 Mar 2026 23:36:50 -0800 Subject: [PATCH 73/74] remove the knob forward_only when executing reset() Signed-off-by: Hongbin Liu --- megatron/core/pipeline_parallel/schedules.py | 6 +++--- megatron/core/transformer/cuda_graphs.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index e903f392bf0..133c027aaf8 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -689,7 +689,7 @@ def forward_backward_no_pipelining( force_all_reduce=force_all_reduce, ) - if not forward_only and config.fine_grained_activation_offloading: + if config.fine_grained_activation_offloading: off_interface.reset() if config.timers is not None: @@ -2054,7 +2054,7 @@ def pp_post_backward(input_tensor_grad, vp_stage=None): force_all_reduce=force_all_reduce, ) - if not forward_only and config.fine_grained_activation_offloading: + if config.fine_grained_activation_offloading: off_interface.reset() # Restore config.grad_sync_func and config.param_sync_func. if forward_only: @@ -2442,7 +2442,7 @@ def enable_grad_sync(): force_all_reduce=force_all_reduce, ) - if not forward_only and config.fine_grained_activation_offloading: + if config.fine_grained_activation_offloading: off_interface.reset() if config.timers is not None: diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index 7d3d2562a1c..3660f5eac9a 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -2222,7 +2222,8 @@ def _finish_capturing(self, start_time): ) from megatron.core.transformer.moe.moe_utils import clear_aux_losses_tracker - off_interface.reset() + if self.config.fine_grained_activation_offloading: + off_interface.reset() torch.distributed.barrier() for model_chunk in self.model: From 19fe6b34191d8aecadb0c9adbbfd13b4d73da3d8 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Thu, 5 Mar 2026 05:18:50 -0800 Subject: [PATCH 74/74] fix ut and reviewer's comments Signed-off-by: Hongbin Liu --- megatron/core/pipeline_parallel/schedules.py | 6 ++--- .../core/transformer/transformer_config.py | 26 ++++++++++--------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 133c027aaf8..a142956068d 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -689,7 +689,7 @@ def forward_backward_no_pipelining( force_all_reduce=force_all_reduce, ) - if config.fine_grained_activation_offloading: + if getattr(config, 'fine_grained_activation_offloading', False): off_interface.reset() if config.timers is not None: @@ -2054,7 +2054,7 @@ def pp_post_backward(input_tensor_grad, vp_stage=None): force_all_reduce=force_all_reduce, ) - if config.fine_grained_activation_offloading: + if getattr(config, 'fine_grained_activation_offloading', False): off_interface.reset() # Restore config.grad_sync_func and config.param_sync_func. if forward_only: @@ -2442,7 +2442,7 @@ def enable_grad_sync(): force_all_reduce=force_all_reduce, ) - if config.fine_grained_activation_offloading: + if getattr(config, 'fine_grained_activation_offloading', False): off_interface.reset() if config.timers is not None: diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 7581616ea9c..7338d6e97a8 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1382,18 +1382,20 @@ def __post_init__(self): assert ( self.cuda_graph_scope is not None ), "cuda_graph_scope must be set when enabling offloading." - assert ( - "attn" in self.cuda_graph_scope and "moe_router" in self.cuda_graph_scope - ) or ( - CudaGraphScope.attn in self.cuda_graph_scope - and CudaGraphScope.moe_router in self.cuda_graph_scope - ), "attn and moe_router must be in cuda_graph_scope when enabling offloading." - assert ( - "attn_norm" not in self.offload_modules - ), "input of attn_norm is the start point of cuda graph, which can't be offloaded." - assert ( - "mlp_norm" not in self.offload_modules - ), "mlp_norm goes through the boundary of cuda graph, which can't be offloaded." + if ( + "attn" in self.cuda_graph_scope + or "moe_router" in self.cuda_graph_scope + or "moe_preprocess" in self.cuda_graph_scope + or CudaGraphScope.attn in self.cuda_graph_scope + or CudaGraphScope.moe_router in self.cuda_graph_scope + or CudaGraphScope.moe_preprocess in self.cuda_graph_scope + ): + assert ( + "attn_norm" not in self.offload_modules + ), "attn_norm is the start point of cuda graph, so can't be offloaded." + assert ( + "mlp_norm" not in self.offload_modules + ), "mlp_norm goes through the boundary of cuda graph, so can't be offloaded." if ( self.num_layers_in_first_pipeline_stage is not None