From ca325bcb43c41e8a74ff1f921e9a0bb7e80060fb Mon Sep 17 00:00:00 2001 From: Adam Foster Date: Fri, 16 Jan 2026 16:11:00 +0000 Subject: [PATCH 1/4] Add warnings and protections around loading pickles --- README.md | 5 +++++ src/oneqmc/entrypoint.py | 4 ++++ src/oneqmc/log.py | 4 ++++ tests/integration_tests/test_scripts.py | 2 ++ 4 files changed, 15 insertions(+) diff --git a/README.md b/README.md index 8bb82fd..57877e1 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ ![header](envelope.png) +**Warning: Orbformer checkpoints are stored using `pickle`. Never read a checkpoint from an untrusted source.** + # OneQMC This package provides an implementation of the [Orbformer wave function foundation model](https://arxiv.org/abs/2506.19960). @@ -61,6 +63,9 @@ python scripts/transferable.py -d -n Tuple[TrainState, int]: + if not os.environ["ORBFORMER_PICKLE_LOADING"] == "1": + raise PermissionError("Loading pickle files is disable for security. Set ORBFORMER_PICKLE_LOADING=1 to allow") with open(chkpt, "rb") as chkpt_file: init_step, (smpl_state, param_state, opt_state) = pickle.load(chkpt_file) if discard_sampler_state: @@ -28,6 +30,8 @@ def load_chkpt_file(chkpt: str, discard_sampler_state: bool) -> Tuple[TrainState def load_density_chkpt_file(chkpt: str, discard_sampler_state: bool) -> Tuple[Tuple, int]: + if not os.environ["ORBFORMER_PICKLE_LOADING"] == "1": + raise PermissionError("Loading pickle files is disable for security. Set ORBFORMER_PICKLE_LOADING=1 to allow") with open(chkpt, "rb") as chkpt_file: init_step, (param_state, opt_state) = pickle.load(chkpt_file) return (param_state, opt_state), init_step diff --git a/src/oneqmc/log.py b/src/oneqmc/log.py index 4615eff..74add48 100644 --- a/src/oneqmc/log.py +++ b/src/oneqmc/log.py @@ -133,6 +133,8 @@ def close(self): def last(self): step_fast, step_slow = -1, -1 # account for the case where a queue is not initialized while self.fast_chkpts: + if not os.environ["ORBFORMER_PICKLE_LOADING"] == "1": + raise PermissionError("Loading pickle files is disable for security. Set ORBFORMER_PICKLE_LOADING=1 to allow") with self.fast_chkpts.pop(-1).path.open("rb") as f: step_fast, last_chkpt_fast = pickle.load(f) if not jax.tree.reduce( @@ -142,6 +144,8 @@ def last(self): ): break while self.slow_chkpts: + if not os.environ["ORBFORMER_PICKLE_LOADING"] == "1": + raise PermissionError("Loading pickle files is disable for security. Set ORBFORMER_PICKLE_LOADING=1 to allow") with self.slow_chkpts.pop(-1).path.open("rb") as f: step_slow, last_chkpt_slow = pickle.load(f) if not jax.tree.reduce( diff --git a/tests/integration_tests/test_scripts.py b/tests/integration_tests/test_scripts.py index 2dab3f5..faa9205 100644 --- a/tests/integration_tests/test_scripts.py +++ b/tests/integration_tests/test_scripts.py @@ -52,6 +52,7 @@ def runner(extra_args): ], cwd=project_root, capture_output=True, + env={"ORBFORMER_PICKLE_LOADING": "1"} ) if result.returncode != 0: raise OneQMCProcessError(result.stderr.decode()) @@ -90,6 +91,7 @@ def runner(extra_args): ], cwd=project_root, capture_output=True, + env={"ORBFORMER_PICKLE_LOADING": "1"} ) if result.returncode != 0: raise OneQMCProcessError(result.stderr.decode()) From 4bf6ec55ed25a29c926471f262f9919338736b83 Mon Sep 17 00:00:00 2001 From: Adam Foster Date: Fri, 16 Jan 2026 16:15:28 +0000 Subject: [PATCH 2/4] Formatting --- README.md | 2 +- src/oneqmc/entrypoint.py | 8 ++++++-- src/oneqmc/log.py | 8 ++++++-- tests/integration_tests/test_scripts.py | 4 ++-- 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 57877e1..e63f621 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,7 @@ We recommend using distinct output directories for every training run. Regarding other optional arguments, run `python scripts/transferable.py -h` for more information. **Note**: Checkpoints are stored using `pickle`. This means that opening a checkpoint from an untrusted source is a major security risk. Only ever open checkpoint files from trusted sources. To prevent reading untrusted pickle files, checkpoint reading is disabled by default and can be -re-enabled by settings the environment variables `ORBFORMER_PICKLE_LOADING=1`. +re-enabled by settings the environment variable `ORBFORMER_PICKLE_LOADING=1`. ### Preparing new structure data for fine-tuning diff --git a/src/oneqmc/entrypoint.py b/src/oneqmc/entrypoint.py index 415c873..3ad07ff 100644 --- a/src/oneqmc/entrypoint.py +++ b/src/oneqmc/entrypoint.py @@ -19,7 +19,9 @@ def load_chkpt_file(chkpt: str, discard_sampler_state: bool) -> Tuple[TrainState, int]: if not os.environ["ORBFORMER_PICKLE_LOADING"] == "1": - raise PermissionError("Loading pickle files is disable for security. Set ORBFORMER_PICKLE_LOADING=1 to allow") + raise PermissionError( + "Loading pickle files is disable for security. Set ORBFORMER_PICKLE_LOADING=1 to allow" + ) with open(chkpt, "rb") as chkpt_file: init_step, (smpl_state, param_state, opt_state) = pickle.load(chkpt_file) if discard_sampler_state: @@ -31,7 +33,9 @@ def load_chkpt_file(chkpt: str, discard_sampler_state: bool) -> Tuple[TrainState def load_density_chkpt_file(chkpt: str, discard_sampler_state: bool) -> Tuple[Tuple, int]: if not os.environ["ORBFORMER_PICKLE_LOADING"] == "1": - raise PermissionError("Loading pickle files is disable for security. Set ORBFORMER_PICKLE_LOADING=1 to allow") + raise PermissionError( + "Loading pickle files is disable for security. Set ORBFORMER_PICKLE_LOADING=1 to allow" + ) with open(chkpt, "rb") as chkpt_file: init_step, (param_state, opt_state) = pickle.load(chkpt_file) return (param_state, opt_state), init_step diff --git a/src/oneqmc/log.py b/src/oneqmc/log.py index 74add48..36294c5 100644 --- a/src/oneqmc/log.py +++ b/src/oneqmc/log.py @@ -134,7 +134,9 @@ def last(self): step_fast, step_slow = -1, -1 # account for the case where a queue is not initialized while self.fast_chkpts: if not os.environ["ORBFORMER_PICKLE_LOADING"] == "1": - raise PermissionError("Loading pickle files is disable for security. Set ORBFORMER_PICKLE_LOADING=1 to allow") + raise PermissionError( + "Loading pickle files is disable for security. Set ORBFORMER_PICKLE_LOADING=1 to allow" + ) with self.fast_chkpts.pop(-1).path.open("rb") as f: step_fast, last_chkpt_fast = pickle.load(f) if not jax.tree.reduce( @@ -145,7 +147,9 @@ def last(self): break while self.slow_chkpts: if not os.environ["ORBFORMER_PICKLE_LOADING"] == "1": - raise PermissionError("Loading pickle files is disable for security. Set ORBFORMER_PICKLE_LOADING=1 to allow") + raise PermissionError( + "Loading pickle files is disable for security. Set ORBFORMER_PICKLE_LOADING=1 to allow" + ) with self.slow_chkpts.pop(-1).path.open("rb") as f: step_slow, last_chkpt_slow = pickle.load(f) if not jax.tree.reduce( diff --git a/tests/integration_tests/test_scripts.py b/tests/integration_tests/test_scripts.py index faa9205..708179c 100644 --- a/tests/integration_tests/test_scripts.py +++ b/tests/integration_tests/test_scripts.py @@ -52,7 +52,7 @@ def runner(extra_args): ], cwd=project_root, capture_output=True, - env={"ORBFORMER_PICKLE_LOADING": "1"} + env={"ORBFORMER_PICKLE_LOADING": "1"}, ) if result.returncode != 0: raise OneQMCProcessError(result.stderr.decode()) @@ -91,7 +91,7 @@ def runner(extra_args): ], cwd=project_root, capture_output=True, - env={"ORBFORMER_PICKLE_LOADING": "1"} + env={"ORBFORMER_PICKLE_LOADING": "1"}, ) if result.returncode != 0: raise OneQMCProcessError(result.stderr.decode()) From 860307bb61071d0ce025459e6c4e8a664cecfed8 Mon Sep 17 00:00:00 2001 From: Adam Foster Date: Fri, 16 Jan 2026 16:16:53 +0000 Subject: [PATCH 3/4] Typo --- src/oneqmc/entrypoint.py | 4 ++-- src/oneqmc/log.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/oneqmc/entrypoint.py b/src/oneqmc/entrypoint.py index 3ad07ff..643d4f3 100644 --- a/src/oneqmc/entrypoint.py +++ b/src/oneqmc/entrypoint.py @@ -20,7 +20,7 @@ def load_chkpt_file(chkpt: str, discard_sampler_state: bool) -> Tuple[TrainState, int]: if not os.environ["ORBFORMER_PICKLE_LOADING"] == "1": raise PermissionError( - "Loading pickle files is disable for security. Set ORBFORMER_PICKLE_LOADING=1 to allow" + "Loading pickle files is disabled for security. Set ORBFORMER_PICKLE_LOADING=1 to allow" ) with open(chkpt, "rb") as chkpt_file: init_step, (smpl_state, param_state, opt_state) = pickle.load(chkpt_file) @@ -34,7 +34,7 @@ def load_chkpt_file(chkpt: str, discard_sampler_state: bool) -> Tuple[TrainState def load_density_chkpt_file(chkpt: str, discard_sampler_state: bool) -> Tuple[Tuple, int]: if not os.environ["ORBFORMER_PICKLE_LOADING"] == "1": raise PermissionError( - "Loading pickle files is disable for security. Set ORBFORMER_PICKLE_LOADING=1 to allow" + "Loading pickle files is disabled for security. Set ORBFORMER_PICKLE_LOADING=1 to allow" ) with open(chkpt, "rb") as chkpt_file: init_step, (param_state, opt_state) = pickle.load(chkpt_file) diff --git a/src/oneqmc/log.py b/src/oneqmc/log.py index 36294c5..941bca7 100644 --- a/src/oneqmc/log.py +++ b/src/oneqmc/log.py @@ -135,7 +135,7 @@ def last(self): while self.fast_chkpts: if not os.environ["ORBFORMER_PICKLE_LOADING"] == "1": raise PermissionError( - "Loading pickle files is disable for security. Set ORBFORMER_PICKLE_LOADING=1 to allow" + "Loading pickle files is disabled for security. Set ORBFORMER_PICKLE_LOADING=1 to allow" ) with self.fast_chkpts.pop(-1).path.open("rb") as f: step_fast, last_chkpt_fast = pickle.load(f) @@ -148,7 +148,7 @@ def last(self): while self.slow_chkpts: if not os.environ["ORBFORMER_PICKLE_LOADING"] == "1": raise PermissionError( - "Loading pickle files is disable for security. Set ORBFORMER_PICKLE_LOADING=1 to allow" + "Loading pickle files is disabled for security. Set ORBFORMER_PICKLE_LOADING=1 to allow" ) with self.slow_chkpts.pop(-1).path.open("rb") as f: step_slow, last_chkpt_slow = pickle.load(f) From b4e36bb456405adf01ea3ef0aed96fe5605a4f76 Mon Sep 17 00:00:00 2001 From: Adam Foster Date: Fri, 16 Jan 2026 16:28:36 +0000 Subject: [PATCH 4/4] Only run tests on py3.11 (conda was overwriting this anyway). Fix env --- .github/workflows/pytest.yaml | 2 +- tests/integration_tests/test_scripts.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.github/workflows/pytest.yaml b/.github/workflows/pytest.yaml index 99cc3bf..ed5c043 100644 --- a/.github/workflows/pytest.yaml +++ b/.github/workflows/pytest.yaml @@ -14,7 +14,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.10", "3.11"] + python-version: ["3.11"] steps: - uses: actions/checkout@v4 diff --git a/tests/integration_tests/test_scripts.py b/tests/integration_tests/test_scripts.py index 708179c..8ef6787 100644 --- a/tests/integration_tests/test_scripts.py +++ b/tests/integration_tests/test_scripts.py @@ -43,6 +43,8 @@ def run_transferable_script(project_root): ] def runner(extra_args): + env = os.environ.copy() + env["ORBFORMER_PICKLE_LOADING"] = "1" result = subprocess.run( [ "python", @@ -52,7 +54,7 @@ def runner(extra_args): ], cwd=project_root, capture_output=True, - env={"ORBFORMER_PICKLE_LOADING": "1"}, + env=env, ) if result.returncode != 0: raise OneQMCProcessError(result.stderr.decode()) @@ -82,6 +84,8 @@ def run_density_script(project_root): ] def runner(extra_args): + env = os.environ.copy() + env["ORBFORMER_PICKLE_LOADING"] = "1" result = subprocess.run( [ "python", @@ -91,7 +95,7 @@ def runner(extra_args): ], cwd=project_root, capture_output=True, - env={"ORBFORMER_PICKLE_LOADING": "1"}, + env=env, ) if result.returncode != 0: raise OneQMCProcessError(result.stderr.decode())