Skip to content

Scalable tranistion state generation for large molecules.

License

Notifications You must be signed in to change notification settings

ronsh9/FragmentFlow

Repository files navigation

FragmentFlow: Scalable Transition State Generation for Large Molecules

[Paper], [Data]

A flow matching model for transition state structure prediction with a fragment-based approach. FragmentFlow predicts the geometry of the reactive core, which can be completed to the full structure by re-attaching the substituent atoms. This way, FragmentFlow avoids errors that stem from a distribution shift in molecular size.

This repository contains training, inference, and evaluation code.

Fragment Flow

Table of Contents

Installation

conda env create -f fragmentflow.yaml
conda activate fragmentflow
pip install -e .

Warning

The WLN mapper (localmapper) is not included in the main fragmentflow conda environment.
To use it for dataset curation, install it separately via:

pip install localmapper

Note: For compatibility, especially due to PyTorch/DGL version constraints, you may need to install localmapper in a separate conda environment.

Training

Train the model using distributed training with torchrun:

CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 --master_port=29503 \
  -m reactot.train.main_train \
  --distributed \
  --diffusion_type sb \
  --wandb_project "your_project" \
  --wandb_job_name "run_name" \
  --wandb_api_key YOUR_API_KEY \
  --datadir /path/to/dataset \
  --output_dir ./output/run_name \
  --max_nodes_per_batch 10000 \
  --num_workers 16 \
  --epochs 10000 \
  --timesteps 3000 \
  --lr 1e-5 \
  --nfe 25 \
  --ot_ode \
  --ema_decay 0.999 \
  --ckpt_epochs 100 \
  --eval_epochs 50 \
  --seed 0

To resume training from a checkpoint:

--resume /path/to/checkpoint.ckpt

Note

This training script has been adapted from the original ReactOT implementation to remove the dependency on PyTorch Lightning, allowing for standard PyTorch-based training.

Sampling

Generate transition state predictions from a trained model:

python -m reactot.sample_ts \
    --nfe-values 200 \
    --output-folder /path/to/output \
    --checkpoint /path/to/checkpoint.ckpt \
    --dataset-path /path/to/dataset.pkl \
    --batch-size 20 \
    --solver ddpm \
    --gpu 0 \
    --fragmentation-mode epsilon \
    --overwrite

Arguments:

  • --nfe-values: Number of function evaluations for sampling
  • --fragmentation-mode: Filter dataset by fragmentation level (original, epsilon, or all)
  • --overwrite: Reprocess existing structures

Output structure:

output-folder/
├── ref_ts/structures/xyz_files/      # Reference TS structures
├── nfe_200/
│   ├── structures/xyz_files/         # Predicted structures
│   └── reports/                      # Metrics and logs

Evaluation

Calculate Energies

Compute UMA energies for generated structures:

python -m evaluation.calculate_energies_uma \
    --xyz-dir /path/to/xyz_files \
    --output-path /path/to/energies.csv \
    --model uma-s-1p1 \
    --gpu 0

Calculate Energy Differences

Compare energies between reference and predicted structures:

python -m evaluation.calculate_energy_differences \
    --ref-csv /path/to/ref_energies.csv \
    --sample-csv /path/to/sample_energies.csv \
    --output-csv /path/to/energy_diff.csv \
    --ref-units kcal/mol \
    --sample-units kcal/mol \
    --overwrite

Calculate RMSD

Compute RMSD between reference and predicted structures:

python -m evaluation.calculate_rmsd \
    --ref-xyz-dir /path/to/reference/xyz_files \
    --sample-xyz-dir /path/to/predicted/xyz_files \
    --output-csv /path/to/rmsd.csv

Post-Processing

Attaching Substituents

Complete fragmented core structures by attaching substituents using IDPP interpolation:

python -m post_processing.attach_substituents.complete_from_xyz_cores \
    --xyz_folder /path/to/core_structures \
    --pkl_path /path/to/fragmented_dataset.pkl \
    --output_folder /path/to/completed_structures \
    --num_idpp_images 11

This takes fragmented transition state cores and completes them by:

  1. Loading the original reactant/product structures from the dataset
  2. Finding the best IDPP image that matches the core
  3. Using Kabsch alignment and constrained IDPP refinement to position substituents

Sella Optimization

Optimize transition state structures using Sella with UMA or PySCF:

python -m post_processing.sella_optimization.optimize_structures_sequential \
    --xyz-dir /path/to/xyz_files \
    --output-dir /path/to/optimized \
    --calculator uma \
    --model uma-s-1p1 \
    --device cuda \
    --fmax 0.05 \
    --steps 500 \
    --process-by-size

Arguments:

  • --calculator: Energy calculator (uma or pyscf)
  • --fmax: Force convergence threshold
  • --steps: Maximum optimization steps
  • --process-by-size: Process structures smallest to largest

Dataset Utilities

Fragment Dataset Creation

Create fragmented versions of a Transition1x dataset using backbone classification:

python -m dataset_utils.fragment_dataset.fragment_t1x \
    --input /path/to/input_dataset.pkl \
    --output /path/to/fragmented_dataset.pkl \
    --summary /path/to/summary.txt \
    --epsilon-levels 0,1,2 \
    --random-continuous-fractions 0 \
    --threshold-continuous-fractions 0.6

This identifies backbone atoms using Murcko scaffolds and reactive center detection, then creates fragmented versions at different epsilon levels (number of shells of neighboring atoms to include: epsilon=0 includes the reactive core atoms only, and epsilon=k for k>0 also include atoms that are k bonds apart from an atom that belongs to the reactive core).

Create Dataset from TS-tools Reactions

Note

The dataset_utils/larget1x directory contains scripts for generating fully mapped SMILES strings and validating reactions using TS-tools. To obtain 3D geometries of reactants, products, and transition states, please use the TS-tools package.

Create a dataset from reactions validated with TS-tools:

python -m dataset_utils.larget1x.create_fragmented_pkl_dataset \
    --csv-path /path/to/extended_dataset_smiles.csv \
    --dataset-root /path/to/reaction_folders \
    --filtered-dir /path/to/filtered_reactions \
    --output-path /path/to/output_dataset.pkl

Generate Baseline TS Guesses

Generate linear interpolation or IDPP transition state guesses:

python -m dataset_utils.simple_benchmarks.generate_trajectory_ts \
    --dataset-path /path/to/dataset.pkl \
    --output-folder /path/to/output \
    --mode idpp \
    --num-structures 10 \
    --model uma-s-1p1 \
    --gpu 0 \
    --fragmentation-mode original

Arguments:

  • --mode: Trajectory generation method (interpolation or idpp)
  • --num-structures: Number of intermediate structures in trajectory

Gallery

Example transition state predictions comparing FragmentFlow output to reference structures from LargeT1x.

FragmentFlow + Sella Optimization Reference

Citation

If you use FragmentFlow in your research, please cite:

@article{shprints2026fragmentflow,
  title={FragmentFlow: Scalable Transition State Generation for Large Molecules},
  author={Shprints, Ron and Holderrieth, Peter and Nam, Juno and G{\'o}mez-Bombarelli, Rafael and Jaakkola, Tommi},
  journal={arXiv preprint arXiv:2602.02310},
  year={2026}
}

About

Scalable tranistion state generation for large molecules.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages