diff --git a/vision.py b/vision.py index f1836f4..a2e1d59 100644 --- a/vision.py +++ b/vision.py @@ -169,8 +169,9 @@ def parse_args(argv=None): attn_implementation = args.attn_implementation ) - if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8: - torch.set_float32_matmul_precision("high") + if torch.cuda.is_available(): + if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8: + torch.set_float32_matmul_precision("high") if args.load_in_4bit: extra_params['load_in_4bit'] = True