-
Notifications
You must be signed in to change notification settings - Fork 49
Description
What would you like to be added:
Adding a small fix to MS-AMP GeMM to solve the potential memory leaking issue.
Why is this needed:
Currently, when training deit models using MS-AMP framework, the GPU memory situation is somewhat abnormal:
| scheme | FP8 activation | Mem after forward | Mem after backward | Max mem | Throughput | One epoch time |
|---|---|---|---|---|---|---|
| FP16 | × | 18774.96MB | 1535.79MB | 19242.61MB | ~14974.5128 (12708.2790) | 02:12 |
| FP8 O2 | × | 15696.38MB | 3964.60MB | 19298.19MB | ~13673.2756 (11722.0941) | 02:15 |
| FP8 O2 | ✓ | 15697.34MB | 3964.33MB | 19296.50MB | ~9812.8065 (9420.7160) | 02:25 |
Pay attention to the Mem after backward column, which is weird. If the memory is correctly optimized, then the Mem after backward should only be the w + w_grad, which should be about less than FP16 since we use FP8 w_grad.
Without this feature, how does current msamp work:
See Why is this needed, current msamp may introduce potential memory leakage when training.
Components that may involve changes:
The custom GeMM function in current MS-AMP framework (msamp/nn/functional.py/class _FP8GemmFunction)
Brief description of your proposal if any:
Currently the custon GeMM function use ctx object to save input tensor x and weight tensor W. In backward gradient computing, x and W are needed. ctx.input_fp8 means directly saving this attribute. However, input_fp8 is for class ScalingTensor. In practice, this saving method does not fully leverage the advantage of FP8 tensors!
Please see PR for detailed information.