Problem
The script currently uses bfloat16 via torch.amp.autocast. However, on H100 and newer architectures (Compute Capability 9.0+), FP8 tensor cores offer up to 2x the throughput of BF16.
Proposal
Introduce a pathway for the AI to experiment with FP8 training, possibly utilizing torch.float8_e4m3fn or transformer_engine. This requires careful handling of scaling factors. This issue tracks the AI agent's overall goal of upgrading the linear layers and attention projections to FP8, which could drastically increase the current baseline MFU.