A flow matching model for transition state structure prediction with a fragment-based approach. FragmentFlow predicts the geometry of the reactive core, which can be completed to the full structure by re-attaching the substituent atoms. This way, FragmentFlow avoids errors that stem from a distribution shift in molecular size.
This repository contains training, inference, and evaluation code.
conda env create -f fragmentflow.yaml
conda activate fragmentflow
pip install -e .Warning
The WLN mapper (localmapper) is not included in the main fragmentflow conda environment.
To use it for dataset curation, install it separately via:
pip install localmapper
Note: For compatibility, especially due to PyTorch/DGL version constraints, you may need to install localmapper in a separate conda environment.
Train the model using distributed training with torchrun:
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 --master_port=29503 \
-m reactot.train.main_train \
--distributed \
--diffusion_type sb \
--wandb_project "your_project" \
--wandb_job_name "run_name" \
--wandb_api_key YOUR_API_KEY \
--datadir /path/to/dataset \
--output_dir ./output/run_name \
--max_nodes_per_batch 10000 \
--num_workers 16 \
--epochs 10000 \
--timesteps 3000 \
--lr 1e-5 \
--nfe 25 \
--ot_ode \
--ema_decay 0.999 \
--ckpt_epochs 100 \
--eval_epochs 50 \
--seed 0To resume training from a checkpoint:
--resume /path/to/checkpoint.ckptNote
This training script has been adapted from the original ReactOT implementation to remove the dependency on PyTorch Lightning, allowing for standard PyTorch-based training.
Generate transition state predictions from a trained model:
python -m reactot.sample_ts \
--nfe-values 200 \
--output-folder /path/to/output \
--checkpoint /path/to/checkpoint.ckpt \
--dataset-path /path/to/dataset.pkl \
--batch-size 20 \
--solver ddpm \
--gpu 0 \
--fragmentation-mode epsilon \
--overwriteArguments:
--nfe-values: Number of function evaluations for sampling--fragmentation-mode: Filter dataset by fragmentation level (original,epsilon, orall)--overwrite: Reprocess existing structures
Output structure:
output-folder/
├── ref_ts/structures/xyz_files/ # Reference TS structures
├── nfe_200/
│ ├── structures/xyz_files/ # Predicted structures
│ └── reports/ # Metrics and logs
Compute UMA energies for generated structures:
python -m evaluation.calculate_energies_uma \
--xyz-dir /path/to/xyz_files \
--output-path /path/to/energies.csv \
--model uma-s-1p1 \
--gpu 0Compare energies between reference and predicted structures:
python -m evaluation.calculate_energy_differences \
--ref-csv /path/to/ref_energies.csv \
--sample-csv /path/to/sample_energies.csv \
--output-csv /path/to/energy_diff.csv \
--ref-units kcal/mol \
--sample-units kcal/mol \
--overwriteCompute RMSD between reference and predicted structures:
python -m evaluation.calculate_rmsd \
--ref-xyz-dir /path/to/reference/xyz_files \
--sample-xyz-dir /path/to/predicted/xyz_files \
--output-csv /path/to/rmsd.csvComplete fragmented core structures by attaching substituents using IDPP interpolation:
python -m post_processing.attach_substituents.complete_from_xyz_cores \
--xyz_folder /path/to/core_structures \
--pkl_path /path/to/fragmented_dataset.pkl \
--output_folder /path/to/completed_structures \
--num_idpp_images 11This takes fragmented transition state cores and completes them by:
- Loading the original reactant/product structures from the dataset
- Finding the best IDPP image that matches the core
- Using Kabsch alignment and constrained IDPP refinement to position substituents
Optimize transition state structures using Sella with UMA or PySCF:
python -m post_processing.sella_optimization.optimize_structures_sequential \
--xyz-dir /path/to/xyz_files \
--output-dir /path/to/optimized \
--calculator uma \
--model uma-s-1p1 \
--device cuda \
--fmax 0.05 \
--steps 500 \
--process-by-sizeArguments:
--calculator: Energy calculator (umaorpyscf)--fmax: Force convergence threshold--steps: Maximum optimization steps--process-by-size: Process structures smallest to largest
Create fragmented versions of a Transition1x dataset using backbone classification:
python -m dataset_utils.fragment_dataset.fragment_t1x \
--input /path/to/input_dataset.pkl \
--output /path/to/fragmented_dataset.pkl \
--summary /path/to/summary.txt \
--epsilon-levels 0,1,2 \
--random-continuous-fractions 0 \
--threshold-continuous-fractions 0.6This identifies backbone atoms using Murcko scaffolds and reactive center detection, then creates fragmented versions at different epsilon levels (number of shells of neighboring atoms to include: epsilon=0 includes the reactive core atoms only, and epsilon=k for k>0 also include atoms that are k bonds apart from an atom that belongs to the reactive core).
Note
The dataset_utils/larget1x directory contains scripts for generating fully mapped SMILES strings and validating reactions using TS-tools. To obtain 3D geometries of reactants, products, and transition states, please use the TS-tools package.
Create a dataset from reactions validated with TS-tools:
python -m dataset_utils.larget1x.create_fragmented_pkl_dataset \
--csv-path /path/to/extended_dataset_smiles.csv \
--dataset-root /path/to/reaction_folders \
--filtered-dir /path/to/filtered_reactions \
--output-path /path/to/output_dataset.pklGenerate linear interpolation or IDPP transition state guesses:
python -m dataset_utils.simple_benchmarks.generate_trajectory_ts \
--dataset-path /path/to/dataset.pkl \
--output-folder /path/to/output \
--mode idpp \
--num-structures 10 \
--model uma-s-1p1 \
--gpu 0 \
--fragmentation-mode originalArguments:
--mode: Trajectory generation method (interpolationoridpp)--num-structures: Number of intermediate structures in trajectory
Example transition state predictions comparing FragmentFlow output to reference structures from LargeT1x.
| FragmentFlow + Sella Optimization | Reference |
|---|---|
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
If you use FragmentFlow in your research, please cite:
@article{shprints2026fragmentflow,
title={FragmentFlow: Scalable Transition State Generation for Large Molecules},
author={Shprints, Ron and Holderrieth, Peter and Nam, Juno and G{\'o}mez-Bombarelli, Rafael and Jaakkola, Tommi},
journal={arXiv preprint arXiv:2602.02310},
year={2026}
}





