Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.10", "3.11"]
python-version: ["3.11"]

steps:
- uses: actions/checkout@v4
Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
![header](envelope.png)

**Warning: Orbformer checkpoints are stored using `pickle`. Never read a checkpoint from an untrusted source.**

# OneQMC

This package provides an implementation of the [Orbformer wave function foundation model](https://arxiv.org/abs/2506.19960).
Expand Down Expand Up @@ -61,6 +63,9 @@ python scripts/transferable.py -d <subdirectory of ./data> -n <number of trainin
We recommend using distinct output directories for every training run.
Regarding other optional arguments, run `python scripts/transferable.py -h` for more information.

**Note**: Checkpoints are stored using `pickle`. This means that opening a checkpoint from an untrusted source is a major security risk. Only ever open checkpoint files from trusted sources. To prevent reading untrusted pickle files, checkpoint reading is disabled by default and can be
re-enabled by settings the environment variable `ORBFORMER_PICKLE_LOADING=1`.

### Preparing new structure data for fine-tuning

To create a new dataset for fine-tuning, we recommend using [qcelemental](https://github.com/MolSSI/QCElemental) format.
Expand Down
8 changes: 8 additions & 0 deletions src/oneqmc/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@


def load_chkpt_file(chkpt: str, discard_sampler_state: bool) -> Tuple[TrainState, int]:
if not os.environ["ORBFORMER_PICKLE_LOADING"] == "1":
raise PermissionError(
"Loading pickle files is disabled for security. Set ORBFORMER_PICKLE_LOADING=1 to allow"
)
with open(chkpt, "rb") as chkpt_file:
init_step, (smpl_state, param_state, opt_state) = pickle.load(chkpt_file)
if discard_sampler_state:
Expand All @@ -28,6 +32,10 @@ def load_chkpt_file(chkpt: str, discard_sampler_state: bool) -> Tuple[TrainState


def load_density_chkpt_file(chkpt: str, discard_sampler_state: bool) -> Tuple[Tuple, int]:
if not os.environ["ORBFORMER_PICKLE_LOADING"] == "1":
raise PermissionError(
"Loading pickle files is disabled for security. Set ORBFORMER_PICKLE_LOADING=1 to allow"
)
with open(chkpt, "rb") as chkpt_file:
init_step, (param_state, opt_state) = pickle.load(chkpt_file)
return (param_state, opt_state), init_step
Expand Down
8 changes: 8 additions & 0 deletions src/oneqmc/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def close(self):
def last(self):
step_fast, step_slow = -1, -1 # account for the case where a queue is not initialized
while self.fast_chkpts:
if not os.environ["ORBFORMER_PICKLE_LOADING"] == "1":
raise PermissionError(
"Loading pickle files is disabled for security. Set ORBFORMER_PICKLE_LOADING=1 to allow"
)
with self.fast_chkpts.pop(-1).path.open("rb") as f:
step_fast, last_chkpt_fast = pickle.load(f)
if not jax.tree.reduce(
Expand All @@ -142,6 +146,10 @@ def last(self):
):
break
while self.slow_chkpts:
if not os.environ["ORBFORMER_PICKLE_LOADING"] == "1":
raise PermissionError(
"Loading pickle files is disabled for security. Set ORBFORMER_PICKLE_LOADING=1 to allow"
)
with self.slow_chkpts.pop(-1).path.open("rb") as f:
step_slow, last_chkpt_slow = pickle.load(f)
if not jax.tree.reduce(
Expand Down
6 changes: 6 additions & 0 deletions tests/integration_tests/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def run_transferable_script(project_root):
]

def runner(extra_args):
env = os.environ.copy()
env["ORBFORMER_PICKLE_LOADING"] = "1"
result = subprocess.run(
[
"python",
Expand All @@ -52,6 +54,7 @@ def runner(extra_args):
],
cwd=project_root,
capture_output=True,
env=env,
)
if result.returncode != 0:
raise OneQMCProcessError(result.stderr.decode())
Expand Down Expand Up @@ -81,6 +84,8 @@ def run_density_script(project_root):
]

def runner(extra_args):
env = os.environ.copy()
env["ORBFORMER_PICKLE_LOADING"] = "1"
result = subprocess.run(
[
"python",
Expand All @@ -90,6 +95,7 @@ def runner(extra_args):
],
cwd=project_root,
capture_output=True,
env=env,
)
if result.returncode != 0:
raise OneQMCProcessError(result.stderr.decode())
Expand Down