Skip to content

Conversation

@kashif
Copy link
Contributor

@kashif kashif commented Nov 25, 2025

  1. Masked variance calculation (lines 95-107): Changed from the numerically unstable E[X²] - E[X]² formula to the stable centered formula E[(X-μ)²]
  2. Sigma clamping (line 609): Changed from torch.where(sigma < tolerance, 1.0, sigma) to torch.clamp(sigma, min=tolerance)

by: huggingface/transformers#42099 @AnMakc

  1. Masked variance calculation (lines 95-107): Changed from the numerically unstable E[X²] - E[X]² formula  to the stable centered formula E[(X-μ)²]
  2. Sigma clamping (line 609): Changed from torch.where(sigma < tolerance, 1.0, sigma) to torch.clamp(sigma,   min=tolerance)
@AnMakc
Copy link

AnMakc commented Nov 25, 2025

@kashif
Thanks for bringing it here.
Could you also fix the jax version, it has the same issue:

masked_var = masked_squared_sum / num_valid_elements - masked_mean**2




Btw, this potentially affected model pre-training as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants