diff --git a/examples/flava/native/train.py b/examples/flava/native/train.py index 6e2afa10..6fe29e73 100644 --- a/examples/flava/native/train.py +++ b/examples/flava/native/train.py @@ -117,7 +117,11 @@ def __init__(self, config: DictConfig): else torch.float16 ) - self.scaler = ShardedGradScaler() if config.training.enable_amp else None + self.scaler = ( + ShardedGradScaler() + if config.training.enable_amp and self.half_dtype == torch.float16 + else None + ) def log( self,