Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
641d7ab
test commit
Jun 3, 2025
3b05347
scratch folder added
Jun 4, 2025
566ae05
scratch files for JAMUN run on alanine dipeptide uncapped
Jun 5, 2025
6d658e9
prep for hidden state introduction to denoiser
Jun 5, 2025
9184121
introduced hidden state in data, made a test denoiser module with con…
Jun 9, 2025
76af63d
conditional generation models, testing, and hydra debugging script
Jun 13, 2025
8e1bd9c
checks on conditioned denoising model
Jun 16, 2025
c82d428
conditioning module set up, dataloaders with time lags, and adding mu…
Jun 24, 2025
a51de44
multimeasurement loss added and tested with sweeps
Jun 25, 2025
1bc3944
added label override to dataset parser
Jul 11, 2025
df42585
Bug fix, label override, init graph recentering
Jul 14, 2025
3a3bf87
Added a data generation protocol for generating equilibrium structure…
Jul 19, 2025
4982bf6
New configs for noise check experiments, debugged multimeasurement (s…
Jul 24, 2025
c2745a7
merge from main
Jul 24, 2025
c1d7e26
added kwargs to score function processed
Jul 24, 2025
67c6c13
catch up with main
Jul 26, 2025
02219ab
Updated configs/sbatch script for single run
Jul 27, 2025
c6b04d1
feat: Add spatiotemporal conditioning with input attributes for enhan…
Jul 28, 2025
73ff35d
added spatiotemporal transformer with configs
Jul 30, 2025
265e753
wrapper for pretrained_denoiser
Aug 3, 2025
5d3e3b3
Fixes:
Aug 8, 2025
4824862
New configs, sampling sweep, modified sampler
Aug 12, 2025
c14a900
conditional-gen
Aug 18, 2025
20dfc1e
commits for cluster migration
Aug 19, 2025
e0b210d
pre commit before documentation commit
Aug 22, 2025
4c350bb
documentation commit
Aug 22, 2025
a7ccaa2
documentation commit
Aug 22, 2025
e5d2f24
documentation commit
Aug 22, 2025
148b76a
documentation commit
Aug 22, 2025
e166db1
update deps
kleinhenz Feb 5, 2026
cc4b427
format + lint
kleinhenz Feb 5, 2026
6d2a992
lint unsafe fixes
kleinhenz Feb 5, 2026
133c6b2
update .gitignore
kleinhenz Feb 5, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,5 @@ torch_compile_debug
*.profile*
**/*.log
**/*.err
scripts/study
wandb/*
26 changes: 14 additions & 12 deletions configs/experiment/sample_capped_2AA.yaml
Original file line number Diff line number Diff line change
@@ -1,36 +1,37 @@
# @package _global_

defaults:
- override /callbacks:
- sampler/save_trajectory.yaml
- _self_
# defaults:
# - override /callbacks:
# - sampler/save_trajectory.yaml
# - _self_


init_datasets:
_target_: jamun.data.parse_datasets_from_directory
root: "${paths.data_path}/capped_diamines/timewarp_splits/test"
root: "${paths.data_path}/capped_diamines/timewarp_splits/train"
traj_pattern: "^(.*).xtc"
pdb_pattern: "^(.*).pdb"
subsample: 1
num_frames: 60000
filter_codes: ['ALA_ALA']
num_frames: 320000


num_sampling_steps_per_batch: 20000
num_batches: 5
num_init_samples_per_dataset: 1
num_sampling_steps_per_batch: 1000
num_batches: 10
num_init_samples_per_dataset: 50
repeat_init_samples: 1
continue_chain: true

# New 2AA
wandb_train_run_path: prescient-design/jamun/yfz3vpzg
wandb_train_run_path: sule-shashank/jamun/370wpt17

checkpoint_type: best_so_far
sigma: 0.04
M: 1.0
delta: 0.04
delta: ${sigma}
friction: 1.0
inverse_temperature: 1.0
score_fn_clip: 100.0
score_fn_clip: null

sampler:
_target_: jamun.sampling.Sampler
Expand All @@ -39,3 +40,4 @@ sampler:
logger:
wandb:
group: sample_capped_2AA
tags: ['ALA_ALA', 'sigma_0.04', 'standard JAMUN']
51 changes: 51 additions & 0 deletions configs/experiment/sample_capped_single_shape_conditioning.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# @package _global_

defaults:
- override /model: denoiser_conditional_pretrained.yaml
# defaults:
# - override /callbacks:
# - sampler/save_trajectory.yaml
# - _self_

# callbacks:
# viz:
# sigma_list: ["${model.sigma_distribution.sigma}"]

init_datasets:
_target_: jamun.data.parse_datasets_from_directory
root: "${paths.data_path}/capped_diamines/timewarp_splits/train"
traj_pattern: "^(.*).xtc"
pdb_pattern: "^(.*).pdb"
filter_codes: ['ALA_ALA']
as_iterable: false
subsample: 100
total_lag_time: 10
lag_subsample_rate: 100

num_sampling_steps_per_batch: 1000
num_batches: 10
num_init_samples_per_dataset: 10
repeat_init_samples: 1
continue_chain: false

# Add your wandb run path here
wandb_train_run_path: sule-shashank/jamun/jqp09yv1
# checkpoint_dir: /data2/sules/jamun-conditional-runs/outputs/train/dev/runs/2025-07-01_16-15-23/checkpoints
# checkpoint_dir: /data2/sules/jamun-conditional-runs/old_outputs/outputs/train/dev/runs/2a381c2d310e3d0789831338/checkpoints
checkpoint_type: last

sigma: 0.01
M: 1.0
delta: ${sigma}
friction: 1.0
inverse_temperature: 1.0
score_fn_clip: 100.0

sampler:
_target_: jamun.sampling.SamplerMemory
devices: 1

logger:
wandb:
group: sample_ALA_ALA_conditional
notes: stellar-sweep-25
50 changes: 29 additions & 21 deletions configs/experiment/sample_custom.yaml
Original file line number Diff line number Diff line change
@@ -1,24 +1,35 @@
# @package _global_

num_sampling_steps_per_batch: 1000
num_batches: 1
num_init_samples_per_dataset: 1
repeat_init_samples: 1
continue_chain: true
name: cfg # This ensures the config is saved with the 'cfg' key

# wandb_train_run_path: prescient-design/jamun/zzt8s3rc
# defaults:
# - override /callbacks:
# - sampler/save_trajectory.yaml
# - _self_

# Old 4AA
# wandb_train_run_path: prescient-design/jamun/ibtxmwcr
# callbacks:
# viz:
# sigma_list: ["${model.sigma_distribution.sigma}"]

# New 4AA
wandb_train_run_path: prescient-design/jamun/6297yugb
init_datasets:
_target_: jamun.data.parse_datasets_from_directory
root: "${paths.data_path}/timewarp/2AA-1-large/train/"
traj_pattern: "^(.*)-traj-arrays.npz"
pdb_file: AA-traj-state0.pdb
filter_codes:
- AA
subsample: 1
max_datasets: 1

# Finetuned new 4AA
# wandb_train_run_path: prescient-design/jamun/x6rwt91k
num_sampling_steps_per_batch: 1000
num_batches: 100
num_init_samples_per_dataset: 1
repeat_init_samples: 1
continue_chain: true

# init_pdb: /data/bucket/kleinhej/fast-folding/processed/chignolin/filtered.pdb
init_pdbs: ???
# Add your wandb run path here
wandb_train_run_path: sule-shashank/jamun/y4rm5488
# checkpoint_dir: outputs/train/dev/runs/2025-06-11_20-16-04/wandb/latest-run/checkpoints

checkpoint_type: best_so_far
sigma: 0.04
Expand All @@ -28,13 +39,9 @@ friction: 1.0
inverse_temperature: 1.0
score_fn_clip: 100.0

init_datasets:
_target_: jamun.data.create_dataset_from_pdbs
pdbfiles: ${init_pdbs}

finetune_on_init:
num_steps: ???
batch_size: 16
# Model configuration
model:
_target_: scratch.denoiser_test.Denoiser.load_from_checkpoint

sampler:
_target_: jamun.sampling.Sampler
Expand All @@ -43,3 +50,4 @@ sampler:
logger:
wandb:
group: sample_custom
notes: Custom sampling with denoiser, conditioner
106 changes: 106 additions & 0 deletions configs/experiment/sample_enhanced_conditioning.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# @package _global_

defaults:
- override /model: denoiser_conditional_pretrained.yaml
- override /callbacks: null

# init_datasets:
# _target_: jamun.data.parse_datasets_from_directory
# root: "${paths.data_path}/capped_diamines/timewarp_splits/train"
# traj_pattern: "^(.*).xtc"
# pdb_pattern: "^(.*).pdb"
# filter_codes: ['ALA_ALA']
# as_iterable: false
# subsample: 10
# total_lag_time: 5
# lag_subsample_rate: 1
# max_datasets: 1
# label_override: "ALA_ALA"

init_datasets:
_target_: jamun.data.parse_datasets_from_directory
root: "/data2/sules/ALA_ALA_enhanced_full_swarm/val"
traj_pattern: "^(.*).xtc"
pdb_pattern: "^(.*).pdb"
as_iterable: false
subsample: 1
total_lag_time: 5
lag_subsample_rate: 1
max_datasets: 10
label_override: "ALA_ALA"

# model:
# conditioner:
# N_structures: ${init_datasets.total_lag_time}

num_sampling_steps_per_batch: 1000
num_batches: 10
num_init_samples_per_dataset: 1
repeat_init_samples: 1
continue_chain: true

# Add your checkpoint path here - update with actual trained model path
wandb_train_run_path: "sule-shashank/jamun/qiutegoj"
# checkpoint_type: last
# checkpoint_dir: "/data2/sules/jamun-conditional-runs/outputs/train/dev/runs/2025-07-31_00-43-14/checkpoints/"
checkpoint_type: last

sigma: 0.06
M: 1.0
delta: 0.066
friction: 1.2
inverse_temperature: 1.0
score_fn_clip: null

sampler:
_target_: jamun.sampling.SamplerMemory
devices: 1

# Evaluation dataset - standard ALA_ALA from capped diamines for computing metrics
eval_dataset:
_target_: jamun.data.parse_datasets_from_directory
root: "${paths.data_path}/capped_diamines/timewarp_splits/train"
traj_pattern: "^(.*).xtc"
pdb_pattern: "^(.*).pdb"
filter_codes: ['ALA_ALA']
as_iterable: false
subsample: 10
total_lag_time: 5
lag_subsample_rate: 1
max_datasets: 10
label_override: "ALA_ALA"

# Override ALL callbacks to use eval_dataset for metrics computation
callbacks:
measure_sampling_time:
_target_: jamun.callbacks.sampler.MeasureSamplingTimeCallback
chemical_validity:
_target_: jamun.callbacks.sampler.ChemicalValidityMetricsCallback
datasets: ${eval_dataset}
bond_length_tolerance: 0.2
volume_exclusion_tolerance: 0.1
num_molecules_per_trajectory: 100
ramachandran_plot:
_target_: jamun.callbacks.sampler.RamachandranPlotMetricsCallback
datasets: ${eval_dataset}
trajectory_visualizer:
_target_: jamun.callbacks.sampler.TrajectoryVisualizerCallback
datasets: ${eval_dataset}
num_frames_to_animate: 100
sample_visualizer:
_target_: jamun.callbacks.sampler.SampleVisualizerCallback
datasets: ${eval_dataset}
num_samples_to_plot: 16
subsample: 100
score_distribution:
_target_: jamun.callbacks.sampler.ScoreDistributionCallback
datasets: ${eval_dataset}
save_trajectory:
_target_: jamun.callbacks.sampler.SaveTrajectoryCallback
datasets: ${eval_dataset}

logger:
wandb:
group: sample_enhanced_sampling_data
notes: "Sampling from enhanced sampling data using memory sampler"
tags: ["sample", "enhanced_sampling", "memory_sampler", "ALA_ALA", "spatiotemporal conditioner"]
Loading