Skip to content

Using ZeRO 3?  #182

@muellerzr

Description

@muellerzr

What's the issue, what's expected?:
General question I can't seem to find, does MS-AMP support ZeRO-3? Or only ZeRO-1 and ZeRO-2. (It's fine if not and this should be obvious to a user :) )

I ran into some issues trying to run Zero-3:

    config = {
        "train_batch_size": 32,
        "train_micro_batch_size_per_gpu": 16,
        "gradient_accumulation_steps": 1,
        "zero_optimization": {
            "stage": zero_stage,
            "offload_optimizer": {"device": "none", "nvme_path": None},
            "offload_param": {"device": "none", "nvme_path": None},
            "stage3_gather_16bit_weights_on_model_save": False,
            "overlap_comm": True,
            "contiguous_gradients": True,
            "sub_group_size": 1e9,
            "stage3_max_live_parameters": 1e9,
            "stage3_max_reuse_distance": 1e9,
        },
        "gradient_clipping": 1.0,
        "steps_per_print": np.inf,
        "bf16": {"enabled": True},
        "fp16": {"enabled": False},
        "zero_allow_untested_optimizer": True,
        "msamp": {
            "enabled": True,
            "opt_level": opt_level,
        }
    }

Relevant Errors

O1 & O2:

AttributeError: 'ScalingParameter' object has no attribute 'item' Traceback (most recent call last): File "/mnt/work/accelerate/benchmarks/fp8/ms_amp/distrib_deepspeed.py", line 168, in baseline_not_trained, baseline_trained = train_baseline(3, "O1") File "/mnt/work/accelerate/benchmarks/fp8/ms_amp/distrib_deepspeed.py", line 85, in train_baseline ) = deepspeed.initialize( File "/mnt/work/MS-AMP/msamp/deepspeed/__init__.py", line 135, in initialize engine = MSAMPDeepSpeedEngine( File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 308, in __init__ self._configure_optimizer(optimizer, model_parameters) File "/mnt/work/MS-AMP/msamp/deepspeed/runtime/engine.py", line 104, in _configure_optimizer self.optimizer = self._configure_zero_optimizer(basic_optimizer) File "/mnt/work/MS-AMP/msamp/deepspeed/runtime/engine.py", line 283, in _configure_zero_optimizer optimizer = DeepSpeedZeroOptimizer_Stage3( File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/stage3.py", line 177, in __init__ self.parameter_offload = self.initialize_ds_offload( File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/stage3.py", line 423, in initialize_ds_offload return DeepSpeedZeRoOffload(module=module, File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/parameter_offload.py", line 241, in __init__ self._convert_to_zero_parameters(ds_config, module, mpu) File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/parameter_offload.py", line 316, in _convert_to_zero_parameters Init(module=module, File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/partition_parameters.py", line 996, in __init__ self._convert_to_zero_parameters(module.parameters(recurse=True)) File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/partition_parameters.py", line 1021, in _convert_to_zero_parameters self._zero_init_param(param) File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/partition_parameters.py", line 1008, in _zero_init_param self._convert_to_deepspeed_param(param) File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/partition_parameters.py", line 1404, in _convert_to_deepspeed_param param.item = allgather_before(param.item) AttributeError: 'ScalingParameter' object has no attribute 'item'

O3:

AssertionError: MS-AMP O3 requires ZeRO with optimizer_states or gradients partitioning. Traceback (most recent call last): File "/mnt/work/accelerate/benchmarks/fp8/ms_amp/distrib_deepspeed.py", line 168, in baseline_not_trained, baseline_trained = train_baseline(3, "O3") File "/mnt/work/accelerate/benchmarks/fp8/ms_amp/distrib_deepspeed.py", line 85, in train_baseline ) = deepspeed.initialize( File "/mnt/work/MS-AMP/msamp/deepspeed/__init__.py", line 119, in initialize config_class = MSAMPDeepSpeedConfig(config, mpu) File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/config.py", line 787, in __init__ self._do_sanity_check() File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/config.py", line 969, in _do_sanity_check self._do_error_check() File "/mnt/work/MS-AMP/msamp/deepspeed/runtime/config.py", line 43, in _do_error_check self.zero_optimization_stage in [ZeroStageEnum.optimizer_states, ZeroStageEnum.gradients], \ AssertionError: MS-AMP O3 requires ZeRO with optimizer_states or gradients partitioning.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions