Skip to content

Potential Memory Leaking Issue When Using MS-AMP GeMM #201

@Mr-Philo

Description

@Mr-Philo

What would you like to be added:

Adding a small fix to MS-AMP GeMM to solve the potential memory leaking issue.

Why is this needed:

Currently, when training deit models using MS-AMP framework, the GPU memory situation is somewhat abnormal:

scheme FP8 activation Mem after forward Mem after backward Max mem Throughput One epoch time
FP16 × 18774.96MB 1535.79MB 19242.61MB ~14974.5128 (12708.2790) 02:12
FP8 O2 × 15696.38MB 3964.60MB 19298.19MB ~13673.2756 (11722.0941) 02:15
FP8 O2 15697.34MB 3964.33MB 19296.50MB ~9812.8065 (9420.7160) 02:25

Pay attention to the Mem after backward column, which is weird. If the memory is correctly optimized, then the Mem after backward should only be the w + w_grad, which should be about less than FP16 since we use FP8 w_grad.

Without this feature, how does current msamp work

See Why is this needed, current msamp may introduce potential memory leakage when training.

Components that may involve changes:

The custom GeMM function in current MS-AMP framework (msamp/nn/functional.py/class _FP8GemmFunction)

Brief description of your proposal if any:

Currently the custon GeMM function use ctx object to save input tensor x and weight tensor W. In backward gradient computing, x and W are needed. ctx.input_fp8 means directly saving this attribute. However, input_fp8 is for class ScalingTensor. In practice, this saving method does not fully leverage the advantage of FP8 tensors!

Please see PR for detailed information.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions