diff --git a/.github/workflows/pytest.yaml b/.github/workflows/pytest.yaml index 99cc3bf..ed5c043 100644 --- a/.github/workflows/pytest.yaml +++ b/.github/workflows/pytest.yaml @@ -14,7 +14,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.10", "3.11"] + python-version: ["3.11"] steps: - uses: actions/checkout@v4 diff --git a/README.md b/README.md index 8bb82fd..e63f621 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ ![header](envelope.png) +**Warning: Orbformer checkpoints are stored using `pickle`. Never read a checkpoint from an untrusted source.** + # OneQMC This package provides an implementation of the [Orbformer wave function foundation model](https://arxiv.org/abs/2506.19960). @@ -61,6 +63,9 @@ python scripts/transferable.py -d -n Tuple[TrainState, int]: + if not os.environ["ORBFORMER_PICKLE_LOADING"] == "1": + raise PermissionError( + "Loading pickle files is disabled for security. Set ORBFORMER_PICKLE_LOADING=1 to allow" + ) with open(chkpt, "rb") as chkpt_file: init_step, (smpl_state, param_state, opt_state) = pickle.load(chkpt_file) if discard_sampler_state: @@ -28,6 +32,10 @@ def load_chkpt_file(chkpt: str, discard_sampler_state: bool) -> Tuple[TrainState def load_density_chkpt_file(chkpt: str, discard_sampler_state: bool) -> Tuple[Tuple, int]: + if not os.environ["ORBFORMER_PICKLE_LOADING"] == "1": + raise PermissionError( + "Loading pickle files is disabled for security. Set ORBFORMER_PICKLE_LOADING=1 to allow" + ) with open(chkpt, "rb") as chkpt_file: init_step, (param_state, opt_state) = pickle.load(chkpt_file) return (param_state, opt_state), init_step diff --git a/src/oneqmc/log.py b/src/oneqmc/log.py index 4615eff..941bca7 100644 --- a/src/oneqmc/log.py +++ b/src/oneqmc/log.py @@ -133,6 +133,10 @@ def close(self): def last(self): step_fast, step_slow = -1, -1 # account for the case where a queue is not initialized while self.fast_chkpts: + if not os.environ["ORBFORMER_PICKLE_LOADING"] == "1": + raise PermissionError( + "Loading pickle files is disabled for security. Set ORBFORMER_PICKLE_LOADING=1 to allow" + ) with self.fast_chkpts.pop(-1).path.open("rb") as f: step_fast, last_chkpt_fast = pickle.load(f) if not jax.tree.reduce( @@ -142,6 +146,10 @@ def last(self): ): break while self.slow_chkpts: + if not os.environ["ORBFORMER_PICKLE_LOADING"] == "1": + raise PermissionError( + "Loading pickle files is disabled for security. Set ORBFORMER_PICKLE_LOADING=1 to allow" + ) with self.slow_chkpts.pop(-1).path.open("rb") as f: step_slow, last_chkpt_slow = pickle.load(f) if not jax.tree.reduce( diff --git a/tests/integration_tests/test_scripts.py b/tests/integration_tests/test_scripts.py index 2dab3f5..8ef6787 100644 --- a/tests/integration_tests/test_scripts.py +++ b/tests/integration_tests/test_scripts.py @@ -43,6 +43,8 @@ def run_transferable_script(project_root): ] def runner(extra_args): + env = os.environ.copy() + env["ORBFORMER_PICKLE_LOADING"] = "1" result = subprocess.run( [ "python", @@ -52,6 +54,7 @@ def runner(extra_args): ], cwd=project_root, capture_output=True, + env=env, ) if result.returncode != 0: raise OneQMCProcessError(result.stderr.decode()) @@ -81,6 +84,8 @@ def run_density_script(project_root): ] def runner(extra_args): + env = os.environ.copy() + env["ORBFORMER_PICKLE_LOADING"] = "1" result = subprocess.run( [ "python", @@ -90,6 +95,7 @@ def runner(extra_args): ], cwd=project_root, capture_output=True, + env=env, ) if result.returncode != 0: raise OneQMCProcessError(result.stderr.decode())