Skip to content

Conversation

@zhenga1
Copy link
Contributor

@zhenga1 zhenga1 commented Jan 15, 2026

Flash Attention 2 only works with torch.float16 or torch.bfloat16, but many Qwen models default to being loaded with dtype of torch.float32. The HFModelWrapper doesn't wrap the model anymore with the correct dtype so will sometimes lead to compatibility related bugs.

Ex.
m(FSDPPolicyWorkerBase pid=123348)[0m Flash Attention 2 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Qwen2ForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", dtype=torch.float16)` [36m(FSDPPolicyWorkerBase pid=123348)[0m Flash Attention 2 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Qwen2Model is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", dtype=torch.float16)` [36m(FSD

Fix allows the AutoConfig to load model with custom dtype so that Flash Attention 2 is compatible

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix a compatibility issue with Flash Attention 2 by ensuring models are loaded with the correct dtype. However, the current change in fsdp_worker.py appears to be ineffective. It modifies how AutoConfig is loaded, but this configuration isn't used for the actual model loading, which still happens with torch.float32 due to a hardcoded bf16=False parameter. This means the original bug is not fixed. I've left a detailed comment explaining the issue and suggesting a more direct fix, which involves changing the hardcoded bf16 parameter, similar to how it's handled in FSDPRefWorkerBase. Additionally, the same bug exists in FSDPCriticWorkerBase but is not addressed in this PR. The fix should be applied there as well for consistency.

Comment on lines +119 to +121
model_config = AutoConfig.from_pretrained(model_path,
torch_dtype=torch.bfloat16 if self.cfg.trainer.bf16 else None,
trust_remote_code=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change appears to be ineffective in resolving the Flash Attention 2 dtype compatibility issue. The model_config created here is only used to determine use_meta_tensor for get_init_weight_context_manager and is not passed to HFModelWrapper for model loading.

The root cause of the bug is the hardcoded bf16=False on line 132, which forces the model to be loaded in torch.float32. This is what causes the incompatibility with Flash Attention 2.

A more direct fix would be to change line 132 to bf16=self.cfg.trainer.bf16, which is consistent with how FSDPRefWorkerBase handles it on line 365. With that change, this modification would no longer be necessary.


model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
model_config = AutoConfig.from_pretrained(model_path,
torch_dtype=torch.bfloat16 if self.cfg.trainer.bf16 else None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to load the model in fp32 for correctness reasons for FSDP, which is why we hardcode loading in bf16 to false below:

screenshot from PR prior to release where we fixed this for context:
Image

@erictang000
Copy link
Collaborator

closing this PR for now due to the comment i shared above - we could consider suppressing the warning since this warning is expected (we cast to bf16 after load to use flash attn), but we shouldn't load in bf16

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants