From 54091fdabda5e2c4823daa382977b1e2d74294ed Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Thu, 6 Jul 2023 17:11:45 -0700 Subject: [PATCH] Dont use scaler for bf16 Summary: ShardedGradScaler not needed for bf16 training. Differential Revision: D47218367 fbshipit-source-id: 6b477a45c346625410aa4dfa289489952618b42d --- examples/flava/native/train.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/flava/native/train.py b/examples/flava/native/train.py index 6e2afa10..6fe29e73 100644 --- a/examples/flava/native/train.py +++ b/examples/flava/native/train.py @@ -117,7 +117,11 @@ def __init__(self, config: DictConfig): else torch.float16 ) - self.scaler = ShardedGradScaler() if config.training.enable_amp else None + self.scaler = ( + ShardedGradScaler() + if config.training.enable_amp and self.half_dtype == torch.float16 + else None + ) def log( self,