Skip to content

Add weights_only=True to torch.load() calls#58

Open
michaelv2 wants to merge 1 commit intosaprmarks:mainfrom
michaelv2:fix/unsafe-torch-load
Open

Add weights_only=True to torch.load() calls#58
michaelv2 wants to merge 1 commit intosaprmarks:mainfrom
michaelv2:fix/unsafe-torch-load

Conversation

@michaelv2
Copy link
Copy Markdown

@michaelv2 michaelv2 commented Apr 8, 2026

Summary

  • Adds weights_only=True to all 8 torch.load() call sites across the codebase
  • Prevents arbitrary code execution via crafted .pt files (pickle deserialization attack)
  • Aligns with PyTorch's recommended practice since 2.0 and addresses CVE-2025-32434 for users on PyTorch < 2.6.0 (weights_only is available since 2.0 but defaults to False, so this is no-op on 2.6+ and avoids breaking changes for 2.0-2.5)

Context

torch.load() without weights_only=True uses Python's pickle module, which can execute arbitrary code during deserialization. Since all from_pretrained() methods in this repo load state dicts (plain dicts of tensors), weights_only=True is sufficient and doesn't change behavior for legitimate model files.

This is a narrow attack chain if the .pt files are only ever loaded from the official HuggingFace account, but in a shared research environment where files may be routinely passed around, a malicious file could propagate.

Affected call sites:

  • dictionary.py: AutoEncoder, GatedAutoEncoder, JumpReluAutoEncoder, AutoEncoderNew from_pretrained()
  • trainers/top_k.py: AutoEncoderTopK from_pretrained()
  • trainers/batch_top_k.py: BatchTopKSAE from_pretrained()
  • trainers/matryoshka_batch_top_k.py: MatryoshkaBatchTopKSAE from_pretrained()
  • activault_s3_buffer.py: compile()

Test plan

  • Existing pretrained models load correctly with weights_only=True
  • pytest passes (no changes to logic, only to deserialization safety)

PyTorch's pickle-based deserialization can execute arbitrary code when
loading a crafted .pt file. Adding weights_only=True restricts
deserialization to tensor data only, preventing this class of attack.

This is the recommended practice since PyTorch 2.0 and addresses
CVE-2025-32434 for users on PyTorch < 2.6.

Affected call sites:
- dictionary.py: AutoEncoder, GatedAutoEncoder, JumpReluAutoEncoder,
  AutoEncoderNew from_pretrained()
- trainers/top_k.py: AutoEncoderTopK from_pretrained()
- trainers/batch_top_k.py: BatchTopKSAE from_pretrained()
- trainers/matryoshka_batch_top_k.py: MatryoshkaBatchTopKSAE from_pretrained()
- activault_s3_buffer.py: compile()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant