diff --git a/train.py b/train.py index 2e743974..98b2c590 100644 --- a/train.py +++ b/train.py @@ -458,6 +458,11 @@ def step(self): torch.manual_seed(42) torch.cuda.manual_seed(42) torch.set_float32_matmul_precision("high") + +# Print GPU info for debugging and verification +gpu_name = torch.cuda.get_device_name(0) +gpu_vram_gb = torch.cuda.get_device_properties(0).total_memory / 1024**3 +print(f"GPU: {gpu_name} ({gpu_vram_gb:.1f} GB)") device = torch.device("cuda") autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) H100_BF16_PEAK_FLOPS = 989.5e12