diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 91e26af99c6..6404fdf38ea 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -2579,7 +2579,7 @@ def _add_distributed_args(parser): group.add_argument('--create-all-gather-group', action='store_true', help='Create a separate process group for all-gather operations ' 'to overlap reduce-scatter and all-gather operations.') - group.add_argument('--data-parallel-sharding-strategy', type=str, default='no_shard', + group.add_argument('--data-parallel-sharding-strategy', type=str, default='optim_grads_params', choices=['no_shard', 'optim', 'optim_grads', 'optim_grads_params'], help='Sharding strategy of data parallelism.') group.add_argument('--outer-dp-sharding-strategy', type=str, default='no_shard',