diff --git a/pyproject.toml b/pyproject.toml index 2a4eb2f..f65b9c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ description = "Sparse Parameter Decomposition" requires-python = ">=3.11" readme = "README.md" dependencies = [ - "torch<2.6.0", + "torch>2.4.1,<2.6.0", "torchvision", "pydantic", "wandb", diff --git a/spd/experiments/lm/lm_decomposition.py b/spd/experiments/lm/lm_decomposition.py index 1d9a189..88892b5 100644 --- a/spd/experiments/lm/lm_decomposition.py +++ b/spd/experiments/lm/lm_decomposition.py @@ -49,7 +49,11 @@ def plot_lm_results( def main( - config_path_or_obj: Path | str | Config, sweep_config_path: Path | str | None = None + config_path_or_obj: Path | str | Config, + sweep_config_path: Path | str | None = None, + # due to pytorch vuln https://nvd.nist.gov/vuln/detail/CVE-2025-32434 + # hf `transformers` library doesn't allow loading with weights_only=True + weights_only: bool = True, ) -> None: config = load_config(config_path_or_obj, config_model=Config) @@ -72,6 +76,7 @@ def main( path_to_class=config.pretrained_model_class, model_path=None, model_name_hf=config.pretrained_model_name_hf, + weights_only=weights_only, ) # --- Setup Run Name and Output Dir --- #