-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Open
Labels
Description
Is your feature request related to a problem? Please describe.
During training, especially when the context window is long, we usually run into OOM issues. The OOM may happen during the mlp layer (dense or MOE).
Tag the @mcore-oncall
to get oncall's attention to this issue.
Describe the solution you'd like
It is noticed that there is a chunked MLP solution in prefill phase under inference. However the same trick can be applied to the training.
should_chunk_mlp_for_prefill = (
self.config.mlp_chunks_for_prefill > 1
and inference_context is not None
and not inference_context.is_decode_only()
and not isinstance(self.mlp, IdentityOp)
and not self.config.transformer_impl == "inference_optimized"
)
....
elif should_chunk_mlp_for_prefill:
# Chunk input along sequence dimension
num_chunks = min(self.config.mlp_chunks_for_prefill, pre_mlp_layernorm_output.shape[0])
chunks = pre_mlp_layernorm_output.chunk(num_chunks, dim=0)
....
mlp_output_with_bias = (mlp_output, bias_output)
Describe alternatives you've considered
I tried directly handle the efficiency issues inside of the MLP layer. However this simple chunking trick can altogether reduce the memory peak for activations.
Additional context
Add any other context or screenshots about the feature request here.
Reactions are currently unavailable