diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 25f0d0d06d0..2cb2a2a0e6a 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -2477,7 +2477,7 @@ def _add_distributed_args(parser): group.add_argument('--create-all-gather-group', action='store_true', help='Create a separate process group for all-gather operations ' 'to overlap reduce-scatter and all-gather operations.') - group.add_argument('--data-parallel-sharding-strategy', type=str, default='no_shard', + group.add_argument('--data-parallel-sharding-strategy', type=str, default='optim_grads_params', choices=['no_shard', 'optim', 'optim_grads', 'optim_grads_params'], help='Sharding strategy of data parallelism.') group.add_argument('--outer-dp-sharding-strategy', type=str, default='no_shard',