diff --git a/.gitignore b/.gitignore index df6fd73..2b478af 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,5 @@ torch_compile_debug *.profile* **/*.log **/*.err +scripts/study +wandb/* diff --git a/configs/experiment/sample_capped_2AA.yaml b/configs/experiment/sample_capped_2AA.yaml index f9b1fbf..af989c9 100644 --- a/configs/experiment/sample_capped_2AA.yaml +++ b/configs/experiment/sample_capped_2AA.yaml @@ -1,36 +1,37 @@ # @package _global_ -defaults: - - override /callbacks: - - sampler/save_trajectory.yaml - - _self_ +# defaults: +# - override /callbacks: +# - sampler/save_trajectory.yaml +# - _self_ init_datasets: _target_: jamun.data.parse_datasets_from_directory - root: "${paths.data_path}/capped_diamines/timewarp_splits/test" + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" traj_pattern: "^(.*).xtc" pdb_pattern: "^(.*).pdb" subsample: 1 - num_frames: 60000 + filter_codes: ['ALA_ALA'] + num_frames: 320000 -num_sampling_steps_per_batch: 20000 -num_batches: 5 -num_init_samples_per_dataset: 1 +num_sampling_steps_per_batch: 1000 +num_batches: 10 +num_init_samples_per_dataset: 50 repeat_init_samples: 1 continue_chain: true # New 2AA -wandb_train_run_path: prescient-design/jamun/yfz3vpzg +wandb_train_run_path: sule-shashank/jamun/370wpt17 checkpoint_type: best_so_far sigma: 0.04 M: 1.0 -delta: 0.04 +delta: ${sigma} friction: 1.0 inverse_temperature: 1.0 -score_fn_clip: 100.0 +score_fn_clip: null sampler: _target_: jamun.sampling.Sampler @@ -39,3 +40,4 @@ sampler: logger: wandb: group: sample_capped_2AA + tags: ['ALA_ALA', 'sigma_0.04', 'standard JAMUN'] \ No newline at end of file diff --git a/configs/experiment/sample_capped_single_shape_conditioning.yaml b/configs/experiment/sample_capped_single_shape_conditioning.yaml new file mode 100644 index 0000000..911e2ee --- /dev/null +++ b/configs/experiment/sample_capped_single_shape_conditioning.yaml @@ -0,0 +1,51 @@ +# @package _global_ + +defaults: + - override /model: denoiser_conditional_pretrained.yaml +# defaults: +# - override /callbacks: +# - sampler/save_trajectory.yaml +# - _self_ + +# callbacks: +# viz: +# sigma_list: ["${model.sigma_distribution.sigma}"] + +init_datasets: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: false + subsample: 100 + total_lag_time: 10 + lag_subsample_rate: 100 + +num_sampling_steps_per_batch: 1000 +num_batches: 10 +num_init_samples_per_dataset: 10 +repeat_init_samples: 1 +continue_chain: false + +# Add your wandb run path here +wandb_train_run_path: sule-shashank/jamun/jqp09yv1 +# checkpoint_dir: /data2/sules/jamun-conditional-runs/outputs/train/dev/runs/2025-07-01_16-15-23/checkpoints +# checkpoint_dir: /data2/sules/jamun-conditional-runs/old_outputs/outputs/train/dev/runs/2a381c2d310e3d0789831338/checkpoints +checkpoint_type: last + +sigma: 0.01 +M: 1.0 +delta: ${sigma} +friction: 1.0 +inverse_temperature: 1.0 +score_fn_clip: 100.0 + +sampler: + _target_: jamun.sampling.SamplerMemory + devices: 1 + +logger: + wandb: + group: sample_ALA_ALA_conditional + notes: stellar-sweep-25 diff --git a/configs/experiment/sample_custom.yaml b/configs/experiment/sample_custom.yaml index 9b657e3..3052289 100644 --- a/configs/experiment/sample_custom.yaml +++ b/configs/experiment/sample_custom.yaml @@ -1,24 +1,35 @@ # @package _global_ -num_sampling_steps_per_batch: 1000 -num_batches: 1 -num_init_samples_per_dataset: 1 -repeat_init_samples: 1 -continue_chain: true +name: cfg # This ensures the config is saved with the 'cfg' key -# wandb_train_run_path: prescient-design/jamun/zzt8s3rc +# defaults: +# - override /callbacks: +# - sampler/save_trajectory.yaml +# - _self_ -# Old 4AA -# wandb_train_run_path: prescient-design/jamun/ibtxmwcr +# callbacks: +# viz: +# sigma_list: ["${model.sigma_distribution.sigma}"] -# New 4AA -wandb_train_run_path: prescient-design/jamun/6297yugb +init_datasets: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/timewarp/2AA-1-large/train/" + traj_pattern: "^(.*)-traj-arrays.npz" + pdb_file: AA-traj-state0.pdb + filter_codes: + - AA + subsample: 1 + max_datasets: 1 -# Finetuned new 4AA -# wandb_train_run_path: prescient-design/jamun/x6rwt91k +num_sampling_steps_per_batch: 1000 +num_batches: 100 +num_init_samples_per_dataset: 1 +repeat_init_samples: 1 +continue_chain: true -# init_pdb: /data/bucket/kleinhej/fast-folding/processed/chignolin/filtered.pdb -init_pdbs: ??? +# Add your wandb run path here +wandb_train_run_path: sule-shashank/jamun/y4rm5488 +# checkpoint_dir: outputs/train/dev/runs/2025-06-11_20-16-04/wandb/latest-run/checkpoints checkpoint_type: best_so_far sigma: 0.04 @@ -28,13 +39,9 @@ friction: 1.0 inverse_temperature: 1.0 score_fn_clip: 100.0 -init_datasets: - _target_: jamun.data.create_dataset_from_pdbs - pdbfiles: ${init_pdbs} - -finetune_on_init: - num_steps: ??? - batch_size: 16 +# Model configuration +model: + _target_: scratch.denoiser_test.Denoiser.load_from_checkpoint sampler: _target_: jamun.sampling.Sampler @@ -43,3 +50,4 @@ sampler: logger: wandb: group: sample_custom + notes: Custom sampling with denoiser, conditioner \ No newline at end of file diff --git a/configs/experiment/sample_enhanced_conditioning.yaml b/configs/experiment/sample_enhanced_conditioning.yaml new file mode 100644 index 0000000..95aa3df --- /dev/null +++ b/configs/experiment/sample_enhanced_conditioning.yaml @@ -0,0 +1,106 @@ +# @package _global_ + +defaults: + - override /model: denoiser_conditional_pretrained.yaml + - override /callbacks: null + +# init_datasets: +# _target_: jamun.data.parse_datasets_from_directory +# root: "${paths.data_path}/capped_diamines/timewarp_splits/train" +# traj_pattern: "^(.*).xtc" +# pdb_pattern: "^(.*).pdb" +# filter_codes: ['ALA_ALA'] +# as_iterable: false +# subsample: 10 +# total_lag_time: 5 +# lag_subsample_rate: 1 +# max_datasets: 1 +# label_override: "ALA_ALA" + +init_datasets: + _target_: jamun.data.parse_datasets_from_directory + root: "/data2/sules/ALA_ALA_enhanced_full_swarm/val" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + as_iterable: false + subsample: 1 + total_lag_time: 5 + lag_subsample_rate: 1 + max_datasets: 10 + label_override: "ALA_ALA" + +# model: +# conditioner: +# N_structures: ${init_datasets.total_lag_time} + +num_sampling_steps_per_batch: 1000 +num_batches: 10 +num_init_samples_per_dataset: 1 +repeat_init_samples: 1 +continue_chain: true + +# Add your checkpoint path here - update with actual trained model path +wandb_train_run_path: "sule-shashank/jamun/qiutegoj" +# checkpoint_type: last +# checkpoint_dir: "/data2/sules/jamun-conditional-runs/outputs/train/dev/runs/2025-07-31_00-43-14/checkpoints/" +checkpoint_type: last + +sigma: 0.06 +M: 1.0 +delta: 0.066 +friction: 1.2 +inverse_temperature: 1.0 +score_fn_clip: null + +sampler: + _target_: jamun.sampling.SamplerMemory + devices: 1 + +# Evaluation dataset - standard ALA_ALA from capped diamines for computing metrics +eval_dataset: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: false + subsample: 10 + total_lag_time: 5 + lag_subsample_rate: 1 + max_datasets: 10 + label_override: "ALA_ALA" + +# Override ALL callbacks to use eval_dataset for metrics computation +callbacks: + measure_sampling_time: + _target_: jamun.callbacks.sampler.MeasureSamplingTimeCallback + chemical_validity: + _target_: jamun.callbacks.sampler.ChemicalValidityMetricsCallback + datasets: ${eval_dataset} + bond_length_tolerance: 0.2 + volume_exclusion_tolerance: 0.1 + num_molecules_per_trajectory: 100 + ramachandran_plot: + _target_: jamun.callbacks.sampler.RamachandranPlotMetricsCallback + datasets: ${eval_dataset} + trajectory_visualizer: + _target_: jamun.callbacks.sampler.TrajectoryVisualizerCallback + datasets: ${eval_dataset} + num_frames_to_animate: 100 + sample_visualizer: + _target_: jamun.callbacks.sampler.SampleVisualizerCallback + datasets: ${eval_dataset} + num_samples_to_plot: 16 + subsample: 100 + score_distribution: + _target_: jamun.callbacks.sampler.ScoreDistributionCallback + datasets: ${eval_dataset} + save_trajectory: + _target_: jamun.callbacks.sampler.SaveTrajectoryCallback + datasets: ${eval_dataset} + +logger: + wandb: + group: sample_enhanced_sampling_data + notes: "Sampling from enhanced sampling data using memory sampler" + tags: ["sample", "enhanced_sampling", "memory_sampler", "ALA_ALA", "spatiotemporal conditioner"] \ No newline at end of file diff --git a/configs/experiment/sample_enhanced_conditioning_sweep.yaml b/configs/experiment/sample_enhanced_conditioning_sweep.yaml new file mode 100644 index 0000000..2037d9c --- /dev/null +++ b/configs/experiment/sample_enhanced_conditioning_sweep.yaml @@ -0,0 +1,110 @@ +# @package _global_ + +defaults: + - override /model: denoiser_conditional_pretrained.yaml + - override /callbacks: null + +# init_datasets: +# _target_: jamun.data.parse_datasets_from_directory +# root: "${paths.data_path}/capped_diamines/timewarp_splits/train" +# traj_pattern: "^(.*).xtc" +# pdb_pattern: "^(.*).pdb" +# filter_codes: ['ALA_ALA'] +# as_iterable: false +# subsample: 10 +# total_lag_time: 5 +# lag_subsample_rate: 1 +# max_datasets: 1 +# label_override: "ALA_ALA" + +init_datasets: + _target_: jamun.data.parse_datasets_from_directory + root: "/data2/sules/ALA_ALA_enhanced_long/val" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + as_iterable: false + subsample: 1 + num_frames: 6 + total_lag_time: 5 + lag_subsample_rate: 1 + max_datasets: 20 + label_override: "ALA_ALA" + +# model: +# conditioner: +# N_structures: ${init_datasets.total_lag_time} + +num_sampling_steps_per_batch: 1000 +num_batches: 1 +num_init_samples_per_dataset: 1 +repeat_init_samples: 1 +continue_chain: true + +batch_sampler: + mcmc: + history_update_frequency: 10 +# Add your checkpoint path here - update with actual trained model path +wandb_train_run_path: sule-shashank/jamun/k20xo7qb +# checkpoint_type: last +# checkpoint_dir: "/data2/sules/jamun-conditional-runs/outputs/train/dev/runs/2025-07-31_00-43-14/checkpoints/" +checkpoint_type: last + +sigma: 0.06 +M: 1.0 +delta: 0.04 +friction: 1.0 +inverse_temperature: 1.0 +score_fn_clip: null + +sampler: + _target_: jamun.sampling.SamplerMemory + devices: 1 + +# Evaluation dataset - standard ALA_ALA from capped diamines for computing metrics +eval_dataset: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: false + subsample: 10 + total_lag_time: 5 + lag_subsample_rate: 1 + max_datasets: 1 + label_override: "ALA_ALA" + +# Override ALL callbacks to use eval_dataset for metrics computation +callbacks: + measure_sampling_time: + _target_: jamun.callbacks.sampler.MeasureSamplingTimeCallback + chemical_validity: + _target_: jamun.callbacks.sampler.ChemicalValidityMetricsCallback + datasets: ${eval_dataset} + bond_length_tolerance: 0.2 + volume_exclusion_tolerance: 0.1 + num_molecules_per_trajectory: 100 + ramachandran_plot: + _target_: jamun.callbacks.sampler.RamachandranPlotMetricsCallback + datasets: ${eval_dataset} + trajectory_visualizer: + _target_: jamun.callbacks.sampler.TrajectoryVisualizerCallback + datasets: ${eval_dataset} + num_frames_to_animate: 100 + sample_visualizer: + _target_: jamun.callbacks.sampler.SampleVisualizerCallback + datasets: ${eval_dataset} + num_samples_to_plot: 16 + subsample: 100 + score_distribution: + _target_: jamun.callbacks.sampler.ScoreDistributionCallback + datasets: ${eval_dataset} + save_trajectory: + _target_: jamun.callbacks.sampler.SaveTrajectoryCallback + datasets: ${eval_dataset} + +logger: + wandb: + group: sampling_hyperparameter_sweep + notes: "Tuning hyperparameter sweep for memory sampler" + tags: ["sample", "enhanced_sampling", "memory_sampler", "ALA_ALA", "spatiotemporal conditioner", "sweep"] \ No newline at end of file diff --git a/configs/experiment/sample_enhanced_standard.yaml b/configs/experiment/sample_enhanced_standard.yaml new file mode 100644 index 0000000..51d83b9 --- /dev/null +++ b/configs/experiment/sample_enhanced_standard.yaml @@ -0,0 +1,106 @@ +# @package _global_ + +defaults: + - override /model: denoiser_pretrained.yaml + - override /callbacks: null + +# init_datasets: +# _target_: jamun.data.parse_datasets_from_directory +# root: "${paths.data_path}/capped_diamines/timewarp_splits/train" +# traj_pattern: "^(.*).xtc" +# pdb_pattern: "^(.*).pdb" +# filter_codes: ['ALA_ALA'] +# as_iterable: false +# subsample: 10 +# total_lag_time: 5 +# lag_subsample_rate: 1 +# max_datasets: 1 +# label_override: "ALA_ALA" + +init_datasets: + _target_: jamun.data.parse_datasets_from_directory + root: "/data2/sules/ALA_ALA_enhanced_full_grid/val" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + as_iterable: false + subsample: 1 + total_lag_time: 5 + lag_subsample_rate: 1 + max_datasets: 10 + label_override: "ALA_ALA" + +# model: +# conditioner: +# N_structures: ${init_datasets.total_lag_time} + +num_sampling_steps_per_batch: 1000 +num_batches: 10 +num_init_samples_per_dataset: 1 +repeat_init_samples: 1 +continue_chain: true + +# Add your checkpoint path here - update with actual trained model path +wandb_train_run_path: "sule-shashank/jamun/fh9o4mme" +# checkpoint_type: last +# checkpoint_dir: "/data2/sules/jamun-conditional-runs/outputs/train/dev/runs/2025-07-31_00-43-14/checkpoints/" +checkpoint_type: last + +sigma: 0.06 +M: 1.0 +delta: 0.06 +friction: 1.0 +inverse_temperature: 1.0 +score_fn_clip: null + +sampler: + _target_: jamun.sampling.Sampler + devices: 1 + +# Evaluation dataset - standard ALA_ALA from capped diamines for computing metrics +eval_dataset: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: false + subsample: 10 + total_lag_time: 5 + lag_subsample_rate: 1 + max_datasets: 1 + label_override: "ALA_ALA" + +# Override ALL callbacks to use eval_dataset for metrics computation +callbacks: + measure_sampling_time: + _target_: jamun.callbacks.sampler.MeasureSamplingTimeCallback + chemical_validity: + _target_: jamun.callbacks.sampler.ChemicalValidityMetricsCallback + datasets: ${eval_dataset} + bond_length_tolerance: 0.2 + volume_exclusion_tolerance: 0.1 + num_molecules_per_trajectory: 100 + ramachandran_plot: + _target_: jamun.callbacks.sampler.RamachandranPlotMetricsCallback + datasets: ${eval_dataset} + trajectory_visualizer: + _target_: jamun.callbacks.sampler.TrajectoryVisualizerCallback + datasets: ${eval_dataset} + num_frames_to_animate: 100 + sample_visualizer: + _target_: jamun.callbacks.sampler.SampleVisualizerCallback + datasets: ${eval_dataset} + num_samples_to_plot: 16 + subsample: 100 + score_distribution: + _target_: jamun.callbacks.sampler.ScoreDistributionCallback + datasets: ${eval_dataset} + save_trajectory: + _target_: jamun.callbacks.sampler.SaveTrajectoryCallback + datasets: ${eval_dataset} + +logger: + wandb: + group: sample_enhanced_sampling_data + notes: "Benchmark sampling from enhanced sampling data via standard jamun" + tags: ["sample", "enhanced_sampling", "standard_jamun", "ALA_ALA"] \ No newline at end of file diff --git a/configs/experiment/sample_fake_enhanced_sampling_single_shape.yaml b/configs/experiment/sample_fake_enhanced_sampling_single_shape.yaml new file mode 100644 index 0000000..9f3c670 --- /dev/null +++ b/configs/experiment/sample_fake_enhanced_sampling_single_shape.yaml @@ -0,0 +1,86 @@ +# @package _global_ + +defaults: + # - override /model: denoiser_conditional_pretrained.yaml + - override /callbacks: null + +init_datasets: + _target_: jamun.data.parse_datasets_from_directory + root: "/data2/sules/fake_enhanced_data/ALA_ALA_organized/train" + traj_pattern: "^(.*).xtc" + pdb_file: "/data/bucket/kleinhej/capped_diamines/timewarp_splits/train/ALA_ALA.pdb" + as_iterable: false + subsample: 1 + total_lag_time: 5 + lag_subsample_rate: 1 + max_datasets: 10 + label_override: "ALA_ALA" + +num_sampling_steps_per_batch: 100000 +num_batches: 1 +num_init_samples_per_dataset: 1 +repeat_init_samples: 1 +continue_chain: false + +# Add your checkpoint path here - update with actual trained model path +wandb_train_run_path: sule-shashank/jamun/0mu06yg4 +# checkpoint_dir: /data2/sules/jamun-conditional-runs/outputs/train/dev/runs/2025-07-02_00-37-08/checkpoints +checkpoint_type: last + +sigma: 0.04 +M: 1.0 +delta: ${sigma} +friction: 1.0 +inverse_temperature: 1.0 +score_fn_clip: 100.0 + +sampler: + _target_: jamun.sampling.SamplerMemory + devices: 1 + +# Evaluation dataset - standard ALA_ALA from capped diamines for computing metrics +eval_dataset: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: false + subsample: 100 + max_datasets: 1 + label_override: "ALA_ALA" + +# Override ALL callbacks to use eval_dataset for metrics computation +callbacks: + measure_sampling_time: + _target_: jamun.callbacks.sampler.MeasureSamplingTimeCallback + chemical_validity: + _target_: jamun.callbacks.sampler.ChemicalValidityMetricsCallback + datasets: ${eval_dataset} + bond_length_tolerance: 0.2 + volume_exclusion_tolerance: 0.1 + num_molecules_per_trajectory: 100 + ramachandran_plot: + _target_: jamun.callbacks.sampler.RamachandranPlotMetricsCallback + datasets: ${eval_dataset} + trajectory_visualizer: + _target_: jamun.callbacks.sampler.TrajectoryVisualizerCallback + datasets: ${eval_dataset} + num_frames_to_animate: 100 + sample_visualizer: + _target_: jamun.callbacks.sampler.SampleVisualizerCallback + datasets: ${eval_dataset} + num_samples_to_plot: 16 + subsample: 100 + score_distribution: + _target_: jamun.callbacks.sampler.ScoreDistributionCallback + datasets: ${eval_dataset} + save_trajectory: + _target_: jamun.callbacks.sampler.SaveTrajectoryCallback + datasets: ${eval_dataset} + +logger: + wandb: + group: sample_ALA_ALA_enhanced_sampling + notes: "Sampling from enhanced sampling trained conditional denoiser" + tags: ["sample", "enhanced_sampling", "conditional_denoiser", "ALA_ALA", "sunny-paper-304"] \ No newline at end of file diff --git a/configs/experiment/sample_uncapped_single_shape.yaml b/configs/experiment/sample_uncapped_single_shape.yaml new file mode 100644 index 0000000..1553a52 --- /dev/null +++ b/configs/experiment/sample_uncapped_single_shape.yaml @@ -0,0 +1,49 @@ +# @package _global_ + +# defaults: +# - override /callbacks: +# - sampler/save_trajectory.yaml +# - _self_ + +# callbacks: +# viz: +# sigma_list: ["${model.sigma_distribution.sigma}"] + +init_datasets: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/timewarp/2AA-1-large/train/" + traj_pattern: "^(.*)-traj-arrays.npz" + pdb_file: AA-traj-state0.pdb + filter_codes: + - AA + subsample: 1 + max_datasets: 1 + +num_sampling_steps_per_batch: 100 +num_batches: 10 +num_init_samples_per_dataset: 1 +repeat_init_samples: 1 +continue_chain: true + +# New 2AA +wandb_train_run_path: sule-shashank/jamun/7spefobw + +# Old 2AA +# wandb_train_run_path: prescient-design/jamun/zzt8s3rc + +checkpoint_type: best_so_far +sigma: 0.04 +M: 1.0 +delta: 0.04 +friction: 1.0 +inverse_temperature: 1.0 +score_fn_clip: 100.0 + +sampler: + _target_: jamun.sampling.Sampler + devices: 1 + +logger: + wandb: + group: sample_uncapped_2AA + notes: single_shape_AA diff --git a/configs/experiment/sample_uncapped_single_shape_conditioning.yaml b/configs/experiment/sample_uncapped_single_shape_conditioning.yaml new file mode 100644 index 0000000..346ca5c --- /dev/null +++ b/configs/experiment/sample_uncapped_single_shape_conditioning.yaml @@ -0,0 +1,53 @@ +# @package _global_ + +# defaults: +# - override /callbacks: +# - sampler/save_trajectory.yaml +# - _self_ + +# callbacks: +# viz: +# sigma_list: ["${model.sigma_distribution.sigma}"] + +init_datasets: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/timewarp/2AA-1-large/train/" + traj_pattern: "^(.*)-traj-arrays.npz" + pdb_file: AA-traj-state0.pdb + filter_codes: + - AA + subsample: 1 + max_datasets: 1 + +num_sampling_steps_per_batch: 100 +num_batches: 10 +num_init_samples_per_dataset: 1 +repeat_init_samples: 1 +continue_chain: true + +# Add your wandb run path here +wandb_train_run_path: sule-shashank/jamun/3wctzcjp + +checkpoint_type: best_so_far +sigma: 0.04 +M: 1.0 +delta: 0.04 +friction: 1.0 +inverse_temperature: 1.0 +score_fn_clip: 100.0 + +# Model configuration +model: + _target_: scratch.denoiser_test.Denoiser + conditioner: + _target_: scratch.conditioners.SelfConditioner + hidden_dim: 1 + +sampler: + _target_: jamun.sampling.Sampler + devices: 1 + +logger: + wandb: + group: sample_custom + notes: Custom sampling with denoiser and conditioner diff --git a/configs/experiment/train_capped_2AA.yaml b/configs/experiment/train_capped_2AA.yaml index 4b7c02c..c9a6e14 100644 --- a/configs/experiment/train_capped_2AA.yaml +++ b/configs/experiment/train_capped_2AA.yaml @@ -1,4 +1,9 @@ # @package _global_ +defaults: + - override /callbacks: + - timing.yaml + - lr_monitor.yaml + - model_checkpoint.yaml model: sigma_distribution: @@ -8,9 +13,9 @@ model: optim: lr: 0.002 -callbacks: - viz: - sigma_list: ["${model.sigma_distribution.sigma}"] +# callbacks: +# viz: +# sigma_list: ["${model.sigma_distribution.sigma}"] data: datamodule: @@ -21,30 +26,44 @@ data: root: "${paths.data_path}/capped_diamines/timewarp_splits/train" traj_pattern: "^(.*).xtc" pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_GLU', 'GLU_ALA', 'ALA_ALA'] + as_iterable: false + subsample: 5 + total_lag_time: 5 + lag_subsample_rate: 5 num_frames: 320000 val: _target_: jamun.data.parse_datasets_from_directory - root: "${paths.data_path}/capped_diamines/timewarp_splits/val/" + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" traj_pattern: "^(.*).xtc" pdb_pattern: "^(.*).pdb" - subsample: 100 - max_datasets: 20 - num_frames: 320000 + filter_codes: ['GLU_GLU'] + as_iterable: false + subsample: 5 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: 5 + num_frames: 10000 test: _target_: jamun.data.parse_datasets_from_directory - root: "${paths.data_path}/capped_diamines/timewarp_splits/test/" + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" traj_pattern: "^(.*).xtc" pdb_pattern: "^(.*).pdb" - subsample: 100 - max_datasets: 20 - num_frames: 320000 + filter_codes: ['ALA_ALA'] + as_iterable: false + subsample: 1 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: 1 + num_frames: 1000 trainer: val_check_interval: 0.1 - max_epochs: 10 + max_epochs: 100 logger: wandb: - group: train_capped_2AA + group: 2AA_capped_diamines_conditioner_comparison + notes: "Standard JAMUN on ALA_GLU, GLU_ALA, ALA_ALA" + tags: ["standard_jamun", "2AA", "capped_diamines"] + diff --git a/configs/experiment/train_capped_2AA_conditional.yaml b/configs/experiment/train_capped_2AA_conditional.yaml new file mode 100644 index 0000000..a2e6728 --- /dev/null +++ b/configs/experiment/train_capped_2AA_conditional.yaml @@ -0,0 +1,70 @@ +# @package _global_ +defaults: + - override /model: denoiser_conditional + - override /callbacks: + - timing.yaml + - lr_monitor.yaml + - model_checkpoint.yaml + + +callbacks: + viz: + sigma_list: ["${model.sigma_distribution.sigma}"] + +data: + datamodule: + batch_size: 32 + datasets: + train: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + num_frames: 320000 + total_lag_time: 5 + lag_subsample_rate: 10 + + val: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/val/" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + subsample: 100 + max_datasets: 20 + num_frames: 320000 + total_lag_time: 5 + lag_subsample_rate: 10 + + test: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/test/" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + subsample: 100 + max_datasets: 20 + num_frames: 320000 + total_lag_time: 5 + lag_subsample_rate: 10 + +trainer: + val_check_interval: 0.1 + max_epochs: 20 + +model: + sigma_distribution: + _target_: jamun.distributions.ConstantSigma + sigma: 0.04 + arch: + n_layers: 3 + N_structures: ${data.datamodule.datasets.train.total_lag_time} + max_radius: 1000.0 + optim: + lr: 0.002 + conditioner: + _target_: jamun.model.conditioners.PositionConditioner + N_structures: ${model.arch.N_structures} + multimeasurement: false + +logger: + wandb: + group: train_capped_2AA diff --git a/configs/experiment/train_capped_2AA_position_conditioner.yaml b/configs/experiment/train_capped_2AA_position_conditioner.yaml new file mode 100644 index 0000000..59bd863 --- /dev/null +++ b/configs/experiment/train_capped_2AA_position_conditioner.yaml @@ -0,0 +1,75 @@ +# @package _global_ + +defaults: + - override /model: denoiser_conditional + - override /callbacks: + - timing.yaml + - lr_monitor.yaml + - model_checkpoint.yaml + +data: + datamodule: + batch_size: 32 + datasets: + train: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_GLU', 'GLU_ALA', 'ALA_ALA'] + as_iterable: false + subsample: 5 + total_lag_time: 5 + lag_subsample_rate: 5 + num_frames: 320000 + + val: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['GLU_GLU'] + as_iterable: false + subsample: 5 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: 5 + num_frames: 10000 + + test: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: false + subsample: 1 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: 1 + num_frames: 1000 + +model: + sigma_distribution: + _target_: jamun.distributions.ConstantSigma + sigma: 0.04 + arch: + n_layers: 4 + N_structures: ${data.datamodule.datasets.train.total_lag_time} + max_radius: 1000.0 + optim: + lr: 0.002 + conditioner: + _target_: jamun.model.conditioners.PositionConditioner + N_structures: ${model.arch.N_structures} + pretrained_model_path: null + c_in: null + +trainer: + val_check_interval: 0.5 + max_epochs: 100 + +logger: + wandb: + group: 2AA_capped_diamines_conditional_denoiser + notes: "Running on 2AA capped diamines, conditional denoiser with position conditioner" + tags: ["position_conditioner", "2AA", "capped_diamines", "generalization"] + diff --git a/configs/experiment/train_capped_2AA_self_conditioner.yaml b/configs/experiment/train_capped_2AA_self_conditioner.yaml new file mode 100644 index 0000000..7078f82 --- /dev/null +++ b/configs/experiment/train_capped_2AA_self_conditioner.yaml @@ -0,0 +1,75 @@ +# @package _global_ +# Training SelfConditioner on 2AA capped diamines data + +defaults: + - override /model: denoiser_conditional + - override /callbacks: + - timing.yaml + - lr_monitor.yaml + - model_checkpoint.yaml + +data: + datamodule: + batch_size: 32 + datasets: + train: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_GLU', 'GLU_ALA', 'ALA_ALA'] + as_iterable: false + subsample: 5 + total_lag_time: 5 + lag_subsample_rate: 5 + num_frames: 320000 + + val: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['GLU_GLU'] + as_iterable: false + subsample: 5 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: 5 + num_frames: 10000 + + test: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: false + subsample: 1 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: 1 + num_frames: 1000 + +model: + sigma_distribution: + _target_: jamun.distributions.ConstantSigma + sigma: 0.04 + arch: + n_layers: 4 + N_structures: ${data.datamodule.datasets.train.total_lag_time} + max_radius: 1000.0 + optim: + lr: 0.002 + conditioner: + _target_: jamun.model.conditioners.SelfConditioner + N_structures: ${model.arch.N_structures} + +trainer: + val_check_interval: 0.5 + max_epochs: 100 + devices: 1 + +logger: + wandb: + group: 2AA_capped_diamines_conditioner_comparison + notes: "SelfConditioner on 2AA capped diamines data" + tags: ["self_conditioner", "2AA", "capped_diamines", "generalization"] + diff --git a/configs/experiment/train_capped_2AA_spatiotemporal_conditioner.yaml b/configs/experiment/train_capped_2AA_spatiotemporal_conditioner.yaml new file mode 100644 index 0000000..68319b3 --- /dev/null +++ b/configs/experiment/train_capped_2AA_spatiotemporal_conditioner.yaml @@ -0,0 +1,71 @@ +# @package _global_ +# Training SpatioTemporalConditioner on 2AA capped diamines data + +defaults: + - override /model: denoiser_conditional_spatiotemporal + - override /callbacks: + - timing.yaml + - lr_monitor.yaml + - model_checkpoint.yaml + +data: + datamodule: + batch_size: 32 + datasets: + train: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_GLU', 'GLU_ALA', 'ALA_ALA'] + as_iterable: false + subsample: 5 + total_lag_time: 5 + lag_subsample_rate: 5 + num_frames: 320000 + + val: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['GLU_GLU'] + as_iterable: false + subsample: 5 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: 5 + num_frames: 10000 + + test: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: false + subsample: 1 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: 1 + num_frames: 1000 + +model: + sigma_distribution: + _target_: jamun.distributions.ConstantSigma + sigma: 0.04 + arch: + n_layers: 1 + max_radius: 1.0 + optim: + lr: 0.002 # Slightly reduced learning rate for stability + +trainer: + val_check_interval: 0.5 + max_epochs: 100 + devices: 1 + +logger: + wandb: + group: 2AA_capped_diamines_conditioner_comparison + notes: "SpatioTemporalConditioner on 2AA capped diamines data - processes temporal sequences through spatial and temporal modules" + tags: ["spatiotemporal_conditioner", "2AA", "capped_diamines", "transformer", "generalization"] + diff --git a/configs/experiment/train_enhanced_denoised_conditioner.yaml b/configs/experiment/train_enhanced_denoised_conditioner.yaml new file mode 100644 index 0000000..099480a --- /dev/null +++ b/configs/experiment/train_enhanced_denoised_conditioner.yaml @@ -0,0 +1,60 @@ +# @package _global_ +# Training DenoisedConditioner on enhanced sampling data + +defaults: + - override /model: denoiser_conditional + - override /callbacks: + - timing.yaml + - lr_monitor.yaml + - model_checkpoint.yaml + +data: + datamodule: + batch_size: 32 + datasets: + train: + _target_: jamun.data.parse_datasets_from_directory + root: "/data2/sules/ALA_ALA_enhanced_full_grid/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + subsample: 1 + total_lag_time: 5 + lag_subsample_rate: 1 + # max_datasets: 1 + + val: + _target_: jamun.data.parse_datasets_from_directory + root: "/data2/sules/ALA_ALA_enhanced_full_grid/val" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + subsample: 1 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: ${data.datamodule.datasets.train.lag_subsample_rate} + #max_datasets: ${data.datamodule.datasets.train.max_datasets} + +model: + sigma_distribution: + _target_: jamun.distributions.ConstantSigma + sigma: 0.04 + arch: + n_layers: 2 + N_structures: ${data.datamodule.datasets.train.total_lag_time} + max_radius: 1000.0 + optim: + lr: 0.002 + conditioner: + _target_: jamun.model.conditioners.DenoisedConditioner + N_structures: ${model.arch.N_structures} + pretrained_model_path: "sule-shashank/jamun/370wpt17" + c_in: null # Will be computed automatically by training script + +trainer: + val_check_interval: 0.5 + max_epochs: 50 + devices: 1 + +logger: + wandb: + group: enhanced_sampling_conditioner_comparison + notes: "DenoisedConditioner pretrained from 370wpt17" + tags: ["denoised_conditioner", "enhanced_sampling", "pretrained"] \ No newline at end of file diff --git a/configs/experiment/train_enhanced_mean_conditioner.yaml b/configs/experiment/train_enhanced_mean_conditioner.yaml new file mode 100644 index 0000000..84dd0c8 --- /dev/null +++ b/configs/experiment/train_enhanced_mean_conditioner.yaml @@ -0,0 +1,58 @@ +# @package _global_ +# Training DenoisedConditioner on enhanced sampling data + +defaults: + - override /model: denoiser_conditional + - override /callbacks: + - timing.yaml + - lr_monitor.yaml + - model_checkpoint.yaml + +data: + datamodule: + batch_size: 32 + datasets: + train: + _target_: jamun.data.parse_datasets_from_directory + root: "/data2/sules/ALA_ALA_enhanced_full_grid/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + subsample: 1 + total_lag_time: 5 + lag_subsample_rate: 1 + # max_datasets: 1 + + val: + _target_: jamun.data.parse_datasets_from_directory + root: "/data2/sules/ALA_ALA_enhanced_full_grid/val" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + subsample: 1 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: ${data.datamodule.datasets.train.lag_subsample_rate} + #max_datasets: ${data.datamodule.datasets.train.max_datasets} + +model: + sigma_distribution: + _target_: jamun.distributions.ConstantSigma + sigma: 0.04 + arch: + n_layers: 2 + N_structures: ${data.datamodule.datasets.train.total_lag_time} + max_radius: 1000.0 + optim: + lr: 0.002 + conditioner: + _target_: jamun.model.conditioners.MeanConditioner + N_structures: ${model.arch.N_structures} + +trainer: + val_check_interval: 0.5 + max_epochs: 50 + devices: 1 + +logger: + wandb: + group: enhanced_sampling_conditioner_comparison + notes: "Mean conditioner" + tags: ["mean_conditioner", "enhanced_sampling"] \ No newline at end of file diff --git a/configs/experiment/train_enhanced_multimeasurement.yaml b/configs/experiment/train_enhanced_multimeasurement.yaml new file mode 100644 index 0000000..3fb129e --- /dev/null +++ b/configs/experiment/train_enhanced_multimeasurement.yaml @@ -0,0 +1,60 @@ +# @package _global_ +defaults: + - override /model: denoiser_multimeasurement + - override /callbacks: + - timing.yaml + - lr_monitor.yaml + - model_checkpoint.yaml + +data: + datamodule: + batch_size: 2 + datasets: + train: + _target_: jamun.data.parse_datasets_from_directory + root: "/data2/sules/ALA_ALA_enhanced_full_grid/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + subsample: 1 + total_lag_time: 5 + lag_subsample_rate: 1 + # max_datasets: 1 + + val: + _target_: jamun.data.parse_datasets_from_directory + root: "/data2/sules/ALA_ALA_enhanced_full_grid/val" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + subsample: 1 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: ${data.datamodule.datasets.train.lag_subsample_rate} + #max_datasets: ${data.datamodule.datasets.train.max_datasets} + +model: + sigma_distribution: + _target_: jamun.distributions.ConstantSigma + sigma: 0.04 + arch: + n_layers: 3 + N_structures: ${data.datamodule.datasets.train.total_lag_time} + max_radius: 1000.0 + optim: + lr: 0.002 + conditioner: + _target_: jamun.model.conditioners.PositionConditioner + N_structures: ${model.arch.N_structures} + # Override multimeasurement parameters for this experiment + N_measurements: 4 + N_measurements_hidden: 4 + max_graphs_per_batch: null + +trainer: + val_check_interval: 0.5 + max_epochs: 50 + devices: 1 + +logger: + wandb: + group: ALA_ALA_enhanced_full_grid_multimeasurement + notes: "Training multimeasurement model on ALA_ALA enhanced dataset with trajectory-based split" + tags: ["multimeasurement", "ala_ala", "enhanced_dataset", "trajectory_split"] \ No newline at end of file diff --git a/configs/experiment/train_enhanced_position_conditioner.yaml b/configs/experiment/train_enhanced_position_conditioner.yaml new file mode 100644 index 0000000..86fa303 --- /dev/null +++ b/configs/experiment/train_enhanced_position_conditioner.yaml @@ -0,0 +1,58 @@ +# @package _global_ +# Training PositionConditioner on enhanced sampling data + +defaults: + - override /model: denoiser_conditional + - override /callbacks: + - timing.yaml + - lr_monitor.yaml + - model_checkpoint.yaml + +data: + datamodule: + batch_size: 32 + datasets: + train: + _target_: jamun.data.parse_datasets_from_directory + root: "/data2/sules/ALA_ALA_enhanced_full_grid/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + subsample: 1 + total_lag_time: 5 + lag_subsample_rate: 1 + max_datasets: 250 + + val: + _target_: jamun.data.parse_datasets_from_directory + root: "/data2/sules/ALA_ALA_enhanced_full_grid/val" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + subsample: 1 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: ${data.datamodule.datasets.train.lag_subsample_rate} + max_datasets: 50 + +model: + sigma_distribution: + _target_: jamun.distributions.ConstantSigma + sigma: 0.04 + arch: + n_layers: 2 + N_structures: ${data.datamodule.datasets.train.total_lag_time} + max_radius: 1000.0 + optim: + lr: 0.002 + conditioner: + _target_: jamun.model.conditioners.PositionConditioner + N_structures: ${model.arch.N_structures} + +trainer: + val_check_interval: 0.5 + max_epochs: 50 + devices: 1 + +logger: + wandb: + group: enhanced_sampling_conditioner_comparison + notes: "PositionConditioner on enhanced sampling data" + tags: ["position_conditioner", "enhanced_sampling"] \ No newline at end of file diff --git a/configs/experiment/train_enhanced_pretrained_spatiotemporal_conditioner.yaml b/configs/experiment/train_enhanced_pretrained_spatiotemporal_conditioner.yaml new file mode 100644 index 0000000..9e31096 --- /dev/null +++ b/configs/experiment/train_enhanced_pretrained_spatiotemporal_conditioner.yaml @@ -0,0 +1,54 @@ +# @package _global_ +defaults: + - override /model: denoiser_conditional_spatiotemporal + - override /model/conditioner: spatiotemporal_pretrained + - override /callbacks: + - timing.yaml + - lr_monitor.yaml + - model_checkpoint.yaml + +data: + datamodule: + batch_size: 32 # Reduced batch size due to increased model complexity + datasets: + train: + _target_: jamun.data.parse_datasets_from_directory + root: "/data2/sules/ALA_ALA_enhanced_full_grid/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + subsample: 1 + total_lag_time: 5 + lag_subsample_rate: 1 + # max_datasets: 2 # Increased for more training data + + val: + _target_: jamun.data.parse_datasets_from_directory + root: "/data2/sules/ALA_ALA_enhanced_full_grid/val" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + subsample: 1 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: ${data.datamodule.datasets.train.lag_subsample_rate} + # max_datasets: 1 + +model: + sigma_distribution: + _target_: jamun.distributions.ConstantSigma + sigma: 0.04 + arch: + n_layers: 1 + max_radius: 1.0 + optim: + lr: 0.002 # Slightly reduced learning rate for stability + +trainer: + val_check_interval: 0.5 + max_epochs: 1 # Increased due to model complexity + # devices: 1 + # gradient_clip_val: 1.0 # Add gradient clipping for stability + +logger: + wandb: + group: enhanced_sampling_conditioner_comparison + notes: "SpatioTemporalConditioner on enhanced sampling data - processes temporal sequences through spatial and temporal modules" + tags: ["spatiotemporal_conditioner", "enhanced_sampling", "transformer", "e3conv"] \ No newline at end of file diff --git a/configs/experiment/train_enhanced_pretrained_spatiotemporal_conditioner_multimeasurement.yaml b/configs/experiment/train_enhanced_pretrained_spatiotemporal_conditioner_multimeasurement.yaml new file mode 100644 index 0000000..bfcaef3 --- /dev/null +++ b/configs/experiment/train_enhanced_pretrained_spatiotemporal_conditioner_multimeasurement.yaml @@ -0,0 +1,55 @@ +# @package _global_ +defaults: + - override /model: denoiser_multimeasurement + - override /model/arch: e3conv_conditional_spatiotemporal + - override /model/conditioner: spatiotemporal_pretrained + - override /callbacks: + - timing.yaml + - lr_monitor.yaml + - model_checkpoint.yaml + +data: + datamodule: + batch_size: 4 # Reduced batch size due to increased model complexity + datasets: + train: + _target_: jamun.data.parse_datasets_from_directory + root: "/data2/sules/ALA_ALA_enhanced_full_grid/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + subsample: 1 + total_lag_time: 5 + lag_subsample_rate: 1 + # max_datasets: 2 # Increased for more training data + + val: + _target_: jamun.data.parse_datasets_from_directory + root: "/data2/sules/ALA_ALA_enhanced_full_grid/val" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + subsample: 1 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: ${data.datamodule.datasets.train.lag_subsample_rate} + # max_datasets: 1 + +model: + sigma_distribution: + _target_: jamun.distributions.ConstantSigma + sigma: 0.04 + arch: + n_layers: 1 + max_radius: 1.0 + optim: + lr: 0.002 # Slightly reduced learning rate for stability + +trainer: + val_check_interval: 0.5 + max_epochs: 50 # Increased due to model complexity + # devices: 1 + # gradient_clip_val: 1.0 # Add gradient clipping for stability + +logger: + wandb: + group: enhanced_sampling_conditioner_comparison + notes: "SpatioTemporalConditioner on enhanced sampling data - processes temporal sequences through spatial and temporal modules" + tags: ["spatiotemporal_conditioner", "enhanced_sampling", "transformer", "e3conv"] \ No newline at end of file diff --git a/configs/experiment/train_enhanced_self_conditioner.yaml b/configs/experiment/train_enhanced_self_conditioner.yaml new file mode 100644 index 0000000..31f6275 --- /dev/null +++ b/configs/experiment/train_enhanced_self_conditioner.yaml @@ -0,0 +1,56 @@ +# @package _global_ +# Training SelfConditioner on enhanced sampling data + +defaults: + - override /model: denoiser_conditional + - override /callbacks: + - timing.yaml + - lr_monitor.yaml + - model_checkpoint.yaml + +data: + datamodule: + batch_size: 32 + datasets: + train: + _target_: jamun.data.parse_datasets_from_directory + root: "/data2/sules/ALA_ALA_enhanced_full_grid/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + subsample: 1 + total_lag_time: 5 + lag_subsample_rate: 1 + + val: + _target_: jamun.data.parse_datasets_from_directory + root: "/data2/sules/ALA_ALA_enhanced_full_grid/val" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + subsample: 1 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: ${data.datamodule.datasets.train.lag_subsample_rate} + +model: + sigma_distribution: + _target_: jamun.distributions.ConstantSigma + sigma: 0.04 + arch: + n_layers: 2 + N_structures: ${data.datamodule.datasets.train.total_lag_time} + max_radius: 1000.0 + optim: + lr: 0.002 + conditioner: + _target_: jamun.model.conditioners.SelfConditioner + N_structures: ${model.arch.N_structures} + +trainer: + val_check_interval: 0.5 + max_epochs: 50 + devices: 1 + +logger: + wandb: + group: enhanced_sampling_conditioner_comparison + notes: "SelfConditioner on enhanced sampling data" + tags: ["self_conditioner", "enhanced_sampling"] \ No newline at end of file diff --git a/configs/experiment/train_enhanced_spatiotemporal_conditioner.yaml b/configs/experiment/train_enhanced_spatiotemporal_conditioner.yaml new file mode 100644 index 0000000..43c8dce --- /dev/null +++ b/configs/experiment/train_enhanced_spatiotemporal_conditioner.yaml @@ -0,0 +1,53 @@ +# @package _global_ +defaults: + - override /model: denoiser_conditional_spatiotemporal + - override /callbacks: + - timing.yaml + - lr_monitor.yaml + - model_checkpoint.yaml + +data: + datamodule: + batch_size: 32 # Reduced batch size due to increased model complexity + datasets: + train: + _target_: jamun.data.parse_datasets_from_directory + root: "/data2/sules/ALA_ALA_enhanced_full_grid/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + subsample: 1 + total_lag_time: 5 + lag_subsample_rate: 1 + max_datasets: 1 # Increased for more training data + + val: + _target_: jamun.data.parse_datasets_from_directory + root: "/data2/sules/ALA_ALA_enhanced_full_grid/val" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + subsample: 1 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: ${data.datamodule.datasets.train.lag_subsample_rate} + max_datasets: 1 + +model: + sigma_distribution: + _target_: jamun.distributions.ConstantSigma + sigma: 0.04 + arch: + n_layers: 1 + max_radius: 1.0 + optim: + lr: 0.002 # Slightly reduced learning rate for stability + +trainer: + val_check_interval: 0.5 + max_epochs: 50 # Increased due to model complexity + # devices: 1 + # gradient_clip_val: 1.0 # Add gradient clipping for stability + +logger: + wandb: + group: enhanced_sampling_conditioner_comparison + notes: "SpatioTemporalConditioner on enhanced sampling data - processes temporal sequences through spatial and temporal modules" + tags: ["spatiotemporal_conditioner", "enhanced_sampling", "transformer", "debug run"] \ No newline at end of file diff --git a/configs/experiment/train_enhanced_spiked_conditioner.yaml b/configs/experiment/train_enhanced_spiked_conditioner.yaml new file mode 100644 index 0000000..347e05d --- /dev/null +++ b/configs/experiment/train_enhanced_spiked_conditioner.yaml @@ -0,0 +1,59 @@ +# @package _global_ +# Training DenoiserSpiked with ConditionerSpiked on enhanced sampling data + +defaults: + - override /model: denoiser_conditional + - override /callbacks: + - timing.yaml + - lr_monitor.yaml + - model_checkpoint.yaml + +data: + datamodule: + batch_size: 32 + datasets: + train: + _target_: jamun.data.parse_datasets_from_directory + root: "/data2/sules/ALA_ALA_enhanced_full_grid/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + subsample: 1 + total_lag_time: 5 + lag_subsample_rate: 1 + # max_datasets: 10 + + val: + _target_: jamun.data.parse_datasets_from_directory + root: "/data2/sules/ALA_ALA_enhanced_full_grid/val" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + subsample: 1 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: ${data.datamodule.datasets.train.lag_subsample_rate} + #max_datasets: ${data.datamodule.datasets.train.max_datasets} + +model: + _target_: jamun.model.denoiser_spiked.DenoiserSpiked + sigma_distribution: + _target_: jamun.distributions.ConstantSigma + sigma: 0.04 + arch: + n_layers: 2 + N_structures: ${data.datamodule.datasets.train.total_lag_time} + max_radius: 1000.0 + optim: + lr: 0.002 + conditioner: + _target_: jamun.model.conditioners.ConditionerSpiked + N_structures: ${model.arch.N_structures} + +trainer: + val_check_interval: 0.5 + max_epochs: 25 + devices: 1 + +logger: + wandb: + group: enhanced_sampling_spiked_conditioner + notes: "DenoiserSpiked with ConditionerSpiked - clean structure conditioning" + tags: ["spiked_conditioner", "enhanced_sampling", "clean_conditioning"] \ No newline at end of file diff --git a/configs/experiment/train_enhanced_standard_jamun.yaml b/configs/experiment/train_enhanced_standard_jamun.yaml new file mode 100644 index 0000000..83aaade --- /dev/null +++ b/configs/experiment/train_enhanced_standard_jamun.yaml @@ -0,0 +1,52 @@ +# @package _global_ +# Training standard JAMUN Denoiser on enhanced sampling data + +defaults: + - override /model: denoiser + - override /callbacks: + - timing.yaml + - lr_monitor.yaml + - model_checkpoint.yaml + +data: + datamodule: + batch_size: 32 + datasets: + train: + _target_: jamun.data.parse_datasets_from_directory + root: "/data2/sules/ALA_ALA_enhanced_full_grid/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + subsample: 1 + total_lag_time: 5 + lag_subsample_rate: 1 + + val: + _target_: jamun.data.parse_datasets_from_directory + root: "/data2/sules/ALA_ALA_enhanced_full_grid/val" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + subsample: 1 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: ${data.datamodule.datasets.train.lag_subsample_rate} + +model: + sigma_distribution: + _target_: jamun.distributions.ConstantSigma + sigma: 0.04 + arch: + n_layers: 4 + max_radius: 1.0 + optim: + lr: 0.002 + +trainer: + val_check_interval: 0.5 + max_epochs: 500 + devices: 1 + +logger: + wandb: + group: enhanced_sampling_conditioning_comparison + notes: "Standard JAMUN Denoiser on enhanced sampling data" + tags: ["standard_jamun", "enhanced_sampling", "no_conditioning"] \ No newline at end of file diff --git a/configs/experiment/train_test_single_shape.yaml b/configs/experiment/train_test_single_shape.yaml index 8d9d1ef..d972fea 100644 --- a/configs/experiment/train_test_single_shape.yaml +++ b/configs/experiment/train_test_single_shape.yaml @@ -5,7 +5,7 @@ model: _target_: jamun.distributions.ConstantSigma sigma: 0.04 arch: - n_layers: 2 + n_layers: 4 max_radius: 1000.0 optim: lr: 0.002 @@ -20,37 +20,45 @@ data: batch_size: 32 datasets: train: - - _target_: jamun.data.MDtrajDataset - root: "${paths.data_path}/timewarp/2AA-1-large/train/" - trajfiles: - - AA-traj-arrays.npz - pdbfile: AA-traj-state0.pdb - subsample: 100 - label: "AA" + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: False + subsample: 80 + start_frame: 0 + num_frames: 800000 val: - - _target_: jamun.data.MDtrajDataset - root: "${paths.data_path}/timewarp/2AA-1-large/train/" - trajfiles: - - AA-traj-arrays.npz - pdbfile: AA-traj-state0.pdb - subsample: 100 - label: "AA" + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: False + subsample: 80 + start_frame: 800000 + num_frames: 100000 test: - - _target_: jamun.data.MDtrajDataset - root: "${paths.data_path}/timewarp/2AA-1-large/train/" - trajfiles: - - AA-traj-arrays.npz - pdbfile: AA-traj-state0.pdb - subsample: 100 - label: "AA" + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: False + subsample: 80 + start_frame: 900000 + num_frames: 200000 trainer: val_check_interval: 0.5 - max_epochs: 1 + max_epochs: 100 logger: wandb: - group: train_test + group: model_comparison_delta_t_T_models_graphs + notes: "Standard JAMUN on 2AA capped diamines data" + tags: ["standard_jamun", "2AA", "capped_diamines", "denoiser", "generalization"] \ No newline at end of file diff --git a/configs/experiment/train_test_single_shape_conditional.yaml b/configs/experiment/train_test_single_shape_conditional.yaml new file mode 100644 index 0000000..a6c9663 --- /dev/null +++ b/configs/experiment/train_test_single_shape_conditional.yaml @@ -0,0 +1,78 @@ +# @package _global_ + +defaults: + - override /model: denoiser_conditional + - override /callbacks: + - timing.yaml + - lr_monitor.yaml + - model_checkpoint.yaml + +data: + datamodule: + batch_size: 32 + datasets: + train: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: False + subsample: 80 + total_lag_time: 8 + lag_subsample_rate: 10 + start_frame: 0 + num_frames: 800000 + + val: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: False + subsample: 80 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: ${data.datamodule.datasets.train.lag_subsample_rate} + start_frame: 800000 + num_frames: 100000 + + test: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: False + subsample: 80 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: ${data.datamodule.datasets.train.lag_subsample_rate} + start_frame: 900000 + num_frames: 200000 +model: + sigma_distribution: + _target_: jamun.distributions.ConstantSigma + sigma: 0.04 + arch: + n_layers: 2 + N_structures: ${data.datamodule.datasets.train.total_lag_time} + max_radius: 1000.0 + optim: + lr: 0.002 + conditioner: + _target_: jamun.model.conditioners.PositionConditioner + N_structures: ${model.arch.N_structures} + pretrained_model_path: null + c_in: null + + +trainer: + val_check_interval: 0.5 + max_epochs: 100 + + +logger: + wandb: + group: model_comparison_delta_t_T_models_graphs + notes: "PositionConditioner on 2AA capped diamines data" + tags: ["position_conditioner", "2AA", "capped_diamines", "conditional", "generalization"] \ No newline at end of file diff --git a/configs/experiment/train_test_single_shape_conditional_one_traj.yaml b/configs/experiment/train_test_single_shape_conditional_one_traj.yaml new file mode 100644 index 0000000..c9a3de8 --- /dev/null +++ b/configs/experiment/train_test_single_shape_conditional_one_traj.yaml @@ -0,0 +1,81 @@ +# @package _global_ + +defaults: + - override /model: denoiser_conditional + - override /callbacks: + - timing.yaml + - lr_monitor.yaml + - model_checkpoint.yaml + +data: + datamodule: + batch_size: 1 + num_workers: 4 + datasets: + train: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: false + subsample: 10 + total_lag_time: 5 + lag_subsample_rate: 1 + start_frame: 0 + num_frames: 11 + + val: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: false + subsample: 10 + total_lag_time: 5 + lag_subsample_rate: 1 + start_frame: 1 + num_frames: 11 + + + test: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: false + subsample: 10 + total_lag_time: 5 + lag_subsample_rate: 1 + start_frame: 2 + num_frames: 11 + +model: + sigma_distribution: + _target_: jamun.distributions.ConstantSigma + sigma: 0.04 + arch: + n_layers: 2 + N_structures: ${data.datamodule.datasets.train.total_lag_time} + max_radius: 1000.0 + optim: + lr: 0.002 + conditioner: + _target_: jamun.model.conditioners.PositionConditioner + N_structures: ${model.arch.N_structures} + multimeasurement: true + N_measurements: 100 + N_measurements_hidden: 1 + +trainer: + val_check_interval: 0.5 + max_epochs: 500 + + +logger: + wandb: + group: ALA_ALA, capped diamines, conditional denoiser + notes: "Running on ALA_ALA capped diamine, conditional denoiser" + tags: ["train", "capped_diamines", "conditional denoiser", "one_traj"] \ No newline at end of file diff --git a/configs/experiment/train_test_single_shape_fake_enhanced_sampling.yaml b/configs/experiment/train_test_single_shape_fake_enhanced_sampling.yaml new file mode 100644 index 0000000..3bae047 --- /dev/null +++ b/configs/experiment/train_test_single_shape_fake_enhanced_sampling.yaml @@ -0,0 +1,71 @@ +# @package _global_ + +defaults: + - override /model: denoiser_conditional + - override /callbacks: + - timing.yaml + - lr_monitor.yaml + - model_checkpoint.yaml + +data: + datamodule: + batch_size: 32 + datasets: + train: + _target_: jamun.data.parse_datasets_from_directory + root: "/data2/sules/fake_enhanced_data/ALA_ALA_organized/train" + traj_pattern: "^(.*).xtc" + pdb_file: "/data/bucket/kleinhej/capped_diamines/timewarp_splits/train/ALA_ALA.pdb" + as_iterable: false + subsample: 1 + total_lag_time: 5 + lag_subsample_rate: 1 + max_datasets: 5000 + + val: + _target_: jamun.data.parse_datasets_from_directory + root: "/data2/sules/fake_enhanced_data/ALA_ALA_organized/val" + traj_pattern: "^(.*).xtc" + pdb_file: "/data/bucket/kleinhej/capped_diamines/timewarp_splits/train/ALA_ALA.pdb" + as_iterable: false + subsample: 1 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: 1 + max_datasets: 5000 + + test: + _target_: jamun.data.parse_datasets_from_directory + root: "/data2/sules/fake_enhanced_data/ALA_ALA_organized/test" + traj_pattern: "^(.*).xtc" + pdb_file: "/data/bucket/kleinhej/capped_diamines/timewarp_splits/train/ALA_ALA.pdb" + as_iterable: false + subsample: 1 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: 1 + max_datasets: 5000 + +model: + sigma_distribution: + _target_: jamun.distributions.ConstantSigma + sigma: 0.04 + arch: + n_layers: 2 + N_structures: ${data.datamodule.datasets.train.total_lag_time} + max_radius: 1000.0 + optim: + lr: 0.002 + conditioner: + _target_: jamun.model.conditioners.SelfConditioner + N_structures: ${model.arch.N_structures} + pretrained_model_path: null + c_in: null + +trainer: + val_check_interval: 0.5 + max_epochs: 100 + +logger: + wandb: + group: ALA_ALA, enhanced sampling, conditional denoiser + notes: "Training conditional denoiser on ALA_ALA enhanced sampling data" + tags: ["train", "enhanced_sampling", "conditional_denoiser", "ALA_ALA"] \ No newline at end of file diff --git a/configs/experiment/train_test_single_shape_spatiotemporal_conditioner.yaml b/configs/experiment/train_test_single_shape_spatiotemporal_conditioner.yaml new file mode 100644 index 0000000..99eebd5 --- /dev/null +++ b/configs/experiment/train_test_single_shape_spatiotemporal_conditioner.yaml @@ -0,0 +1,73 @@ +# @package _global_ +# Training SpatioTemporalConditioner on 2AA capped diamines data + +defaults: + - override /model: denoiser_conditional_spatiotemporal + - override /callbacks: + - timing.yaml + - lr_monitor.yaml + - model_checkpoint.yaml + +data: + datamodule: + batch_size: 32 + datasets: + train: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: False + subsample: 80 + total_lag_time: 8 + lag_subsample_rate: 10 + start_frame: 0 + num_frames: 800000 + + val: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: False + subsample: 80 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: ${data.datamodule.datasets.train.lag_subsample_rate} + start_frame: 800000 + num_frames: 100000 + + test: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: False + subsample: 80 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: ${data.datamodule.datasets.train.lag_subsample_rate} + start_frame: 900000 + num_frames: 200000 + +model: + sigma_distribution: + _target_: jamun.distributions.ConstantSigma + sigma: 0.04 + arch: + n_layers: 1 + max_radius: 1.0 + optim: + lr: 0.002 # Slightly reduced learning rate for stability + +trainer: + val_check_interval: 0.5 + max_epochs: 100 + +logger: + wandb: + group: model_comparison_delta_t_T_models_graphs + notes: "SpatioTemporalConditioner on 2AA capped diamines data" + tags: ["spatiotemporal_conditioner", "2AA", "capped_diamines", "transformer", "generalization"] + diff --git a/configs/noise_check/ala_ala_denoiser_experiment_model1.yaml b/configs/noise_check/ala_ala_denoiser_experiment_model1.yaml new file mode 100644 index 0000000..92c9a73 --- /dev/null +++ b/configs/noise_check/ala_ala_denoiser_experiment_model1.yaml @@ -0,0 +1,73 @@ +# @package _global_ +# Model 1: Denoiser with self conditioner, two structures, noise level sigma + +defaults: + - override /model: denoiser_conditional + - override /callbacks: + - timing.yaml + - lr_monitor.yaml + - model_checkpoint.yaml + +data: + datamodule: + batch_size: 32 + datasets: + train: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: false + subsample: 1 + total_lag_time: 2 # Two structures: current + 1 hidden state + lag_subsample_rate: 1 + num_frames: 10000 + + val: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: false + subsample: 1 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: ${data.datamodule.datasets.train.lag_subsample_rate} + num_frames: 100 + + test: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: false + subsample: 1 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: 1 + num_frames: 10 + +model: + sigma_distribution: + _target_: jamun.distributions.ConstantSigma + sigma: 0.04 # Base sigma level + arch: + n_layers: 2 + N_structures: ${data.datamodule.datasets.train.total_lag_time} + max_radius: 1000.0 + optim: + lr: 0.002 + conditioner: + _target_: jamun.model.conditioners.SelfConditioner + N_structures: ${model.arch.N_structures} + +trainer: + max_epochs: 100 + val_check_interval: 0.5 + +logger: + wandb: + group: ALA_ALA_noise_check + notes: "Model 1: Denoiser with SelfConditioner, 2 structures, sigma=0.04" + tags: ["noise_check", "high noise", "identical measurements"] \ No newline at end of file diff --git a/configs/noise_check/ala_ala_denoiser_experiment_model2.yaml b/configs/noise_check/ala_ala_denoiser_experiment_model2.yaml new file mode 100644 index 0000000..5d47ce6 --- /dev/null +++ b/configs/noise_check/ala_ala_denoiser_experiment_model2.yaml @@ -0,0 +1,72 @@ +# @package _global_ +# Model 2: Denoiser with self conditioner, two structures, noise level sigma/sqrt(2) + +defaults: + - override /model: denoiser_conditional + - override /callbacks: + - timing.yaml + - lr_monitor.yaml + - model_checkpoint.yaml + +data: + datamodule: + batch_size: 32 + datasets: + train: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: false + subsample: 1 + total_lag_time: 2 # Two structures: current + 1 hidden state + lag_subsample_rate: 1 + num_frames: 10000 + + val: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: false + subsample: 1 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: ${data.datamodule.datasets.train.lag_subsample_rate} + num_frames: 100 + + test: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: false + subsample: 1 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: 1 + num_frames: 10 + +model: + sigma_distribution: + _target_: jamun.distributions.ConstantSigma + sigma: 0.02828427124 # sigma/sqrt(2) = 0.04/sqrt(2) ≈ 0.0283 + arch: + n_layers: 2 + N_structures: ${data.datamodule.datasets.train.total_lag_time} + max_radius: 1000.0 + optim: + lr: 0.002 + conditioner: + _target_: jamun.model.conditioners.SelfConditioner + N_structures: ${model.arch.N_structures} + +trainer: + max_epochs: 100 + +logger: + wandb: + group: ALA_ALA_noise_check + notes: "Model 2: Denoiser with SelfConditioner, 2 structures, sigma=0.04/sqrt(2)" + tags: ["noise_check", "low noise", "identical measurements"] \ No newline at end of file diff --git a/configs/noise_check/ala_ala_denoiser_experiment_model3.yaml b/configs/noise_check/ala_ala_denoiser_experiment_model3.yaml new file mode 100644 index 0000000..1f92fa5 --- /dev/null +++ b/configs/noise_check/ala_ala_denoiser_experiment_model3.yaml @@ -0,0 +1,73 @@ +# @package _global_ +# Model 3: Denoiser with position conditioner, noise level sigma, +# but hidden states are repeated copies of y.pos (noise added by denoiser) + +defaults: + - override /model: denoiser_conditional + - override /callbacks: + - timing.yaml + - lr_monitor.yaml + - model_checkpoint.yaml + +data: + datamodule: + batch_size: 32 + datasets: + train: + _target_: jamun.data.parse_repeated_position_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: false + subsample: 1 + total_lag_time: 2 # Two structures: current + 1 copy + lag_subsample_rate: 1 + num_frames: 10000 + + val: + _target_: jamun.data.parse_repeated_position_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: false + subsample: 1 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: ${data.datamodule.datasets.train.lag_subsample_rate} + num_frames: 100 + + test: + _target_: jamun.data.parse_repeated_position_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: false + subsample: 1 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: 1 + num_frames: 10 + +model: + sigma_distribution: + _target_: jamun.distributions.ConstantSigma + sigma: 0.04 # Base sigma level for denoising + arch: + n_layers: 2 + N_structures: ${data.datamodule.datasets.train.total_lag_time} + max_radius: 1000.0 + optim: + lr: 0.002 + conditioner: + _target_: jamun.model.conditioners.PositionConditioner + N_structures: ${model.arch.N_structures} + +trainer: + max_epochs: 200 + +logger: + wandb: + group: ALA_ALA_noise_check + notes: "Model 3: Denoiser with PositionConditioner, 2 structures, sigma=0.04" + tags: ["noise_check", "high noise", "non-identical, i.i.d. measurements"] \ No newline at end of file diff --git a/configs/noise_check/ala_ala_denoiser_experiment_model4.yaml b/configs/noise_check/ala_ala_denoiser_experiment_model4.yaml new file mode 100644 index 0000000..77e181a --- /dev/null +++ b/configs/noise_check/ala_ala_denoiser_experiment_model4.yaml @@ -0,0 +1,74 @@ +# @package _global_ +# Model 3: Denoiser with position conditioner, noise level sigma, +# but hidden states are repeated copies of y.pos (noise added by denoiser) + +defaults: + - override /model: denoiser_conditional + - override /callbacks: + - timing.yaml + - lr_monitor.yaml + - model_checkpoint.yaml + +data: + datamodule: + batch_size: 32 + datasets: + train: + _target_: jamun.data.parse_repeated_position_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: false + subsample: 1 + total_lag_time: 2 # Two structures: current + 1 copy + lag_subsample_rate: 1 + num_frames: 10000 + + val: + _target_: jamun.data.parse_repeated_position_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: false + subsample: 1 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: ${data.datamodule.datasets.train.lag_subsample_rate} + num_frames: 100 + + test: + _target_: jamun.data.parse_repeated_position_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: false + subsample: 1 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: 1 + num_frames: 10 + +model: + sigma_distribution: + _target_: jamun.distributions.ConstantSigma + sigma: 0.04 # Base sigma level for denoising + arch: + n_layers: 2 + N_structures: ${data.datamodule.datasets.train.total_lag_time} + max_radius: 1000.0 + optim: + lr: 0.002 + conditioner: + _target_: jamun.model.conditioners.PositionConditioner + N_structures: ${model.arch.N_structures} + align_hidden_states: false + +trainer: + max_epochs: 100 + +logger: + wandb: + group: ALA_ALA_noise_check + notes: "Model 4: Denoiser with PositionConditioner, 2 structures, sigma=0.04, hidden states not aligned" + tags: ["noise_check", "high noise", "non-identical, i.i.d. measurements", "hidden states not aligned"] \ No newline at end of file diff --git a/configs/noise_check/ala_ala_denoiser_experiment_spike_check.yaml b/configs/noise_check/ala_ala_denoiser_experiment_spike_check.yaml new file mode 100644 index 0000000..b99df8a --- /dev/null +++ b/configs/noise_check/ala_ala_denoiser_experiment_spike_check.yaml @@ -0,0 +1,73 @@ +# @package _global_ +# Model 3: Denoiser with position conditioner, noise level sigma, +# but hidden states are repeated copies of y.pos (noise added by denoiser) + +defaults: + - override /model: denoiser_conditional + - override /callbacks: + - timing.yaml + - lr_monitor.yaml + - model_checkpoint.yaml + +data: + datamodule: + batch_size: 32 + datasets: + train: + _target_: jamun.data.parse_repeated_position_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: false + subsample: 1 + total_lag_time: 2 # Two structures: current + 1 copy + lag_subsample_rate: 1 + num_frames: 10000 + + val: + _target_: jamun.data.parse_repeated_position_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: false + subsample: 1 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: ${data.datamodule.datasets.train.lag_subsample_rate} + num_frames: 100 + + test: + _target_: jamun.data.parse_repeated_position_datasets_from_directory + root: "${paths.data_path}/capped_diamines/timewarp_splits/train" + traj_pattern: "^(.*).xtc" + pdb_pattern: "^(.*).pdb" + filter_codes: ['ALA_ALA'] + as_iterable: false + subsample: 1 + total_lag_time: ${data.datamodule.datasets.train.total_lag_time} + lag_subsample_rate: 1 + num_frames: 10 + +model: + sigma_distribution: + _target_: jamun.distributions.ConstantSigma + sigma: 0.04 # Base sigma level for denoising + arch: + n_layers: 2 + N_structures: ${data.datamodule.datasets.train.total_lag_time} + max_radius: 1000.0 + optim: + lr: 0.002 + conditioner: + _target_: jamun.model.conditioners.PositionConditioner + N_structures: ${model.arch.N_structures} + +trainer: + max_epochs: 100 + +logger: + wandb: + group: ALA_ALA_noise_check_spike_check + notes: "Model 3: Denoiser with PositionConditioner, 2 structures, sigma=0.04" + tags: ["noise_check", "high noise", "non-identical, i.i.d. measurements", "spike check"] \ No newline at end of file diff --git a/configs/sweep.yaml b/configs/sweep.yaml new file mode 100644 index 0000000..6f8c3a1 --- /dev/null +++ b/configs/sweep.yaml @@ -0,0 +1,23 @@ +program: jamun_train +method: grid +project: jamun +name: conditional_vs_self +metric: + name: val/loss + goal: minimize +parameters: + model.conditioner._target_: + values: + - jamun.model.conditioners.PositionConditioner + - jamun.model.conditioners.SelfConditioner + model.N_measurements_hidden: + values: [1, 5, 10, 20, 25] + model.sigma_distribution.sigma: + values: [0.04, 0.1085767, 0.29472252, 0.8] + +command: + - ${program} + - "--config-dir=configs" + - "experiment=train_test_single_shape_conditional_one_traj" + - "++trainer.log_every_n_steps=10" + - ${args_no_hyphens} \ No newline at end of file diff --git a/configs/sweep_conditioning.yaml b/configs/sweep_conditioning.yaml new file mode 100644 index 0000000..bfa7b25 --- /dev/null +++ b/configs/sweep_conditioning.yaml @@ -0,0 +1,21 @@ +program: jamun_train +method: grid +project: jamun +name: conditional_vs_self +metric: + name: val/loss + goal: minimize +parameters: + model.conditioner._target_: + values: + - jamun.model.conditioners.PositionConditioner + - jamun.model.conditioners.SelfConditioner + data.datamodule.datasets.train.total_lag_time: + values: [4, 5, 6, 7, 8, 9, 10] +command: + - ${program} + - "--config-dir=configs" + - "experiment=train_test_single_shape_conditional" + - "++trainer.log_every_n_steps=10" + - "++paths.root_path=/data2/sules/jamun-conditional-runs" + - ${args_no_hyphens} \ No newline at end of file diff --git a/configs/sweep_conditioning_noise.yaml b/configs/sweep_conditioning_noise.yaml new file mode 100644 index 0000000..4f3f293 --- /dev/null +++ b/configs/sweep_conditioning_noise.yaml @@ -0,0 +1,24 @@ +program: jamun_train +method: grid +project: jamun +name: conditional_vs_self +metric: + name: val/loss + goal: minimize +parameters: + model.conditioner._target_: + values: + - jamun.model.conditioners.PositionConditioner + - jamun.model.conditioners.SelfConditioner + data.datamodule.datasets.train.total_lag_time: + values: [4, 6, 8, 10] + model.sigma_distribution.sigma: + values: [0.01, 0.03684031, 0.13572088, 0.5] +command: + - ${program} + - "--config-dir=configs" + - "experiment=train_test_single_shape_conditional" + - "++trainer.log_every_n_steps=10" + - "++trainer.max_epochs=100" + - "++paths.root_path=/data2/sules/jamun-conditional-runs" + - ${args_no_hyphens} \ No newline at end of file diff --git a/configs/sweep_delta_friction.yaml b/configs/sweep_delta_friction.yaml new file mode 100644 index 0000000..cd3e24c --- /dev/null +++ b/configs/sweep_delta_friction.yaml @@ -0,0 +1,19 @@ +program: jamun_sample +method: grid +project: jamun +name: sweep_delta_friction_large_noise +metric: + name: Jenson-Shannon Divergence vs. Number of Samples for Predicted Trajectory joined + goal: minimize +parameters: + delta: + values: [0.012, 0.039, 0.066, 0.093, 0.12] + friction: + values: [2.52572864, 1.2552661, 0.71334989, 0.36384343, 0.10536052] + +command: + - ${program} + - "--config-dir=configs" + - "experiment=sample_enhanced_conditioning_sweep" + - "++batch_sampler.mcmc.history_update_frequency=100" + - ${args_no_hyphens} \ No newline at end of file diff --git a/docs/KALA_JAMUN_documentation.md b/docs/KALA_JAMUN_documentation.md new file mode 100644 index 0000000..d270791 --- /dev/null +++ b/docs/KALA_JAMUN_documentation.md @@ -0,0 +1,1015 @@ +# KALA-JAMUN: Spatiotemporal Conditional Generation Documentation + +## Introduction + +KALA-JAMUN introduces conditioning into the JAMUN workflow to enable temporal-aware molecular generation. The key innovation is conditioning the denoising process on past noisy states, allowing the model to learn and maintain temporal correlations in molecular dynamics. This enhancement necessitated significant changes across three core components of the system: + +1. **Modified Dataset Infrastructure**: To support conditioning on historical states +2. **Enhanced Model Architectures**: To process both current and historical information +3. **Memory-Aware Sampling**: To maintain temporal consistency during generation + +The conditioning mechanism works by feeding past noisy states directly to the model as part of the input data. This enables the model to learn temporal dependencies and generate more realistic molecular trajectories that respect the underlying dynamics. + +This document provides a comprehensive guide covering the complete KALA-JAMUN workflow from data preparation through model architecture to sampling procedures. + +## Table of Contents + +1. [Chapter 1: Datasets](#chapter-1-datasets) +2. [Chapter 2: Architecture](#chapter-2-architecture) +3. [Chapter 3: Sampling](#chapter-3-sampling) + +--- + +## Chapter 1: Datasets + +### Overview + +In KALA-JAMUN, the conditioning is based on past noisy states that are fed directly to the model as part of the input data. This required fundamental modifications to the data structure itself. + +### Data Structure Modifications + +The core data class `DataWithResidueInformation` has been enhanced with a new field to support temporal conditioning: + +**Source:** [`src/jamun/utils/data_with_residue_info.py`](src/jamun/utils/data_with_residue_info.py), lines 5-16 + +```python +class DataWithResidueInformation(torch_geometric.data.Data): + """Graph with residue-level information.""" + + pos: torch.Tensor + atom_type_index: torch.Tensor + atom_code_index: torch.Tensor + residue_code_index: torch.Tensor + residue_sequence_index: torch.Tensor + residue_index: torch.Tensor + num_residues: int + loss_weight: float + hidden_state: Any # NEW: Stores past trajectory states +``` + +**Key Addition:** +- **`hidden_state`**: This new field stores the historical molecular configurations that enable temporal conditioning. It keeps the past trajectories that the model will condition on during the denoising process. + +This modification allows the data graph to carry both current molecular state (`pos`) and historical context (`hidden_state`), enabling the model to learn temporal dependencies in molecular dynamics. + +### 1.1 MDTrajDataset with Subsampling + +The core dataset class `MDTrajDataset` has been enhanced to support KALA-JAMUN's conditioning mechanism through historical state management. + +#### Enhanced MDTrajDataset Structure + +When a data graph is processed in KALA-JAMUN: +- **`graph.pos`**: Contains the current molecular state (positions) +- **`graph.hidden_state`**: Contains `total_lag_time - 1` past states + +The historical states are selected using two key parameters: +- **`lag_subsample_rate`**: Temporal difference between consecutive stored states +- **`total_lag_time`**: Total number of states stored (including the present state) + +#### Subsampling Implementation + +**Source:** [`src/jamun/data/_mdtraj.py`](src/jamun/data/_mdtraj.py), lines 249-261 (in MDTrajDataset.__init__) + +```python +if total_lag_time is not None and lag_subsample_rate is not None: + lagged_indices = get_subsampled_indices( + self.traj.n_frames, subsample, total_lag_time, lag_subsample_rate + ) + # Extract subsampled indices (first element of each list) + subsampled_indices = [indices[0] for indices in lagged_indices] + # Extract lagged indices (all except first element) + self.lagged_indices = [indices[1:] for indices in lagged_indices] + # Subsample the trajectory using the subsampled indices + self.hidden_state = [self.traj[indices] for indices in self.lagged_indices] + self.traj = self.traj[subsampled_indices] # self.traj is permanently modified. +``` + +**Example**: With `total_lag_time=5` and `lag_subsample_rate=10`: +- Present state: frame 100 +- Historical states: frames 90, 80, 70, 60 +- Total stored: 5 states (1 current + 4 historical) + +### 1.2 Loading with parse_datasets_from_directory + +Such datasets can be loaded using the `parse_datasets_from_directory` function: + +**Source:** [`src/jamun/data/_utils.py`](src/jamun/data/_utils.py), lines 38-49 (function definition) + +```python +datasets = parse_datasets_from_directory( + root="/data/trajectories", + traj_pattern=r"traj_(\w+)\.dcd", + pdb_pattern=r"(\w+)\.pdb", + total_lag_time=5, + lag_subsample_rate=10, + max_datasets=100 +) +``` + +This function automatically: +- Discovers trajectory files using regex patterns +- Matches trajectory files with corresponding PDB topology files +- Creates `MDTrajDataset` objects with proper historical state management +- Applies the index subsampling procedure to populate `hidden_state` + +### 1.3 RepeatedPositionDataset (Multimeasurement) + +Multimeasurement refers to collecting T independent noisy copies of the same present state. This enables the model to learn from multiple noise realizations applied to identical molecular configurations, improving robustness and sample diversity. + +#### Concept + +Instead of using historical states from different time points, multimeasurement uses: +- **`batch.pos`**: The current molecular state +- **`batch.hidden_state`**: Contains `T-1` copies of the same `pos` + +When noise is independently added in the denoiser, independent realizations of the noise get added to the exact same underlying state, allowing the model to learn the noise distribution more effectively. + +#### MDTrajRepeatedDataset + +To enable multimeasurement, KALA-JAMUN provides a specialized dataset: + +**Source:** [`src/jamun/data/noisy_position_dataset.py`](src/jamun/data/noisy_position_dataset.py), lines 5-37 + +```python +class RepeatedPositionDataset(MDtrajDataset): + def __getitem__(self, idx: int) -> torch_geometric.data.Data: + graph = super().__getitem__(idx) + # Hidden state contains repeated copies of the current position + # instead of historical states from different time points + return graph +``` + +#### Loading with parse_repeated_position_datasets_from_directory + +Multimeasurement datasets are created using: + +**Source:** [`src/jamun/data/_utils.py`](src/jamun/data/_utils.py), lines 362-373 (function definition) + +```python +datasets = parse_repeated_position_datasets_from_directory( + root="/data/trajectories", + traj_pattern=r"traj_(\w+)\.dcd", + pdb_pattern=r"(\w+)\.pdb", + # No temporal parameters - using repeated states instead + max_datasets=100 +) +``` + +**Key Difference**: +- **Temporal Conditioning**: `hidden_state` contains past states from different time points +- **Multimeasurement**: `hidden_state` contains repeated copies of the current state + +This approach allows the model to: +1. Learn from multiple noise realizations on the same state +2. Improve denoising performance through ensemble-like training +3. Generate more diverse samples from the same initial condition + +--- + +## Chapter 2: Architecture + +### Overview + +KALA-JAMUN implements a new denoiser class, `denoiser_conditional.Denoiser`, which handles the internal training process for temporal conditioning. This model consists of two main operating submodules, the conditioning module and the architecture module, called according to the following workflow: + +1. Just like `jamun.model.Denoiser`, the modules within `jamun.model.denoiser_conditional.Denoiser` are called from within `xhat_normalized`. +2. The **Conditioning Module**: `denoiser_conditional.Denoiser.conditioner` - first calculates features based on historical states +3. Next, the **Architecture Module**: `denoiser_conditional.Denoiser.g` - processes features from the conditioning module into the final output, which is combined with the noisy coordinates with the normalization factors to construct xhat. + + +**Key Components:** +1. **Conditioning Module**: Implements various conditioning strategies, with the spatiotemporal conditioner being the most sophisticated +2. **Architecture Module (model.g)**: Enhanced E3Conv variants that handle conditional inputs +3. **Training Process**: Integrated mean centering, alignment, scaling, propagation, and loss computation + +**⚠️ Important**: The input signatures for `denoiser_conditional.Denoiser` do not match those of the original `Denoiser`. + + + +### 2.1 Conditioning Module + +Every conditioning module in KALA-JAMUN is of the class `Conditioner`, which defines the interface for calculating features based on historical states. + +**Source:** [`src/jamun/model/conditioners/conditioners.py`](src/jamun/model/conditioners/conditioners.py) + +#### Available Conditioning Modules + +KALA-JAMUN provides several conditioning strategies: + +1. **PositionConditioner**: Returns just the input positions (baseline) +2. **MeanConditioner**: Provides mean-centered positions and repeated structures +3. **SpatioTemporalConditioner**: Uses spatiotemporal processing for feature extraction (most sophisticated) + +#### SpatioTemporalConditioner + +The most sophisticated conditioning module uses a spatiotemporal GNN to output features based on both current and historical states. This conditioner processes the input through a complete spatiotemporal architecture before passing features to the main denoiser. + +**Source:** [`src/jamun/model/conditioners/conditioners.py`](src/jamun/model/conditioners/conditioners.py), SpatioTemporalConditioner class + +```python +class SpatioTemporalConditioner(pl.LightningModule): + def forward(self, y: torch_geometric.data.Batch) -> list[torch.Tensor]: + # Process through spatiotemporal model + spatial_features = self.spatiotemporal_model(y, c_noise=self.c_noise) + + # Return [positions, features] for concatenation + return [y.pos, spatial_features] +``` + +### 2.2 Spatiotemporal Model (E3SpatioTemporal) + +The spatiotemporal model is the core of the `SpatioTemporalConditioner`. It processes molecular data through several stages: + +**Source:** [`src/jamun/model/arch/spatiotemporal.py`](src/jamun/model/arch/spatiotemporal.py), lines 403+ + +#### Architecture Components + +1. **Spatial Module**: Processes individual molecular graphs (current and historical states) +2. **Temporal Graph Construction**: Converts spatial graphs into temporal graphs with temporal connections +3. **Temporal Module (E3Transformer)**: Applies transformer architecture on temporal graphs +4. **Reconversion**: Converts temporal graph features back to spatial representation + +#### Spatial Module Processing + +The spatial module processes each molecular configuration (current and historical) independently: + +```python +# Process current positions +node_attr_current = self.spatial_module( + pos=batch.pos, + topology=topology, + batch=batch.batch, + num_graphs=batch.num_graphs, + c_noise=c_noise, + effective_radial_cutoff=self.radial_cutoff +) + +# Process historical positions +for hidden_pos in batch.hidden_state: + node_attr_hidden = self.spatial_module( + pos=hidden_pos, + topology=topology, + batch=batch.batch, + num_graphs=batch.num_graphs, + c_noise=c_noise, + effective_radial_cutoff=self.radial_cutoff + ) +``` + +#### Temporal Graph Construction + +After spatial processing, the individual graphs are converted into temporal graphs where nodes across different time steps are connected: + +**Source:** [`src/jamun/model/arch/spatiotemporal.py`](src/jamun/model/arch/spatiotemporal.py), create_temporal_graph function + +The temporal graph construction creates edges between the intertemporal copies of the same atom. This gives some freedom as to how to define the connectivity structure of the temporal graph. Three stratgies that have been explored are: + +1. Fan graph--the present node connects to all nodes, and the ith historical node connects to the (i+1)th and (i-1)th historical node (whenever such nodes are available). +2. Hub and spoke--the present node connects to all nodes, and no historical nodes are mutually connected. +3. Complete graph--all nodes are mutually connected. + +In the temporal graph we also need to define what features the nodes and the edges have. This will be discussed below. + +#### E3Transformer (Temporal Module) + +The temporal module applies a transformer architecture specifically designed for temporal graphs: + +**Source:** [`src/jamun/model/arch/spatiotemporal.py`](src/jamun/model/arch/spatiotemporal.py), lines 217-284 + +##### Temporal Embeddings and Encoding Functions + +The E3Transformer uses several specialized encoding functions to handle temporal information: + +**Key Parameters:** +- **`irreps_node_attr_temporal`**: Irreducible representations for temporal node attributes (default: "1x1e") +- **`node_attr_temporal_encoding_function`**: Encodes temporal position information (default: "gaussian") +- **`edge_attr_temporal_encoding_function`**: Encodes temporal edge attributes (default: "gaussian") +- **`radial_edge_attr_encoding_function`**: Encodes radial distances (default: "gaussian") + +##### Temporal Attribute Processing + +```python +# Split edge attribute dimensions: radial and temporal +self.radial_edge_attr_dim = self.edge_attr_dim // 2 +self.temporal_edge_attr_dim = self.edge_attr_dim - self.radial_edge_attr_dim + +# Temporal gate for combining node attributes with temporal position +irreps_with_temporal = self.irreps_node_attr + self.irreps_node_attr_temporal +self.temporal_gate = e3tools.nn.GateWrapper( + irreps_in=irreps_with_temporal, + irreps_out=self.irreps_hidden, + irreps_gate=irreps_with_temporal, +) +``` + +The temporal gate combines: +- **Node attributes**: From spatial processing of individual timesteps +- **Temporal position**: Encoded position in the temporal sequence +- **Temporal edges**: Connections between atoms across different timesteps + +##### Transformer Layers + +The temporal transformer processes the combined spatial-temporal information through multiple attention layers: + +**Source:** [`src/jamun/model/arch/spatiotemporal.py`](src/jamun/model/arch/spatiotemporal.py), lines 267-284 + +```python +for _ in range(num_layers): + self.layers.append( + e3tools.nn.TransformerBlock( + irreps_in=self.irreps_hidden, + irreps_out=self.irreps_hidden, + irreps_sh=self.irreps_sh, + edge_attr_dim=self.edge_attr_dim, + num_heads=self.num_attention_heads, + conv=self.conv, + ) + ) +``` + +#### Integrating Pretrained Spatial Modules + +A powerful feature of KALA-JAMUN is the ability to use a pretrained unconditional JAMUN model as the spatial module within the spatiotemporal architecture. This enables leveraging existing trained models as building blocks for more sophisticated temporal conditioning. + +##### Architecture Overview + +In this setup, we have: +- **Overlying Conditional Denoiser**: `jamun.model.denoiser_conditional.Denoiser` - the main KALA-JAMUN model +- **Sub-Denoiser**: `jamun.model.Denoiser` - the pretrained unconditional JAMUN model used as spatial module + +The overlying conditional denoiser has its own normalization factor `c_in`, but when scaled data `y_scaled = c_in * y` goes into the sub-denoiser, this scaling must be divided out since the sub-denoiser has its own internal normalization. + +##### DenoiserWrapper: Input Signature Unification + +The core challenge is that `Denoiser.xhat` and `E3Conv` have different input signatures, requiring a wrapper to make them compatible: + +**Source:** [`src/jamun/utils/pretrained_wrapper.py`](src/jamun/utils/pretrained_wrapper.py), lines 56-85 + +```python +class DenoiserWrapper(nn.Module): + """ + Wrapper around a denoiser model that matches the spatial module interface. + + This allows pretrained denoiser models to be used as spatial/temporal modules + in the spatiotemporal architecture by replicating the full denoiser logic + including normalization factors computed from the denoiser's own parameters. + """ + + def __init__(self, denoiser_model: nn.Module, c_in: float = 1.0, trainable: bool = True): + """ + Args: + denoiser_model: The pretrained denoiser model + c_in: Rescaling factor to convert positions from overlaying model scale + trainable: Whether to keep the model trainable (default: True) + """ + super().__init__() + self.denoiser = denoiser_model + self.c_in = c_in # Rescaling factor from overlying denoiser +``` + +##### Rescaling Mechanism + +The `DenoiserWrapper` handles the critical `c_in` rescaling between the overlying and sub-denoiser: + +**Source:** [`src/jamun/utils/pretrained_wrapper.py`](src/jamun/utils/pretrained_wrapper.py), lines 98-118 + +```python +def forward(self, pos, topology, batch, num_graphs, c_noise, effective_radial_cutoff): + # Sample sigma from the denoiser's own sigma distribution + sigma = self.denoiser.sigma_distribution.sample().to(pos.device) + + # CRITICAL: Rescale positions from overlaying model scale + y = pos / self.c_in # Divide out overlying denoiser's c_in + + # Apply sub-denoiser's own normalization + c_in, c_skip, c_out, _ = compute_normalization_factors( + sigma, + average_squared_distance=self.denoiser.average_squared_distance, + normalization_type=self.denoiser.normalization_type, + sigma_data=self.denoiser.sigma_data, + D=y.shape[-1], + device=y.device, + ) +``` + +**Key Steps:** +1. **Input Rescaling**: `y = pos / self.c_in` - divides out the overlying denoiser's scaling +2. **Internal Normalization**: The sub-denoiser computes its own `c_in`, `c_skip`, `c_out` +3. **Denoiser Processing**: Full `xhat_normalized` logic is replicated internally +4. **Output**: Features compatible with the spatiotemporal architecture + +This design allows seamless integration of pretrained models while maintaining proper normalization hierarchies between the overlying conditional denoiser and the embedded unconditional denoiser. + +### 2.3 Architecture Module (model.g) + +The architecture module `model.g` is the main processing unit that receives features from the conditioning module. Unlike standard E3Conv models, these variants are designed to handle conditional inputs. + +#### E3ConvConditional vs Standard E3Conv + +The key difference from standard E3Conv models is the ability to process multiple input structures and conditional information: + +**Source:** [`src/jamun/model/arch/e3conv_conditional.py`](src/jamun/model/arch/e3conv_conditional.py), lines 15-40 + +```python +class E3ConvConditional(torch.nn.Module): + def __init__( + self, + # Standard E3Conv parameters... + N_structures: int = 1, # NEW: Number of input structures + # ... + ): +``` + +**Key Features:** +- **Multi-Structure Support**: Processes multiple molecular structures simultaneously via `N_structures` +- **Noise Conditioning**: Integrates noise level information throughout the network +- **Skip Connections**: Uses noise-conditional skip connections for stable training + +#### E3ConvConditionalSpatioTemporal + +A specialized variant designed specifically for spatiotemporal conditioning: + +**Source:** [`src/jamun/model/arch/e3conv_conditional.py`](src/jamun/model/arch/e3conv_conditional.py), lines 312+ + +This variant handles concatenated position and feature data from the spatiotemporal model: + +```python +def forward( + self, + pos: Tensor, # [N, 3 + spatial_features_dim] from [y.pos, spatial_features] + topology: torch_geometric.data.Batch, + c_noise: Tensor, + effective_radial_cutoff: float, +) -> Tensor: + # Split positions: first 3 coords are physical, rest are features + pos_physical = pos[:, :3] # [N, 3] - physical coordinates + pos_features = pos[:, 3:] # [N, spatial_features_dim] - spatial features + + # Compute edge attributes using ONLY physical positions + edge_vec_physical = pos_physical[src] - pos_physical[dst] + edge_sh = self.sh(edge_vec_physical) + + # Combine node attributes with spatial features + combined_attr = torch.cat([node_attr, pos_features], dim=-1) + node_attr = self.spatial_feature_aggregator(combined_attr) +``` + +**Design Principle**: Separates physical coordinates (for geometric operations) from feature coordinates (for conditioning). + +### 2.4 Noising and Denoising Process + +KALA-JAMUN implements a comprehensive training pipeline that handles mean centering, alignment, scaling, model propagation, and loss computation. + +**Source:** [`src/jamun/model/denoiser_conditional.py`](src/jamun/model/denoiser_conditional.py), xhat and xhat_normalized methods + +#### Complete Denoising Workflow + +The denoising process follows these steps: + +##### 1. Mean Centering (Input Preparation) + +**Source:** [`src/jamun/model/denoiser_conditional.py`](src/jamun/model/denoiser_conditional.py), lines 308-311 + +```python +if self.mean_center: + y = mean_center(y) + y = self._mean_center_hidden_states(y) +``` + +Both current and historical states are mean-centered to ensure translational invariance. + +##### 2. Alignment (if enabled) + +Molecular configurations are aligned to a reference to handle rotational variance. This is done after adding noise but before denoising during training. + +##### 3. Scaling (Normalization) + +**Source:** [`src/jamun/model/denoiser_conditional.py`](src/jamun/model/denoiser_conditional.py), lines 264-286 + +```python +# Compute normalization factors +c_in, c_skip, c_out, c_noise = self.normalization_factors(sigma, D) + +# Scale input positions and hidden states +y_scaled = y.clone() +y_scaled.pos = y.pos * c_in +if hasattr(y, "hidden_state") and y.hidden_state is not None: + y_scaled.hidden_state = [] + for positions in y.hidden_state: + y_scaled.hidden_state.append(positions * c_in) +``` + +**Key scaling factors:** +- **`c_in`**: Scales input coordinates and hidden states +- **`c_skip`**: Skip connection scaling +- **`c_out`**: Output scaling +- **`c_noise`**: Noise conditioning scaling + +##### 4. Model Propagation + +**Source:** [`src/jamun/model/denoiser_conditional.py`](src/jamun/model/denoiser_conditional.py), lines 294-299 + +```python +# Step 1: Conditioning +conditioned_structures = self.conditioner(y_scaled) + +# Step 2: Architecture processing +g_pred = self.g(torch.cat([*conditioned_structures], dim=-1), + topology=y_scaled, + c_noise=c_noise, + effective_radial_cutoff=radial_cutoff) +``` + +The conditioner processes scaled inputs and hidden states, then the architecture module processes the concatenated conditioned structures. + +##### 5. Skip Connection and Output Scaling + +**Source:** [`src/jamun/model/denoiser_conditional.py`](src/jamun/model/denoiser_conditional.py), lines 301-304 + +```python +# Apply skip connection and output scaling +xhat.pos = c_skip * y.pos + c_out * g_pred + +# Update hidden state for next iteration +if hasattr(y, "hidden_state") and y.hidden_state is not None: + xhat.hidden_state = [y.pos, *y.hidden_state[:-1]] +``` + +##### 6. Mean Centering (Output) + +**Source:** [`src/jamun/model/denoiser_conditional.py`](src/jamun/model/denoiser_conditional.py), lines 317-319 + +```python +# Mean center the prediction +if self.mean_center: + xhat = mean_center(xhat) +``` + +##### 7. Loss Computation + +The loss is computed between the denoised prediction and the clean target, typically using MSE loss on the positions while maintaining proper handling of the hidden states for temporal consistency. + +#### Key Differences from Standard Denoiser + +1. **Hidden State Handling**: All scaling operations are applied to both current and hidden states +2. **Conditioning Integration**: The conditioner processes scaled inputs before the main architecture +3. **Temporal Consistency**: Hidden states are properly updated to maintain temporal sequence +4. **Multi-Structure Processing**: Conditioned structures are concatenated before processing + +This comprehensive pipeline ensures that KALA-JAMUN properly handles temporal information throughout the entire denoising process, maintaining consistency between current and historical states while applying the appropriate transformations for effective learning. + +--- + +## Chapter 3: Sampling + +### Overview + +Once the KALA-JAMUN model is trained, it is time to sample using the score function obtained via the Miyasawa-Tweedie formula from the denoiser. The score function relates the denoised prediction to the true score of the data distribution: + +``` +score(y, σ) = (x̂(y, σ) - y) / σ² +``` + +Where `x̂(y, σ)` is the denoised prediction from the trained model. + +Since KALA-JAMUN conditions on historical states, the sampling process must be modified to properly handle memory. This requires changes to the standard ABOBA and BAOAB samplers to account for the temporal dependencies and ensure that historical information is correctly propagated through the sampling chain. + +### 3.1 Model Loading + +Loading the trained conditional model is straightforward: + +**Source:** [`src/jamun/model/denoiser_conditional.py`](src/jamun/model/denoiser_conditional.py), load_from_checkpoint method + +```python +model = denoiser_conditional.Denoiser.load_from_checkpoint(checkpoint_path) +``` + +The loaded model retains its conditional structure with both the conditioning module and architecture module properly initialized. + +### 3.2 Memory-Aware Sampling Modifications + +#### Changes to ABOBA and BAOAB Samplers + +The standard ABOBA and BAOAB sampling algorithms must be modified to handle historical states. The key changes involve: + +1. **State Representation**: Instead of just current positions `y`, we now track `(y, y_hist)` where `y_hist` is a list of historical states +2. **Score Function Interface**: The score function now takes both current and historical states as input +3. **Memory Updates**: Historical states must be updated periodically during sampling to maintain temporal consistency + +#### Modified BAOAB with Memory + +The memory-aware BAOAB sampler uses a two-loop structure to handle temporal conditioning: + +**Source:** [`src/jamun/sampling/mcmc/functional/_splitting.py`](src/jamun/sampling/mcmc/functional/_splitting.py), lines 255-327 + +```python +def baoab_memory( + y: torch.Tensor, # Current positions + y_hist: list, # Historical states list + score_fn: Callable, # Score function accepting (y, y_hist) + steps: int, + history_update_frequency=1, # Inner loop length + **kwargs +): + """BAOAB splitting scheme with two-loop structure for memory updates.""" + + # Initialize velocity and score processing + v = initialize_velocity(v_init=v_init, y=y, u=u) + score_fn_processed = create_score_fn(score_fn, inverse_temperature, score_fn_clip) + psi, orig_score = score_fn_processed(y, y_hist=y_hist) + + # OUTER LOOP: Iterate over memory updates + for i in range(1, steps): + + # INNER LOOP: Equilibrate to conditional density p(y_t | y_hist) + for j in range(1, history_update_frequency): + y_current = y.clone().detach() + + # Standard BAOAB steps with FIXED history + v = v + u * (delta / 2) * psi # B: velocity update + y = y + (delta / 2) * v # A: position update + R = torch.randn_like(y) + vhat = math.exp(-friction) * v + zeta2 * math.sqrt(u) * R # O: Ornstein-Uhlenbeck + y = y + (delta / 2) * vhat # A: position update + psi, orig_score = score_fn_processed(y, y_hist=y_hist) # B: score update + v = vhat + (delta / 2) * psi + + # MEMORY UPDATE: Shift history after equilibration + y_hist.pop(-1) # Remove oldest state + y_hist.insert(0, y_current) # Add equilibrated state to history +``` + +##### Two-Loop Structure and Conditional Equilibration + +**Outer Loop (Memory Updates):** +- Iterates `steps` times over the complete sampling process +- Each iteration updates the historical memory `y_hist` +- Represents the temporal progression of the molecular system + +**Inner Loop (Conditional Equilibration):** +- Runs for `history_update_frequency` steps with **fixed historical context** +- Equilibrates the current state `y` to the conditional density `p(y_t | y_hist)` +- The historical states `y_hist` remain constant during this equilibration + +**Conditional Density Equilibration:** +The inner loop is crucial because it allows the sampler to properly explore the conditional distribution given the current historical context. Without sufficient equilibration: +- The sampler might not fully explore `p(y_t | y_hist)` before updating history +- This could lead to poor mixing and biased samples +- The temporal correlations learned during training might not be properly respected + +**Key Parameters:** +- **`history_update_frequency`**: Controls the balance between: + - **Computational cost**: Higher values require more inner loop steps + - **Sampling quality**: Longer equilibration ensures better conditional sampling + - **Temporal accuracy**: More frequent updates maintain tighter temporal consistency + +### 3.3 Score Function Wrapper Modifications + +The score function wrapper has been modified to handle current and historical states differently: + +**Source:** [`src/jamun/utils/sampling_wrapper.py`](src/jamun/utils/sampling_wrapper.py), lines 132-140 + +```python +def score(self, y, y_hist, sigma): + """Score function that handles current and historical states.""" + graph = self.positions_to_graph(y, y_hist).to(self.device) + return self._model.score(graph, sigma) + +def positions_to_graph(self, y, y_hist): + """Convert positions and history to graph format.""" + graph = self.init_graphs.clone() + graph.pos = y + graph.hidden_state = y_hist # Assign historical states + return graph +``` + +**Key Changes:** +1. **Dual Input**: Score function now accepts both `y` (current) and `y_hist` (historical) positions +2. **Graph Construction**: `positions_to_graph` method creates data graphs with `hidden_state` populated from `y_hist` +3. **Model Interface**: The wrapped model's score method processes the complete graph with temporal information + +### 3.4 ModelSamplingWrapperMemory + +The sampling wrapper provides the interface between the memory-aware samplers and the conditional model: + +**Source:** [`src/jamun/utils/sampling_wrapper.py`](src/jamun/utils/sampling_wrapper.py), lines 95-127 + +```python +class ModelSamplingWrapperMemory: + """Wrapper for models that depend on a memory of states.""" + + def __init__(self, model: nn.Module, init_graphs: torch_geometric.data.Data, sigma: float, recenter_on_init: bool = True): + self._model = model + self.init_graphs = init_graphs + self.sigma = sigma + + # Mean center both positions and hidden states + if recenter_on_init: + self.init_graphs = mean_center(self.init_graphs) + if hasattr(self.init_graphs, 'hidden_state') and self.init_graphs.hidden_state: + for i in range(len(self.init_graphs.hidden_state)): + mean = scatter(self.init_graphs.hidden_state[i], self.init_graphs.batch, dim=0, reduce="mean") + self.init_graphs.hidden_state[i] = self.init_graphs.hidden_state[i] - mean[self.init_graphs.batch] + + def sample_initial_noisy_positions(self) -> torch.Tensor: + """Sample initial noisy current positions.""" + pos = self.init_graphs.pos + return pos + torch.randn_like(pos) * self.sigma + + def sample_initial_noisy_history(self) -> list: + """Sample initial noisy historical states.""" + noisy_history = [] + for hidden_state in self.init_graphs.hidden_state: + noisy_history.append(hidden_state + torch.randn_like(hidden_state) * self.sigma) + return noisy_history +``` + +**Key Features:** +- **Dual Initialization**: Separately initializes current positions and historical states with noise +- **Mean Centering**: Applies mean centering to both current and historical states +- **Memory Management**: Handles the `hidden_state` list structure properly + +### 3.5 SamplerMemory Module + +The `SamplerMemory` class wraps the memory-aware sampling functionality: + +**Source:** [`src/jamun/sampling/_sampler.py`](src/jamun/sampling/_sampler.py), lines 101-130 + +```python +class SamplerMemory(Sampler): + """A sampler for molecular dynamics simulations that uses memory.""" + + def sample( + self, + model, + batch_sampler, + num_batches: int, + init_graphs: torch_geometric.data.Data, + continue_chain: bool = False, + ): + # Setup model and device + self.fabric.launch() + self.fabric.setup(model) + model.eval() + + # Create memory-aware wrapper + model_wrapped = utils.ModelSamplingWrapperMemory( + model=model, + init_graphs=init_graphs, + sigma=batch_sampler.sigma, + ) + + # Initialize with memory + y_init = model_wrapped.sample_initial_noisy_positions() + y_hist_init = model_wrapped.sample_initial_noisy_history() +``` + +**Responsibilities:** +- **Model Wrapping**: Creates `ModelSamplingWrapperMemory` instance +- **Memory Initialization**: Sets up both current and historical initial states +- **Sampling Coordination**: Manages the overall sampling process with memory + +### 3.6 Memory Loop Mechanics + +The memory loop in KALA-JAMUN sampling works as follows: + +1. **Initialization**: + - Current positions: `y_init` (noisy version of initial state) + - Historical states: `y_hist_init` (noisy versions of `hidden_state`) + +2. **Sampling Step**: + - Compute score using both current and historical states + - Update current positions using BAOAB/ABOBA dynamics + - Periodically update historical memory + +3. **Memory Update**: + - Shift historical states: `y_hist = [y_current, y_hist[:-1]]` + - Maintain constant memory length matching training configuration + +4. **Temporal Consistency**: + - Memory update frequency controls temporal coherence + - Balance between computational cost and temporal accuracy + +This design ensures that the sampling process respects the temporal dependencies learned during training while maintaining computational efficiency through configurable memory update frequencies. + +--- + +## Chapter 4: Usage + +### Overview + +This chapter describes the experimental setup for KALA-JAMUN, covering the enhanced dataset generation, training procedures, and sampling protocols. The experiments focus on alanine dipeptide (ALA_ALA) systems with enhanced sampling data to evaluate temporal conditioning performance. + +### 4.1 Enhanced Sampled Data + +#### Dataset Overview + +KALA-JAMUN experiments utilize enhanced sampling data from two main series: + +1. **ALA_ALA_enhanced**: 5 swarms of 50 frames each from 184 different grid points +2. **ALA_ALA_enhanced_long**: 2 swarms of 100,000 frames each from 184 different grid points + +Both datasets consist of swarms sampled at **20 fs intervals**, providing high temporal resolution for learning molecular dynamics. + +#### Data Source and Organization + +**Source Location**: `/data/bucket/vanib/ALA_ALA/swarms/swarm_results` + +The raw swarm data has been reorganized using the script: `scratch/reorganize_swarm_data.py` which sorts the trajectories into training and validation buckets according to different splitting strategies. + +#### Data Splitting Strategies + +The reorganization script implements four distinct splitting strategies: + +##### 1. Grid Split (`grid_split`) +- **Training Set**: 172 randomly selected grid codes, all trajectories (001-005) +- **Validation Set**: Remaining 12 grid codes, all trajectories +- **Principle**: Complete separation by spatial location in conformational space +- **Use Case**: Tests generalization to unseen regions of conformational space +- **Output**: `/data2/sules/ALA_ALA_enhanced_full_swarm` + +##### 2. Trajectory Split (`trajectory_split`) +- **Training Set**: All grid points, trajectories 001-004 +- **Validation Set**: All grid points, trajectory 005 +- **Principle**: Ensures both train/val cover all conformational regions +- **Use Case**: Tests temporal generalization within known conformational regions +- **Output**: `/data2/sules/ALA_ALA_enhanced_full_grid` + +##### 3. Long Grid Split (`long_grid_split`) +- **Training Set**: 172 grid codes, 2000ps trajectories (001, 003) +- **Validation Set**: Remaining 12 grid codes, 2000ps trajectories +- **Principle**: Grid-based splitting with extended trajectories +- **Use Case**: Tests generalization with longer temporal context +- **Output**: `/data2/sules/ALA_ALA_enhanced_long` + +##### 4. State Split (`state_split`) +- **Criterion**: Conformational state based on phi/psi angles of first residue +- **Training Set**: Trajectories outside phi ∈ (0,100°), psi ∈ (-50,100°) +- **Validation Set**: Trajectories with first residue in specified phi/psi range +- **Principle**: Complete withholding of specific conformational states +- **Use Case**: Tests ability to generate unseen metastable conformations +- **Output**: `/data2/sules/ALA_ALA_enhanced_long_state_split` + +**Script Usage:** +```bash +python reorganize_swarm_data.py SPLITTING_STRATEGY +``` + +**Available Strategies:** +- `grid_split`: For standard enhanced data with grid-based splitting +- `trajectory_split`: For full grid coverage with trajectory-based splitting +- `long_grid_split`: For long trajectories with grid-based splitting +- `state_split`: For conformational state-based splitting + +#### Using 2000ps Trajectories + +For experiments requiring the 2000ps trajectory data (strategies `long_grid_split` and `state_split`), the script automatically uses the longer trajectories. However, if you need to modify this behavior, update the following locations in `reorganize_swarm_data.py`: + +**Source:** [`scratch/reorganize_swarm_data.py`](scratch/reorganize_swarm_data.py), lines 468 and 480 + +```python +# Line 468: In reorganize_with_long_grid_split function +copy_files_for_grid_split( + SOURCE_DIR, + os.path.join(target_dir, 'train'), + train_codes, + trajectory_codes, + SINGLE_PDB_FILE, + 'TRAIN', + use_2000ps=True # Set to True for 2000ps trajectories +) + +# Line 480: In the same function for validation split +copy_files_for_grid_split( + SOURCE_DIR, + os.path.join(target_dir, 'val'), + val_codes, + trajectory_codes, + SINGLE_PDB_FILE, + 'VAL', + use_2000ps=True # Set to True for 2000ps trajectories +) +``` + +**Function Parameter:** [`scratch/reorganize_swarm_data.py`](scratch/reorganize_swarm_data.py), line 151 + +```python +def copy_files_for_grid_split( + source_dir: str, + target_dir: str, + grid_codes: List[str], + trajectory_codes: List[str], + single_pdb_file: str, + split_name: str, + use_2000ps: bool = False # Set to True for 2000ps trajectories +): +``` + +### 4.2 Training + +#### Training Configuration + +Once data selection is completed, models can be trained using the `train_enhanced_*` configuration series: + +**Command Syntax:** +```bash +jamun_train --config-dir=configs experiment={experimental_config_name} +``` + +### 4.3 Sampling + +#### Sampling Configuration + +Once training is completed, sampling requires the memory-aware configuration: + +**Critical Configuration**: Set `config="sample_memory"` in the sampling script to enable memory-aware sampling with historical state management. + +#### Sampling Command + +**Standard Syntax:** +```bash +jamun_sample --config-dir=configs experiment={experimental_config_name} +``` + +**Example:** +```bash +jamun_sample --config-dir=configs experiment=train_enhanced_full_grid +``` + +#### Memory-Aware Sampling Setup + +The `sample_memory` configuration automatically handles: + +1. **Model Loading**: Loads conditional denoiser from checkpoint +2. **Memory Initialization**: Sets up initial historical states from validation data +3. **Sampler Selection**: Uses `SamplerMemory` with `baoab_memory` algorithm +4. **Wrapper Configuration**: Employs `ModelSamplingWrapperMemory` for proper interface + +### 4.4 Experiments + +This section describes key experiments designed to evaluate KALA-JAMUN's performance and validate design choices for temporal conditioning. + +#### 4.4.1 Model Comparison + +**Objective**: Compare different conditioning strategies and temporal graph topologies to establish the effectiveness of spatiotemporal conditioning. + +**Models Compared**: +1. **Standard JAMUN**: Baseline unconditional denoiser without temporal information +2. **Position Conditioner**: Simple conditioning using current positions only +3. **Spatiotemporal Conditioner (Fan Graph)**: Full spatiotemporal model with fan temporal graph topology +4. **Spatiotemporal Conditioner (Hub-and-Spoke)**: Full spatiotemporal model with hub-and-spoke temporal graph topology + +For instance, check out this wandb [run](https://genentech.wandb.io/sule-shashank/jamun/runs/scxc4bt4/overview) and its associated group. + + +#### 4.4.2 Noise Check (Multimeasurement Validation) + +**Objective**: Validate the multimeasurement approach by comparing standard JAMUN with reduced noise against spatiotemporal models using repeated position datasets. + +**Experimental Setup**: + +**Standard JAMUN Configuration**: +- Noise level: `σ/√T` (reduced noise to account for T measurements) +- Dataset: Standard molecular trajectory data +- Model: Unconditional denoiser + +**Spatiotemporal Model Configuration**: +- **Repeated Position Dataset**: `total_lag_time = T` with repeated copies of current state +- **Standard Temporal Dataset**: `total_lag_time = T` with historical trajectory states +- Noise level: Standard `σ` +- Model: Spatiotemporal conditioner + +**Experimental Script**: [`scripts/slurm/train_noise_check.sh`](scripts/slurm/train_noise_check.sh) + +**Key Comparisons**: +1. **Standard JAMUN (σ/√T)** vs **Spatiotemporal + Repeated Dataset (σ)** +2. **Standard JAMUN (σ/√T)** vs **Spatiotemporal + Temporal Dataset (σ)** +3. **Repeated Dataset** vs **Temporal Dataset** (both with spatiotemporal conditioning) + +**Sample wandb run** +Run [here](https://genentech.wandb.io/sule-shashank/jamun/runs/4j8bfj5k/overview) and check out its associated group. + +#### 4.4.3 Total Lag Time vs Lag Subsample Rate Experiment + +**Objective**: Systematically evaluate the impact of temporal parameters (`total_lag_time` and `lag_subsample_rate`) across different temporal graph topologies. + +**Parameter Space**: +- **Total Lag Time**: Number of historical states included (e.g., 2, 4, 6, 8, 10) +- **Lag Subsample Rate**: Temporal spacing between consecutive states (e.g., 5, 10, 20, 50 timesteps) +- **Graph Types**: Fan, Hub-and-Spoke, Complete graph topologies + +**Experimental Design**: +- Grid search across parameter combinations +- Fixed computational budget per configuration +- Consistent evaluation metrics across all runs + +**Experimental Script**: [`scripts/slurm/train_graph_type_comparison.sh`](scripts/slurm/train_graph_type_comparison.sh) + +**Wandb runs** + +Run [here](https://genentech.wandb.io/sule-shashank/jamun/runs/tjwcsf4g/overview) and check out its associated group +#### 4.4.4 Sampling runs + +1. **Bond degradation** The bond degradation of KALA-JAMUN vs Standard JAMUN with a trajectory of 50K steps was compared. We also compared KALA-JAMUN to a standard JAMUN trained for 500 epochs. The run for KALA jamun is [here](https://genentech.wandb.io/sule-shashank/jamun/runs/1j4us3nx?nw=nwusersuleshashank). The run for Standard JAMUN is [here](https://genentech.wandb.io/sule-shashank/jamun/runs/vigqbemt/overview) and the run for the highly trained standard JAMUN is [here](https://genentech.wandb.io/sule-shashank/jamun/runs/9u4qo5ax/overview). + +2. **Comparing ensembles**: We compared KALA JAMUN vs JAMUN in terms of being able to converge the distribution from the short swarm data (1ps). The results for KALA-JAMUN are [here](https://genentech.wandb.io/sule-shashank/jamun/runs/jwk7i45j/overview) and standard JAMUN are [here](https://genentech.wandb.io/sule-shashank/jamun/runs/u2of58jn/overview). \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 2b99b1b..bb4b014 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ requires-python = ">=3.10" dependencies = [ "ase>=3.23.0", "e3nn>=0.5.6", - "e3tools>=0.1.1", + "e3tools>=0.1.2", "einops>=0.8.0", "hydra-core>=1.3.2", "lightning>=2.4.0", @@ -38,12 +38,20 @@ dependencies = [ "universal-pathlib>=0.2.6", "wandb>=0.19.1", "orb_models>=0.5.4", + "optree>=0.17.0", ] [project.scripts] jamun_train = "jamun.cmdline.train:main" jamun_sample = "jamun.cmdline.sample:main" +[project.optional-dependencies] +analysis = [ + "polars>=1.32.0", + "pyarrow>=21.0.0", + "seaborn>=0.13.2", +] + [build-system] requires = ["hatchling", "hatch-vcs"] build-backend = "hatchling.build" diff --git a/scratch/README_reorganization.md b/scratch/README_reorganization.md new file mode 100644 index 0000000..856b5d0 --- /dev/null +++ b/scratch/README_reorganization.md @@ -0,0 +1,122 @@ +# ALA_ALA Swarm Data Reorganization Scripts + +## Overview +These scripts reorganize molecular dynamics swarm data from `/data/bucket/vanib/ALA_ALA/swarm_results/` into a machine learning-ready format in `/data2/sules/ALA_ALA_enhanced/`. + +## Scripts + +### 1. `reorganize_swarm_data.py` - Main Script +**Input Structure:** +- Source: `/data/bucket/vanib/ALA_ALA/swarm_results/` +- 184 directories: `AA_000/`, `AA_001/`, ..., `AA_183/` +- Each contains: `swarm_1ps_001.xtc`, `swarm_1ps_002.xtc`, ..., `swarm_1ps_005.xtc` +- Single PDB file: `/data/bucket/vanib/ALA_ALA/ALA_ALA.pdb` + +**Output Structure:** +- Target: `/data2/sules/ALA_ALA_enhanced/` +- `train/` - 172 randomly selected grid codes (860 .xtc + 860 .pdb files) +- `val/` - Remaining 12 grid codes (60 .xtc + 60 .pdb files) + +**File Naming Convention:** +- Original: `swarm_1ps_{traj_code}.xtc` → New: `swarm_1ps_{grid_code}_{traj_code}.xtc` +- PDB files: `swarm_1ps_{grid_code}_{traj_code}.pdb` (copied from single source) + +**Features:** +- ✅ **Progress bars** with tqdm showing copy progress +- ✅ **Reproducible random split** (seed=42) +- ✅ **mdtraj validation** - Tests that .xtc + .pdb pairs load correctly +- ✅ **Comprehensive logging** - Detailed progress and error reporting +- ✅ **Safe operation** - Only copies, never moves/deletes source data +- ✅ **Verification** - File count validation and structure checking + +### 2. `test_reorganize_swarm_data.py` - Test Script +- Creates mock data structure with 5 grid codes +- Tests the reorganization logic with small dataset +- Validates file organization and naming +- Tests mdtraj integration (expected to fail on mock data) +- Verifies progress bar functionality + +## Usage + +```bash +# Activate conda environment +conda activate jamun + +# Navigate to scripts directory +cd scratch + +# Run test first (optional) +python test_reorganize_swarm_data.py + +# Run reorganization with trajectory split (default) +python reorganize_swarm_data.py trajectory_split + +# Or run reorganization with grid split +python reorganize_swarm_data.py grid_split + +# Or run without arguments (defaults to trajectory_split) +python reorganize_swarm_data.py +``` + +## Splitting Strategies + +The script supports two different data splitting strategies: + +### 1. Grid Split (`grid_split`) +- **Random grid codes split**: 172 grids for train, 12 grids for val, all trajectories +- **Output folder**: `/data2/sules/ALA_ALA_enhanced_full_swarm/` +- **Train**: 172 grid codes × 5 trajectories × 2 file types = 1,720 files +- **Val**: 12 grid codes × 5 trajectories × 2 file types = 120 files + +### 2. Trajectory Split (`trajectory_split`) - **DEFAULT** +- **All grids split by trajectory**: trajectories 001-004 for train, 005 for val +- **Output folder**: `/data2/sules/ALA_ALA_enhanced_full_grid/` +- **Train**: 184 grid codes × 4 trajectories × 2 file types = 1,472 files +- **Val**: 184 grid codes × 1 trajectory × 2 file types = 368 files + +## Expected Results + +**Grid Split structure:** +``` +/data2/sules/ALA_ALA_enhanced_full_swarm/ +├── train/ # 1720 files total +│ ├── swarm_1ps_000_001.xtc # Random 172 grids, all trajectories +│ ├── swarm_1ps_000_001.pdb +│ └── ... (172 grid codes × 5 trajectories × 2 file types) +└── val/ # 120 files total + ├── swarm_1ps_XXX_001.xtc # Remaining 12 grids, all trajectories + └── ... (12 grid codes × 5 trajectories × 2 file types) +``` + +**Trajectory Split structure:** +``` +/data2/sules/ALA_ALA_enhanced_full_grid/ +├── train/ # 1472 files total +│ ├── swarm_1ps_000_001.xtc # All 184 grids, trajectories 001-004 +│ ├── swarm_1ps_000_002.xtc +│ ├── swarm_1ps_000_003.xtc +│ ├── swarm_1ps_000_004.xtc +│ └── ... (184 grid codes × 4 trajectories × 2 file types) +└── val/ # 368 files total + ├── swarm_1ps_000_005.xtc # All 184 grids, trajectory 005 only + ├── swarm_1ps_001_005.xtc + └── ... (184 grid codes × 1 trajectory × 2 file types) +``` + +## Dependencies +- **mdtraj** - For trajectory validation (available in jamun environment) +- **tqdm** - For progress bars (available in jamun environment) +- **Standard library** - os, shutil, random, logging, pathlib + +## Runtime Estimate +- **Test script**: ~1 second +- **Grid split**: ~52 seconds (1,840 file operations) +- **Trajectory split**: ~37 seconds (1,840 file operations) +- **Disk space needed**: ~Same as source data for each folder (copying, not moving) + +## Safety Features +- Never deletes or moves source data +- Validates all paths before starting +- Reports missing files without stopping +- Tests mdtraj compatibility with random samples +- Detailed error logging and recovery \ No newline at end of file diff --git a/scratch/__init__.py b/scratch/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scratch/analyze_grid_code_distribution.py b/scratch/analyze_grid_code_distribution.py new file mode 100755 index 0000000..610d14b --- /dev/null +++ b/scratch/analyze_grid_code_distribution.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python3 +""" +Script to analyze the distribution of trajectories by grid codes. +Creates a histogram showing how many trajectory codes exist for each grid code. +""" + +import argparse +import os +import re +from collections import Counter, defaultdict + +import matplotlib.pyplot as plt +import numpy as np + + +def parse_trajectory_files(data_dir): + """Parse trajectory files and extract grid codes and traj codes.""" + # Pattern to match traj_{grid_code}_{traj_code} + pattern = re.compile(r"^traj_(\d+)_(\d+)") + + grid_traj_mapping = defaultdict(list) + + # Scan directory for trajectory files + if not os.path.exists(data_dir): + raise ValueError(f"Directory {data_dir} does not exist") + + files = os.listdir(data_dir) + trajectory_files = [] + + for filename in files: + match = pattern.match(filename) + if match: + grid_code = int(match.group(1)) + traj_code = int(match.group(2)) + grid_traj_mapping[grid_code].append(traj_code) + trajectory_files.append(filename) + + print(f"Found {len(trajectory_files)} trajectory files") + print(f"Found {len(grid_traj_mapping)} unique grid codes") + + return grid_traj_mapping + + +def create_histogram(grid_traj_mapping, output_path=None): + """Create histogram of trajectory counts per grid code.""" + + if not grid_traj_mapping: + print("No trajectory files found!") + return + + # Get the full range of grid codes (min to max) + all_grid_codes = list(grid_traj_mapping.keys()) + min_grid = min(all_grid_codes) + max_grid = max(all_grid_codes) + + print(f"Grid code range: {min_grid} to {max_grid}") + + # Create array for all grid codes in range + full_range = list(range(min_grid, max_grid + 1)) + traj_counts = [] + + for grid_code in full_range: + count = len(grid_traj_mapping.get(grid_code, [])) + traj_counts.append(count) + + # Print some statistics + total_trajs = sum(traj_counts) + non_zero_grids = sum(1 for count in traj_counts if count > 0) + zero_grids = len(full_range) - non_zero_grids + + print(f"Total trajectories: {total_trajs}") + print(f"Grid codes with trajectories: {non_zero_grids}") + print(f"Grid codes with no trajectories: {zero_grids}") + print(f"Average trajectories per grid code: {total_trajs / len(full_range):.2f}") + print(f"Max trajectories for single grid code: {max(traj_counts)}") + print(f"Min trajectories for single grid code: {min(traj_counts)}") + + # Create histogram + plt.figure(figsize=(12, 6)) + plt.bar(full_range, traj_counts, width=0.8, alpha=0.7, edgecolor="black", linewidth=0.5) + plt.xlabel("Grid Code") + plt.ylabel("Number of Trajectories") + plt.title("Distribution of Trajectories by Grid Code") + plt.grid(True, alpha=0.3) + + # Add some formatting + if len(full_range) > 50: + # If too many grid codes, adjust x-axis ticks + step = max(1, len(full_range) // 20) + plt.xticks(full_range[::step], rotation=45) + else: + plt.xticks(full_range) + + plt.tight_layout() + + # Save plot + if output_path: + plt.savefig(output_path, dpi=300, bbox_inches="tight") + print(f"Histogram saved to: {output_path}") + + plt.show() + + return full_range, traj_counts + + +def print_detailed_stats(grid_traj_mapping): + """Print detailed statistics about the distribution.""" + + counts = [len(trajs) for trajs in grid_traj_mapping.values()] + + print("\n" + "=" * 50) + print("DETAILED STATISTICS") + print("=" * 50) + + print(f"Total unique grid codes found: {len(grid_traj_mapping)}") + print(f"Total trajectories: {sum(counts)}") + + if counts: + print(f"Mean trajectories per grid code: {np.mean(counts):.2f}") + print(f"Median trajectories per grid code: {np.median(counts):.2f}") + print(f"Std dev trajectories per grid code: {np.std(counts):.2f}") + + # Show distribution of counts + count_dist = Counter(counts) + print("\nDistribution of trajectory counts:") + for count, frequency in sorted(count_dist.items()): + print(f" {count} trajectories: {frequency} grid codes") + + # Show some examples + print("\nExample grid codes and their trajectory counts:") + for i, (grid_code, trajs) in enumerate(sorted(grid_traj_mapping.items())[:10]): + print(f" Grid {grid_code}: {len(trajs)} trajectories") + + if len(grid_traj_mapping) > 10: + print(f" ... and {len(grid_traj_mapping) - 10} more") + + +def main(): + parser = argparse.ArgumentParser(description="Analyze trajectory distribution by grid codes") + parser.add_argument( + "--data-dir", default="/data2/sules/fake_enhanced_data/ALA_ALA", help="Directory containing trajectory files" + ) + parser.add_argument("--output", default="scratch/grid_code_histogram.png", help="Output path for histogram plot") + + args = parser.parse_args() + + print(f"Analyzing trajectories in: {args.data_dir}") + + # Parse trajectory files + grid_traj_mapping = parse_trajectory_files(args.data_dir) + + if not grid_traj_mapping: + print("No trajectory files found matching pattern traj_{grid_code}_{traj_code}") + return + + # Print detailed statistics + print_detailed_stats(grid_traj_mapping) + + # Create histogram + print("\nCreating histogram...") + full_range, traj_counts = create_histogram(grid_traj_mapping, args.output) + + print("\nAnalysis complete!") + + +if __name__ == "__main__": + main() diff --git a/scratch/analyze_trajectory_noise.py b/scratch/analyze_trajectory_noise.py new file mode 100644 index 0000000..8db0ba1 --- /dev/null +++ b/scratch/analyze_trajectory_noise.py @@ -0,0 +1,289 @@ +#!/usr/bin/env python3 +""" +Trajectory noise analysis script. +Loads .xtc trajectory files, adds noise, computes norms between successive points, +and creates histograms with flexible time point filtering. +""" + +import argparse +import glob +import os + +import matplotlib.pyplot as plt +import mdtraj as md +import numpy as np + + +def load_trajectories(traj_dir: str, topology_file: str, max_files: int | None = None) -> list[md.Trajectory]: + """ + Load trajectory files from directory using MDTraj. + + Args: + traj_dir: Directory containing .xtc files + topology_file: Path to PDB topology file + max_files: Maximum number of files to load (None for all) + + Returns: + List of MDTraj trajectory objects + """ + xtc_files = glob.glob(os.path.join(traj_dir, "*.xtc")) + xtc_files.sort() + + if max_files is not None: + xtc_files = xtc_files[:max_files] + + print(f"Loading {len(xtc_files)} trajectory files...") + + trajectories = [] + for i, xtc_file in enumerate(xtc_files): + try: + traj = md.load(xtc_file, top=topology_file) + trajectories.append(traj) + if (i + 1) % 50 == 0: + print(f"Loaded {i + 1}/{len(xtc_files)} trajectories") + except Exception as e: + print(f"Warning: Failed to load {xtc_file}: {e}") + + print(f"Successfully loaded {len(trajectories)} trajectories") + return trajectories + + +def add_noise_to_trajectory(traj: md.Trajectory, noise_magnitude: float = 0.04) -> md.Trajectory: + """ + Add Gaussian noise to trajectory coordinates. + + Args: + traj: MDTraj trajectory object + noise_magnitude: Standard deviation of Gaussian noise to add (in nm) + + Returns: + New trajectory with added noise + """ + # Copy the trajectory to avoid modifying the original + noisy_traj = traj.slice(range(traj.n_frames)) + + # Add Gaussian noise to xyz coordinates + noise = np.random.normal(0, noise_magnitude, noisy_traj.xyz.shape) + noisy_traj.xyz += noise + + return noisy_traj + + +def compute_successive_norms(traj: md.Trajectory) -> np.ndarray: + """ + Compute norms between successive trajectory points. + + Args: + traj: MDTraj trajectory object + + Returns: + Array of norms between successive points for each atom + """ + if traj.n_frames < 2: + return np.array([]) + + # Calculate differences between successive frames + diff = traj.xyz[1:] - traj.xyz[:-1] # Shape: (n_frames-1, n_atoms, 3) + + # Compute norms for each atom at each time step + norms = np.linalg.norm(diff, axis=2) # Shape: (n_frames-1, n_atoms) + + return norms + + +def compute_norms_for_time_points(traj: md.Trajectory, time_points: list[tuple[int, int]]) -> np.ndarray: + """ + Compute norms between specific time points. + + Args: + traj: MDTraj trajectory object + time_points: List of (start_frame, end_frame) tuples + + Returns: + Array of norms for specified time point pairs + """ + norms = [] + + for start_frame, end_frame in time_points: + if start_frame < traj.n_frames and end_frame < traj.n_frames: + diff = traj.xyz[end_frame] - traj.xyz[start_frame] # Shape: (n_atoms, 3) + frame_norms = np.linalg.norm(diff, axis=1) # Shape: (n_atoms,) + norms.extend(frame_norms) + + return np.array(norms) + + +def analyze_trajectories( + trajectories: list[md.Trajectory], + noise_magnitude: float = 0.04, + time_point_filter: list[tuple[int, int]] | None = None, +) -> tuple[np.ndarray, np.ndarray]: + """ + Analyze trajectories by adding noise and computing norms. + + Args: + trajectories: List of MDTraj trajectory objects + noise_magnitude: Standard deviation of Gaussian noise + time_point_filter: Optional list of (start, end) frame pairs to analyze + + Returns: + Tuple of (original_norms, noisy_norms) + """ + original_norms = [] + noisy_norms = [] + + print(f"Analyzing {len(trajectories)} trajectories...") + + for i, traj in enumerate(trajectories): + if time_point_filter is not None: + # Compute norms for specific time points + orig_norm = compute_norms_for_time_points(traj, time_point_filter) + + # Add noise and compute norms for same time points + noisy_traj = add_noise_to_trajectory(traj, noise_magnitude) + noisy_norm = compute_norms_for_time_points(noisy_traj, time_point_filter) + else: + # Compute successive norms for all time points + orig_norm = compute_successive_norms(traj) + + # Add noise and compute successive norms + noisy_traj = add_noise_to_trajectory(traj, noise_magnitude) + noisy_norm = compute_successive_norms(noisy_traj) + + # Flatten and collect norms + original_norms.extend(orig_norm.flatten()) + noisy_norms.extend(noisy_norm.flatten()) + + if (i + 1) % 10 == 0: + print(f"Analyzed {i + 1}/{len(trajectories)} trajectories") + + return np.array(original_norms), np.array(noisy_norms) + + +def create_histogram( + original_norms: np.ndarray, + noisy_norms: np.ndarray, + title: str = "Norm Differences Between Successive Trajectory Points", + bins: int = 50, + save_path: str | None = None, +): + """ + Create histogram comparing original and noisy trajectory norms. + + Args: + original_norms: Array of norms from original trajectories + noisy_norms: Array of norms from noisy trajectories + title: Plot title + bins: Number of histogram bins + save_path: Optional path to save the plot + """ + plt.figure(figsize=(12, 8)) + + # Create histogram + plt.hist(original_norms, bins=bins, alpha=0.7, label="Original", density=True, color="blue") + plt.hist(noisy_norms, bins=bins, alpha=0.7, label="With Noise (σ=0.04)", density=True, color="red") + + plt.xlabel("Norm (nm)") + plt.ylabel("Density") + plt.title(title) + plt.legend() + plt.grid(True, alpha=0.3) + + # Add statistics + orig_mean, orig_std = np.mean(original_norms), np.std(original_norms) + noisy_mean, noisy_std = np.mean(noisy_norms), np.std(noisy_norms) + + stats_text = f"Original: μ={orig_mean:.4f}, σ={orig_std:.4f}\n" + stats_text += f"Noisy: μ={noisy_mean:.4f}, σ={noisy_std:.4f}" + + plt.text( + 0.98, + 0.98, + stats_text, + transform=plt.gca().transAxes, + verticalalignment="top", + horizontalalignment="right", + bbox=dict(boxstyle="round", facecolor="white", alpha=0.8), + ) + + plt.tight_layout() + + if save_path: + plt.savefig(save_path, dpi=300, bbox_inches="tight") + print(f"Plot saved to {save_path}") + + plt.show() + + +def main(): + parser = argparse.ArgumentParser(description="Analyze trajectory noise effects") + parser.add_argument( + "--traj_dir", + type=str, + default="/data2/sules/fake_enhanced_data/ALA_ALA_organized/train", + help="Directory containing trajectory files", + ) + parser.add_argument( + "--topology", + type=str, + default="/data/bucket/kleinhej/capped_diamines/timewarp_splits/train/ALA_ALA.pdb", + help="PDB topology file", + ) + parser.add_argument("--noise_magnitude", type=float, default=0.04, help="Noise magnitude (standard deviation)") + parser.add_argument("--max_files", type=int, default=50, help="Maximum number of trajectory files to load") + parser.add_argument( + "--time_filter", + type=str, + default=None, + help='Time point filter as "start1,end1;start2,end2" (e.g., "0,1" for initial->next)', + ) + parser.add_argument("--output", type=str, default="trajectory_noise_analysis.png", help="Output plot filename") + parser.add_argument("--bins", type=int, default=50, help="Number of histogram bins") + + args = parser.parse_args() + + # Parse time filter if provided + time_point_filter = None + if args.time_filter: + try: + pairs = args.time_filter.split(";") + time_point_filter = [] + for pair in pairs: + start, end = map(int, pair.split(",")) + time_point_filter.append((start, end)) + print(f"Using time point filter: {time_point_filter}") + except: + print("Warning: Invalid time filter format. Using all successive points.") + + # Load trajectories + trajectories = load_trajectories(args.traj_dir, args.topology, args.max_files) + + if not trajectories: + print("No trajectories loaded. Exiting.") + return + + # Analyze trajectories + original_norms, noisy_norms = analyze_trajectories(trajectories, args.noise_magnitude, time_point_filter) + + # Create title based on analysis type + if time_point_filter: + title = f"Norm Differences for Time Points {time_point_filter}" + else: + title = "Norm Differences Between Successive Trajectory Points" + title += f" (Noise σ={args.noise_magnitude})" + + # Create histogram + create_histogram(original_norms, noisy_norms, title, args.bins, args.output) + + # Print summary statistics + print("\nSummary Statistics:") + print(f"Original trajectories: {len(original_norms)} data points") + print(f" Mean norm: {np.mean(original_norms):.6f} nm") + print(f" Std norm: {np.std(original_norms):.6f} nm") + print(f"Noisy trajectories: {len(noisy_norms)} data points") + print(f" Mean norm: {np.mean(noisy_norms):.6f} nm") + print(f" Std norm: {np.std(noisy_norms):.6f} nm") + + +if __name__ == "__main__": + main() diff --git a/scratch/bond_length_issues.py b/scratch/bond_length_issues.py new file mode 100644 index 0000000..9f1b2f3 --- /dev/null +++ b/scratch/bond_length_issues.py @@ -0,0 +1,119 @@ +import matplotlib.pyplot as plt +import mdtraj as md +import numpy as np + +from jamun.metrics._chemical_validity import check_bond_lengths + +# a. Load trajectory and topology +traj_path_conditional = "/data2/sules/jamun-conditional-runs/outputs/sample/dev/runs/2025-08-13_20-22-36/sampler/ALA_ALA/predicted_samples/dcd/joined.dcd" +pdb_path_conditional = ( + "/data2/sules/jamun-conditional-runs/outputs/sample/dev/runs/2025-08-13_20-22-36/sampler/ALA_ALA/topology.pdb" +) +traj_path_unconditional = "/data2/sules/jamun-conditional-runs//outputs/sample/dev/runs/2025-08-19_18-56-30/sampler/ALA_ALA/predicted_samples/dcd/joined.dcd" +pdb_path_unconditional = ( + "/data2/sules/jamun-conditional-runs//outputs/sample/dev/runs/2025-08-19_18-56-30/sampler/ALA_ALA/topology.pdb" +) +md_traj_path = "/data/bucket/kleinhej/capped_diamines/timewarp_splits/train/ALA_ALA.xtc" +md_pdb_path = "/data/bucket/kleinhej/capped_diamines/timewarp_splits/train/ALA_ALA.pdb" +# Load trajectory with topology +traj_conditional = md.load(traj_path_conditional, top=pdb_path_conditional) +traj_unconditional = md.load(traj_path_unconditional, top=pdb_path_unconditional) +md_traj = md.load(md_traj_path, top=md_pdb_path) +md_traj = md_traj[::28] +breakpoint() + +# b. Check bond length issues +tolerance = 0.1 # 20% tolerance (commonly used value) +bond_length_issues_conditional = check_bond_lengths(traj_conditional, tolerance=tolerance) +bond_length_issues_unconditional = check_bond_lengths(traj_unconditional, tolerance=tolerance) +bond_length_issues_md = check_bond_lengths(md_traj, tolerance=tolerance) +breakpoint() +print(f"\nBond length analysis (tolerance: {tolerance * 100}%):") +print(f"Number of frames analyzed: {len(bond_length_issues_conditional)}") + +# Convert to numpy array for easier analysis +issues_array_conditional = np.array(bond_length_issues_conditional) +total_issues_conditional = np.sum(issues_array_conditional) +cumulants_conditional = np.array( + [ + np.sum(issues_array_conditional[:i]) / np.sum(issues_array_conditional) + for i in range(issues_array_conditional.shape[0]) + ] +) + +issues_array_unconditional = np.array(bond_length_issues_unconditional) +total_issues_unconditional = np.sum(issues_array_unconditional) +cumulants_unconditional = np.array( + [ + np.sum(issues_array_unconditional[:i]) / np.sum(issues_array_unconditional) + for i in range(issues_array_unconditional.shape[0]) + ] +) + +issues_array_md = np.array(bond_length_issues_md) +total_issues_md = np.sum(issues_array_md) +cumulants_md = np.array( + [np.sum(issues_array_md[:i]) / np.sum(issues_array_md) for i in range(issues_array_md.shape[0])] +) + +breakpoint() +# Create histogram +plt.figure(figsize=(12, 8)) + +# Main histogram +plt.subplot(1, 2, 1) +plt.hist( + issues_array_conditional, bins=10, alpha=0.7, range=(0.0, 1.0), edgecolor="black", color="blue", label="KALA-JAMUN" +) +plt.hist( + issues_array_unconditional, bins=10, alpha=0.7, range=(0.0, 1.0), edgecolor="black", color="red", label="JAMUN" +) +plt.hist( + issues_array_md, + bins=10, + alpha=0.7, + range=(0.0, 1.0), + edgecolor="black", + color="green", + label="Reference MD Trajectory issues", +) +plt.legend() +plt.xlabel("Fraction of Bonds with Issues", fontsize=14) +plt.ylabel("Number of Frames", fontsize=14) +plt.ylim(0, 5.0e4) +plt.title(f"Distribution of Bond Length Issues\n(Tolerance: {tolerance * 100}%)", fontsize=14) +plt.grid(True, alpha=0.3) + +# Time series plot +plt.subplot(1, 2, 2) +plt.plot( + np.linspace(0, 1, issues_array_conditional.shape[0]), + cumulants_conditional, + alpha=0.5, + color="blue", + label="KALA-JAMUN", + linewidth=5, +) +plt.plot( + np.linspace(0, 1, issues_array_unconditional.shape[0]), + cumulants_unconditional, + alpha=0.5, + color="red", + label="JAMUN", + linewidth=5, +) +plt.plot( + np.linspace(0, 1, issues_array_md.shape[0]), + cumulants_md, + alpha=0.5, + color="green", + label="Reference MD Trajectory", + linewidth=5, +) +plt.legend() +plt.xlabel("Prop. of trajectory length", fontsize=14) +plt.ylabel("Fraction of issues arising", fontsize=14) +plt.title("Bond Issues Over Time", fontsize=14) +plt.grid(True, alpha=0.3) + +plt.savefig("bond_length_issues_conditional_traj.png") diff --git a/scratch/check_denoiser.py b/scratch/check_denoiser.py new file mode 100644 index 0000000..b8e527a --- /dev/null +++ b/scratch/check_denoiser.py @@ -0,0 +1,246 @@ +# %% +import logging +import os + +import dotenv + +logging.basicConfig(format="[%(asctime)s][%(name)s][%(levelname)s] - %(message)s", level=logging.INFO) +py_logger = logging.getLogger("jamun") + +import torch + +torch.cuda.is_available() +torch.set_float32_matmul_precision("high") + +import e3nn + +e3nn.set_optimization_defaults(jit_script_fx=False) + +import jamun +import jamun.data +import jamun.distributions +import jamun.model +import jamun.model.arch + +# %% +# dataset +dotenv.load_dotenv("../.env", verbose=True) +JAMUN_DATA_PATH = os.getenv("JAMUN_DATA_PATH") +JAMUN_ROOT_PATH = os.getenv("JAMUN_ROOT_PATH") + +# %% +datasets = { + "test": jamun.data.parse_datasets_from_directory( + root=f"{JAMUN_DATA_PATH}/timewarp/2AA-1-large/train/", + traj_pattern="^(.*)-traj-arrays.npz", + pdb_file="AA-traj-state0.pdb", + filter_codes=["AA"], + as_iterable=False, + subsample=100, + max_datasets=1, + ) +} + +datamodule = jamun.data.MDtrajDataModule( + datasets=datasets, + batch_size=3, + num_workers=2, +) +datamodule.setup("test") +_, data_batch = next(enumerate(datamodule.test_dataloader())) + +# %% check paths +import sys + +project_root = "/homefs/home/sules/jamun" # Or use os.path.abspath("..") if your notebook is in a subdir of jamun + +if project_root not in sys.path: + sys.path.insert(0, project_root) + py_logger.info(f"Added '{project_root}' to sys.path for module discovery.") +else: + py_logger.info(f"'{project_root}' is already in sys.path.") + +# %% get configs +import hydra +from hydra import compose, initialize +from omegaconf import OmegaConf + +# Load the configuration file +with initialize(config_path="", job_name="jamun_test"): + cfg = compose( + config_name="config", + overrides=[ + "model.arch._target_=scratch.e3conv_test.E3Conv", # Override to use E3Conv from e3conv_test.py + "model._target_=scratch.denoiser_test.Denoiser", # Override to use Denoiser from denoiser_test.py + "+model.arch.N_structures=2", # Ensure N_structures is set, defaulting to 2 + ], + ) +# Log the configuration +py_logger.info("Loaded configuration:") +py_logger.info(OmegaConf.to_yaml(cfg)) + + +# %% Re-instantiate the model with the updated configuration +try: + py_logger.info("Attempting to re-instantiate model with updated arch...") + # Ensure average_squared_distance is still set correctly + if not hasattr(cfg.model, "average_squared_distance") or cfg.model.average_squared_distance is None: + from jamun.utils import compute_average_squared_distance_from_datasets # Ensure import + + average_squared_distance = compute_average_squared_distance_from_datasets( + datasets["test"], cfg.model.max_radius + ) + cfg.model.average_squared_distance = average_squared_distance + py_logger.info(f"Set cfg.model.average_squared_distance to {cfg.model.average_squared_distance}") + # Provide conditioner if needed + if not hasattr(cfg.model, "conditioner"): + OmegaConf.set_struct(cfg.model, False) # Allow modification + cfg.model.conditioner = OmegaConf.create({}) + cfg.model.conditioner._target_ = ( + "scratch.conditioners.PositionConditioner" # Use the PositionConditioner from scratch.conditioners + ) + OmegaConf.set_struct(cfg.model, True) # Lock structure again + py_logger.info("Set cfg.model.conditioner to instantiate 'scratch.conditioners.PositionConditioner'") + model = hydra.utils.instantiate(cfg.model) + py_logger.info("Successfully re-instantiated model with E3Conv from e3conv_test.py:") + print(model) + # You can inspect model.arch to confirm it's an instance of E3Conv + py_logger.info(f"Instantiated model architecture type: {type(model.g)}") + +except Exception as e: + py_logger.error(f"Error during model re-instantiation: {e}") + import traceback + + traceback.print_exc() + +# %% Tests for Denoiser.noise_and_denoise + +# Ensure 'model' (your Denoiser instance) and 'data_batch' are available from previous cells. +# If 'model' is not the correct Denoiser instance, re-instantiate it as needed. +# For example, if you were using the custom_denoiser_model: +# denoiser_to_test = custom_denoiser_model +# Or if you are using the one from the cfg re-instantiation: +denoiser_to_test = model + +# Make sure data_batch is on the same device as the model +if hasattr(denoiser_to_test, "device"): + data_batch = data_batch.to(denoiser_to_test.device) +elif next(denoiser_to_test.parameters()).is_cuda: + data_batch = data_batch.to(next(denoiser_to_test.parameters()).device) + + +py_logger.info(f"Testing Denoiser instance of type: {type(denoiser_to_test)}") +py_logger.info(f"Data batch has {data_batch.num_graphs} graphs and {data_batch.num_nodes} nodes.") + +# %% Tests for Denoiser object (denoiser_to_test) + + +py_logger.info("Starting tests for Denoiser object...") +_, data_batch = next(enumerate(datamodule.test_dataloader())) +# Ensure data_batch has hidden_state for the tests, matching model's N_structures +if ( + not hasattr(data_batch, "hidden_state") + or not isinstance(data_batch.hidden_state, list) + or len(data_batch.hidden_state) != denoiser_to_test.g._orig_mod.N_structures - 1 +): # Use _orig_mod to access N_structures if g is compiled + n_structures = denoiser_to_test.g._orig_mod.N_structures + py_logger.info(f"data_batch.hidden_state is missing or incorrect. Re-creating with {n_structures} structures.") + data_batch.hidden_state = [torch.randn_like(data_batch.pos) for _ in range(n_structures)] + data_batch.hidden_state = [hs.to(data_batch.pos.device) for hs in data_batch.hidden_state] +else: + py_logger.info(f"data_batch.hidden_state already exists with {len(data_batch.hidden_state)} structures.") + + +# %% Test 1: Denoiser.noise_and_denoise (align_noisy_input=False) + +try: + py_logger.info("Test 1: Denoiser.noise_and_denoise (align_noisy_input=False)") + original_x = data_batch.clone() + # sigma_test1 = torch.tensor(0.5, device=denoiser_to_test.device) + sigma = denoiser_to_test.sigma_distribution.sample() * 1e-5 + xhat1, y_processed1 = denoiser_to_test.noise_and_denoise(original_x.clone(), sigma, align_noisy_input=True) + + # assert isinstance(xhat1, torch_geometric.data.Batch), "xhat1 is not a PyG Batch object" + # assert isinstance(y_processed1, torch_geometric.data.Batch), "y_processed1 is not a PyG Batch object" + + assert xhat1.pos.shape == original_x.pos.shape, "xhat1.pos shape mismatch" + assert y_processed1.pos.shape == original_x.pos.shape, "y_processed1.pos shape mismatch" + + assert not torch.allclose(y_processed1.pos, original_x.pos), ( + "y_processed1.pos should be different from original x.pos" + ) + + assert xhat1.num_graphs == original_x.num_graphs, "xhat1 num_graphs mismatch" + assert y_processed1.num_graphs == original_x.num_graphs, "y_processed1 num_graphs mismatch" + assert xhat1.num_nodes == original_x.num_nodes, "xhat1 num_nodes mismatch" + assert y_processed1.num_nodes == original_x.num_nodes, "y_processed1 num_nodes mismatch" + assert torch.allclose(xhat1.batch, original_x.batch), "xhat1.batch mismatch" + assert torch.allclose(y_processed1.batch, original_x.batch), "y_processed1.batch mismatch" + + # Check hidden_state in y_processed1 (noisy input) + if hasattr(original_x, "hidden_state") and original_x.hidden_state: + assert hasattr(y_processed1, "hidden_state") and len(y_processed1.hidden_state) == len( + original_x.hidden_state + ), "y_processed1.hidden_state length mismatch" + for i in range(len(original_x.hidden_state)): + assert not torch.allclose(y_processed1.hidden_state[i], original_x.hidden_state[i]), ( + f"y_processed1.hidden_state[{i}] should be different" + ) + + # xhat inherits attributes from the input to xhat_normalized, which is the noisy graph (y_processed1) + # So, xhat1 should also have hidden_state if y_processed1 does. + if hasattr(y_processed1, "hidden_state") and y_processed1.hidden_state: + assert hasattr(xhat1, "hidden_state") and len(xhat1.hidden_state) == len(y_processed1.hidden_state), ( + "xhat1.hidden_state length mismatch with y_processed1" + ) + + py_logger.info("Test 1 PASSED.") +except Exception as e: + py_logger.error(f"Test 1 FAILED: {e}") + import traceback + + traceback.print_exc() + +# %% Test 3: Denoiser.training_step +try: + py_logger.info("Test 3: Denoiser.training_step") + # Get a fresh batch for training_step to avoid issues with modified data_batch from other tests + _, train_batch = next(enumerate(datamodule.test_dataloader())) # Using test_dataloader for convenience + + # Ensure train_batch has hidden_state + if ( + not hasattr(train_batch, "hidden_state") + or not isinstance(train_batch.hidden_state, list) + or len(train_batch.hidden_state) != denoiser_to_test.g._orig_mod.N_structures - 1 + ): + n_structures = denoiser_to_test.g._orig_mod.N_structures + train_batch.hidden_state = [torch.randn_like(train_batch.pos) for _ in range(n_structures)] + else: + py_logger.info(f"train_batch.hidden_state already exists with {len(train_batch.hidden_state)} structures.") + train_batch.hidden_state = [hs.to(denoiser_to_test.device) for hs in train_batch.hidden_state] + train_batch = train_batch.to(denoiser_to_test.device) + + # Manually set align_noisy_input_during_training if not set (it's a param of Denoiser) + if not hasattr(denoiser_to_test, "align_noisy_input_during_training"): + py_logger.warning("Denoiser missing 'align_noisy_input_during_training', defaulting to False for this test.") + denoiser_to_test.align_noisy_input_during_training = False # Or True, as needed + + logs_dict = denoiser_to_test.training_step(train_batch, 0) + + assert isinstance(logs_dict, dict), "Logs is not a dictionary" + expected_keys = ["mse", "rmsd", "scaled_rmsd", "loss"] + for key in expected_keys: + assert key in logs_dict, f"Key '{key}' missing in logs" + assert isinstance(logs_dict[key], torch.Tensor), f"Log value for '{key}' is not a tensor" + if key == "loss": + assert logs_dict[key].ndim == 0, f"logs_dict['loss'] is not scalar, shape: {logs_dict[key].shape}" + else: # mse, rmsd, scaled_rmsd are averaged in training_step's aux_mean + assert logs_dict[key].ndim == 0, f"logs_dict['{key}'] is not scalar, shape: {logs_dict[key].shape}" + + py_logger.info(f"Test 3 PASSED. Loss: {logs_dict['mse'].item()}") +except Exception as e: + py_logger.error(f"Test 3 FAILED: {e}") + import traceback + + traceback.print_exc() +# %% diff --git a/scratch/check_hidden_state.py b/scratch/check_hidden_state.py new file mode 100644 index 0000000..e0c7a6c --- /dev/null +++ b/scratch/check_hidden_state.py @@ -0,0 +1,106 @@ +import e3nn + +e3nn.set_optimization_defaults(jit_script_fx=False) +import logging +import os +import sys + +import dotenv +import torch +from denoiser_test import Denoiser +from hydra import compose, initialize + +import jamun.data +from jamun.utils import find_checkpoint + +# Setup logging +logging.basicConfig(format="[%(asctime)s][%(name)s][%(levelname)s] - %(message)s", level=logging.INFO) +py_logger = logging.getLogger("check_hidden_state") + +# Add project root to path +project_root = "/homefs/home/sules/jamun" +if project_root not in sys.path: + sys.path.insert(0, project_root) + +# Load configuration +with initialize(config_path="", job_name="check_hidden_state"): + cfg = compose( + config_name="config", + overrides=[ + "model.arch._target_=scratch.e3conv_test.E3Conv", + "model._target_=scratch.denoiser_test.Denoiser", + "+model.arch.N_structures=2", # We need at least 2 structures to test hidden state + "model.use_torch_compile=false", # Disable torch.compile to avoid ScriptModule issues + "+model.conditioner._target_=scratch.conditioners.SelfConditioner", + ], + ) + +# Load checkpoint +checkpoint_path = find_checkpoint(wandb_train_run_path="sule-shashank/jamun/y4rm5488", checkpoint_type="last") +checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + +# Modify hyperparameters to disable torch.compile +if "hyper_parameters" in checkpoint: + checkpoint["hyper_parameters"]["use_torch_compile"] = False + checkpoint["hyper_parameters"]["torch_compile_kwargs"] = None + +# Load model with modified hyperparameters +breakpoint() +model = Denoiser.load_from_checkpoint(checkpoint_path) +model.eval() + +# Get test data +dotenv.load_dotenv("../.env", verbose=True) +JAMUN_DATA_PATH = os.getenv("JAMUN_DATA_PATH") +JAMUN_ROOT_PATH = os.getenv("JAMUN_ROOT_PATH") + +datasets = { + "test": jamun.data.parse_datasets_from_directory( + root=f"{JAMUN_DATA_PATH}/timewarp/2AA-1-large/train/", + traj_pattern="^(.*)-traj-arrays.npz", + pdb_file="AA-traj-state0.pdb", + filter_codes=["AA"], + as_iterable=False, + subsample=100, + max_datasets=1, + ) +} + +datamodule = jamun.data.MDtrajDataModule( + datasets=datasets, + batch_size=3, + num_workers=2, +) +datamodule.setup("test") +_, test_data = next(enumerate(datamodule.test_dataloader())) +test_data = test_data.to(model.device) + +# Ensure test_data has hidden_state +if not hasattr(test_data, "hidden_state") or not test_data.hidden_state: + py_logger.info("Adding hidden state to test data") + test_data.hidden_state = [torch.randn_like(test_data.pos) for _ in range(model.g.N_structures - 1)] + +breakpoint() +# Add noise and denoise +sigma = torch.tensor(0.04) # Same sigma as in config +with torch.no_grad(): + xhat, y = model.noise_and_denoise(test_data, sigma, align_noisy_input=False) + +# Check if hidden state is preserved +print("\nChecking hidden state preservation:") +print(f"Original hidden state shapes: {[hs.shape for hs in test_data.hidden_state]}") +print(f"Noisy hidden state shapes: {[hs.shape for hs in y.hidden_state]}") +print(f"Denoised hidden state shapes: {[hs.shape for hs in xhat.hidden_state]}") + +# Check if hidden state values are preserved +for i in range(len(test_data.hidden_state)): + hidden_state_diff = torch.abs(xhat.hidden_state[i] - test_data.hidden_state[i]).mean() + print(f"\nMean absolute difference between original and denoised hidden state {i}: {hidden_state_diff.item()}") + +# Check if positions are actually denoised +pos_diff = torch.abs(xhat.pos - test_data.pos).mean() +print(f"Mean absolute difference between original and denoised positions: {pos_diff.item()}") + +# Check if noisy positions are different from original +noisy_pos_diff = torch.abs(y.pos - test_data.pos).mean() +print(f"Mean absolute difference between original and noisy positions: {noisy_pos_diff.item()}") diff --git a/scratch/check_model_arch.py b/scratch/check_model_arch.py new file mode 100644 index 0000000..849a0df --- /dev/null +++ b/scratch/check_model_arch.py @@ -0,0 +1,149 @@ +# %% +import functools +import logging +import os + +import dotenv + +logging.basicConfig(format="[%(asctime)s][%(name)s][%(levelname)s] - %(message)s", level=logging.INFO) +py_logger = logging.getLogger("jamun") + +import torch + +torch.cuda.is_available() +torch.set_float32_matmul_precision("high") + +import e3nn +import e3tools.nn + +e3nn.set_optimization_defaults(jit_script_fx=False) + +import jamun +import jamun.data +import jamun.distributions +import jamun.model +import jamun.model.arch + +# %% +# dataset +dotenv.load_dotenv("../.env", verbose=True) +JAMUN_DATA_PATH = os.getenv("JAMUN_DATA_PATH") +JAMUN_ROOT_PATH = os.getenv("JAMUN_ROOT_PATH") + +# %% +datasets = { + "test": jamun.data.parse_datasets_from_directory( + root=f"{JAMUN_DATA_PATH}/capped_diamines/timewarp_splits/train/", + traj_pattern="^(.*).xtc", + pdb_file="ALA_ALA.pdb", + filter_codes=["ALA_ALA"], + as_iterable=False, + subsample=100, + total_lag_time=10, + lag_subsample_rate=100, + max_datasets=1, + ) +} + +# %% +datamodule = jamun.data.MDtrajDataModule( + datasets=datasets, + batch_size=5, + num_workers=2, +) +datamodule.setup("test") +_, data_batch = next(enumerate(datamodule.test_dataloader())) +print(f"Number of hidden states: {len(data_batch.hidden_state)}") +print(f"Size of one hidden state: {data_batch.hidden_state[0].shape}") +# %% test the new e3conv_test class +import torch_geometric +from e3conv_test import E3Conv +from e3tools import radius_graph + +trial_model = E3Conv( + irreps_out="1x1e", + irreps_hidden="120x0e + 32x1e", + irreps_sh="1x0e + 1x1e", + n_layers=1, + edge_attr_dim=8, + atom_type_embedding_dim=8, + atom_code_embedding_dim=8, + residue_code_embedding_dim=32, + residue_index_embedding_dim=8, + use_residue_information=False, + use_residue_sequence_index=False, + hidden_layer_factory=functools.partial( + e3tools.nn.ConvBlock, + conv=e3tools.nn.Conv, + ), + output_head_factory=functools.partial(e3tools.nn.EquivariantMLP, irreps_hidden_list=["120x0e + 32x1e"]), + N_structures=2, +) + + +# %% postprocess data for plugging into model +def add_bond_mask(y: torch_geometric.data.Batch, cutoff: float = 1.0) -> torch_geometric.data.Batch: + radial_edge_index = radius_graph(y.pos, cutoff, batch=y["batch"]) + bonded_edge_index = y.edge_index + edge_index = torch.cat((radial_edge_index, bonded_edge_index), dim=-1) + bond_mask = torch.cat( + ( + torch.zeros(radial_edge_index.shape[1], dtype=torch.long, device=y.edge_index.device), + torch.ones(bonded_edge_index.shape[1], dtype=torch.long, device=y.edge_index.device), + ), + dim=0, + ) + y.edge_index = edge_index + y.bond_mask = bond_mask + return y + + +# add bond mask--do this only once! +bond_mask_exists = hasattr(data_batch, "bond_mask") and data_batch.bond_mask is not None +if not bond_mask_exists: + py_logger.info("Adding bond mask to data_batch...") + # Ensure data_batch is a torch_geometric.data.Batch + if not isinstance(data_batch, torch_geometric.data.Batch): + raise TypeError(f"Expected data_batch to be a torch_geometric.data.Batch, got {type(data_batch)}") + + # Add bond mask + data_batch = add_bond_mask(data_batch) +else: + py_logger.info("Bond mask already exists in data_batch, skipping addition.") + # If bond mask already exists, we can still use it + # but we should ensure it's in the correct format + if not isinstance(data_batch.bond_mask, torch.Tensor): + raise TypeError(f"Expected data_batch.bond_mask to be a torch.Tensor, got {type(data_batch.bond_mask)}") + if data_batch.bond_mask.dtype != torch.long: + raise ValueError(f"Expected data_batch.bond_mask to be of dtype torch.long, got {data_batch.bond_mask.dtype}") + + # Ensure edge_index is set correctly + if not hasattr(data_batch, "edge_index") or data_batch.edge_index is None: + raise ValueError("data_batch.edge_index is not set. Please ensure it is initialized before adding bond mask.") + + # If everything is fine, we can proceed with the existing bond mask +py_logger.info(f"data_batch has {data_batch.num_graphs} graphs and {data_batch.num_nodes} nodes.") +# Ensure data_batch is on the same device as the model +if hasattr(trial_model, "device"): + data_batch = data_batch.to(trial_model.device) +elif next(trial_model.parameters()).is_cuda: + data_batch = data_batch.to(next(trial_model.parameters()).device) + +# %% Test the E3Conv model with the data_batch +py_logger.info("Testing E3Conv model with data_batch...") +# Ensure data_batch is on the same device as the model +y = data_batch +if hasattr(trial_model, "device"): + y = y.to(trial_model.device) +elif next(trial_model.parameters()).is_cuda: + y = y.to(next(trial_model.parameters()).device) +# Run a forward pass through the model +try: + py_logger.info("Running forward pass through E3Conv model...") + output = trial_model(torch.cat([y.pos, *y.hidden_state], dim=-1), y, torch.Tensor([1e-5]), 100.0) + py_logger.info(f"Output shape: {output.shape}") +except Exception as e: + py_logger.error(f"Error during forward pass: {e}") + import traceback + + traceback.print_exc() diff --git a/scratch/check_trajectory_length.py b/scratch/check_trajectory_length.py new file mode 100644 index 0000000..7197083 --- /dev/null +++ b/scratch/check_trajectory_length.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +""" +Check the raw trajectory length in xtc files. +""" + +import logging +import os + +import mdtraj as md + +# Set up logging +logging.basicConfig(format="[%(asctime)s][%(name)s][%(levelname)s] - %(message)s", level=logging.INFO) +py_logger = logging.getLogger("traj_length_check") + + +def check_trajectory_lengths(): + """Check the length of trajectories in raw xtc files.""" + + dataset_root = "/data2/sules/fake_enhanced_data/ALA_ALA_organized/train" + pdb_file = "/data/bucket/kleinhej/capped_diamines/timewarp_splits/train/ALA_ALA.pdb" + + py_logger.info("CHECKING RAW TRAJECTORY LENGTHS") + py_logger.info("=" * 50) + + # Get a few trajectory files to sample + import glob + + xtc_files = glob.glob(os.path.join(dataset_root, "*.xtc")) + + py_logger.info(f"Found {len(xtc_files)} total xtc files") + py_logger.info("Checking first 10 files...") + + lengths = [] + + for i, xtc_file in enumerate(xtc_files[:10]): + try: + # Load trajectory with topology + traj = md.load(xtc_file, top=pdb_file) + length = traj.n_frames + lengths.append(length) + + filename = os.path.basename(xtc_file) + py_logger.info(f"{i + 1:2d}. {filename}: {length} frames") + + except Exception as e: + py_logger.error(f"Error loading {xtc_file}: {e}") + + if lengths: + py_logger.info("-" * 50) + py_logger.info(f"Statistics from {len(lengths)} files:") + py_logger.info(f" Minimum length: {min(lengths)} frames") + py_logger.info(f" Maximum length: {max(lengths)} frames") + py_logger.info(f" Average length: {sum(lengths) / len(lengths):.1f} frames") + py_logger.info(f" All lengths: {sorted(set(lengths))}") + + # Show how subsampling affects this + py_logger.info("\nEffect of subsampling (with lag requirements):") + original_length = sum(lengths) / len(lengths) + + # The lag requirements mean we need at least total_lag_time * lag_subsample_rate frames + # to get any output, and then we lose some frames at the beginning + test_cases = [ + {"subsample": 1, "total_lag_time": 5, "lag_subsample_rate": 1}, + {"subsample": 5, "total_lag_time": 5, "lag_subsample_rate": 1}, + {"subsample": 10, "total_lag_time": 5, "lag_subsample_rate": 1}, + {"subsample": 20, "total_lag_time": 5, "lag_subsample_rate": 1}, + ] + + for params in test_cases: + # Estimate how many frames we'd get after subsampling and lag filtering + # The algorithm starts from frames that have enough history + min_start_frame = (params["total_lag_time"] - 1) * params["lag_subsample_rate"] + available_frames = max(0, original_length - min_start_frame) + subsampled_frames = available_frames // params["subsample"] + + py_logger.info( + f" subsample={params['subsample']:2d}: ~{subsampled_frames:.0f} frames " + f"(from {original_length:.0f} original)" + ) + + +if __name__ == "__main__": + check_trajectory_lengths() diff --git a/scratch/config.yaml b/scratch/config.yaml new file mode 100644 index 0000000..f1446de --- /dev/null +++ b/scratch/config.yaml @@ -0,0 +1,126 @@ +float32_matmul_precision: high +task_name: train +run_group: dev +run_key: ${now:%Y-%m-%d}_${now:%H-%M-%S} +python: + version: ${python_version:micro} +init_time: ${now:%y-%m-%d_%H:%M:%S} +compute_average_squared_distance_from_data: true +data: + datamodule: + _target_: jamun.data.MDtrajDataModule + batch_size: 32 + num_workers: 4 + datasets: + train: + - _target_: jamun.data.MDtrajDataset + root: ${paths.data_path}/timewarp/2AA-1-large/train/ + traj_files: + - AA-traj-arrays.npz + pdb_file: AA-traj-state0.pdb + subsample: 100 + label: AA + val: + - _target_: jamun.data.MDtrajDataset + root: ${paths.data_path}/timewarp/2AA-1-large/train/ + traj_files: + - AA-traj-arrays.npz + pdb_file: AA-traj-state0.pdb + subsample: 100 + label: AA + test: + - _target_: jamun.data.MDtrajDataset + root: ${paths.data_path}/timewarp/2AA-1-large/train/ + traj_files: + - AA-traj-arrays.npz + pdb_file: AA-traj-state0.pdb + subsample: 100 + label: AA + use_residue_information: true +model: + arch: + _target_: scratch.e3conv_test.E3Conv + _partial_: true + irreps_out: 1x1e + irreps_hidden: 120x0e + 32x1e + irreps_sh: 1x0e + 1x1e + n_layers: 2 + edge_attr_dim: 64 + atom_type_embedding_dim: 8 + atom_code_embedding_dim: 8 + residue_code_embedding_dim: 32 + residue_index_embedding_dim: 8 + use_residue_information: ${data.use_residue_information} + use_residue_sequence_index: false + num_atom_types: 20 + max_sequence_length: 10 + num_atom_codes: 10 + num_residue_types: 25 + hidden_layer_factory: + _target_: e3tools.nn.ConvBlock + _partial_: true + conv: + _target_: e3tools.nn.Conv + _partial_: true + output_head_factory: + _target_: e3tools.nn.EquivariantMLP + _partial_: true + irreps_hidden_list: + - ${model.arch.irreps_hidden} + optim: + _target_: torch.optim.Adam + _partial_: true + lr: 0.002 + weight_decay: 0.0 + max_radius: 1000.0 + average_squared_distance: null + add_fixed_noise: false + add_fixed_ones: false + align_noisy_input_during_training: true + align_noisy_input_during_evaluation: true + mean_center: true + mirror_augmentation_rate: 0.0 + use_torch_compile: true + torch_compile_kwargs: + fullgraph: true + dynamic: true + mode: default + _target_: jamun.model.Denoiser + sigma_distribution: + _target_: jamun.distributions.ConstantSigma + sigma: 0.04 +trainer: + _target_: lightning.Trainer + limit_train_batches: 1.0 + val_check_interval: 0.5 + max_epochs: 200 +logger: + wandb: + _target_: lightning.pytorch.loggers.WandbLogger + project: jamun + entity: null + offline: false + log_model: all + group: train_test + notes: alanine dipeptide, without conditioning. + save_dir: +callbacks: + timing: + _target_: jamun.callbacks.Timing + lr_monitor: + _target_: lightning.pytorch.callbacks.LearningRateMonitor + model_checkpoint: + _target_: lightning.pytorch.callbacks.ModelCheckpoint + dirpath: ${hydra:runtime.output_dir}/checkpoints + save_top_k: 5 + save_last: true + monitor: val/loss + viz: + _target_: jamun.callbacks.VisualizeDenoise + datasets: ${data.datamodule.datasets.val} + sigma_list: + - ${model.sigma_distribution.sigma} +paths: + root_path: ${oc.env:JAMUN_ROOT_PATH, "."} + data_path: ${oc.env:JAMUN_DATA_PATH, ${paths.root_path}/data} + run_path: ${paths.root_path}/outputs/${task_name}/${run_group}/runs/${run_key} diff --git a/scratch/diagnose_mdtraj_dataset.py b/scratch/diagnose_mdtraj_dataset.py new file mode 100644 index 0000000..97916a0 --- /dev/null +++ b/scratch/diagnose_mdtraj_dataset.py @@ -0,0 +1,232 @@ +#!/usr/bin/env python3 +""" +Diagnostic script to understand MDtrajDataset subsampling and lag behavior. +""" + +import os +import sys + +import mdtraj as md + +# Add the project root to the path +sys.path.insert(0, "/homefs/home/sules/jamun") + +from jamun.data._mdtraj import MDtrajDataset, get_subsampled_indices + + +def test_trajectory_loading(): + """Test basic trajectory loading without any subsampling.""" + print("=" * 60) + print("TESTING BASIC TRAJECTORY LOADING") + print("=" * 60) + + # Use one of your actual trajectory files + traj_file = "/data2/sules/ALA_ALA_enhanced_full_grid/train/swarm_1ps_000_001.xtc" + pdb_file = "/data2/sules/ALA_ALA_enhanced_full_grid/train/swarm_1ps_000_001.pdb" + + print(f"Loading trajectory: {traj_file}") + print(f"Loading topology: {pdb_file}") + + # Test direct mdtraj loading + direct_traj = md.load(traj_file, top=pdb_file) + print(f"Direct mdtraj load: {direct_traj.n_frames} frames, {direct_traj.n_atoms} atoms") + + # Test MDtrajDataset without any subsampling parameters + print("\n--- Testing MDtrajDataset with default parameters ---") + dataset = MDtrajDataset( + root="/data2/sules/ALA_ALA_enhanced_full_grid/train", + traj_files=["swarm_1ps_000_001.xtc"], + pdb_file="swarm_1ps_000_001.pdb", + label="test_basic", + verbose=True, + ) + + print(f"MDtrajDataset length: {len(dataset)}") + print(f"Dataset trajectory frames: {dataset.traj.n_frames}") + print(f"Dataset trajectory atoms: {dataset.traj.n_atoms}") + + # Check if there are any default parameters being set + print(f"Dataset num_frames param: {getattr(dataset, 'num_frames', 'Not set')}") + print(f"Dataset start_frame param: {getattr(dataset, 'start_frame', 'Not set')}") + print(f"Dataset subsample param: {getattr(dataset, 'subsample', 'Not set')}") + + +def test_subsampling_behavior(): + """Test different subsampling scenarios.""" + print("\n" + "=" * 60) + print("TESTING SUBSAMPLING BEHAVIOR") + print("=" * 60) + + base_params = { + "root": "/data2/sules/ALA_ALA_enhanced_full_grid/train", + "traj_files": ["swarm_1ps_000_001.xtc"], + "pdb_file": "swarm_1ps_000_001.pdb", + "label": "test_subsample", + "verbose": True, + } + + # Test 1: Explicit num_frames + print("\n--- Test 1: Explicit num_frames ---") + dataset1 = MDtrajDataset(**base_params, num_frames=100) + print(f"With num_frames=100: {len(dataset1)} frames") + + # Test 2: Explicit num_frames = -1 (should load all) + print("\n--- Test 2: num_frames=-1 (load all) ---") + dataset2 = MDtrajDataset(**base_params, num_frames=-1) + print(f"With num_frames=-1: {len(dataset2)} frames") + + # Test 3: No num_frames specified + print("\n--- Test 3: No num_frames specified ---") + dataset3 = MDtrajDataset(**base_params) + print(f"With default num_frames: {len(dataset3)} frames") + + # Test 4: Explicit subsample + print("\n--- Test 4: With subsample=2 ---") + dataset4 = MDtrajDataset(**base_params, num_frames=-1, subsample=2) + print(f"With subsample=2: {len(dataset4)} frames") + + +def test_lag_subsampling(): + """Test lag-based subsampling behavior.""" + print("\n" + "=" * 60) + print("TESTING LAG SUBSAMPLING") + print("=" * 60) + + # First get the actual trajectory length + traj_file = "/data2/sules/ALA_ALA_enhanced_full_grid/train/swarm_1ps_000_001.xtc" + pdb_file = "/data2/sules/ALA_ALA_enhanced_full_grid/train/swarm_1ps_000_001.pdb" + direct_traj = md.load(traj_file, top=pdb_file) + print(f"Actual trajectory length: {direct_traj.n_frames} frames") + + # Test the get_subsampled_indices function directly + print("\n--- Testing get_subsampled_indices function ---") + + test_cases = [ + {"N": 50, "subsample": 1, "total_lag_time": 5, "lag_subsample_rate": 1}, + {"N": 250, "subsample": 1, "total_lag_time": 5, "lag_subsample_rate": 1}, + {"N": 50, "subsample": 1, "total_lag_time": 2, "lag_subsample_rate": 1}, + ] + + for i, params in enumerate(test_cases): + print(f"\nTest case {i + 1}: {params}") + try: + indices = get_subsampled_indices(**params) + print(f" Result: {len(indices)} valid starting points") + if len(indices) <= 5: + print(f" Indices: {indices}") + else: + print(f" First 3 indices: {indices[:3]}") + print(f" Last 3 indices: {indices[-3:]}") + except Exception as e: + print(f" Error: {e}") + + # Test actual MDtrajDataset with lag parameters + print("\n--- Testing MDtrajDataset with lag parameters ---") + + base_params = { + "root": "/data2/sules/ALA_ALA_enhanced_full_grid/train", + "traj_files": ["swarm_1ps_000_001.xtc"], + "pdb_file": "swarm_1ps_000_001.pdb", + "label": "test_lag", + "verbose": True, + } + + # Test with different configurations + lag_configs = [ + {"total_lag_time": 5, "lag_subsample_rate": 1}, + {"total_lag_time": 5, "lag_subsample_rate": 1, "num_frames": -1}, + {"total_lag_time": 2, "lag_subsample_rate": 1}, + {"total_lag_time": 5, "lag_subsample_rate": 1, "subsample": 1}, + ] + + for i, config in enumerate(lag_configs): + print(f"\nLag config {i + 1}: {config}") + try: + dataset = MDtrajDataset(**base_params, **config) + print(f" Dataset length: {len(dataset)}") + print(f" Trajectory frames: {dataset.traj.n_frames}") + if hasattr(dataset, "hidden_state") and dataset.hidden_state: + print(f" Hidden states: {len(dataset.hidden_state)} sets") + if len(dataset.hidden_state) > 0: + print(f" Hidden state 0 length: {len(dataset.hidden_state[0])}") + print(f" Lagged indices available: {dataset.lagged_indices is not None}") + except Exception as e: + print(f" Error: {e}") + + +def test_configuration_parsing(): + """Test how the configuration parameters are being processed.""" + print("\n" + "=" * 60) + print("TESTING CONFIGURATION PARAMETER PROCESSING") + print("=" * 60) + + # Simulate the exact configuration from your experiment + print("Simulating experiment configuration:") + config = { + "root": "/data2/sules/ALA_ALA_enhanced_full_grid/train", + "traj_pattern": "^(.*).xtc", + "pdb_pattern": "^(.*).pdb", + "subsample": 1, + "total_lag_time": 5, + "lag_subsample_rate": 1, + "max_datasets": 1, # This limits to 1 dataset + } + + print(f"Config: {config}") + + # This should be what parse_datasets_from_directory creates + from jamun.data._utils import parse_datasets_from_directory + + print("\nCreating datasets with parse_datasets_from_directory...") + try: + datasets = parse_datasets_from_directory(**config) + print(f"Number of datasets created: {len(datasets)}") + + for i, dataset in enumerate(datasets[:3]): # Show first 3 + print(f"\nDataset {i}: {dataset.label()}") + print(f" Length: {len(dataset)}") + print(f" Trajectory frames: {dataset.traj.n_frames}") + print(f" Has hidden states: {dataset.hidden_state is not None}") + if dataset.hidden_state: + print(f" Hidden states count: {len(dataset.hidden_state)}") + + except Exception as e: + print(f"Error creating datasets: {e}") + import traceback + + traceback.print_exc() + + +def main(): + """Run all diagnostic tests.""" + print("MDTRAJ DATASET DIAGNOSTIC SCRIPT") + print("=" * 60) + + # Check if files exist + test_files = [ + "/data2/sules/ALA_ALA_enhanced_full_grid/train/swarm_1ps_000_001.xtc", + "/data2/sules/ALA_ALA_enhanced_full_grid/train/swarm_1ps_000_001.pdb", + ] + + for file_path in test_files: + if os.path.exists(file_path): + print(f"✅ Found: {file_path}") + else: + print(f"❌ Missing: {file_path}") + return + + try: + test_trajectory_loading() + test_subsampling_behavior() + test_lag_subsampling() + test_configuration_parsing() + + except Exception as e: + print(f"\n❌ ERROR: {e}") + import traceback + + traceback.print_exc() + + +if __name__ == "__main__": + main() diff --git a/scratch/explore_fake_enhanced_dataset.py b/scratch/explore_fake_enhanced_dataset.py new file mode 100644 index 0000000..a882b6a --- /dev/null +++ b/scratch/explore_fake_enhanced_dataset.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python3 +""" +Script to explore the fake enhanced dataset using parse_datasets_from_directory. + +Specifically explores how many trajectories we get when: +- subsample_rate = 10 (called 'subsample' in the function) +- total_lag_time = 5 +- lag_subsample_rate = 1 +""" + +import logging +import os +import sys + +import dotenv + +# Set up logging +logging.basicConfig(format="[%(asctime)s][%(name)s][%(levelname)s] - %(message)s", level=logging.INFO) +py_logger = logging.getLogger("fake_enhanced_exploration") + +# Load environment variables +dotenv.load_dotenv(".env", verbose=True) +JAMUN_DATA_PATH = os.getenv("JAMUN_DATA_PATH") + +# Add jamun to path if needed +project_root = os.path.abspath(".") +if project_root not in sys.path: + sys.path.insert(0, project_root) + +import jamun +import jamun.data + + +def explore_dataset_parameters(): + """Explore the fake enhanced dataset with specified parameters.""" + + # Dataset parameters as requested + dataset_root = "/data2/sules/fake_enhanced_data/ALA_ALA_organized" + pdb_file = "/data/bucket/kleinhej/capped_diamines/timewarp_splits/train/ALA_ALA.pdb" + traj_pattern = "^(.*).xtc" + + # Subsampling parameters as specified by user + subsample_rate = 10 # Called 'subsample' in the function + total_lag_time = 5 + lag_subsample_rate = 1 + + py_logger.info("=" * 60) + py_logger.info("EXPLORING FAKE ENHANCED DATASET") + py_logger.info("=" * 60) + py_logger.info(f"Dataset root: {dataset_root}") + py_logger.info(f"PDB file: {pdb_file}") + py_logger.info(f"Trajectory pattern: {traj_pattern}") + py_logger.info(f"Subsample rate: {subsample_rate}") + py_logger.info(f"Total lag time: {total_lag_time}") + py_logger.info(f"Lag subsample rate: {lag_subsample_rate}") + py_logger.info("=" * 60) + + # Parse datasets for each split + for split in ["train", "val", "test"]: + py_logger.info(f"\n--- Exploring {split.upper()} split ---") + + try: + datasets = jamun.data.parse_datasets_from_directory( + root=f"{dataset_root}/{split}", + traj_pattern=traj_pattern, + pdb_file=pdb_file, + as_iterable=False, + subsample=subsample_rate, + total_lag_time=total_lag_time, + lag_subsample_rate=lag_subsample_rate, + max_datasets=None, # Load all datasets to get full count + verbose=True, + ) + + py_logger.info(f"Number of datasets found: {len(datasets)}") + + if datasets: + # Analyze first dataset in detail + first_dataset = datasets[0] + py_logger.info(f"First dataset label: {first_dataset.label()}") + py_logger.info(f"Number of frames in first dataset: {len(first_dataset)}") + + # Check hidden state structure + sample_data = first_dataset[0] + if hasattr(sample_data, "hidden_state") and sample_data.hidden_state: + py_logger.info(f"Hidden state length: {len(sample_data.hidden_state)}") + py_logger.info(f"Shape of first hidden state: {sample_data.hidden_state[0].shape}") + else: + py_logger.info("No hidden state found (expected for regular subsampling)") + + # Calculate total trajectories across all datasets + total_frames = sum(len(dataset) for dataset in datasets) + py_logger.info(f"Total frames across all datasets: {total_frames}") + + # Estimate original frames before subsampling + original_frames_estimate = total_frames * subsample_rate + py_logger.info(f"Estimated original frames (before subsampling): {original_frames_estimate}") + + # Show some dataset labels + py_logger.info(f"First 5 dataset labels: {[ds.label() for ds in datasets[:5]]}") + if len(datasets) > 5: + py_logger.info(f"... and {len(datasets) - 5} more datasets") + + except Exception as e: + py_logger.error(f"Error processing {split} split: {e}") + continue + + py_logger.info("\n" + "=" * 60) + py_logger.info("EXPLORATION COMPLETE") + py_logger.info("=" * 60) + + +def compare_with_different_parameters(): + """Compare trajectory counts with different subsampling parameters.""" + + py_logger.info("\n" + "=" * 60) + py_logger.info("PARAMETER COMPARISON") + py_logger.info("=" * 60) + + dataset_root = "/data2/sules/fake_enhanced_data/ALA_ALA_organized/train" # Just use train split + pdb_file = "/data/bucket/kleinhej/capped_diamines/timewarp_splits/train/ALA_ALA.pdb" + traj_pattern = "^(.*).xtc" + + # Test different parameter combinations + test_cases = [ + {"subsample": 1, "total_lag_time": None, "lag_subsample_rate": None, "desc": "No subsampling"}, + {"subsample": 10, "total_lag_time": None, "lag_subsample_rate": None, "desc": "Subsample 10, no lag"}, + {"subsample": 10, "total_lag_time": 5, "lag_subsample_rate": 1, "desc": "User's requested parameters"}, + {"subsample": 10, "total_lag_time": 3, "lag_subsample_rate": 1, "desc": "Different lag time"}, + {"subsample": 5, "total_lag_time": 5, "lag_subsample_rate": 1, "desc": "Different subsample rate"}, + ] + + for i, params in enumerate(test_cases): + py_logger.info(f"\nTest case {i + 1}: {params['desc']}") + py_logger.info( + f"Parameters: subsample={params['subsample']}, total_lag_time={params['total_lag_time']}, lag_subsample_rate={params['lag_subsample_rate']}" + ) + + try: + # Limit to first few datasets for speed + datasets = jamun.data.parse_datasets_from_directory( + root=dataset_root, + traj_pattern=traj_pattern, + pdb_file=pdb_file, + as_iterable=False, + subsample=params["subsample"], + total_lag_time=params["total_lag_time"], + lag_subsample_rate=params["lag_subsample_rate"], + max_datasets=3, # Limit for speed + verbose=False, + ) + + if datasets: + frames_per_dataset = [len(ds) for ds in datasets] + total_frames = sum(frames_per_dataset) + py_logger.info(f" -> {len(datasets)} datasets, {total_frames} total frames") + py_logger.info(f" -> Frames per dataset: {frames_per_dataset}") + + # Check if lagged data exists + sample = datasets[0][0] + if hasattr(sample, "hidden_state") and sample.hidden_state: + py_logger.info(f" -> Hidden state length: {len(sample.hidden_state)}") + else: + py_logger.info(" -> No hidden state") + else: + py_logger.warning(" -> No datasets found") + + except Exception as e: + py_logger.error(f" -> Error: {e}") + + +if __name__ == "__main__": + # First explore with user's specific parameters + explore_dataset_parameters() + + # Then compare with different parameters + compare_with_different_parameters() diff --git a/scratch/explore_fake_enhanced_minimal.py b/scratch/explore_fake_enhanced_minimal.py new file mode 100644 index 0000000..b666e37 --- /dev/null +++ b/scratch/explore_fake_enhanced_minimal.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 +""" +Minimal script to explore fake enhanced dataset trajectory counts. + +Answers: How many trajectories do we get when subsample=10, total_lag_time=5, lag_subsample_rate=1? +""" + +import logging +import os +import sys + +import dotenv + +# Set up logging +logging.basicConfig(format="[%(asctime)s][%(name)s][%(levelname)s] - %(message)s", level=logging.INFO) +py_logger = logging.getLogger("minimal_exploration") + +# Load environment variables +dotenv.load_dotenv(".env", verbose=True) + +# Add jamun to path +project_root = os.path.abspath(".") +if project_root not in sys.path: + sys.path.insert(0, project_root) + +import jamun +import jamun.data + + +def quick_exploration(): + """Quick exploration of dataset with limited scope.""" + + # Dataset parameters + dataset_root = "/data2/sules/fake_enhanced_data/ALA_ALA_organized/train" # Just train for speed + pdb_file = "/data/bucket/kleinhej/capped_diamines/timewarp_splits/train/ALA_ALA.pdb" + traj_pattern = "^(.*).xtc" + + # User's parameters + subsample_rate = 10 + total_lag_time = 5 + lag_subsample_rate = 1 + + py_logger.info("MINIMAL EXPLORATION OF FAKE ENHANCED DATASET") + py_logger.info("=" * 60) + py_logger.info( + f"Parameters: subsample={subsample_rate}, total_lag_time={total_lag_time}, lag_subsample_rate={lag_subsample_rate}" + ) + py_logger.info("=" * 60) + + try: + # Limit to first 5 datasets for speed + py_logger.info("Loading first 5 datasets from train split...") + datasets = jamun.data.parse_datasets_from_directory( + root=dataset_root, + traj_pattern=traj_pattern, + pdb_file=pdb_file, + as_iterable=False, + subsample=subsample_rate, + total_lag_time=total_lag_time, + lag_subsample_rate=lag_subsample_rate, + max_datasets=5, # LIMIT for speed + verbose=True, + ) + + py_logger.info(f"Successfully loaded {len(datasets)} datasets") + + if datasets: + # Analyze each dataset + total_frames = 0 + for i, dataset in enumerate(datasets): + frames = len(dataset) + total_frames += frames + py_logger.info(f"Dataset {i + 1} ('{dataset.label()}'): {frames} frames") + + # Check first dataset in detail + if i == 0: + sample = dataset[0] + py_logger.info(f" Sample position shape: {sample.pos.shape}") + if hasattr(sample, "hidden_state") and sample.hidden_state: + py_logger.info(f" Hidden state: {len(sample.hidden_state)} lag frames") + py_logger.info(f" First hidden state shape: {sample.hidden_state[0].shape}") + else: + py_logger.info(" No hidden state found") + + py_logger.info("-" * 40) + py_logger.info(f"TOTAL FRAMES across {len(datasets)} datasets: {total_frames}") + py_logger.info(f"Average frames per dataset: {total_frames / len(datasets):.1f}") + + # Extrapolate to estimate full dataset + py_logger.info("\nESTIMATING FULL DATASET:") + py_logger.info("Assuming all datasets have similar sizes...") + + # We could try to count total datasets but that might be slow + # Instead, let's just report what we found + py_logger.info( + f"With subsample={subsample_rate}, each dataset gives ~{total_frames / len(datasets):.0f} trajectories" + ) + py_logger.info(f"Total lag time {total_lag_time} creates hidden states for conditional training") + + else: + py_logger.warning("No datasets found!") + + except Exception as e: + py_logger.error(f"Error: {e}") + import traceback + + traceback.print_exc() + + +def test_different_subsample_rates(): + """Test how trajectory count changes with different subsample rates.""" + + py_logger.info("\n" + "=" * 60) + py_logger.info("TESTING DIFFERENT SUBSAMPLE RATES") + py_logger.info("=" * 60) + + dataset_root = "/data2/sules/fake_enhanced_data/ALA_ALA_organized/train" + pdb_file = "/data/bucket/kleinhej/capped_diamines/timewarp_splits/train/ALA_ALA.pdb" + traj_pattern = "^(.*).xtc" + + # Test different subsample rates (keeping lag parameters constant) + test_params = [ + {"subsample": 1, "total_lag_time": 5, "lag_subsample_rate": 1, "desc": "No subsampling"}, + {"subsample": 5, "total_lag_time": 5, "lag_subsample_rate": 1, "desc": "Subsample 5"}, + {"subsample": 10, "total_lag_time": 5, "lag_subsample_rate": 1, "desc": "User's parameters (subsample 10)"}, + {"subsample": 20, "total_lag_time": 5, "lag_subsample_rate": 1, "desc": "Subsample 20"}, + ] + + for params in test_params: + py_logger.info(f"\nTesting: {params['desc']}") + py_logger.info( + f" subsample={params['subsample']}, total_lag_time={params['total_lag_time']}, lag_subsample_rate={params['lag_subsample_rate']}" + ) + + try: + # Load just one dataset for comparison + datasets = jamun.data.parse_datasets_from_directory( + root=dataset_root, + traj_pattern=traj_pattern, + pdb_file=pdb_file, + as_iterable=False, + subsample=params["subsample"], + total_lag_time=params["total_lag_time"], + lag_subsample_rate=params["lag_subsample_rate"], + max_datasets=1, # Just one dataset for speed + verbose=False, + ) + + if datasets: + frames = len(datasets[0]) + py_logger.info(f" -> {frames} frames in first dataset") + else: + py_logger.info(" -> No datasets found") + + except Exception as e: + py_logger.info(f" -> Error: {e}") + + +if __name__ == "__main__": + quick_exploration() + test_different_subsample_rates() diff --git a/scratch/hydra_trials.py b/scratch/hydra_trials.py new file mode 100644 index 0000000..65008fa --- /dev/null +++ b/scratch/hydra_trials.py @@ -0,0 +1,72 @@ +# %% imports +import logging +import os +import sys +import traceback + +import dotenv +import hydra +from omegaconf import OmegaConf + +# --- Basic Setup --- +logging.basicConfig(format="[%(asctime)s][%(name)s][%(levelname)s] - %(message)s", level=logging.INFO) +py_logger = logging.getLogger("jamun_sampling_script") + +# Add project root to path for custom modules +project_root = "/homefs/home/sules/jamun" +if project_root not in sys.path: + sys.path.insert(0, project_root) + py_logger.info(f"Added '{project_root}' to sys.path for module discovery.") +else: + py_logger.info(f"'{project_root}' is already in sys.path.") + +dotenv.load_dotenv("../.env", verbose=True) +JAMUN_DATA_PATH = os.getenv("JAMUN_DATA_PATH") +JAMUN_ROOT_PATH = os.getenv("JAMUN_ROOT_PATH") + + +def print_config_sections(cfg): + """Print different sections of the configuration.""" + print("\nFull Configuration:") + print(OmegaConf.to_yaml(cfg)) + + if hasattr(cfg, "model"): + print("\nModel Configuration:") + print(OmegaConf.to_yaml(cfg.model)) + + if hasattr(cfg, "init_datasets"): + print("\nDataset Configuration:") + print(OmegaConf.to_yaml(cfg.init_datasets)) + + if hasattr(cfg, "sampler"): + print("\nSampler Configuration:") + print(OmegaConf.to_yaml(cfg.sampler)) + + +def run(cfg): + """Main function to run the config loading and printing.""" + try: + # Print the loaded configuration + print_config_sections(cfg) + + # Print specific config values + if hasattr(cfg, "model"): + print("\nModel target:", cfg.model._target_) + if hasattr(cfg, "sampler"): + print("Sampler target:", cfg.sampler._target_) + except Exception: + traceback.print_exc(file=sys.stderr) + raise + + +@hydra.main(version_base=None, config_path="../src/jamun/hydra_config", config_name="sample") +def main(cfg): + try: + run(cfg) + except Exception: + traceback.print_exc(file=sys.stderr) + raise + + +if __name__ == "__main__": + main() diff --git a/scratch/inspect_model_minimal.py b/scratch/inspect_model_minimal.py new file mode 100644 index 0000000..9a94a58 --- /dev/null +++ b/scratch/inspect_model_minimal.py @@ -0,0 +1,18 @@ +import sys + +sys.path.insert(0, "src") + +from jamun.model.denoiser_conditional import Denoiser + +print("Loading model from checkpoint...") +# Load model +model = Denoiser.load_from_checkpoint( + "/data2/sules/jamun-conditional-runs/outputs/train/dev/runs/2025-07-08_22-31-01/checkpoints/last.ckpt" +) + +# Print key info +print(f"Model: {type(model).__name__}") +print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}") +print(f"Conditioner: {model.conditioner}") +print(f"Architecture: {model.g}") +print(f"Hyperparams: {dict(model.hparams)}") diff --git a/scratch/load_model_state_dict.py b/scratch/load_model_state_dict.py new file mode 100644 index 0000000..2470b83 --- /dev/null +++ b/scratch/load_model_state_dict.py @@ -0,0 +1,24 @@ +import hydra +import torch +from omegaconf import OmegaConf + +# Load the config +config_path = "/data2/sules/jamun-conditional-runs//outputs/train/dev/runs/2025-08-05_04-24-31/wandb/run-20250805_042516-yqn9mm7x/files/config.yaml" +cfg = OmegaConf.load(config_path) + +# Find the checkpoint file +checkpoint_path = ( + "/data2/sules/jamun-conditional-runs//outputs/train/dev/runs/2025-08-05_04-24-31/checkpoints/last.ckpt" +) +# checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.endswith('.ckpt')] +# checkpoint_path = os.path.join(checkpoint_dir, checkpoint_files[0]) # or choose specific one +breakpoint() +# Instantiate the model using the config +model = hydra.utils.instantiate(cfg.model) +breakpoint() +# Load the state dict from checkpoint +checkpoint = torch.load(checkpoint_path, map_location="cpu") +model.load_state_dict(checkpoint["state_dict"]) + +print(f"Loaded model: {type(model).__name__}") +print(f"From checkpoint: {checkpoint_path}") diff --git a/scratch/load_wandb_checkpoint.py b/scratch/load_wandb_checkpoint.py new file mode 100644 index 0000000..6a52906 --- /dev/null +++ b/scratch/load_wandb_checkpoint.py @@ -0,0 +1,40 @@ +import os + +from jamun.model.denoiser_conditional import Denoiser + + +def load_model_from_local_checkpoint(checkpoint_dir: str): + """ + Loads a model from a local checkpoint directory. + """ + try: + checkpoint_file = None + for file_name in os.listdir(checkpoint_dir): + if file_name.endswith(".ckpt"): + if "last.ckpt" in file_name: + checkpoint_file = file_name + break + checkpoint_file = file_name # fallback to first .ckpt + + if checkpoint_file: + checkpoint_path = os.path.join(checkpoint_dir, checkpoint_file) + print(f"Found checkpoint file: {checkpoint_path}") + + # load model + model = Denoiser.load_from_checkpoint(checkpoint_path) + print("Model loaded successfully!") + print(model) + + return model + else: + print(f"No checkpoint file (.ckpt) found in directory: {checkpoint_dir}") + return None + + except Exception as e: + print(f"An error occurred: {e}") + return None + + +if __name__ == "__main__": + checkpoint_dir = "/data2/sules/jamun-conditional-runs/outputs/train/dev/runs/2025-06-30_19-07-58/checkpoints" + load_model_from_local_checkpoint(checkpoint_dir) diff --git a/scratch/my_histogram.png b/scratch/my_histogram.png new file mode 100644 index 0000000..716b783 Binary files /dev/null and b/scratch/my_histogram.png differ diff --git a/scratch/organize_data.py b/scratch/organize_data.py new file mode 100755 index 0000000..eca00f1 --- /dev/null +++ b/scratch/organize_data.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +""" +Script to organize files from ALA_ALA dataset into train/val/test directories +with random 70/10/20 split. +""" + +import random +import shutil +from pathlib import Path + + +def get_files_from_directory(source_dir: str) -> list[str]: + """Get all files from the source directory.""" + source_path = Path(source_dir) + if not source_path.exists(): + raise FileNotFoundError(f"Source directory {source_dir} does not exist") + + files = [f for f in source_path.iterdir() if f.is_file()] + return files + + +def create_target_directories(base_dir: str) -> dict: + """Create train/val/test directories and return their paths.""" + base_path = Path(base_dir) + + directories = {"train": base_path / "train", "val": base_path / "val", "test": base_path / "test"} + + for dir_path in directories.values(): + dir_path.mkdir(parents=True, exist_ok=True) + print(f"Created directory: {dir_path}") + + return directories + + +def split_files(files: list[Path], train_ratio: float = 0.7, val_ratio: float = 0.1, test_ratio: float = 0.2): + """Split files randomly into train/val/test sets.""" + if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6: + raise ValueError("Split ratios must sum to 1.0") + + # Shuffle files randomly + files_copy = files.copy() + random.shuffle(files_copy) + + total_files = len(files_copy) + train_count = int(total_files * train_ratio) + val_count = int(total_files * val_ratio) + + # Split the files + train_files = files_copy[:train_count] + val_files = files_copy[train_count : train_count + val_count] + test_files = files_copy[train_count + val_count :] + + return {"train": train_files, "val": val_files, "test": test_files} + + +def copy_files(file_splits: dict, target_dirs: dict, copy_mode: str = "copy"): + """Copy or move files to their respective directories.""" + for split_name, files in file_splits.items(): + target_dir = target_dirs[split_name] + + print(f"\n{copy_mode.capitalize()}ing {len(files)} files to {split_name} directory...") + + for file_path in files: + target_path = target_dir / file_path.name + + if copy_mode == "copy": + shutil.copy2(file_path, target_path) + elif copy_mode == "move": + shutil.move(str(file_path), str(target_path)) + else: + raise ValueError("copy_mode must be either 'copy' or 'move'") + + print(f"Completed {split_name}: {len(files)} files") + + +def main(): + # Configuration + source_directory = "/data2/sules/fake_enhanced_data/ALA_ALA" + target_base_directory = "/data2/sules/fake_enhanced_data/ALA_ALA_organized" + + # Split ratios + train_ratio = 0.8 + val_ratio = 0.1 + test_ratio = 0.1 + + # Set random seed for reproducibility (optional) + random.seed(42) + + print(f"Organizing files from: {source_directory}") + print(f"Target directory: {target_base_directory}") + print(f"Split ratios - Train: {train_ratio}, Val: {val_ratio}, Test: {test_ratio}") + + try: + # Get all files from source directory + print("\nGetting files from source directory...") + files = get_files_from_directory(source_directory) + print(f"Found {len(files)} files") + + if len(files) == 0: + print("No files found in source directory. Exiting.") + return + + # Create target directories + print("\nCreating target directories...") + target_dirs = create_target_directories(target_base_directory) + + # Split files randomly + print("\nSplitting files randomly...") + file_splits = split_files(files, train_ratio, val_ratio, test_ratio) + + # Print split statistics + print("\nSplit statistics:") + for split_name, files_in_split in file_splits.items(): + percentage = (len(files_in_split) / len(files)) * 100 + print(f" {split_name}: {len(files_in_split)} files ({percentage:.1f}%)") + + # Copy files to target directories + print("\nCopying files...") + copy_files(file_splits, target_dirs, copy_mode="copy") + + print(f"\n✅ Successfully organized {len(files)} files!") + print(f"Files copied to: {target_base_directory}") + + except Exception as e: + print(f"❌ Error: {e}") + return 1 + + +if __name__ == "__main__": + exit(main()) diff --git a/scratch/override_config.yaml b/scratch/override_config.yaml new file mode 100644 index 0000000..826332f --- /dev/null +++ b/scratch/override_config.yaml @@ -0,0 +1,32 @@ +_target_: scratch.denoiser_test.Denoiser +arch: + _target_: scratch.e3conv_test.E3Conv + _partial_: true + irreps_out: 1x1e + irreps_hidden: 120x0e + 32x1e + irreps_sh: 1x0e + 1x1e + n_layers: 2 + edge_attr_dim: 64 + atom_type_embedding_dim: 8 + atom_code_embedding_dim: 8 + residue_code_embedding_dim: 32 + residue_index_embedding_dim: 8 + use_residue_information: ${data.use_residue_information} + use_residue_sequence_index: false + num_atom_types: 20 + max_sequence_length: 10 + num_atom_codes: 10 + num_residue_types: 25 + hidden_layer_factory: + _target_: e3tools.nn.ConvBlock + _partial_: true + conv: + _target_: e3tools.nn.Conv + _partial_: true + output_head_factory: + _target_: e3tools.nn.EquivariantMLP + _partial_: true + irreps_hidden_list: + - ${model.arch.irreps_hidden} +conditioner: + _target_: scratch.conditioners.SelfConditioner \ No newline at end of file diff --git a/scratch/reorganize_swarm_data.py b/scratch/reorganize_swarm_data.py new file mode 100644 index 0000000..d05350c --- /dev/null +++ b/scratch/reorganize_swarm_data.py @@ -0,0 +1,758 @@ +#!/usr/bin/env python3 +""" +Script to reorganize ALA_ALA swarm results data. + +This script takes data from /data/bucket/vanib/ALA_ALA/swarm_results/ and organizes it +into /data2/sules/ALA_ALA_enhanced/ with train/val splits. + +Input structure: +- /data/bucket/vanib/ALA_ALA/swarm_results/AA_{grid_code}/ + - swarm_1ps_{traj_code}.xtc (where traj_code is 001-005) +- /data/bucket/vanib/ALA_ALA/ALA_ALA.pdb (single PDB file to use for all) + +Output structure: +- /data2/sules/ALA_ALA_enhanced/train/ + - swarm_1ps_{grid_code}_{traj_code}.xtc + - swarm_1ps_{grid_code}_{traj_code}.pdb +- /data2/sules/ALA_ALA_enhanced/val/ + - swarm_1ps_{grid_code}_{traj_code}.xtc + - swarm_1ps_{grid_code}_{traj_code}.pdb + +The train folder contains 172 randomly sampled grid codes, val contains the remaining 12. +Each grid code has 5 swarms (001-005), so: +- Train: 172 × 5 = 860 .xtc files + 860 .pdb files +- Val: 12 × 5 = 60 .xtc files + 60 .pdb files +""" + +import logging +import os +import random +import shutil + +try: + import mdtraj as md + import numpy as np + + MDTRAJ_AVAILABLE = True +except ImportError: + MDTRAJ_AVAILABLE = False + logging.warning("mdtraj not available. Trajectory validation will be skipped.") + +try: + from tqdm import tqdm + + TQDM_AVAILABLE = True +except ImportError: + TQDM_AVAILABLE = False + logging.warning("tqdm not available. Progress bars will be disabled.") + +# Configure logging +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + +# Configuration +SOURCE_DIR = "/data/bucket/vanib/ALA_ALA/swarms/swarm_results" +SINGLE_PDB_FILE = "/data/bucket/vanib/ALA_ALA/swarms/ALA_ALA.pdb" +TRAJECTORY_CODES = ["001", "002", "003", "004", "005"] +LONG_TRAJECTORY_CODES = ["001", "003"] # For 2000ps trajectories + +# Splitting strategies +SPLITTING_STRATEGIES = { + "grid_split": { + "target_dir": "/data2/sules/ALA_ALA_enhanced_full_swarm", + "train_size": 172, # Number of grid codes for training + "description": "Random grid codes split: 172 grids for train, 12 grids for val, all trajectories", + }, + "trajectory_split": { + "target_dir": "/data2/sules/ALA_ALA_enhanced_full_grid", + "train_trajectories": ["001", "002", "003", "004"], # First 4 trajectories for train + "val_trajectories": ["005"], # Last trajectory for val + "description": "All grids split by trajectory: trajectories 001-004 for train, 005 for val", + }, + "long_grid_split": { + "target_dir": "/data2/sules/ALA_ALA_enhanced_long", + "trajectory_codes": ["001", "003"], # Only 2000ps trajectories + "train_size": 172, # Number of grid codes for training + "description": "Random grid codes split for 2000ps trajectories: 172 grids for train, 12 grids for val", + }, + "state_split": { + "target_dir": "/data2/sules/ALA_ALA_enhanced_long_state_split", + "trajectory_codes": ["001", "003"], # Only 2000ps trajectories + "phi_range": (0, 100), # First residue phi range for validation set + "psi_range": (-50, 100), # First residue psi range for validation set + "description": "Split by conformational state: trajectories with first residue phi,psi in (0,100)x(-50,100) go to val, others to train", + }, +} + + +def get_all_grid_codes(source_dir: str) -> list[str]: + """ + Get all grid codes from the source directory. + + Args: + source_dir: Path to the swarm results directory + + Returns: + List of grid codes (e.g., ['000', '001', '002', ...]) + """ + grid_codes = [] + for item in os.listdir(source_dir): + if os.path.isdir(os.path.join(source_dir, item)) and item.startswith("AA_"): + grid_code = item[3:] # Remove 'AA_' prefix + grid_codes.append(grid_code) + + grid_codes.sort() + logger.info(f"Found {len(grid_codes)} grid codes") + return grid_codes + + +def split_train_val(grid_codes: list[str], train_size: int, random_seed: int = 42) -> tuple[list[str], list[str]]: + """ + Randomly split grid codes into train and validation sets. + + Args: + grid_codes: List of all grid codes + train_size: Number of grid codes for training + random_seed: Random seed for reproducibility + + Returns: + Tuple of (train_grid_codes, val_grid_codes) + """ + random.seed(random_seed) + shuffled_codes = grid_codes.copy() + random.shuffle(shuffled_codes) + + train_codes = shuffled_codes[:train_size] + val_codes = shuffled_codes[train_size:] + + logger.info(f"Train set: {len(train_codes)} grid codes") + logger.info(f"Val set: {len(val_codes)} grid codes") + + return train_codes, val_codes + + +def create_target_directories(target_dir: str): + """Create target directory structure.""" + train_dir = os.path.join(target_dir, "train") + val_dir = os.path.join(target_dir, "val") + + os.makedirs(train_dir, exist_ok=True) + os.makedirs(val_dir, exist_ok=True) + + logger.info(f"Created directories: {train_dir}, {val_dir}") + + +def copy_files_for_grid_split( + source_dir: str, + target_dir: str, + grid_codes: list[str], + trajectory_codes: list[str], + single_pdb_file: str, + split_name: str, + use_2000ps: bool = False, +): + """ + Copy and rename files for a specific split (train or val). + + Args: + source_dir: Source swarm results directory + target_dir: Target directory for this split + grid_codes: List of grid codes for this split + trajectory_codes: List of trajectory codes (001-005) + single_pdb_file: Path to the single PDB file to copy + split_name: Name of the split for logging + use_2000ps: If True, use swarm_2000ps_*.xtc files instead of swarm_1ps_*.xtc + """ + total_operations = len(grid_codes) * len(trajectory_codes) * 2 # ×2 for .xtc and .pdb + + # Create progress bar + if TQDM_AVAILABLE: + pbar = tqdm( + total=total_operations, + desc=f"Copying {split_name} files", + unit="files", + bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]", + ) + + copied_files = 0 + missing_files = 0 + + if use_2000ps: + traj_prefix = "swarm_2000ps" + else: + traj_prefix = "swarm_1ps" + for grid_code in grid_codes: + source_grid_dir = os.path.join(source_dir, f"AA_{grid_code}") + + if not os.path.exists(source_grid_dir): + logger.warning(f"Source directory does not exist: {source_grid_dir}") + # Skip all files for this grid code + if TQDM_AVAILABLE: + pbar.update(len(trajectory_codes) * 2) + continue + + for traj_code in trajectory_codes: + # Handle .xtc file + source_xtc = os.path.join(source_grid_dir, f"{traj_prefix}_{traj_code}.xtc") + target_xtc = os.path.join(target_dir, f"{traj_prefix}_{grid_code}_{traj_code}.xtc") + + if os.path.exists(source_xtc): + shutil.copy2(source_xtc, target_xtc) + copied_files += 1 + else: + logger.warning(f"Source file does not exist: {source_xtc}") + missing_files += 1 + + if TQDM_AVAILABLE: + pbar.update(1) + + # Handle .pdb file (copy the single PDB file) + target_pdb = os.path.join(target_dir, f"{traj_prefix}_{grid_code}_{traj_code}.pdb") + if os.path.exists(single_pdb_file): + shutil.copy2(single_pdb_file, target_pdb) + copied_files += 1 + else: + logger.error(f"Single PDB file does not exist: {single_pdb_file}") + missing_files += 1 + + if TQDM_AVAILABLE: + pbar.update(1) + + if TQDM_AVAILABLE: + pbar.close() + + logger.info(f"{split_name}: Completed copying {copied_files} files ({missing_files} missing/failed)") + + +def copy_files_for_trajectory_split( + source_dir: str, + target_dir: str, + all_grid_codes: list[str], + trajectory_codes: list[str], + single_pdb_file: str, + split_name: str, +): + """ + Copy and rename files for trajectory-based split (all grids, specific trajectories). + + Args: + source_dir: Source swarm results directory + target_dir: Target directory for this split + all_grid_codes: List of all grid codes to include + trajectory_codes: List of trajectory codes for this split + single_pdb_file: Path to the single PDB file to copy + split_name: Name of the split for logging + """ + total_operations = len(all_grid_codes) * len(trajectory_codes) * 2 # ×2 for .xtc and .pdb + + # Create progress bar + if TQDM_AVAILABLE: + pbar = tqdm( + total=total_operations, + desc=f"Copying {split_name} files", + unit="files", + bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]", + ) + + copied_files = 0 + missing_files = 0 + + for grid_code in all_grid_codes: + source_grid_dir = os.path.join(source_dir, f"AA_{grid_code}") + + if not os.path.exists(source_grid_dir): + logger.warning(f"Source directory does not exist: {source_grid_dir}") + # Skip all files for this grid code + if TQDM_AVAILABLE: + pbar.update(len(trajectory_codes) * 2) + continue + + for traj_code in trajectory_codes: + # Handle .xtc file + source_xtc = os.path.join(source_grid_dir, f"swarm_1ps_{traj_code}.xtc") + target_xtc = os.path.join(target_dir, f"swarm_1ps_{grid_code}_{traj_code}.xtc") + + if os.path.exists(source_xtc): + shutil.copy2(source_xtc, target_xtc) + copied_files += 1 + else: + logger.warning(f"Source file does not exist: {source_xtc}") + missing_files += 1 + + if TQDM_AVAILABLE: + pbar.update(1) + + # Handle .pdb file (copy the single PDB file) + target_pdb = os.path.join(target_dir, f"swarm_1ps_{grid_code}_{traj_code}.pdb") + if os.path.exists(single_pdb_file): + shutil.copy2(single_pdb_file, target_pdb) + copied_files += 1 + else: + logger.error(f"Single PDB file does not exist: {single_pdb_file}") + missing_files += 1 + + if TQDM_AVAILABLE: + pbar.update(1) + + if TQDM_AVAILABLE: + pbar.close() + + logger.info(f"{split_name}: Completed copying {copied_files} files ({missing_files} missing/failed)") + + +def analyze_trajectory_state(xtc_path: str, pdb_path: str, phi_range: tuple, psi_range: tuple) -> bool: + """ + Analyze a trajectory to determine if any point has first residue phi,psi in the specified ranges. + + Args: + xtc_path: Path to trajectory file + pdb_path: Path to topology file + phi_range: Tuple of (min, max) for phi angles in degrees + psi_range: Tuple of (min, max) for psi angles in degrees + + Returns: + True if any point in trajectory has first residue phi,psi in the specified ranges + """ + if not MDTRAJ_AVAILABLE: + logger.error("mdtraj not available, cannot analyze trajectory states") + return False + + try: + # Load trajectory + traj = md.load(xtc_path, top=pdb_path) + traj = traj[:1000] # only use first 1000 frames to avoid memory issues + # Compute phi and psi angles + _, phi_angles = md.compute_phi(traj) + _, psi_angles = md.compute_psi(traj) + + # Convert to degrees + phi_deg = np.degrees(phi_angles) + psi_deg = np.degrees(psi_angles) + + # Check first residue (index 0) for points in specified ranges + first_phi_in_range = (phi_deg[:, 0] > phi_range[0]) & (phi_deg[:, 0] < phi_range[1]) + first_psi_in_range = (psi_deg[:, 0] > psi_range[0]) & (psi_deg[:, 0] < psi_range[1]) + first_residue_in_range = first_phi_in_range & first_psi_in_range + + # Return True if any point is in range + has_points_in_range = np.any(first_residue_in_range) + n_points_in_range = np.sum(first_residue_in_range) + + logger.debug(f"Trajectory {xtc_path}: {n_points_in_range}/{len(phi_deg)} points in target range") + + return has_points_in_range + + except Exception as e: + logger.error(f"Failed to analyze trajectory {xtc_path}: {str(e)}") + return False + + +def test_mdtraj_compatibility(target_dir: str, num_samples: int = 3): + """ + Test that mdtraj can successfully load swarm + PDB combinations. + + Args: + target_dir: Target directory containing train/val splits + num_samples: Number of random samples to test from each split + """ + if not MDTRAJ_AVAILABLE: + logger.warning("⚠️ mdtraj not available, skipping trajectory compatibility tests") + return True + + logger.info("=== TESTING MDTRAJ COMPATIBILITY ===") + + for split in ["train", "val"]: + split_dir = os.path.join(target_dir, split) + if not os.path.exists(split_dir): + continue + + # Get all .xtc files + xtc_files = [f for f in os.listdir(split_dir) if f.endswith(".xtc")] + + if not xtc_files: + logger.warning(f"No .xtc files found in {split} directory") + continue + + # Sample a few files to test + test_files = random.sample(xtc_files, min(num_samples, len(xtc_files))) + + success_count = 0 + for xtc_file in test_files: + # Get corresponding PDB file + base_name = xtc_file.replace(".xtc", "") + pdb_file = f"{base_name}.pdb" + + xtc_path = os.path.join(split_dir, xtc_file) + pdb_path = os.path.join(split_dir, pdb_file) + + if not os.path.exists(pdb_path): + logger.error(f"Missing PDB file: {pdb_path}") + continue + + try: + # Try to load trajectory with mdtraj + traj = md.load(xtc_path, top=pdb_path) + logger.info( + f"✅ {split}: Successfully loaded {xtc_file} + {pdb_file} " + f"({traj.n_frames} frames, {traj.n_atoms} atoms)" + ) + success_count += 1 + + # Clean up memory + del traj + + except Exception as e: + logger.error(f"❌ {split}: Failed to load {xtc_file} + {pdb_file}: {str(e)}") + + logger.info(f"{split}: {success_count}/{len(test_files)} trajectory tests passed") + + logger.info("mdtraj compatibility testing completed") + return True + + +def verify_output(target_dir: str, expected_train_files: int, expected_val_files: int): + """ + Verify the output directory structure and file counts. + + Args: + target_dir: Target directory path + expected_train_files: Expected number of files in train directory + expected_val_files: Expected number of files in val directory + """ + train_dir = os.path.join(target_dir, "train") + val_dir = os.path.join(target_dir, "val") + + train_files = len([f for f in os.listdir(train_dir) if os.path.isfile(os.path.join(train_dir, f))]) + val_files = len([f for f in os.listdir(val_dir) if os.path.isfile(os.path.join(val_dir, f))]) + + train_xtc = len([f for f in os.listdir(train_dir) if f.endswith(".xtc")]) + train_pdb = len([f for f in os.listdir(train_dir) if f.endswith(".pdb")]) + val_xtc = len([f for f in os.listdir(val_dir) if f.endswith(".xtc")]) + val_pdb = len([f for f in os.listdir(val_dir) if f.endswith(".pdb")]) + + logger.info("=== VERIFICATION RESULTS ===") + logger.info(f"Train directory: {train_files} total files ({train_xtc} .xtc, {train_pdb} .pdb)") + logger.info(f"Val directory: {val_files} total files ({val_xtc} .xtc, {val_pdb} .pdb)") + logger.info(f"Expected train files: {expected_train_files}") + logger.info(f"Expected val files: {expected_val_files}") + + if train_files == expected_train_files and val_files == expected_val_files: + logger.info("✅ File counts match expectations!") + + # Test mdtraj compatibility + test_mdtraj_compatibility(target_dir) + + else: + logger.warning("❌ File counts do not match expectations!") + + +def reorganize_with_long_grid_split(grid_codes: list[str], strategy_config: dict): + """Reorganize 2000ps data using grid-based splitting strategy.""" + target_dir = strategy_config["target_dir"] + train_size = strategy_config["train_size"] + trajectory_codes = strategy_config["trajectory_codes"] + + logger.info(f"Using long grid split strategy: {strategy_config['description']}") + + if len(grid_codes) < train_size: + logger.error(f"Not enough grid codes found. Expected at least {train_size}, found {len(grid_codes)}") + return + + # Split into train and validation + train_codes, val_codes = split_train_val(grid_codes, train_size) + + # Create target directories + create_target_directories(target_dir) + + # Copy files for train split (using 2000ps trajectories) + logger.info("Copying 2000ps files for train split...") + copy_files_for_grid_split( + SOURCE_DIR, + os.path.join(target_dir, "train"), + train_codes, + trajectory_codes, + SINGLE_PDB_FILE, + "TRAIN", + use_2000ps=True, + ) + + # Copy files for val split (using 2000ps trajectories) + logger.info("Copying 2000ps files for val split...") + copy_files_for_grid_split( + SOURCE_DIR, + os.path.join(target_dir, "val"), + val_codes, + trajectory_codes, + SINGLE_PDB_FILE, + "VAL", + use_2000ps=True, + ) + + # Verify output + expected_train_files = len(train_codes) * len(trajectory_codes) * 2 # ×2 for .xtc and .pdb + expected_val_files = len(val_codes) * len(trajectory_codes) * 2 + + verify_output(target_dir, expected_train_files, expected_val_files) + + +def reorganize_with_grid_split(grid_codes: list[str], strategy_config: dict): + """Reorganize data using grid-based splitting strategy.""" + target_dir = strategy_config["target_dir"] + train_size = strategy_config["train_size"] + + logger.info(f"Using grid split strategy: {strategy_config['description']}") + + if len(grid_codes) < train_size: + logger.error(f"Not enough grid codes found. Expected at least {train_size}, found {len(grid_codes)}") + return + + # Split into train and validation + train_codes, val_codes = split_train_val(grid_codes, train_size) + + # Create target directories + create_target_directories(target_dir) + + # Copy files for train split + logger.info("Copying files for train split...") + copy_files_for_grid_split( + SOURCE_DIR, os.path.join(target_dir, "train"), train_codes, TRAJECTORY_CODES, SINGLE_PDB_FILE, "TRAIN" + ) + + # Copy files for val split + logger.info("Copying files for val split...") + copy_files_for_grid_split( + SOURCE_DIR, os.path.join(target_dir, "val"), val_codes, TRAJECTORY_CODES, SINGLE_PDB_FILE, "VAL" + ) + + # Verify output + expected_train_files = len(train_codes) * len(TRAJECTORY_CODES) * 2 # ×2 for .xtc and .pdb + expected_val_files = len(val_codes) * len(TRAJECTORY_CODES) * 2 + + verify_output(target_dir, expected_train_files, expected_val_files) + + +def reorganize_with_trajectory_split(grid_codes: list[str], strategy_config: dict): + """Reorganize data using trajectory-based splitting strategy.""" + target_dir = strategy_config["target_dir"] + train_trajectories = strategy_config["train_trajectories"] + val_trajectories = strategy_config["val_trajectories"] + + logger.info(f"Using trajectory split strategy: {strategy_config['description']}") + + # Create target directories + create_target_directories(target_dir) + + # Copy files for train split (all grids, first 4 trajectories) + logger.info("Copying files for train split...") + copy_files_for_trajectory_split( + SOURCE_DIR, os.path.join(target_dir, "train"), grid_codes, train_trajectories, SINGLE_PDB_FILE, "TRAIN" + ) + + # Copy files for val split (all grids, last trajectory) + logger.info("Copying files for val split...") + copy_files_for_trajectory_split( + SOURCE_DIR, os.path.join(target_dir, "val"), grid_codes, val_trajectories, SINGLE_PDB_FILE, "VAL" + ) + + # Verify output + expected_train_files = len(grid_codes) * len(train_trajectories) * 2 # ×2 for .xtc and .pdb + expected_val_files = len(grid_codes) * len(val_trajectories) * 2 + + verify_output(target_dir, expected_train_files, expected_val_files) + + +def copy_files_for_state_split( + source_dir: str, + target_dir: str, + all_grid_codes: list[str], + trajectory_codes: list[str], + single_pdb_file: str, + phi_range: tuple, + psi_range: tuple, +): + """ + Copy and organize files based on conformational state analysis. + + Args: + source_dir: Source swarm results directory + target_dir: Target directory with train/val subdirectories + all_grid_codes: List of all grid codes to process + trajectory_codes: List of trajectory codes to include + single_pdb_file: Path to the single PDB file to copy + phi_range: Tuple of (min, max) for phi angles in degrees + psi_range: Tuple of (min, max) for psi angles in degrees + """ + if not MDTRAJ_AVAILABLE: + logger.error("mdtraj not available, cannot perform state-based splitting") + return + + train_dir = os.path.join(target_dir, "train") + val_dir = os.path.join(target_dir, "val") + + total_operations = len(all_grid_codes) * len(trajectory_codes) * 2 # ×2 for .xtc and .pdb + + # Create progress bar + if TQDM_AVAILABLE: + pbar = tqdm( + total=total_operations, + desc="Analyzing and copying files", + unit="files", + bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]", + ) + + copied_train = 0 + copied_val = 0 + missing_files = 0 + analysis_errors = 0 + + for grid_code in all_grid_codes: + source_grid_dir = os.path.join(source_dir, f"AA_{grid_code}") + + if not os.path.exists(source_grid_dir): + logger.warning(f"Source directory does not exist: {source_grid_dir}") + # Skip all files for this grid code + if TQDM_AVAILABLE: + pbar.update(len(trajectory_codes) * 2) + continue + + for traj_code in trajectory_codes: + # Handle .xtc file - need to analyze it first + source_xtc = os.path.join(source_grid_dir, f"swarm_2000ps_{traj_code}.xtc") + + if not os.path.exists(source_xtc): + logger.warning(f"Source file does not exist: {source_xtc}") + missing_files += 1 + if TQDM_AVAILABLE: + pbar.update(2) # Skip both .xtc and .pdb + continue + + # Analyze trajectory to determine train/val split + try: + goes_to_val = analyze_trajectory_state(source_xtc, single_pdb_file, phi_range, psi_range) + + if goes_to_val: + target_xtc = os.path.join(val_dir, f"swarm_2000ps_{grid_code}_{traj_code}.xtc") + target_pdb = os.path.join(val_dir, f"swarm_2000ps_{grid_code}_{traj_code}.pdb") + split_name = "VAL" + copied_val += 1 + else: + target_xtc = os.path.join(train_dir, f"swarm_2000ps_{grid_code}_{traj_code}.xtc") + target_pdb = os.path.join(train_dir, f"swarm_2000ps_{grid_code}_{traj_code}.pdb") + split_name = "TRAIN" + copied_train += 1 + + # Copy .xtc file + shutil.copy2(source_xtc, target_xtc) + logger.debug(f"Copied {source_xtc} to {split_name}") + + except Exception as e: + logger.error(f"Failed to analyze trajectory {source_xtc}: {str(e)}") + analysis_errors += 1 + if TQDM_AVAILABLE: + pbar.update(2) # Skip both .xtc and .pdb + continue + + if TQDM_AVAILABLE: + pbar.update(1) + + # Handle .pdb file (copy the single PDB file) + if os.path.exists(single_pdb_file): + shutil.copy2(single_pdb_file, target_pdb) + else: + logger.error(f"Single PDB file does not exist: {single_pdb_file}") + missing_files += 1 + + if TQDM_AVAILABLE: + pbar.update(1) + + if TQDM_AVAILABLE: + pbar.close() + + logger.info("State split completed:") + logger.info(f" TRAIN: {copied_train} trajectories") + logger.info(f" VAL: {copied_val} trajectories") + logger.info(f" Missing files: {missing_files}") + logger.info(f" Analysis errors: {analysis_errors}") + + +def reorganize_with_state_split(grid_codes: list[str], strategy_config: dict): + """Reorganize data using conformational state-based splitting strategy.""" + target_dir = strategy_config["target_dir"] + trajectory_codes = strategy_config["trajectory_codes"] + phi_range = strategy_config["phi_range"] + psi_range = strategy_config["psi_range"] + + logger.info(f"Using state split strategy: {strategy_config['description']}") + logger.info(f"Target ranges: phi {phi_range}, psi {psi_range}") + logger.info(f"Using trajectory codes: {trajectory_codes}") + + if not MDTRAJ_AVAILABLE: + logger.error("mdtraj not available, cannot perform state-based splitting") + return + + # Create target directories + create_target_directories(target_dir) + + # Copy and split files based on conformational state + logger.info("Analyzing trajectories and copying files...") + copy_files_for_state_split( + SOURCE_DIR, target_dir, grid_codes, trajectory_codes, SINGLE_PDB_FILE, phi_range, psi_range + ) + + # Note: We can't predict exact file counts since they depend on trajectory analysis + logger.info("State-based reorganization completed!") + + +def main(strategy: str = "trajectory_split"): + """ + Main function to reorganize the swarm data. + + Args: + strategy: Either 'grid_split' or 'trajectory_split' + """ + logger.info("Starting swarm data reorganization...") + + # Validate input paths + if not os.path.exists(SOURCE_DIR): + logger.error(f"Source directory does not exist: {SOURCE_DIR}") + return + + if not os.path.exists(SINGLE_PDB_FILE): + logger.error(f"Single PDB file does not exist: {SINGLE_PDB_FILE}") + return + + if strategy not in SPLITTING_STRATEGIES: + logger.error(f"Invalid strategy: {strategy}. Choose from {list(SPLITTING_STRATEGIES.keys())}") + return + + # Get all grid codes + grid_codes = get_all_grid_codes(SOURCE_DIR) + strategy_config = SPLITTING_STRATEGIES[strategy] + + # Execute the appropriate strategy + if strategy == "grid_split": + reorganize_with_grid_split(grid_codes, strategy_config) + elif strategy == "trajectory_split": + reorganize_with_trajectory_split(grid_codes, strategy_config) + elif strategy == "long_grid_split": + reorganize_with_long_grid_split(grid_codes, strategy_config) + elif strategy == "state_split": + reorganize_with_state_split(grid_codes, strategy_config) + + logger.info("Swarm data reorganization completed!") + + +if __name__ == "__main__": + import sys + + # Default to trajectory_split for the new requirement + strategy = "trajectory_split" + + # Allow command line argument to choose strategy + if len(sys.argv) > 1: + strategy = sys.argv[1] + + print(f"Running reorganization with strategy: {strategy}") + print(f"Description: {SPLITTING_STRATEGIES[strategy]['description']}") + + main(strategy) diff --git a/scratch/run_single_shape_AA_conditional.sh b/scratch/run_single_shape_AA_conditional.sh new file mode 100644 index 0000000..e92a26f --- /dev/null +++ b/scratch/run_single_shape_AA_conditional.sh @@ -0,0 +1,52 @@ +#!/usr/bin/env bash + +#SBATCH --partition gpu2 +#SBATCH --nodes 1 +#SBATCH --ntasks-per-node 1 # Adjusted to 1 as typically one training script runs per job +#SBATCH --gpus-per-node 1 # Assuming your script uses 1 GPU. Adjust if it uses more. +#SBATCH --cpus-per-task 8 # Number of CPUs for your task +#SBATCH --time 08:00:00 # 7 days runtime limit +#SBATCH --mem-per-cpu=32G # Memory per CPU +#SBATCH --job-name=train_prototype # Descriptive job name +#SBATCH --output=slurm_logs/train_prototype_%A_%a.out # Standard output file +#SBATCH --error=slurm_logs/train_prototype_%A_%a.err # Standard error file + +# --- Environment Setup --- +echo "SLURM_JOB_ID = ${SLURM_JOB_ID}" +echo "SLURM_JOB_NODELIST = ${SLURM_JOB_NODELIST}" +echo "hostname = $(hostname)" +echo "Running on partition: ${SLURM_JOB_PARTITION}" +echo "Allocated GPUs: ${CUDA_VISIBLE_DEVICES:-"Not set"}" # SLURM usually sets CUDA_VISIBLE_DEVICES + +# Activate Conda environment +eval "$(conda shell.bash hook)" +conda activate jamun +echo "Conda environment 'jamun' activated." +echo "Python version: $(python --version)" +echo "PyTorch version: $(python -c 'import torch; print(torch.__version__)')" +echo "CUDA available: $(python -c 'import torch; print(torch.cuda.is_available())')" + +# --- Create Log Directory --- +# Create a directory for SLURM logs if it doesn't exist +# This should be relative to where you submit the job from, or an absolute path. +# Assuming you submit from /homefs/home/sules/jamun/ +LOG_DIR_BASE="/homefs/home/sules/jamun/slurm_logs" +mkdir -p "${LOG_DIR_BASE}" +# The %A_%a in sbatch output/error directives will be replaced by JobID and TaskID + +# --- Application Execution --- +# Navigate to the directory containing your script, if necessary +# Assuming training_prototype.py is in /homefs/home/sules/jamun/scratch/ +# SCRIPT_DIR="/homefs/home/sules/jamun/scratch" +# PYTHON_SCRIPT="training_prototype.py" + +# echo "Changing directory to ${SCRIPT_DIR}" +# cd "${SCRIPT_DIR}" || { echo "Failed to cd to ${SCRIPT_DIR}"; exit 1; } + +echo "Starting Python script: ${PYTHON_SCRIPT}" +# Run the Python script +# Add any necessary command-line arguments for your script here +python "${PYTHON_SCRIPT}" + +echo "Python script finished." +echo "Job finished at: $(date)" diff --git a/scratch/test_conditional_denoiser.py b/scratch/test_conditional_denoiser.py new file mode 100644 index 0000000..999ce1e --- /dev/null +++ b/scratch/test_conditional_denoiser.py @@ -0,0 +1,196 @@ +import e3nn + +e3nn.set_optimization_defaults(jit_script_fx=False) +import math +import os +import sys + +import dotenv +import hydra +import torch +from omegaconf import OmegaConf + +from jamun.utils import compute_average_squared_distance_from_datasets +from jamun.utils._normalizations import normalization_factors +from jamun.utils.average_squared_distance import compute_temporal_average_squared_distance_from_datasets + + +def compute_radial_cutoff(max_radius: float, average_squared_distance: float, sigma: float, D: int = 3) -> float: + """ + Compute radial cutoff using the same formula as the denoiser. + + This replicates the computation from denoiser_conditional.py: + radial_cutoff = effective_radial_cutoff(sigma) / c_in + where: + - effective_radial_cutoff = sqrt(max_radius² + 6σ²) + - c_in = 1.0 / sqrt(average_squared_distance + 2Dσ²) + + Args: + max_radius: Maximum radius parameter + average_squared_distance: Average squared distance from dataset + sigma: Noise level + D: Dimensionality (default 3 for 3D coordinates) + + Returns: + Computed radial cutoff + """ + # Effective radial cutoff based on noise level + effective_radial_cutoff = math.sqrt(max_radius**2 + 6 * sigma**2) + + # JAMUN normalization factor c_in + A = average_squared_distance + B = 2 * D * sigma**2 + c_in = 1.0 / math.sqrt(A + B) + + # Final radial cutoff + radial_cutoff = effective_radial_cutoff / c_in + + print("Radial cutoff computation:") + print(f" max_radius: {max_radius}") + print(f" average_squared_distance: {average_squared_distance}") + print(f" sigma: {sigma}") + print(f" D: {D}") + print(f" effective_radial_cutoff: {effective_radial_cutoff}") + print(f" c_in: {c_in}") + print(f" final radial_cutoff: {radial_cutoff}") + + return radial_cutoff + + +dotenv.load_dotenv("../.env", verbose=True) # Adjust path if script is not in scratch/ +JAMUN_DATA_PATH = os.getenv("JAMUN_DATA_PATH") +JAMUN_ROOT_PATH = os.getenv("JAMUN_ROOT_PATH") + +project_root = "/homefs/home/sules/jamun" # Adjust if necessary +if project_root not in sys.path: + sys.path.insert(0, project_root) + print(f"Added '{project_root}' to sys.path for module discovery.") +else: + print(f"'{project_root}' is already in sys.path.") + + +def compute_average_squared_distance_from_config(cfg: OmegaConf) -> float: + """Computes the average squared distance for normalization from the data.""" + datamodule = hydra.utils.instantiate(cfg.data.datamodule) + datamodule.setup("compute_normalization") + train_datasets = datamodule.datasets["train"] + cutoff = cfg.model.max_radius + average_squared_distance = compute_average_squared_distance_from_datasets(train_datasets, cutoff) + return average_squared_distance + + +def compute_temporal_average_squared_distance_from_config(cfg: OmegaConf) -> float: + """Computes the temporal average squared distance for normalization from the data.""" + datamodule = hydra.utils.instantiate(cfg.data.datamodule) + datamodule.setup("compute_normalization") + train_datasets = datamodule.datasets["train"] + + average_squared_distance = compute_temporal_average_squared_distance_from_datasets( + train_datasets, + num_samples=100, # Use reasonable number of samples + verbose=True, + ) + return average_squared_distance + + +@hydra.main(version_base=None, config_path="../src/jamun/hydra_config", config_name="train") +def main(cfg): + # Override configuration to use denoiser_conditional with DenoisingConditioner + # cfg.model._target_ = "jamun.model.denoiser_conditional.Denoiser" + # cfg.model.sigma_distribution._target_ = "jamun.distributions.ConstantSigma" + # cfg.model.sigma_distribution.sigma = 0.04 + breakpoint() + datamodule = hydra.utils.instantiate(cfg.data.datamodule) + datamodule.setup("test") + breakpoint() + # Load the test config + average_squared_distance = compute_average_squared_distance_from_config(cfg) + temporal_average_squared_distance = compute_temporal_average_squared_distance_from_config(cfg) + cfg.model.average_squared_distance = average_squared_distance + + # Compute radial cutoff for spatiotemporal model using the same formula as denoiser + sigma = cfg.model.sigma_distribution.sigma + max_radius = cfg.model.max_radius + spatial_radial_cutoff = compute_radial_cutoff( + max_radius=max_radius, + average_squared_distance=average_squared_distance, # Use temporal for spatiotemporal model + sigma=sigma, + D=3, + ) + temporal_radial_cutoff = compute_radial_cutoff( + max_radius=max_radius, + average_squared_distance=temporal_average_squared_distance, # Use temporal for spatiotemporal model + sigma=sigma, + D=3, + ) + cfg.model.conditioner.spatiotemporal_model.radial_cutoff = spatial_radial_cutoff + cfg.model.conditioner.spatiotemporal_model.temporal_cutoff = temporal_radial_cutoff + # Compute c_in using the utility function + sigma = cfg.model.sigma_distribution.sigma + c_in, c_skip, c_out, c_noise = normalization_factors(sigma, average_squared_distance) + c_in_float = float(c_in) + c_noise_float = float(c_noise) + print(f"Computed normalization factors with sigma={sigma}:") + print(f" c_in: {c_in_float}") + print(f" c_skip: {c_skip}") + print(f" c_out: {c_out}") + print(f" c_noise: {c_noise}") + breakpoint() + # Configure DenoisingConditioner with computed c_in + if cfg.model.conditioner._target_ == "jamun.model.conditioners.DenoisedConditioner": + # cfg.model.conditioner.N_structures = 2 # Must match architecture N_structures + cfg.model.conditioner.pretrained_model_path = "sule-shashank/jamun/88i7qkj2" + cfg.model.conditioner.c_in = c_in_float + + if cfg.model.conditioner._target_ == "jamun.model.conditioners.conditioners.SpatioTemporalConditioner": + cfg.model.conditioner.spatiotemporal_model.radial_cutoff = average_squared_distance + max_radius = cfg.model.max_radius + temporal_average_squared_distance = compute_temporal_average_squared_distance_from_config(cfg) + temporal_radial_cutoff = compute_radial_cutoff( + max_radius=max_radius, + average_squared_distance=temporal_average_squared_distance, # Use temporal for spatiotemporal model + sigma=sigma, + D=3, + ) + cfg.model.conditioner.spatiotemporal_model.temporal_cutoff = temporal_radial_cutoff + cfg.model.conditioner.c_noise = c_noise_float + cfg.model.conditioner.c_in = c_in_float + + print("Loading model...") + model = hydra.utils.instantiate(cfg.model) + model.conditioning_module.c_noise = c_noise + print(f"Model loaded: {type(model)}") + print(f"Conditioner: {type(model.conditioning_module)}") + print(f"Sigma: {model.sigma_distribution.sigma}") + # print(f"Conditioner c_in: {model.conditioning_module.c_in}") + breakpoint() + + # Get a single batch + print("Getting a batch of data...") + train_loader = datamodule.train_dataloader() + _, batch = next(enumerate(train_loader)) + + print(f"Batch shape: {batch.pos.shape}") + print(f"Hidden state shape: {[h.shape for h in batch.hidden_state]}") + breakpoint() + + # Test forward pass + print("Testing forward pass...") + with torch.no_grad(): + sigma = model.sigma_distribution.sample() + x_target, xhat, y = model.noise_and_denoise(batch, sigma, align_noisy_input=True) + + print(f"Input shape: {batch.pos.shape}") + print(f"Noisy shape: {y.pos.shape}") + print(f"Output shape: {xhat.pos.shape}") + breakpoint() + + # Test single training step + print("Testing training step...") + loss_output = model.training_step(batch, 0) + print(f"Loss: {loss_output['loss']:.4f}") + breakpoint() + + +if __name__ == "__main__": + main() diff --git a/scratch/test_conditional_sampling.py b/scratch/test_conditional_sampling.py new file mode 100644 index 0000000..d00616c --- /dev/null +++ b/scratch/test_conditional_sampling.py @@ -0,0 +1,186 @@ +import e3nn + +e3nn.set_optimization_defaults(jit_script_fx=False) +import os +import sys + +import dotenv +import hydra +import torch +from omegaconf import OmegaConf +from tqdm import tqdm + +from jamun.data import MDtrajDataModule +from jamun.utils import ModelSamplingWrapperMemory, find_checkpoint + +dotenv.load_dotenv("../.env", verbose=True) # Adjust path if script is not in scratch/ +JAMUN_DATA_PATH = os.getenv("JAMUN_DATA_PATH") +JAMUN_ROOT_PATH = os.getenv("JAMUN_ROOT_PATH") + +project_root = "/homefs/home/sules/jamun" # Adjust if necessary +if project_root not in sys.path: + sys.path.insert(0, project_root) + print(f"Added '{project_root}' to sys.path for module discovery.") +else: + print(f"'{project_root}' is already in sys.path.") + + +@hydra.main(version_base=None, config_path="../src/jamun/hydra_config", config_name="sample_memory") +def main(cfg): + print("Configuration loaded:") + print(OmegaConf.to_yaml(cfg)) + breakpoint() + + # Load checkpoint using find_checkpoint function + print("Finding checkpoint...") + checkpoint_path = find_checkpoint( + wandb_train_run_path=cfg.get("wandb_train_run_path"), + checkpoint_dir=cfg.get("checkpoint_dir"), + checkpoint_type=cfg.get("checkpoint_type"), + ) + print(f"Checkpoint found at: {checkpoint_path}") + # cfg.M = 1/6.0 + # cfg.delta = float(cfg.sigma) + # cfg.friction = float(-np.log(np.sqrt(1-4*cfg.M))) + # u = 1/cfg.M + # cfg.inverse_temperature = float(4/(u*(1- np.sqrt(1 - 4/u)))) + print(f"Sampler params: {cfg.M}, {cfg.delta}, {cfg.friction}, {cfg.inverse_temperature}") + breakpoint() + + # Load the model from checkpoint by instantiating it with the checkpoint path + print("Loading model from checkpoint...") + cfg.model.checkpoint_path = checkpoint_path + model = hydra.utils.instantiate(cfg.model) + from e3tools.nn import LayerNorm + + model.conditioning_module.spatiotemporal_model.temporal_to_spatial_pooler.layer_norm = LayerNorm( + model.conditioning_module.spatiotemporal_model.temporal_module.irreps_out + ) + model.conditioning_module.spatiotemporal_model.spatial_to_temporal_pooler.layer_norm = LayerNorm( + model.conditioning_module.spatiotemporal_model.spatial_module.irreps_out + ) + print(f"Model loaded: {type(model)}") + breakpoint() + + # Set up initial datasets for sampling + print("Setting up initial datasets...") + init_datasets = hydra.utils.instantiate(cfg.init_datasets) + print(f"Initial datasets loaded: {len(init_datasets)} datasets") + print(f"Dataset types: {[type(ds) for ds in init_datasets]}") + breakpoint() + + # Manually construct the DataModule + print("Creating datamodule for testing...") + datamodule = MDtrajDataModule( + datasets={"train": init_datasets, "val": init_datasets, "test": init_datasets}, batch_size=1, num_workers=1 + ) + + datamodule.setup("test") + print("Datamodule setup complete") + breakpoint() + + # Get a sample batch + print("Getting a sample batch...") + test_loader = datamodule.test_dataloader() + batch_idx, batch = next(enumerate(test_loader)) + print(f"Batch shape: {batch.pos.shape}") + print(f"Batch keys: {batch.keys}") + # if hasattr(batch, 'hidden_state') and len(batch.hidden_state) > 0: + # print(f"Hidden state shapes: {[h.shape for h in batch.hidden_state]}") + breakpoint() + + # Set up sampler + print("Setting up sampler...") + sampler = hydra.utils.instantiate(cfg.sampler) + print(f"Sampler created: {type(sampler)}") + breakpoint() + + # set up batch sampler + batch_sampler = hydra.utils.instantiate(cfg.batch_sampler) + print(f"Batch sampler created: {type(batch_sampler)}") + print(f"Batch sampler mcmc: {batch_sampler.mcmc}") + breakpoint() + + # Write test for score + print("Testing score function...") + with torch.no_grad(): + init_graphs = batch + init_graphs = init_graphs.to(sampler.fabric.device) + model_wrapped = ModelSamplingWrapperMemory( + model=model, init_graphs=init_graphs, sigma=batch_sampler.sigma, recenter_on_init=True + ) + y_init = model_wrapped.sample_initial_noisy_positions() + y_hist_init = model_wrapped.sample_initial_noisy_history() + init_score = model_wrapped.score(y_init, y_hist_init, batch_sampler.sigma) + print(f"Initial score: {init_score}") + breakpoint() + + # Test walk + with torch.no_grad(): + y, v, y_hist, y_traj, score_traj, y_hist_traj = batch_sampler.mcmc( + y_init, + y_hist_init, + lambda y, y_hist: model_wrapped.score(y, y_hist, batch_sampler.sigma), + v_init="zero", + steps=5, + ) + print(f"Score trajectory: {score_traj}") + breakpoint() + + # Test jump + with torch.no_grad(): + xhat_traj = torch.stack( + [ + model_wrapped.xhat(y_traj[i, :], y_hist_traj[i], sigma=batch_sampler.sigma) + for i in tqdm(range(y_traj.size(0)), leave=False, desc="Jump") + ], + dim=0, + ) + print(f"Xhat trajectory: {xhat_traj}") + breakpoint() + + # Test walkjump + with torch.no_grad(): + out = batch_sampler.sample(model_wrapped, y_init=y_init, v_init="zero", y_hist_init=y_hist_init) + print(f"Out: {out}") + breakpoint() + + # Test unbatching + with torch.no_grad(): + samples = model_wrapped.unbatch_samples(out) + print(f"Samples: {samples}") + breakpoint() + + # Test sampling parameters + print("Testing sampling setup...") + print(f"Sigma: {cfg.sigma}") + print(f"M: {cfg.M}") + print(f"Delta: {cfg.delta}") + print(f"Friction: {cfg.friction}") + print(f"Number of sampling steps per batch: {cfg.num_sampling_steps_per_batch}") + print(f"Number of batches: {cfg.num_batches}") + breakpoint() + + # Test a forward pass with the model + print("Testing model forward pass...") + model.eval() + with torch.no_grad(): + # Test if the model can process the batch + if hasattr(model, "noise_and_denoise"): + sigma_tensor = torch.tensor([cfg.sigma]) + x_target, xhat, y = model.noise_and_denoise(batch, sigma_tensor, align_noisy_input=True) + print("Forward pass successful!") + print(f"Input shape: {batch.pos.shape}") + print(f"Noisy shape: {y.pos.shape}") + print(f"Denoised shape: {xhat.pos.shape}") + else: + print("Model doesn't have noise_and_denoise method, testing direct forward pass") + output = model(batch) + print(f"Model output: {type(output)}") + breakpoint() + + print("Script completed successfully!") + + +if __name__ == "__main__": + main() diff --git a/scratch/test_conditional_simple.py b/scratch/test_conditional_simple.py new file mode 100644 index 0000000..bae7188 --- /dev/null +++ b/scratch/test_conditional_simple.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 +""" +Simple test script for the debugged denoiser_conditional using default hydra config. +Tests with sigma = 0.0 and sigma = 0.1. +""" + +import e3nn + +e3nn.set_optimization_defaults(jit_script_fx=False) +import os +import sys + +import dotenv +import hydra +import torch +from omegaconf import OmegaConf + +from jamun.utils import compute_average_squared_distance_from_datasets + +breakpoint() # Start debugging + +dotenv.load_dotenv("../.env", verbose=True) +JAMUN_DATA_PATH = os.getenv("JAMUN_DATA_PATH") +JAMUN_ROOT_PATH = os.getenv("JAMUN_ROOT_PATH") + +project_root = "/homefs/home/sules/jamun" +if project_root not in sys.path: + sys.path.insert(0, project_root) + print(f"Added '{project_root}' to sys.path for module discovery.") + +breakpoint() # After setup + + +def compute_average_squared_distance_from_config(cfg: OmegaConf) -> float: + """Computes the average squared distance for normalization from the data.""" + breakpoint() # Start of function + datamodule = hydra.utils.instantiate(cfg.data.datamodule) + datamodule.setup("compute_normalization") + train_datasets = datamodule.datasets["train"] + cutoff = cfg.model.max_radius + average_squared_distance = compute_average_squared_distance_from_datasets(train_datasets, cutoff) + breakpoint() # After computation + return average_squared_distance + + +@hydra.main(version_base=None, config_path="../src/jamun/hydra_config", config_name="train") +def main(cfg): + breakpoint() # Start of main + print("=" * 60) + print("Testing debugged denoiser_conditional") + print("=" * 60) + + # Compute average squared distance + print("Computing average squared distance...") + breakpoint() # Before distance computation + average_squared_distance = compute_average_squared_distance_from_config(cfg) + cfg.model.average_squared_distance = average_squared_distance + print(f"Average squared distance: {average_squared_distance:.6f}") + + # Load datamodule + print("Loading datamodule...") + breakpoint() # Before datamodule + datamodule = hydra.utils.instantiate(cfg.data.datamodule) + datamodule.setup("test") + + # Load model + print("Loading model...") + breakpoint() # Before model loading + model = hydra.utils.instantiate(cfg.model) + + # Get a single batch + print("Getting a batch of data...") + breakpoint() # Before getting batch + train_loader = datamodule.train_dataloader() + _, batch = next(enumerate(train_loader)) + + print("Batch info:") + print(f" Position shape: {batch.pos.shape}") + print(f" Number of atoms: {batch.pos.shape[0]}") + print(f" Hidden state shapes: {[h.shape for h in batch.hidden_state]}") + print(f" Number of hidden states: {len(batch.hidden_state)}") + + breakpoint() # After batch info + + # Test with sigma = 0.0 + print("\n" + "=" * 40) + print("Testing with sigma = 0.0 (no noise)") + print("=" * 40) + + breakpoint() # Before sigma=0.0 test + + with torch.no_grad(): + breakpoint() # Before noise_and_denoise + sigma = torch.tensor(0.0) + x_target, xhat, y = model.noise_and_denoise(batch, sigma, align_noisy_input=True) + + print(f"Input shape: {batch.pos.shape}") + print(f"Noisy shape: {y.pos.shape}") + print(f"Output shape: {xhat.pos.shape}") + + breakpoint() # After noise_and_denoise + + # Compute loss + loss, aux = model.compute_loss(x_target, xhat, sigma) + print(f"Loss: {loss.mean().item():.6f}") + print(f"Metrics: {aux}") + + # Check if positions are preserved (should be identical with sigma=0) + pos_diff = torch.abs(batch.pos - y.pos).max() + print(f"Max position difference (sigma=0): {pos_diff.item():.8f}") + + breakpoint() # After sigma=0.0 test + + # Test with sigma = 0.1 + print("\n" + "=" * 40) + print("Testing with sigma = 0.1 (with noise)") + print("=" * 40) + + breakpoint() # Before sigma=0.1 test + + with torch.no_grad(): + sigma = torch.tensor(0.1) + breakpoint() # Before noise_and_denoise with sigma=0.1 + x_target, xhat, y = model.noise_and_denoise(batch, sigma, align_noisy_input=True) + + print(f"Input shape: {batch.pos.shape}") + print(f"Noisy shape: {y.pos.shape}") + print(f"Output shape: {xhat.pos.shape}") + + # Compute loss + loss, aux = model.compute_loss(x_target, xhat, sigma) + print(f"Loss: {loss.mean().item():.6f}") + print(f"Metrics: {aux}") + + # Check noise level + pos_diff = torch.abs(batch.pos - y.pos).max() + print(f"Max position difference (sigma=0.1): {pos_diff.item():.6f}") + + # Check denoising quality + denoise_diff = torch.abs(batch.pos - xhat.pos).max() + print(f"Max denoising difference: {denoise_diff.item():.6f}") + + breakpoint() # After sigma=0.1 test + + print("\n" + "=" * 60) + print("Testing complete!") + print("=" * 60) + + breakpoint() # End of main + + +if __name__ == "__main__": + main() diff --git a/scratch/test_conditioners.py b/scratch/test_conditioners.py new file mode 100644 index 0000000..bc66643 --- /dev/null +++ b/scratch/test_conditioners.py @@ -0,0 +1,395 @@ +#!/usr/bin/env python3 +""" +Test script for SelfConditioner, PositionConditioner, and MeanConditioner with both +MDtrajDataset and RepeatedPositionDataset. +""" + +# Add the src directory to the path so we can import jamun modules +import sys +from pathlib import Path + +import torch +import torch_geometric + +sys.path.insert(0, str(Path(__file__).parent / "src")) + +from jamun.data._mdtraj import MDtrajDataset +from jamun.data.noisy_position_dataset import RepeatedPositionDataset +from jamun.model.conditioners.conditioners import MeanConditioner, PositionConditioner, SelfConditioner + + +def print_tensor_summary(tensor, name, max_elements=6): + """Print a summary of a tensor with first few elements.""" + if tensor.numel() <= max_elements: + print(f"{name}: {tensor.flatten().tolist()}") + else: + flat = tensor.flatten() + print( + f"{name} (shape {tensor.shape}): [{flat[0]:.6f}, {flat[1]:.6f}, {flat[2]:.6f}, ..., {flat[-3]:.6f}, {flat[-2]:.6f}, {flat[-1]:.6f}]" + ) + + +def create_datasets(): + """Create both types of datasets with 3 total structures (2 hidden states).""" + + root = "/data/bucket/kleinhej/capped_diamines/timewarp_splits/train" + traj_files = ["ALA_ALA.xtc"] + pdb_file = "ALA_ALA.pdb" + total_lag_time = 3 # This should create 2 hidden states (3 - 1 = 2) + + print(f"Creating datasets with total_lag_time={total_lag_time} (expecting 2 hidden states)...") + + # Create MDtrajDataset (with real lag processing) + mdtraj_dataset = MDtrajDataset( + root=root, + traj_files=traj_files, + pdb_file=pdb_file, + label="ALA_ALA_mdtraj", + total_lag_time=total_lag_time, + lag_subsample_rate=1, + num_frames=10, + verbose=True, + ) + + # Create RepeatedPositionDataset (with position copies) + repeated_dataset = RepeatedPositionDataset( + root=root, + traj_files=traj_files, + pdb_file=pdb_file, + label="ALA_ALA_repeated", + total_lag_time=total_lag_time, + lag_subsample_rate=1, + num_frames=10, + verbose=True, + ) + + return mdtraj_dataset, repeated_dataset + + +def create_batch_from_dataset(dataset, sample_idx=0): + """Create a batched graph from a single dataset sample.""" + graph = dataset[sample_idx] + batch = torch_geometric.data.Batch.from_data_list([graph]) + return batch + + +def print_batch_details(batch, batch_name): + """Print detailed information about a batch.""" + print(f"\n--- {batch_name} Details ---") + print(f"Position shape: {batch.pos.shape}") + print_tensor_summary(batch.pos, "Position") + + # Check if position is mean centered + pos_mean = torch.mean(batch.pos, dim=0) # Mean over atoms + pos_mean_magnitude = torch.norm(pos_mean).item() + print(f"Position mean: [{pos_mean[0]:.6f}, {pos_mean[1]:.6f}, {pos_mean[2]:.6f}]") + print(f"Position mean magnitude: {pos_mean_magnitude:.6f}") + if pos_mean_magnitude < 1e-6: + print("✅ Input position is mean centered") + else: + print("❌ Input position is NOT mean centered") + + print(f"Number of hidden states: {len(batch.hidden_state)}") + for i, hidden_state in enumerate(batch.hidden_state): + print(f"Hidden state {i} shape: {hidden_state.shape}") + print_tensor_summary(hidden_state, f"Hidden state {i}") + + # Check if hidden state is mean centered + hidden_mean = torch.mean(hidden_state, dim=0) # Mean over atoms + hidden_mean_magnitude = torch.norm(hidden_mean).item() + print(f"Hidden state {i} mean: [{hidden_mean[0]:.6f}, {hidden_mean[1]:.6f}, {hidden_mean[2]:.6f}]") + print(f"Hidden state {i} mean magnitude: {hidden_mean_magnitude:.6f}") + if hidden_mean_magnitude < 1e-6: + print(f"✅ Hidden state {i} is mean centered") + else: + print(f"❌ Hidden state {i} is NOT mean centered") + + +def test_conditioner_detailed(conditioner, batch, test_name): + """Test a conditioner with detailed output.""" + print(f"\n{'=' * 70}") + print(f"{test_name}") + print(f"{'=' * 70}") + + # Print input details + print_batch_details(batch, "Input Batch") + + # Run the conditioner + try: + print(f"\nRunning {conditioner.__class__.__name__}...") + conditioned_structures = conditioner(batch) + + print("\n--- Conditioner Output ---") + print(f"Number of conditioned structures: {len(conditioned_structures)}") + + # Print each conditioned structure + for i, structure in enumerate(conditioned_structures): + print(f"\nConditioned structure {i} shape: {structure.shape}") + print_tensor_summary(structure, f"Conditioned structure {i}") + + # Compare with input position + pos_diff = torch.max(torch.abs(structure - batch.pos)).item() + print(f"Max difference from current position: {pos_diff:.10f}") + + # Compare with hidden states if available + if i < len(batch.hidden_state): + hidden_diff = torch.max(torch.abs(structure - batch.hidden_state[i])).item() + print(f"Max difference from hidden state {i}: {hidden_diff:.10f}") + + # Check if structure is mean centered (for PositionConditioner) + if conditioner.__class__.__name__ == "PositionConditioner": + structure_mean = torch.mean(structure, dim=0) # Mean over atoms + mean_magnitude = torch.norm(structure_mean).item() + print( + f"Mean of structure {i}: [{structure_mean[0]:.6f}, {structure_mean[1]:.6f}, {structure_mean[2]:.6f}]" + ) + print(f"Magnitude of mean: {mean_magnitude:.6f}") + + # Check if it's close to zero (mean centered) + if mean_magnitude < 1e-6: + print(f"✅ Structure {i} is mean centered (mean ≈ 0)") + else: + print(f"❌ Structure {i} is NOT mean centered") + + # Check if structure contains means across time steps (for MeanConditioner) + if conditioner.__class__.__name__ == "MeanConditioner": + # For MeanConditioner, each structure should be the mean across time steps + # All structures should be identical, but atoms can have different coordinates + print(f"✅ Structure {i} contains mean across time steps") + + # Check if this structure is the same as the first structure (all should be the same mean) + if i > 0: + first_structure = conditioned_structures[0] + structures_same = torch.allclose(structure, first_structure, atol=1e-6) + if structures_same: + print(f"✅ Structure {i} matches structure 0 (all structures are identical means)") + else: + print(f"❌ Structure {i} doesn't match structure 0 (all should be identical)") + max_diff = torch.max(torch.abs(structure - first_structure)).item() + print(f"Maximum difference: {max_diff:.10f}") + + # Verify the mean computation is correct by manually computing it + if hasattr(batch, "hidden_state") and batch.hidden_state is not None: + all_positions = [batch.pos] + batch.hidden_state + manual_mean = torch.mean(torch.stack(all_positions, dim=0), dim=0) + mean_correct = torch.allclose(structure, manual_mean, atol=1e-6) + if mean_correct: + print(f"✅ Structure {i} correctly computed as mean across {len(all_positions)} time steps") + else: + print(f"❌ Structure {i} mean computation incorrect") + max_diff = torch.max(torch.abs(structure - manual_mean)).item() + print(f"Maximum difference from expected mean: {max_diff:.10f}") + else: + # If no hidden states, should just be y.pos repeated + pos_same = torch.allclose(structure, batch.pos, atol=1e-6) + if pos_same: + print(f"✅ Structure {i} correctly equals y.pos (no hidden states)") + else: + print(f"❌ Structure {i} should equal y.pos when no hidden states") + + return conditioned_structures + + except Exception as e: + print(f"❌ ERROR: Exception in {test_name}: {e}") + import traceback + + traceback.print_exc() + return False + + +def verify_results(conditioned_structures, batch, test_name, expected_behavior): + """Verify the results match expected behavior.""" + print(f"\n--- Verification for {test_name} ---") + print(f"Expected behavior: {expected_behavior}") + + expected_count = 3 # N_structures = 3, so we expect 3 total structures including current position + if len(conditioned_structures) != expected_count: + print(f"❌ ERROR: Expected {expected_count} structures, got {len(conditioned_structures)}") + return False + + print(f"✅ Correct count: {len(conditioned_structures)} structures") + + success = True + for i, structure in enumerate(conditioned_structures): + if structure.shape != batch.pos.shape: + print(f"❌ ERROR: Structure {i} shape mismatch!") + success = False + else: + print(f"✅ Structure {i} has correct shape") + + # Verify that the first structure is the current position for most conditioners + # Exception: MeanConditioner returns time-averaged means, not current position + first_structure = conditioned_structures[0] + pos_diff = torch.max(torch.abs(first_structure - batch.pos)).item() + + if test_name.startswith("MeanConditioner"): + # For MeanConditioner, first structure should be the time-averaged mean, not y.pos + if hasattr(batch, "hidden_state") and batch.hidden_state is not None: + all_positions = [batch.pos] + batch.hidden_state + expected_mean = torch.mean(torch.stack(all_positions, dim=0), dim=0) + mean_diff = torch.max(torch.abs(first_structure - expected_mean)).item() + if mean_diff < 1e-6: + print(f"✅ First structure correctly equals time-averaged mean (diff: {mean_diff:.2e})") + else: + print(f"❌ ERROR: First structure doesn't match expected time-averaged mean (diff: {mean_diff:.2e})") + success = False + else: + # If no hidden states, should equal y.pos + if pos_diff < 1e-10: + print(f"✅ First structure correctly equals y.pos (no hidden states, diff: {pos_diff:.2e})") + else: + print(f"❌ ERROR: First structure doesn't match y.pos when no hidden states (diff: {pos_diff:.2e})") + success = False + else: + # For other conditioners, first structure should be y.pos + if pos_diff < 1e-10: + print(f"✅ First structure matches current position (diff: {pos_diff:.2e})") + else: + print(f"❌ ERROR: First structure doesn't match current position (diff: {pos_diff:.2e})") + success = False + + return success + + +def main(): + """Main test function.""" + print( + "Testing Conditioners: SelfConditioner, PositionConditioner, MeanConditioner with 3 total structures (2 hidden states)" + ) + print("=" * 70) + + # Create datasets + try: + mdtraj_dataset, repeated_dataset = create_datasets() + print("✅ Created datasets") + print(f" MDtrajDataset length: {len(mdtraj_dataset)}") + print(f" RepeatedPositionDataset length: {len(repeated_dataset)}") + except Exception as e: + print(f"❌ ERROR: Failed to create datasets: {e}") + return False + + # Create batches + try: + mdtraj_batch = create_batch_from_dataset(mdtraj_dataset, sample_idx=0) + repeated_batch = create_batch_from_dataset(repeated_dataset, sample_idx=0) + print("✅ Created batches") + except Exception as e: + print(f"❌ ERROR: Failed to create batches: {e}") + return False + + # Create conditioners + try: + self_conditioner = SelfConditioner(N_structures=3) + position_conditioner = PositionConditioner(N_structures=3) + mean_conditioner = MeanConditioner(N_structures=3) + print("✅ Created conditioners") + except Exception as e: + print(f"❌ ERROR: Failed to create conditioners: {e}") + return False + + # Test 1: SelfConditioner on MDtrajDataset + result1 = test_conditioner_detailed(self_conditioner, mdtraj_batch, "TEST 1: SelfConditioner on MDtrajDataset") + if result1 is False: + return False + success1 = verify_results( + result1, + mdtraj_batch, + "SelfConditioner + MDtrajDataset", + "Should return [y.pos, y.pos, y.pos] - 3 copies of current position", + ) + + # Test 2: SelfConditioner on RepeatedPositionDataset + result2 = test_conditioner_detailed( + self_conditioner, repeated_batch, "TEST 2: SelfConditioner on RepeatedPositionDataset" + ) + if result2 is False: + return False + success2 = verify_results( + result2, + repeated_batch, + "SelfConditioner + RepeatedPositionDataset", + "Should return [y.pos, y.pos, y.pos] - 3 copies of current position", + ) + + # Test 3: PositionConditioner on MDtrajDataset + result3 = test_conditioner_detailed( + position_conditioner, mdtraj_batch, "TEST 3: PositionConditioner on MDtrajDataset" + ) + if result3 is False: + return False + success3 = verify_results( + result3, + mdtraj_batch, + "PositionConditioner + MDtrajDataset", + "Should return [y.pos, aligned_hidden_state_1, aligned_hidden_state_2] - current position + 2 aligned hidden states", + ) + + # Test 4: PositionConditioner on RepeatedPositionDataset + result4 = test_conditioner_detailed( + position_conditioner, repeated_batch, "TEST 4: PositionConditioner on RepeatedPositionDataset" + ) + if result4 is False: + return False + success4 = verify_results( + result4, + repeated_batch, + "PositionConditioner + RepeatedPositionDataset", + "Should return [y.pos, aligned_copy_1, aligned_copy_2] - current position + 2 aligned copies", + ) + + # Test 5: MeanConditioner on MDtrajDataset + result5 = test_conditioner_detailed(mean_conditioner, mdtraj_batch, "TEST 5: MeanConditioner on MDtrajDataset") + if result5 is False: + return False + success5 = verify_results( + result5, + mdtraj_batch, + "MeanConditioner + MDtrajDataset", + "Should return [time_mean, time_mean, time_mean] - 3 copies of mean across time steps (y.pos + hidden states)", + ) + + # Test 6: MeanConditioner on RepeatedPositionDataset + result6 = test_conditioner_detailed( + mean_conditioner, repeated_batch, "TEST 6: MeanConditioner on RepeatedPositionDataset" + ) + if result6 is False: + return False + success6 = verify_results( + result6, + repeated_batch, + "MeanConditioner + RepeatedPositionDataset", + "Should return [time_mean, time_mean, time_mean] - 3 copies of mean across time steps (y.pos + hidden states)", + ) + + # Summary + print(f"\n{'=' * 70}") + print("SUMMARY") + print(f"{'=' * 70}") + + tests = [ + ("SelfConditioner + MDtrajDataset", success1), + ("SelfConditioner + RepeatedPositionDataset", success2), + ("PositionConditioner + MDtrajDataset", success3), + ("PositionConditioner + RepeatedPositionDataset", success4), + ("MeanConditioner + MDtrajDataset", success5), + ("MeanConditioner + RepeatedPositionDataset", success6), + ] + + all_passed = True + for test_name, success in tests: + status = "✅ PASS" if success else "❌ FAIL" + print(f"{test_name}: {status}") + if not success: + all_passed = False + + if all_passed: + print("\n🎉 All conditioner tests passed!") + return True + else: + print("\n💥 Some conditioner tests failed!") + return False + + +if __name__ == "__main__": + success = main() + exit(0 if success else 1) diff --git a/scratch/test_denoised_conditioner.py b/scratch/test_denoised_conditioner.py new file mode 100644 index 0000000..fef2df2 --- /dev/null +++ b/scratch/test_denoised_conditioner.py @@ -0,0 +1,423 @@ +#!/usr/bin/env python3 +""" +Comprehensive test for DenoisedConditioner using real ALA_ALA data. +This test emulates the behavior of xhat_normalized to properly test the scaling parameter. +""" + +import os + +import e3nn +import numpy as np +import torch +import torch_geometric + +from jamun.data import parse_datasets_from_directory +from jamun.model.conditioners import DenoisedConditioner +from jamun.utils import mean_center, unsqueeze_trailing +from jamun.utils._normalizations import normalization_factors + +# Fix e3nn optimization for avoiding script issues +e3nn.set_optimization_defaults(jit_script_fx=False) + + +def load_ala_ala_data(): + """Load actual ALA_ALA data from the capped diamines dataset.""" + + # Get data path from environment variable + data_path = os.getenv("JAMUN_DATA_PATH") + if not data_path: + raise ValueError("JAMUN_DATA_PATH environment variable not set") + + # Load ALA_ALA dataset + datasets = parse_datasets_from_directory( + root=f"{data_path}/capped_diamines/timewarp_splits/train", + traj_pattern="^(.*).xtc", + pdb_pattern="^(.*).pdb", + filter_codes=["ALA_ALA"], + as_iterable=False, + subsample=1, + total_lag_time=3, # This will give us hidden states + lag_subsample_rate=1, + num_frames=10, + max_datasets=1, + ) + + if not datasets: + raise ValueError("No ALA_ALA datasets found") + + # Get the first dataset + dataset = datasets[0] + print(f"Loaded ALA_ALA dataset with {len(dataset)} frames") + + # Get a few samples to create a batch + samples = [] + for i in range(min(2, len(dataset))): # Get 2 samples for batch + sample = dataset[i] + samples.append(sample) + + # Create batch + batch = torch_geometric.data.Batch.from_data_list(samples) + + print(f"Created batch with {batch.num_graphs} graphs") + print(f"Batch position shape: {batch.pos.shape}") + + if hasattr(batch, "hidden_state") and batch.hidden_state: + print(f"Hidden states: {len(batch.hidden_state)} states") + for i, hidden_state in enumerate(batch.hidden_state): + print(f" Hidden state {i}: shape {hidden_state.shape}") + else: + print("No hidden states found") + + return batch + + +def add_noise_to_batch(x: torch_geometric.data.Batch, sigma: float) -> torch_geometric.data.Batch: + """Add noise to a batch, similar to the denoiser's add_noise method.""" + sigma = unsqueeze_trailing(torch.tensor(sigma), x.pos.ndim) + + y = x.clone() + + # Add noise to positions + noise = torch.randn_like(x.pos) + y.pos = x.pos + sigma * noise + + # Add noise to hidden states if they exist + if hasattr(x, "hidden_state") and x.hidden_state is not None: + y.hidden_state = [] + for hidden_positions in x.hidden_state: + hidden_noise = torch.randn_like(hidden_positions) + y.hidden_state.append(hidden_positions + sigma * hidden_noise) + + return y + + +def mean_center_positions(batch: torch_geometric.data.Batch) -> torch_geometric.data.Batch: + """Mean-center positions and hidden states for each graph in the batch.""" + + # Mean-center the main positions using the jamun utils function + batch = mean_center(batch) + + # Mean-center each hidden state individually + if hasattr(batch, "hidden_state") and batch.hidden_state is not None: + for i, hidden_positions in enumerate(batch.hidden_state): + # Create a temporary batch with just the hidden state positions to mean-center + temp_batch = batch.clone() + temp_batch.pos = hidden_positions + temp_batch_centered = mean_center(temp_batch) + batch.hidden_state[i] = temp_batch_centered.pos + + return batch + + +def emulate_xhat_normalized_scaling(batch, sigma: float, average_squared_distance: float = 0.332): + """ + Emulate the scaling behavior in xhat_normalized method. + This simulates how the denoiser scales data before passing to conditioner. + """ + print("\n=== Emulating xhat_normalized scaling ===") + print(f"Input sigma: {sigma}") + print(f"Average squared distance: {average_squared_distance}") + + # Mean-center the batch positions and hidden states (as done in actual xhat_normalized) + batch = mean_center_positions(batch) + + # Compute normalization factors (same as in denoiser) + c_in, c_skip, c_out, c_noise = normalization_factors(sigma, average_squared_distance) + + print("Normalization factors:") + print(f" c_in: {c_in}") + print(f" c_skip: {c_skip}") + print(f" c_out: {c_out}") + print(f" c_noise: {c_noise}") + + # Adjust dimensions (same as in denoiser) + c_in = unsqueeze_trailing(c_in, batch.pos.ndim - 1) + c_skip = unsqueeze_trailing(c_skip, batch.pos.ndim - 1) + c_out = unsqueeze_trailing(c_out, batch.pos.ndim - 1) + c_noise = c_noise.unsqueeze(0) + + # Scale the batch (same as in denoiser) + y_scaled = batch.clone() + y_scaled.pos = batch.pos * c_in + + print(f"Original position mean: {batch.pos.mean():.6f}") + print(f"Scaled position mean: {y_scaled.pos.mean():.6f}") + + # Scale hidden states (same as in denoiser) + if hasattr(batch, "hidden_state") and batch.hidden_state is not None: + y_scaled.hidden_state = [] + for i, positions in enumerate(batch.hidden_state): + scaled_positions = positions * c_in + y_scaled.hidden_state.append(scaled_positions) + print( + f"Hidden state {i} - Original mean: {positions.mean():.6f}, Scaled mean: {scaled_positions.mean():.6f}" + ) + + return y_scaled, c_in, c_skip, c_out, c_noise + + +def test_denoised_conditioner_with_scaling(): + """Test the DenoisedConditioner with proper scaling emulation.""" + + print("=== Testing DenoisedConditioner with xhat_normalized scaling ===") + + # Test parameters + N_structures = 3 # Must match architecture N_structures (updated to match hidden states) + pretrained_model_path = "sule-shashank/jamun/370wpt17" # Update this to your desired checkpoint + test_sigma = 0.04 + + try: + # Load real ALA_ALA data + print("\n1. Loading real ALA_ALA data...") + original_batch = load_ala_ala_data() + print("✓ Real ALA_ALA data loaded successfully") + + # Mean-center the original batch (clean reference) - positions and hidden states + print("\n2. Mean-centering the data...") + print(f" Original position mean: {original_batch.pos.mean():.6f}") + if hasattr(original_batch, "hidden_state") and original_batch.hidden_state: + for i, hidden_state in enumerate(original_batch.hidden_state): + print(f" Original hidden state {i} mean: {hidden_state.mean():.6f}") + + x_clean = mean_center_positions(original_batch) + print(f" Mean-centered position mean: {x_clean.pos.mean():.6f}") + + if hasattr(x_clean, "hidden_state") and x_clean.hidden_state: + for i, hidden_state in enumerate(x_clean.hidden_state): + print(f" Mean-centered hidden state {i} mean: {hidden_state.mean():.6f}") + + # Add noise to the mean-centered data + print(f"\n3. Adding noise with sigma={test_sigma}...") + y_noisy = add_noise_to_batch(x_clean, test_sigma) + print(f" Noisy position mean: {y_noisy.pos.mean():.6f}") + print(f" Noisy position std: {y_noisy.pos.std():.6f}") + + if hasattr(y_noisy, "hidden_state") and y_noisy.hidden_state: + for i, hidden_state in enumerate(y_noisy.hidden_state): + print(f" Noisy hidden state {i} mean: {hidden_state.mean():.6f}") + print(f" Noisy hidden state {i} std: {hidden_state.std():.6f}") + + # Initialize conditioner and extract average_squared_distance from checkpoint + print("\n4. Initializing DenoisedConditioner and extracting average_squared_distance...") + print(f" N_structures: {N_structures}") + print(f" pretrained_model_path: {pretrained_model_path}") + + # Use a temporary c_in for initialization + temp_c_in, _, _, _ = normalization_factors(test_sigma, 0.332) # temporary default + temp_c_in_float = float(temp_c_in) + + # Initialize conditioner + conditioner = DenoisedConditioner( + N_structures=N_structures, pretrained_model_path=pretrained_model_path, c_in=temp_c_in_float + ) + + print("✓ DenoisedConditioner initialized successfully") + print(f" Denoiser sigma: {conditioner.denoiser_sigma}") + + # Extract average_squared_distance from the loaded checkpoint + average_squared_distance = None + if hasattr(conditioner.pretrained_denoiser, "average_squared_distance"): + average_squared_distance = float(conditioner.pretrained_denoiser.average_squared_distance) + print(f" ✓ Extracted average_squared_distance from checkpoint: {average_squared_distance}") + elif hasattr(conditioner.pretrained_denoiser, "hparams") and hasattr( + conditioner.pretrained_denoiser.hparams, "average_squared_distance" + ): + average_squared_distance = float(conditioner.pretrained_denoiser.hparams.average_squared_distance) + print(f" ✓ Extracted average_squared_distance from hparams: {average_squared_distance}") + else: + # Try to extract from the config if available + if hasattr(conditioner.pretrained_denoiser, "cfg"): + cfg = conditioner.pretrained_denoiser.cfg + if hasattr(cfg, "average_squared_distance"): + average_squared_distance = float(cfg.average_squared_distance) + print(f" ✓ Extracted average_squared_distance from config: {average_squared_distance}") + + if average_squared_distance is None: + print(" ⚠️ Could not extract average_squared_distance from checkpoint") + print(" Available attributes on pretrained_denoiser:") + for attr in dir(conditioner.pretrained_denoiser): + if not attr.startswith("_"): + print(f" - {attr}") + # Use default + average_squared_distance = 0.332 + print(f" Using default average_squared_distance: {average_squared_distance}") + + # Recompute c_in with the correct average_squared_distance + c_in, _, _, _ = normalization_factors(test_sigma, average_squared_distance) + c_in_float = float(c_in) + + # Update the conditioner's c_in + if abs(c_in_float - temp_c_in_float) > 1e-6: + print(f" Updating c_in from {temp_c_in_float} to {c_in_float}") + conditioner.c_in = c_in_float + else: + print(f" c_in remains: {c_in_float}") + + # Test sigma consistency + print("\n5. Testing sigma consistency...") + if abs(conditioner.denoiser_sigma - test_sigma) < 1e-5: + print(f"✓ Test sigma ({test_sigma}) matches denoiser sigma ({conditioner.denoiser_sigma})") + else: + print(f"⚠️ Test sigma ({test_sigma}) differs from denoiser sigma ({conditioner.denoiser_sigma})") + print(" Using denoiser sigma for consistency") + test_sigma = conditioner.denoiser_sigma + # Recompute c_in with corrected sigma + c_in, _, _, _ = normalization_factors(test_sigma, average_squared_distance) + c_in_float = float(c_in) + conditioner.c_in = c_in_float + print(f" Updated c_in to: {c_in_float}") + + # Emulate xhat_normalized scaling on the noisy batch + print("\n6. Emulating xhat_normalized scaling...") + scaled_batch, c_in_tensor, c_skip, c_out, c_noise = emulate_xhat_normalized_scaling( + y_noisy, test_sigma, average_squared_distance + ) + + # Verify our c_in calculation matches + assert abs(float(c_in_tensor) - c_in_float) < 1e-6, f"c_in mismatch: {c_in_tensor} vs {c_in_float}" + print("✓ c_in calculation verified") + + # Test conditioner with scaled noisy data + print("\n7. Testing conditioner with scaled noisy data...") + + # Move scaled_batch to the same device as the conditioner + device = next(conditioner.parameters()).device + scaled_batch = scaled_batch.to(device) + x_clean = x_clean.to(device) + y_noisy = y_noisy.to(device) + print(f" Moved batches to device: {device}") + + conditioned_structures = conditioner.forward(scaled_batch) + + print("✓ Conditioner forward pass completed") + print(f" Returned {len(conditioned_structures)} structures") + print(f" Expected N_structures: {N_structures}") + + # Verify output structure + assert len(conditioned_structures) == N_structures, ( + f"Expected {N_structures} structures, got {len(conditioned_structures)}" + ) + + for i, structure in enumerate(conditioned_structures): + assert structure.shape == scaled_batch.pos.shape, ( + f"Structure {i} has wrong shape: {structure.shape} vs {scaled_batch.pos.shape}" + ) + print(f" Structure {i}: shape {structure.shape}") + + # Check that first structure is the scaled current position + assert torch.allclose(conditioned_structures[0], scaled_batch.pos), ( + "First structure should be scaled current position" + ) + print("✓ First structure matches scaled current position") + + # Comprehensive denoising quality test + print("\n8. COMPREHENSIVE DENOISING QUALITY TEST...") + denoising_improvements = [] + + if hasattr(x_clean, "hidden_state") and x_clean.hidden_state and len(conditioned_structures) > 1: + print(f" Testing denoising on {len(x_clean.hidden_state)} hidden states...") + + for i in range(1, len(conditioned_structures)): # Skip first structure (current position) + hidden_idx = i - 1 # Map to hidden state index + if hidden_idx < len(x_clean.hidden_state): + denoised_structure = conditioned_structures[i] + clean_hidden = x_clean.hidden_state[hidden_idx] + noisy_hidden = y_noisy.hidden_state[hidden_idx] + + # Calculate RMSE between denoised and clean + denoised_rmse = torch.sqrt(torch.mean((denoised_structure - clean_hidden) ** 2)) + + # Calculate RMSE between noisy and clean for comparison + noisy_rmse = torch.sqrt(torch.mean((noisy_hidden - clean_hidden) ** 2)) + + # Calculate improvement + improvement = noisy_rmse - denoised_rmse + improvement_percent = (improvement / noisy_rmse) * 100 + + print(f" Hidden State {hidden_idx}:") + print(f" Noisy RMSE vs clean: {noisy_rmse.item():.6f}") + print(f" Denoised RMSE vs clean: {denoised_rmse.item():.6f}") + print(f" Improvement: {improvement.item():.6f} ({improvement_percent.item():.2f}%)") + + denoising_improvements.append(improvement.item()) + + if improvement > 0: + print(" ✓ DENOISING SUCCESSFUL (RMSE reduced)") + else: + print(" ❌ DENOISING FAILED (RMSE increased)") + + # Verify denoised is different from both noisy and original + assert not torch.allclose(denoised_structure, noisy_hidden, atol=1e-4), ( + f"Denoised structure {i} should be different from noisy" + ) + assert not torch.allclose(denoised_structure, scaled_batch.pos, atol=1e-4), ( + f"Denoised structure {i} should be different from current position" + ) + + # Overall denoising assessment + print("\n9. OVERALL DENOISING ASSESSMENT...") + if denoising_improvements: + avg_improvement = np.mean(denoising_improvements) + successful_denoising = sum(1 for imp in denoising_improvements if imp > 0) + total_tests = len(denoising_improvements) + success_rate = (successful_denoising / total_tests) * 100 + + print(f" Total hidden states tested: {total_tests}") + print(f" Successful denoising: {successful_denoising}/{total_tests} ({success_rate:.1f}%)") + print(f" Average RMSE improvement: {avg_improvement:.6f}") + + if success_rate >= 80: # Require at least 80% success rate + print(" ✅ DENOISING QUALITY: EXCELLENT (≥80% success)") + elif success_rate >= 60: + print(" ⚠️ DENOISING QUALITY: GOOD (≥60% success)") + elif success_rate >= 40: + print(" ⚠️ DENOISING QUALITY: MODERATE (≥40% success)") + else: + print(" ❌ DENOISING QUALITY: POOR (<40% success)") + + # Assert that at least some denoising occurred + assert successful_denoising > 0, "At least one hidden state should show denoising improvement" + print(" ✓ At least some denoising improvement verified") + + else: + print(" No hidden states available for denoising quality assessment") + + # Additional validation + print("\n10. Additional validation...") + for i, structure in enumerate(conditioned_structures): + # Check for NaN values + assert not torch.isnan(structure).any(), f"Structure {i} contains NaN values" + # Check for infinite values + assert not torch.isinf(structure).any(), f"Structure {i} contains infinite values" + print(f"✓ Structure {i} contains valid values") + + print("\n🎉 All tests passed! DenoisedConditioner works correctly.") + + # Print final summary + print("\n=== FINAL SUMMARY ===") + print(f" Checkpoint: {pretrained_model_path}") + print(f" Test sigma: {test_sigma}") + print(f" Denoiser sigma: {conditioner.denoiser_sigma}") + print(f" Average squared distance: {average_squared_distance}") + print(f" Computed c_in: {c_in_float}") + if denoising_improvements: + print(f" Denoising success rate: {success_rate:.1f}%") + print(f" Average RMSE improvement: {avg_improvement:.6f}") + + return True + + except Exception as e: + print(f"❌ Test failed with error: {e}") + import traceback + + traceback.print_exc() + return False + + +if __name__ == "__main__": + success = test_denoised_conditioner_with_scaling() + if success: + print("\n✅ DenoisedConditioner scaling test passed!") + print("The conditioner correctly handles the scaling parameter and emulates xhat_normalized behavior.") + else: + print("\n❌ DenoisedConditioner scaling test failed!") diff --git a/scratch/test_gradient_equivalence.py b/scratch/test_gradient_equivalence.py new file mode 100644 index 0000000..e338c63 --- /dev/null +++ b/scratch/test_gradient_equivalence.py @@ -0,0 +1,279 @@ +""" +Test gradient equivalence between automatic and manual optimization in DenoiserMultimeasurement. + +To run this test with proper hydra configuration: + +python3 scratch/test_gradient_equivalence.py --config-dir=configs experiment=train_test_single_shape_conditional ++model._target_=jamun.model.denoiser_multimeasurement.DenoiserMultimeasurement ++model.multimeasurement=True ++model.N_measurements_hidden=2 ++model.N_measurements=2 ++model.max_graphs_per_batch=1 + +This will: +- Use the train_test_single_shape_conditional experiment config +- Override model to use DenoiserMultimeasurement +- Enable multimeasurement with 2 hidden measurements and 2 measurements +- Set max_graphs_per_batch=1 for manual optimization testing +""" + +import e3nn + +e3nn.set_optimization_defaults(jit_script_fx=False) +import os +import sys + +import dotenv +import hydra +import numpy as np +import torch +import torch_geometric + +dotenv.load_dotenv("../.env", verbose=True) +JAMUN_DATA_PATH = os.getenv("JAMUN_DATA_PATH") + +project_root = "/homefs/home/sules/jamun" +if project_root not in sys.path: + sys.path.insert(0, project_root) + +from jamun.utils import compute_average_squared_distance_from_datasets + + +def create_model_and_data(cfg, max_graphs_per_batch=None): + """Create a model and datamodule with specified optimization mode.""" + # Configure DenoiserMultimeasurement + cfg.model._target_ = "jamun.model.denoiser_multimeasurement.DenoiserMultimeasurement" + cfg.model.sigma_distribution._target_ = "jamun.distributions.ConstantSigma" + cfg.model.sigma_distribution.sigma = 0.04 + cfg.model.multimeasurement = True + cfg.model.N_measurements_hidden = 2 + cfg.model.N_measurements = 2 + cfg.model.max_graphs_per_batch = max_graphs_per_batch + + # # Set up data - use correct attribute name filter_codes + # cfg.data.datamodule.datasets.train.filter_codes = ['ALA_ALA'] + # cfg.data.datamodule.datasets.val.filter_codes = ['ALA_ALA'] + # cfg.data.datamodule.datasets.test.filter_codes = ['ALA_ALA'] + # cfg.data.datamodule.batch_size = 8 # Larger batch for meaningful chunking + + # Compute normalization + datamodule = hydra.utils.instantiate(cfg.data.datamodule) + datamodule.setup("compute_normalization") + train_datasets = datamodule.datasets["train"] + cutoff = cfg.model.max_radius + average_squared_distance = compute_average_squared_distance_from_datasets(train_datasets, cutoff) + cfg.model.average_squared_distance = average_squared_distance + + # Create model and data + model = hydra.utils.instantiate(cfg.model) + datamodule.setup("test") + + return model, datamodule + + +def get_model_gradients(model): + """Extract gradients from model parameters.""" + gradients = {} + for name, param in model.named_parameters(): + if param.grad is not None: + gradients[name] = param.grad.clone().detach() + return gradients + + +def compute_gradient_norm(gradients): + """Compute the total gradient norm across all parameters.""" + total_norm = 0.0 + for grad in gradients.values(): + total_norm += grad.norm().item() ** 2 + return total_norm**0.5 + + +def compare_gradients(grad1, grad2, tolerance=1e-3): + """Compare two gradient dictionaries.""" + if set(grad1.keys()) != set(grad2.keys()): + print("ERROR: Different parameter names!") + return False + + max_relative_diff = 0.0 + for name in grad1.keys(): + g1, g2 = grad1[name], grad2[name] + + # Compute relative difference + diff = torch.abs(g1 - g2) + max_val = torch.max(torch.abs(g1), torch.abs(g2)) + relative_diff = torch.where(max_val > 1e-8, diff / (max_val + 1e-8), diff) + max_rel_diff_param = relative_diff.max().item() + max_relative_diff = max(max_relative_diff, max_rel_diff_param) + + print( + f"{name:30s}: max_rel_diff = {max_rel_diff_param:.6f}, norm_ratio = {g1.norm().item() / g2.norm().item():.6f}" + ) + + print(f"\nOverall max relative difference: {max_relative_diff:.6f}") + return max_relative_diff < tolerance + + +@hydra.main(version_base=None, config_path="../src/jamun/hydra_config", config_name="train") +def test_gradient_equivalence(cfg): + """Test that automatic and manual optimization produce equivalent gradients.""" + print("=" * 80) + print("TESTING GRADIENT EQUIVALENCE: AUTOMATIC vs MANUAL OPTIMIZATION") + print("=" * 80) + + # Set seeds for reproducibility + torch.manual_seed(42) + np.random.seed(42) + + # Create automatic optimization model + print("\n1. Creating AUTOMATIC optimization model...") + model_auto, datamodule_auto = create_model_and_data(cfg.copy(), max_graphs_per_batch=None) + print(f" Automatic optimization: {model_auto.automatic_optimization}") + + # Create manual optimization model with same architecture + print("2. Creating MANUAL optimization model...") + model_manual, datamodule_manual = create_model_and_data(cfg.copy(), max_graphs_per_batch=2) # 2 graphs per chunk + print(f" Automatic optimization: {model_manual.automatic_optimization}") + + # Ensure models are in training mode and parameters require gradients + print("3. Setting up models for gradient computation...") + model_auto.train() + model_manual.train() + + # Ensure all parameters require gradients + for param in model_auto.parameters(): + param.requires_grad_(True) + for param in model_manual.parameters(): + param.requires_grad_(True) + + # Copy weights from auto to manual model to ensure identical starting point + print("4. Synchronizing model weights...") + model_manual.load_state_dict(model_auto.state_dict()) + + # Get the same batch of data + print("5. Getting identical batch...") + torch.manual_seed(42) # Reset seed to get same batch + train_loader_auto = datamodule_auto.train_dataloader() + batch_auto = next(iter(train_loader_auto)) + + torch.manual_seed(42) # Reset seed to get same batch + train_loader_manual = datamodule_manual.train_dataloader() + batch_manual = next(iter(train_loader_manual)) + + print(f" Batch shapes: auto={batch_auto.pos.shape}, manual={batch_manual.pos.shape}") + print(f" Batch equality: {torch.allclose(batch_auto.pos, batch_manual.pos)}") + + # Use the same sigma for both (important!) + print("\n6. Setting identical sigma values...") + sigma_value = 0.04 # Fixed sigma instead of sampling + sigma_auto = torch.tensor(sigma_value) + sigma_manual = torch.tensor(sigma_value) + + print(f" Using fixed sigma: {sigma_value}") + + # Test forward pass equivalence (without multimeasurement first) + print("\n7. Testing forward pass equivalence...") + + # Disable multimeasurement temporarily for cleaner testing + model_auto.multimeasurement = False + model_manual.multimeasurement = False + + with torch.no_grad(): + # Reset random seeds before each forward pass + torch.manual_seed(123) + x_target_auto, xhat_auto, y_auto = model_auto.noise_and_denoise(batch_auto, sigma_auto, align_noisy_input=True) + + torch.manual_seed(123) # Same seed for manual + x_target_manual, xhat_manual, y_manual = model_manual.noise_and_denoise( + batch_manual, sigma_manual, align_noisy_input=True + ) + + forward_equal = torch.allclose(xhat_auto.pos, xhat_manual.pos, atol=1e-5) + print(f" Forward pass output equality: {forward_equal}") + + if not forward_equal: + print(f" Max difference: {(xhat_auto.pos - xhat_manual.pos).abs().max().item():.8f}") + print(" Continuing with test anyway...") + + # Test gradient computation + print("\n8. Computing gradients...") + + # AUTOMATIC OPTIMIZATION + print(" Computing automatic optimization gradients...") + model_auto.zero_grad() + + # Use same random seed for loss computation + torch.manual_seed(456) + x_target_auto, xhat_auto, y_auto = model_auto.noise_and_denoise(batch_auto, sigma_auto, align_noisy_input=True) + loss_auto, aux_auto = model_auto.compute_loss(x_target_auto, xhat_auto, sigma_auto) + loss_auto_mean = loss_auto.mean() + + print(f" Auto loss: {loss_auto_mean.item():.6f}") + print(f" Auto loss requires_grad: {loss_auto_mean.requires_grad}") + + loss_auto_mean.backward() + + gradients_auto = get_model_gradients(model_auto) + grad_norm_auto = compute_gradient_norm(gradients_auto) + + print(f" Auto grad_norm: {grad_norm_auto:.6f}") + + # MANUAL OPTIMIZATION (simulate _manual_step) + print(" Computing manual optimization gradients...") + model_manual.zero_grad() + + # Use same random seed for noise generation + torch.manual_seed(456) + y_manual, x_target_manual_prep = model_manual._prepare_noisy_batch( + batch_manual, sigma_manual, align_noisy_input=True + ) + + # Split into chunks + y_list = y_manual.to_data_list() + x_target_list = x_target_manual_prep.to_data_list() + chunk_size = model_manual.max_graphs_per_batch + num_chunks = (len(y_list) + chunk_size - 1) // chunk_size + + print(f" Manual: {len(y_list)} graphs → {num_chunks} chunks of size {chunk_size}") + + # Process chunks and accumulate gradients (simulate manual_step) + total_loss_manual = 0.0 + for i in range(num_chunks): + start_idx = i * chunk_size + end_idx = min(start_idx + chunk_size, len(y_list)) + + y_chunk = torch_geometric.data.Batch.from_data_list(y_list[start_idx:end_idx]) + x_target_chunk = torch_geometric.data.Batch.from_data_list(x_target_list[start_idx:end_idx]) + + xhat_chunk = model_manual.xhat(y_chunk, sigma_manual) + loss_chunk, aux_chunk = model_manual.compute_loss(x_target_chunk, xhat_chunk, sigma_manual) + loss_chunk_mean = loss_chunk.mean() + + # Scale loss by number of chunks (the fix we implemented) + scaled_loss = loss_chunk_mean / num_chunks + scaled_loss.backward() + + total_loss_manual += loss_chunk_mean.item() + + gradients_manual = get_model_gradients(model_manual) + grad_norm_manual = compute_gradient_norm(gradients_manual) + + print(f" Manual total loss: {total_loss_manual:.6f}, grad_norm: {grad_norm_manual:.6f}") + + # Compare gradients + print("\n9. COMPARING GRADIENTS:") + print(f" Gradient norm ratio (manual/auto): {grad_norm_manual / grad_norm_auto:.6f}") + print(f" Loss ratio (manual/auto): {total_loss_manual / loss_auto_mean.item():.6f}") + + print("\n Parameter-wise comparison:") + gradients_match = compare_gradients(gradients_auto, gradients_manual, tolerance=1e-2) + + print("\n" + "=" * 80) + if gradients_match: + print("✅ SUCCESS: Manual and automatic optimization produce equivalent gradients!") + print(f" Gradient norms: auto={grad_norm_auto:.6f}, manual={grad_norm_manual:.6f}") + print(f" Relative difference: {abs(grad_norm_manual - grad_norm_auto) / grad_norm_auto * 100:.3f}%") + else: + print("❌ FAILURE: Gradients do not match within tolerance!") + print(" This indicates an issue with the manual optimization implementation.") + print("=" * 80) + + return gradients_match + + +if __name__ == "__main__": + test_gradient_equivalence() diff --git a/scratch/test_multimeasurement.py b/scratch/test_multimeasurement.py new file mode 100644 index 0000000..d6fe553 --- /dev/null +++ b/scratch/test_multimeasurement.py @@ -0,0 +1,175 @@ +import e3nn + +e3nn.set_optimization_defaults(jit_script_fx=False) +import os +import sys + +import dotenv +import hydra +import lightning.pytorch as pl +import torch +from omegaconf import OmegaConf + +from jamun.utils import compute_average_squared_distance_from_datasets + +# Fix PyTorch Geometric backend issues +try: + import torch_cluster + import torch_scatter + import torch_sparse +except ImportError: + print("Warning: Some PyTorch Geometric extensions not available") + +# Use GPU if available +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(f"Using device: {device}") + +dotenv.load_dotenv("../.env", verbose=True) # Adjust path if script is not in scratch/ +JAMUN_DATA_PATH = os.getenv("JAMUN_DATA_PATH") +JAMUN_ROOT_PATH = os.getenv("JAMUN_ROOT_PATH") + +project_root = "/homefs/home/sules/jamun" # Adjust if necessary +if project_root not in sys.path: + sys.path.insert(0, project_root) + print(f"Added '{project_root}' to sys.path for module discovery.") +else: + print(f"'{project_root}' is already in sys.path.") + + +def compute_average_squared_distance_from_config(cfg: OmegaConf) -> float: + """Computes the average squared distance for normalization from the data.""" + datamodule = hydra.utils.instantiate(cfg.data.datamodule) + datamodule.setup("compute_normalization") + train_datasets = datamodule.datasets["train"] + cutoff = cfg.model.max_radius + average_squared_distance = compute_average_squared_distance_from_datasets(train_datasets, cutoff) + return average_squared_distance + + +def test_training_mode(cfg, mode_name, max_graphs_per_batch): + """Test training with specified optimization mode.""" + print(f"\n{'=' * 50}") + print(f"Testing {mode_name} mode (max_graphs_per_batch={max_graphs_per_batch})") + print(f"{'=' * 50}") + + # # Configure DenoiserMultimeasurement + # cfg.model._target_ = "jamun.model.denoiser_multimeasurement.DenoiserMultimeasurement" + # cfg.model.sigma_distribution._target_ = "jamun.distributions.ConstantSigma" + # cfg.model.sigma_distribution.sigma = 0.04 + + # Set multimeasurement parameters + # cfg.model.multimeasurement = True + cfg.model.N_measurements_hidden = 2 + cfg.model.N_measurements = 2 + cfg.model.max_graphs_per_batch = max_graphs_per_batch + + # Compute normalization + average_squared_distance = compute_average_squared_distance_from_config(cfg) + cfg.model.average_squared_distance = average_squared_distance + breakpoint() + print("Loading datamodule...") + datamodule = hydra.utils.instantiate(cfg.data.datamodule) + datamodule.setup("test") + breakpoint() + print("Loading model...") + model = hydra.utils.instantiate(cfg.model) + print(f"Model loaded: {type(model)}") + print(f"Multimeasurement: {model.multimeasurement}") + print(f"N_measurements_hidden: {model.N_measurements_hidden}") + print(f"N_measurements: {model.N_measurements}") + print(f"Automatic optimization: {model.automatic_optimization}") + print(f"Sigma: {model.sigma_distribution.sigma}") + breakpoint() + # Get a single batch + print("Getting a batch of data...") + train_loader = datamodule.train_dataloader() + _, batch = next(enumerate(train_loader)) + breakpoint() + print(f"Batch shape: {batch.pos.shape}") + print(f"Batch num_graphs: {batch.num_graphs}") + if hasattr(batch, "hidden_state") and batch.hidden_state is not None: + print(f"Hidden state shapes: {[h.shape for h in batch.hidden_state]}") + else: + print("No hidden states in batch") + breakpoint() + # Test forward pass + print("Testing forward pass...") + with torch.no_grad(): + sigma = model.sigma_distribution.sample() + x_target, xhat, y = model.noise_and_denoise(batch, sigma, align_noisy_input=True) + breakpoint() + print(f"Input shape: {batch.pos.shape}") + print(f"Noisy shape: {y.pos.shape}") + print(f"Output shape: {xhat.pos.shape}") + print(f"Target shape: {x_target.pos.shape}") + + # Verify multimeasurement expansion + expected_graphs = batch.num_graphs * model.N_measurements_hidden * model.N_measurements + actual_graphs = y.num_graphs + print(f"Expected graphs after multimeasurement: {expected_graphs}") + print(f"Actual graphs: {actual_graphs}") + assert actual_graphs == expected_graphs, f"Graph count mismatch: expected {expected_graphs}, got {actual_graphs}" + + # Test actual training with fast_dev_run + print("Testing training with fast_dev_run...") + + # Configure trainer to use only 1 GPU + if torch.cuda.is_available(): + print(f"CUDA available with {torch.cuda.device_count()} GPUs - using GPU 0 only") + trainer = pl.Trainer( + fast_dev_run=1, # Run 1 train, 1 val batch and stop + enable_checkpointing=False, + logger=False, + enable_progress_bar=True, + enable_model_summary=False, + accelerator="gpu", + devices=[0], # Explicitly use only GPU 0 + strategy="auto", # Single device strategy + ) + else: + print("CUDA not available - using CPU") + trainer = pl.Trainer( + fast_dev_run=1, + enable_checkpointing=False, + logger=False, + enable_progress_bar=True, + enable_model_summary=False, + accelerator="cpu", + devices=1, + ) + + try: + trainer.fit(model, datamodule) + print(f"{mode_name} mode training completed successfully!") + except Exception as e: + print(f"Error during training: {e}") + raise + + print(f"{mode_name} mode test completed successfully!") + return model + + +@hydra.main(version_base=None, config_path="../src/jamun/hydra_config", config_name="train") +def main(cfg): + # # Override data config to use only ALA_ALA + # cfg.data.datamodule.filter_codes = ['ALA_ALA'] + # cfg.data.datamodule.subsample = 10 # Use fewer samples for faster testing + # cfg.data.datamodule.batch_size = 4 # Small batch size for testing + + print("Testing DenoiserMultimeasurement training modes") + print(f"Using ALA_ALA data from: {JAMUN_DATA_PATH}") + + # Test automatic optimization mode + test_training_mode(cfg.copy(), "AUTOMATIC", None) + + # Test manual optimization mode + test_training_mode(cfg.copy(), "MANUAL", 2) # Process 2 graphs at a time + + print(f"\n{'=' * 50}") + print("ALL TESTS PASSED!") + print("Both automatic and manual optimization modes work correctly.") + print(f"{'=' * 50}") + + +if __name__ == "__main__": + main() diff --git a/scratch/test_reorganize_swarm_data.py b/scratch/test_reorganize_swarm_data.py new file mode 100644 index 0000000..456d2b9 --- /dev/null +++ b/scratch/test_reorganize_swarm_data.py @@ -0,0 +1,262 @@ +#!/usr/bin/env python3 +""" +Test script for the swarm data reorganization. + +This script tests the reorganization functionality on a small subset of data +before running the full reorganization. +""" + +import os +import sys +import tempfile + +# Add the scratch directory to the path so we can import the main script +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +try: + import mdtraj as md + + MDTRAJ_AVAILABLE = True +except ImportError: + MDTRAJ_AVAILABLE = False + print("⚠️ mdtraj not available. Trajectory validation will be skipped.") + +try: + from tqdm import tqdm + + TQDM_AVAILABLE = True +except ImportError: + TQDM_AVAILABLE = False + print("⚠️ tqdm not available. Progress bars will be disabled.") + + +def create_test_data(test_source_dir: str, test_pdb_file: str): + """Create a small test dataset structure.""" + print("Creating test data structure...") + + # Create test source directory + os.makedirs(test_source_dir, exist_ok=True) + + # Create a few test grid directories with mock files + test_grid_codes = ["000", "001", "002", "003", "004"] + trajectory_codes = ["001", "002", "003", "004", "005"] + + for grid_code in test_grid_codes: + grid_dir = os.path.join(test_source_dir, f"AA_{grid_code}") + os.makedirs(grid_dir, exist_ok=True) + + # Create mock .xtc files + for traj_code in trajectory_codes: + xtc_file = os.path.join(grid_dir, f"swarm_1ps_{traj_code}.xtc") + # Create empty files that will be copied, but note they won't be valid XTC format + # This is just for testing the file organization logic + with open(xtc_file, "w") as f: + f.write(f"Mock XTC data for grid {grid_code}, trajectory {traj_code}\n") + + # Create mock single PDB file with more realistic content + os.makedirs(os.path.dirname(test_pdb_file), exist_ok=True) + with open(test_pdb_file, "w") as f: + # Write a minimal but valid PDB structure for 2 alanine residues + f.write("TITLE MOCK ALA-ALA DIPEPTIDE\n") + f.write("ATOM 1 N ALA A 1 0.000 0.000 0.000 1.00 0.00 N\n") + f.write("ATOM 2 CA ALA A 1 1.458 0.000 0.000 1.00 0.00 C\n") + f.write("ATOM 3 C ALA A 1 2.009 1.420 0.000 1.00 0.00 C\n") + f.write("ATOM 4 O ALA A 1 1.332 2.445 0.000 1.00 0.00 O\n") + f.write("ATOM 5 CB ALA A 1 1.978 -0.750 1.202 1.00 0.00 C\n") + f.write("ATOM 6 H ALA A 1 -0.481 0.000 0.890 1.00 0.00 H\n") + f.write("ATOM 7 HA ALA A 1 1.804 -0.531 -0.900 1.00 0.00 H\n") + f.write("ATOM 8 HB1 ALA A 1 1.642 -1.785 1.202 1.00 0.00 H\n") + f.write("ATOM 9 HB2 ALA A 1 3.068 -0.750 1.202 1.00 0.00 H\n") + f.write("ATOM 10 HB3 ALA A 1 1.642 -0.281 2.132 1.00 0.00 H\n") + f.write("ATOM 11 N ALA A 2 3.332 1.420 0.000 1.00 0.00 N\n") + f.write("ATOM 12 CA ALA A 2 4.009 2.709 0.000 1.00 0.00 C\n") + f.write("ATOM 13 C ALA A 2 5.509 2.709 0.000 1.00 0.00 C\n") + f.write("ATOM 14 O ALA A 2 6.134 1.649 0.000 1.00 0.00 O\n") + f.write("ATOM 15 CB ALA A 2 3.489 3.459 1.202 1.00 0.00 C\n") + f.write("ATOM 16 H ALA A 2 3.855 0.556 0.000 1.00 0.00 H\n") + f.write("ATOM 17 HA ALA A 2 3.673 3.240 -0.900 1.00 0.00 H\n") + f.write("ATOM 18 HB1 ALA A 2 3.825 4.494 1.202 1.00 0.00 H\n") + f.write("ATOM 19 HB2 ALA A 2 2.399 3.459 1.202 1.00 0.00 H\n") + f.write("ATOM 20 HB3 ALA A 2 3.825 2.990 2.132 1.00 0.00 H\n") + f.write("ATOM 21 OXT ALA A 2 6.032 3.829 0.000 1.00 0.00 O\n") + f.write("TER 22 ALA A 2\n") + f.write("END\n") + + print(f"Created test data with {len(test_grid_codes)} grid codes") + + +def test_mdtraj_with_mock_data(train_dir: str, val_dir: str): + """ + Test mdtraj functionality with mock data. + Note: This will likely fail since we're creating mock XTC files that aren't real trajectories. + """ + if not MDTRAJ_AVAILABLE: + print("⚠️ mdtraj not available, skipping trajectory compatibility tests") + return + + print("Testing mdtraj compatibility (expected to fail with mock data)...") + + for split_name, split_dir in [("train", train_dir), ("val", val_dir)]: + xtc_files = [f for f in os.listdir(split_dir) if f.endswith(".xtc")] + if not xtc_files: + continue + + # Test one file from each split + test_file = xtc_files[0] + base_name = test_file.replace(".xtc", "") + pdb_file = f"{base_name}.pdb" + + xtc_path = os.path.join(split_dir, test_file) + pdb_path = os.path.join(split_dir, pdb_file) + + try: + # This will likely fail since we have mock XTC data + traj = md.load(xtc_path, top=pdb_path) + print( + f"✅ {split_name}: Successfully loaded {test_file} + {pdb_file} " + f"({traj.n_frames} frames, {traj.n_atoms} atoms)" + ) + del traj + except Exception as e: + print(f"❌ {split_name}: Failed to load {test_file} + {pdb_file}: {str(e)}") + print(" (This is expected with mock data - real data should work)") + + print("Note: mdtraj tests with mock data are expected to fail.") + print("The real script will test with actual trajectory files.") + + +def test_reorganization(): + """Test the reorganization script with mock data.""" + with tempfile.TemporaryDirectory() as temp_dir: + # Set up test paths + test_source = os.path.join(temp_dir, "test_swarm_results") + test_target = os.path.join(temp_dir, "test_enhanced") + test_pdb = os.path.join(temp_dir, "test_ALA_ALA.pdb") + + # Create test data + create_test_data(test_source, test_pdb) + + # Import and modify the main script for testing + import reorganize_swarm_data + + # Temporarily override the configuration + original_source = reorganize_swarm_data.SOURCE_DIR + original_pdb = reorganize_swarm_data.SINGLE_PDB_FILE + original_strategies = reorganize_swarm_data.SPLITTING_STRATEGIES.copy() + + reorganize_swarm_data.SOURCE_DIR = test_source + reorganize_swarm_data.SINGLE_PDB_FILE = test_pdb + + # Override the splitting strategies for testing + reorganize_swarm_data.SPLITTING_STRATEGIES = { + "grid_split": { + "target_dir": test_target + "_grid_split", + "train_size": 3, # Use 3 for train, 2 for val + "description": "Test grid split", + }, + "trajectory_split": { + "target_dir": test_target + "_trajectory_split", + "train_trajectories": ["001", "002", "003"], # First 3 for testing + "val_trajectories": ["004", "005"], # Last 2 for testing + "description": "Test trajectory split", + }, + } + + try: + print("\n" + "=" * 50) + print("RUNNING TEST REORGANIZATION") + print("=" * 50) + + # Test both strategies + print("Testing grid split strategy...") + reorganize_swarm_data.main("grid_split") + + print("Testing trajectory split strategy...") + reorganize_swarm_data.main("trajectory_split") + + # Verify results + print("\n" + "=" * 50) + print("VERIFYING TEST RESULTS") + print("=" * 50) + + # Check both strategies + for strategy_name in ["grid_split", "trajectory_split"]: + strategy_dir = test_target + f"_{strategy_name}" + train_dir = os.path.join(strategy_dir, "train") + val_dir = os.path.join(strategy_dir, "val") + + if os.path.exists(train_dir) and os.path.exists(val_dir): + train_files = os.listdir(train_dir) + val_files = os.listdir(val_dir) + + train_xtc = [f for f in train_files if f.endswith(".xtc")] + train_pdb = [f for f in train_files if f.endswith(".pdb")] + val_xtc = [f for f in val_files if f.endswith(".xtc")] + val_pdb = [f for f in val_files if f.endswith(".pdb")] + + print(f"\n{strategy_name.upper()} STRATEGY:") + print(f"Train directory: {len(train_files)} files ({len(train_xtc)} .xtc, {len(train_pdb)} .pdb)") + print(f"Val directory: {len(val_files)} files ({len(val_xtc)} .xtc, {len(val_pdb)} .pdb)") + + # Calculate expected files based on strategy + if strategy_name == "grid_split": + # 3 grid codes × 5 trajectories × 2 file types = 30 train files + # 2 grid codes × 5 trajectories × 2 file types = 20 val files + expected_train = 3 * 5 * 2 + expected_val = 2 * 5 * 2 + else: # trajectory_split + # 5 grid codes × 3 trajectories × 2 file types = 30 train files + # 5 grid codes × 2 trajectories × 2 file types = 20 val files + expected_train = 5 * 3 * 2 + expected_val = 5 * 2 * 2 + + if len(train_files) == expected_train and len(val_files) == expected_val: + print(f"✅ {strategy_name} Test PASSED! File counts are correct.") + + # Check a few file names + print("Sample train files:", sorted(train_files)[:3]) + print("Sample val files:", sorted(val_files)[:3]) + else: + print( + f"❌ {strategy_name} Test FAILED! Expected {expected_train} train, {expected_val} val files" + ) + return False + else: + print(f"❌ {strategy_name} Test FAILED! Output directories were not created") + return False + + # Test mdtraj compatibility on one strategy + strategy_dir = test_target + "_grid_split" + train_dir = os.path.join(strategy_dir, "train") + val_dir = os.path.join(strategy_dir, "val") + print("\n=== Testing mdtraj compatibility ===") + test_mdtraj_with_mock_data(train_dir, val_dir) + + print("\n✅ All tests completed successfully!") + return True + + finally: + # Restore original configuration + reorganize_swarm_data.SOURCE_DIR = original_source + reorganize_swarm_data.SINGLE_PDB_FILE = original_pdb + reorganize_swarm_data.SPLITTING_STRATEGIES = original_strategies + + +def main(): + """Run the test.""" + print("Starting test of swarm data reorganization script...") + + success = test_reorganization() + + if success: + print("\n🎉 Test passed! The script should work correctly on the real data.") + print("\nTo run the full reorganization, execute:") + print("python scratch/reorganize_swarm_data.py") + else: + print("\n❌ Test failed! Please check the script before running on real data.") + + return success + + +if __name__ == "__main__": + main() diff --git a/scratch/test_repeated_position_dataset.py b/scratch/test_repeated_position_dataset.py new file mode 100644 index 0000000..2474f54 --- /dev/null +++ b/scratch/test_repeated_position_dataset.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +""" +Test script for RepeatedPositionDataset to verify that hidden states +are exact copies of the current position. +""" + +import os + +# Add the src directory to the path so we can import jamun modules +import sys +from pathlib import Path + +import torch + +sys.path.insert(0, str(Path(__file__).parent / "src")) + +from jamun.data.noisy_position_dataset import RepeatedPositionDataset + + +def test_repeated_position_dataset(): + """Test RepeatedPositionDataset with ALA_ALA capped diamines data.""" + + print("Testing RepeatedPositionDataset...") + print("=" * 50) + + # Set up dataset parameters + root = "/data/bucket/kleinhej/capped_diamines/timewarp_splits/train" + traj_files = ["ALA_ALA.xtc"] + pdb_file = "ALA_ALA.pdb" + label = "ALA_ALA_test" + total_lag_time = 4 # This should create 3 hidden states (4 - 1 = 3) + + # Check if files exist + xtc_path = os.path.join(root, traj_files[0]) + pdb_path = os.path.join(root, pdb_file) + + if not os.path.exists(xtc_path): + print(f"❌ ERROR: XTC file not found: {xtc_path}") + return False + if not os.path.exists(pdb_path): + print(f"❌ ERROR: PDB file not found: {pdb_path}") + return False + + print(f"✅ Found XTC file: {xtc_path}") + print(f"✅ Found PDB file: {pdb_path}") + + try: + # Create the dataset + print(f"\nCreating RepeatedPositionDataset with total_lag_time={total_lag_time}...") + dataset = RepeatedPositionDataset( + root=root, + traj_files=traj_files, + pdb_file=pdb_file, + label=label, + total_lag_time=total_lag_time, + num_frames=5, # Only load 5 frames for testing + verbose=True, + ) + + print("✅ Dataset created successfully") + print(f" Dataset length: {len(dataset)}") + print(f" Dataset label: {dataset.label()}") + + # Test a few samples + print("\nTesting samples...") + + for idx in range(min(3, len(dataset))): + print(f"\n--- Sample {idx} ---") + + # Get sample from dataset + graph = dataset[idx] + + print(f"Graph pos shape: {graph.pos.shape}") + print(f"Number of hidden states: {len(graph.hidden_state)}") + + # Verify we have the expected number of hidden states + expected_hidden_states = total_lag_time - 1 + if len(graph.hidden_state) != expected_hidden_states: + print(f"❌ ERROR: Expected {expected_hidden_states} hidden states, got {len(graph.hidden_state)}") + return False + + print(f"✅ Correct number of hidden states: {len(graph.hidden_state)}") + + # Test each hidden state + for i, hidden_pos in enumerate(graph.hidden_state): + print(f"Hidden state {i} shape: {hidden_pos.shape}") + + # Check if shapes match + if hidden_pos.shape != graph.pos.shape: + print(f"❌ ERROR: Shape mismatch! pos: {graph.pos.shape}, hidden_state[{i}]: {hidden_pos.shape}") + return False + + # Check if values are exactly equal + if not torch.allclose(hidden_pos, graph.pos, atol=1e-10): + print(f"❌ ERROR: Hidden state {i} is not exactly equal to current position!") + print(f" Max difference: {torch.max(torch.abs(hidden_pos - graph.pos)).item()}") + return False + + # Check if they are the exact same tensor (should be different objects but same values) + if hidden_pos is graph.pos: + print(f"⚠️ WARNING: Hidden state {i} is the same object as pos (should be different objects)") + else: + print(f"✅ Hidden state {i} is a different object with same values as pos") + + print(f"✅ Hidden state {i} exactly matches current position") + + print("\n🎉 All tests passed!") + print(" ✅ Dataset loads correctly") + print(f" ✅ Correct number of hidden states ({total_lag_time - 1})") + print(" ✅ Hidden states exactly match current position") + print(" ✅ Hidden states are separate objects (not references)") + + return True + + except Exception as e: + print(f"❌ ERROR: Exception occurred: {e}") + import traceback + + traceback.print_exc() + return False + + +if __name__ == "__main__": + success = test_repeated_position_dataset() + if success: + print("\n🎉 Test completed successfully!") + exit(0) + else: + print("\n💥 Test failed!") + exit(1) diff --git a/scratch/training_prototype.py b/scratch/training_prototype.py new file mode 100644 index 0000000..8452e53 --- /dev/null +++ b/scratch/training_prototype.py @@ -0,0 +1,369 @@ +# %% Imports and Basic Setup +import logging +import os +import sys + +import dotenv +import e3nn +import hydra +import lightning.pytorch as pl +import torch +import torch_geometric.data +from hydra import compose, initialize +from omegaconf import OmegaConf + +import jamun +import jamun.data +import jamun.distributions +import jamun.model +import jamun.model.arch +from jamun.hydra import instantiate_dict_cfg + +# Assuming these are in jamun.utils and jamun.hydra respectively +from jamun.utils import compute_average_squared_distance_from_datasets, find_checkpoint + +# --- Basic Setup --- +logging.basicConfig(format="[%(asctime)s][%(name)s][%(levelname)s] - %(message)s", level=logging.INFO) +py_logger = logging.getLogger("jamun_script") + +torch.cuda.is_available() # Good to check, but PL trainer will also handle device +torch.set_float32_matmul_precision("high") +e3nn.set_optimization_defaults(jit_script_fx=False) + +# %% Environment and Paths +dotenv.load_dotenv("../.env", verbose=True) # Adjust path if script is not in scratch/ +JAMUN_DATA_PATH = os.getenv("JAMUN_DATA_PATH") +JAMUN_ROOT_PATH = os.getenv("JAMUN_ROOT_PATH") + +project_root = "/homefs/home/sules/jamun" # Adjust if necessary +if project_root not in sys.path: + sys.path.insert(0, project_root) + py_logger.info(f"Added '{project_root}' to sys.path for module discovery.") +else: + py_logger.info(f"'{project_root}' is already in sys.path.") + +# %% Load Hydra Configuration +py_logger.info("Loading Hydra configuration...") +# Adjust config_path relative to the script's location if it's not in scratch/ +# If jamun_training_script.py is in scratch/, and configs are in jamun/configs/ +# then config_path should be "../configs" +with initialize(config_path=".", job_name="conditioning_initial_run"): # Corrected job_name from previous context + cfg = compose( + config_name="config", # Main config file + overrides=[ + "model.arch._target_=scratch.e3conv_test.E3Conv", # Relative to project_root + "model._target_=scratch.denoiser_test.Denoiser", # Relative to project_root + "+model.arch.N_structures=2", + "trainer.max_epochs=100", # Example: train for 10 epochs + # Add other overrides, e.g. "logger=null" if you don't want default loggers for a quick test + ], + ) +py_logger.info("Loaded configuration:") +py_logger.info(OmegaConf.to_yaml(cfg)) + + +# %% Initial Dataset Setup (for model properties like average_squared_distance) +py_logger.info("Setting up initial dataset for model properties...") +initial_datasets_for_props = { + "props_dataset": jamun.data.parse_datasets_from_directory( # Renamed key for clarity + root=f"{JAMUN_DATA_PATH}/timewarp/2AA-1-large/train/", + traj_pattern="^(.*)-traj-arrays.npz", + pdb_file="AA-traj-state0.pdb", + filter_codes=["AA"], + as_iterable=False, + subsample=100, # Keep this small for this purpose + max_datasets=1, + ) +} + +# %% Model Instantiation +py_logger.info("Instantiating model...") +try: + if not hasattr(cfg.model, "average_squared_distance") or cfg.model.average_squared_distance is None: + py_logger.info("Computing average_squared_distance for the model...") + average_squared_distance = compute_average_squared_distance_from_datasets( + initial_datasets_for_props["props_dataset"], # Use the small dataset for this + cfg.model.max_radius, + ) + cfg.model.average_squared_distance = average_squared_distance + py_logger.info(f"Set cfg.model.average_squared_distance to {cfg.model.average_squared_distance}") + + # Provide conditioner if needed + if not hasattr(cfg.model, "conditioner"): + OmegaConf.set_struct(cfg.model, False) # Allow modification + cfg.model.conditioner = OmegaConf.create({}) + cfg.model.conditioner._target_ = ( + "scratch.conditioners.SelfConditioner" # Use the SelfConditioner from scratch.conditioners + ) + OmegaConf.set_struct(cfg.model, True) # Lock structure again + py_logger.info(f"Set cfg.model.conditioner to instantiate {cfg.model.conditioner._target_}.") + + model = hydra.utils.instantiate(cfg.model) + + py_logger.info("Successfully instantiated model.") + py_logger.info(f"Instantiated model architecture type: {type(model.g)}") +except Exception as e: + py_logger.error(f"Error during model instantiation: {e}") + import traceback + + traceback.print_exc() + sys.exit(1) + +# %% Device Setup for Model +# Determine the target device, preferring GPU if available. +if torch.cuda.is_available(): + device = torch.device("cuda") + py_logger.info("CUDA is available. Attempting to use GPU.") +else: + device = torch.device("cpu") + py_logger.info("CUDA not available. Using CPU.") + +# Move the model to the determined device. +model = model.to(device) + +# Verify and log the model's actual device. +# After .to(device), the model's internal device attribute should update. +# We can also check a parameter's device as a fallback verification. +final_model_device = None +if hasattr(model, "device") and model.device is not None: + final_model_device = model.device +elif next(model.parameters(), None) is not None: + final_model_device = next(model.parameters()).device + +py_logger.info(f"Model '{type(model).__name__}' is now on device: {final_model_device}") + + +# %% Setup for Actual Training +py_logger.info("Setting up for actual training...") + +# 1. Prepare datasets for training, validation, and testing +# Using the same dataset source for all splits as a placeholder. +# In a real scenario, these would be different datasets or splits. +# You might want to parse a larger dataset here for actual training. +py_logger.info("Parsing dataset for training...") +training_dataset_source = jamun.data.parse_datasets_from_directory( + root=f"{JAMUN_DATA_PATH}/timewarp/2AA-1-large/train/", # Consider using full dataset + traj_pattern="^(.*)-traj-arrays.npz", + pdb_file="AA-traj-state0.pdb", + filter_codes=["AA"], + as_iterable=False, # Set to True for very large datasets if memory is an issue + subsample=cfg.data.datamodule.datasets.train[0].subsample, # Use subsample from config or None for full + max_datasets=1, # Use from config or None for all +) + +datasets_for_training = { + "train": training_dataset_source, + "val": training_dataset_source, # Replace with actual validation set + "test": training_dataset_source, # Replace with actual test set +} +py_logger.info(f"Prepared datasets for training: { {k: type(v).__name__ for k, v in datasets_for_training.items()} }") +if isinstance(training_dataset_source, torch_geometric.data.Dataset): + py_logger.info(f"Training dataset size: {len(training_dataset_source)}") + + +# 2. Initialize DataModule for training +datamodule_for_training = jamun.data.MDtrajDataModule( + datasets=datasets_for_training, + batch_size=cfg.data.datamodule.batch_size, + num_workers=cfg.data.datamodule.num_workers, +) + +# 3. Model is already instantiated and on device +py_logger.info(f"Model '{type(model).__name__}' is ready for training.") + +# %% Loggers and Callbacks Setup +py_logger.info("Setting up loggers and callbacks...") + +if cfg.get("logger") and cfg.logger.get("wandb"): + try: + # This requires JAMUN_ROOT_PATH, task_name, run_group, and run_key to be correctly + # defined and resolvable in your configuration. + wandb_save_dir = str(cfg.paths.run_path) # Resolve the path from OmegaConf + + # Update the logger config before instantiation + OmegaConf.update(cfg, "logger.wandb.save_dir", wandb_save_dir, merge=False) + py_logger.info(f"Explicitly setting WandbLogger save_dir to: {wandb_save_dir}") + + # Ensure the target directory for wandb files exists. + # WandbLogger will create its 'wandb/' subdirectory and run-specific folders inside this save_dir. + os.makedirs(wandb_save_dir, exist_ok=True) + py_logger.info(f"Ensured WandbLogger save_dir exists: {wandb_save_dir}") + + except Exception as e: + py_logger.error(f"Could not resolve or set wandb save_dir from cfg.paths.run_path: {e}") + py_logger.warning(f"Wandb will use default save directory (likely ./wandb in CWD: {os.getcwd()}).") + +# 1. Instantiate Loggers and Callbacks from Hydra config +loggers_list = [] +if cfg.get("logger"): + if hasattr(cfg.logger, "_target_"): + loggers_list.append(hydra.utils.instantiate(cfg.logger)) + else: + # Assuming instantiate_dict_cfg iterates and calls hydra.utils.instantiate for each logger config + loggers_list = instantiate_dict_cfg(cfg.logger) +py_logger.info(f"Instantiated loggers: {[type(l).__name__ for l in loggers_list]}") + +# %% 2. Determine and set the ModelCheckpoint directory path +final_checkpoint_dir = None +# Check if the first logger is a WandbLogger and provides a directory +if ( + loggers_list + and isinstance(loggers_list[0], pl.loggers.WandbLogger) + and hasattr(loggers_list[0], "experiment") + and loggers_list[0].experiment + and hasattr(loggers_list[0].experiment, "dir") + and loggers_list[0].experiment.dir +): + wandb_run_root_dir = loggers_list[0].experiment.dir + # Adjust if wandb_run_root_dir points to a 'files' subdirectory + if os.path.basename(wandb_run_root_dir) == "files": + wandb_run_root_dir = os.path.dirname(wandb_run_root_dir) + final_checkpoint_dir = os.path.join(wandb_run_root_dir, "checkpoints") + py_logger.info(f"Using WandB logger's experiment directory for checkpoints: {final_checkpoint_dir}") +else: + # Default path if no suitable WandB logger is found or no loggers are configured + final_checkpoint_dir = os.path.join(os.getcwd(), "outputs", "checkpoints") + if not loggers_list: + py_logger.info(f"No loggers configured. Defaulting checkpoint directory: {final_checkpoint_dir}") + else: + py_logger.info( + f"First logger is not a suitable WandB logger. Defaulting checkpoint directory: {final_checkpoint_dir}" + ) + +# Update the config if model_checkpoint callback is defined +if cfg.get("callbacks") and cfg.callbacks.get("model_checkpoint"): + # Use OmegaConf.update to safely set the possibly nested key, + # this will create it if it doesn't exist or overwrite if it does. + OmegaConf.update(cfg, "callbacks.model_checkpoint.dirpath", final_checkpoint_dir, merge=False) + py_logger.info(f"Set cfg.callbacks.model_checkpoint.dirpath to: {final_checkpoint_dir}") + # Ensure the directory exists + os.makedirs(final_checkpoint_dir, exist_ok=True) +else: + py_logger.info("ModelCheckpoint callback not configured in cfg.callbacks, dirpath not set.") + + +# %% Instantiate Callbacks +callbacks_list = [] +if cfg.get("callbacks"): + if hasattr(cfg.callbacks, "_target_"): # Single callback config + callbacks_list.append(hydra.utils.instantiate(cfg.callbacks)) + else: # Dictionary of callback configs + callbacks_list = instantiate_dict_cfg(cfg.callbacks) # This will now use the modified dirpath +py_logger.info(f"Instantiated callbacks: {[type(c).__name__ for c in callbacks_list]}") + +# 2. Instantiate PyTorch Lightning Trainer +trainer_config = cfg.trainer +if not hasattr(trainer_config, "_target_") and isinstance(trainer_config, dict): + trainer_config = OmegaConf.merge(trainer_config, {"_target_": "lightning.pytorch.Trainer"}) + +trainer: pl.Trainer = hydra.utils.instantiate( + trainer_config, + logger=loggers_list if loggers_list else True, + callbacks=callbacks_list, +) +py_logger.info(f"Instantiated Trainer: {type(trainer)}") +py_logger.info(f"Trainer will run for {trainer.max_epochs} epochs.") + +# 3. Handle checkpoint resumption (optional) +checkpoint_path = None +if resume_checkpoint_cfg := cfg.get("resume_from_checkpoint"): + if resume_checkpoint_cfg.get("enabled", False): + py_logger.info(f"Attempting to resume from checkpoint with config: {resume_checkpoint_cfg}") + try: + checkpoint_path = find_checkpoint( + wandb_train_run_path=resume_checkpoint_cfg.get("wandb_train_run_path"), + checkpoint_dir=resume_checkpoint_cfg.get("checkpoint_dir"), + checkpoint_type=resume_checkpoint_cfg.get("checkpoint_type", "last"), + ) + if checkpoint_path: + py_logger.info(f"Found checkpoint to resume from: {checkpoint_path}") + else: + py_logger.warning("No checkpoint found for resumption based on config.") + except Exception as e: + py_logger.error(f"Error finding checkpoint: {e}. Starting training from scratch.") + checkpoint_path = None + else: + py_logger.info("Checkpoint resumption is configured but not enabled.") + +# %% Start Training +py_logger.info("Starting training...") +try: + trainer.fit( + model=model, # Use the main model instance + datamodule=datamodule_for_training, + ckpt_path=checkpoint_path if checkpoint_path else None, + ) + py_logger.info("Training finished.") + + if cfg.get("run_test_after_train", False): + py_logger.info("Running test phase...") + trainer.test(model=model, datamodule=datamodule_for_training) + py_logger.info("Test phase finished.") + +except Exception as e: + py_logger.error(f"Training FAILED: {e}") + traceback.print_exc() + +py_logger.info("Script finished.") + +# %% Log the final configuration and save it locally +wandb_logger_instance = None +# loggers_list should still be in scope from when it was passed to the Trainer +for logger_from_list in loggers_list: # Use a different variable name to avoid conflict if logger is defined elsewhere + if isinstance(logger_from_list, pl.loggers.WandbLogger): + wandb_logger_instance = logger_from_list + break + +if wandb_logger_instance and hasattr(wandb_logger_instance, "experiment") and wandb_logger_instance.experiment: + py_logger.info( + f"WandbLogger experiment active (run_id: {wandb_logger_instance.experiment.id}). Logging final script config to wandb." + ) + # Convert the current OmegaConf object 'cfg' to a plain dictionary + # This 'cfg' includes all modifications made throughout the script + final_script_cfg_dict = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) + + try: + wandb_logger_instance.experiment.config.update( + {"cfg": final_script_cfg_dict, "jamun_version_at_end": jamun.__version__, "script_cwd_at_end": os.getcwd()} + ) + py_logger.info("Updated wandb.config with final_script_cfg.") + except Exception as e: + py_logger.error(f"Failed to update wandb.config with final script config: {e}") +else: + if cfg.get("logger") and cfg.logger.get("wandb"): + py_logger.warning( + "WandbLogger was configured but not found or experiment not active at script end. Final script config not logged to wandb.config." + ) + +# 2. Explicitly save the final state of the OmegaConf object 'cfg' to a local file + +final_config_output_dir = None +if cfg.get("logger") and cfg.logger.get("wandb") and cfg.logger.wandb.get("save_dir"): + final_config_output_dir = cfg.logger.wandb.save_dir +elif "wandb_save_dir" in locals() and wandb_save_dir: # If it was set in a previous cell + final_config_output_dir = wandb_save_dir +else: + # Fallback if a specific run directory isn't easily available + # This might not be ideal as it won't be co-located with W&B run files if save_dir wasn't set + final_config_output_dir = os.path.join( + os.getcwd(), "outputs", cfg.get("task_name", "unknown_task"), cfg.get("run_key", "unknown_run") + ) + os.makedirs(final_config_output_dir, exist_ok=True) + + +if final_config_output_dir: + final_config_path = os.path.join(final_config_output_dir, "final_resolved_script_config.yaml") + try: + with open(final_config_path, "w") as f: + OmegaConf.save(config=cfg, f=f) + py_logger.info(f"Final script configuration saved locally to: {final_config_path}") + except Exception as e: + py_logger.error(f"Failed to save final script configuration locally: {e}") +else: + py_logger.warning( + "Could not determine a definitive output directory for final_resolved_script_config.yaml. Not saving locally." + ) + +loggers_list[0].experiment.finish() if loggers_list and isinstance(loggers_list[0], pl.loggers.WandbLogger) else None +py_logger.info("Script finished.") +# %% diff --git a/scratch/transformer/convert_spatiotemporal.py b/scratch/transformer/convert_spatiotemporal.py new file mode 100644 index 0000000..2cb5053 --- /dev/null +++ b/scratch/transformer/convert_spatiotemporal.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python3 +""" +Functions for converting between spatial and temporal graph representations. +""" + +import torch +import torch_geometric + + +def calculate_temporal_positions(temporal_length, device=None): + """ + Calculate normalized temporal positions for nodes in a temporal graph. + + Args: + temporal_length: Total number of nodes in the temporal sequence + device: Device to create tensors on + + Returns: + torch.Tensor: Normalized positions [0, 1/T, 2/T, ..., (T-1)/T] + """ + if temporal_length <= 1: + return torch.tensor([0.0], device=device) + + # Create positions [0, 1, 2, ..., T-1] and normalize by T + positions = torch.arange(temporal_length, dtype=torch.float32, device=device) + normalized_positions = positions / temporal_length + + return normalized_positions + + +def spatial_to_temporal_graphs(batch, graph_type="fan"): + """ + Convert a batch of spatial graphs to temporal graphs with configurable connectivity. + + For each spatial node with position + hidden states, create a temporal graph where: + - Node 0: current position + - Nodes 1-T: hidden state positions + - Connectivity depends on graph_type parameter + + Args: + batch: Input spatial graph batch + graph_type: Type of connectivity to use + - "fan": Hub connects to all + sequential connections (0->all, i->(i+1)) + - "hub_n_spoke": Only hub-spoke connections (0->all, no sequential) + - "complete": Complete graph with self-loops (all-to-all including self) + - "complete_no_self": Complete graph without self-loops (all-to-all excluding self) + """ + + # Validate graph_type + valid_types = ["fan", "hub_n_spoke", "complete", "complete_no_self"] + if graph_type not in valid_types: + raise ValueError(f"graph_type must be one of {valid_types}, got {graph_type}") + + # Get device from input batch + device = batch.pos.device + + # Get dimensions + num_spatial_nodes = batch.pos.shape[0] + + # Check if we have hidden states + if hasattr(batch, "hidden_state") and batch.hidden_state is not None and len(batch.hidden_state) > 0: + num_hidden_states = len(batch.hidden_state) + temporal_length = 1 + num_hidden_states # current + hidden + else: + # If no hidden states, just use current position + num_hidden_states = 0 + temporal_length = 1 + + # print(f"Creating {graph_type} temporal graphs: {num_spatial_nodes} spatial nodes -> {num_spatial_nodes} temporal graphs of length {temporal_length}") + + # Store reference to spatial graph + spatial_graph = batch.clone() + + # Set connectivity type code for tracking + connectivity_type_map = {"fan": 0, "hub_n_spoke": 1, "complete": 2, "complete_no_self": 3} + + temporal_graphs = [] + + for node_idx in range(num_spatial_nodes): + # Build temporal positions: [current_pos, hidden_1, hidden_2, ...] + temporal_positions = [batch.pos[node_idx]] # Start with current position + + # Add hidden state positions + if num_hidden_states > 0: + for hidden_pos in batch.hidden_state: + temporal_positions.append(hidden_pos[node_idx]) + + temporal_pos = torch.stack(temporal_positions) # Shape: [T, 3] + + # Calculate temporal positions for this sequence + temporal_position = calculate_temporal_positions(temporal_length, device=device) + + # Create edge connectivity based on graph_type + if temporal_length > 1: + if graph_type == "fan": + # Original fan system: hub-spoke + sequential + # Hub connections: 0->1, 0->2, 0->3, ..., 0->T-1 + hub_src = [0] * (temporal_length - 1) + hub_dst = list(range(1, temporal_length)) + + # Sequential connections: 1->2, 2->3, ..., (T-2)->(T-1) + seq_src = list(range(1, temporal_length - 1)) + seq_dst = list(range(2, temporal_length)) + + # Combine all edges + all_src = hub_src + seq_src + all_dst = hub_dst + seq_dst + + edge_index = torch.tensor([all_src, all_dst], dtype=torch.long, device=device) + + elif graph_type == "hub_n_spoke": + # Hub-and-spoke only: 0 connects to all others, no sequential + hub_src = [0] * (temporal_length - 1) + hub_dst = list(range(1, temporal_length)) + + edge_index = torch.tensor([hub_src, hub_dst], dtype=torch.long, device=device) + + elif graph_type == "complete": + # Complete graph without self-loops: all-to-all excluding self + src_nodes = [] + dst_nodes = [] + + for i in range(temporal_length): + for j in range(temporal_length): + if i != j: # Exclude self-loops + src_nodes.append(i) + dst_nodes.append(j) + + edge_index = torch.tensor([src_nodes, dst_nodes], dtype=torch.long, device=device) + + else: + # Single node case + if graph_type == "complete": + # Single node with self-loop + edge_index = torch.tensor([[0], [0]], dtype=torch.long, device=device) + else: + # Single node, no edges for other types + edge_index = torch.tensor([[], []], dtype=torch.long, device=device) + + # Create temporal graph for this spatial node + temporal_graph = torch_geometric.data.Data( + pos=temporal_pos, + edge_index=edge_index, + spatial_node_idx=torch.tensor([node_idx], device=device), # Track which spatial node this came from + temporal_length=torch.tensor([temporal_length], device=device), + temporal_position=temporal_position, # Normalized position in sequence [0, 1/T, 2/T, ...] + connectivity_type=torch.tensor([connectivity_type_map[graph_type]], device=device), + graph_type=graph_type, # Store graph type as string for debugging + ) + temporal_graphs.append(temporal_graph) + + # Batch all temporal graphs + temporal_batch = torch_geometric.data.Batch.from_data_list(temporal_graphs) + + # Store spatial graph reference and graph type + temporal_batch.spatial_graph = spatial_graph + temporal_batch.graph_type = graph_type + + return temporal_batch + + +def temporal_to_spatial_graphs(temporal_batch): + """ + Convert temporal graphs back to spatial graphs. + Take the 0th node position from each temporal graph as the updated spatial position. + """ + # Get the spatial graph template + spatial_graph = temporal_batch.spatial_graph.clone() + + # Extract 0th node positions from each temporal graph + num_temporal_graphs = temporal_batch.num_graphs + updated_positions = [] + + # Iterate through each temporal graph in the batch + for graph_idx in range(num_temporal_graphs): + # Get the node range for this temporal graph + start_idx = temporal_batch.ptr[graph_idx] + + # The 0th node of each temporal graph is at the start of its range + updated_positions.append(temporal_batch.pos[start_idx]) + + # Stack to create new position tensor + updated_positions = torch.stack(updated_positions) + + # Update spatial graph with new positions + spatial_graph.pos = updated_positions + + return spatial_graph diff --git a/scratch/transformer/develop_transformer.py b/scratch/transformer/develop_transformer.py new file mode 100644 index 0000000..69a7aec --- /dev/null +++ b/scratch/transformer/develop_transformer.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python3 +""" +Simple test script for the debugged denoiser_conditional using default hydra config. +Tests with sigma = 0.0 and sigma = 0.1. + +Device Handling: +- This script manually handles CUDA device placement for standalone testing +- PyTorch Lightning DOES handle device placement automatically when using the Trainer +- In Lightning, you typically don't need to call .to(device) manually on models or data +- Lightning moves models to the specified device and handles data loading automatically +- For standalone scripts like this one, manual device handling is required +""" + +import e3nn + +e3nn.set_optimization_defaults(jit_script_fx=False) +import torch +import torch_geometric + +# Import spatial-temporal conversion functions +from convert_spatiotemporal import spatial_to_temporal_graphs, temporal_to_spatial_graphs + +# Import node attribute conversion functions +from pooling import SpatialTemporalToTemporalNodeAttr, TemporalToSpatialNodeAttrMean + +from jamun.data import parse_datasets_from_directory +from jamun.utils import unsqueeze_trailing + +# Setup device - use CUDA if available +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(f"Using device: {device}") +if torch.cuda.is_available(): + print(f"CUDA device: {torch.cuda.get_device_name()}") + print(f"CUDA memory: {torch.cuda.get_device_properties(device).total_memory / 1e9:.1f} GB") + + +def to_device(obj, device): + """Helper function to move objects to device, handling various types.""" + if hasattr(obj, "to"): + return obj.to(device) + elif isinstance(obj, list | tuple): + return type(obj)(to_device(item, device) for item in obj) + elif isinstance(obj, dict): + return {key: to_device(value, device) for key, value in obj.items()} + else: + return obj + + +def move_graph_to_device(graph, device): + """Move a PyTorch Geometric graph and all its tensor attributes to device.""" + # Move the graph using standard .to() method + graph = graph.to(device) + + # Manually move any custom tensor attributes that might not be handled + for attr_name in dir(graph): + if not attr_name.startswith("_"): # Skip private attributes + attr_value = getattr(graph, attr_name, None) + if isinstance(attr_value, torch.Tensor): + setattr(graph, attr_name, attr_value.to(device)) + + return graph + + +dataset = parse_datasets_from_directory( + root="/data/bucket/kleinhej/capped_diamines/timewarp_splits/train", + traj_pattern="^(.*).xtc", + pdb_pattern="^(.*).pdb", + filter_codes=["ALA_ALA"], + as_iterable=False, + subsample=80, + total_lag_time=8, + lag_subsample_rate=10, + start_frame=800000, + num_frames=200000, +) +# temporal_distance_cutoff = compute_temporal_average_squared_distance_from_datasets(dataset) +breakpoint() +# convert to dataloader, then pull data + +graph = dataset[0].__getitem__(0) +batch = torch_geometric.data.Batch.from_data_list([graph]) +from helpers import add_edges + +batch = add_edges(batch.pos, batch, batch.batch, 0.05) + +# Move batch data to device +batch = move_graph_to_device(batch, device) +print(f"Moved batch to device: {batch.pos.device}") + +# Use E3Conv architecture for spatial feature processing +from e3nn import o3 +from helpers import create_e3conv_network, get_e3conv_output_irreps + +# Create E3Conv network with yaml configuration parameters +spatial_e3conv = create_e3conv_network() + +# Move E3Conv model to device +spatial_e3conv = spatial_e3conv.to(device) +print(f"Moved E3Conv model to device: {next(spatial_e3conv.parameters()).device}") + +print(f"E3Conv output irreps: {get_e3conv_output_irreps()}") +output_irreps = o3.Irreps(get_e3conv_output_irreps()) +print(f"E3Conv output dimension: {output_irreps.dim}") + +print("\n" + "=" * 50) +print("TEMPORAL GRAPH CONVERSION") +print("=" * 50) + + +def test_temporal_conversion(batch, graph_type="fan"): + """Test the conversion functions with example output.""" + print("=== Testing Temporal Graph Conversion ===") + + print("Original spatial batch:") + print(f" - pos shape: {batch.pos.shape}") + print( + f" - hidden_state length: {len(batch.hidden_state) if hasattr(batch, 'hidden_state') and batch.hidden_state else 0}" + ) + if hasattr(batch, "hidden_state") and batch.hidden_state: + print(f" - hidden_state[0] shape: {batch.hidden_state[0].shape}") + + # Convert to temporal + temporal_batch = spatial_to_temporal_graphs(batch) + + print("\nTemporal batch:") + print(f" - pos shape: {temporal_batch.pos.shape}") + print(f" - edge_index shape: {temporal_batch.edge_index.shape}") + print(f" - num_graphs: {temporal_batch.num_graphs}") + print(" - example edge_index for first temporal graph:") + + # Show first temporal graph structure + first_graph_end = temporal_batch.ptr[1] if temporal_batch.num_graphs > 1 else len(temporal_batch.pos) + first_graph_edges = temporal_batch.edge_index[:, temporal_batch.edge_index[0] < first_graph_end] + print(f" {first_graph_edges}") + + print("\n - COMPLETE edge_index for entire temporal batch:") + print(f" Shape: {temporal_batch.edge_index.shape}") + print(f" {temporal_batch.edge_index}") + + print(f"\n - Temporal graph boundaries (ptr): {temporal_batch.ptr}") + print(" - Graph node ranges:") + for i in range(temporal_batch.num_graphs): + start = temporal_batch.ptr[i] + end = temporal_batch.ptr[i + 1] if i + 1 < len(temporal_batch.ptr) else len(temporal_batch.pos) + print(f" Graph {i}: nodes {start}-{end - 1} ({end - start} nodes)") + + print("\n - Temporal positions for each graph:") + print(f" Shape: {temporal_batch.temporal_position.shape}") + print(f" First graph temporal_position: {temporal_batch.temporal_position[:5]}") # Show first 5 positions + print(f" All temporal_position values: {temporal_batch.temporal_position}") + + # Convert back to spatial + reconstructed_spatial = temporal_to_spatial_graphs(temporal_batch) + + print("\nReconstructed spatial:") + print(f" - pos shape: {reconstructed_spatial.pos.shape}") + print(f" - position difference from original: {torch.norm(reconstructed_spatial.pos - batch.pos)}") + + return temporal_batch, reconstructed_spatial + + +# Test the temporal conversion +temporal_batch, reconstructed_spatial = test_temporal_conversion(batch) + +# Move temporal batch to device (should already be on correct device now) +temporal_batch = move_graph_to_device(temporal_batch, device) +reconstructed_spatial = move_graph_to_device(reconstructed_spatial, device) +print("Temporal batch device verification:") +print(f" - pos device: {temporal_batch.pos.device}") +print(f" - edge_index device: {temporal_batch.edge_index.device}") +print(f" - batch device: {temporal_batch.batch.device}") +print(f" - ptr device: {temporal_batch.ptr.device}") +if hasattr(temporal_batch, "temporal_position"): + print(f" - temporal_position device: {temporal_batch.temporal_position.device}") +if hasattr(temporal_batch, "spatial_node_idx"): + print(f" - spatial_node_idx device: {temporal_batch.spatial_node_idx.device}") + +print("\n" + "=" * 50) +print("PROCESSING ALL TEMPORAL POSITIONS WITH E3CONV") +print("=" * 50) +breakpoint() +# Process all temporal positions with E3Conv +with torch.no_grad(): + # Create topology without positions for E3Conv processing + # add edges to the topology + sigma = torch.tensor(0.0, device=device) + from jamun.utils import unsqueeze_trailing + + sigma = unsqueeze_trailing(sigma, 1) + topology = batch.clone() + topology = move_graph_to_device(topology, device) + del topology.pos, topology.batch, topology.num_graphs + + # Process current positions: [N, 3] -> [N, 1, num_features] + node_attr_current = spatial_e3conv( + batch.pos, topology, batch.batch, num_graphs=batch.num_graphs, c_noise=sigma, effective_radial_cutoff=0.05 + ).unsqueeze(1) + + # Process hidden state positions and collect all temporal features + node_attr_list = [node_attr_current] + breakpoint() + if hasattr(batch, "hidden_state") and batch.hidden_state: + for hidden_pos in batch.hidden_state: + node_attr_hidden = node_attr_current = spatial_e3conv( + hidden_pos, + topology, + batch.batch, + num_graphs=batch.num_graphs, + c_noise=sigma, + effective_radial_cutoff=0.05, + ).unsqueeze(1) + node_attr_list.append(node_attr_hidden) + + # Stack along temporal dimension: [N, T, num_features] + breakpoint() + node_attr_spatial_temporal = torch.cat(node_attr_list, dim=1) + + breakpoint() + # Convert spatial-temporal features to temporal node attributes with proper ordering + spatial_temporal_pooler = SpatialTemporalToTemporalNodeAttr() + spatial_node_attr_all_temporal = spatial_temporal_pooler(node_attr_spatial_temporal, temporal_batch) + +breakpoint() +print(f"Node attributes for all temporal positions: {spatial_node_attr_all_temporal.shape}") +print(f"First spatial node temporal features: {node_attr_spatial_temporal[0].shape}") +print(f"Total norm (should be nonzero): {torch.norm(spatial_node_attr_all_temporal):.6f}") +print(f"Spatial node attributes device: {spatial_node_attr_all_temporal.device}") + +print("\n" + "=" * 50) +print("E3TRANSFORMER TEST") +print("=" * 50) + +breakpoint() + + +def test_e3_transformer(batch, temporal_batch, spatial_node_attr_all_temporal, device): + """Test the E3Transformer with temporal graphs.""" + from temporal_transformer import E3Transformer + + print("=== Testing E3Transformer ===") + + # Use the precomputed temporal node attributes (processed by E3Conv) + print(f"Using temporal node attributes (processed by E3Conv): {spatial_node_attr_all_temporal.shape}") + print(f"Sample temporal node attr: {spatial_node_attr_all_temporal[0]}") + + # The node attributes are already arranged to match temporal graph ordering + temporal_node_attr = spatial_node_attr_all_temporal + + print("Input shapes:") + print(f" - temporal_node_attr: {temporal_node_attr.shape}") + print(f" - temporal_graph.pos: {temporal_batch.pos.shape}") + print(f" - temporal_graph.edge_index: {temporal_batch.edge_index.shape}") + print(f" - temporal_graph.temporal_position: {temporal_batch.temporal_position.shape}") + print(f" - temporal_graph.batch: {temporal_batch.batch.shape}") + print(f" - temporal_graph.num_graphs: {temporal_batch.num_graphs}") + + # Create E3Transformer model that takes 1x1e node attributes (E3Conv output) + transformer = E3Transformer( + irreps_out="3x1e", # 3D output (like positions) + irreps_hidden="8x0e + 4x1e", # Hidden representations + irreps_sh="1x0e + 1x1e", # Spherical harmonics + irreps_node_attr="1x1e", # Input node attributes match E3Conv output + num_layers=2, + edge_attr_dim=24, # Split into 2 parts: 12+12 (radial+temporal) + num_attention_heads=1, # Single attention head for simpler test + ) + + # Move transformer to device + transformer = transformer.to(device) + print(f"Moved transformer to device: {next(transformer.parameters()).device}") + + print("\nTransformer parameters:") + print(f" - irreps_out: {transformer.irreps_out}") + print(f" - irreps_hidden: {transformer.irreps_hidden}") + print(f" - irreps_node_attr: {transformer.irreps_node_attr}") + print(f" - temporal_gate.irreps_out: {transformer.temporal_gate.irreps_out}") + print(f" - radial_edge_attr_dim: {transformer.radial_edge_attr_dim}") + print(f" - temporal_edge_attr_dim: {transformer.temporal_edge_attr_dim}") + + # Forward pass with tensor and graph (like E3Conv) + effective_radial_cutoff = 5.0 # Define the cutoff in forward pass + temporal_cutoff = 1.0 # Default temporal cutoff (no cutoff for temporal contributions) + with torch.no_grad(): + try: + transformer_output = transformer( + temporal_node_attr, temporal_batch, effective_radial_cutoff, temporal_cutoff + ) + print("\n✅ Transformer forward pass successful!") + print(f"Transformer output shape: {transformer_output.shape}") + print(f"Transformer output sample: {transformer_output[0]}") + print(f"Transformer output norm: {torch.norm(transformer_output):.6f}") + print(f"Used effective_radial_cutoff: {effective_radial_cutoff}") + print(f"Used temporal_cutoff: {temporal_cutoff}") + return True + except Exception as e: + print(f"\n❌ Transformer forward pass failed: {e}") + import traceback + + traceback.print_exc() + return False + + +# Test the complete workflow: E3Conv -> Transformer +success = test_e3_transformer(batch, temporal_batch, spatial_node_attr_all_temporal, device) + +print("\n" + "=" * 50) +print("TEMPORAL TO SPATIAL MEAN POOLING") +print("=" * 50) +breakpoint() +# Demonstrate mean pooling from temporal features back to spatial features +print("=== Testing Mean Pooling ===") + +# Use the transformer output or the original temporal features for pooling demonstration +print(f"Input temporal features shape: {spatial_node_attr_all_temporal.shape}") + +# Create mean pooling module and apply it +temporal_to_spatial_pooler = TemporalToSpatialNodeAttrMean() +spatial_features_pooled = temporal_to_spatial_pooler(spatial_node_attr_all_temporal, temporal_batch) + +print(f"Output spatial features shape: {spatial_features_pooled.shape}") +print(f"Number of spatial nodes recovered: {spatial_features_pooled.shape[0]}") +print(f"Original spatial nodes: {batch.pos.shape[0]}") +print(f"Feature dimension: {spatial_features_pooled.shape[1]}") +print(f"Sample pooled features (first node): {spatial_features_pooled[0]}") +print(f"Pooled features norm: {torch.norm(spatial_features_pooled):.6f}") + +# Verify that we correctly recovered the spatial dimension +assert spatial_features_pooled.shape[0] == batch.pos.shape[0], ( + f"Spatial node count mismatch: {spatial_features_pooled.shape[0]} vs {batch.pos.shape[0]}" +) +print("✅ Mean pooling successfully converted temporal features back to spatial!") + +print("\n" + "=" * 50) +print("TESTS COMPLETED") +print("=" * 50) + +print("\n" + "=" * 50) +print("FINAL SUMMARY") +print("=" * 50) + +# Device summary +print("Device Summary:") +print(f" - Used device: {device}") +if torch.cuda.is_available(): + print(f" - CUDA memory allocated: {torch.cuda.memory_allocated(device) / 1e9:.2f} GB") + print(f" - CUDA memory cached: {torch.cuda.memory_reserved(device) / 1e9:.2f} GB") + +print("\nTest Results:") +print(f" - Manual workflow tests: {'✅ PASSED' if success else '❌ FAILED'}") + +if success: + print("\n🎉 ALL TESTS PASSED!") + print("The manual spatio-temporal workflow is working correctly.") + print("To test the unified E3SpatioTemporal model, run: python3 test_e3_spatiotemporal.py") +else: + print("\n⚠️ Some tests failed. Check the output above for details.") diff --git a/scratch/transformer/helpers.py b/scratch/transformer/helpers.py new file mode 100644 index 0000000..d06d705 --- /dev/null +++ b/scratch/transformer/helpers.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +""" +Helper functions for creating network architectures used in transformer development. +""" + +import functools + +import e3tools +import numpy as np +import torch +import torch_geometric +from convert_spatiotemporal import spatial_to_temporal_graphs + +from jamun.model.arch.e3conv import E3Conv +from jamun.utils.average_squared_distance import compute_average_squared_distance + + +def compute_temporal_average_squared_distance_from_dataset( + dataset, num_samples: int = 100, verbose: bool = False +) -> float: + """ + Compute average squared distance between neighboring vertices in temporal graphs. + + Args: + dataset: Dataset containing spatial graphs with hidden states + num_samples: Number of samples to use for estimation + verbose: Whether to print verbose output + + Returns: + float: Average squared distance between temporal neighbors + """ + + avg_sq_dists = [] + num_graphs = 0 + + # Follow pattern from average_squared_distance.py + for item in dataset: + if num_graphs >= num_samples: + break + for graph in item: + if num_graphs >= num_samples: + break + # Convert to temporal graphs here + temporal_batch = spatial_to_temporal_graphs(graph) + temporal_graphs = torch_geometric.data.Batch.to_data_list(temporal_batch) + graph_mean = 0.0 + num_nodes = graph.pos.shape[0] + for temporal_graph in temporal_graphs: + avg_sq_dist = compute_average_squared_distance(temporal_graph.pos, cutoff=None) + graph_mean += avg_sq_dist / num_nodes + avg_sq_dists.append(graph_mean) + num_graphs += 1 + mean_avg_sq_dist = sum(avg_sq_dists) / num_graphs + + if verbose: + print(f"Total graphs processed: {num_graphs}") + print(f"Total temporal graphs processed: {len(avg_sq_dists)}") + print(f"Mean average squared distance between temporal nodes: {mean_avg_sq_dist:.6f}") + print(f"Standard deviation: {np.std(avg_sq_dists):.6f}") + + return float(mean_avg_sq_dist) + + +def add_edges( + y: torch.Tensor, + topology: torch_geometric.data.Batch, + batch: torch.Tensor, + radial_cutoff: float, +) -> torch_geometric.data.Batch: + """Add edges to the graph based on the effective radial cutoff.""" + if topology.get("edge_index") is not None: + return topology + + topology = topology.clone() + with torch.cuda.nvtx.range("radial_graph"): + radial_edge_index = e3tools.radius_graph(y, radial_cutoff, batch) + + with torch.cuda.nvtx.range("concatenate_edges"): + edge_index = torch.cat((radial_edge_index, topology.bonded_edge_index), dim=-1) + if topology.bonded_edge_index.numel() == 0: + bond_mask = torch.zeros(radial_edge_index.shape[1], dtype=torch.long, device=y.device) + else: + bond_mask = torch.cat( + ( + torch.zeros(radial_edge_index.shape[1], dtype=torch.long, device=y.device), + torch.ones(topology.bonded_edge_index.shape[1], dtype=torch.long, device=y.device), + ), + dim=0, + ) + + topology.edge_index = edge_index + topology.bond_mask = bond_mask + return topology + + +def apply_e3conv_to_positions(e3conv_model, pos, topology, batch, effective_radial_cutoff=5.0): + """ + Apply E3Conv model to a set of positions using existing graph topology. + + Args: + e3conv_model: E3Conv model instance + pos (torch.Tensor): Positions [N, 3] + topology (torch_geometric.data.Batch): Existing graph topology from dataloader + batch (torch.Tensor): Batch tensor from the graph + effective_radial_cutoff (float): Radial cutoff for edges + + Returns: + torch.Tensor: Node features [N, feature_dim] + """ + # Clone topology to avoid modifying original + topology_with_edges = topology.clone() + + # Add edges using the local add_edges function + topology_with_edges = add_edges(pos, topology_with_edges, batch, effective_radial_cutoff) + + # Use noise conditioning of 0.0 (no noise) + c_noise = torch.zeros(pos.shape[0], dtype=pos.dtype, device=pos.device) + + # Apply E3Conv + num_graphs = batch.max().item() + 1 # Number of graphs in the batch + node_features = e3conv_model( + pos=pos, + topology=topology_with_edges, + batch=batch, + num_graphs=num_graphs, + c_noise=c_noise, + effective_radial_cutoff=effective_radial_cutoff, + ) + + return node_features + + +def create_e3conv_network(): + """ + Create an E3Conv network with parameters matching the yaml configuration. + + Returns: + E3Conv: Configured E3Conv network + """ + + # Hidden layer factory as specified in yaml + hidden_layer_factory = functools.partial(e3tools.nn.ConvBlock, conv=functools.partial(e3tools.nn.Conv)) + + # Output head factory as specified in yaml + output_head_factory = functools.partial( + e3tools.nn.EquivariantMLP, + irreps_hidden_list=["120x0e + 32x1e"], # Using irreps_hidden from yaml + ) + + # Create E3Conv with exact parameters from yaml + e3conv = E3Conv( + irreps_out="1x1e", # 3D vector output + irreps_hidden="120x0e + 32x1e", # Hidden representations + irreps_sh="1x0e + 1x1e", # Spherical harmonics + hidden_layer_factory=hidden_layer_factory, + output_head_factory=output_head_factory, + use_residue_information=True, # Assuming True, matches yaml ${data.use_residue_information} + n_layers=1, # Number of layers + edge_attr_dim=64, # Edge attribute dimension + atom_type_embedding_dim=8, # Atom type embedding + atom_code_embedding_dim=8, # Atom code embedding + residue_code_embedding_dim=32, # Residue code embedding + residue_index_embedding_dim=8, # Residue index embedding + use_residue_sequence_index=False, # As specified in yaml + num_atom_types=20, # Number of atom types + max_sequence_length=10, # Max sequence length + num_atom_codes=10, # Number of atom codes + num_residue_types=25, # Number of residue types + test_equivariance=False, # Disable for production + reduce=None, # No reduction + ) + + return e3conv + + +def get_e3conv_output_irreps(): + """ + Get the output irreps of the E3Conv network. + + Returns: + str: Output irreps string + """ + return "1x1e" # 3D vector output as specified in yaml diff --git a/scratch/transformer/pooling.py b/scratch/transformer/pooling.py new file mode 100644 index 0000000..418aed8 --- /dev/null +++ b/scratch/transformer/pooling.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python3 +""" +Lightning modules for converting node attributes between spatial and temporal representations. +""" + +import pytorch_lightning as pl +import torch + + +class SpatialToTemporalNodeAttr(pl.LightningModule): + """ + Lightning module to transfer node attributes from spatial nodes to temporal nodes + by repeating first temporal feature. + """ + + def __init__(self): + super().__init__() + + def forward(self, spatial_node_attr_temporal, temporal_batch): + """ + Transfer node attributes from spatial nodes to temporal nodes by repeating first temporal feature. + Takes the first temporal feature (t=0) and repeats it T times for each spatial node. + + Args: + spatial_node_attr_temporal (torch.Tensor): Node attributes [N_spatial, T, attr_dim] + temporal_batch (torch_geometric.data.Batch): Batch of temporal graphs + + Returns: + torch.Tensor: Node attributes for temporal nodes [N_temporal, attr_dim] + """ + num_spatial_nodes, temporal_length, attr_dim = spatial_node_attr_temporal.shape + num_temporal_graphs = temporal_batch.num_graphs + + # Verify consistency + assert num_spatial_nodes == num_temporal_graphs, ( + f"Mismatch: {num_spatial_nodes} spatial nodes vs {num_temporal_graphs} temporal graphs" + ) + + # Verify temporal length consistency + expected_temporal_nodes = temporal_batch.pos.shape[0] + expected_total_nodes = num_spatial_nodes * temporal_length + assert expected_total_nodes == expected_temporal_nodes, ( + f"Temporal length mismatch: {expected_total_nodes} vs {expected_temporal_nodes}" + ) + + # Extract first temporal feature (t=0) and repeat it T times for each spatial node + first_temporal_features = spatial_node_attr_temporal[:, 0, :] # [N, attr_dim] + + # Repeat each spatial node's first temporal feature T times + temporal_node_attr = first_temporal_features.repeat_interleave(temporal_length, dim=0) # [N*T, attr_dim] + + # Verify the output shape matches the temporal batch + assert temporal_node_attr.shape[0] == expected_temporal_nodes, ( + f"Output shape mismatch: {temporal_node_attr.shape[0]} vs expected {expected_temporal_nodes}" + ) + + return temporal_node_attr + + +class TemporalToSpatialNodeAttr(pl.LightningModule): + """ + Lightning module to convert temporal node attributes back to spatial node attributes. + Takes the first temporal node attribute from each temporal graph. + """ + + def __init__(self): + super().__init__() + + def forward(self, temporal_node_attr, temporal_batch): + """ + Convert temporal node attributes back to spatial node attributes. + Takes the first temporal node attribute from each temporal graph. + + Args: + temporal_node_attr (torch.Tensor): Node attributes for temporal nodes [N_temporal, attr_dim] + temporal_batch (torch_geometric.data.Batch): Batch of temporal graphs + + Returns: + torch.Tensor: Node attributes for spatial nodes [N_spatial, attr_dim] + """ + num_temporal_graphs = temporal_batch.num_graphs + attr_dim = temporal_node_attr.shape[1] + + # Extract the first node attribute from each temporal graph + spatial_node_attr = [] + + for graph_idx in range(num_temporal_graphs): + # Get the node range for this temporal graph + start_idx = temporal_batch.ptr[graph_idx] + + # The 0th node of each temporal graph is at the start of its range + first_node_attr = temporal_node_attr[start_idx] + spatial_node_attr.append(first_node_attr) + + # Stack to create spatial node attribute tensor + spatial_node_attr = torch.stack(spatial_node_attr) + + # Verify output shape + assert spatial_node_attr.shape == (num_temporal_graphs, attr_dim), ( + f"Output shape mismatch: {spatial_node_attr.shape} vs expected ({num_temporal_graphs}, {attr_dim})" + ) + + return spatial_node_attr + + +class TemporalToSpatialNodeAttrMean(pl.LightningModule): + """ + Lightning module to convert temporal node attributes back to spatial node attributes by averaging. + Takes the mean of all temporal node attributes for each temporal graph. + """ + + def __init__(self): + super().__init__() + + def forward(self, temporal_node_attr, temporal_batch): + """ + Convert temporal node attributes back to spatial node attributes by averaging. + Takes the mean of all temporal node attributes for each temporal graph. + + Args: + temporal_node_attr (torch.Tensor): Node attributes for temporal nodes [N_temporal, attr_dim] + temporal_batch (torch_geometric.data.Batch): Batch of temporal graphs + + Returns: + torch.Tensor: Node attributes for spatial nodes [N_spatial, attr_dim] + """ + num_temporal_graphs = temporal_batch.num_graphs + attr_dim = temporal_node_attr.shape[1] + + # Extract the mean node attributes from each temporal graph + spatial_node_attr = [] + + for graph_idx in range(num_temporal_graphs): + # Get the node range for this temporal graph + start_idx = temporal_batch.ptr[graph_idx] + end_idx = ( + temporal_batch.ptr[graph_idx + 1] + if graph_idx + 1 < len(temporal_batch.ptr) + else len(temporal_node_attr) + ) + + # Take the mean of all temporal nodes for this spatial node + temporal_nodes_attr = temporal_node_attr[start_idx:end_idx] # [temporal_length, attr_dim] + mean_node_attr = temporal_nodes_attr.mean(dim=0) # [attr_dim] + spatial_node_attr.append(mean_node_attr) + + # Stack to create spatial node attribute tensor + spatial_node_attr = torch.stack(spatial_node_attr) + + # Verify output shape + assert spatial_node_attr.shape == (num_temporal_graphs, attr_dim), ( + f"Output shape mismatch: {spatial_node_attr.shape} vs expected ({num_temporal_graphs}, {attr_dim})" + ) + + return spatial_node_attr + + +class SpatialTemporalToTemporalNodeAttr(pl.LightningModule): + """ + Lightning module to convert spatial node attributes arranged temporally to temporal node attributes. + Converts from [N, T, features] to [NT, features] with correct temporal graph ordering. + """ + + def __init__(self): + super().__init__() + + def forward(self, spatial_node_attr_temporal, temporal_batch): + """ + Convert spatial node attributes arranged temporally to temporal node attributes. + Converts from [N, T, features] to [NT, features] with correct temporal graph ordering. + + Args: + spatial_node_attr_temporal (torch.Tensor): Node attributes [N_spatial, T, attr_dim] + temporal_batch (torch_geometric.data.Batch): Batch of temporal graphs for validation + + Returns: + torch.Tensor: Node attributes for temporal nodes [N_temporal, attr_dim] + """ + num_spatial_nodes, temporal_length, attr_dim = spatial_node_attr_temporal.shape + num_temporal_graphs = temporal_batch.num_graphs + + # Verify consistency with temporal batch + assert num_spatial_nodes == num_temporal_graphs, ( + f"Mismatch: {num_spatial_nodes} spatial nodes vs {num_temporal_graphs} temporal graphs" + ) + + # Verify temporal length consistency + expected_temporal_nodes = temporal_batch.pos.shape[0] + expected_total_nodes = num_spatial_nodes * temporal_length + assert expected_total_nodes == expected_temporal_nodes, ( + f"Temporal length mismatch: {expected_total_nodes} vs {expected_temporal_nodes}" + ) + + # Reshape to match temporal graph ordering: [N, T, features] -> [N*T, features] + # Temporal graph arranges nodes as: [node0_t0, node0_t1, ..., node0_tT-1, node1_t0, ...] + temporal_node_attr = spatial_node_attr_temporal.reshape(num_spatial_nodes * temporal_length, attr_dim) + + # Verify the output shape matches the temporal batch + assert temporal_node_attr.shape[0] == expected_temporal_nodes, ( + f"Output shape mismatch: {temporal_node_attr.shape[0]} vs expected {expected_temporal_nodes}" + ) + + return temporal_node_attr + + +# Legacy function interfaces for backward compatibility +def spatial_to_temporal_node_attr(spatial_node_attr_temporal, temporal_batch): + """Legacy function interface for backward compatibility.""" + module = SpatialToTemporalNodeAttr() + return module(spatial_node_attr_temporal, temporal_batch) + + +def temporal_to_spatial_node_attr(temporal_node_attr, temporal_batch): + """Legacy function interface for backward compatibility.""" + module = TemporalToSpatialNodeAttr() + return module(temporal_node_attr, temporal_batch) + + +def temporal_to_spatial_node_attr_mean(temporal_node_attr, temporal_batch): + """Legacy function interface for backward compatibility.""" + module = TemporalToSpatialNodeAttrMean() + return module(temporal_node_attr, temporal_batch) + + +def spatial_temporal_to_temporal_node_attr(spatial_node_attr_temporal, temporal_batch): + """Legacy function interface for backward compatibility.""" + module = SpatialTemporalToTemporalNodeAttr() + return module(spatial_node_attr_temporal, temporal_batch) diff --git a/scratch/transformer/temporal_transformer.py b/scratch/transformer/temporal_transformer.py new file mode 100644 index 0000000..2824943 --- /dev/null +++ b/scratch/transformer/temporal_transformer.py @@ -0,0 +1,298 @@ +import e3nn +import e3tools +import e3tools.nn +import torch +import torch.nn as nn +import torch_geometric.data +from e3nn import o3 + + +class E3Transformer(nn.Module): + """E(3)-equivariant transformer with temporal graph support.""" + + def __init__( + self, + irreps_out: str | e3nn.o3.Irreps, + irreps_hidden: str | e3nn.o3.Irreps, + irreps_sh: str | e3nn.o3.Irreps, + irreps_node_attr: str | e3nn.o3.Irreps, + num_layers: int, + edge_attr_dim: int, + num_attention_heads: int, + reduce: str | None = None, + ): + super().__init__() + + self.irreps_out = o3.Irreps(irreps_out) + self.irreps_hidden = o3.Irreps(irreps_hidden) + self.irreps_sh = o3.Irreps(irreps_sh) + self.irreps_node_attr = o3.Irreps(irreps_node_attr) # input irreps + self.num_layers = num_layers + self.edge_attr_dim = edge_attr_dim + self.num_attention_heads = num_attention_heads + self.reduce = reduce + self.sh = o3.SphericalHarmonics(irreps_out=self.irreps_sh, normalize=True, normalization="component") + # Split edge attribute dimensions: radial and temporal (bondedness is optional) + self.radial_edge_attr_dim = self.edge_attr_dim // 2 + self.temporal_edge_attr_dim = self.edge_attr_dim - self.radial_edge_attr_dim + + # Optional bondedness embedding (only used if bond_mask exists in graph) + self.embed_bondedness = nn.Embedding(2, self.edge_attr_dim // 3) + + # Gate for combining node attributes with temporal position + # Input: node_attr (from data) + temporal_position (1x0e scalar) + irreps_with_temporal = self.irreps_node_attr + o3.Irreps("1x0e") + self.temporal_gate = e3tools.nn.GateWrapper( + irreps_in=irreps_with_temporal, + irreps_out=self.irreps_hidden, + irreps_gate=irreps_with_temporal, + ) + # self.initial_linear = o3.Linear( + # self.temporal_gate.irreps_out, self.irreps_hidden + # ) + + self.layers = nn.ModuleList() + for _ in range(num_layers): + self.layers.append( + e3tools.nn.TransformerBlock( + irreps_in=self.irreps_hidden, + irreps_out=self.irreps_hidden, + irreps_sh=self.irreps_sh, + edge_attr_dim=self.edge_attr_dim, + num_heads=self.num_attention_heads, + ) + ) + self.output_head = e3tools.nn.EquivariantMLP( + irreps_in=self.irreps_hidden, + irreps_out=self.irreps_out, + irreps_hidden_list=[self.irreps_hidden], + ) + + def forward( + self, + node_attr: torch.Tensor, + temporal_graph: torch_geometric.data.Batch, + effective_radial_cutoff: float, + temporal_cutoff: float = 1.0, + ) -> torch.Tensor: + """Forward pass of the E3Transformer model.""" + # Extract graph data + pos = temporal_graph.pos + edge_index = temporal_graph.edge_index + temporal_position = temporal_graph.temporal_position + batch = temporal_graph.batch + num_graphs = temporal_graph.num_graphs + + src, dst = edge_index + edge_vec = pos[src] - pos[dst] + edge_sh = self.sh(edge_vec) + + # Compute edge attributes: radial and temporal + radial_edge_attr = e3nn.math.soft_one_hot_linspace( + edge_vec.norm(dim=1), + 0.0, + effective_radial_cutoff, + self.radial_edge_attr_dim, + basis="gaussian", + cutoff=True, + ) + + # Temporal edge attributes from temporal_position differences + temporal_edge_vec = temporal_position[src] - temporal_position[dst] + temporal_edge_attr = e3nn.math.soft_one_hot_linspace( + temporal_edge_vec.abs(), # Use absolute difference + 0.0, + temporal_cutoff, + self.temporal_edge_attr_dim, + basis="gaussian", + cutoff=True, + ) + + # Optional bondedness (if bond_mask exists in the temporal graph) + if hasattr(temporal_graph, "bond_mask") and temporal_graph.bond_mask is not None: + bonded_edge_attr = self.embed_bondedness(temporal_graph.bond_mask) + edge_attr = torch.cat((bonded_edge_attr, radial_edge_attr, temporal_edge_attr), dim=-1) + else: + edge_attr = torch.cat((radial_edge_attr, temporal_edge_attr), dim=-1) + + # Process node attributes with temporal gating + + # Concatenate node_attr with temporal_position (scalar) + temporal_position_expanded = temporal_position.unsqueeze(-1) # [N, 1] for concatenation + node_attr_with_temporal = torch.cat([node_attr, temporal_position_expanded], dim=-1) + + # Apply temporal gate + node_attr_processed = self.temporal_gate(node_attr_with_temporal) + # node_attr_processed = self.initial_linear(node_attr_gated) + + # Perform message passing with gated node attributes + for layer in self.layers: + node_attr_processed = layer(node_attr_processed, edge_index, edge_attr, edge_sh) + node_attr_processed = self.output_head(node_attr_processed) + + # Pool over nodes. + if self.reduce is not None: + node_attr_processed = e3tools.scatter( + node_attr_processed, + index=batch, + dim=0, + dim_size=num_graphs, + reduce=self.reduce, + ) + + return node_attr_processed + + +class E3SpatioTemporal(nn.Module): + """ + E(3)-equivariant spatio-temporal model that combines spatial and temporal processing. + + This model implements the complete workflow: + 1. Process input spatial graph and hidden states through spatial module + 2. Pool spatial features to temporal graph representation + 3. Process temporal graph through temporal module + 4. Pool temporal features back to spatial representation + 5. Convert temporal graph back to spatial graph + """ + + def __init__( + self, + spatial_module: nn.Module, + temporal_module: nn.Module, + spatial_to_temporal_pooler: nn.Module, + temporal_to_spatial_pooler: nn.Module, + radial_cutoff: float, + temporal_cutoff: float = 1.0, + ): + """ + Initialize the E3SpatioTemporal model. + + Args: + spatial_module: Module for processing spatial positions (e.g., E3Conv) + temporal_module: Module for processing temporal graphs (e.g., E3Transformer) + spatial_to_temporal_pooler: Module to convert spatial-temporal features to temporal node attributes + temporal_to_spatial_pooler: Module to convert temporal features back to spatial features + radial_cutoff: Cutoff for spatial radial edge weights + temporal_cutoff: Cutoff for temporal edge weights + """ + super().__init__() + + self.spatial_module = spatial_module + self.temporal_module = temporal_module + self.spatial_to_temporal_pooler = spatial_to_temporal_pooler + self.temporal_to_spatial_pooler = temporal_to_spatial_pooler + self.radial_cutoff = radial_cutoff + self.temporal_cutoff = temporal_cutoff + + def forward( + self, + batch: torch_geometric.data.Batch, + c_noise: torch.Tensor, + return_temporal_features: bool = False, + return_temporal_graph: bool = False, + ) -> torch.Tensor | dict[str, torch.Tensor]: + """ + Forward pass implementing the complete spatio-temporal workflow. + + Args: + batch: Input spatial graph batch with pos, batch, num_graphs, and optionally hidden_state + c_noise: Noise conditioning tensor + return_temporal_features: Whether to return intermediate temporal features + return_temporal_graph: Whether to return the temporal graph + + Returns: + If return_temporal_features or return_temporal_graph is True, returns dict with: + - 'spatial_features': Final spatial features + - 'spatial_graph': Output spatial graph + - 'temporal_features': Temporal features (if requested) + - 'temporal_graph': Temporal graph (if requested) + Otherwise returns just the final spatial features tensor + """ + from convert_spatiotemporal import spatial_to_temporal_graphs, temporal_to_spatial_graphs + + # Store original device + + # Step 1: Convert spatial graph to temporal graphs + temporal_batch = spatial_to_temporal_graphs(batch) + + # Step 2: Process all positions (current + hidden states) with spatial module + # Create topology for spatial processing (without positions) + topology = batch.clone() + # Remove position-dependent attributes but keep graph structure + if hasattr(topology, "pos"): + del topology.pos + if hasattr(topology, "batch"): + del topology.batch + if hasattr(topology, "num_graphs"): + del topology.num_graphs + + node_attr_list = [] + + # Process current positions + node_attr_current = self.spatial_module( + batch.pos, + topology, + batch.batch, + num_graphs=batch.num_graphs, + c_noise=c_noise, + effective_radial_cutoff=self.radial_cutoff, + ).unsqueeze(1) # [N, 1, features] + node_attr_list.append(node_attr_current) + + # Process hidden state positions if they exist + if hasattr(batch, "hidden_state") and batch.hidden_state is not None and len(batch.hidden_state) > 0: + for hidden_pos in batch.hidden_state: + node_attr_hidden = self.spatial_module( + hidden_pos, + topology, + batch.batch, + num_graphs=batch.num_graphs, + c_noise=c_noise, + effective_radial_cutoff=self.radial_cutoff, + ).unsqueeze(1) # [N, 1, features] + node_attr_list.append(node_attr_hidden) + + # Step 3: Stack spatial-temporal features + node_attr_spatial_temporal = torch.cat(node_attr_list, dim=1) # [N, T, features] + + # Step 4: Convert spatial-temporal features to temporal node attributes + temporal_node_attr = self.spatial_to_temporal_pooler(node_attr_spatial_temporal, temporal_batch) + + # Step 5: Process temporal graph through temporal module + temporal_output = self.temporal_module( + temporal_node_attr, temporal_batch, self.radial_cutoff, self.temporal_cutoff + ) + + # Step 6: Pool temporal features back to spatial features + spatial_features = self.temporal_to_spatial_pooler(temporal_output, temporal_batch) + + # Step 7: Convert temporal graph back to spatial graph + output_spatial_graph = temporal_to_spatial_graphs(temporal_batch) + + # Prepare return values + if return_temporal_features or return_temporal_graph: + result = { + "spatial_features": spatial_features, + "spatial_graph": output_spatial_graph, + } + if return_temporal_features: + result["temporal_features"] = temporal_output + if return_temporal_graph: + result["temporal_graph"] = temporal_batch + return result + else: + return spatial_features + + def get_spatial_output_irreps(self): + """Get the irreps of the spatial module output.""" + if hasattr(self.spatial_module, "irreps_out"): + return self.spatial_module.irreps_out + else: + raise AttributeError("Spatial module does not have irreps_out attribute") + + def get_temporal_output_irreps(self): + """Get the irreps of the temporal module output.""" + if hasattr(self.temporal_module, "irreps_out"): + return self.temporal_module.irreps_out + else: + raise AttributeError("Temporal module does not have irreps_out attribute") diff --git a/scratch/transformer/test_e3_spatiotemporal.py b/scratch/transformer/test_e3_spatiotemporal.py new file mode 100644 index 0000000..c6a65a7 --- /dev/null +++ b/scratch/transformer/test_e3_spatiotemporal.py @@ -0,0 +1,221 @@ +#!/usr/bin/env python3 +""" +Test script for the E3SpatioTemporal model. + +This script tests the unified E3SpatioTemporal model that encapsulates +the complete spatio-temporal processing workflow. +""" + +import e3nn + +e3nn.set_optimization_defaults(jit_script_fx=False) +import torch +import torch_geometric + +# Import modules needed for the test +from helpers import add_edges, create_e3conv_network +from pooling import SpatialTemporalToTemporalNodeAttr, TemporalToSpatialNodeAttrMean +from temporal_transformer import E3SpatioTemporal, E3Transformer + +from jamun.data import parse_datasets_from_directory +from jamun.utils import unsqueeze_trailing + + +def setup_device(): + """Setup CUDA device if available.""" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + if torch.cuda.is_available(): + print(f"CUDA device: {torch.cuda.get_device_name()}") + print(f"CUDA memory: {torch.cuda.get_device_properties(device).total_memory / 1e9:.1f} GB") + return device + + +def move_graph_to_device(graph, device): + """Move a PyTorch Geometric graph and all its tensor attributes to device.""" + # Move the graph using standard .to() method + graph = graph.to(device) + + # Manually move any custom tensor attributes that might not be handled + for attr_name in dir(graph): + if not attr_name.startswith("_"): # Skip private attributes + attr_value = getattr(graph, attr_name, None) + if isinstance(attr_value, torch.Tensor): + setattr(graph, attr_name, attr_value.to(device)) + + return graph + + +def load_test_data(device): + """Load and prepare test data.""" + print("Loading test data...") + + dataset = parse_datasets_from_directory( + root="/data2/sules/ALA_ALA_enhanced_full_grid/train", + traj_pattern="^(.*).xtc", + pdb_pattern="^(.*).pdb", + subsample=1, + total_lag_time=5, + lag_subsample_rate=1, + max_datasets=3, + num_frames=10, + ) + + # Get first graph and create batch + graph = dataset[0].__getitem__(0) + batch = torch_geometric.data.Batch.from_data_list([graph]) + batch = add_edges(batch.pos, batch, batch.batch, 0.05) + + # Move to device + batch = move_graph_to_device(batch, device) + print(f"Loaded batch with {batch.pos.shape[0]} nodes on device: {batch.pos.device}") + + return batch + + +def create_spatiotemporal_model(device): + """Create and configure the E3SpatioTemporal model.""" + print("Creating E3SpatioTemporal model...") + + # Create component modules + spatial_module = create_e3conv_network().to(device) + + temporal_module = E3Transformer( + irreps_out="3x1e", # 3D output (like positions) + irreps_hidden="8x0e + 4x1e", # Hidden representations + irreps_sh="1x0e + 1x1e", # Spherical harmonics + irreps_node_attr="1x1e", # Input node attributes match E3Conv output + num_layers=2, + edge_attr_dim=24, # Split into 2 parts: 12+12 (radial+temporal) + num_attention_heads=1, # Single attention head for simpler test + ).to(device) + + spatial_to_temporal_pooler = SpatialTemporalToTemporalNodeAttr() + temporal_to_spatial_pooler = TemporalToSpatialNodeAttrMean() + + # Create the unified model + spatiotemporal_model = E3SpatioTemporal( + spatial_module=spatial_module, + temporal_module=temporal_module, + spatial_to_temporal_pooler=spatial_to_temporal_pooler, + temporal_to_spatial_pooler=temporal_to_spatial_pooler, + radial_cutoff=0.05, + temporal_cutoff=1.0, + ).to(device) + + print(f"Created E3SpatioTemporal model on device: {next(spatiotemporal_model.parameters()).device}") + return spatiotemporal_model + + +def test_spatiotemporal_model(model, batch, device): + """Test the E3SpatioTemporal model with various configurations.""" + print("=" * 50) + print("TESTING E3SPATIOTEMPORAL MODEL") + print("=" * 50) + + # Print model information + print("Model components:") + print(f" - Spatial module output irreps: {model.get_spatial_output_irreps()}") + print(f" - Temporal module output irreps: {model.get_temporal_output_irreps()}") + print(f" - Radial cutoff: {model.radial_cutoff}") + print(f" - Temporal cutoff: {model.temporal_cutoff}") + + # Prepare noise conditioning + sigma = torch.tensor(0.0, device=device) + sigma = unsqueeze_trailing(sigma, 1) + + print("\nInput batch:") + print(f" - pos shape: {batch.pos.shape}") + print( + f" - hidden_state length: {len(batch.hidden_state) if hasattr(batch, 'hidden_state') and batch.hidden_state else 0}" + ) + print(f" - batch device: {batch.pos.device}") + + success = True + + try: + with torch.no_grad(): + print("\n1. Testing simple forward pass (spatial features only)...") + spatial_features = model(batch, sigma) + print(f" ✅ Success! Spatial features shape: {spatial_features.shape}") + print(f" Spatial features device: {spatial_features.device}") + print(f" Spatial features norm: {torch.norm(spatial_features):.6f}") + + print("\n2. Testing full forward pass (all outputs)...") + results = model(batch, sigma, return_temporal_features=True, return_temporal_graph=True) + + print(" ✅ Success! Full output results:") + print(f" - spatial_features shape: {results['spatial_features'].shape}") + print(f" - temporal_features shape: {results['temporal_features'].shape}") + print(f" - temporal_graph num_graphs: {results['temporal_graph'].num_graphs}") + print(f" - spatial_graph pos shape: {results['spatial_graph'].pos.shape}") + + # Verify spatial graph reconstruction + pos_difference = torch.norm(results["spatial_graph"].pos - batch.pos) + print(f" - Spatial position reconstruction error: {pos_difference:.6f}") + + print("\n3. Testing consistency between simple and full forward pass...") + simple_features = spatial_features + full_features = results["spatial_features"] + consistency_error = torch.norm(simple_features - full_features) + print(f" Consistency error: {consistency_error:.6f}") + + if consistency_error < 1e-6: + print(" ✅ Results are consistent!") + else: + print(" ⚠️ Results differ between simple and full forward pass") + success = False + + except Exception as e: + print(f"\n❌ Test failed: {e}") + import traceback + + traceback.print_exc() + success = False + + return success + + +def main(): + """Main test function.""" + print("E3SpatioTemporal Model Test") + print("=" * 50) + + # Setup + device = setup_device() + batch = load_test_data(device) + model = create_spatiotemporal_model(device) + + # Run tests + success = test_spatiotemporal_model(model, batch, device) + + # Final summary + print("\n" + "=" * 50) + print("FINAL RESULTS") + print("=" * 50) + + # Device summary + print("Device Summary:") + print(f" - Used device: {device}") + if torch.cuda.is_available(): + print(f" - CUDA memory allocated: {torch.cuda.memory_allocated(device) / 1e9:.2f} GB") + print(f" - CUDA memory cached: {torch.cuda.memory_reserved(device) / 1e9:.2f} GB") + + print("\nTest Results:") + if success: + print("🎉 ALL TESTS PASSED! The E3SpatioTemporal model works correctly!") + print("\nThe model successfully:") + print(" ✅ Processes spatial graphs with hidden states") + print(" ✅ Converts to temporal representation") + print(" ✅ Applies temporal transformations") + print(" ✅ Pools back to spatial features") + print(" ✅ Reconstructs spatial graphs") + print(" ✅ Maintains consistency across different call patterns") + else: + print("❌ SOME TESTS FAILED! Check the output above for details.") + + return success + + +if __name__ == "__main__": + main() diff --git a/scratch/transformer/test_spatiotemporal_conditioner.py b/scratch/transformer/test_spatiotemporal_conditioner.py new file mode 100755 index 0000000..b525237 --- /dev/null +++ b/scratch/transformer/test_spatiotemporal_conditioner.py @@ -0,0 +1,446 @@ +#!/usr/bin/env python3 +""" +Test script for loading and testing a conditional denoiser with spatiotemporal conditioner. +Uses the new approach where SpatioTemporalConditioner outputs [y.pos, spatial_features] +and E3ConvConditionalSpatioTemporal handles concatenated inputs. +""" + +import e3nn + +e3nn.set_optimization_defaults(jit_script_fx=False) + +import sys +from typing import Any + +import torch +import torch_geometric + +# Add the src directory to path to import jamun modules +sys.path.insert(0, "src") + +from jamun.data import parse_datasets_from_directory +from jamun.distributions._distributions import ConstantSigma +from jamun.model.arch.e3conv import E3Conv +from jamun.model.arch.e3conv_conditional import ( + E3ConvConditionalSpatioTemporal, # Changed from E3ConvConditionalWithInputAttr +) +from jamun.model.arch.spatiotemporal import E3SpatioTemporal, E3Transformer +from jamun.model.conditioners.conditioners import SpatioTemporalConditioner +from jamun.model.denoiser_conditional import Denoiser # Changed from DenoiserWithInputAttr +from jamun.model.pooling import SpatialTemporalToTemporalNodeAttr, TemporalToSpatialNodeAttrMean +from jamun.utils.average_squared_distance import ( + compute_temporal_average_squared_distance_from_datasets, # Import temporal function +) + +# Setup device +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(f"Using device: {device}") + + +def create_spatial_module() -> E3Conv: + """Create E3Conv spatial module with reasonable parameters.""" + import functools + + import e3tools + + # Create factory functions + hidden_layer_factory = functools.partial(e3tools.nn.ConvBlock, conv=functools.partial(e3tools.nn.Conv)) + + output_head_factory = functools.partial(e3tools.nn.EquivariantMLP, irreps_hidden_list=["120x0e + 32x1e"]) + + return E3Conv( + irreps_out="3x1e", # Changed to match temporal module input + irreps_hidden="120x0e + 32x1e", + irreps_sh="1x0e + 1x1e", + hidden_layer_factory=hidden_layer_factory, + output_head_factory=output_head_factory, + n_layers=1, + edge_attr_dim=64, + use_residue_information=True, + atom_type_embedding_dim=8, + atom_code_embedding_dim=8, + residue_code_embedding_dim=32, + residue_index_embedding_dim=8, + use_residue_sequence_index=False, + num_atom_types=20, + max_sequence_length=10, + num_atom_codes=10, + num_residue_types=25, + test_equivariance=False, + reduce=None, + ) + + +def create_temporal_module() -> E3Transformer: + """Create E3Transformer temporal module.""" + return E3Transformer( + irreps_out="3x1e", # Final spatial features output + irreps_hidden="8x0e + 4x1e", + irreps_sh="1x0e + 1x1e", + irreps_node_attr="3x1e", # Match spatial module output + num_layers=2, + edge_attr_dim=24, + num_attention_heads=1, + reduce=None, + ) + + +def create_spatiotemporal_model() -> E3SpatioTemporal: + """Create the complete E3SpatioTemporal model.""" + spatial_module = create_spatial_module() + temporal_module = create_temporal_module() + + # Create pooling modules + spatial_to_temporal_pooler = SpatialTemporalToTemporalNodeAttr(irreps_out="3x1e") # Match spatial module output + temporal_to_spatial_pooler = TemporalToSpatialNodeAttrMean(irreps_out="3x1e") # Match temporal module output + + # Compute radial cutoff using temporal average squared distance + print("Computing radial cutoff from temporal dataset...") + try: + # Load dataset to compute temporal average squared distance + dataset = parse_datasets_from_directory( + root="/data2/sules/ALA_ALA_enhanced_full_grid/train", + traj_pattern="^(.*).xtc", + pdb_pattern="^(.*).pdb", + subsample=1, + total_lag_time=5, + lag_subsample_rate=1, + max_datasets=2, # Keep small for testing + num_frames=5, # Small number of frames + ) + + # Compute temporal average squared distance + temporal_avg_sq_dist = compute_temporal_average_squared_distance_from_datasets( + [dataset], # Pass as list since function expects multiple datasets + num_samples=50, # Use fewer samples for testing + verbose=True, + ) + + # Use a multiple of the temporal average squared distance as the radial cutoff + # Typically we might use sqrt(temporal_avg_sq_dist) * some_factor + import math + + radial_cutoff = math.sqrt(temporal_avg_sq_dist) * 2.0 # Scale factor of 2.0 + print(f"Computed radial cutoff: {radial_cutoff:.6f} nm") + + except Exception as e: + print(f"Warning: Failed to compute temporal cutoff ({e}), using default value 0.05") + radial_cutoff = 0.05 + + return E3SpatioTemporal( + spatial_module=spatial_module, + temporal_module=temporal_module, + spatial_to_temporal_pooler=spatial_to_temporal_pooler, + temporal_to_spatial_pooler=temporal_to_spatial_pooler, + radial_cutoff=radial_cutoff, + temporal_cutoff=1.0, + ) + + +def create_spatiotemporal_conditioner() -> SpatioTemporalConditioner: + """Create SpatioTemporalConditioner with E3SpatioTemporal model.""" + spatiotemporal_model = create_spatiotemporal_model() + + return SpatioTemporalConditioner( + N_structures=1, # Changed to 2 for [y.pos, spatial_features] + spatiotemporal_model=spatiotemporal_model, + c_noise=0.0, + freeze_spatiotemporal_model=False, # Keep trainable + ) + + +def create_conditional_denoiser_config() -> dict[str, Any]: + """Create configuration for Denoiser with spatiotemporal conditioner.""" + import functools + + import e3tools.nn + + def create_arch(): + """Create the E3ConvConditionalSpatioTemporal architecture module.""" + # Hidden layer factory + hidden_layer_factory = functools.partial(e3tools.nn.ConvBlock, conv=functools.partial(e3tools.nn.Conv)) + + # Output head factory + output_head_factory = functools.partial(e3tools.nn.EquivariantMLP, irreps_hidden_list=["16x0e + 8x1e"]) + + return E3ConvConditionalSpatioTemporal( + irreps_out="1x1e", # Output should be 3 components (1x1e) to match position + irreps_hidden="16x0e + 8x1e", + irreps_sh="1x0e + 1x1e", + hidden_layer_factory=hidden_layer_factory, + output_head_factory=output_head_factory, + n_layers=2, + edge_attr_dim=32, + use_residue_information=True, + atom_type_embedding_dim=8, + atom_code_embedding_dim=8, + residue_code_embedding_dim=16, + residue_index_embedding_dim=8, + use_residue_sequence_index=False, + num_atom_types=20, + max_sequence_length=10, + num_atom_codes=10, + num_residue_types=25, + test_equivariance=False, + reduce=None, + N_structures=1, # Changed to 2 for [y.pos, spatial_features] + input_attr_irreps="3x1e", # spatial_features only (9 components = 3x1e) + ) + + def create_optim(params): + """Create the optimizer.""" + return torch.optim.Adam(params, lr=0.001) + + return { + # Required Denoiser parameters (changed from DenoiserWithInputAttr) + "arch": create_arch, + "optim": create_optim, + "sigma_distribution": ConstantSigma(sigma=0.1), + "max_radius": 1000.0, + "average_squared_distance": 10.0, # Dummy value for testing + "add_fixed_noise": False, + "add_fixed_ones": False, + "align_noisy_input_during_training": True, + "align_noisy_input_during_evaluation": True, + "mean_center": True, + "mirror_augmentation_rate": 0.0, + "bond_loss_coefficient": 1.0, + "normalization_type": "JAMUN", + "sigma_data": None, + "lr_scheduler_config": None, + "use_torch_compile": False, # Disable for testing + "torch_compile_kwargs": None, + "conditioner": create_spatiotemporal_conditioner(), + } + + +def add_edges_to_batch(batch: torch_geometric.data.Batch, cutoff: float = 0.05) -> torch_geometric.data.Batch: + """Add edges to batch using existing utility from denoiser.""" + # Use e3tools radius_graph directly since we don't need the full denoiser add_edges logic + import e3tools + + if hasattr(batch, "edge_index") and batch.edge_index is not None: + return batch + + # Add radius-based edges + edge_index = e3tools.radius_graph(batch.pos, cutoff, batch.batch) + batch.edge_index = edge_index + + # Add bonded edges if they exist + if hasattr(batch, "bonded_edge_index") and batch.bonded_edge_index is not None: + bond_mask = torch.cat( + [ + torch.zeros(edge_index.shape[1], dtype=torch.long, device=batch.pos.device), + torch.ones(batch.bonded_edge_index.shape[1], dtype=torch.long, device=batch.pos.device), + ] + ) + batch.edge_index = torch.cat([edge_index, batch.bonded_edge_index], dim=1) + batch.bond_mask = bond_mask + else: + batch.bond_mask = torch.zeros(edge_index.shape[1], dtype=torch.long, device=batch.pos.device) + + return batch + + +def load_test_data(): + """Load ALA_ALA test dataset.""" + print("Loading ALA_ALA dataset...") + + dataset = parse_datasets_from_directory( + root="/data2/sules/ALA_ALA_enhanced_full_grid/train", + traj_pattern="^(.*).xtc", + pdb_pattern="^(.*).pdb", + subsample=1, + total_lag_time=5, + lag_subsample_rate=1, + max_datasets=2, # Keep small for testing + num_frames=5, # Small number of frames + ) + + print(f"Loaded dataset with {len(dataset)} samples") + + # Get a sample and create batch + graph = dataset[0].__getitem__(0) + batch = torch_geometric.data.Batch.from_data_list([graph]) + + # Add edges + batch = add_edges_to_batch(batch, cutoff=0.05) + + # Move to device + batch = batch.to(device) + + print("Batch info:") + print(f" - pos shape: {batch.pos.shape}") + print(f" - edge_index shape: {batch.edge_index.shape}") + print( + f" - hidden_state length: {len(batch.hidden_state) if hasattr(batch, 'hidden_state') and batch.hidden_state else 0}" + ) + if hasattr(batch, "hidden_state") and batch.hidden_state: + print(f" - hidden_state[0] shape: {batch.hidden_state[0].shape}") + + return batch + + +def test_spatiotemporal_conditioner(conditioner: SpatioTemporalConditioner, batch: torch_geometric.data.Batch): + """Test the spatiotemporal conditioner.""" + print("\n" + "=" * 50) + print("TESTING SPATIOTEMPORAL CONDITIONER") + print("=" * 50) + + try: + # Test forward pass + conditioned_structures = conditioner(batch) + + print("✅ Conditioner forward pass successful!") + print(f"Number of conditioned structures: {len(conditioned_structures)} (expected: 2)") + print(f"First structure (y.pos) shape: {conditioned_structures[0].shape}") + print(f"Second structure (spatial_features) shape: {conditioned_structures[1].shape}") + print(f"Original position shape: {batch.pos.shape}") + print(f"Position difference norm: {torch.norm(conditioned_structures[0] - batch.pos):.6f}") + + # Verify we got exactly two structures + assert len(conditioned_structures) == 2, f"Expected 2 structures, got {len(conditioned_structures)}" + + return True, conditioned_structures + + except Exception as e: + print(f"❌ Conditioner test failed: {e}") + import traceback + + traceback.print_exc() + return False, None + + +def test_conditional_denoiser_creation(): + """Test creating Denoiser with spatiotemporal conditioner.""" + print("\n" + "=" * 50) + print("TESTING DENOISER WITH SPATIOTEMPORAL CONDITIONER CREATION") + print("=" * 50) + + try: + # Create configuration + config = create_conditional_denoiser_config() + + # Create denoiser (this will instantiate all components) + denoiser = Denoiser(**config) + denoiser = denoiser.to(device) + + print("✅ Denoiser created successfully!") + print(f"Denoiser device: {next(denoiser.parameters()).device}") + print(f"Has conditioner: {hasattr(denoiser, 'conditioning_module')}") + print(f"Architecture type: {type(denoiser.g).__name__}") + print(f"Conditioner type: {type(denoiser.conditioning_module).__name__}") + + # Check if spatiotemporal model is properly set up + if hasattr(denoiser.conditioning_module, "spatiotemporal_model"): + st_model = denoiser.conditioning_module.spatiotemporal_model + print(f"SpatioTemporal model type: {type(st_model).__name__}") + print(f"Spatial module type: {type(st_model.spatial_module).__name__}") + print(f"Temporal module type: {type(st_model.temporal_module).__name__}") + + return True, denoiser + + except Exception as e: + print(f"❌ Denoiser creation failed: {e}") + import traceback + + traceback.print_exc() + return False, None + + +def test_denoiser_forward_pass(denoiser: Denoiser, batch: torch_geometric.data.Batch): + """Test the complete denoiser forward pass.""" + print("\n" + "=" * 50) + print("TESTING DENOISER WITH SPATIOTEMPORAL CONDITIONER FORWARD PASS") + print("=" * 50) + + try: + # Test with sigma = 0.1 + sigma = 0.1 + + # Debug: check conditioned structures shapes + conditioned_structures = denoiser.conditioning_module(batch) + print("DEBUG: Conditioned structures shapes:") + for i, struct in enumerate(conditioned_structures): + print(f" Structure {i}: {struct.shape}") + + concatenated = torch.cat([*conditioned_structures], dim=-1) + print(f"DEBUG: Concatenated shape: {concatenated.shape}") + print("DEBUG: Expected irreps: 4x1e = 12 components") + + with torch.no_grad(): + xhat_batch = denoiser.xhat(batch, sigma) + + print("✅ Denoiser forward pass successful!") + print(f"Input shape: {batch.pos.shape}") + print(f"Output shape: {xhat_batch.pos.shape}") + print(f"Output norm: {torch.norm(xhat_batch.pos):.6f}") + print(f"Used sigma: {sigma}") + + # Verify output shapes match input + assert xhat_batch.pos.shape == batch.pos.shape, f"Shape mismatch: {xhat_batch.pos.shape} vs {batch.pos.shape}" + + return True + + except Exception as e: + print(f"❌ Denoiser forward pass failed: {e}") + import traceback + + traceback.print_exc() + return False + + +def main(): + """Main test function.""" + print("=" * 60) + print("CONDITIONAL DENOISER WITH SPATIOTEMPORAL CONDITIONER TEST") + print("=" * 60) + + # Load test data + batch = load_test_data() + + # Test conditioner creation and forward pass + conditioner = create_spatiotemporal_conditioner() + conditioner = conditioner.to(device) + + conditioner_success, conditioned_structures = test_spatiotemporal_conditioner(conditioner, batch) + + if not conditioner_success: + print("❌ Conditioner test failed, stopping here.") + return + + # Test complete denoiser creation + denoiser_success, denoiser = test_conditional_denoiser_creation() + + if not denoiser_success: + print("❌ Denoiser creation failed, stopping here.") + return + + # Test complete forward pass + forward_success = test_denoiser_forward_pass(denoiser, batch) + + # Final summary + print("\n" + "=" * 60) + print("FINAL SUMMARY") + print("=" * 60) + + print("Test Results:") + print(f" - Conditioner test: {'✅ PASSED' if conditioner_success else '❌ FAILED'}") + print(f" - Denoiser creation: {'✅ PASSED' if denoiser_success else '❌ FAILED'}") + print(f" - Forward pass test: {'✅ PASSED' if forward_success else '❌ FAILED'}") + + if conditioner_success and denoiser_success and forward_success: + print("\n🎉 ALL TESTS PASSED!") + print("The conditional denoiser with spatiotemporal conditioner is working correctly!") + else: + print("\n⚠️ Some tests failed. Check the output above for details.") + + # Device memory summary + if torch.cuda.is_available(): + print("\nCUDA Memory:") + print(f" - Allocated: {torch.cuda.memory_allocated(device) / 1e9:.2f} GB") + print(f" - Cached: {torch.cuda.memory_reserved(device) / 1e9:.2f} GB") + + +if __name__ == "__main__": + main() diff --git a/scratch/visualize_fake_enhanced_data.py b/scratch/visualize_fake_enhanced_data.py new file mode 100644 index 0000000..9189cf3 --- /dev/null +++ b/scratch/visualize_fake_enhanced_data.py @@ -0,0 +1,255 @@ +import glob +import itertools +import os +import re +from collections import defaultdict + +import matplotlib.colors as colors +import matplotlib.pyplot as plt +import mdtraj as md +import numpy as np +from tqdm import tqdm + + +def parse_grid_code_from_filename(filename): + """ + Parse grid code from trajectory filename of format traj_{grid_code}_{traj_code}.xtc + """ + basename = os.path.basename(filename) + match = re.match(r"^traj_(\d+)_(\d+)\.xtc$", basename) + if match: + return int(match.group(1)), int(match.group(2)) + return None, None + + +def select_trajectories_with_max_per_grid(traj_files, max_traj_per_grid): + """ + Select trajectories ensuring no grid code has more than max_traj_per_grid trajectories. + """ + grid_trajectories = defaultdict(list) + + # Group trajectories by grid code + for traj_file in traj_files: + grid_code, traj_code = parse_grid_code_from_filename(traj_file) + if grid_code is not None: + grid_trajectories[grid_code].append((traj_file, traj_code)) + + print(f"Found {len(grid_trajectories)} unique grid codes") + + # Limit trajectories per grid code + selected_files = [] + grid_stats = {} + + for grid_code, traj_list in grid_trajectories.items(): + # Sort by trajectory code for deterministic selection + traj_list.sort(key=lambda x: x[1]) + + # Select up to max_traj_per_grid trajectories + selected_count = min(len(traj_list), max_traj_per_grid) + selected_for_grid = traj_list[:selected_count] + + grid_stats[grid_code] = {"total": len(traj_list), "selected": selected_count} + + for traj_file, _ in selected_for_grid: + selected_files.append(traj_file) + + # Print statistics + print("\nGrid code statistics:") + print(f"Total grid codes: {len(grid_stats)}") + total_original = sum(stats["total"] for stats in grid_stats.values()) + total_selected = sum(stats["selected"] for stats in grid_stats.values()) + print(f"Total trajectories: {total_original} -> {total_selected}") + print(f"Max trajectories per grid: {max_traj_per_grid}") + + # Show distribution + selected_counts = [stats["selected"] for stats in grid_stats.values()] + print("Distribution of selected trajectories per grid:") + for count in sorted(set(selected_counts)): + num_grids = sum(1 for c in selected_counts if c == count) + print(f" {count} trajectories: {num_grids} grid codes") + + return sorted(selected_files) + + +def create_ramachandran_plot(traj_path, topology, output_dir): + """ + Loads a trajectory, computes phi and psi angles, and saves a Ramachandran plot. + """ + # Load trajectory + try: + traj = md.load(traj_path, top=topology) + except Exception as e: + print(f"Could not load trajectory {traj_path}. Error: {e}") + return + + # Compute dihedral angles + phi_indices, phi_angles = md.compute_phi(traj) + psi_indices, psi_angles = md.compute_psi(traj) + + # Convert radians to degrees + phi_degrees = np.rad2deg(phi_angles.flatten()) + psi_degrees = np.rad2deg(psi_angles.flatten()) + + # Create plot + plt.figure(figsize=(8, 8)) + # Use hexbin for a nicer look + plt.hexbin(phi_degrees, psi_degrees, gridsize=180, cmap="viridis", mincnt=1) + plt.colorbar(label="Count in bin") + plt.title(f"Ramachandran Plot for {os.path.basename(traj_path)}") + plt.xlabel("Phi (degrees)") + plt.ylabel("Psi (degrees)") + plt.xlim(-180, 180) + plt.ylim(-180, 180) + plt.grid(True, linestyle="--", alpha=0.6) + plt.axhline(0, color="k", linestyle="--", linewidth=0.5) + plt.axvline(0, color="k", linestyle="--", linewidth=0.5) + + # Save plot + output_filename = f"ramachandran_{os.path.basename(traj_path).replace('.xtc', '.png')}" + output_path = os.path.join(output_dir, output_filename) + plt.savefig(output_path, dpi=300, bbox_inches="tight") + plt.close() + + +def create_histogram_plot(dihedrals, name1, name2, output_dir, name_string): + """ + Creates a 2D histogram with density for a pair of dihedrals. + """ + # Flatten all data for the pair of dihedrals + all_x_data = np.concatenate(dihedrals[name1]) + all_y_data = np.concatenate(dihedrals[name2]) + + plt.figure(figsize=(10, 10)) + + # Create 2D histogram with density + plt.hist2d( + all_x_data, + all_y_data, + range=((-np.pi, np.pi), (-np.pi, np.pi)), + bins=100, + cmap="viridis", + alpha=0.8, + norm=colors.LogNorm(), + ) + plt.colorbar(label="Density") + + plt.title(f"Histogram (Density): {name1} vs {name2}") + plt.xlabel(f"{name1} (radians)") + plt.ylabel(f"{name2} (radians)") + plt.xlim(-np.pi, np.pi) + plt.ylim(-np.pi, np.pi) + plt.grid(True, linestyle="--", alpha=0.6) + plt.axhline(0, color="k", linestyle="--", linewidth=0.5) + plt.axvline(0, color="k", linestyle="--", linewidth=0.5) + + output_filename = f"histogram_density_{name1}_vs_{name2}_{name_string}.png" + output_path = os.path.join(output_dir, output_filename) + plt.savefig(output_path, dpi=300, bbox_inches="tight") + plt.close() + + +def main(max_traj_per_grid=10): + """ + Loads trajectories, computes dihedral angles, and creates pairwise scatter plots and histograms. + Only keeps up to max_traj_per_grid trajectories per grid code. + """ + data_dir = "/data2/sules/fake_enhanced_data/ALA_ALA" + pdb_path = "/data/bucket/kleinhej/capped_diamines/timewarp_splits/train/ALA_ALA.pdb" + output_dir = f"/data2/sules/ramachandran_plots_ala_ala_fake_enhanced_data_max{max_traj_per_grid}" + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + try: + topology = md.load_pdb(pdb_path) + except Exception as e: + print(f"Could not load topology file {pdb_path}. Error: {e}") + return + + # Get all trajectory files + all_traj_files = sorted(glob.glob(os.path.join(data_dir, "traj_*.xtc"))) + print(f"Found {len(all_traj_files)} total trajectory files") + + if not all_traj_files: + print(f"No trajectory files found in {data_dir}") + return + + # Select trajectories with max per grid code + traj_files = select_trajectories_with_max_per_grid(all_traj_files, max_traj_per_grid) + print(f"Selected {len(traj_files)} trajectory files after filtering") + + all_phi_angles = [] + all_psi_angles = [] + num_phi, num_psi = None, None + + for traj_file in tqdm(traj_files, desc="Loading trajectories"): + try: + traj = md.load(traj_file, top=topology) + _, phi_angles = md.compute_phi(traj) + _, psi_angles = md.compute_psi(traj) + + if num_phi is None: + num_phi = phi_angles.shape[1] + num_psi = psi_angles.shape[1] + + all_phi_angles.append(phi_angles[:100, :]) + all_psi_angles.append(psi_angles[:100, :]) + except Exception as e: + print(f"Could not load or process trajectory {traj_file}. Error: {e}") + continue + + if not all_phi_angles or not all_psi_angles: + print("No valid trajectories were processed.") + return + + # Dynamically create dihedral dictionary + dihedrals = {} + for i in range(num_phi): + dihedrals[f"phi_{i + 1}"] = [angles[:, i] for angles in all_phi_angles] + for i in range(num_psi): + dihedrals[f"psi_{i + 1}"] = [angles[:, i] for angles in all_psi_angles] + + dihedral_names = list(dihedrals.keys()) + + # Create line plots (existing functionality) + create_line_plots = False + if create_line_plots: + print("Creating line plots...") + for name1, name2 in itertools.combinations(dihedral_names, 2): + plt.figure(figsize=(10, 10)) + + for i in tqdm(range(len(traj_files)), desc=f"Plotting {name1} vs {name2}"): + x_angles = dihedrals[name1][i] + y_angles = dihedrals[name2][i] + plt.plot(x_angles, y_angles, linestyle="-", alpha=0.5) + plt.scatter(x_angles[0], y_angles[0], c="white", marker="o", edgecolor="black", s=50, zorder=5) + + plt.title(f"Ramachandran Plot: {name1} vs {name2}") + plt.xlabel(f"{name1} (degrees)") + plt.ylabel(f"{name2} (degrees)") + plt.xlim(-180, 180) + plt.ylim(-180, 180) + plt.grid(True, linestyle="--", alpha=0.6) + plt.axhline(0, color="k", linestyle="--", linewidth=0.5) + plt.axvline(0, color="k", linestyle="--", linewidth=0.5) + + output_filename = f"ramachandran_{name1}_vs_{name2}.png" + output_path = os.path.join(output_dir, output_filename) + plt.savefig(output_path, dpi=300, bbox_inches="tight") + plt.close() + + create_histogram_plots = True + if create_histogram_plots: + # Create histogram plots (new functionality) + print("Creating histogram plots with density...") + for name1, name2 in tqdm(itertools.combinations(dihedral_names, 2), desc="Creating histograms"): + create_histogram_plot(dihedrals, name1, name2, output_dir, f"100_frames_max{max_traj_per_grid}") + + print(f"\nDone. Histogram plots are saved in {output_dir}") + print(f"Used max {max_traj_per_grid} trajectories per grid code") + + +if __name__ == "__main__": + # Set max trajectories per grid code to 10 + max_traj = 10 + main(max_traj_per_grid=max_traj) diff --git a/scratch/visualize_noise_denoise.py b/scratch/visualize_noise_denoise.py new file mode 100644 index 0000000..c329688 --- /dev/null +++ b/scratch/visualize_noise_denoise.py @@ -0,0 +1,251 @@ +#!/usr/bin/env python3 +""" +Script to analyze validation trajectories using a trained model. +Generates Ramachandran plots for clean, noisy, and denoised samples. +""" + +import glob +import os +import pdb +import tempfile +from pathlib import Path + +import matplotlib.pyplot as plt +import mdtraj as md +import numpy as np +import torch +import torch_geometric +from tqdm import tqdm + +pdb.set_trace() +from jamun.data import MDtrajDataset, parse_datasets_from_directory +from jamun.metrics._visualize_denoise import plot_ramachandran_grid +from jamun.model.denoiser_conditional import Denoiser +from jamun.utils.checkpoint import find_checkpoint + +# from jamun.model.denoiser import Denoiser as Denoiser_unconditional + + +def load_model_from_wandb(wandb_run_path: str, checkpoint_type: str = "last", checkpoint_path: str = None): + """Load model from wandb run.""" + print(f"Loading model from {wandb_run_path}...") + + # Use jamun utilities to find the checkpoint + checkpoint_path_wandb = find_checkpoint(wandb_train_run_path=wandb_run_path, checkpoint_type=checkpoint_type) + + if checkpoint_path is None: + checkpoint_path = checkpoint_path_wandb + + print(f"Loading model from checkpoint: {checkpoint_path}") + + # Load the model + model = Denoiser.load_from_checkpoint(checkpoint_path) + model.eval() + + print("✓ Model loaded successfully") + return model + + +def create_dataset_from_trajectory(traj_file: str, pdb_file: str, total_lag_time: int = 2): + """Create a dataset from a single trajectory file.""" + # Create temporary directory structure expected by parse_datasets_from_directory + temp_dir = tempfile.mkdtemp() + + # Copy trajectory file to temp directory + traj_name = Path(traj_file).stem + temp_traj_path = os.path.join(temp_dir, f"{traj_name}.xtc") + temp_pdb_path = os.path.join(temp_dir, f"{traj_name}.pdb") + + # Create symlinks + os.symlink(traj_file, temp_traj_path) + os.symlink(pdb_file, temp_pdb_path) + + # Parse dataset + datasets = parse_datasets_from_directory( + root=temp_dir, + traj_pattern="^(.*).xtc", + pdb_pattern="^(.*).pdb", + as_iterable=False, + subsample=1, + total_lag_time=total_lag_time, + lag_subsample_rate=1, + max_datasets=1, + label_override=traj_name, + ) + + return datasets[0] if datasets else None + + +def process_trajectory(model, dataset: MDtrajDataset, sigma: float = 0.04): + """Process a single trajectory through the model.""" + # Create dataloader + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=32, shuffle=False, collate_fn=torch_geometric.data.Batch.from_data_list + ) + + # Store all samples + all_clean = [] + all_noisy = [] + all_denoised = [] + + model.eval() + with torch.no_grad(): + for batch in dataloader: + batch = batch.to(model.device) + + # # Ensure all batch attributes are the correct dtype + # if hasattr(batch, 'pos'): + # batch.pos = batch.pos.float() + # if hasattr(batch, 'batch'): + # batch.batch = batch.batch.long() + # if hasattr(batch, 'edge_index'): + # batch.edge_index = batch.edge_index.long() # Keep as long for indexing! + # if hasattr(batch, 'edge_attr') and batch.edge_attr is not None: + # batch.edge_attr = batch.edge_attr.float() + + # Convert sigma to tensor with correct dtype and device + sigma_tensor = torch.tensor(sigma, dtype=torch.float32, device=model.device) + + # Run noise and denoise + _, xhat, y = model.noise_and_denoise( + batch, sigma_tensor, align_noisy_input=model.align_noisy_input_during_evaluation + ) + + # Convert to data lists + clean_samples = torch_geometric.data.Batch.to_data_list(batch) + noisy_samples = torch_geometric.data.Batch.to_data_list(y) + denoised_samples = torch_geometric.data.Batch.to_data_list(xhat) + + all_clean.extend(clean_samples) + all_noisy.extend(noisy_samples) + all_denoised.extend(denoised_samples) + + return all_clean, all_noisy, all_denoised + + +def samples_to_trajectory(samples: list, dataset: MDtrajDataset): + """Convert list of samples to MDTraj trajectory.""" + coordinates = [] + for sample in samples: + coords = sample.pos.cpu().numpy() + coordinates.append(coords) + + coords_array = np.array(coordinates) # Shape: (n_frames, n_atoms, 3) + + # Create trajectory + traj = md.Trajectory(coords_array, dataset.topology) + return traj + + +def create_ramachandran_plot(clean_traj, noisy_traj, denoised_traj, title: str, save_path: str): + """Create and save Ramachandran plot for three trajectories.""" + trajs = {"x": clean_traj, "y": noisy_traj, "xhat": denoised_traj} + + try: + fig, axes = plot_ramachandran_grid(trajs, title) + fig.savefig(save_path, dpi=300, bbox_inches="tight") + plt.close(fig) + print(f"✓ Saved Ramachandran plot: {save_path}") + except Exception as e: + print(f"✗ Error creating Ramachandran plot for {title}: {e}") + + +def main(): + # Configuration + wandb_run_path = "sule-shashank/jamun/4p0ejn0z" + val_dir = "/data2/sules/ALA_ALA_enhanced_full_grid/val" + pdb_file = "/data/bucket/kleinhej/capped_diamines/timewarp_splits/train/ALA_ALA.pdb" + checkpoint_type = "epoch=49-step=52900-v1.ckpt" + # checkpoint_path = "/data2/sules/jamun-conditional-runs/outputs/train/dev/runs/2025-07-31_00-43-14/checkpoints/epoch=9-step=10051.ckpt" + total_lag_time = 5 + sigma = 0.04 + output_dir = "val_ramachandrans" + + # Create output directory + os.makedirs(output_dir, exist_ok=True) + + # Load model + checkpoint_path = find_checkpoint(wandb_run_path, checkpoint_type=checkpoint_type) + model = Denoiser.load_from_checkpoint(checkpoint_path) + model.eval() + model.to("cuda:0") + print("Model loaded and moved to cuda:0") + # config_path = "/data2/sules/jamun-conditional-runs//outputs/train/dev/runs/2025-08-05_04-24-31/wandb/run-20250805_042516-yqn9mm7x/files/config.yaml" + # cfg = OmegaConf.load(config_path) + # checkpoint_path = "/data2/sules/jamun-conditional-runs//outputs/train/dev/runs/2025-08-05_04-24-31/checkpoints/last.ckpt" + # model = hydra.utils.instantiate(cfg.cfg.value.model) + # checkpoint = torch.load(checkpoint_path, weights_only=False) + # model.load_state_dict(checkpoint['state_dict']) + # model.eval() + # model.to('cuda:0') + # print(f"Model loaded and moved to cuda:0") + print(f"Model device: {next(model.parameters()).device}") + # Get all trajectory files + traj_files = glob.glob(os.path.join(val_dir, "*.xtc")) + traj_files.sort() + + print(f"Found {len(traj_files)} trajectory files") + # do one trial run + dataset = create_dataset_from_trajectory(traj_files[0], pdb_file, total_lag_time) + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=32, shuffle=False, collate_fn=torch_geometric.data.Batch.from_data_list + ) + _, batch = next(enumerate(dataloader)) + + batch = batch.to(model.device) + sigma_tensor = torch.tensor(sigma, dtype=torch.float32, device=model.device) + _, xhat, y = model.noise_and_denoise( + batch, sigma_tensor, align_noisy_input=model.align_noisy_input_during_evaluation + ) + # Store all samples for concatenated analysis + all_clean_samples = [] + all_noisy_samples = [] + all_denoised_samples = [] + breakpoint() + # Process each trajectory with progress bar + for traj_file in tqdm(traj_files, desc="Processing trajectories"): + traj_name = Path(traj_file).stem + + try: + # Create dataset + dataset = create_dataset_from_trajectory(traj_file, pdb_file, total_lag_time) + if dataset is None: + tqdm.write(f"Failed to create dataset for {traj_name}") + continue + + # Process trajectory + # breakpoint() + clean_samples, noisy_samples, denoised_samples = process_trajectory(model, dataset, sigma) + + # Store samples for concatenated analysis + all_clean_samples.extend(clean_samples) + all_noisy_samples.extend(noisy_samples) + all_denoised_samples.extend(denoised_samples) + + except Exception as e: + tqdm.write(f"Error processing {traj_name}: {e}") + continue + + # Create concatenated analysis + if all_clean_samples: + print("\nCreating concatenated Ramachandran plot...") + + # Use the last dataset for topology (they should all be the same) + concat_clean_traj = samples_to_trajectory(all_clean_samples, dataset) + concat_noisy_traj = samples_to_trajectory(all_noisy_samples, dataset) + concat_denoised_traj = samples_to_trajectory(all_denoised_samples, dataset) + + # Create concatenated plot + concat_plot_path = os.path.join(output_dir, "concatenated_ramachandran.png") + create_ramachandran_plot( + concat_clean_traj, concat_noisy_traj, concat_denoised_traj, "Concatenated Trajectories", concat_plot_path + ) + + print(f"\nAnalysis complete! Concatenated Ramachandran plot saved in {output_dir}/") + print(f"Processed {len(all_clean_samples)} total samples from {len(traj_files)} trajectories") + else: + print("No samples were processed successfully.") + + +if __name__ == "__main__": + main() diff --git a/scratch/visualize_traj_data.py b/scratch/visualize_traj_data.py new file mode 100644 index 0000000..5cb9780 --- /dev/null +++ b/scratch/visualize_traj_data.py @@ -0,0 +1,60 @@ +import itertools +import os + +import matplotlib.colors as colors +import matplotlib.pyplot as plt +import mdtraj as md +import numpy as np + +# Set file paths for ALA_ALA in capped diamines +xtc_file = "/data/bucket/kleinhej/capped_diamines/timewarp_splits/train/ALA_ALA.xtc" +pdb_file = "/data/bucket/kleinhej/capped_diamines/timewarp_splits/train/ALA_ALA.pdb" + +output_dir = "/data2/sules/ramachandran_plots_ala_ala_fake_enhanced_data" +os.makedirs(output_dir, exist_ok=True) + +print(f"XTC file exists: {os.path.exists(xtc_file)}") +print(f"PDB file exists: {os.path.exists(pdb_file)}") + +# Load the trajectory (subsample=1 means load all frames) +traj = md.load(xtc_file, top=pdb_file) +print(f"Loaded trajectory with {traj.n_frames} frames and {traj.n_atoms} atoms.") + +# Compute backbone dihedrals (phi and psi) +phi_indices, phi_angles = md.compute_phi(traj) +psi_indices, psi_angles = md.compute_psi(traj) + +num_phi = phi_angles.shape[1] +num_psi = psi_angles.shape[1] + +# Collect all dihedral arrays in a dict for easy access +# Each entry is (n_frames,) +dihedrals = {} +for i in range(num_phi): + dihedrals[f"phi_{i + 1}"] = phi_angles[:, i] +for i in range(num_psi): + dihedrals[f"psi_{i + 1}"] = psi_angles[:, i] + +dihedral_names = list(dihedrals.keys()) + +# Make 2D histograms for all pairs +for name1, name2 in itertools.combinations(dihedral_names, 2): + x = dihedrals[name1] + y = dihedrals[name2] + plt.figure(figsize=(8, 8)) + plt.hist2d(x, y, bins=100, range=((-np.pi, np.pi), (-np.pi, np.pi)), cmap="viridis", norm=colors.LogNorm()) + plt.colorbar(label="Density") + plt.title(f"2D Histogram: {name1} vs {name2}") + plt.xlabel(f"{name1} (radians)") + plt.ylabel(f"{name2} (radians)") + plt.xlim(-np.pi, np.pi) + plt.ylim(-np.pi, np.pi) + plt.grid(True, linestyle="--", alpha=0.6) + plt.axhline(0, color="k", linestyle="--", linewidth=0.5) + plt.axvline(0, color="k", linestyle="--", linewidth=0.5) + output_filename = f"hist2d_{name1}_vs_{name2}_true_distribution.png" + output_path = os.path.join(output_dir, output_filename) + plt.savefig(output_path, dpi=300, bbox_inches="tight") + plt.close() + +print(f"All 2D histograms saved in {output_dir}") diff --git a/scripts/analyze_sweep_results.py b/scripts/analyze_sweep_results.py new file mode 100755 index 0000000..3e26e65 --- /dev/null +++ b/scripts/analyze_sweep_results.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 +""" +Script to analyze results from the delta-friction parameter sweep. +""" + +import numpy as np +import math +import argparse +import wandb +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns + +def calculate_parameter_grid(): + """Calculate the parameter grid used in the sweep.""" + sigma = 0.04 + + # Delta values + delta_min = sigma / math.sqrt(5) + delta_max = math.sqrt(5) * sigma + deltas = np.linspace(delta_min, delta_max, 5) + + # Friction values + linear_points = np.linspace(0.01, 0.99, 5) + frictions = [-math.log(p) for p in linear_points] + + return deltas, frictions, sigma + +def fetch_sweep_results(project_name="sule-shashank/jamun"): + """Fetch results from wandb for the parameter sweep.""" + api = wandb.Api() + + # Get runs with the sweep tag + runs = api.runs(project_name, filters={"tags": {"$in": ["sweep", "delta_friction"]}}) + + results = [] + for run in runs: + if run.state == "finished": + # Extract parameters from tags or config + delta = None + friction = None + + # Try to extract from tags first + for tag in run.tags: + if tag.startswith("delta_"): + try: + delta = float(tag.replace("delta_", "")) + except ValueError: + pass + elif tag.startswith("friction_"): + try: + friction = float(tag.replace("friction_", "")) + except ValueError: + pass + + # Try to extract from config if not found in tags + if delta is None: + delta = run.config.get("delta") + if friction is None: + friction = run.config.get("friction") + + if delta is not None and friction is not None: + # Get metrics (adjust these based on what metrics you're interested in) + metrics = {} + if run.summary: + # Add metrics you want to analyze + for key in ["sampling_time", "chemical_validity", "ramachandran_score"]: + if key in run.summary: + metrics[key] = run.summary[key] + + results.append({ + "run_id": run.id, + "run_name": run.name, + "delta": delta, + "friction": friction, + **metrics + }) + + return pd.DataFrame(results) + +def create_heatmaps(df, deltas, frictions): + """Create heatmaps for each metric.""" + if df.empty: + print("No results found to plot.") + return + + # Get metric columns (exclude parameter and metadata columns) + metric_cols = [col for col in df.columns if col not in ["run_id", "run_name", "delta", "friction"]] + + if not metric_cols: + print("No metrics found in the data.") + return + + # Create a figure with subplots for each metric + n_metrics = len(metric_cols) + fig, axes = plt.subplots(1, n_metrics, figsize=(6*n_metrics, 5)) + if n_metrics == 1: + axes = [axes] + + for i, metric in enumerate(metric_cols): + # Create a pivot table for the heatmap + pivot_data = df.pivot(index="friction", columns="delta", values=metric) + + # Create heatmap + sns.heatmap( + pivot_data, + ax=axes[i], + annot=True, + fmt=".4f", + cmap="viridis", + cbar_kws={"label": metric} + ) + axes[i].set_title(f"{metric.replace('_', ' ').title()}") + axes[i].set_xlabel("Delta") + axes[i].set_ylabel("Friction") + + plt.tight_layout() + plt.savefig("sweep_results_heatmap.png", dpi=300, bbox_inches="tight") + plt.show() + +def print_summary(df, deltas, frictions): + """Print a summary of the sweep results.""" + print("\n" + "="*60) + print("PARAMETER SWEEP SUMMARY") + print("="*60) + + print(f"Parameter grid:") + print(f" Deltas: {len(deltas)} values from {deltas[0]:.6f} to {deltas[-1]:.6f}") + print(f" Frictions: {len(frictions)} values from {frictions[0]:.6f} to {frictions[-1]:.6f}") + print(f" Total combinations: {len(deltas) * len(frictions)}") + + print(f"\nResults found: {len(df)} / {len(deltas) * len(frictions)}") + + if not df.empty: + print(f"\nMetrics available:") + metric_cols = [col for col in df.columns if col not in ["run_id", "run_name", "delta", "friction"]] + for metric in metric_cols: + print(f" - {metric}") + + print(f"\nBest performing combinations:") + for metric in metric_cols: + if metric in df.columns: + if "time" in metric.lower(): + # For time metrics, lower is better + best_idx = df[metric].idxmin() + print(f" {metric} (lowest): delta={df.loc[best_idx, 'delta']:.6f}, friction={df.loc[best_idx, 'friction']:.6f}, value={df.loc[best_idx, metric]:.6f}") + else: + # For other metrics, higher is usually better + best_idx = df[metric].idxmax() + print(f" {metric} (highest): delta={df.loc[best_idx, 'delta']:.6f}, friction={df.loc[best_idx, 'friction']:.6f}, value={df.loc[best_idx, metric]:.6f}") + +def main(): + parser = argparse.ArgumentParser(description="Analyze delta-friction parameter sweep results") + parser.add_argument("--project", default="sule-shashank/jamun", help="Wandb project name") + parser.add_argument("--plot", action="store_true", help="Create heatmap plots") + parser.add_argument("--save-csv", help="Save results to CSV file") + + args = parser.parse_args() + + # Calculate parameter grid + deltas, frictions, sigma = calculate_parameter_grid() + + print("Fetching results from wandb...") + df = fetch_sweep_results(args.project) + + # Print summary + print_summary(df, deltas, frictions) + + # Save to CSV if requested + if args.save_csv: + df.to_csv(args.save_csv, index=False) + print(f"\nResults saved to {args.save_csv}") + + # Create plots if requested + if args.plot and not df.empty: + create_heatmaps(df, deltas, frictions) + + print("\nAnalysis complete!") + +if __name__ == "__main__": + main() diff --git a/scripts/concatenate_trajectories.py b/scripts/concatenate_trajectories.py new file mode 100755 index 0000000..9d4955e --- /dev/null +++ b/scripts/concatenate_trajectories.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 + +import os +import glob +import mdtraj as md +import numpy as np +from pathlib import Path +import argparse +from tqdm import tqdm + +def concatenate_trajectories(folder_path, pdb_file, output_name="ALA_ALA.xtc"): + """ + Concatenate all .xtc trajectories in a folder into a single long trajectory. + + Args: + folder_path (str): Path to the folder containing .xtc files + pdb_file (str): Path to the PDB topology file + output_name (str): Name of the output trajectory file + """ + folder_path = Path(folder_path) + + # Find all .xtc files in the folder + xtc_files = sorted(glob.glob(str(folder_path / "*.xtc"))) + + # Filter out any existing ALA_ALA.xtc to avoid including it in concatenation + xtc_files = [f for f in xtc_files if not f.endswith("ALA_ALA.xtc")] + + if not xtc_files: + print(f"No .xtc files found in {folder_path}") + return + + print(f"Found {len(xtc_files)} .xtc files in {folder_path}") + print(f"First few files: {xtc_files[:5]}") + + # Load the first trajectory to get the topology + print("Loading first trajectory to get topology...") + first_traj = md.load(xtc_files[0], top=pdb_file) + print(f"Topology: {first_traj.n_atoms} atoms, {first_traj.n_frames} frames") + + # Initialize the concatenated trajectory with the first one + concat_traj = first_traj + + # Load and concatenate the rest of the trajectories + for xtc_file in tqdm(xtc_files[1:], desc="Concatenating trajectories", unit="file"): + try: + traj = md.load(xtc_file, top=pdb_file) + concat_traj = concat_traj.join(traj) + + except Exception as e: + tqdm.write(f"Error loading {os.path.basename(xtc_file)}: {e}") + continue + + # Save the concatenated trajectory + output_path = folder_path / output_name + print(f"Saving concatenated trajectory to {output_path}") + print(f"Final trajectory: {concat_traj.n_frames} frames, {concat_traj.n_atoms} atoms") + + concat_traj.save_xtc(str(output_path)) + print(f"Successfully saved {output_path}") + + return concat_traj + +def main(): + parser = argparse.ArgumentParser(description="Concatenate .xtc trajectories in folders") + parser.add_argument("--base-dir", + default="/data2/sules/fake_enhanced_data/ALA_ALA_organized", + help="Base directory containing train/val/test folders") + parser.add_argument("--pdb-file", + default="/data/bucket/kleinhej/capped_diamines/timewarp_splits/train/ALA_ALA.pdb", + help="PDB topology file") + parser.add_argument("--folders", nargs='+', + default=["train", "val", "test"], + help="Folders to process") + + args = parser.parse_args() + + base_dir = Path(args.base_dir) + pdb_file = args.pdb_file + + # Check if PDB file exists + if not os.path.exists(pdb_file): + print(f"Error: PDB file not found: {pdb_file}") + return + + print(f"Using PDB file: {pdb_file}") + print(f"Base directory: {base_dir}") + + # Process each folder + for folder in args.folders: + folder_path = base_dir / folder + + if not folder_path.exists(): + print(f"Folder {folder_path} does not exist, skipping...") + continue + + print(f"\n{'='*60}") + print(f"Processing folder: {folder}") + print(f"{'='*60}") + + try: + concatenate_trajectories(folder_path, pdb_file) + except Exception as e: + print(f"Error processing {folder}: {e}") + import traceback + traceback.print_exc() + + print(f"\n{'='*60}") + print("Concatenation completed!") + print(f"{'='*60}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/generate_data/generate_swarms.py b/scripts/generate_data/generate_swarms.py new file mode 100755 index 0000000..32486b6 --- /dev/null +++ b/scripts/generate_data/generate_swarms.py @@ -0,0 +1,584 @@ +#!/usr/bin/env python3 + +import argparse +import logging +import os +import glob +from dataclasses import dataclass +from typing import Optional, List, Tuple +from pathlib import Path + +import openmm_utils as op +from openmm.app import ForceField, Simulation, Topology + +logging.basicConfig(format="[%(asctime)s][%(name)s][%(levelname)s] - %(message)s", level=logging.INFO) +py_logger = logging.getLogger("generate_swarms") + + +@dataclass +class SwarmConfig: + """Configuration parameters for swarm trajectory generation""" + + # Input/Output + input_pdbs: List[str] + output_dir: str + + # MD simulation parameters (similar to run_simulation.py) + dt_ps: float = 0.002 + temp_K: float = 300 + pressure_bar: float = 1.0 + position_restraint_k: float = 10.0 # kJ/(mol.A^2) + forcefield: tuple[str, str] = ("amber99sbildn.xml", "tip3p.xml") + padding_nm: float = 1.0 + water_model: str = "tip3p" + positive_ion: str = "Na+" + negative_ion: str = "Cl-" + + # Equilibration parameters + energy_minimization_steps: int = 1500 + nvt_restraint_steps: int = 75_000 # Reduced from run_simulation defaults + npt_restraint_steps: int = 75_000 # Reduced from run_simulation defaults + nvt_equil_steps: int = 100_000 # Reduced from run_simulation defaults + npt_equil_steps: int = 100_000 # Reduced from run_simulation defaults + + # Swarm generation parameters + num_swarms: int = 10 + swarm_steps: int = 10_000 + save_frequency: int = 10 + + # Processing options + save_intermediate_files: bool = False + single_structure_mode: bool = False # For processing just one structure + structure_index: Optional[int] = None # For processing a specific structure by index + + # New options for separated workflow + skip_equilibration: bool = False # Skip equilibration if already done + equilibrate_only: bool = False # Only do equilibration, no swarms + append_swarms: bool = True # Start swarm indexing from existing trajectories + + +def parse_args() -> SwarmConfig: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Generate swarm trajectories from equilibrated structures", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Process folder of PDBs + %(prog)s --input-folder /path/to/pdbs --output-dir results --num-swarms 50 --swarm-steps 10000 + + # Process specific PDB files + %(prog)s --input-pdbs struct1.pdb struct2.pdb --output-dir results --num-swarms 20 --swarm-steps 5000 + + # Process single structure (for parallelization) + %(prog)s --input-pdbs struct1.pdb --output-dir results --single-structure --structure-index 1 + """, + ) + + # Input arguments (mutually exclusive) + input_group = parser.add_mutually_exclusive_group(required=True) + input_group.add_argument( + "--input-folder", + type=str, + help="Folder containing PDB files to process" + ) + input_group.add_argument( + "--input-pdbs", + nargs="+", + help="List of PDB files to process" + ) + + # Output + parser.add_argument( + "--output-dir", + type=str, + required=True, + help="Output directory for swarm trajectories" + ) + + # Simulation parameters + sim_group = parser.add_argument_group("Simulation Parameters") + sim_group.add_argument( + "--dt", type=float, default=SwarmConfig.dt_ps, + help="Timestep in ps (default: %(default)s)" + ) + sim_group.add_argument( + "--temp", type=float, default=SwarmConfig.temp_K, + help="Temperature in K (default: %(default)s)" + ) + sim_group.add_argument( + "--pressure", type=float, default=SwarmConfig.pressure_bar, + help="Pressure in bar (default: %(default)s)" + ) + sim_group.add_argument( + "--position-restraint-k", type=float, default=SwarmConfig.position_restraint_k, + help="Position restraint force constant in kJ/(mol.A^2) (default: %(default)s)" + ) + + # Forcefield options + ff_group = parser.add_argument_group("Forcefield Options") + ff_group.add_argument( + "--forcefield", nargs=2, default=SwarmConfig.forcefield, + metavar=("FF1", "FF2"), help="Forcefield XML files (default: %(default)s)" + ) + + # Equilibration steps + equil_group = parser.add_argument_group("Equilibration Steps") + equil_group.add_argument( + "--energy-minimization-steps", type=int, default=SwarmConfig.energy_minimization_steps, + help="Steps for energy minimization (default: %(default)s)" + ) + equil_group.add_argument( + "--nvt-restraint-steps", type=int, default=SwarmConfig.nvt_restraint_steps, + help="Steps for NVT equilibration with restraints (default: %(default)s)" + ) + equil_group.add_argument( + "--npt-restraint-steps", type=int, default=SwarmConfig.npt_restraint_steps, + help="Steps for NPT equilibration with restraints (default: %(default)s)" + ) + equil_group.add_argument( + "--nvt-equil-steps", type=int, default=SwarmConfig.nvt_equil_steps, + help="Steps for NVT equilibration without restraints (default: %(default)s)" + ) + equil_group.add_argument( + "--npt-equil-steps", type=int, default=SwarmConfig.npt_equil_steps, + help="Steps for NPT equilibration without restraints (default: %(default)s)" + ) + + # Swarm parameters + swarm_group = parser.add_argument_group("Swarm Parameters") + swarm_group.add_argument( + "--num-swarms", type=int, default=SwarmConfig.num_swarms, + help="Number of swarm trajectories to generate per structure (default: %(default)s)" + ) + swarm_group.add_argument( + "--swarm-steps", type=int, default=SwarmConfig.swarm_steps, + help="Number of steps per swarm trajectory (default: %(default)s)" + ) + swarm_group.add_argument( + "--save-frequency", type=int, default=SwarmConfig.save_frequency, + help="Frequency of saving frames in swarm trajectories (default: %(default)s)" + ) + + # Processing options + proc_group = parser.add_argument_group("Processing Options") + proc_group.add_argument( + "--save-intermediate-files", action="store_true", + help="Save intermediate files during equilibration (default: False)" + ) + proc_group.add_argument( + "--single-structure", action="store_true", + help="Process only a single structure (for parallelization)" + ) + proc_group.add_argument( + "--structure-index", type=int, + help="Index of structure to process (0-based, for parallelization)" + ) + + # New workflow options + proc_group.add_argument( + "--skip-equilibration", action="store_true", + help="Skip equilibration if equilibrated_start.pdb already exists (default: False)" + ) + proc_group.add_argument( + "--equilibrate-only", action="store_true", + help="Only perform equilibration, do not generate swarms (default: False)" + ) + proc_group.add_argument( + "--append-swarms", action="store_true", default=True, + help="Start swarm indexing from existing trajectories rather than overwriting (default: True)" + ) + + args = parser.parse_args() + + # Handle input parsing + if args.input_folder: + # Find all PDB files in the folder + pdb_pattern = os.path.join(args.input_folder, "*.pdb") + input_pdbs = sorted(glob.glob(pdb_pattern)) + if not input_pdbs: + raise ValueError(f"No PDB files found in {args.input_folder}") + py_logger.info(f"Found {len(input_pdbs)} PDB files in {args.input_folder}") + else: + input_pdbs = args.input_pdbs + # Verify all files exist + for pdb_file in input_pdbs: + if not os.path.exists(pdb_file): + raise FileNotFoundError(f"PDB file not found: {pdb_file}") + + # Handle single structure processing + if args.single_structure: + if args.structure_index is not None: + if args.structure_index >= len(input_pdbs): + raise ValueError(f"Structure index {args.structure_index} out of range (0-{len(input_pdbs)-1})") + input_pdbs = [input_pdbs[args.structure_index]] + else: + if len(input_pdbs) > 1: + py_logger.warning("Single structure mode with multiple PDbs - processing only the first one") + input_pdbs = [input_pdbs[0]] + + return SwarmConfig( + input_pdbs=input_pdbs, + output_dir=args.output_dir, + dt_ps=args.dt, + temp_K=args.temp, + pressure_bar=args.pressure, + position_restraint_k=args.position_restraint_k, + forcefield=tuple(args.forcefield), + energy_minimization_steps=args.energy_minimization_steps, + nvt_restraint_steps=args.nvt_restraint_steps, + npt_restraint_steps=args.npt_restraint_steps, + nvt_equil_steps=args.nvt_equil_steps, + npt_equil_steps=args.npt_equil_steps, + num_swarms=args.num_swarms, + swarm_steps=args.swarm_steps, + save_frequency=args.save_frequency, + save_intermediate_files=args.save_intermediate_files, + single_structure_mode=args.single_structure, + structure_index=args.structure_index, + skip_equilibration=args.skip_equilibration, + equilibrate_only=args.equilibrate_only, + append_swarms=args.append_swarms, + ) + + +def get_structure_name(pdb_file: str) -> str: + """Get a clean structure name from PDB filename.""" + return os.path.splitext(os.path.basename(pdb_file))[0] + + +def setup_structure_directory(pdb_file: str, config: SwarmConfig, structure_idx: int) -> Tuple[str, str]: + """Create output directory for a structure and return paths.""" + structure_name = get_structure_name(pdb_file) + structure_dir = os.path.join(config.output_dir, f"AA_{structure_idx:03d}") + + os.makedirs(structure_dir, exist_ok=True) + py_logger.info(f"Created structure directory: {structure_dir}") + + return structure_dir, structure_name + + +def find_existing_swarms(structure_dir: str, swarm_steps: int, dt_ps: float) -> int: + """Find existing swarm trajectories and return the next available index.""" + trajectory_time_ps = swarm_steps * dt_ps + pattern = os.path.join(structure_dir, f"swarm_{trajectory_time_ps:.0f}ps_*.xtc") + existing_swarms = glob.glob(pattern) + + if not existing_swarms: + return 0 # Start from 1 if no existing swarms + + # Extract indices from existing filenames + indices = [] + for swarm_file in existing_swarms: + filename = os.path.basename(swarm_file) + # Extract index from filename like "swarm_1ps_001.xtc" + try: + index_part = filename.split('_')[-1].split('.')[0] # Get "001" part + indices.append(int(index_part)) + except (ValueError, IndexError): + continue + + if indices: + next_index = max(indices) + 1 + py_logger.info(f"Found {len(indices)} existing swarm trajectories, starting from index {next_index}") + return next_index + else: + return 0 + + +def check_equilibration_exists(structure_dir: str) -> bool: + """Check if equilibration has already been completed.""" + equilibrated_pdb = os.path.join(structure_dir, "equilibrated_start.pdb") + return os.path.exists(equilibrated_pdb) + + +def equilibrate_structure( + pdb_file: str, + structure_dir: str, + structure_name: str, + config: SwarmConfig +) -> Tuple[op.Positions, op.Velocities, Simulation]: + """ + Equilibrate a single structure starting from solvation. + Returns the equilibrated positions, velocities, and simulation object. + """ + py_logger.info(f"Starting equilibration for {structure_name}") + + # Convert to absolute path before changing directories + pdb_file_abs = os.path.abspath(pdb_file) + + # Change to structure directory + original_dir = os.getcwd() + os.chdir(structure_dir) + + try: + # Load the initial structure (assume it's already fixed and hydrogenated) + from openmm.app import PDBFile + pdb = PDBFile(pdb_file_abs) + positions = pdb.positions + topology = pdb.topology + + # Create forcefield + ff = ForceField(*config.forcefield) + + # Solvate the structure + py_logger.info("Solvating structure...") + positions, topology = op.solvate( + positions, + topology, + ff, + padding_nm=config.padding_nm, + water_model=config.water_model, + positive_ion=config.positive_ion, + negative_ion=config.negative_ion, + output_file_prefix=f"{structure_name}_solvated", + save_file=config.save_intermediate_files, + ) + + # Create simulation + simulation = op.get_system_with_Langevin_integrator( + topology, ff, config.temp_K, dt_ps=config.dt_ps + ) + + # Add position restraints for equilibration + simulation = op.add_position_restraints( + positions, topology, simulation, k=config.position_restraint_k + ) + + # Energy minimization + py_logger.info("Energy minimization...") + positions, simulation = op.minimize_energy( + positions, + simulation, + num_steps=config.energy_minimization_steps, + output_file_prefix=f"{structure_name}_minimized", + save_file=config.save_intermediate_files, + save_protein_only_file=False, # Don't need protein-only file here + ) + + # NVT equilibration with restraints + py_logger.info("NVT equilibration with restraints...") + positions, velocities, simulation = op.run_simulation( + positions=positions, + simulation=simulation, + velocities=None, + output_frequency=1000, # Less frequent output for equilibration + save_intermediate_files=config.save_intermediate_files, + ensemble="NVT", + output_file_prefix=f"{structure_name}_restrainedNVT", + num_steps=config.nvt_restraint_steps, + ) + + # NPT equilibration with restraints + py_logger.info("NPT equilibration with restraints...") + positions, velocities, simulation = op.run_simulation( + positions=positions, + simulation=simulation, + velocities=velocities, + temp_K=config.temp_K, + pressure_bar=config.pressure_bar, + output_frequency=1000, + save_intermediate_files=config.save_intermediate_files, + ensemble="NPT", + output_file_prefix=f"{structure_name}_restrainedNPT", + num_steps=config.npt_restraint_steps, + ) + + # Remove position restraints + py_logger.info("Removing position restraints...") + simulation.context.getSystem().removeForce(simulation.context.getSystem().getNumForces() - 1) + + # NVT equilibration without restraints + py_logger.info("NVT equilibration without restraints...") + positions, velocities, simulation = op.run_simulation( + positions=positions, + simulation=simulation, + velocities=velocities, + output_frequency=1000, + save_intermediate_files=config.save_intermediate_files, + ensemble="NVT", + output_file_prefix=f"{structure_name}_equilNVT", + num_steps=config.nvt_equil_steps, + ) + + # Final NPT equilibration + py_logger.info("Final NPT equilibration...") + positions, velocities, simulation = op.run_simulation( + positions=positions, + simulation=simulation, + velocities=velocities, + temp_K=config.temp_K, + pressure_bar=config.pressure_bar, + output_frequency=1000, + save_intermediate_files=config.save_intermediate_files, + ensemble="NPT", + output_file_prefix=f"{structure_name}_equilNPT", + num_steps=config.npt_equil_steps, + save_pdb=True, + pdb_output_file="equilibrated_start.pdb", # Save the starting structure for swarms + ) + + py_logger.info(f"Equilibration completed for {structure_name}") + return positions, velocities, simulation + + finally: + # Always return to original directory + os.chdir(original_dir) + + +def generate_swarms( + positions: op.Positions, + velocities: op.Velocities, + simulation: Simulation, + structure_dir: str, + structure_name: str, + config: SwarmConfig +) -> None: + """Generate swarm trajectories from equilibrated structure.""" + py_logger.info(f"Generating {config.num_swarms} swarm trajectories for {structure_name}") + + # Change to structure directory + original_dir = os.getcwd() + os.chdir(structure_dir) + + try: + # Calculate trajectory time in picoseconds + trajectory_time_ps = config.swarm_steps * config.dt_ps + + # Determine starting index based on existing swarms + if config.append_swarms: + start_idx = find_existing_swarms(structure_dir, config.swarm_steps, config.dt_ps) + else: + start_idx = 1 + + for swarm_count in range(config.num_swarms): + swarm_idx = start_idx + swarm_count + py_logger.info(f"Generating swarm {swarm_idx + 1}/{config.num_swarms}") + + # Set initial conditions (same positions, slightly perturbed velocities for variation) + simulation.context.setPositions(positions) + + # Add small random perturbation to velocities for each swarm + import numpy as np + from openmm.unit import nanometer, picosecond + np.random.seed(swarm_idx) # Reproducible but different per swarm + + # Get original velocities as numpy array + velocities_array = np.array(velocities.value_in_unit(nanometer/picosecond)) + + # Add small random perturbation (0.1% of thermal velocity) + perturbation_scale = 0.001 + thermal_velocity = np.sqrt(3 * 8.314 * config.temp_K / 1000) # Approximate thermal velocity + perturbation = np.random.normal(0, perturbation_scale * thermal_velocity, velocities_array.shape) + perturbed_velocities = velocities_array + perturbation + + # Convert back to OpenMM format + from openmm.unit import Quantity + perturbed_velocities_unit = Quantity(perturbed_velocities, nanometer/picosecond) + simulation.context.setVelocities(perturbed_velocities_unit) + + # Generate swarm trajectory + swarm_filename = f"swarm_{trajectory_time_ps:.0f}ps_{swarm_idx + 1:03d}.xtc" + + _, _, simulation = op.run_simulation( + positions=positions, + simulation=simulation, + velocities=velocities, + temp_K=config.temp_K, + pressure_bar=config.pressure_bar, + output_frequency=config.save_frequency, + save_intermediate_files=False, # No intermediate files for swarms + ensemble="NPT", + output_file_prefix=f"swarm_{swarm_idx + 1:03d}", + num_steps=config.swarm_steps, + save_xtc=True, + xtc_output_file=swarm_filename, + save_pdb=False, # Don't save PDB for each swarm + ) + + py_logger.info(f"Completed swarm {swarm_idx + 1}: {swarm_filename}") + + finally: + # Always return to original directory + os.chdir(original_dir) + + +def process_structure(pdb_file: str, structure_idx: int, config: SwarmConfig) -> None: + """Process a single structure: equilibrate and/or generate swarms.""" + structure_name = get_structure_name(pdb_file) + py_logger.info(f"Processing structure {structure_idx + 1}: {structure_name}") + + # Setup structure directory + structure_dir, structure_name = setup_structure_directory(pdb_file, config, structure_idx) + + # Check if equilibration exists and should be skipped + equilibration_exists = check_equilibration_exists(structure_dir) + + if config.skip_equilibration and not equilibration_exists: + py_logger.warning(f"--skip-equilibration set but no equilibrated_start.pdb found in {structure_dir}") + py_logger.info("Proceeding with equilibration...") + config.skip_equilibration = False + + # Handle equilibration + if config.skip_equilibration and equilibration_exists: + py_logger.info(f"Skipping equilibration for {structure_name} (equilibrated_start.pdb exists)") + # TODO: Load from equilibrated state if needed for swarm generation + positions, velocities, simulation = None, None, None + else: + # Perform equilibration + py_logger.info(f"Starting equilibration for {structure_name}") + positions, velocities, simulation = equilibrate_structure( + pdb_file, structure_dir, structure_name, config + ) + + # Stop here if only equilibrating + if config.equilibrate_only: + py_logger.info(f"Equilibration-only mode: completed equilibration for {structure_name}") + return + + # Generate swarms (need to implement loading from equilibrated state if skipped equilibration) + if config.skip_equilibration and equilibration_exists: + py_logger.info("Loading equilibrated state for swarm generation...") + # TODO: Implement loading from saved equilibrated state + py_logger.warning("Loading from saved equilibrated state not yet implemented!") + py_logger.warning("Please run without --skip-equilibration for now") + return + + # Generate swarm trajectories + generate_swarms(positions, velocities, simulation, structure_dir, structure_name, config) + + py_logger.info(f"Completed processing structure {structure_idx + 1}: {structure_name}") + + +def main(): + """Main execution function.""" + config = parse_args() + + # Create main output directory + os.makedirs(config.output_dir, exist_ok=True) + py_logger.info(f"Output directory: {config.output_dir}") + py_logger.info(f"Processing {len(config.input_pdbs)} structure(s)") + + # Process each structure + for idx, pdb_file in enumerate(config.input_pdbs): + try: + # Use global structure index in single-structure mode, otherwise use enumeration index + if config.single_structure_mode and config.structure_index is not None: + structure_idx = config.structure_index + else: + structure_idx = idx + + process_structure(pdb_file, structure_idx, config) + except Exception as e: + py_logger.error(f"Error processing {pdb_file}: {str(e)}") + if config.single_structure_mode: + raise # Re-raise in single structure mode for debugging + else: + py_logger.warning("Continuing with next structure...") + continue + + py_logger.info("Swarm generation completed!") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/generate_data/openmm_utils.py b/scripts/generate_data/openmm_utils.py index 5fe640e..a2a8652 100644 --- a/scripts/generate_data/openmm_utils.py +++ b/scripts/generate_data/openmm_utils.py @@ -145,6 +145,46 @@ def solvate( return modeller.positions, modeller.topology +def select_best_platform(): + """Select the best available OpenMM platform (preferably GPU) with fallback logic.""" + try: + import openmm + platforms = [] + for i in range(openmm.Platform.getNumPlatforms()): + platform = openmm.Platform.getPlatform(i) + platforms.append((platform.getName(), platform.getSpeed(), platform)) + + # Sort by speed (higher is better) and prefer CUDA > OpenCL > CPU + platform_priority = {'CUDA': 3, 'OpenCL': 2, 'CPU': 1, 'Reference': 0} + platforms.sort(key=lambda x: (x[1], platform_priority.get(x[0], 0)), reverse=True) + + # Try each platform in order until one works + for platform_name, speed, platform in platforms: + try: + # Quick test to see if platform can create a context + # We'll use a minimal system for testing + test_system = openmm.System() + test_system.addParticle(1.0) # Add one particle + test_integrator = openmm.LangevinMiddleIntegrator(300, 1.0, 0.002) + test_context = openmm.Context(test_system, test_integrator, platform) + del test_context # Clean up + del test_integrator + del test_system + + py_logger.info(f"Using OpenMM platform: {platform_name} (speed: {speed})") + return platform + except Exception as e: + py_logger.warning(f"Platform {platform_name} failed test: {e}") + continue + + py_logger.warning("No working OpenMM platforms found, using default") + return None + + except Exception as e: + py_logger.warning(f"Could not select optimal platform: {e}, using default") + return None + + def get_system_with_Langevin_integrator( topology: Topology, forcefield: ForceField, temp_K: float, dt_ps: float, state: Optional[str] = None ) -> Simulation: @@ -157,7 +197,13 @@ def get_system_with_Langevin_integrator( constraints=HBonds, ) integrator = LangevinMiddleIntegrator(temp_K * kelvin, 1 / picoseconds, dt_ps * picoseconds) - simulation = Simulation(topology, system, integrator) + + platform = select_best_platform() + if platform is not None: + simulation = Simulation(topology, system, integrator, platform) + else: + simulation = Simulation(topology, system, integrator) + if state is not None: simulation.loadState(state) return simulation @@ -175,7 +221,13 @@ def get_system_with_NoseHoover_integrator( constraints=HBonds, ) integrator = NoseHooverIntegrator(temp_K * kelvin, 1 / picoseconds, dt_ps * picoseconds) - simulation = Simulation(topology, system, integrator) + + platform = select_best_platform() + if platform is not None: + simulation = Simulation(topology, system, integrator, platform) + else: + simulation = Simulation(topology, system, integrator) + return simulation diff --git a/scripts/generate_data/run_swarms_parallel.sh b/scripts/generate_data/run_swarms_parallel.sh new file mode 100755 index 0000000..fe55b46 --- /dev/null +++ b/scripts/generate_data/run_swarms_parallel.sh @@ -0,0 +1,294 @@ +#!/bin/bash + +# Helper script for parallelizing swarm generation +# This script demonstrates different approaches for running generate_swarms.py in parallel + +# Default parameters +INPUT_FOLDER="" +INPUT_PDBS="" +OUTPUT_DIR="" +NUM_SWARMS=10 +SWARM_STEPS=10000 +SAVE_FREQUENCY=10 + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --input-folder) + INPUT_FOLDER="$2" + shift 2 + ;; + --input-pdbs) + shift + INPUT_PDBS="" + while [[ $# -gt 0 && $1 != --* ]]; do + INPUT_PDBS="$INPUT_PDBS $1" + shift + done + ;; + --output-dir) + OUTPUT_DIR="$2" + shift 2 + ;; + --num-swarms) + NUM_SWARMS="$2" + shift 2 + ;; + --swarm-steps) + SWARM_STEPS="$2" + shift 2 + ;; + --save-frequency) + SAVE_FREQUENCY="$2" + shift 2 + ;; + --help) + echo "Usage: $0 [OPTIONS]" + echo "" + echo "Options:" + echo " --input-folder DIR Folder containing PDB files" + echo " --input-pdbs FILE... List of PDB files" + echo " --output-dir DIR Output directory" + echo " --num-swarms N Number of swarms per structure (default: 10)" + echo " --swarm-steps N Steps per swarm (default: 10000)" + echo " --save-frequency N Save frequency (default: 10)" + echo " --help Show this help" + echo "" + echo "Examples:" + echo " # Using SLURM job arrays:" + echo " $0 --input-folder /path/to/pdbs --output-dir results" + echo "" + echo " # Using GNU parallel:" + echo " $0 --input-pdbs *.pdb --output-dir results" + exit 0 + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac +done + +# Validate required arguments +if [[ -z "$OUTPUT_DIR" ]]; then + echo "Error: --output-dir is required" + exit 1 +fi + +if [[ -z "$INPUT_FOLDER" && -z "$INPUT_PDBS" ]]; then + echo "Error: Either --input-folder or --input-pdbs is required" + exit 1 +fi + +# Get list of PDB files +if [[ -n "$INPUT_FOLDER" ]]; then + PDB_FILES=($(find "$INPUT_FOLDER" -name "*.pdb" | sort)) + echo "Found ${#PDB_FILES[@]} PDB files in $INPUT_FOLDER" +else + PDB_FILES=($INPUT_PDBS) + echo "Processing ${#PDB_FILES[@]} specified PDB files" +fi + +if [[ ${#PDB_FILES[@]} -eq 0 ]]; then + echo "Error: No PDB files found" + exit 1 +fi + +# Create output directory +mkdir -p "$OUTPUT_DIR" + +echo "==========================================" +echo "Swarm Generation Parallelization Helper" +echo "==========================================" +echo "PDB files to process: ${#PDB_FILES[@]}" +echo "Output directory: $OUTPUT_DIR" +echo "Swarms per structure: $NUM_SWARMS" +echo "Steps per swarm: $SWARM_STEPS" +echo "Save frequency: $SAVE_FREQUENCY" +echo "" + +# Method 1: SLURM Job Array +echo "=== SLURM Job Array Approach ===" +echo "To submit as a SLURM job array, create a file 'submit_swarms.sh':" +echo "" +cat << 'EOF' +#!/bin/bash +#SBATCH --job-name=swarms +#SBATCH --array=0-NUM_STRUCTURES-1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=1 +#SBATCH --mem=8GB +#SBATCH --time=24:00:00 +#SBATCH --output=swarms_%A_%a.out +#SBATCH --error=swarms_%A_%a.err + +# Get PDB file for this array job +PDB_FILES=(PDB_FILE_LIST) +PDB_FILE=${PDB_FILES[$SLURM_ARRAY_TASK_ID]} + +# Run swarm generation for single structure +python scripts/generate_data/generate_swarms.py \ + --input-pdbs "$PDB_FILE" \ + --output-dir OUTPUT_DIR \ + --single-structure \ + --structure-index $SLURM_ARRAY_TASK_ID \ + --num-swarms NUM_SWARMS \ + --swarm-steps SWARM_STEPS \ + --save-frequency SAVE_FREQUENCY +EOF + +# Create actual SLURM script +SLURM_SCRIPT="submit_swarms_$(date +%Y%m%d_%H%M%S).sh" +sed "s/NUM_STRUCTURES/$((${#PDB_FILES[@]}-1))/" << 'EOF' > "$SLURM_SCRIPT" +#!/bin/bash +#SBATCH --job-name=swarms +#SBATCH --array=0-NUM_STRUCTURES +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=1 +#SBATCH --mem=8GB +#SBATCH --time=24:00:00 +#SBATCH --output=swarms_%A_%a.out +#SBATCH --error=swarms_%A_%a.err + +# Get PDB file for this array job +EOF + +echo "PDB_FILES=(" >> "$SLURM_SCRIPT" +for pdb in "${PDB_FILES[@]}"; do + echo " \"$pdb\"" >> "$SLURM_SCRIPT" +done +echo ")" >> "$SLURM_SCRIPT" + +cat << EOF >> "$SLURM_SCRIPT" +PDB_FILE=\${PDB_FILES[\$SLURM_ARRAY_TASK_ID]} + +# Run swarm generation for single structure +python scripts/generate_data/generate_swarms.py \\ + --input-pdbs "\$PDB_FILE" \\ + --output-dir "$OUTPUT_DIR" \\ + --single-structure \\ + --structure-index \$SLURM_ARRAY_TASK_ID \\ + --num-swarms $NUM_SWARMS \\ + --swarm-steps $SWARM_STEPS \\ + --save-frequency $SAVE_FREQUENCY +EOF + +echo "Created SLURM script: $SLURM_SCRIPT" +echo "To submit: sbatch $SLURM_SCRIPT" +echo "" + +# Method 2: GNU Parallel +echo "=== GNU Parallel Approach ===" +echo "To run with GNU parallel:" + +# Create parallel command file +PARALLEL_SCRIPT="run_swarms_parallel_$(date +%Y%m%d_%H%M%S).sh" +cat << EOF > "$PARALLEL_SCRIPT" +#!/bin/bash + +# Function to process a single PDB file +process_pdb() { + local pdb_file="\$1" + local structure_idx="\$2" + + echo "Processing \$pdb_file (structure \$structure_idx)" + + python scripts/generate_data/generate_swarms.py \\ + --input-pdbs "\$pdb_file" \\ + --output-dir "$OUTPUT_DIR" \\ + --single-structure \\ + --structure-index "\$structure_idx" \\ + --num-swarms $NUM_SWARMS \\ + --swarm-steps $SWARM_STEPS \\ + --save-frequency $SAVE_FREQUENCY +} + +export -f process_pdb + +# Run in parallel (adjust -j for number of parallel jobs) +parallel -j 4 process_pdb {1} {#} ::: \\ +EOF + +for pdb in "${PDB_FILES[@]}"; do + echo " \"$pdb\" \\" >> "$PARALLEL_SCRIPT" +done + +# Remove last backslash +sed -i '$ s/ \\$//' "$PARALLEL_SCRIPT" + +chmod +x "$PARALLEL_SCRIPT" +echo "Created parallel script: $PARALLEL_SCRIPT" +echo "To run: ./$PARALLEL_SCRIPT" +echo "" + +# Method 3: Simple Background Jobs +echo "=== Background Jobs Approach ===" +BACKGROUND_SCRIPT="run_swarms_background_$(date +%Y%m%d_%H%M%S).sh" +cat << EOF > "$BACKGROUND_SCRIPT" +#!/bin/bash + +echo "Running swarm generation with background jobs..." + +# Process each PDB file in background (limit concurrent jobs) +max_jobs=4 # Adjust based on your system +job_count=0 + +EOF + +for i in "${!PDB_FILES[@]}"; do + cat << EOF >> "$BACKGROUND_SCRIPT" +# Wait if we've hit the job limit +while [ \$(jobs -r | wc -l) -ge \$max_jobs ]; do + sleep 1 +done + +echo "Starting structure $i: ${PDB_FILES[$i]}" +python scripts/generate_data/generate_swarms.py \\ + --input-pdbs "${PDB_FILES[$i]}" \\ + --output-dir "$OUTPUT_DIR" \\ + --single-structure \\ + --structure-index $i \\ + --num-swarms $NUM_SWARMS \\ + --swarm-steps $SWARM_STEPS \\ + --save-frequency $SAVE_FREQUENCY & + +EOF +done + +cat << 'EOF' >> "$BACKGROUND_SCRIPT" + +# Wait for all background jobs to complete +echo "Waiting for all jobs to complete..." +wait + +echo "All swarm generation jobs completed!" +EOF + +chmod +x "$BACKGROUND_SCRIPT" +echo "Created background jobs script: $BACKGROUND_SCRIPT" +echo "To run: ./$BACKGROUND_SCRIPT" +echo "" + +# Method 4: Single command (no parallelization) +echo "=== Single Process Approach ===" +echo "To run all structures in a single process:" +if [[ -n "$INPUT_FOLDER" ]]; then + SINGLE_CMD="python scripts/generate_data/generate_swarms.py --input-folder \"$INPUT_FOLDER\"" +else + SINGLE_CMD="python scripts/generate_data/generate_swarms.py --input-pdbs" + for pdb in "${PDB_FILES[@]}"; do + SINGLE_CMD="$SINGLE_CMD \"$pdb\"" + done +fi +SINGLE_CMD="$SINGLE_CMD --output-dir \"$OUTPUT_DIR\" --num-swarms $NUM_SWARMS --swarm-steps $SWARM_STEPS --save-frequency $SAVE_FREQUENCY" + +echo "$SINGLE_CMD" +echo "" + +echo "==========================================" +echo "Choose the parallelization method that best fits your computing environment:" +echo "1. SLURM job arrays - Best for HPC clusters" +echo "2. GNU parallel - Good for multi-core workstations" +echo "3. Background jobs - Simple shell-based parallelization" +echo "4. Single process - No parallelization, simplest approach" +echo "==========================================" \ No newline at end of file diff --git a/scripts/scrape_grid_points.py b/scripts/scrape_grid_points.py new file mode 100644 index 0000000..5a58bbe --- /dev/null +++ b/scripts/scrape_grid_points.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python3 +""" +Script to scrape grid point information from trajectory filenames in the enhanced_full_swarm dataset. + +This script analyzes the filenames in /data2/sules/ALA_ALA_enhanced_full_swarm/train to extract +grid codes and provide statistics about the dataset structure. +""" + +import os +import re +from pathlib import Path +from collections import defaultdict, Counter +import pandas as pd +import argparse + +def extract_grid_code_from_filename(filename): + """ + Extract grid code from trajectory filename. + + Expected format: swarm_1ps_{grid_code}_{traj_code}.xtc + Example: swarm_1ps_042_003.xtc -> grid_code = 042 + + Args: + filename: Name of the trajectory file + + Returns: + Grid code as string, or None if pattern doesn't match + """ + # Pattern to match swarm trajectory files + pattern = r'swarm_1ps_(\d{3})_(\d{3})\.xtc' + match = re.match(pattern, filename) + + if match: + grid_code = match.group(1) + traj_code = match.group(2) + return grid_code, traj_code + else: + return None, None + +def scrape_grid_points(data_dir, output_file=None): + """ + Scrape grid point information from trajectory directory. + + Args: + data_dir: Path to directory containing trajectory files + output_file: Optional path to save results as CSV + + Returns: + Dictionary with grid point statistics + """ + data_path = Path(data_dir) + + if not data_path.exists(): + raise FileNotFoundError(f"Directory does not exist: {data_dir}") + + print(f"Scraping grid points from: {data_dir}") + print("=" * 60) + + # Collect grid point information + grid_data = [] + grid_codes = set() + traj_codes = set() + grid_traj_counts = defaultdict(list) + + # Scan all files in directory + all_files = list(data_path.iterdir()) + xtc_files = [f for f in all_files if f.suffix == '.xtc'] + + print(f"Total files in directory: {len(all_files)}") + print(f"XTC trajectory files: {len(xtc_files)}") + print() + + # Process each XTC file + for file_path in xtc_files: + filename = file_path.name + grid_code, traj_code = extract_grid_code_from_filename(filename) + + if grid_code is not None and traj_code is not None: + grid_data.append({ + 'filename': filename, + 'grid_code': grid_code, + 'traj_code': traj_code, + 'grid_point': int(grid_code), + 'trajectory': int(traj_code) + }) + + grid_codes.add(grid_code) + traj_codes.add(traj_code) + grid_traj_counts[grid_code].append(traj_code) + else: + print(f"Warning: Could not parse filename: {filename}") + + # Create DataFrame for analysis + df = pd.DataFrame(grid_data) + + # Print statistics + print("GRID POINT STATISTICS") + print("=" * 60) + print(f"Total valid trajectory files: {len(grid_data)}") + print(f"Unique grid codes: {len(grid_codes)}") + print(f"Unique trajectory codes: {len(traj_codes)}") + print() + + print("Grid code range:") + if grid_codes: + grid_nums = sorted([int(gc) for gc in grid_codes]) + print(f" Min: {min(grid_nums):03d}") + print(f" Max: {max(grid_nums):03d}") + print(f" Grid codes: {', '.join(sorted(grid_codes))}") + print() + + print("Trajectory code range:") + if traj_codes: + traj_nums = sorted([int(tc) for tc in traj_codes]) + print(f" Min: {min(traj_nums):03d}") + print(f" Max: {max(traj_nums):03d}") + print(f" Trajectory codes: {', '.join(sorted(traj_codes))}") + print() + + # Trajectories per grid point + trajs_per_grid = [len(trajs) for trajs in grid_traj_counts.values()] + if trajs_per_grid: + print("Trajectories per grid point:") + print(f" Min: {min(trajs_per_grid)}") + print(f" Max: {max(trajs_per_grid)}") + print(f" Mean: {sum(trajs_per_grid) / len(trajs_per_grid):.1f}") + print() + + # Count distribution + traj_count_dist = Counter(trajs_per_grid) + print("Distribution of trajectories per grid:") + for count, freq in sorted(traj_count_dist.items()): + print(f" {count} trajectories: {freq} grid points") + print() + + # Check for missing trajectories + if grid_codes and traj_codes: + expected_total = len(grid_codes) * len(traj_codes) + actual_total = len(grid_data) + print(f"Expected files (grid × traj): {expected_total}") + print(f"Actual files found: {actual_total}") + if actual_total != expected_total: + print(f"Missing files: {expected_total - actual_total}") + + # Find missing combinations + missing = [] + for gc in grid_codes: + for tc in traj_codes: + if tc not in grid_traj_counts[gc]: + missing.append(f"swarm_1ps_{gc}_{tc}.xtc") + + if missing and len(missing) <= 20: # Only print if not too many + print("Missing files:") + for mf in missing: + print(f" {mf}") + elif missing: + print(f" (too many to list - {len(missing)} missing files)") + print() + + # Sample of grid points + if len(grid_data) > 0: + print("Sample trajectories:") + sample_size = min(10, len(grid_data)) + sample_df = df.sample(n=sample_size, random_state=42) + for _, row in sample_df.iterrows(): + print(f" {row['filename']} -> Grid: {row['grid_code']}, Traj: {row['traj_code']}") + + # Save to file if requested + if output_file: + output_path = Path(output_file) + df.to_csv(output_path, index=False) + print(f"\nResults saved to: {output_path}") + + # Also save summary statistics + summary_file = output_path.with_suffix('.summary.txt') + with open(summary_file, 'w') as f: + f.write(f"Grid Point Analysis Summary\n") + f.write(f"Directory: {data_dir}\n") + f.write(f"Total trajectory files: {len(grid_data)}\n") + f.write(f"Unique grid codes: {len(grid_codes)}\n") + f.write(f"Unique trajectory codes: {len(traj_codes)}\n") + f.write(f"Grid codes: {', '.join(sorted(grid_codes))}\n") + f.write(f"Trajectory codes: {', '.join(sorted(traj_codes))}\n") + + print(f"Summary saved to: {summary_file}") + + return { + 'data': df, + 'grid_codes': sorted(grid_codes), + 'traj_codes': sorted(traj_codes), + 'grid_traj_counts': dict(grid_traj_counts), + 'total_files': len(grid_data), + 'unique_grids': len(grid_codes), + 'unique_trajs': len(traj_codes) + } + +def main(): + parser = argparse.ArgumentParser(description='Scrape grid point information from trajectory files') + parser.add_argument('--data-dir', '-d', + default='/data2/sules/ALA_ALA_enhanced_full_swarm/train', + help='Directory containing trajectory files') + parser.add_argument('--output', '-o', + help='Output CSV file for results') + parser.add_argument('--also-val', action='store_true', + help='Also analyze validation set') + + args = parser.parse_args() + + # Analyze training set + print("ANALYZING TRAINING SET") + print("=" * 80) + train_results = scrape_grid_points(args.data_dir, args.output) + + # Optionally analyze validation set + if args.also_val: + val_dir = args.data_dir.replace('/train', '/val') + if os.path.exists(val_dir): + print("\n" + "=" * 80) + print("ANALYZING VALIDATION SET") + print("=" * 80) + val_output = args.output.replace('.csv', '_val.csv') if args.output else None + val_results = scrape_grid_points(val_dir, val_output) + + # Compare train vs val + print("\n" + "=" * 80) + print("TRAIN vs VAL COMPARISON") + print("=" * 80) + train_grids = set(train_results['grid_codes']) + val_grids = set(val_results['grid_codes']) + + print(f"Train grid codes: {len(train_grids)}") + print(f"Val grid codes: {len(val_grids)}") + print(f"Overlap: {len(train_grids & val_grids)}") + + if train_grids & val_grids: + print("Warning: Train and validation sets have overlapping grid codes!") + print(f"Overlapping codes: {sorted(train_grids & val_grids)}") + else: + print("Good: No overlap between train and validation grid codes") + else: + print(f"\nValidation directory not found: {val_dir}") + +if __name__ == "__main__": + main() diff --git a/scripts/slurm/check_sweep_runs.sh b/scripts/slurm/check_sweep_runs.sh new file mode 100755 index 0000000..277e2e9 --- /dev/null +++ b/scripts/slurm/check_sweep_runs.sh @@ -0,0 +1,148 @@ +#!/bin/bash + +# Helper script to check how many runs are in the enhanced sampling sweep +# and preview the array range for the SLURM script + +set -e + +# Configuration +WANDB_GROUP="fake_enhanced_data_jul_11_sweep" +ENTITY="sule-shashank" +PROJECT="jamun" + +echo "=========================================" +echo "Enhanced Sampling Sweep - Run Check" +echo "Group: $WANDB_GROUP" +echo "Entity/Project: $ENTITY/$PROJECT" +echo "=========================================" + +# Initialize conda +echo "Initializing conda..." +source ~/.bashrc +eval "$(conda shell.bash hook)" + +# Activate conda environment +echo "Activating jamun environment..." +conda activate jamun + +# Python script to fetch wandb runs and show summary +python -c " +import wandb +import sys +from collections import defaultdict + +# Configuration +entity = '$ENTITY' +project = '$PROJECT' +group = '$WANDB_GROUP' + +try: + # Initialize wandb API + api = wandb.Api() + + # Get all runs from the specified group + print(f'Fetching runs from {entity}/{project} with group \"{group}\"...') + runs = api.runs(f'{entity}/{project}', filters={'group': group}) + runs_list = list(runs) + + print(f'\\nFound {len(runs_list)} runs in group \"{group}\"') + + if len(runs_list) == 0: + print('No runs found in this group!') + sys.exit(1) + + # Collect parameter combinations + param_combinations = [] + conditioner_counts = defaultdict(int) + sigma_counts = defaultdict(int) + lag_time_counts = defaultdict(int) + + for i, run in enumerate(runs_list): + try: + config = run.config + cfg = config.get('cfg', {}) + + conditioner = cfg.get('model', {}).get('conditioner', {}).get('_target_', 'Unknown') + sigma = cfg.get('model', {}).get('sigma_distribution', {}).get('sigma', 'Unknown') + total_lag_time = cfg.get('data', {}).get('datamodule', {}).get('datasets', {}).get('train', {}).get('total_lag_time', 'Unknown') + + conditioner_name = conditioner.split('.')[-1] if conditioner != 'Unknown' else 'Unknown' + + param_combinations.append({ + 'index': i, + 'name': run.name, + 'run_path': '/'.join(run.path), + 'conditioner': conditioner_name, + 'sigma': sigma, + 'lag_time': total_lag_time, + 'state': run.state + }) + + conditioner_counts[conditioner_name] += 1 + sigma_counts[sigma] += 1 + lag_time_counts[total_lag_time] += 1 + + except Exception as e: + print(f'Warning: Could not extract parameters for run {i}: {e}') + param_combinations.append({ + 'index': i, + 'name': run.name, + 'run_path': '/'.join(run.path), + 'conditioner': 'Error', + 'sigma': 'Error', + 'lag_time': 'Error', + 'state': run.state + }) + + # Print summary + print('\\n========================================') + print('PARAMETER DISTRIBUTION SUMMARY:') + print('========================================') + + print('\\nConditioner types:') + for conditioner, count in sorted(conditioner_counts.items()): + print(f' {conditioner}: {count} runs') + + print('\\nSigma values:') + for sigma, count in sorted(sigma_counts.items()): + print(f' {sigma}: {count} runs') + + print('\\nLag time values:') + for lag_time, count in sorted(lag_time_counts.items()): + print(f' {lag_time}: {count} runs') + + # Print first 5 runs as examples + print('\\n========================================') + print('FIRST 5 RUNS (EXAMPLES):') + print('========================================') + print(f'{'Index':<6} {'Name':<25} {'Conditioner':<18} {'Sigma':<8} {'Lag':<5} {'State':<10}') + print('-' * 75) + + for combo in param_combinations[:5]: + print(f'{combo[\"index\"]:<6} {combo[\"name\"]:<25} {combo[\"conditioner\"]:<18} {combo[\"sigma\"]:<8} {combo[\"lag_time\"]:<5} {combo[\"state\"]:<10}') + + if len(param_combinations) > 5: + print(f'... and {len(param_combinations) - 5} more runs') + + print('\\n========================================') + print('SLURM ARRAY CONFIGURATION:') + print('========================================') + print(f'Total runs: {len(runs_list)}') + print(f'Array range: 0-{len(runs_list) - 1}') + print(f'\\nUpdate your SLURM script with:') + print(f'#SBATCH --array=0-{len(runs_list) - 1}') + print('\\nTo submit the job:') + print('sbatch scripts/slurm/sweep_enhanced_sampling.sh') + print('\\nTo submit a subset (e.g., first 5 runs):') + print('sbatch --array=0-4 scripts/slurm/sweep_enhanced_sampling.sh') + +except Exception as e: + print(f'Error: {e}', file=sys.stderr) + import traceback + traceback.print_exc() + sys.exit(1) +" + +echo "=========================================" +echo "Run check completed!" +echo "=========================================" \ No newline at end of file diff --git a/scripts/slurm/debug_sweep_enhanced_sampling.sh b/scripts/slurm/debug_sweep_enhanced_sampling.sh new file mode 100755 index 0000000..bfd8b34 --- /dev/null +++ b/scripts/slurm/debug_sweep_enhanced_sampling.sh @@ -0,0 +1,189 @@ +#!/bin/bash + +# Debug version of the enhanced sampling sweep script +# Usage: ./debug_sweep_enhanced_sampling.sh + +set -e + +# Check if run index is provided +if [ $# -ne 1 ]; then + echo "Usage: $0 " + echo "Please provide the 0-based index of the run to process." + exit 1 +fi + +RUN_INDEX=$1 + +# Set up environment +export JAMUN_ROOT_PATH=/homefs/home/sules/jamun +cd $JAMUN_ROOT_PATH + +# Initialize conda +echo "Initializing conda..." +source ~/.bashrc +eval "$(conda shell.bash hook)" + +# Activate conda environment +echo "Activating jamun environment..." +conda activate jamun + +# Configuration +WANDB_GROUP="fake_enhanced_data_jul_11_sweep" +ENTITY="sule-shashank" +PROJECT="jamun" + +echo "=========================================" +echo "DEBUG MODE - Enhanced Sampling Sweep" +echo "Working directory: $(pwd)" +echo "Processing run index: $RUN_INDEX from group: $WANDB_GROUP" +echo "=========================================" + +# Python script to fetch wandb runs and build jamun_sample command +python -c " +import wandb +import sys +import os + +# Configuration +entity = '$ENTITY' +project = '$PROJECT' +group = '$WANDB_GROUP' +run_index = $RUN_INDEX + +try: + # Initialize wandb API + api = wandb.Api() + + # Get all runs from the specified group + print(f'Fetching runs from {entity}/{project} with group \"{group}\"...') + runs = api.runs(f'{entity}/{project}', filters={'group': group}) + runs_list = list(runs) + + print(f'Found {len(runs_list)} runs in group \"{group}\"') + + # Check if run_index is valid + if run_index >= len(runs_list) or run_index < 0: + print(f'Error: Run index {run_index} is out of bounds. The group has {len(runs_list)} runs (indices 0 to {len(runs_list) - 1}).', file=sys.stderr) + sys.exit(1) + + # Get the specific run + run = runs_list[run_index] + run_path = '/'.join(run.path) + + print(f'\\nProcessing run: {run.name} ({run_path})') + print(f'Run URL: {run.url}') + print(f'Run state: {run.state}') + + # Extract parameters from the run config + config = run.config + print(f'\\nAvailable config keys: {list(config.keys())}') + + cfg_key = 'cfg' # This is the key used in jamun configs + + if cfg_key not in config: + print(f'Error: Config key \"{cfg_key}\" not found in run config.', file=sys.stderr) + sys.exit(1) + + cfg = config[cfg_key] + print(f'Config structure keys: {list(cfg.keys())}') + + # Extract required parameters with detailed debugging + try: + print(f'\\nExtracting parameters...') + + # Extract conditioner + conditioner = cfg['model']['conditioner']['_target_'] + print(f' ✓ Conditioner: {conditioner}') + + # Extract sigma + sigma = cfg['model']['sigma_distribution']['sigma'] + print(f' ✓ Sigma: {sigma}') + + # Extract total_lag_time + total_lag_time = cfg['data']['datamodule']['datasets']['train']['total_lag_time'] + print(f' ✓ Total Lag Time: {total_lag_time}') + + # Optional: Extract other useful parameters + model_arch_N_structures = cfg.get('model', {}).get('arch', {}).get('N_structures', total_lag_time) + print(f' ✓ Model N_structures: {model_arch_N_structures}') + + except KeyError as e: + print(f'Error: Could not extract required parameter: {e}', file=sys.stderr) + print(f'Available model keys: {cfg.get(\"model\", {}).keys()}', file=sys.stderr) + if 'model' in cfg: + print(f'Available model.conditioner keys: {cfg[\"model\"].get(\"conditioner\", {}).keys()}', file=sys.stderr) + print(f'Available model.sigma_distribution keys: {cfg[\"model\"].get(\"sigma_distribution\", {}).keys()}', file=sys.stderr) + if 'data' in cfg: + print(f'Available data keys: {cfg[\"data\"].keys()}', file=sys.stderr) + sys.exit(1) + + # Ensure all required parameters are present + if not all(v is not None for v in [run_path, conditioner, total_lag_time, sigma]): + print(f'Error: Could not extract all required parameters for run at index {run_index}.', file=sys.stderr) + sys.exit(1) + + # Create a meaningful run group name for sampling + conditioner_name = conditioner.split('.')[-1] # Get class name without module path + sampling_group = 'sample_enhanced_sampling_from_jul_11' + + # Create tags for better organization + tags = [ + f'sweep_run_{run_index}', + f'conditioner_{conditioner_name}', + f'sigma_{sigma}', + f'lag_time_{total_lag_time}', + 'enhanced_sampling', + 'sample_from_sweep' + ] + tags_string = '[' + ', '.join(f'\"{tag}\"' for tag in tags) + ']' + + print('\\n========================================') + print(f'Parameters extracted successfully!') + print(f' Run Path: {run_path}') + print(f' Conditioner: {conditioner}') + print(f' Conditioner Name: {conditioner_name}') + print(f' Sigma: {sigma}') + print(f' Total Lag Time: {total_lag_time}') + print(f' Sample Group: {sampling_group}') + print('========================================\\n') + + # Build the jamun_sample command + # Note: We don't override model.conditioner._target_ because the checkpoint + # already contains the correct conditioner configuration with all parameters + cmd_parts = [ + 'jamun_sample', + '--config-dir=configs', + 'experiment=sample_enhanced_sampling_single_shape.yaml', + f'++wandb_train_run_path={run_path}', + f'++init_datasets.total_lag_time={total_lag_time}', + f'++sigma={sigma}', + f'++delta={sigma}', + f'++logger.wandb.group={sampling_group}', + f'++logger.wandb.tags={tags_string}', + f'++logger.wandb.notes=\"Sampling from enhanced sampling sweep run {run_index} - {conditioner_name} sigma={sigma} lag={total_lag_time}\"', + f'++run_key=sweep_sample_{run_index}_{conditioner_name}_sigma_{sigma}_lag_{total_lag_time}' + ] + + print('DEBUG: Generated jamun_sample command:') + print('=' * 80) + cmd_string = ' \\\\\\n '.join(cmd_parts) + print(cmd_string) + print('=' * 80) + + print('\\nDEBUG: Single line command:') + print('=' * 80) + print(' '.join(cmd_parts)) + print('=' * 80) + + print(f'\\nDEBUG: Successfully processed run index {run_index}') + +except Exception as e: + print(f'Error: {e}', file=sys.stderr) + import traceback + traceback.print_exc() + sys.exit(1) +" + +echo "=========================================" +echo "DEBUG: Finished processing run index: $RUN_INDEX" +echo "=========================================" \ No newline at end of file diff --git a/scripts/slurm/generate_swarms_batch_final.sh b/scripts/slurm/generate_swarms_batch_final.sh new file mode 100755 index 0000000..1364de1 --- /dev/null +++ b/scripts/slurm/generate_swarms_batch_final.sh @@ -0,0 +1,183 @@ +#!/bin/bash + +# FINAL CORRECTED: Master script to generate SLURM batch jobs for swarm generation +# Correct equilibration steps: 50k restrained, 10 unrestrained +# Sequential processing: equilibration + swarms per structure before moving to next + +echo "🚀 FINAL CORRECTED Swarm Batch Generation Script" +echo "================================================" + +# Configuration +STRUCTURES_PER_BATCH=20 +INPUT_DIR="data/swarm_data/test" +OUTPUT_DIR="data/swarm_data/test/swarm_results" +SCRIPT_DIR="scripts/slurm/batches_final" + +# Equilibration settings (CORRECTED) +NVT_RESTRAINT_STEPS=50000 # 50k as requested +NPT_RESTRAINT_STEPS=50000 # 50k as requested +NVT_EQUIL_STEPS=10 # 10 steps (not 10k!) as requested +NPT_EQUIL_STEPS=10 # 10 steps (not 10k!) as requested + +# Swarm settings +NUM_SWARMS=5 +SWARM_STEPS=500 # 1ps ÷ 2fs/step = 500 steps per swarm +SAVE_FREQUENCY=10 + +# Create batch script directory +mkdir -p "$SCRIPT_DIR" + +# Get list of all PDB files +PDB_FILES=($(ls -1 "$INPUT_DIR"/*.pdb | sort)) +TOTAL_STRUCTURES=${#PDB_FILES[@]} + +echo "📊 Configuration:" +echo " Total structures: $TOTAL_STRUCTURES" +echo " Structures per batch: $STRUCTURES_PER_BATCH" +echo " Equilibration steps: NVT/NPT restrained=50k, unrestrained=10 (CORRECTED)" +echo " Swarms: $NUM_SWARMS × 1ps ($SWARM_STEPS steps) per structure" +echo " Workflow: Sequential (equil+swarms per structure)" +echo "" + +# Calculate number of batches needed +NUM_BATCHES=$(( (TOTAL_STRUCTURES + STRUCTURES_PER_BATCH - 1) / STRUCTURES_PER_BATCH )) + +echo "📝 Generating $NUM_BATCHES FINAL CORRECTED batch scripts..." + +for ((batch=1; batch<=NUM_BATCHES; batch++)); do + # Calculate structure range for this batch + start_idx=$(( (batch - 1) * STRUCTURES_PER_BATCH )) + end_idx=$(( start_idx + STRUCTURES_PER_BATCH - 1 )) + + # Don't exceed total number of structures + if [ $end_idx -ge $TOTAL_STRUCTURES ]; then + end_idx=$(( TOTAL_STRUCTURES - 1 )) + fi + + structures_in_batch=$(( end_idx - start_idx + 1 )) + + echo " Batch $batch: global indices $start_idx-$end_idx ($structures_in_batch structures)" + + # Create SLURM script for this batch + script_file="$SCRIPT_DIR/swarms_batch_${batch}_final.sh" + + cat > "$script_file" << EOF +#!/bin/bash +#SBATCH --job-name=swarms_batch_${batch}_final +#SBATCH --partition=gpu2 +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=16GB +#SBATCH --time=4:00:00 +#SBATCH --output=swarms_batch_${batch}_final_%j.out +#SBATCH --error=swarms_batch_${batch}_final_%j.err + +echo "SLURM_JOB_ID = \$SLURM_JOB_ID" +echo "hostname = \$(hostname)" +echo "Starting FINAL CORRECTED swarm batch ${batch} on GPU..." +echo "" + +# Print GPU info +nvidia-smi + +# Activate conda environment +echo "Activating conda environment..." +source /homefs/home/vanib/miniforge3/etc/profile.d/conda.sh +conda activate jamun +echo "Python path: \$(which python)" +echo "Conda environment: \$CONDA_DEFAULT_ENV" +echo "" + +# Change to working directory +cd /homefs/home/vanib/jamun + +# Create full PDB file list +ALL_PDB_FILES=( +EOF + + # Add ALL PDB files to each script (needed for structure index validation) + for ((i=0; i> "$script_file" + done + + cat >> "$script_file" << EOF +) + +echo "🧬 Processing structures with GLOBAL indices $start_idx to $end_idx:" +echo "Using full PDB list of \${#ALL_PDB_FILES[@]} files for proper indexing" +echo "Sequential workflow: equilibration + swarms per structure" +echo "" + +# Process each structure in this batch (SEQUENTIAL: equil+swarms per structure) +EOF + + # Add individual structure processing (SEQUENTIAL) + for ((global_idx=start_idx; global_idx<=end_idx; global_idx++)); do + pdb_file="${PDB_FILES[$global_idx]}" + cat >> "$script_file" << EOF + +echo "⚖️ Processing structure $global_idx: \$(basename "${pdb_file}")" +echo "============================================================" + +# SINGLE COMMAND: Do both equilibration AND swarms for this structure +echo "🔄 Processing structure $global_idx: equilibration + $NUM_SWARMS × 1ps swarms..." +python scripts/generate_data/generate_swarms.py \\ + --input-pdbs "\${ALL_PDB_FILES[@]}" \\ + --output-dir "$OUTPUT_DIR" \\ + --single-structure \\ + --structure-index $global_idx \\ + --nvt-restraint-steps $NVT_RESTRAINT_STEPS \\ + --npt-restraint-steps $NPT_RESTRAINT_STEPS \\ + --nvt-equil-steps $NVT_EQUIL_STEPS \\ + --npt-equil-steps $NPT_EQUIL_STEPS \\ + --num-swarms $NUM_SWARMS \\ + --swarm-steps $SWARM_STEPS \\ + --save-frequency $SAVE_FREQUENCY \\ + --save-intermediate-files + +if [ \$? -ne 0 ]; then + echo "❌ Processing failed for structure $global_idx" + exit 1 +fi + +echo "✅ COMPLETED structure $global_idx: \$(basename "${pdb_file}") (equilibration + $NUM_SWARMS swarms)" +echo "" +EOF + done + + cat >> "$script_file" << EOF + +# Summary +echo "📈 BATCH ${batch} SUMMARY" +echo "======================" +echo " Global structure indices: $start_idx to $end_idx" +echo " Structures processed: $structures_in_batch" +echo " Swarms per structure: $NUM_SWARMS" +echo " Total swarms generated: $(( structures_in_batch * NUM_SWARMS ))" +echo " Swarm duration: 1ps each" +echo " Equilibration: 50k restrained, 10 unrestrained steps" +echo " Workflow: Sequential (equil+swarms per structure)" +echo "" +echo "🎉 Batch ${batch} completed successfully!" +EOF + + chmod +x "$script_file" +done + +echo "" +echo "✅ Generated $NUM_BATCHES FINAL CORRECTED batch scripts in $SCRIPT_DIR/" +echo "" +echo "📋 To submit jobs:" +echo " # Submit first batch only (for testing):" +echo " sbatch $SCRIPT_DIR/swarms_batch_1_final.sh" +echo "" +echo " # After verification, submit remaining batches:" +echo " for i in {2..$NUM_BATCHES}; do" +echo " sbatch $SCRIPT_DIR/swarms_batch_\${i}_final.sh" +echo " done" +echo "" +echo "🔧 FINAL CORRECTIONS APPLIED:" +echo " ✅ Correct equilibration steps: 50k restrained, 10 unrestrained" +echo " ✅ Sequential workflow: equilibration + swarms per structure" +echo " ✅ Passes all PDB files for proper structure index validation" +echo " ✅ Each structure gets unique AA_XXX directory" \ No newline at end of file diff --git a/scripts/slurm/noise_check.sh b/scripts/slurm/noise_check.sh new file mode 100755 index 0000000..0286321 --- /dev/null +++ b/scripts/slurm/noise_check.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +#SBATCH --job-name=noise_check +#SBATCH --partition gpu2 +#SBATCH --nodes 1 +#SBATCH --ntasks-per-node=1 # Number of agents to run in parallel on this node +#SBATCH --gpus-per-node=1 # Assign one GPU to each agent +#SBATCH --cpus-per-task=12 +#SBATCH --time 1-0 +#SBATCH --mem-per-cpu=32G +#SBATCH --array 3-4 + +# Print job information +echo "Starting job $SLURM_JOB_ID, array task $SLURM_ARRAY_TASK_ID" +echo "Running on node: $(hostname)" +echo "Job started at: $(date)" + +# Set up environment +source ~/.bashrc +conda activate jamun + +# Change to project directory +cd /homefs/home/sules/jamun + +# Run training with the corresponding model config +echo "Training with experiment: ala_ala_denoiser_experiment_model${SLURM_ARRAY_TASK_ID}" +jamun_train --config-dir=configs experiment=ala_ala_denoiser_experiment_model${SLURM_ARRAY_TASK_ID} + +echo "Job completed at: $(date)" \ No newline at end of file diff --git a/scripts/slurm/run_denoiser_experiments.py b/scripts/slurm/run_denoiser_experiments.py new file mode 100644 index 0000000..457201a --- /dev/null +++ b/scripts/slurm/run_denoiser_experiments.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 +""" +Script to run the three ALA_ALA denoiser experiments. + +Experiment Setup: +================= + +1. Model 1: Denoiser with SelfConditioner, 2 structures, noise level sigma=0.04 + - Uses real lagged frames from trajectory as hidden states + - SelfConditioner just repeats the current position + +2. Model 2: Denoiser with SelfConditioner, 2 structures, noise level sigma/sqrt(2)≈0.0283 + - Same as Model 1 but with reduced noise level + - Uses real lagged frames from trajectory as hidden states + +3. Model 3: Denoiser with PositionConditioner, 2 structures, noise level sigma=0.04 + - Hidden states are repeated copies of current position (not real trajectory frames) + - PositionConditioner aligns these copies to current position + - Noise is added by the denoiser during training + +Usage: +====== +python run_denoiser_experiments.py [model_number] + +Where model_number is 1, 2, or 3. If no number is provided, all models will be run. +""" + +import subprocess +import sys +import time +from pathlib import Path + +def run_experiment(model_num: int, root_path: str = "/data2/sules/jamun-denoiser-experiments"): + """Run a specific experiment model.""" + config_name = f"ala_ala_denoiser_experiment_model{model_num}" + + # Map model numbers to descriptions + descriptions = { + 1: "Model 1: SelfConditioner, sigma=0.04", + 2: "Model 2: SelfConditioner, sigma/sqrt(2)≈0.0283", + 3: "Model 3: PositionConditioner with repeated position copies, sigma=0.04" + } + + print(f"\n{'='*60}") + print(f"Starting {descriptions[model_num]}") + print(f"Config: {config_name}") + print(f"Output path: {root_path}/model{model_num}") + print(f"{'='*60}\n") + + cmd = [ + "python", "jamun_train.py", + "--config-dir=configs", + f"experiment={config_name}", + f"++paths.root_path={root_path}/model{model_num}", + "++trainer.max_epochs=500", + "++trainer.log_every_n_steps=10" + ] + + print(f"Running command: {' '.join(cmd)}") + + try: + result = subprocess.run(cmd, check=True, cwd=Path(__file__).parent) + print(f"\n✅ Model {model_num} completed successfully!") + return True + except subprocess.CalledProcessError as e: + print(f"\n❌ Model {model_num} failed with error: {e}") + return False + except KeyboardInterrupt: + print(f"\n⚠️ Model {model_num} interrupted by user") + return False + +def main(): + """Main function to run experiments.""" + print(__doc__) + + # Parse command line arguments + if len(sys.argv) > 1: + try: + model_num = int(sys.argv[1]) + if model_num not in [1, 2, 3]: + raise ValueError() + models_to_run = [model_num] + except ValueError: + print("Error: Please provide a valid model number (1, 2, or 3)") + sys.exit(1) + else: + models_to_run = [1, 2, 3] + print("No model specified. Running all three models...") + + # Check if we're in the right directory + if not Path("jamun_train.py").exists(): + print("Error: jamun_train.py not found. Please run this script from the jamun root directory.") + sys.exit(1) + + # Run experiments + start_time = time.time() + results = {} + + for model_num in models_to_run: + print(f"\n\nStarting Model {model_num}...") + results[model_num] = run_experiment(model_num) + + if len(models_to_run) > 1 and model_num != models_to_run[-1]: + print(f"\nWaiting 10 seconds before starting next model...") + time.sleep(10) + + # Print summary + elapsed = time.time() - start_time + print(f"\n\n{'='*60}") + print(f"EXPERIMENT SUMMARY") + print(f"{'='*60}") + print(f"Total time: {elapsed/3600:.2f} hours") + print() + + for model_num in models_to_run: + status = "✅ SUCCESS" if results[model_num] else "❌ FAILED" + print(f"Model {model_num}: {status}") + + print(f"\nResults saved to: /data2/sules/jamun-denoiser-experiments/") + print(f"{'='*60}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/slurm/run_model3_beta_comparison.sh b/scripts/slurm/run_model3_beta_comparison.sh new file mode 100644 index 0000000..fc96d45 --- /dev/null +++ b/scripts/slurm/run_model3_beta_comparison.sh @@ -0,0 +1,54 @@ +#!/bin/bash +#SBATCH --job-name=model3_beta_comparison +#SBATCH --partition=gpu2 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 # Number of agents to run in parallel on this node +#SBATCH --gpus-per-node=1 # Assign one GPU to each agent +#SBATCH --cpus-per-task=12 +#SBATCH --time=1-0 +#SBATCH --mem-per-cpu=32G +#SBATCH --array=1-2 + +# Create logs directory if it doesn't exist +mkdir -p logs + +eval "$(conda shell.bash hook)" +conda activate jamun + +set -eux + +echo "SLURM_JOB_ID = ${SLURM_JOB_ID}" +echo "SLURM_ARRAY_TASK_ID = ${SLURM_ARRAY_TASK_ID}" +echo "hostname = $(hostname)" + +export HYDRA_FULL_ERROR=1 + +# Generate unique run key for this experiment +RUN_KEY=$(openssl rand -hex 12) +echo "RUN_KEY = ${RUN_KEY}" + +nvidia-smi + +# Navigate to project directory +cd /homefs/home/sules/jamun + +# Determine which beta configuration to run based on array task ID +if [ $SLURM_ARRAY_TASK_ID -eq 1 ]; then + # First run: Adam betas (0.9, 0.9) + echo "Running model3 experiment with Adam betas (0.9, 0.9)" + jamun_train --config-dir=configs \ + experiment=ala_ala_denoiser_experiment_model3 \ + ++run_key=$RUN_KEY \ + '++model.optim.betas=[0.9,0.9]' \ + ++logger.wandb.name="model3_beta09_09" +else + # Second run: Adam betas (0.9, 0.999) - PyTorch default + echo "Running model3 experiment with Adam betas (0.9, 0.999)" + jamun_train --config-dir=configs \ + experiment=ala_ala_denoiser_experiment_model3 \ + ++run_key=$RUN_KEY \ + '++model.optim.betas=[0.9,0.999]' \ + ++logger.wandb.name="model3_beta09_0999" +fi + +echo "Model3 beta comparison experiment $SLURM_ARRAY_TASK_ID completed" \ No newline at end of file diff --git a/scripts/slurm/run_single_sample.sh b/scripts/slurm/run_single_sample.sh new file mode 100755 index 0000000..4730be3 --- /dev/null +++ b/scripts/slurm/run_single_sample.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +# This script runs sampling for a specific run from a wandb sweep, selected by an index. + +jamun_sample --config-dir=configs experiment=sample_capped_single_shape_conditioning ++wandb_train_run_path=sule-shashank/jamun/zchesftt ++logger.wandb.notes=jumping-sweep-29 +jamun_sample --config-dir=configs experiment=sample_capped_single_shape_conditioning ++wandb_train_run_path=sule-shashank/jamun/jqp09yv1 ++logger.wandb.notes=stellar-sweep-25 \ No newline at end of file diff --git a/scripts/slurm/run_single_sample_from_sweep.sh b/scripts/slurm/run_single_sample_from_sweep.sh new file mode 100755 index 0000000..a53eb1f --- /dev/null +++ b/scripts/slurm/run_single_sample_from_sweep.sh @@ -0,0 +1,86 @@ +#!/bin/bash + +# This script runs sampling for a specific run from a wandb sweep, selected by an index. + +set -e + +# --- Configuration --- +SWEEP_ID="sule-shashank/jamun/evgtrff4" + +# --- Argument Parsing --- +if [ "$#" -ne 1 ]; then + echo "Usage: $0 " + echo "Please provide the 0-based index of the run to process." + exit 1 +fi +RUN_INDEX=$1 + +# --- Main Logic --- +echo "Fetching run at index $RUN_INDEX from sweep $SWEEP_ID..." + +python -c " +import wandb +import subprocess +import sys + +# The sweep ID and run index are passed as command-line arguments +if len(sys.argv) < 3: + print('Usage: python_script.py ', file=sys.stderr) + sys.exit(1) +sweep_id = sys.argv[1] +run_index = int(sys.argv[2]) + +try: + api = wandb.Api() + sweep = api.sweep(sweep_id) + # wandb runs are often ordered from newest to oldest; reverse to make index stable + runs_list = list(reversed(list(sweep.runs))) + + if run_index >= len(runs_list) or run_index < 0: + print(f'Error: Run index {run_index} is out of bounds. The sweep has {len(runs_list)} runs (indices 0 to {len(runs_list) - 1}).', file=sys.stderr) + sys.exit(1) + + run = runs_list[run_index] + run_path = '/'.join(run.path) + conditioner = run.config.get('cfg', {}).get('model', {}).get('conditioner', {}).get('_target_') + total_lag_time = run.config.get('cfg', {}).get('data', {}).get('datamodule', {}).get('datasets', {}).get('train', {}).get('total_lag_time') + sigma = run.config.get('cfg', {}).get('model', {}).get('sigma_distribution', {}).get('sigma') + + # Ensure all required parameters are present + if not all(v is not None for v in [run_path, conditioner, total_lag_time, sigma]): + print(f'Error: Could not extract all required parameters for run at index {run_index}.', file=sys.stderr) + sys.exit(1) + + print('========================================') + print(f'Starting sampling for run at index {run_index}: {run_path}') + print(f' Conditioner: {conditioner}') + print(f' Total Lag Time: {total_lag_time}') + print(f' Sigma: {sigma}') + print('========================================') + + # Execute jamun_sample with the extracted parameters + tags_string = '[' + f'\"{str(conditioner)}\", \"{str(total_lag_time)}\", \"{str(sigma)}\"' + ']' + cmd = [ + 'jamun_sample', + '--config-dir=configs', + 'experiment=sample_capped_single_shape_conditioning.yaml', + f'wandb_train_run_path={run_path}', + f'++init_datasets.total_lag_time={total_lag_time}', + f'++sigma={sigma}', + f'++delta={sigma}', + f'++logger.wandb.group=sampling_from_sweep_{run_index}', + f'++logger.wandb.tags={tags_string}' + ] + + result = subprocess.run(cmd, check=True) + + print('----------------------------------------') + print(f'Finished sampling for run index {run_index}.') + +except subprocess.CalledProcessError as e: + print(f'Error running jamun_sample: {e}', file=sys.stderr) + sys.exit(1) +except Exception as e: + print(f'Error fetching data from wandb: {e}', file=sys.stderr) + sys.exit(1) +" "$SWEEP_ID" "$RUN_INDEX" \ No newline at end of file diff --git a/scripts/slurm/run_sweep_sampling.sh b/scripts/slurm/run_sweep_sampling.sh new file mode 100644 index 0000000..c0a6011 --- /dev/null +++ b/scripts/slurm/run_sweep_sampling.sh @@ -0,0 +1,33 @@ +#!/bin/bash +#SBATCH --job-name=sweep_sampling +#SBATCH --partition gpu2 +#SBATCH --nodes 1 +#SBATCH --ntasks-per-node=1 # Number of agents to run in parallel on this node +#SBATCH --gpus-per-node=1 # Assign one GPU to each agent +#SBATCH --cpus-per-task=12 +#SBATCH --time 1-0 +#SBATCH --mem-per-cpu=32G +#SBATCH --array 2-31 + +# Create logs directory if it doesn't exist +mkdir -p logs + +# Print job information +echo "Job ID: $SLURM_JOB_ID" +echo "Array Task ID: $SLURM_ARRAY_TASK_ID" +echo "Running on node: $HOSTNAME" +echo "Starting time: $(date)" + +# Activate conda environment (adjust the environment name as needed) +source ~/.bashrc +conda activate jamun + +# Change to the working directory +cd /homefs/home/sules/jamun + +# Run the sampling script with the array task ID as the run index +echo "Running sampling for sweep run index: $SLURM_ARRAY_TASK_ID" +bash run_single_sample_from_sweep.sh $SLURM_ARRAY_TASK_ID + +echo "Finished sampling for run index: $SLURM_ARRAY_TASK_ID" +echo "End time: $(date)" \ No newline at end of file diff --git a/scripts/slurm/run_train_noise_check.sh b/scripts/slurm/run_train_noise_check.sh new file mode 100755 index 0000000..617fc39 --- /dev/null +++ b/scripts/slurm/run_train_noise_check.sh @@ -0,0 +1,63 @@ +#!/usr/bin/env bash +# +# Wrapper script to run train_noise_check.sh for multiple m values +# +# This script loops over m values from 2 to 10 and submits the train_noise_check.sh +# SLURM script for each value. This ensures only 4 parallel jobs are submitted at a time +# (one for each model configuration) rather than submitting all 36 jobs at once. +# +# Usage: ./run_train_noise_check.sh +# + +# Set script directory +# SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TRAIN_SCRIPT="scripts/slurm/train_noise_check.sh" + +# Check if train_noise_check.sh exists +if [ ! -f "$TRAIN_SCRIPT" ]; then + echo "Error: train_noise_check.sh not found at $TRAIN_SCRIPT" + exit 1 +fi + +# Make sure the script is executable +chmod +x "$TRAIN_SCRIPT" + +echo "Starting noise check experiments for m values 2-10" +echo "Each submission will run 4 parallel jobs (one for each model configuration)" +echo "" + +# Loop over m values from 2 to 10 +for M in {2..10}; do + echo "Submitting jobs for M=$M..." + + # Submit the SLURM script with the current m value + JOB_ID=$(sbatch --parsable scripts/slurm/train_noise_check.sh $M) + + if [ $? -eq 0 ]; then + echo " ✓ Successfully submitted job ID: $JOB_ID for M=$M" + echo " This will run 4 parallel jobs (array 0-3) for the 4 model configurations" + + # Wait for all array jobs to complete before submitting next batch + # echo " ⏳ Waiting for job ID $JOB_ID to complete before submitting next batch..." + echo " ⏳ Submitted ID $JOB_ID..." + # # Wait for the job to finish (all array tasks) + # while squeue -j "$JOB_ID" 2>/dev/null | grep -q "$JOB_ID"; do + # sleep 30 # Check every 30 seconds + # done + + echo " ✅ All jobs for M=$M (Job ID: $JOB_ID) completed!" + echo "" + + else + echo " ✗ Failed to submit job for M=$M" + exit 1 + fi +done + +echo "" +echo "All jobs submitted successfully!" +echo "Total submissions: 9 (one for each m value from 2-10)" +echo "Total jobs: 36 (4 models × 9 m values)" +echo "" +echo "Monitor job status with: squeue -u \$USER" +echo "View job outputs in the current directory with pattern: slurm-_.out" diff --git a/scripts/slurm/sample_delta_friction_sweep.sh b/scripts/slurm/sample_delta_friction_sweep.sh new file mode 100644 index 0000000..59acb70 --- /dev/null +++ b/scripts/slurm/sample_delta_friction_sweep.sh @@ -0,0 +1,39 @@ +#!/bin/bash +#SBATCH --job-name=sweep_delta_friction +#SBATCH --partition=gpu2 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=1 +#SBATCH --cpus-per-task=8 +#SBATCH --time=1-0 +#SBATCH --mem-per-cpu=32G +#SBATCH --array=0-24 + +# Load environment +source ~/.bashrc +conda activate jamun +cd /homefs/home/sules/jamun +mkdir -p logs + +# Precomputed values for sigma=0.04 +# Delta: 5 values from sigma/sqrt(5) to sqrt(5)*sigma +DELTAS=(0.017889 0.026833 0.040000 0.059665 0.089443) + +# Friction: -log of 5 values from 0.01 to 0.99 +FRICTIONS=(2.52572864 1.2552661 0.71334989 0.36384343 0.10536052) + +# Get parameter values based on array index +DELTA_INDEX=$((SLURM_ARRAY_TASK_ID / 5)) +FRICTION_INDEX=$((SLURM_ARRAY_TASK_ID % 5)) +DELTA=${DELTAS[$DELTA_INDEX]} +FRICTION=${FRICTIONS[$FRICTION_INDEX]} + +echo "Running: delta=$DELTA, friction=$FRICTION (job $SLURM_ARRAY_TASK_ID)" + +# Run experiment +jamun_sample \ + --config-dir=configs \ + experiment=sample_enhanced_conditioning_sweep \ + ++delta=$DELTA \ + ++friction=$FRICTION \ + ++logger.wandb.name="sweep_d${DELTA}_f${FRICTION}_${SLURM_ARRAY_JOB_ID}_${SLURM_ARRAY_TASK_ID}" \ No newline at end of file diff --git a/scripts/slurm/show_sweep_combinations.py b/scripts/slurm/show_sweep_combinations.py new file mode 100644 index 0000000..b89a908 --- /dev/null +++ b/scripts/slurm/show_sweep_combinations.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +""" +Show all parameter combinations for the enhanced sampling sweep. +""" + +# Define parameter arrays (same as in the bash script) +CONDITIONERS = ["PositionConditioner", "SelfConditioner"] +SIGMAS = [0.01, 0.04, 0.08, 0.1] +LAG_TIMES = [2, 5, 8] + +print("Enhanced Sampling Training Sweep - Parameter Combinations") +print("=" * 60) +print(f"Total combinations: {len(CONDITIONERS)} × {len(SIGMAS)} × {len(LAG_TIMES)} = {len(CONDITIONERS) * len(SIGMAS) * len(LAG_TIMES)}") +print() +print(f"{'Task ID':<8} {'Conditioner':<18} {'Sigma':<8} {'Lag Time':<10}") +print("-" * 45) + +task_id = 0 +for cond_idx, conditioner in enumerate(CONDITIONERS): + for sigma_idx, sigma in enumerate(SIGMAS): + for lag_idx, lag_time in enumerate(LAG_TIMES): + print(f"{task_id:<8} {conditioner:<18} {sigma:<8} {lag_time:<10}") + task_id += 1 + +print() +print("To run the sweep:") +print("sbatch scripts/slurm/train_enhanced_sampling_sweep.sh") \ No newline at end of file diff --git a/scripts/slurm/sweep.sh b/scripts/slurm/sweep.sh new file mode 100644 index 0000000..f319723 --- /dev/null +++ b/scripts/slurm/sweep.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash +#SBATCH --partition=b200 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 # Number of agents to run in parallel on this node +#SBATCH --gpus-per-node=1 # Assign one GPU to each agent +#SBATCH --cpus-per-task=12 +#SBATCH --time 1-0 +#SBATCH --mem-per-cpu=32G +#SBATCH --array 0-24 + +# Check if a Sweep ID is provided as an argument +export JAMUN_ROOT_PATH=/data2/sules/jamun-conditional-runs +if [ -z "$1" ]; then + echo "Error: Please provide the W&B Sweep ID as the first argument." + echo "Usage: sbatch scripts/slurm/sweep.sh " + exit 1 +fi + +SWEEP_ID=$1 + +# Set up the environment +source ~/.bashrc +eval "$(conda shell.bash hook)" +conda activate jamun + +set -eux + +echo "SLURM_JOB_ID: ${SLURM_JOB_ID}" +echo "Running on hostname: $(hostname)" +echo "Starting ${SLURM_NTASKS} agents for sweep: ${SWEEP_ID}" + +# Launch multiple wandb agents in parallel using srun. +# Each agent will poll the sweep server, get a configuration, and run one training job. +# PyTorch Lightning will automatically use the single GPU assigned by Slurm to each task. +wandb agent "${SWEEP_ID}" \ No newline at end of file diff --git a/scripts/slurm/sweep_enhanced_sampling.sh b/scripts/slurm/sweep_enhanced_sampling.sh new file mode 100755 index 0000000..a3fdea4 --- /dev/null +++ b/scripts/slurm/sweep_enhanced_sampling.sh @@ -0,0 +1,179 @@ +#!/bin/bash +#SBATCH --job-name=sweep_enhanced_sampling +#SBATCH --partition=gpu2 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=1 +#SBATCH --cpus-per-task=12 +#SBATCH --time=2-0 +#SBATCH --mem-per-cpu=32G +#SBATCH --array=0-23 # Adjust this range based on number of runs in your sweep +#SBATCH --output=logs/%A_%a_sweep_enhanced_sampling.log +#SBATCH --error=logs/%A_%a_sweep_enhanced_sampling.err + +# Set up environment +set -e +export JAMUN_ROOT_PATH=/homefs/home/sules/jamun +cd $JAMUN_ROOT_PATH + +# Create logs directory if it doesn't exist +mkdir -p logs + +# Initialize conda +echo "Initializing conda..." +source ~/.bashrc +eval "$(conda shell.bash hook)" + +# Activate conda environment +echo "Activating jamun environment..." +conda activate jamun + +# Configuration +WANDB_GROUP="fake_enhanced_data_jul_11_sweep" +ENTITY="sule-shashank" +PROJECT="jamun" +RUN_INDEX=$SLURM_ARRAY_TASK_ID + +echo "=========================================" +echo "Enhanced Sampling Sweep - Production Run" +echo "SLURM Job ID: $SLURM_JOB_ID" +echo "Array Task ID: $SLURM_ARRAY_TASK_ID" +echo "Running on hostname: $(hostname)" +echo "Working directory: $(pwd)" +echo "Starting time: $(date)" +echo "Processing run index: $RUN_INDEX from group: $WANDB_GROUP" +echo "=========================================" + +# Python script to fetch wandb runs and execute jamun_sample +python -c " +import wandb +import subprocess +import sys +import os + +# Configuration +entity = '$ENTITY' +project = '$PROJECT' +group = '$WANDB_GROUP' +run_index = $RUN_INDEX + +try: + # Initialize wandb API + api = wandb.Api() + + # Get all runs from the specified group + print(f'Fetching runs from {entity}/{project} with group \"{group}\"...') + runs = api.runs(f'{entity}/{project}', filters={'group': group}) + runs_list = list(runs) + + print(f'Found {len(runs_list)} runs in group \"{group}\"') + + # Check if run_index is valid + if run_index >= len(runs_list) or run_index < 0: + print(f'Error: Run index {run_index} is out of bounds. The group has {len(runs_list)} runs (indices 0 to {len(runs_list) - 1}).', file=sys.stderr) + sys.exit(1) + + # Get the specific run + run = runs_list[run_index] + run_path = '/'.join(run.path) + + print(f'\\nProcessing run: {run.name} ({run_path})') + print(f'Run URL: {run.url}') + print(f'Run state: {run.state}') + + # Extract parameters from the run config + config = run.config + cfg_key = 'cfg' # This is the key used in jamun configs + + if cfg_key not in config: + print(f'Error: Config key \"{cfg_key}\" not found in run config. Available keys: {list(config.keys())}', file=sys.stderr) + sys.exit(1) + + cfg = config[cfg_key] + + # Extract required parameters + try: + conditioner = cfg['model']['conditioner']['_target_'] + sigma = cfg['model']['sigma_distribution']['sigma'] + total_lag_time = cfg['data']['datamodule']['datasets']['train']['total_lag_time'] + + print(f'\\nExtracted parameters:') + print(f' Conditioner: {conditioner}') + print(f' Sigma: {sigma}') + print(f' Total Lag Time: {total_lag_time}') + + except KeyError as e: + print(f'Error: Could not extract required parameter: {e}', file=sys.stderr) + print(f'Available config structure: {cfg.keys()}', file=sys.stderr) + sys.exit(1) + + # Ensure all required parameters are present + if not all(v is not None for v in [run_path, conditioner, total_lag_time, sigma]): + print(f'Error: Could not extract all required parameters for run at index {run_index}.', file=sys.stderr) + sys.exit(1) + + # Create a meaningful run group name for sampling + conditioner_name = conditioner.split('.')[-1] # Get class name without module path + sampling_group = 'sample_enhanced_sampling_from_jul_11' + + # Create tags for better organization + tags = [ + f'sweep_run_{run_index}', + f'conditioner_{conditioner_name}', + f'sigma_{sigma}', + f'lag_time_{total_lag_time}', + 'enhanced_sampling', + 'sample_from_sweep' + ] + tags_string = '[' + ', '.join(f'\"{tag}\"' for tag in tags) + ']' + + print('\\n========================================') + print(f'Starting sampling for run at index {run_index}: {run_path}') + print(f' Conditioner: {conditioner}') + print(f' Sigma: {sigma}') + print(f' Total Lag Time: {total_lag_time}') + print(f' Sample Group: {sampling_group}') + print('========================================\\n') + + # Build the jamun_sample command + # Note: We don't override model.conditioner._target_ because the checkpoint + # already contains the correct conditioner configuration with all parameters + cmd = [ + 'jamun_sample', + '--config-dir=configs', + 'experiment=sample_enhanced_sampling_single_shape.yaml', + f'++wandb_train_run_path={run_path}', + f'++init_datasets.total_lag_time={total_lag_time}', + f'++sigma={sigma}', + f'++delta={sigma}', + f'++logger.wandb.group={sampling_group}', + f'++logger.wandb.tags={tags_string}', + f'++logger.wandb.notes=\"Sampling from enhanced sampling sweep run {run_index} - {conditioner_name} sigma={sigma} lag={total_lag_time}\"', + f'++run_key=sweep_sample_{run_index}_{conditioner_name}_sigma_{sigma}_lag_{total_lag_time}' + ] + + print('Executing command:') + print(' '.join(cmd)) + print('\\n' + '='*50 + '\\n') + + # Execute the command + result = subprocess.run(cmd, check=True, env=os.environ.copy()) + + print('\\n' + '='*50) + print(f'Successfully completed sampling for run index {run_index}') + print(f'End time: {os.popen(\"date\").read().strip()}') + +except subprocess.CalledProcessError as e: + print(f'Error running jamun_sample: {e}', file=sys.stderr) + sys.exit(1) +except Exception as e: + print(f'Error: {e}', file=sys.stderr) + import traceback + traceback.print_exc() + sys.exit(1) +" + +echo "=========================================" +echo "Finished processing run index: $RUN_INDEX" +echo "End time: $(date)" +echo "=========================================" \ No newline at end of file diff --git a/scripts/slurm/train_capped_2AA.sh b/scripts/slurm/train_capped_2AA.sh index f8117be..4a54eeb 100644 --- a/scripts/slurm/train_capped_2AA.sh +++ b/scripts/slurm/train_capped_2AA.sh @@ -28,7 +28,7 @@ echo "RUN_KEY = ${RUN_KEY}" nvidia-smi srun --cpus-per-task 8 --cpu-bind=cores,verbose \ - jamun_train --config-dir=/homefs/home/daigavaa/jamun/configs \ + jamun_train --config-dir=configs \ experiment=train_capped_2AA.yaml \ ++trainer.devices=$SLURM_GPUS_PER_NODE \ ++trainer.num_nodes=$SLURM_JOB_NUM_NODES \ diff --git a/scripts/slurm/train_capped_2AA_ALA_ALA_conditional.sh b/scripts/slurm/train_capped_2AA_ALA_ALA_conditional.sh new file mode 100644 index 0000000..7d10607 --- /dev/null +++ b/scripts/slurm/train_capped_2AA_ALA_ALA_conditional.sh @@ -0,0 +1,39 @@ +#!/usr/bin/env bash + +#SBATCH --partition gpu3 +#SBATCH --qos=preempt +#SBATCH --nodes 1 +#SBATCH --ntasks-per-node 1 +#SBATCH --gpus-per-node 1 +#SBATCH --cpus-per-task 8 +#SBATCH --time 3-0 +#SBATCH --mem-per-cpu=32G +#SBATCH --array 4-10 + +eval "$(conda shell.bash hook)" +conda activate jamun + +set -eux + +echo "SLURM_JOB_ID = ${SLURM_JOB_ID}" +echo "hostname = $(hostname)" + +export HYDRA_FULL_ERROR=1 +# export TORCH_COMPILE_DEBUG=1 +# export TORCH_LOGS="+dynamo" +# export TORCHDYNAMO_VERBOSE=1 + +# NOTE: We generate this in submit script instead of using time-based default to ensure consistency across ranks. +RUN_KEY=$(openssl rand -hex 12) +echo "RUN_KEY = ${RUN_KEY}" + +nvidia-smi + +srun --cpus-per-task 8 --cpu-bind=cores,verbose \ + jamun_train --config-dir=/homefs/home/sules/jamun/configs \ + experiment=train_test_single_shape_conditional.yaml \ + ++model.conditioner._target_=jamun.model.conditioners.SelfConditioner \ + ++trainer.devices=$SLURM_GPUS_PER_NODE \ + ++trainer.num_nodes=$SLURM_JOB_NUM_NODES \ + ++logger.wandb.tags=["'${SLURM_JOB_ID}'","'${RUN_KEY}'","train","capped_2AA, ALA_ALA, conditional denoiser"] \ + ++run_key=$RUN_KEY \ No newline at end of file diff --git a/scripts/slurm/train_capped_2AA_comparison.sh b/scripts/slurm/train_capped_2AA_comparison.sh new file mode 100755 index 0000000..fc976f5 --- /dev/null +++ b/scripts/slurm/train_capped_2AA_comparison.sh @@ -0,0 +1,85 @@ +#!/usr/bin/env bash + +#SBATCH --partition=gpu2 +#SBATCH --job-name=capped_2AA_comparison +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=1 +#SBATCH --cpus-per-task=8 +#SBATCH --time=3-0 +#SBATCH --mem-per-cpu=32G +#SBATCH --array=0-4 + +# Initialize conda +source ~/.bashrc +eval "$(conda shell.bash hook)" +conda activate jamun + +# Verify conda activation worked +which python +echo "Python path: $(which python)" +echo "Conda environment: $CONDA_DEFAULT_ENV" +nvidia-smi + +echo "Running array job ${SLURM_ARRAY_TASK_ID}" + +# NOTE: We generate this in submit script instead of using time-based default to ensure consistency across ranks. +RUN_KEY=$(openssl rand -hex 12) +echo "RUN_KEY = ${RUN_KEY}" + +# Define configurations for each job +case ${SLURM_ARRAY_TASK_ID} in + 0) + echo "Job 0: Standard JAMUN on 2AA capped diamines" + CONFIG="train_capped_2AA" + OVERRIDES="" + ;; + 1) + echo "Job 1: Position conditioner on 2AA capped diamines" + CONFIG="train_capped_2AA_position_conditioner" + OVERRIDES="" + ;; + 2) + echo "Job 2: Self conditioner on 2AA capped diamines" + CONFIG="train_capped_2AA_self_conditioner" + OVERRIDES="" + ;; + 3) + echo "Job 3: Spatiotemporal conditioner with temporal embedding and mean pooling on 2AA capped diamines" + CONFIG="train_capped_2AA_spatiotemporal_conditioner" + OVERRIDES="++model.conditioner.spatiotemporal_model.temporal_to_spatial_pooler._target_=jamun.model.pooling.TemporalToSpatialNodeAttrMean" + ;; + 4) + echo "Job 4: Spatiotemporal conditioner with ones temporal encoding and mean pooling on 2AA capped diamines" + CONFIG="train_capped_2AA_spatiotemporal_conditioner" + OVERRIDES="++model.conditioner.spatiotemporal_model.temporal_to_spatial_pooler._target_=jamun.model.pooling.TemporalToSpatialNodeAttrMean ++model.conditioner.spatiotemporal_model.temporal_module.node_attr_temporal_encoding_function=ones ++model.conditioner.spatiotemporal_model.temporal_module.edge_attr_temporal_encoding_function=ones" + ;; + *) + echo "Unknown job ID: ${SLURM_ARRAY_TASK_ID}" + exit 1 + ;; +esac + +# Build the command with base config +CMD="jamun_train --config-dir=configs experiment=${CONFIG}.yaml" + +# Add overrides if any +if [ -n "$OVERRIDES" ]; then + CMD="$CMD $OVERRIDES" +fi + +# Add common training overrides +CMD="$CMD ++trainer.max_epochs=100" +CMD="$CMD ++logger.wandb.group=capped_2AA_model_comparison" +CMD="$CMD ++run_key=${RUN_KEY}" + +# Add dataset overrides for debugging (quick completion) +# CMD="$CMD ++data.datamodule.datasets.train.max_datasets=1" +# CMD="$CMD ++data.datamodule.datasets.val.max_datasets=1" + +# Add job-specific wandb tags +WANDB_TAG="job_${SLURM_ARRAY_TASK_ID}" +CMD="$CMD ++logger.wandb.tags=[${WANDB_TAG},capped_2AA_comparison,generalization_test]" + +echo "Running command: $CMD" +exec $CMD diff --git a/scripts/slurm/train_capped_2AA_conditional.sh b/scripts/slurm/train_capped_2AA_conditional.sh new file mode 100644 index 0000000..5384bd9 --- /dev/null +++ b/scripts/slurm/train_capped_2AA_conditional.sh @@ -0,0 +1,37 @@ +#!/usr/bin/env bash + +#SBATCH --partition gpu3 +#SBATCH --qos=preempt +#SBATCH --nodes 1 +#SBATCH --ntasks-per-node 4 +#SBATCH --gpus-per-node 4 +#SBATCH --cpus-per-task 8 +#SBATCH --time 3-0 +#SBATCH --mem-per-cpu=32G + +eval "$(conda shell.bash hook)" +conda activate jamun + +set -eux + +echo "SLURM_JOB_ID = ${SLURM_JOB_ID}" +echo "hostname = $(hostname)" + +export HYDRA_FULL_ERROR=1 +# export TORCH_COMPILE_DEBUG=1 +# export TORCH_LOGS="+dynamo" +# export TORCHDYNAMO_VERBOSE=1 + +# NOTE: We generate this in submit script instead of using time-based default to ensure consistency across ranks. +RUN_KEY=$(openssl rand -hex 12) +echo "RUN_KEY = ${RUN_KEY}" + +nvidia-smi + +# srun --cpus-per-task 8 --cpu-bind=cores,verbose \ +jamun_train --config-dir=configs \ + experiment=train_capped_2AA_conditional.yaml \ + ++trainer.devices=$SLURM_GPUS_PER_NODE \ + ++trainer.num_nodes=$SLURM_JOB_NUM_NODES \ + ++logger.wandb.tags=["'${SLURM_JOB_ID}'","'${RUN_KEY}'","train","capped_2AA"] \ + ++run_key=$RUN_KEY diff --git a/scripts/slurm/train_enhanced_long_comparison.sh b/scripts/slurm/train_enhanced_long_comparison.sh new file mode 100644 index 0000000..9874b5a --- /dev/null +++ b/scripts/slurm/train_enhanced_long_comparison.sh @@ -0,0 +1,119 @@ +#!/usr/bin/env bash + +#SBATCH --partition=b200 +#SBATCH --job-name=enhanced_long_comparison + #SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=1 +#SBATCH --cpus-per-task=8 +#SBATCH --time=3-0 +#SBATCH --mem-per-cpu=32G +#SBATCH --array=0-5 + +# Initialize conda +source ~/.bashrc +eval "$(conda shell.bash hook)" +conda activate jamun + +# Verify conda activation worked +which python +echo "Python path: $(which python)" +echo "Conda environment: $CONDA_DEFAULT_ENV" +nvidia-smi + +echo "Running array job ${SLURM_ARRAY_TASK_ID}" + +# NOTE: We generate this in submit script instead of using time-based default to ensure consistency across ranks. +RUN_KEY=$(openssl rand -hex 12) +echo "RUN_KEY = ${RUN_KEY}" + +# Define configurations for each job +case ${SLURM_ARRAY_TASK_ID} in + 0) + echo "Job 0: Standard JAMUN on enhanced_long, noise 0.04" + CONFIG="train_enhanced_standard_jamun" + DATA_PATH="/data2/sules/ALA_ALA_enhanced_long" + WANDB_GROUP="model_comparison_enhanced_long_take2" + NOISE_LEVEL="0.04" + RUN_NAME="enhanced_long_standard_noise0.04" + ;; + 1) + echo "Job 1: Spatiotemporal JAMUN on enhanced_long, noise 0.04" + CONFIG="train_enhanced_spatiotemporal_conditioner" + DATA_PATH="/data2/sules/ALA_ALA_enhanced_long" + WANDB_GROUP="model_comparison_enhanced_long_take2" + NOISE_LEVEL="0.04" + RUN_NAME="enhanced_long_spatiotemporal_noise0.04" + ;; + 2) + echo "Job 2: Standard JAMUN on enhanced_long, noise 0.06" + CONFIG="train_enhanced_standard_jamun" + DATA_PATH="/data2/sules/ALA_ALA_enhanced_long" + WANDB_GROUP="model_comparison_enhanced_long_take2" + NOISE_LEVEL="0.06" + RUN_NAME="enhanced_long_standard_noise0.06" + ;; + 3) + echo "Job 3: Spatiotemporal JAMUN on enhanced_long, noise 0.06" + CONFIG="train_enhanced_spatiotemporal_conditioner" + DATA_PATH="/data2/sules/ALA_ALA_enhanced_long" + WANDB_GROUP="model_comparison_enhanced_long_take2" + NOISE_LEVEL="0.06" + RUN_NAME="enhanced_long_spatiotemporal_noise0.06" + ;; + 4) + echo "Job 4: Standard JAMUN on enhanced_long_state_split, noise 0.04" + CONFIG="train_enhanced_standard_jamun" + DATA_PATH="/data2/sules/ALA_ALA_enhanced_long_state_split" + WANDB_GROUP="withheld_state_take2" + NOISE_LEVEL="0.04" + RUN_NAME="enhanced_long_state_split_standard_noise0.04" + ;; + 5) + echo "Job 5: Spatiotemporal JAMUN on enhanced_long_state_split, noise 0.04" + CONFIG="train_enhanced_spatiotemporal_conditioner" + DATA_PATH="/data2/sules/ALA_ALA_enhanced_long_state_split" + WANDB_GROUP="withheld_state_take2" + NOISE_LEVEL="0.04" + RUN_NAME="enhanced_long_state_split_spatiotemporal_noise0.04" + ;; + *) + echo "Unknown job ID: ${SLURM_ARRAY_TASK_ID}" + exit 1 + ;; +esac + +# Build the command with base config +CMD="jamun_train --config-dir=configs experiment=${CONFIG}.yaml" + +# Add common training parameters +CMD="$CMD ++data.datamodule.datasets.train.root=${DATA_PATH}/train" +CMD="$CMD ++data.datamodule.datasets.val.root=${DATA_PATH}/val" +CMD="$CMD ++data.datamodule.datasets.train.num_frames=10000" +CMD="$CMD ++data.datamodule.datasets.val.num_frames=10000" +CMD="$CMD ++data.datamodule.datasets.train.lag_subsample_rate=1" +CMD="$CMD ++data.datamodule.datasets.val.lag_subsample_rate=1" +CMD="$CMD ++data.datamodule.datasets.train.subsample=10" +CMD="$CMD ++data.datamodule.datasets.val.subsample=10" +CMD="$CMD ++data.datamodule.datasets.train.total_lag_time=5" +CMD="$CMD ++data.datamodule.datasets.val.total_lag_time=5" + +# Add test run parameters (change for full training) +CMD="$CMD ++data.datamodule.datasets.train.max_datasets=200" +CMD="$CMD ++data.datamodule.datasets.val.max_datasets=50" +CMD="$CMD ++trainer.max_epochs=100" +CMD="$CMD ++trainer.val_check_interval=0.2" + +# Add noise level +CMD="$CMD ++model.sigma_distribution.sigma=${NOISE_LEVEL}" + +# Add wandb configuration +CMD="$CMD ++logger.wandb.group=${WANDB_GROUP}" +CMD="$CMD ++logger.wandb.tags=[${RUN_NAME},enhanced_long_comparison,job_${SLURM_ARRAY_TASK_ID}]" +CMD="$CMD ++run_key=${RUN_KEY}" + +# # Add experiment name +CMD="$CMD ++logger.wandb.name=${RUN_NAME}" + +echo "Running command: $CMD" +exec $CMD diff --git a/scripts/slurm/train_enhanced_sampling_conditioners.sh b/scripts/slurm/train_enhanced_sampling_conditioners.sh new file mode 100644 index 0000000..8812440 --- /dev/null +++ b/scripts/slurm/train_enhanced_sampling_conditioners.sh @@ -0,0 +1,81 @@ +#!/bin/bash + +#SBATCH --job-name=conditioner_lag_sweep +#SBATCH --array=0-5 +#SBATCH --partition=gpu2 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=1 +#SBATCH --cpus-per-task=12 +#SBATCH --time=1-0 + +# Experiment: Sweep over SelfConditioner, PositionConditioner, and DenoisedConditioner +# with different total lag times (2, 5, 8) on enhanced sampling data with 2 layers +# Testing mode: 1 epoch, max_datasets=1 + +# Array of config names to run (base configs) +CONFIG_NAMES=( + "train_enhanced_self_conditioner" + "train_enhanced_position_conditioner" + "train_enhanced_denoised_conditioner" +) + +# Array of conditioner names for logging +CONDITIONER_NAMES=( + "SelfConditioner" + "PositionConditioner" + "DenoisedConditioner" +) + +# Array of lag times to test +LAG_TIMES=(2 5) + +# Calculate which conditioner and lag time based on array index +# Array index 0-5 maps to: +# 0-1: SelfConditioner with lag_time 2,5 +# 2-3: PositionConditioner with lag_time 2,5 +# 4-5: DenoisedConditioner with lag_time 2,5 +CONDITIONER_IDX=$((SLURM_ARRAY_TASK_ID / 2)) +LAG_TIME_IDX=$((SLURM_ARRAY_TASK_ID % 2)) + +CONFIG_NAME=${CONFIG_NAMES[$CONDITIONER_IDX]} +CONDITIONER_NAME=${CONDITIONER_NAMES[$CONDITIONER_IDX]} +LAG_TIME=${LAG_TIMES[$LAG_TIME_IDX]} + +echo "=== SLURM Array Job ${SLURM_ARRAY_TASK_ID}: Training ${CONDITIONER_NAME} with lag_time=${LAG_TIME} ===" +echo "Job ID: ${SLURM_JOB_ID}" +echo "Array Task ID: ${SLURM_ARRAY_TASK_ID}" +echo "Conditioner Index: ${CONDITIONER_IDX}" +echo "Lag Time Index: ${LAG_TIME_IDX}" +echo "Config: ${CONFIG_NAME}" +echo "Lag Time: ${LAG_TIME}" +echo "Starting at $(date)" +echo "" + +# Set environment variables +export JAMUN_DATA_PATH=/data/bucket/kleinhej/ +export WANDB_PROJECT=jamun + +# Activate conda environment +source ~/.bashrc +conda activate jamun + +# Create logs directory if it doesn't exist +mkdir -p logs + +# Build command with testing overrides and lag time sweep +CMD="jamun_train --config-dir=configs experiment=${CONFIG_NAME}" +CMD="${CMD} ++data.datamodule.datasets.train.total_lag_time=${LAG_TIME}" +CMD="${CMD} ++data.datamodule.datasets.val.total_lag_time=${LAG_TIME}" +CMD="${CMD} ++trainer.max_epochs=100" + +echo "Running command:" +echo "${CMD}" +echo "" + +eval ${CMD} + +echo "" +echo "=== ${CONDITIONER_NAME} (lag_time=${LAG_TIME}) Training Complete ===" +echo "Finished at $(date)" +echo "Check Weights & Biases group 'conditioner_lag_sweep_test' for results" \ No newline at end of file diff --git a/scripts/slurm/train_enhanced_sampling_noise_conditioner_sweep.sh b/scripts/slurm/train_enhanced_sampling_noise_conditioner_sweep.sh new file mode 100755 index 0000000..043b0e2 --- /dev/null +++ b/scripts/slurm/train_enhanced_sampling_noise_conditioner_sweep.sh @@ -0,0 +1,76 @@ +#!/usr/bin/env bash + +#SBATCH --partition gpu2 +#SBATCH --qos=preempt +#SBATCH --nodes 1 +#SBATCH --ntasks-per-node 1 +#SBATCH --gpus-per-node 1 +#SBATCH --cpus-per-task 8 +#SBATCH --time 1-0 +#SBATCH --mem-per-cpu=32G +#SBATCH --array=0 + +# Initialize conda +source ~/.bashrc +eval "$(conda shell.bash hook)" +conda activate jamun + +# Verify conda activation worked +which python +echo "Python path: $(which python)" +echo "Conda environment: $CONDA_DEFAULT_ENV" + +set -eux + +echo "SLURM_JOB_ID = ${SLURM_JOB_ID}" +echo "SLURM_ARRAY_TASK_ID = ${SLURM_ARRAY_TASK_ID}" +echo "hostname = $(hostname)" + +export HYDRA_FULL_ERROR=1 + +# NOTE: We generate this in submit script instead of using time-based default to ensure consistency across ranks. +RUN_KEY=$(openssl rand -hex 12) +echo "RUN_KEY = ${RUN_KEY}" + +# Define parameter arrays +CONDITIONERS=("jamun.model.conditioners.PositionConditioner" "jamun.model.conditioners.SelfConditioner") +CONDITIONER_NAMES=("PositionConditioner" "SelfConditioner") +SIGMAS=(0.01 0.04 0.08 0.1) +LAG_TIMES=(2 5 8) + +# Calculate parameter indices from SLURM_ARRAY_TASK_ID +# Total combinations: 2 conditioners * 4 sigmas * 3 lag_times = 24 +COND_IDX=$((SLURM_ARRAY_TASK_ID / 12)) +SIGMA_IDX=$(((SLURM_ARRAY_TASK_ID % 12) / 3)) +LAG_IDX=$((SLURM_ARRAY_TASK_ID % 3)) + +# Get parameter values +CONDITIONER=${CONDITIONERS[$COND_IDX]} +CONDITIONER_NAME=${CONDITIONER_NAMES[$COND_IDX]} +SIGMA=${SIGMAS[$SIGMA_IDX]} +LAG_TIME=${LAG_TIMES[$LAG_IDX]} + +echo "TRIAL RUN - Parameter combination ${SLURM_ARRAY_TASK_ID}:" +echo " Conditioner: ${CONDITIONER_NAME}" +echo " Sigma: ${SIGMA}" +echo " Total lag time: ${LAG_TIME}" + +nvidia-smi + +# Run training with parameter overrides (reduced epochs for trial) +jamun_train --config-dir=configs \ + experiment=train_test_single_shape_enhanced_sampling.yaml \ + ++trainer.max_epochs=2 \ + ++data.datamodule.datasets.train.subsample=10 \ + ++data.datamodule.datasets.val.subsample=10 \ + ++data.datamodule.datasets.test.subsample=10 \ + ++data.datamodule.datasets.train.max_datasets=5 \ + ++data.datamodule.datasets.val.max_datasets=5 \ + ++data.datamodule.datasets.test.max_datasets=5 \ + ++model.conditioner._target_=${CONDITIONER} \ + ++model.sigma_distribution.sigma=${SIGMA} \ + ++data.datamodule.datasets.train.total_lag_time=${LAG_TIME} \ + ++model.arch.N_structures=${LAG_TIME} \ + ++logger.wandb.group="fake_enhanced_data_trial" \ + ++logger.wandb.tags=["'${SLURM_JOB_ID}'","'${RUN_KEY}'","trial","enhanced_sampling","${CONDITIONER_NAME}","sigma_${SIGMA}","lag_${LAG_TIME}"] \ + ++run_key=$RUN_KEY \ No newline at end of file diff --git a/scripts/slurm/train_enhanced_sampling_single.sh b/scripts/slurm/train_enhanced_sampling_single.sh new file mode 100755 index 0000000..a57ad75 --- /dev/null +++ b/scripts/slurm/train_enhanced_sampling_single.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +#SBATCH --partition gpu2 +#SBATCH --nodes 1 +#SBATCH --ntasks-per-node 1 +#SBATCH --gpus-per-node 1 +#SBATCH --cpus-per-task 8 +#SBATCH --time 12:00:00 +#SBATCH --mem-per-cpu=32G + +# Initialize conda +source ~/.bashrc +eval "$(conda shell.bash hook)" +conda activate jamun + +# Verify conda activation worked +which python +echo "Python path: $(which python)" +echo "Conda environment: $CONDA_DEFAULT_ENV" +nvidia-smi + +# Run training with parameter overrides +jamun_train --config-dir=configs experiment=train_enhanced_pretrained_spatiotemporal_conditioner.yaml model.conditioner.spatiotemporal_model.temporal_to_spatial_pooler._target_=jamun.model.pooling.TemporalToSpatialNodeAttrMean model.conditioner.spatiotemporal_model.spatial_module.trainable=false trainer.max_epochs=10 logger.wandb.tags=[job_4,spatiotemporal_comparison] \ No newline at end of file diff --git a/scripts/slurm/train_enhanced_sampling_spatiotemporal_comparison.sh b/scripts/slurm/train_enhanced_sampling_spatiotemporal_comparison.sh new file mode 100755 index 0000000..f38a064 --- /dev/null +++ b/scripts/slurm/train_enhanced_sampling_spatiotemporal_comparison.sh @@ -0,0 +1,91 @@ +#!/usr/bin/env bash + +#SBATCH --partition=gpu2 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=1 +#SBATCH --cpus-per-task=8 +#SBATCH --time=3-0 +#SBATCH --mem-per-cpu=32G +#SBATCH --array=0-5 + +# Initialize conda +source ~/.bashrc +eval "$(conda shell.bash hook)" +conda activate jamun + +# Verify conda activation worked +which python +echo "Python path: $(which python)" +echo "Conda environment: $CONDA_DEFAULT_ENV" +nvidia-smi + +echo "Running array job ${SLURM_ARRAY_TASK_ID}" + +# Define configurations for each job +case ${SLURM_ARRAY_TASK_ID} in + 0) + echo "Job 0: Initialized spatial module + TemporalToSpatialNodeAttrMean" + CONFIG="train_enhanced_spatiotemporal_conditioner" + POOLER="jamun.model.pooling.TemporalToSpatialNodeAttrMean" + TRAINABLE_OVERRIDE="" + ;; + 1) + echo "Job 1: Initialized spatial module + TemporalToSpatialNodeAttr" + CONFIG="train_enhanced_spatiotemporal_conditioner" + POOLER="jamun.model.pooling.TemporalToSpatialNodeAttr" + TRAINABLE_OVERRIDE="" + ;; + 2) + echo "Job 2: Pretrained trainable spatial module + TemporalToSpatialNodeAttrMean" + CONFIG="train_enhanced_pretrained_spatiotemporal_conditioner" + POOLER="jamun.model.pooling.TemporalToSpatialNodeAttrMean" + TRAINABLE_OVERRIDE="model.conditioner.spatiotemporal_model.spatial_module.trainable=true" + ;; + 3) + echo "Job 3: Pretrained trainable spatial module + TemporalToSpatialNodeAttr" + CONFIG="train_enhanced_pretrained_spatiotemporal_conditioner" + POOLER="jamun.model.pooling.TemporalToSpatialNodeAttr" + TRAINABLE_OVERRIDE="model.conditioner.spatiotemporal_model.spatial_module.trainable=true" + ;; + 4) + echo "Job 4: Pretrained non-trainable spatial module + TemporalToSpatialNodeAttrMean" + CONFIG="train_enhanced_pretrained_spatiotemporal_conditioner" + POOLER="jamun.model.pooling.TemporalToSpatialNodeAttrMean" + TRAINABLE_OVERRIDE="model.conditioner.spatiotemporal_model.spatial_module.trainable=false" + ;; + 5) + echo "Job 5: Pretrained non-trainable spatial module + TemporalToSpatialNodeAttr" + CONFIG="train_enhanced_pretrained_spatiotemporal_conditioner" + POOLER="jamun.model.pooling.TemporalToSpatialNodeAttr" + TRAINABLE_OVERRIDE="model.conditioner.spatiotemporal_model.spatial_module.trainable=false" + ;; + *) + echo "Unknown job ID: ${SLURM_ARRAY_TASK_ID}" + exit 1 + ;; +esac + +# Build the command with overrides +CMD="jamun_train --config-dir=configs experiment=${CONFIG}.yaml" + +# Add pooler override (keeping the irreps_out parameter from base config) +CMD="$CMD ++model.conditioner.spatiotemporal_model.temporal_to_spatial_pooler._target_=${POOLER}" + +# Add trainable override if needed +if [ -n "$TRAINABLE_OVERRIDE" ]; then + CMD="$CMD $TRAINABLE_OVERRIDE" +fi + +# Add dataset and training overrides +# CMD="$CMD data.datamodule.datasets.train.max_datasets=1" +# CMD="$CMD data.datamodule.datasets.val.max_datasets=1" +CMD="$CMD ++trainer.max_epochs=100" +CMD="$CMD ++wandb.logger.group=spatiotemporal_comparison" + +# Add job-specific wandb tags +WANDB_TAG="job_${SLURM_ARRAY_TASK_ID}" +# CMD="$CMD ++wandb.logger.tags=[${WANDB_TAG},spatiotemporal_comparison]" + +echo "Running command: $CMD" +exec $CMD \ No newline at end of file diff --git a/scripts/slurm/train_graph_type_comparison.sh b/scripts/slurm/train_graph_type_comparison.sh new file mode 100755 index 0000000..e2f13ed --- /dev/null +++ b/scripts/slurm/train_graph_type_comparison.sh @@ -0,0 +1,138 @@ +#!/usr/bin/env bash +# +# Graph type comparison experiment script +# +# Iterates over: +# - Lag subsample rates: 1, 2, 3, 4 +# - Total lag times: 2, 4, 6, 8 +# - Configs: train_test_single_shape.yaml, train_test_single_shape_conditional.yaml, train_test_single_shape_spatiotemporal_conditioner.yaml +# - For spatiotemporal: hub_n_spoke vs complete graph types (both with ones encoding) +# +# Total jobs: 4 lag_subsample_rates × 4 total_lag_times × (1 conditional + 2 spatiotemporal variants) = 48 jobs +# +#SBATCH --partition=gpu3 +#SBATCH --job-name=graph_type_comparison +#SBATCH --qos=preempt +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-task=1 +#SBATCH --cpus-per-task=8 +#SBATCH --time=3-0 +#SBATCH --mem-per-cpu=32G +#SBATCH --array=0-15 + +# Set up the environment +source ~/.bashrc +eval "$(conda shell.bash hook)" +conda activate jamun + +# Define experiment parameters +LAG_SUBSAMPLE_RATES=(1 2 3 4) +TOTAL_LAG_TIMES=(2 4 6 8) +CONFIGS=( + # "train_test_single_shape" + "train_test_single_shape_conditional" + "train_test_single_shape_spatiotemporal_conditioner_default" + "train_test_single_shape_spatiotemporal_conditioner_hub_spoke" + # "train_test_single_shape_spatiotemporal_conditioner_complete" +) + +NUM_CONFIGS=1 +NUM_LAG_TIMES=4 +NUM_LAG_SUBSAMPLE_RATES=4 + +# Calculate indices +LAG_SUBSAMPLE_INDEX=$((SLURM_ARRAY_TASK_ID / (NUM_LAG_TIMES * NUM_CONFIGS))) +REMAINDER=$((SLURM_ARRAY_TASK_ID % (NUM_LAG_TIMES * NUM_CONFIGS))) +LAG_TIME_INDEX=$((REMAINDER / NUM_CONFIGS)) +CONFIG_INDEX=$((REMAINDER % NUM_CONFIGS)) + +LAG_SUBSAMPLE_RATE=${LAG_SUBSAMPLE_RATES[$LAG_SUBSAMPLE_INDEX]} +TOTAL_LAG_TIME=${TOTAL_LAG_TIMES[$LAG_TIME_INDEX]} +CONFIG_TYPE=${CONFIGS[$CONFIG_INDEX]} + +echo "Job ${SLURM_ARRAY_TASK_ID}: lag_subsample_rate=${LAG_SUBSAMPLE_RATE}, total_lag_time=${TOTAL_LAG_TIME}, config=${CONFIG_TYPE}" + +# Generate unique run key +RUN_KEY=$(openssl rand -hex 12) +echo "RUN_KEY = ${RUN_KEY}" + +# Calculate max_datasets using the formula: floor(250 * (49 / (lag_subsample_rate * (total_lag_time - 1) - 1))) +DENOMINATOR=$((LAG_SUBSAMPLE_RATE * (TOTAL_LAG_TIME - 1) - 1)) +if [ $DENOMINATOR -le 0 ]; then + echo "Warning: Invalid denominator ($DENOMINATOR) for max_datasets calculation. Setting to 250." + MAX_DATASETS=250 +else + # Using bc for floating point calculation and floor function + MAX_DATASETS=$(echo "scale=0; 250 * 49 / $DENOMINATOR" | bc) +fi +echo "Calculated max_datasets = $MAX_DATASETS (lag_subsample_rate=$LAG_SUBSAMPLE_RATE, total_lag_time=$TOTAL_LAG_TIME)" + +# Set base configuration and overrides based on config type +case $CONFIG_INDEX in + # 0) # Standard JAMUN + # CONFIG="train_test_single_shape" + # OVERRIDES="" + # WANDB_TAG="standard_jamun_lag_${LAG_SUBSAMPLE_RATE}_time_${TOTAL_LAG_TIME}" + # ;; + # 0) # Position Conditioner + # CONFIG="train_enhanced_position_conditioner" + # OVERRIDES="++data.datamodule.datasets.train.total_lag_time=${TOTAL_LAG_TIME}" + # OVERRIDES="$OVERRIDES ++data.datamodule.datasets.train.lag_subsample_rate=${LAG_SUBSAMPLE_RATE}" + # OVERRIDES="$OVERRIDES ++model.arch.N_structures=${TOTAL_LAG_TIME}" + # OVERRIDES="$OVERRIDES ++model.conditioner.N_structures=${TOTAL_LAG_TIME}" + # WANDB_TAG="position_conditioner_lag_${LAG_SUBSAMPLE_RATE}_time_${TOTAL_LAG_TIME}" + # ;; + 0) # SpatioTemporal Conditioner - Default (fan graph, temporal encoding) + CONFIG="train_enhanced_spatiotemporal_conditioner" + OVERRIDES="++data.datamodule.datasets.train.total_lag_time=${TOTAL_LAG_TIME}" + OVERRIDES="$OVERRIDES ++data.datamodule.datasets.train.lag_subsample_rate=${LAG_SUBSAMPLE_RATE}" + WANDB_TAG="spatiotemporal_default_fan_temporal_lag_${LAG_SUBSAMPLE_RATE}_time_${TOTAL_LAG_TIME}" + ;; + # 2) # SpatioTemporal Conditioner - Hub & Spoke + # CONFIG="train_enhanced_spatiotemporal_conditioner" + # OVERRIDES="++data.datamodule.datasets.train.total_lag_time=${TOTAL_LAG_TIME}" + # OVERRIDES="$OVERRIDES ++data.datamodule.datasets.train.lag_subsample_rate=${LAG_SUBSAMPLE_RATE}" + # OVERRIDES="$OVERRIDES ++model.conditioner.spatiotemporal_model.graph_type=hub_n_spoke" + # OVERRIDES="$OVERRIDES ++model.conditioner.spatiotemporal_model.temporal_module.node_attr_temporal_encoding_function=ones" + # OVERRIDES="$OVERRIDES ++model.conditioner.spatiotemporal_model.temporal_module.edge_attr_temporal_encoding_function=ones" + # WANDB_TAG="spatiotemporal_hub_spoke_ones_lag_${LAG_SUBSAMPLE_RATE}_time_${TOTAL_LAG_TIME}" + # ;; + # 4) # SpatioTemporal Conditioner - Complete + # CONFIG="train_test_single_shape_spatiotemporal_conditioner" + # OVERRIDES="++data.datamodule.datasets.train.total_lag_time=${TOTAL_LAG_TIME}" + # OVERRIDES="$OVERRIDES ++data.datamodule.datasets.train.lag_subsample_rate=${LAG_SUBSAMPLE_RATE}" + # OVERRIDES="$OVERRIDES ++model.conditioner.spatiotemporal_model.graph_type=complete" + # OVERRIDES="$OVERRIDES ++model.conditioner.spatiotemporal_model.temporal_module.node_attr_temporal_encoding_function=ones" + # OVERRIDES="$OVERRIDES ++model.conditioner.spatiotemporal_model.temporal_module.edge_attr_temporal_encoding_function=ones" + # WANDB_TAG="spatiotemporal_complete_ones_lag_${LAG_SUBSAMPLE_RATE}_time_${TOTAL_LAG_TIME}" + # ;; +esac + +# Build command +CMD="jamun_train --config-dir=configs experiment=${CONFIG}.yaml" + +# Add overrides if any +if [ -n "$OVERRIDES" ]; then + CMD="$CMD $OVERRIDES" +fi + +# Calculate validation max_datasets using multiplier of 50 +VAL_MAX_DATASETS=$(echo "scale=0; 50 * 49 / $DENOMINATOR" | bc) +if [ $DENOMINATOR -le 0 ]; then + VAL_MAX_DATASETS=50 +fi + +# Add common overrides +CMD="$CMD ++run_key=${RUN_KEY}" +CMD="$CMD ++data.datamodule.datasets.train.max_datasets=${MAX_DATASETS}" +CMD="$CMD ++data.datamodule.datasets.val.max_datasets=${VAL_MAX_DATASETS}" +CMD="$CMD ++logger.wandb.group=graph_type_comparison_experiment_enhanced_sampling_data_onlyfan_aug17" +CMD="$CMD ++logger.wandb.tags=[${WANDB_TAG},graph_comparison,lag_subsample_${LAG_SUBSAMPLE_RATE},total_lag_${TOTAL_LAG_TIME}]" + +# Set wandb run name +WANDB_RUN_NAME="graph_comparison_${WANDB_TAG}_enhanced_sampling_data" +CMD="$CMD ++logger.wandb.name=${WANDB_RUN_NAME}" + +echo "Running command: $CMD" +exec $CMD diff --git a/scripts/slurm/train_model_comparison.sh b/scripts/slurm/train_model_comparison.sh new file mode 100644 index 0000000..ac4da0f --- /dev/null +++ b/scripts/slurm/train_model_comparison.sh @@ -0,0 +1,86 @@ +#!/usr/bin/env bash + +#SBATCH --partition=gpu3 +#SBATCH --job-name=model_comparison +#SBATCH --qos=preempt +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=1 +#SBATCH --cpus-per-task=8 +#SBATCH --time=3-0 +#SBATCH --mem-per-cpu=32G +#SBATCH --array=0-4 + +# Initialize conda +source ~/.bashrc +eval "$(conda shell.bash hook)" +conda activate jamun + +# Verify conda activation worked +which python +echo "Python path: $(which python)" +echo "Conda environment: $CONDA_DEFAULT_ENV" +nvidia-smi + +echo "Running array job ${SLURM_ARRAY_TASK_ID}" + +# NOTE: We generate this in submit script instead of using time-based default to ensure consistency across ranks. +RUN_KEY=$(openssl rand -hex 12) +echo "RUN_KEY = ${RUN_KEY}" + +# Define configurations for each job +case ${SLURM_ARRAY_TASK_ID} in + 0) + echo "Job 0: Standard JAMUN" + CONFIG="train_enhanced_standard_jamun" + OVERRIDES="" + ;; + 1) + echo "Job 1: Position conditioner" + CONFIG="train_enhanced_position_conditioner" + OVERRIDES="" + ;; + 2) + echo "Job 2: Self conditioner" + CONFIG="train_enhanced_self_conditioner" + OVERRIDES="" + ;; + 3) + echo "Job 3: Spatiotemporal conditioner with temporal embedding and mean pooling" + CONFIG="train_enhanced_spatiotemporal_conditioner" + OVERRIDES="++model.conditioner.spatiotemporal_model.temporal_to_spatial_pooler._target_=jamun.model.pooling.TemporalToSpatialNodeAttrMean" + ;; + 4) + echo "Job 4: Spatiotemporal conditioner with ones temporal encoding and mean pooling" + CONFIG="train_enhanced_spatiotemporal_conditioner" + OVERRIDES="++model.conditioner.spatiotemporal_model.temporal_to_spatial_pooler._target_=jamun.model.pooling.TemporalToSpatialNodeAttrMean ++model.conditioner.spatiotemporal_model.temporal_module.node_attr_temporal_encoding_function=ones ++model.conditioner.spatiotemporal_model.temporal_module.edge_attr_temporal_encoding_function=ones" + ;; + *) + echo "Unknown job ID: ${SLURM_ARRAY_TASK_ID}" + exit 1 + ;; +esac + +# Build the command with base config +CMD="jamun_train --config-dir=configs experiment=${CONFIG}.yaml" + +# Add overrides if any +if [ -n "$OVERRIDES" ]; then + CMD="$CMD $OVERRIDES" +fi + +# Add common training overrides +CMD="$CMD ++trainer.max_epochs=100" +CMD="$CMD ++logger.wandb.group=model_comparison" +CMD="$CMD ++run_key=${RUN_KEY}" + +# Add dataset overrides for debugging (quick completion) +# CMD="$CMD ++data.datamodule.datasets.train.max_datasets=1" +# CMD="$CMD ++data.datamodule.datasets.val.max_datasets=1" + +# Add job-specific wandb tags +WANDB_TAG="job_${SLURM_ARRAY_TASK_ID}" +CMD="$CMD ++logger.wandb.tags=[${WANDB_TAG},model_comparison,separable_conv]" + +echo "Running command: $CMD" +exec $CMD diff --git a/scripts/slurm/train_model_comparison_full_swarm.sh b/scripts/slurm/train_model_comparison_full_swarm.sh new file mode 100644 index 0000000..0ce02c7 --- /dev/null +++ b/scripts/slurm/train_model_comparison_full_swarm.sh @@ -0,0 +1,88 @@ +#!/usr/bin/env bash + +#SBATCH --partition=gpu3 +#SBATCH --job-name=model_comparison +#SBATCH --qos=preempt +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=1 +#SBATCH --cpus-per-task=8 +#SBATCH --time=3-0 +#SBATCH --mem-per-cpu=32G +#SBATCH --array=0-4 + +# Initialize conda +source ~/.bashrc +eval "$(conda shell.bash hook)" +conda activate jamun + +# Verify conda activation worked +which python +echo "Python path: $(which python)" +echo "Conda environment: $CONDA_DEFAULT_ENV" +nvidia-smi + +echo "Running array job ${SLURM_ARRAY_TASK_ID}" + +# NOTE: We generate this in submit script instead of using time-based default to ensure consistency across ranks. +RUN_KEY=$(openssl rand -hex 12) +echo "RUN_KEY = ${RUN_KEY}" + +# Define configurations for each job +case ${SLURM_ARRAY_TASK_ID} in + 0) + echo "Job 0: Standard JAMUN" + CONFIG="train_enhanced_standard_jamun" + OVERRIDES="" + ;; + 1) + echo "Job 1: Position conditioner" + CONFIG="train_enhanced_position_conditioner" + OVERRIDES="" + ;; + 2) + echo "Job 2: Self conditioner" + CONFIG="train_enhanced_self_conditioner" + OVERRIDES="" + ;; + 3) + echo "Job 3: Spatiotemporal conditioner with temporal embedding and mean pooling" + CONFIG="train_enhanced_spatiotemporal_conditioner" + OVERRIDES="++model.conditioner.spatiotemporal_model.temporal_to_spatial_pooler._target_=jamun.model.pooling.TemporalToSpatialNodeAttrMean" + ;; + 4) + echo "Job 4: Spatiotemporal conditioner with ones temporal encoding and mean pooling" + CONFIG="train_enhanced_spatiotemporal_conditioner" + OVERRIDES="++model.conditioner.spatiotemporal_model.temporal_to_spatial_pooler._target_=jamun.model.pooling.TemporalToSpatialNodeAttrMean ++model.conditioner.spatiotemporal_model.temporal_module.node_attr_temporal_encoding_function=ones ++model.conditioner.spatiotemporal_model.temporal_module.edge_attr_temporal_encoding_function=ones" + ;; + *) + echo "Unknown job ID: ${SLURM_ARRAY_TASK_ID}" + exit 1 + ;; +esac + +# Build the command with base config +CMD="jamun_train --config-dir=configs experiment=${CONFIG}.yaml" + +# Add overrides if any +if [ -n "$OVERRIDES" ]; then + CMD="$CMD $OVERRIDES" +fi + +# Add common training overrides +CMD="$CMD ++trainer.max_epochs=100" +CMD="$CMD ++logger.wandb.group=model_comparison_full_swarm" +CMD="$CMD ++data.datamodule.datasets.train.root=/data2/sules/ALA_ALA_enhanced_full_swarm/train" +CMD="$CMD ++data.datamodule.datasets.val.root=/data2/sules/ALA_ALA_enhanced_full_swarm/val" +CMD="$CMD ++run_key=${RUN_KEY}" + +# Add dataset overrides for debugging (quick completion) +# CMD="$CMD ++data.datamodule.datasets.train.max_datasets=1" +# CMD="$CMD ++data.datamodule.datasets.val.max_datasets=1" + +# Add job-specific wandb tags +WANDB_TAG="job_${SLURM_ARRAY_TASK_ID}" +CMD="$CMD ++logger.wandb.tags=[${WANDB_TAG},model_comparison,separable_conv]" + +echo "Running command: $CMD" +exec $CMD diff --git a/scripts/slurm/train_model_comparison_high_noise.sh b/scripts/slurm/train_model_comparison_high_noise.sh new file mode 100644 index 0000000..ad1e943 --- /dev/null +++ b/scripts/slurm/train_model_comparison_high_noise.sh @@ -0,0 +1,86 @@ +#!/usr/bin/env bash + +#SBATCH --partition=gpu2 +#SBATCH --job-name=model_comparison +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=1 +#SBATCH --cpus-per-task=8 +#SBATCH --time=3-0 +#SBATCH --mem-per-cpu=32G +#SBATCH --array=0-4 + +# Initialize conda +source ~/.bashrc +eval "$(conda shell.bash hook)" +conda activate jamun + +# Verify conda activation worked +which python +echo "Python path: $(which python)" +echo "Conda environment: $CONDA_DEFAULT_ENV" +nvidia-smi + +echo "Running array job ${SLURM_ARRAY_TASK_ID}" + +# NOTE: We generate this in submit script instead of using time-based default to ensure consistency across ranks. +RUN_KEY=$(openssl rand -hex 12) +echo "RUN_KEY = ${RUN_KEY}" + +# Define configurations for each job +case ${SLURM_ARRAY_TASK_ID} in + 0) + echo "Job 0: Standard JAMUN" + CONFIG="train_enhanced_standard_jamun" + OVERRIDES="" + ;; + 1) + echo "Job 1: Position conditioner" + CONFIG="train_enhanced_position_conditioner" + OVERRIDES="" + ;; + 2) + echo "Job 2: Self conditioner" + CONFIG="train_enhanced_self_conditioner" + OVERRIDES="" + ;; + 3) + echo "Job 3: Spatiotemporal conditioner with temporal embedding and mean pooling" + CONFIG="train_enhanced_spatiotemporal_conditioner" + OVERRIDES="++model.conditioner.spatiotemporal_model.temporal_to_spatial_pooler._target_=jamun.model.pooling.TemporalToSpatialNodeAttrMean" + ;; + 4) + echo "Job 4: Spatiotemporal conditioner with ones temporal encoding and mean pooling" + CONFIG="train_enhanced_spatiotemporal_conditioner" + OVERRIDES="++model.conditioner.spatiotemporal_model.temporal_to_spatial_pooler._target_=jamun.model.pooling.TemporalToSpatialNodeAttrMean ++model.conditioner.spatiotemporal_model.temporal_module.node_attr_temporal_encoding_function=ones ++model.conditioner.spatiotemporal_model.temporal_module.edge_attr_temporal_encoding_function=ones" + ;; + *) + echo "Unknown job ID: ${SLURM_ARRAY_TASK_ID}" + exit 1 + ;; +esac + +# Build the command with base config +CMD="jamun_train --config-dir=configs experiment=${CONFIG}.yaml" + +# Add overrides if any +if [ -n "$OVERRIDES" ]; then + CMD="$CMD $OVERRIDES" +fi + +# Add common training overrides +CMD="$CMD ++trainer.max_epochs=100" +CMD="$CMD ++logger.wandb.group=model_comparison_high_noise" +CMD="$CMD ++model.sigma_distribution.sigma=0.06" +CMD="$CMD ++run_key=${RUN_KEY}" + +# Add dataset overrides for debugging (quick completion) +# CMD="$CMD ++data.datamodule.datasets.train.max_datasets=1" +# CMD="$CMD ++data.datamodule.datasets.val.max_datasets=1" + +# Add job-specific wandb tags +WANDB_TAG="job_${SLURM_ARRAY_TASK_ID}" +CMD="$CMD ++logger.wandb.tags=[${WANDB_TAG},model_comparison,separable_conv]" + +echo "Running command: $CMD" +exec $CMD diff --git a/scripts/slurm/train_noise_check.sh b/scripts/slurm/train_noise_check.sh new file mode 100755 index 0000000..7be084c --- /dev/null +++ b/scripts/slurm/train_noise_check.sh @@ -0,0 +1,150 @@ +#!/usr/bin/env bash +# +# Noise check experiment script +# +# Implements 4 model configurations for a given m value (passed as command line argument): +# 1. Standard JAMUN with repeated position dataset and noise level sigma/sqrt(m) +# 2. Spatiotemporal JAMUN with repeated position dataset and total lag time m +# 3. Spatiotemporal JAMUN with total lag time m +# 4. Spatiotemporal JAMUN with total lag time m, hub_n_spoke graph type, ones encoding +# +# Usage: sbatch train_noise_check.sh +# Total jobs: 4 models (array 0-3) + +#SBATCH --partition=gpu2 +#SBATCH --job-name=noise_check +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=1 +#SBATCH --cpus-per-task=8 +#SBATCH --time=3-0 +#SBATCH --mem-per-cpu=32G +#SBATCH --array=1-3 + +# Initialize conda +source ~/.bashrc +eval "$(conda shell.bash hook)" +conda activate jamun + +# Verify conda activation worked +which python +echo "Python path: $(which python)" +echo "Conda environment: $CONDA_DEFAULT_ENV" +nvidia-smi + +# Get m value from command line argument +if [ $# -eq 0 ]; then + echo "Error: Please provide m value as command line argument" + echo "Usage: sbatch train_noise_check.sh " + exit 1 +fi + +M=$1 +echo "Running array job ${SLURM_ARRAY_TASK_ID} with M=${M}" + +# Define experiment parameters +# Model types: 4 types +# Total jobs: 4 models (array 0-3) + +MODEL_TYPES=( + "standard_jamun_repeated_pos" + "spatiotemporal_repeated_pos" + "spatiotemporal_default" + "spatiotemporal_hub_spoke_ones" +) + +# Calculate model index directly from array task ID +MODEL_INDEX=${SLURM_ARRAY_TASK_ID} +MODEL_TYPE=${MODEL_TYPES[$MODEL_INDEX]} + +echo "Job ${SLURM_ARRAY_TASK_ID}: M=${M}, Model=${MODEL_TYPE}" + +# Generate unique run key to prevent checkpoint overwrites +RUN_KEY=$(openssl rand -hex 12) +echo "RUN_KEY = ${RUN_KEY}" + +# Calculate noise level: sigma / sqrt(m) where base sigma = 0.04 +BASE_SIGMA=0.04 +NOISE_LEVEL=$(python3 -c "import math; print(${BASE_SIGMA} / math.sqrt(${M}))") + +echo "Noise level: ${NOISE_LEVEL}" + +# Configure base parameters based on model type +case ${MODEL_TYPE} in + # "standard_jamun_repeated_pos") + # echo "Model 1: Standard JAMUN with repeated position dataset and noise level sigma/sqrt(m)" + # CONFIG="train_enhanced_standard_jamun" + # OVERRIDES="++model.sigma_distribution.sigma=${NOISE_LEVEL}" + # OVERRIDES="${OVERRIDES} ++data.datamodule.datasets.train._target_=jamun.data.parse_repeated_position_datasets_from_directory" + # OVERRIDES="${OVERRIDES} ++data.datamodule.datasets.val._target_=jamun.data.parse_repeated_position_datasets_from_directory" + # OVERRIDES="${OVERRIDES} ++data.datamodule.datasets.train.total_lag_time=${M}" + # OVERRIDES="${OVERRIDES} ++data.datamodule.datasets.val.total_lag_time=${M}" + # OVERRIDES="${OVERRIDES} ++data.datamodule.datasets.train.max_datasets=500" + # OVERRIDES="${OVERRIDES} ++data.datamodule.datasets.val.max_datasets=100" + # WANDB_TAG="standard_jamun_repeated_pos_m${M}" + # ;; + "spatiotemporal_repeated_pos") + echo "Model 2: Spatiotemporal JAMUN with repeated position dataset and total lag time m" + CONFIG="train_enhanced_spatiotemporal_conditioner" + OVERRIDES="++data.datamodule.datasets.train._target_=jamun.data.parse_repeated_position_datasets_from_directory" + OVERRIDES="${OVERRIDES} ++data.datamodule.datasets.val._target_=jamun.data.parse_repeated_position_datasets_from_directory" + OVERRIDES="${OVERRIDES} ++data.datamodule.datasets.train.total_lag_time=${M}" + OVERRIDES="${OVERRIDES} ++data.datamodule.datasets.val.total_lag_time=${M}" + OVERRIDES="${OVERRIDES} ++data.datamodule.datasets.train.max_datasets=500" + OVERRIDES="${OVERRIDES} ++data.datamodule.datasets.val.max_datasets=100" + WANDB_TAG="spatiotemporal_repeated_pos_m${M}" + ;; + "spatiotemporal_default") + echo "Model 3: Spatiotemporal JAMUN with total lag time m" + CONFIG="train_enhanced_spatiotemporal_conditioner" + OVERRIDES="++data.datamodule.datasets.train.total_lag_time=${M}" + OVERRIDES="${OVERRIDES} ++data.datamodule.datasets.val.total_lag_time=${M}" + OVERRIDES="${OVERRIDES} ++data.datamodule.datasets.train.max_datasets=500" + OVERRIDES="${OVERRIDES} ++data.datamodule.datasets.val.max_datasets=100" + WANDB_TAG="spatiotemporal_default_m${M}" + ;; + "spatiotemporal_hub_spoke_ones") + echo "Model 4: Spatiotemporal JAMUN with total lag time m, hub_n_spoke graph type, ones encoding" + CONFIG="train_enhanced_spatiotemporal_conditioner" + OVERRIDES="++data.datamodule.datasets.train.total_lag_time=${M}" + OVERRIDES="${OVERRIDES} ++data.datamodule.datasets.val.total_lag_time=${M}" + OVERRIDES="${OVERRIDES} ++model.conditioner.spatiotemporal_model.graph_type=hub_n_spoke" + OVERRIDES="${OVERRIDES} ++model.conditioner.spatiotemporal_model.temporal_module.node_attr_temporal_encoding_function=ones" + OVERRIDES="${OVERRIDES} ++model.conditioner.spatiotemporal_model.temporal_module.edge_attr_temporal_encoding_function=ones" + OVERRIDES="${OVERRIDES} ++data.datamodule.datasets.train.max_datasets=500" + OVERRIDES="${OVERRIDES} ++data.datamodule.datasets.val.max_datasets=100" + WANDB_TAG="spatiotemporal_hub_spoke_ones_m${M}" + ;; + *) + echo "Unknown model type: ${MODEL_TYPE}" + exit 1 + ;; +esac + +# Build the command with base config +CMD="jamun_train --config-dir=configs experiment=${CONFIG}.yaml" + +# Add overrides +if [ -n "$OVERRIDES" ]; then + CMD="$CMD $OVERRIDES" +fi + +# Add common training overrides +CMD="$CMD ++trainer.max_epochs=50" +CMD="$CMD ++run_key=${RUN_KEY}" +CMD="$CMD ++logger.wandb.group=noise_check_experiment_multimeasurement_vs_correlation" + +# Add job-specific wandb tags and run name +WANDB_RUN_NAME="noise_check_${WANDB_TAG}" +CMD="$CMD ++logger.wandb.tags=[${WANDB_TAG},noise_check,m_${M}]" +CMD="$CMD ++logger.wandb.name=${WANDB_RUN_NAME}" + +# Add notes about the experiment +WANDB_NOTES="Noise_check_experiment:_${MODEL_TYPE}_with_m=${M}" +if [[ ${MODEL_TYPE} == "standard_jamun" ]]; then + WANDB_NOTES="${WANDB_NOTES}, noise_level=${NOISE_LEVEL}" +fi +CMD="$CMD ++logger.wandb.notes=\"${WANDB_NOTES}\"" + +echo "Running command: $CMD" +exec $CMD diff --git a/src/jamun/cmdline/sample.py b/src/jamun/cmdline/sample.py index 4e8b5ef..f23d22a 100644 --- a/src/jamun/cmdline/sample.py +++ b/src/jamun/cmdline/sample.py @@ -22,6 +22,22 @@ dotenv.load_dotenv(".env", verbose=True) OmegaConf.register_new_resolver("format", format_resolver) +import logging + +# Setup logging +logging.basicConfig(format="[%(asctime)s][%(name)s][%(levelname)s] - %(message)s", level=logging.INFO) +logger = logging.getLogger("load_wandb_checkpoint") + +dotenv.load_dotenv("../.env", verbose=True) # Adjust path if script is not in scratch/ +JAMUN_DATA_PATH = os.getenv("JAMUN_DATA_PATH") +JAMUN_ROOT_PATH = os.getenv("JAMUN_ROOT_PATH") + +project_root = "/homefs/home/sules/jamun" # Adjust if necessary +if project_root not in sys.path: + sys.path.insert(0, project_root) + logger.info(f"Added '{project_root}' to sys.path for module discovery.") +else: + logger.info(f"'{project_root}' is already in sys.path.") def get_initial_graphs( @@ -50,7 +66,6 @@ def run(cfg): if matmul_prec := cfg.get("float32_matmul_precision"): dist_log(f"Setting float_32_matmul_precision to {matmul_prec}") torch.set_float32_matmul_precision(matmul_prec) - loggers = instantiate_dict_cfg(cfg.get("logger"), verbose=(rank_zero_only.rank == 0)) wandb_logger = None for logger in loggers: @@ -72,21 +87,29 @@ def run(cfg): cfg.model.checkpoint_path = checkpoint_path model = hydra.utils.instantiate(cfg.model) + # Set default graph_type to "fan" if spatiotemporal model exists but doesn't have graph_type + if ( + hasattr(model, "conditioner") + and hasattr(model.conditioner, "spatiotemporal_model") + and not hasattr(model.conditioner.spatiotemporal_model, "graph_type") + ): + model.conditioner.spatiotemporal_model.graph_type = "fan" + + print(f"Checkpoint path at: {checkpoint_path}") init_datasets = hydra.utils.instantiate(cfg.init_datasets) + # breakpoint() init_graphs = get_initial_graphs( init_datasets, num_init_samples_per_dataset=cfg.num_init_samples_per_dataset, repeat=cfg.repeat_init_samples, ) - callbacks = instantiate_dict_cfg(cfg.get("callbacks"), verbose=(rank_zero_only.rank == 0)) sampler = hydra.utils.instantiate(cfg.sampler, callbacks=callbacks, loggers=loggers) batch_sampler = hydra.utils.instantiate(cfg.batch_sampler) - if seed := cfg.get("seed"): # During sampling, we want ranks to generate different chains. pl.seed_everything(seed + sampler.fabric.global_rank) - + # breakpoint() # Run test-time adapation, if specified. if finetuning_cfg := cfg.get("finetune_on_init"): num_finetuning_steps = finetuning_cfg.get("num_steps") @@ -129,7 +152,7 @@ def run(cfg): # Needed for submitit error output. # See https://github.com/facebookresearch/hydra/issues/2664 -@hydra.main(version_base=None, config_path="../hydra_config", config_name="sample") +@hydra.main(version_base=None, config_path="../hydra_config", config_name="sample_memory") def main(cfg): try: run(cfg) diff --git a/src/jamun/cmdline/train.py b/src/jamun/cmdline/train.py index 5ec8883..f43cdaa 100644 --- a/src/jamun/cmdline/train.py +++ b/src/jamun/cmdline/train.py @@ -14,15 +14,61 @@ e3nn.set_optimization_defaults(jit_script_fx=False) +import math + import jamun # noqa: E402 from jamun.hydra import instantiate_dict_cfg # noqa: E402 from jamun.hydra.utils import format_resolver # noqa: E402 from jamun.utils import compute_average_squared_distance_from_datasets, dist_log, find_checkpoint # noqa: E402 +from jamun.utils._normalizations import normalization_factors # noqa: E402 +from jamun.utils.average_squared_distance import compute_temporal_average_squared_distance_from_datasets # noqa: E402 dotenv.load_dotenv(".env", verbose=True) OmegaConf.register_new_resolver("format", format_resolver) +def compute_radial_cutoff(max_radius: float, average_squared_distance: float, sigma: float, D: int = 3) -> float: + """ + Compute radial cutoff using the same formula as the denoiser. + + This replicates the computation from denoiser_conditional.py: + radial_cutoff = effective_radial_cutoff(sigma) / c_in + where: + - effective_radial_cutoff = sqrt(max_radius² + 6σ²) + - c_in = 1.0 / sqrt(average_squared_distance + 2Dσ²) + + Args: + max_radius: Maximum radius parameter + average_squared_distance: Average squared distance from dataset + sigma: Noise level + D: Dimensionality (default 3 for 3D coordinates) + + Returns: + Computed radial cutoff + """ + # Effective radial cutoff based on noise level + effective_radial_cutoff = math.sqrt(max_radius**2 + 6 * sigma**2) + + # JAMUN normalization factor c_in + A = average_squared_distance + B = 2 * D * sigma**2 + c_in = 1.0 / math.sqrt(A + B) + + # Final radial cutoff + radial_cutoff = effective_radial_cutoff / c_in + + print("Radial cutoff computation:") + print(f" max_radius: {max_radius}") + print(f" average_squared_distance: {average_squared_distance}") + print(f" sigma: {sigma}") + print(f" D: {D}") + print(f" effective_radial_cutoff: {effective_radial_cutoff}") + print(f" c_in: {c_in}") + print(f" final radial_cutoff: {radial_cutoff}") + + return radial_cutoff + + def compute_average_squared_distance_from_config(cfg: OmegaConf) -> float: """Computes the average squared distance for normalization from the data.""" datamodule = hydra.utils.instantiate(cfg.data.datamodule) @@ -33,17 +79,37 @@ def compute_average_squared_distance_from_config(cfg: OmegaConf) -> float: return average_squared_distance +def compute_temporal_average_squared_distance_from_config(cfg: OmegaConf) -> float: + """Computes the temporal average squared distance for normalization from the data.""" + datamodule = hydra.utils.instantiate(cfg.data.datamodule) + datamodule.setup("compute_normalization") + train_datasets = datamodule.datasets["train"] + + average_squared_distance = compute_temporal_average_squared_distance_from_datasets( + train_datasets, + num_samples=100, # Use reasonable number of samples + verbose=True, + ) + return average_squared_distance + + def run(cfg): log_cfg = OmegaConf.to_container(cfg, throw_on_missing=True, resolve=True) dist_log(f"{OmegaConf.to_yaml(log_cfg)}") dist_log(f"{os.getcwd()=}") dist_log(f"{torch.__config__.parallel_info()}") + dist_log(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}") dist_log(f"{os.sched_getaffinity(0)=}") # Set the start method to spawn to avoid issues with the default fork method. torch.multiprocessing.set_start_method("spawn", force=True) + # Set random seed for reproducible training + if seed := cfg.get("seed"): + lightning.seed_everything(seed) + dist_log(f"Set random seed to {seed} for reproducible training") + # Compute data normalization. if cfg.get("compute_average_squared_distance_from_data"): average_squared_distance = compute_average_squared_distance_from_config(cfg) @@ -52,12 +118,74 @@ def run(cfg): ) cfg.model.average_squared_distance = average_squared_distance + sigma = cfg.model.sigma_distribution.sigma + average_squared_distance = cfg.model.average_squared_distance + c_in, c_skip, c_out, c_noise = normalization_factors(sigma, average_squared_distance) + c_in_float = float(c_in) + c_noise_float = float(c_noise) + + # Compute normalization factors for conditioner c_in parameter + if ( + cfg.model.get("conditioner") + and cfg.model.conditioner.get("_target_") == "jamun.model.conditioners.DenoisedConditioner" + ): + if hasattr(cfg.model.sigma_distribution, "sigma"): + dist_log(f"Computing normalization factors for DenoisedConditioner with sigma={sigma}") + dist_log(f" average_squared_distance: {average_squared_distance}") + dist_log(f" c_in: {c_in_float}") + dist_log(f" c_skip: {c_skip}") + dist_log(f" c_out: {c_out}") + dist_log(f" c_noise: {c_noise}") + + cfg.model.conditioner.c_in = c_in_float + dist_log(f"Set cfg.model.conditioner.c_in to {c_in_float}") + # breakpoint() + if ( + cfg.model.get("conditioner") + and cfg.model.conditioner.get("_target_") == "jamun.model.conditioners.conditioners.SpatioTemporalConditioner" + ): + cfg.model.conditioner.spatiotemporal_model.radial_cutoff = average_squared_distance + max_radius = cfg.model.max_radius + temporal_average_squared_distance = compute_temporal_average_squared_distance_from_config(cfg) + temporal_radial_cutoff = compute_radial_cutoff( + max_radius=max_radius, + average_squared_distance=temporal_average_squared_distance, # Use temporal for spatiotemporal model + sigma=sigma, + D=3, + ) + cfg.model.conditioner.spatiotemporal_model.temporal_cutoff = temporal_radial_cutoff + cfg.model.conditioner.c_noise = c_noise_float + cfg.model.conditioner.c_in = c_in_float + dist_log(f"Set cfg.model.conditioner.spatiotemporal_model.c_noise to {c_noise_float}") + dist_log(f"Set cfg.model.conditioner.c_in to {c_in_float}") + # # do this for the sweep + # if cfg.model.N_measurements_hidden is not None: + # dist_log(f"Number of hidden measurements: {cfg.model.N_measurements_hidden}") + # dist_log(f"Overwriting N_measurements...") + # cfg.model.N_measurements = 100 // cfg.model.N_measurements_hidden + # dist_log(f"New num of measurements: {cfg.model.N_measurements=}") + # breakpoint() datamodule = hydra.utils.instantiate(cfg.data.datamodule) model = hydra.utils.instantiate(cfg.model) if matmul_prec := cfg.get("float32_matmul_precision"): dist_log(f"Setting float_32_matmul_precision to {matmul_prec}") torch.set_float32_matmul_precision(matmul_prec) + # breakpoint() + # # If running under Slurm, ensure the number of devices matches the allocation. + # if "SLURM_GPUS_PER_TASK" in os.environ and torch.cuda.is_available(): + # dist_log(f"torch.cuda.device_count(): {torch.cuda.device_count()}") + # try: + # num_gpus = int(os.environ["SLURM_GPUS_PER_TASK"]) + # dist_log(f"Slurm-allocated GPUs per task: {num_gpus}") + # # Explicitly create a list of device IDs [0, 1, ..., n-1] for Lightning. + # device_ids = list(range(num_gpus)) + # # This will override any value from the config file, ensuring it matches the Slurm allocation. + # cfg.trainer.devices = device_ids + # dist_log(f"Explicitly set cfg.trainer.devices to {cfg.trainer.devices}") + # except (ValueError, KeyError): + # dist_log("Could not parse or find SLURM_GPUS_PER_TASK.") + loggers = instantiate_dict_cfg(cfg.get("logger"), verbose=(rank_zero_only.rank == 0)) wandb_logger = None for logger in loggers: @@ -70,7 +198,7 @@ def run(cfg): callbacks = instantiate_dict_cfg(cfg.get("callbacks"), verbose=(rank_zero_only.rank == 0)) trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=loggers) - + # breakpoint() # TODO support wandb notes/description if rank_zero_only.rank == 0 and wandb_logger: wandb_logger.experiment.config.update({"cfg": log_cfg, "version": jamun.__version__, "cwd": os.getcwd()}) @@ -85,9 +213,10 @@ def run(cfg): ) else: checkpoint_path = None + print(f"Saving checkpoints @ {checkpoint_path}") trainer.fit(model, datamodule=datamodule, ckpt_path=checkpoint_path) - + # breakpoint() if wandb_logger and isinstance(trainer.profiler, lightning.pytorch.profilers.PyTorchProfiler): profile_art = wandb.Artifact("trace", type="profile") for trace in pathlib.Path(trainer.profiler.dirpath).glob("*.pt.trace.json"): diff --git a/src/jamun/data/__init__.py b/src/jamun/data/__init__.py index 8800cc2..a49bd3c 100644 --- a/src/jamun/data/__init__.py +++ b/src/jamun/data/__init__.py @@ -7,5 +7,7 @@ dloader_map_reduce, parse_datasets_from_directory, parse_datasets_from_directory_new, + parse_repeated_position_datasets_from_directory, parse_sdf_datasets_from_directory, ) +from .noisy_position_dataset import RepeatedPositionDataset diff --git a/src/jamun/data/_mdtraj.py b/src/jamun/data/_mdtraj.py index 54a6b23..3e83386 100644 --- a/src/jamun/data/_mdtraj.py +++ b/src/jamun/data/_mdtraj.py @@ -11,6 +11,47 @@ from jamun import utils +def get_subsampled_indices( + N: int, + subsample_rate: int, + total_lag_time: int, + lag_subsample_rate: int, +) -> list[np.ndarray]: + """ + Generate subsampled indices and their corresponding lagged indices. + + Args: + N: Total number of frames + subsample_rate: Rate at which to subsample the frames + total_lag_time: Number of lagged frames to generate for each subsampled frame + lag_subsample_rate: Rate at which to subsample the lagged frames + + Returns: + List of arrays, where each array contains the lagged indices for a subsampled frame + + Raises: + ValueError: If the input parameters don't satisfy the required constraints + """ + # Check guardrails + if N / subsample_rate < 1: + raise ValueError(f"Number of samples (N/subsample_rate = {N / subsample_rate}) must be >= 1") + + # Generate subsampled indices + subsampled_indices = np.arange(0, N, subsample_rate) + + # Generate lagged indices for each subsampled index + lagged_indices = [] + for idx in subsampled_indices: + # Calculate lagged indices + lagged = [int(idx - j * lag_subsample_rate) for j in range(total_lag_time)] + + # Check if we have enough lagged indices + if len(lagged) == total_lag_time and all(x >= 0 for x in lagged): + lagged_indices.append(lagged) + + return lagged_indices + + def make_graph_from_topology( topology: md.Topology, ) -> torch_geometric.data.Data: @@ -36,6 +77,7 @@ def make_graph_from_topology( num_residues=residue_sequence_index.max().item() + 1, bonded_edge_index=bonds, pos=None, + hidden_state=None, ) graph.residues = [x.residue.name for x in topology.atoms] graph.atom_names = [x.name for x in topology.atoms] @@ -109,7 +151,7 @@ def __init__( self.graph.dataset_label = self.label() self.graph.loss_weight = torch.tensor([loss_weight], dtype=torch.float32) - + self.graph.hidden_state = None # self.save_topology_pdb() if verbose: @@ -165,6 +207,8 @@ def __init__( start_frame: int | None = None, transform: Callable | None = None, subsample: int | None = None, + total_lag_time: int | None = None, + lag_subsample_rate: int | None = None, loss_weight: float = 1.0, verbose: bool = False, keep_hydrogens: bool = False, @@ -201,32 +245,41 @@ def __init__( if subsample is None or subsample == 0: subsample = 1 - # Subsample the trajectory. - self.traj = self.traj[start_frame : start_frame + num_frames : subsample] - topology = self.traj.topology - - self.top_with_H, self.topology_slice_with_H = preprocess_topology(topology, keep_hydrogens=True) - self.top_without_H, self.topology_slice_without_H = preprocess_topology(topology, keep_hydrogens=False) - self.graph_with_H = make_graph_from_topology(self.top_with_H) - self.graph_without_H = make_graph_from_topology(self.top_without_H) - - if keep_hydrogens: - self.graph = self.graph_with_H - self.top = self.top_with_H - self.topology_slice = self.topology_slice_with_H + # Get lagged indices if lag parameters are provided + if total_lag_time is not None and lag_subsample_rate is not None: + # print(f"total_lag_time: {total_lag_time}, lag_subsample_rate: {lag_subsample_rate}") + self.traj = self.traj[start_frame : start_frame + num_frames] # accommodate for start_frame and num_frames + lagged_indices = get_subsampled_indices(self.traj.n_frames, subsample, total_lag_time, lag_subsample_rate) + # Extract subsampled indices (first element of each list) + subsampled_indices = [indices[0] for indices in lagged_indices] + # Extract lagged indices (all except first element) + self.lagged_indices = [indices[1:] for indices in lagged_indices] + # Subsample the trajectory using the subsampled indices + self.hidden_state = [self.traj[indices] for indices in self.lagged_indices] + self.traj = self.traj[subsampled_indices] # self.traj is permanently modified. else: - self.graph = self.graph_without_H - self.top = self.top_without_H - self.topology_slice = self.topology_slice_without_H + # Regular subsampling without lag + # print(f"subsample: {subsample}, regular subsampling") + self.traj = self.traj[start_frame : start_frame + num_frames : subsample] + self.hidden_state = None + self.lagged_indices = None - self.traj = self.traj.atom_slice(self.topology_slice) + topology = self.traj.topology + self.top, atom_selection = preprocess_topology(topology, keep_hydrogens=False) + self.graph = make_graph_from_topology(self.top) + self.traj = self.traj.atom_slice(atom_selection) + if self.hidden_state is not None: + self.hidden_state = [ + traj.atom_slice(atom_selection) for traj in self.hidden_state + ] # select protein atoms for hidden state(s) self.graph.pos = torch.tensor(self.traj.xyz[0], dtype=torch.float32) self.graph.loss_weight = torch.tensor([loss_weight], dtype=torch.float32) self.graph.dataset_label = self.label() - - # self.save_topology_pdb() - + if self.hidden_state is not None: + self.graph.hidden_state = [self.hidden_state[0].xyz[i] for i in range(self.hidden_state[0].n_frames)] + else: + self.graph.hidden_state = [] if verbose: utils.dist_log(f"Dataset {self.label()}: Loading trajectory files {traj_files} and PDB file {pdb_file}.") utils.dist_log( @@ -246,6 +299,14 @@ def save_topology_pdb(self, filename: str | None = None): def __getitem__(self, idx): graph = self.graph.clone() graph.pos = torch.tensor(self.traj.xyz[idx]) + + if self.hidden_state is not None: + graph.hidden_state = [ + torch.tensor(self.hidden_state[idx].xyz[i]) for i in range(self.hidden_state[idx].n_frames) + ] + else: + graph.hidden_state = [] + if self.transform: graph = self.transform(graph) return graph diff --git a/src/jamun/data/_subsample.py b/src/jamun/data/_subsample.py new file mode 100644 index 0000000..fcc3973 --- /dev/null +++ b/src/jamun/data/_subsample.py @@ -0,0 +1,86 @@ +import numpy as np + + +def get_subsampled_indices( + N: int, + subsample_rate: int, + total_lag_time: int, + lag_subsample_rate: int, +) -> list[np.ndarray]: + """ + Generate subsampled indices and their corresponding lagged indices. + + Args: + N: Total number of frames + subsample_rate: Rate at which to subsample the frames + total_lag_time: Number of lagged frames to generate for each subsampled frame + lag_subsample_rate: Rate at which to subsample the lagged frames + + Returns: + List of arrays, where each array contains the lagged indices for a subsampled frame + + Raises: + ValueError: If the input parameters don't satisfy the required constraints + """ + # Check guardrails + if N / subsample_rate < 1: + raise ValueError(f"Number of samples (N/subsample_rate = {N / subsample_rate}) must be >= 1") + + # if total_lag_time * lag_subsample_rate > subsample_rate: + # raise ValueError( + # f"total_lag_time * lag_subsample_rate ({total_lag_time * lag_subsample_rate}) " + # f"must be <= subsample_rate ({subsample_rate})" + # ) + + # Generate subsampled indices + subsampled_indices = np.arange(0, N, subsample_rate) + + # Generate lagged indices for each subsampled index + lagged_indices = [] + for idx in subsampled_indices: + # Calculate lagged indices + lagged = [int(idx - j * lag_subsample_rate) for j in range(total_lag_time)] + + # Check if we have enough lagged indices + if len(lagged) == total_lag_time and all(x >= 0 for x in lagged): + lagged_indices.append(lagged) + + return lagged_indices + + +def get_subsampled_trajectory( + positions: np.ndarray, + subsample_rate: int, + total_lag_time: int, + lag_subsample_rate: int, +) -> tuple[np.ndarray, list[np.ndarray]]: + """ + Subsample a trajectory and generate lagged states for each subsampled frame. + + Args: + positions: Array of shape (N, ...) containing trajectory positions + subsample_rate: Rate at which to subsample the frames + total_lag_time: Number of lagged frames to generate for each subsampled frame + lag_subsample_rate: Rate at which to subsample the lagged frames + + Returns: + Tuple containing: + - subsampled_positions: Array of subsampled positions + - lagged_positions: List of arrays, where each array contains the lagged positions + for the corresponding subsampled frame + + Raises: + ValueError: If the input parameters don't satisfy the required constraints + """ + N = len(positions) + + # Get the lagged indices + lagged_indices = get_subsampled_indices(N, subsample_rate, total_lag_time, lag_subsample_rate) + + # Extract subsampled positions (first element of each lagged indices list) + subsampled_positions = np.array([positions[indices[0]] for indices in lagged_indices]) + + # Generate lagged positions for each subsampled frame + lagged_positions = [[positions[idx] for idx in indices[1:]] for indices in lagged_indices] + + return subsampled_positions, lagged_positions diff --git a/src/jamun/data/_utils.py b/src/jamun/data/_utils.py index bf953c7..20fc614 100644 --- a/src/jamun/data/_utils.py +++ b/src/jamun/data/_utils.py @@ -43,6 +43,7 @@ def parse_datasets_from_directory( max_datasets_offset: int | None = None, filter_codes: Sequence[str] | None = None, as_iterable: bool = False, + label_override: str | None = None, **dataset_kwargs, ) -> list[MDtrajDataset]: """Helper function to create MDtrajDataset objects from a directory of trajectory files.""" @@ -106,11 +107,18 @@ def parse_datasets_from_directory( datasets = [] for code in tqdm(codes, desc="Creating datasets"): + # Use label_override for the dataset label, but keep original code for file lookups + if label_override is not None: + print(f"Label override: {label_override}") + dataset_label = str(label_override) + else: + dataset_label = code + dataset = dataset_class( root, traj_files=traj_files[code], pdb_file=pdb_files[code], - label=code, + label=dataset_label, **dataset_kwargs, ) datasets.append(dataset) @@ -348,3 +356,100 @@ def create_dataset_from_pdbs(pdbfiles: str, label_prefix: str | None = None) -> datasets.append(dataset) return datasets + + +def parse_repeated_position_datasets_from_directory( + root: str, + traj_pattern: str, + pdb_pattern: str | None = None, + pdb_file: Sequence[str] | None = None, + max_datasets: int | None = None, + max_datasets_offset: int | None = None, + filter_codes: Sequence[str] | None = None, + as_iterable: bool = False, + label_override: str | None = None, + **dataset_kwargs, +) -> list: + """Helper function to create RepeatedPositionDataset objects from a directory of trajectory files.""" + # Import here to avoid circular imports + from jamun.data.noisy_position_dataset import RepeatedPositionDataset + + # Print the dataset_kwargs for debugging + print("=== parse_repeated_position_datasets_from_directory dataset_kwargs ===") + print(f"dataset_kwargs: {dataset_kwargs}") + print("=== End dataset_kwargs ===") + + if pdb_file is not None and pdb_pattern is not None: + raise ValueError("Exactly one of pdb_file and pdb_pattern should be provided.") + + traj_prefix, traj_pattern = os.path.split(traj_pattern) + traj_pattern_compiled = re.compile(traj_pattern) + if "*" in traj_prefix or "?" in traj_prefix: + raise ValueError("traj_prefix should not contain wildcards.") + + traj_files = collections.defaultdict(list) + codes = set() + for entry in os.scandir(os.path.join(root, traj_prefix)): + match = traj_pattern_compiled.match(entry.name) + if not match: + continue + + code = match.group(1) + codes.add(code) + traj_files[code].append(os.path.join(traj_prefix, entry.name)) + + if len(codes) == 0: + raise ValueError("No codes found in directory.") + + pdb_files = {} + if pdb_pattern is not None: + pdb_prefix, pdb_pattern = os.path.split(pdb_pattern) + pdb_pattern_compiled = re.compile(pdb_pattern) + if "*" in pdb_prefix or "?" in pdb_prefix: + raise ValueError("pdb_prefix should not contain wildcards.") + + for entry in os.scandir(os.path.join(root, pdb_prefix)): + match = pdb_pattern_compiled.match(entry.name) + if not match: + continue + + code = match.group(1) + if code not in codes: + continue + pdb_files[code] = os.path.join(pdb_prefix, entry.name) + else: + for code in codes: + pdb_files[code] = pdb_file + + # Filter out codes. + if filter_codes is not None: + codes = [code for code in codes if code in set(filter_codes)] + + # Sort the codes and offset them, if necessary. + codes = list(sorted(codes)) + if max_datasets_offset is not None: + codes = codes[max_datasets_offset:] + if max_datasets is not None: + codes = codes[:max_datasets] + + if as_iterable: + raise ValueError("RepeatedPositionDataset does not support iterable mode") + + datasets = [] + for code in tqdm(codes, desc="Creating RepeatedPositionDatasets"): + # Use label_override for the dataset label, but keep original code for file lookups + if label_override is not None: + print(f"Label override: {label_override}") + dataset_label = str(label_override) + else: + dataset_label = code + + dataset = RepeatedPositionDataset( + root, + traj_files=traj_files[code], + pdb_file=pdb_files[code], + label=dataset_label, + **dataset_kwargs, + ) + datasets.append(dataset) + return datasets diff --git a/src/jamun/data/noisy_position_dataset.py b/src/jamun/data/noisy_position_dataset.py new file mode 100644 index 0000000..c8d070c --- /dev/null +++ b/src/jamun/data/noisy_position_dataset.py @@ -0,0 +1,36 @@ +from jamun.data._mdtraj import MDtrajDataset + + +class RepeatedPositionDataset(MDtrajDataset): + """ + Dataset that replaces hidden states with copies of the current position. + This is used for Model 3 experiment where the structures passed to the denoiser + are copies of the same structure given by y.pos. The denoiser will add noise during training. + """ + + def __init__(self, *args, **kwargs): + """Initialize but store total_lag_time before modifying parent behavior.""" + # Store the total_lag_time for our own use + self._target_total_lag_time = kwargs.get("total_lag_time", 2) + + # Prevent parent from doing lag processing by removing lag parameters + kwargs_no_lag = kwargs.copy() + kwargs_no_lag["total_lag_time"] = None + kwargs_no_lag["lag_subsample_rate"] = None + + super().__init__(*args, **kwargs_no_lag) + + def __getitem__(self, idx): + """Override to create position copies instead of using real hidden states.""" + # Get the normal item from parent class (without lag processing) + graph = super().__getitem__(idx) + + # Create the number of hidden states we want based on our target total_lag_time + num_hidden_states = self._target_total_lag_time - 1 + + graph.hidden_state = [] + for _ in range(num_hidden_states): + # Create a copy of the current position (no noise added here) + graph.hidden_state.append(graph.pos.clone()) + + return graph diff --git a/src/jamun/data/tests/test_subsample.py b/src/jamun/data/tests/test_subsample.py new file mode 100644 index 0000000..99392e2 --- /dev/null +++ b/src/jamun/data/tests/test_subsample.py @@ -0,0 +1,239 @@ +import numpy as np + +from jamun.data._subsample import get_subsampled_indices, get_subsampled_trajectory + + +def test_basic_functionality(): + """Test basic functionality with valid inputs.""" + print("\nTesting basic functionality...") + N = 100 + subsample_rate = 10 + total_lag_time = 3 + lag_subsample_rate = 10 + + print( + f"Input parameters: N={N}, subsample_rate={subsample_rate}, " + f"total_lag_time={total_lag_time}, lag_subsample_rate={lag_subsample_rate}" + ) + + breakpoint() # Debug point 1: Check input parameters before function call + + lagged_indices = get_subsampled_indices(N, subsample_rate, total_lag_time, lag_subsample_rate) + + breakpoint() # Debug point 2: Check function output + + # Extract subsampled indices (first element of each list) + subsampled_indices = np.array([indices[0] for indices in lagged_indices]) + print(f"Subsampled indices: {subsampled_indices}") + print(f"Number of lagged indices lists: {len(lagged_indices)}") + + # Check subsampled indices + expected_subsampled = np.array([20, 30, 40, 50, 60, 70, 80, 90]) + assert np.array_equal(subsampled_indices, expected_subsampled), ( + f"Expected {expected_subsampled}, got {subsampled_indices}" + ) + + # Check lagged indices + for i, lagged in enumerate(lagged_indices): + expected_lagged = np.array( + [ + subsampled_indices[i], + subsampled_indices[i] - lag_subsample_rate, + subsampled_indices[i] - 2 * lag_subsample_rate, + ] + ) + assert np.array_equal(lagged, expected_lagged), f"For index {i}, expected {expected_lagged}, got {lagged}" + + print("Basic functionality test passed!") + + +def test_edge_cases(): + """Test edge cases and boundary conditions.""" + print("\nTesting edge cases...") + N = 10 + subsample_rate = 10 + total_lag_time = 1 + lag_subsample_rate = 1 + + print( + f"Input parameters: N={N}, subsample_rate={subsample_rate}, " + f"total_lag_time={total_lag_time}, lag_subsample_rate={lag_subsample_rate}" + ) + + breakpoint() # Debug point 5: Check edge case parameters before function call + + lagged_indices = get_subsampled_indices(N, subsample_rate, total_lag_time, lag_subsample_rate) + + breakpoint() # Debug point 6: Check edge case results + + # Extract subsampled indices (first element of each list) + subsampled_indices = np.array([indices[0] for indices in lagged_indices]) + print(f"Subsampled indices: {subsampled_indices}") + print(f"Lagged indices: {lagged_indices}") + + assert len(subsampled_indices) == 1, f"Expected 1 subsampled index, got {len(subsampled_indices)}" + assert len(lagged_indices) == 1, f"Expected 1 lagged indices list, got {len(lagged_indices)}" + assert np.array_equal(subsampled_indices, np.array([0])), f"Expected [0], got {subsampled_indices}" + assert np.array_equal(lagged_indices[0], np.array([0])), f"Expected [0], got {lagged_indices[0]}" + + print("Edge cases test passed!") + + +def test_lagged_indices_filtering(): + """Test that lagged indices are properly filtered when they would go negative.""" + print("\nTesting lagged indices filtering...") + N = 20 + subsample_rate = 5 + total_lag_time = 3 + lag_subsample_rate = 3 + + print( + f"Input parameters: N={N}, subsample_rate={subsample_rate}, " + f"total_lag_time={total_lag_time}, lag_subsample_rate={lag_subsample_rate}" + ) + + breakpoint() # Debug point 7: Check filtering parameters before function call + + lagged_indices = get_subsampled_indices(N, subsample_rate, total_lag_time, lag_subsample_rate) + + breakpoint() # Debug point 8: Check filtering results + + # Extract subsampled indices (first element of each list) + subsampled_indices = np.array([indices[0] for indices in lagged_indices]) + print(f"Subsampled indices: {subsampled_indices}") + print(f"Number of lagged indices lists: {len(lagged_indices)}") + + expected_subsampled = np.array([5, 10, 15]) + assert np.array_equal(subsampled_indices, expected_subsampled), ( + f"Expected {expected_subsampled}, got {subsampled_indices}" + ) + + assert len(lagged_indices) == len(subsampled_indices), ( + f"Expected {len(subsampled_indices)} lagged indices lists, got {len(lagged_indices)}" + ) + + expected_first_lagged = np.array([5, 2, -1]) + assert not any(np.array_equal(lagged, expected_first_lagged) for lagged in lagged_indices), ( + "Found unexpected lagged indices that should have been filtered out" + ) + + print("Lagged indices filtering test passed!") + + +def test_large_numbers(): + """Test with larger numbers to ensure scalability.""" + print("\nTesting large numbers...") + N = 10000 + subsample_rate = 100 + total_lag_time = 5 + lag_subsample_rate = 10 + + print( + f"Input parameters: N={N}, subsample_rate={subsample_rate}, " + f"total_lag_time={total_lag_time}, lag_subsample_rate={lag_subsample_rate}" + ) + + breakpoint() # Debug point 9: Check large number parameters before function call + + lagged_indices = get_subsampled_indices(N, subsample_rate, total_lag_time, lag_subsample_rate) + + breakpoint() # Debug point 10: Check large number results + + # Extract subsampled indices (first element of each list) + subsampled_indices = np.array([indices[0] for indices in lagged_indices]) + print(f"Number of subsampled indices: {len(subsampled_indices)}") + print(f"Number of lagged indices lists: {len(lagged_indices)}") + + assert len(subsampled_indices) == N // subsample_rate, ( + f"Expected {N // subsample_rate} subsampled indices, got {len(subsampled_indices)}" + ) + + for i, lagged in enumerate(lagged_indices): + print(f"\nChecking lagged indices list {i}:") + print(f"Lagged indices: {lagged}") + print(f"Type of lagged indices: {type(lagged)}") + print("Individual values and their types:") + for j, val in enumerate(lagged): + print(f" Index {j}: value={val}, type={type(val)}") + + assert len(lagged) == total_lag_time, f"Expected lagged indices length {total_lag_time}, got {len(lagged)}" + + # Check each value individually + for j, val in enumerate(lagged): + assert isinstance(val, int | np.integer), ( + f"Value at index {j} is not an integer: {val} (type: {type(val)})" + ) + assert val >= 0, f"Found negative value at index {j}: {val}" + + print("Large numbers test passed!") + + +def test_trajectory_subsampling(): + """Test subsampling of trajectory positions.""" + print("\nTesting trajectory subsampling...") + + # Create a random trajectory with 100 frames, 10 particles, and 3 coordinates + N = 100 + np.random.seed(42) # For reproducibility + positions = np.random.randn(N, 10, 3) + + subsample_rate = 10 + total_lag_time = 3 + lag_subsample_rate = 10 + + print( + f"Input parameters: N={N}, subsample_rate={subsample_rate}, " + f"total_lag_time={total_lag_time}, lag_subsample_rate={lag_subsample_rate}" + ) + + breakpoint() # Debug point 11: Check trajectory parameters + + subsampled_positions, lagged_positions = get_subsampled_trajectory( + positions, subsample_rate, total_lag_time, lag_subsample_rate + ) + + breakpoint() # Debug point 12: Check trajectory results + + print(f"Original positions shape: {positions.shape}") + print(f"Subsampled positions shape: {subsampled_positions.shape}") + print(f"Number of lagged position lists: {len(lagged_positions)}") + + # Check shapes + expected_num_subsampled = (N - 20) // subsample_rate # Starting from index 20 + assert subsampled_positions.shape[0] == expected_num_subsampled, ( + f"Expected {expected_num_subsampled} subsampled positions, got {subsampled_positions.shape[0]}" + ) + assert subsampled_positions.shape[1:] == (10, 3), ( + f"Expected subsampled positions to have shape (N, 10, 3), got {subsampled_positions.shape}" + ) + + # Check values + for i in range(expected_num_subsampled): + # Check subsampled position + expected_sub_pos = positions[20 + i * subsample_rate] + assert np.array_equal(subsampled_positions[i], expected_sub_pos), ( + f"For index {i}, expected subsampled position {expected_sub_pos}, got {subsampled_positions[i]}" + ) + + # Check lagged positions + assert len(lagged_positions[i]) == total_lag_time, ( + f"For index {i}, expected {total_lag_time} lagged positions, got {len(lagged_positions[i])}" + ) + + for j, lag_pos in enumerate(lagged_positions[i]): + expected_lag_pos = positions[20 + i * subsample_rate - j * lag_subsample_rate] + assert np.array_equal(lag_pos, expected_lag_pos), ( + f"For index {i}, lag {j}, expected position {expected_lag_pos}, got {lag_pos}" + ) + + print("Trajectory subsampling test passed!") + + +if __name__ == "__main__": + print("Starting tests...") + test_basic_functionality() + # test_edge_cases() + # test_lagged_indices_filtering() + # test_large_numbers() + test_trajectory_subsampling() + print("\nAll tests completed!") diff --git a/src/jamun/hydra_config/batch_sampler/mcmc/aboba_memory.yaml b/src/jamun/hydra_config/batch_sampler/mcmc/aboba_memory.yaml new file mode 100644 index 0000000..40ea97d --- /dev/null +++ b/src/jamun/hydra_config/batch_sampler/mcmc/aboba_memory.yaml @@ -0,0 +1,12 @@ +_target_: jamun.sampling.mcmc.ABOBA_memory +delta: ${delta} +friction: ${friction} +steps: ${num_sampling_steps_per_batch} +save_trajectory: true +cpu_offload: true +verbose: true +inverse_temperature: ${inverse_temperature} +score_fn_clip: ${score_fn_clip} +M: ${M} +burn_in_steps: 0 +v_init: "zero" \ No newline at end of file diff --git a/src/jamun/hydra_config/batch_sampler/mcmc/baoab_memory.yaml b/src/jamun/hydra_config/batch_sampler/mcmc/baoab_memory.yaml new file mode 100644 index 0000000..01e4a04 --- /dev/null +++ b/src/jamun/hydra_config/batch_sampler/mcmc/baoab_memory.yaml @@ -0,0 +1,12 @@ +_target_: jamun.sampling.mcmc.BAOAB_memory +delta: ${delta} +friction: ${friction} +steps: ${num_sampling_steps_per_batch} +save_trajectory: true +cpu_offload: true +verbose: true +inverse_temperature: ${inverse_temperature} +score_fn_clip: ${score_fn_clip} +M: ${M} +burn_in_steps: 0 +v_init: "zero" \ No newline at end of file diff --git a/src/jamun/hydra_config/batch_sampler/single_measurement_sampler_memory.yaml b/src/jamun/hydra_config/batch_sampler/single_measurement_sampler_memory.yaml new file mode 100644 index 0000000..a58a9cb --- /dev/null +++ b/src/jamun/hydra_config/batch_sampler/single_measurement_sampler_memory.yaml @@ -0,0 +1,7 @@ +defaults: + - mcmc: baoab_memory.yaml + - callbacks: null + - _self_ + +_target_: jamun.sampling.walkjump.SingleMeasurementSamplerMemory +sigma: ${sigma} diff --git a/src/jamun/hydra_config/callbacks/model_checkpoint.yaml b/src/jamun/hydra_config/callbacks/model_checkpoint.yaml index da57e74..0a59fcd 100644 --- a/src/jamun/hydra_config/callbacks/model_checkpoint.yaml +++ b/src/jamun/hydra_config/callbacks/model_checkpoint.yaml @@ -1,6 +1,8 @@ model_checkpoint: _target_: "lightning.pytorch.callbacks.ModelCheckpoint" dirpath: "${hydra:runtime.output_dir}/checkpoints" - save_top_k: 5 + save_top_k: -1 save_last: true monitor: "val/loss" + save_on_train_epoch_end: true + every_n_epochs: 1 diff --git a/src/jamun/hydra_config/model/arch/e3conv.yaml b/src/jamun/hydra_config/model/arch/e3conv.yaml index 4bfd8ee..3e04b19 100644 --- a/src/jamun/hydra_config/model/arch/e3conv.yaml +++ b/src/jamun/hydra_config/model/arch/e3conv.yaml @@ -19,7 +19,7 @@ hidden_layer_factory: _target_: e3tools.nn.ConvBlock _partial_: true conv: - _target_: e3tools.nn.Conv + _target_: e3tools.nn.SeparableConv _partial_: true output_head_factory: _target_: e3tools.nn.EquivariantMLP diff --git a/src/jamun/hydra_config/model/arch/e3conv_conditional.yaml b/src/jamun/hydra_config/model/arch/e3conv_conditional.yaml new file mode 100644 index 0000000..3a272e7 --- /dev/null +++ b/src/jamun/hydra_config/model/arch/e3conv_conditional.yaml @@ -0,0 +1,31 @@ +_target_: jamun.model.arch.E3ConvConditional +_partial_: true +irreps_out: "1x1e" +irreps_hidden: "120x0e + 32x1e" +irreps_sh: "1x0e + 1x1e" +n_layers: 5 +edge_attr_dim: 64 +atom_type_embedding_dim: 8 +atom_code_embedding_dim: 8 +residue_code_embedding_dim: 32 +residue_index_embedding_dim: 8 +use_residue_information: ${data.use_residue_information} +use_residue_sequence_index: false +num_atom_types: 20 +max_sequence_length: 10 +num_atom_codes: 10 +num_residue_types: 25 +hidden_layer_factory: + _target_: e3tools.nn.ConvBlock + _partial_: true + conv: + _target_: e3tools.nn.SeparableConv + _partial_: true +output_head_factory: + _target_: e3tools.nn.EquivariantMLP + _partial_: true + irreps_hidden_list: + - ${model.arch.irreps_hidden} +test_equivariance: false +reduce: null +N_structures: 2 \ No newline at end of file diff --git a/src/jamun/hydra_config/model/arch/e3conv_conditional_spatiotemporal.yaml b/src/jamun/hydra_config/model/arch/e3conv_conditional_spatiotemporal.yaml new file mode 100644 index 0000000..3833eae --- /dev/null +++ b/src/jamun/hydra_config/model/arch/e3conv_conditional_spatiotemporal.yaml @@ -0,0 +1,32 @@ +_target_: jamun.model.arch.e3conv_conditional.E3ConvConditionalSpatioTemporal +_partial_: true +irreps_out: "1x1e" # Output 3 components (1x1e) to match position +irreps_hidden: "120x0e + 32x1e" +irreps_sh: "1x0e + 1x1e" +n_layers: 1 +edge_attr_dim: 64 # Match spatiotemporal spatial module +atom_type_embedding_dim: 8 +atom_code_embedding_dim: 8 +residue_code_embedding_dim: 32 # Match spatiotemporal spatial module +residue_index_embedding_dim: 8 +use_residue_information: ${data.use_residue_information} +use_residue_sequence_index: false +num_atom_types: 20 +max_sequence_length: 10 +num_atom_codes: 10 +num_residue_types: 25 +test_equivariance: false +reduce: null +hidden_layer_factory: + _target_: e3tools.nn.ConvBlock + _partial_: true + conv: + _target_: e3tools.nn.SeparableConv # replace with Conv for non-separable case + _partial_: true +output_head_factory: + _target_: e3tools.nn.EquivariantMLP + _partial_: true + irreps_hidden_list: + - ${model.arch.irreps_hidden} +N_structures: 1 # For [y.pos, spatial_features] input +input_attr_irreps: "120x0e + 32x1e" # spatial_features (9 components = 3x1e) \ No newline at end of file diff --git a/src/jamun/hydra_config/model/arch/spatiotemporal.yaml b/src/jamun/hydra_config/model/arch/spatiotemporal.yaml new file mode 100644 index 0000000..c4384ee --- /dev/null +++ b/src/jamun/hydra_config/model/arch/spatiotemporal.yaml @@ -0,0 +1,58 @@ +# Configuration for E3SpatioTemporal architecture +_target_: jamun.model.arch.spatiotemporal.E3SpatioTemporal + +# Cutoff parameters +radial_cutoff: 0.05 # radial cutoff for the spatial module +temporal_cutoff: 1.0 # radial cutoff for the temporal module + +# Spatial module (E3Conv) +spatial_module: + _target_: jamun.model.arch.E3Conv + _partial_: true + irreps_out: "1x1e" + irreps_hidden: "120x0e + 32x1e" + irreps_sh: "1x0e + 1x1e" + n_layers: 5 + edge_attr_dim: 64 + atom_type_embedding_dim: 8 + atom_code_embedding_dim: 8 + residue_code_embedding_dim: 32 + residue_index_embedding_dim: 8 + use_residue_information: ${data.use_residue_information} + use_residue_sequence_index: false + num_atom_types: 20 + max_sequence_length: 10 + num_atom_codes: 10 + num_residue_types: 25 + hidden_layer_factory: + _target_: e3tools.nn.ConvBlock + _partial_: true + conv: + _target_: e3tools.nn.SeparableConv + _partial_: true + output_head_factory: + _target_: e3tools.nn.EquivariantMLP + _partial_: true + irreps_hidden_list: + - ${model.arch.irreps_hidden} + +# Temporal module (E3Transformer) +temporal_module: + _target_: jamun.model.arch.spatiotemporal.E3Transformer + irreps_out: "3x1e" + irreps_hidden: "8x0e + 4x1e" + irreps_sh: "1x0e + 1x1e" + irreps_node_attr: "1x1e" # Match spatial module output + num_layers: 2 + edge_attr_dim: 24 + num_attention_heads: 1 + reduce: null + +# Pooling modules +spatial_to_temporal_pooler: + _target_: jamun.model.pooling.SpatialTemporalToTemporalNodeAttr + irreps_out: "1x1e" # Match spatial module output + +temporal_to_spatial_pooler: + _target_: jamun.model.pooling.TemporalToSpatialNodeAttrMean + irreps_out: "3x1e" # Match temporal module output \ No newline at end of file diff --git a/src/jamun/hydra_config/model/conditioner/denoised.yaml b/src/jamun/hydra_config/model/conditioner/denoised.yaml new file mode 100644 index 0000000..c3c242f --- /dev/null +++ b/src/jamun/hydra_config/model/conditioner/denoised.yaml @@ -0,0 +1,5 @@ +# @package _global_ +_target_: jamun.model.conditioners.conditioners.DenoisedConditioner +N_structures: 1 +pretrained_model_path: "wandb_run_path/here" # Replace with actual wandb run path +c_in: 1.0 \ No newline at end of file diff --git a/src/jamun/hydra_config/model/conditioner/mean.yaml b/src/jamun/hydra_config/model/conditioner/mean.yaml new file mode 100644 index 0000000..8157c8e --- /dev/null +++ b/src/jamun/hydra_config/model/conditioner/mean.yaml @@ -0,0 +1,3 @@ +# @package _global_ +_target_: jamun.model.conditioners.conditioners.MeanConditioner +N_structures: 1 \ No newline at end of file diff --git a/src/jamun/hydra_config/model/conditioner/position.yaml b/src/jamun/hydra_config/model/conditioner/position.yaml new file mode 100644 index 0000000..0cb184c --- /dev/null +++ b/src/jamun/hydra_config/model/conditioner/position.yaml @@ -0,0 +1,4 @@ +# @package _global_ +_target_: jamun.model.conditioners.conditioners.PositionConditioner +N_structures: 1 +align_hidden_states: true \ No newline at end of file diff --git a/src/jamun/hydra_config/model/conditioner/self.yaml b/src/jamun/hydra_config/model/conditioner/self.yaml new file mode 100644 index 0000000..f6334c7 --- /dev/null +++ b/src/jamun/hydra_config/model/conditioner/self.yaml @@ -0,0 +1,3 @@ +# @package _global_ +_target_: jamun.model.conditioners.conditioners.SelfConditioner +N_structures: 1 \ No newline at end of file diff --git a/src/jamun/hydra_config/model/conditioner/spatiotemporal.yaml b/src/jamun/hydra_config/model/conditioner/spatiotemporal.yaml new file mode 100644 index 0000000..5799ca2 --- /dev/null +++ b/src/jamun/hydra_config/model/conditioner/spatiotemporal.yaml @@ -0,0 +1,66 @@ +_target_: jamun.model.conditioners.conditioners.SpatioTemporalConditioner +N_structures: 2 # Now returns [y.pos, spatial_features] +c_noise: 0.0 +c_in: 1.0 +freeze_spatiotemporal_model: false # Trainable by default + +# Spatiotemporal model configuration +spatiotemporal_model: + _target_: jamun.model.arch.spatiotemporal.E3SpatioTemporal + radial_cutoff: 0.05 + temporal_cutoff: 1.0 + + # Spatial module (E3Conv) + spatial_module: + _target_: jamun.model.arch.e3conv.E3Conv + irreps_out: "1x1e" + irreps_hidden: "120x0e + 32x1e" + irreps_sh: "1x0e + 1x1e" + n_layers: 1 + edge_attr_dim: 64 + atom_type_embedding_dim: 8 + atom_code_embedding_dim: 8 + residue_code_embedding_dim: 32 + residue_index_embedding_dim: 8 + use_residue_information: ${data.use_residue_information} + use_residue_sequence_index: false + num_atom_types: 20 + max_sequence_length: 10 + num_atom_codes: 10 + num_residue_types: 25 + hidden_layer_factory: + _target_: e3tools.nn.ConvBlock + _partial_: true + conv: + _target_: e3tools.nn.SeparableConv # replace with Conv for non-separable case + _partial_: true + output_head_factory: + _target_: e3tools.nn.EquivariantMLP + _partial_: true + irreps_hidden_list: + - "120x0e + 32x1e" + + # Temporal module (E3Transformer) + temporal_module: + _target_: jamun.model.arch.spatiotemporal.E3Transformer + irreps_out: "120x0e + 32x1e" # Final spatial features output + irreps_hidden: "120x0e + 32x1e" + irreps_sh: "1x0e + 1x1e" + irreps_node_attr: "1x1e" # Match spatial module output + irreps_node_attr_temporal: "3x0e" + conv: + _target_: e3tools.nn.SeparableConv + _partial_: true + num_layers: 2 + edge_attr_dim: 24 + num_attention_heads: 1 + reduce: null + + # Pooling modules + spatial_to_temporal_pooler: + _target_: jamun.model.pooling.SpatialTemporalToTemporalNodeAttr + irreps_out: "1x1e" # Match spatial module output + + temporal_to_spatial_pooler: + _target_: jamun.model.pooling.TemporalToSpatialNodeAttrMean + irreps_out: "120x0e + 32x1e" # Match temporal module output \ No newline at end of file diff --git a/src/jamun/hydra_config/model/conditioner/spatiotemporal_pretrained.yaml b/src/jamun/hydra_config/model/conditioner/spatiotemporal_pretrained.yaml new file mode 100644 index 0000000..920b565 --- /dev/null +++ b/src/jamun/hydra_config/model/conditioner/spatiotemporal_pretrained.yaml @@ -0,0 +1,41 @@ +_target_: jamun.model.conditioners.conditioners.SpatioTemporalConditioner +N_structures: 2 # Now returns [y.pos, spatial_features] +c_noise: 0.0 +freeze_spatiotemporal_model: false # Trainable by default +c_in: 1.0 +# Spatiotemporal model configuration +spatiotemporal_model: + _target_: jamun.model.arch.spatiotemporal.E3SpatioTemporal + radial_cutoff: 0.05 + temporal_cutoff: 1.0 + + # Spatial module (E3Conv) + spatial_module: + _target_: jamun.utils.pretrained_wrapper.return_wrapped_denoiser + wandb_run_path: "sule-shashank/jamun/sxcdx4wf" + checkpoint_type: "last" + trainable: false + c_in: ${model.conditioner.c_in} + + + # Temporal module (E3Transformer) + temporal_module: + _target_: jamun.model.arch.spatiotemporal.E3Transformer + irreps_out: "120x0e + 32x1e" # Final spatial features output + irreps_hidden: "120x0e + 32x1e" + irreps_sh: "1x0e + 1x1e" + irreps_node_attr: "1x1e" # Match spatial module output + irreps_node_attr_temporal: "3x0e" + num_layers: 2 + edge_attr_dim: 24 + num_attention_heads: 1 + reduce: null + + # Pooling modules + spatial_to_temporal_pooler: + _target_: jamun.model.pooling.SpatialTemporalToTemporalNodeAttr + irreps_out: "1x1e" # Match spatial module output (from temporal module irreps_node_attr) + + temporal_to_spatial_pooler: + _target_: jamun.model.pooling.TemporalToSpatialNodeAttrMean + irreps_out: "120x0e + 32x1e" # Match temporal module output \ No newline at end of file diff --git a/src/jamun/hydra_config/model/conditioner/spiked.yaml b/src/jamun/hydra_config/model/conditioner/spiked.yaml new file mode 100644 index 0000000..0189deb --- /dev/null +++ b/src/jamun/hydra_config/model/conditioner/spiked.yaml @@ -0,0 +1,3 @@ +# @package _global_ +_target_: jamun.model.conditioners.conditioners.ConditionerSpiked +N_structures: 1 \ No newline at end of file diff --git a/src/jamun/hydra_config/model/denoiser_conditional.yaml b/src/jamun/hydra_config/model/denoiser_conditional.yaml new file mode 100644 index 0000000..d4369ee --- /dev/null +++ b/src/jamun/hydra_config/model/denoiser_conditional.yaml @@ -0,0 +1,25 @@ +defaults: + - arch: e3conv_conditional.yaml + - optim: adam.yaml + - lr_scheduler_config: null + - conditioner: position.yaml # Default conditioner, can be overridden + - _self_ + +max_radius: null +average_squared_distance: null +add_fixed_noise: false +add_fixed_ones: false +align_noisy_input_during_training: true +align_noisy_input_during_evaluation: true +mean_center: true +mirror_augmentation_rate: 0.0 +use_torch_compile: true +torch_compile_kwargs: + fullgraph: true + dynamic: true + mode: default + +# Conditioner configuration now comes from defaults above +# Use conditioners=spatiotemporal to use the spatio-temporal conditioner + +_target_: jamun.model.denoiser_conditional.Denoiser \ No newline at end of file diff --git a/src/jamun/hydra_config/model/denoiser_conditional_pretrained.yaml b/src/jamun/hydra_config/model/denoiser_conditional_pretrained.yaml new file mode 100644 index 0000000..b473d6b --- /dev/null +++ b/src/jamun/hydra_config/model/denoiser_conditional_pretrained.yaml @@ -0,0 +1,2 @@ +_target_: jamun.model.denoiser_conditional.Denoiser.load_from_checkpoint +checkpoint_path: null diff --git a/src/jamun/hydra_config/model/denoiser_conditional_spatiotemporal.yaml b/src/jamun/hydra_config/model/denoiser_conditional_spatiotemporal.yaml new file mode 100644 index 0000000..69102df --- /dev/null +++ b/src/jamun/hydra_config/model/denoiser_conditional_spatiotemporal.yaml @@ -0,0 +1,31 @@ +defaults: + - arch: e3conv_conditional_spatiotemporal.yaml + - optim: adam.yaml + - conditioner: spatiotemporal.yaml + - _self_ + +# Model configuration +sigma_distribution: + _target_: jamun.model.sigma_distribution.ConstantSigma + sigma: 0.1 + +max_radius: 1.0 +average_squared_distance: 10.0 +add_fixed_noise: false +add_fixed_ones: false +align_noisy_input_during_training: true +align_noisy_input_during_evaluation: true +mean_center: true +mirror_augmentation_rate: 0.0 +bond_loss_coefficient: 1.0 +normalization_type: "JAMUN" +sigma_data: null +lr_scheduler_config: null +use_torch_compile: false +torch_compile_kwargs: null + +# Override to ensure N_structures matches spatiotemporal conditioner output +arch: + N_structures: 1 # For [y.pos, spatial_features] + +_target_: jamun.model.denoiser_conditional.Denoiser diff --git a/src/jamun/hydra_config/model/denoiser_multimeasurement.yaml b/src/jamun/hydra_config/model/denoiser_multimeasurement.yaml new file mode 100644 index 0000000..e098a80 --- /dev/null +++ b/src/jamun/hydra_config/model/denoiser_multimeasurement.yaml @@ -0,0 +1,32 @@ +defaults: + - arch: e3conv_conditional.yaml + - optim: adam.yaml + - conditioner: position_conditioner.yaml + - lr_scheduler_config: null + - _self_ + +max_radius: null +average_squared_distance: null +add_fixed_noise: false +add_fixed_ones: false +align_noisy_input_during_training: true +align_noisy_input_during_evaluation: true +mean_center: true +mirror_augmentation_rate: 0.0 +use_torch_compile: true +torch_compile_kwargs: + fullgraph: true + dynamic: true + mode: default + +# conditioner: +# _target_: jamun.model.conditioners.PositionConditioner +# N_structures: ${model.arch.N_structures} + +# Multimeasurement specific parameters +multimeasurement: true +N_measurements: 4 +N_measurements_hidden: 4 +max_graphs_per_batch: 32 + +_target_: jamun.model.DenoiserMultimeasurement \ No newline at end of file diff --git a/src/jamun/hydra_config/sample.yaml b/src/jamun/hydra_config/sample.yaml index dc97330..e7f53d2 100644 --- a/src/jamun/hydra_config/sample.yaml +++ b/src/jamun/hydra_config/sample.yaml @@ -6,7 +6,7 @@ defaults: - paths: default - hydra: default - callbacks: sampler/default - - experiment: null + - experiment: sample_uncapped_single_shape_conditioning float32_matmul_precision: "high" diff --git a/src/jamun/hydra_config/sample_memory.yaml b/src/jamun/hydra_config/sample_memory.yaml new file mode 100644 index 0000000..4d00052 --- /dev/null +++ b/src/jamun/hydra_config/sample_memory.yaml @@ -0,0 +1,27 @@ +defaults: + - _self_ + - model: denoiser_conditional_pretrained + - batch_sampler: single_measurement_sampler_memory + - logger: default + - paths: default + - hydra: default + - callbacks: sampler/default + - experiment: sample_uncapped_single_shape_conditioning + +float32_matmul_precision: "high" + +sample_pdb: null +repeat_init_samples: 1 +num_batches: 1 +continue_chain: true +finetune_on_init: false + +seed: 42 +task_name: "sample" +run_group: "dev" +run_key: ${now:%Y-%m-%d}_${now:%H-%M-%S} # NOTE in DDP this must be set consistently across ranks + +sampler: + _target_: jamun.sampling.SamplerMemory + _convert_: "partial" # loggers argument must be passed as plain list + precision: "32-true" diff --git a/src/jamun/hydra_config/train.yaml b/src/jamun/hydra_config/train.yaml index 3255d93..92623cd 100644 --- a/src/jamun/hydra_config/train.yaml +++ b/src/jamun/hydra_config/train.yaml @@ -12,6 +12,7 @@ defaults: float32_matmul_precision: "high" +seed: 42 task_name: "train" run_group: "dev" run_key: ${now:%Y-%m-%d}_${now:%H-%M-%S} # NOTE in DDP this must be set consistently across ranks diff --git a/src/jamun/model/__init__.py b/src/jamun/model/__init__.py index de26ff3..3369ea0 100644 --- a/src/jamun/model/__init__.py +++ b/src/jamun/model/__init__.py @@ -1,2 +1,5 @@ +from .conditioners import ConditionerSpiked from .denoiser import Denoiser +from .denoiser_multimeasurement import DenoiserMultimeasurement +from .denoiser_spiked import DenoiserSpiked from .energy import EnergyModel diff --git a/src/jamun/model/arch/__init__.py b/src/jamun/model/arch/__init__.py index 4b9edfd..3a35d9d 100644 --- a/src/jamun/model/arch/__init__.py +++ b/src/jamun/model/arch/__init__.py @@ -1,3 +1,4 @@ from .e3conv import E3Conv +from .e3conv_conditional import E3ConvConditional from .ophiuchus import Ophiuchus -from .orb import MoleculeGNSWrapper +from .spatiotemporal import E3SpatioTemporal, E3Transformer diff --git a/src/jamun/model/arch/e3conv.py b/src/jamun/model/arch/e3conv.py index 93d8aab..d0b0591 100644 --- a/src/jamun/model/arch/e3conv.py +++ b/src/jamun/model/arch/e3conv.py @@ -113,7 +113,7 @@ def forward( src, dst = edge_index edge_vec = pos[src] - pos[dst] edge_sh = self.sh(edge_vec) - + # print(f"Edge spherical harmonics: {type(edge_sh)}") bonded_edge_attr = self.embed_bondedness(bond_mask) radial_edge_attr = e3nn.math.soft_one_hot_linspace( edge_vec.norm(dim=1), diff --git a/src/jamun/model/arch/e3conv_conditional.py b/src/jamun/model/arch/e3conv_conditional.py new file mode 100644 index 0000000..a199712 --- /dev/null +++ b/src/jamun/model/arch/e3conv_conditional.py @@ -0,0 +1,466 @@ +from collections.abc import Callable + +import e3nn +import e3tools.nn +import torch +import torch_geometric +from e3nn import o3 +from e3nn.o3 import Irreps +from e3tools import scatter +from torch import Tensor + +from jamun.model.atom_embedding import AtomEmbeddingWithResidueInformation, SimpleAtomEmbedding +from jamun.model.noise_conditioning import NoiseConditionalScaling, NoiseConditionalSkipConnection + + +class E3ConvConditional(torch.nn.Module): + """A simple E(3)-equivariant convolutional neural network, similar to NequIP.""" + + def __init__( + self, + irreps_out: str | Irreps, + irreps_hidden: str | Irreps, + irreps_sh: str | Irreps, + hidden_layer_factory: Callable[..., torch.nn.Module], + output_head_factory: Callable[..., torch.nn.Module], + use_residue_information: bool, + n_layers: int, + edge_attr_dim: int, + atom_type_embedding_dim: int, + atom_code_embedding_dim: int, + residue_code_embedding_dim: int, + residue_index_embedding_dim: int, + use_residue_sequence_index: bool, + num_atom_types: int = 20, + max_sequence_length: int = 10, + num_atom_codes: int = 10, + num_residue_types: int = 25, + test_equivariance: bool = False, + reduce: str | None = None, + N_structures: int = 1, + ): + super().__init__() + + self.test_equivariance = test_equivariance + self.irreps_out = o3.Irreps(irreps_out) + self.irreps_hidden = o3.Irreps(irreps_hidden) + self.irreps_sh = o3.Irreps(irreps_sh) + self.n_layers = n_layers + self.edge_attr_dim = edge_attr_dim + self.N_structures = N_structures + self.sh = o3.SphericalHarmonics(irreps_out=self.irreps_sh, normalize=True, normalization="component") + self.bonded_edge_attr_dim, self.radial_edge_attr_dim = self.edge_attr_dim // 2, (self.edge_attr_dim + 1) // 2 + self.embed_bondedness = torch.nn.Embedding(2, self.bonded_edge_attr_dim) + + if use_residue_information: + self.atom_embedder = AtomEmbeddingWithResidueInformation( + atom_type_embedding_dim=atom_type_embedding_dim, + atom_code_embedding_dim=atom_code_embedding_dim, + residue_code_embedding_dim=residue_code_embedding_dim, + residue_index_embedding_dim=residue_index_embedding_dim, + use_residue_sequence_index=use_residue_sequence_index, + num_atom_types=num_atom_types, + max_sequence_length=max_sequence_length, + num_atom_codes=num_atom_codes, + num_residue_types=num_residue_types, + ) + else: + self.atom_embedder = SimpleAtomEmbedding( + embedding_dim=atom_type_embedding_dim + + atom_code_embedding_dim + + residue_code_embedding_dim + + residue_index_embedding_dim + ) + + self.initial_noise_scaling = NoiseConditionalScaling(self.atom_embedder.irreps_out) + self.initial_projector = hidden_layer_factory( + irreps_in=self.initial_noise_scaling.irreps_out, + irreps_out=self.irreps_hidden, + irreps_sh=N_structures * self.irreps_sh, + edge_attr_dim=edge_attr_dim, + ) + + self.layers = torch.nn.ModuleList() + self.noise_scalings = torch.nn.ModuleList() + self.skip_connections = torch.nn.ModuleList() + for _ in range(n_layers): + self.layers.append( + hidden_layer_factory( + irreps_in=self.irreps_hidden, + irreps_out=self.irreps_hidden, + irreps_sh=N_structures * self.irreps_sh, + edge_attr_dim=self.edge_attr_dim, + ) + ) + self.noise_scalings.append(NoiseConditionalScaling(self.irreps_hidden)) + self.skip_connections.append(NoiseConditionalSkipConnection(self.irreps_hidden)) + + self.output_head = output_head_factory(irreps_in=self.irreps_hidden, irreps_out=self.irreps_out) + self.output_gain = torch.nn.Parameter(torch.tensor(0.0)) + self.reduce = reduce + + def forward( + self, + pos: Tensor, # should be [batch_size*N, 3T], T is the number of previous time-steps + topology: torch_geometric.data.Batch, + c_noise: Tensor, + effective_radial_cutoff: float, + ) -> torch_geometric.data.Batch: + # Extract edge attributes. + edge_index = topology["edge_index"] + bond_mask = topology["bond_mask"] + + src, dst = edge_index # compute edge spherical harmonics over concat structures + positions = torch.split(pos, 3, dim=-1) + edge_sh = [] + for block in positions: + edge_vec = block[src] - block[dst] + edge_sh.append(self.sh(edge_vec)) + edge_sh = torch.cat(edge_sh, dim=-1) + + # print(f"Edge spherical harmonics: {type(edge_sh)}") + bonded_edge_attr = self.embed_bondedness(bond_mask) + edge_vec_main = positions[0][src] - positions[0][dst] + radial_edge_attr = e3nn.math.soft_one_hot_linspace( + edge_vec_main.norm(dim=1), + 0.0, + effective_radial_cutoff, + self.radial_edge_attr_dim, + basis="gaussian", + cutoff=True, + ) + edge_attr = torch.cat((bonded_edge_attr, radial_edge_attr), dim=-1) + + node_attr = self.atom_embedder(topology) + node_attr = self.initial_noise_scaling(node_attr, c_noise) + node_attr = self.initial_projector(node_attr, edge_index, edge_attr, edge_sh) + for scaling, skip, layer in zip(self.noise_scalings, self.skip_connections, self.layers): + node_attr = skip(node_attr, layer(scaling(node_attr, c_noise), edge_index, edge_attr, edge_sh), c_noise) + node_attr = self.output_head(node_attr) + node_attr = node_attr * self.output_gain + + if self.reduce is not None: + node_attr = scatter(node_attr, topology.batch, dim=0, reduce=self.reduce) + + return node_attr + + +class E3ConvConditionalWithInputAttr(E3ConvConditional): + """ + Extension of E3ConvConditional that can accept additional input attributes + and combine them with the computed node attributes. + """ + + def __init__( + self, + irreps_out: str | Irreps, + irreps_hidden: str | Irreps, + irreps_sh: str | Irreps, + hidden_layer_factory: Callable[..., torch.nn.Module], + output_head_factory: Callable[..., torch.nn.Module], + use_residue_information: bool, + n_layers: int, + edge_attr_dim: int, + atom_type_embedding_dim: int, + atom_code_embedding_dim: int, + residue_code_embedding_dim: int, + residue_index_embedding_dim: int, + use_residue_sequence_index: bool, + num_atom_types: int = 20, + max_sequence_length: int = 10, + num_atom_codes: int = 10, + num_residue_types: int = 25, + test_equivariance: bool = False, + reduce: str | None = None, + N_structures: int = 1, + input_attr_irreps: str | Irreps | None = None, + ): + """ + Initialize E3ConvConditionalWithInputAttr. + + Args: + input_attr_irreps: Irreps of the input attributes that will be combined with node_attr. + If None, the model behaves like the parent class. + All other args: Same as parent E3ConvConditional class. + """ + super().__init__( + irreps_out=irreps_out, + irreps_hidden=irreps_hidden, + irreps_sh=irreps_sh, + hidden_layer_factory=hidden_layer_factory, + output_head_factory=output_head_factory, + use_residue_information=use_residue_information, + n_layers=n_layers, + edge_attr_dim=edge_attr_dim, + atom_type_embedding_dim=atom_type_embedding_dim, + atom_code_embedding_dim=atom_code_embedding_dim, + residue_code_embedding_dim=residue_code_embedding_dim, + residue_index_embedding_dim=residue_index_embedding_dim, + use_residue_sequence_index=use_residue_sequence_index, + num_atom_types=num_atom_types, + max_sequence_length=max_sequence_length, + num_atom_codes=num_atom_codes, + num_residue_types=num_residue_types, + test_equivariance=test_equivariance, + reduce=reduce, + N_structures=N_structures, + ) + + self.input_attr_irreps = o3.Irreps(input_attr_irreps) if input_attr_irreps is not None else None + + # Create input irrep aggregator if input attributes are provided + if self.input_attr_irreps is not None: + # Combined irreps: node_attr irreps + input_attr irreps + combined_irreps = self.irreps_hidden + self.input_attr_irreps + + # Create aggregator that takes combined input and outputs node_attr irreps + self.input_irrep_aggregator = e3tools.nn.EquivariantMLP( + irreps_in=combined_irreps, + irreps_out=self.irreps_hidden, + irreps_hidden_list=[self.irreps_hidden], # Single hidden layer + ) + else: + self.input_irrep_aggregator = None + + def forward( + self, + pos: Tensor, + topology: torch_geometric.data.Batch, + c_noise: Tensor, + effective_radial_cutoff: float, + input_attr: Tensor | None = None, + ) -> Tensor: + """ + Forward pass with optional input attributes. + + Args: + pos: Node positions + topology: Graph topology + c_noise: Noise conditioning + effective_radial_cutoff: Radial cutoff for edges + input_attr: Optional input attributes to combine with node_attr. + Should have shape [N, input_attr_irreps.dim] where N is number of nodes. + + Returns: + Node attributes after processing + """ + # Extract edge attributes. + edge_index = topology["edge_index"] + bond_mask = topology["bond_mask"] + + src, dst = edge_index # compute edge spherical harmonics over concat structures + positions = torch.split(pos, 3, dim=-1) + edge_sh = [] + for block in positions: + edge_vec = block[src] - block[dst] + edge_sh.append(self.sh(edge_vec)) + edge_sh = torch.cat(edge_sh, dim=-1) + + # print(f"Edge spherical harmonics: {type(edge_sh)}") + bonded_edge_attr = self.embed_bondedness(bond_mask) + edge_vec_main = positions[0][src] - positions[0][dst] + radial_edge_attr = e3nn.math.soft_one_hot_linspace( + edge_vec_main.norm(dim=1), + 0.0, + effective_radial_cutoff, + self.radial_edge_attr_dim, + basis="gaussian", + cutoff=True, + ) + edge_attr = torch.cat((bonded_edge_attr, radial_edge_attr), dim=-1) + + node_attr = self.atom_embedder(topology) + node_attr = self.initial_noise_scaling(node_attr, c_noise) + node_attr = self.initial_projector(node_attr, edge_index, edge_attr, edge_sh) + + # Combine with input attributes if provided + if input_attr is not None and self.input_irrep_aggregator is not None: + # Validate input_attr shape + expected_dim = self.input_attr_irreps.dim + if input_attr.shape[-1] != expected_dim: + raise ValueError( + f"Expected input_attr to have dimension {expected_dim}, but got {input_attr.shape[-1]}" + ) + if input_attr.shape[0] != node_attr.shape[0]: + raise ValueError( + f"Expected input_attr to have {node_attr.shape[0]} nodes, but got {input_attr.shape[0]}" + ) + + # Concatenate node_attr with input_attr + combined_attr = torch.cat([node_attr, input_attr], dim=-1) + + # Aggregate to get back to node_attr irreps + node_attr = self.input_irrep_aggregator(combined_attr) + elif input_attr is not None and self.input_irrep_aggregator is None: + raise ValueError("input_attr provided but input_attr_irreps was not specified during initialization") + + # Continue with normal processing + for scaling, skip, layer in zip(self.noise_scalings, self.skip_connections, self.layers): + node_attr = skip(node_attr, layer(scaling(node_attr, c_noise), edge_index, edge_attr, edge_sh), c_noise) + node_attr = self.output_head(node_attr) + node_attr = node_attr * self.output_gain + + if self.reduce is not None: + node_attr = scatter(node_attr, topology.batch, dim=0, reduce=self.reduce) + + return node_attr + + +class E3ConvConditionalSpatioTemporal(E3ConvConditional): + """ + E3ConvConditional specifically designed for spatiotemporal conditioning. + + This class expects input positions to be concatenated as [y.pos, spatial_features] + where y.pos are the physical 3D coordinates and spatial_features are additional + attributes from the spatiotemporal model. + + Key differences from E3ConvConditional: + - Edge spherical harmonics are only computed for the first 3 coordinates (y.pos) + - Remaining coordinates are treated as per-node input attributes + - Input attributes are combined with computed node attributes + """ + + def __init__( + self, + irreps_out: str | Irreps, + irreps_hidden: str | Irreps, + irreps_sh: str | Irreps, + hidden_layer_factory: Callable[..., torch.nn.Module], + output_head_factory: Callable[..., torch.nn.Module], + use_residue_information: bool, + n_layers: int, + edge_attr_dim: int, + atom_type_embedding_dim: int, + atom_code_embedding_dim: int, + residue_code_embedding_dim: int, + residue_index_embedding_dim: int, + use_residue_sequence_index: bool, + num_atom_types: int = 20, + max_sequence_length: int = 10, + num_atom_codes: int = 10, + num_residue_types: int = 25, + test_equivariance: bool = False, + reduce: str | None = None, + N_structures: int = 1, # Should be 2 for [y.pos, spatial_features] + input_attr_irreps: str | Irreps = "3x1e", # Default for spatial features + ): + """ + Initialize E3ConvConditionalSpatioTemporal. + + Args: + input_attr_irreps: Irreps of the spatial features from spatiotemporal model. + Should match the irreps_out of the spatiotemporal model. + N_structures: Should be 2 for [y.pos, spatial_features] + All other args: Same as parent E3ConvConditional class. + """ + super().__init__( + irreps_out=irreps_out, + irreps_hidden=irreps_hidden, + irreps_sh=irreps_sh, + hidden_layer_factory=hidden_layer_factory, + output_head_factory=output_head_factory, + use_residue_information=use_residue_information, + n_layers=n_layers, + edge_attr_dim=edge_attr_dim, + atom_type_embedding_dim=atom_type_embedding_dim, + atom_code_embedding_dim=atom_code_embedding_dim, + residue_code_embedding_dim=residue_code_embedding_dim, + residue_index_embedding_dim=residue_index_embedding_dim, + use_residue_sequence_index=use_residue_sequence_index, + num_atom_types=num_atom_types, + max_sequence_length=max_sequence_length, + num_atom_codes=num_atom_codes, + num_residue_types=num_residue_types, + test_equivariance=test_equivariance, + reduce=reduce, + N_structures=N_structures, + ) + + # Set up input attribute handling + self.input_attr_irreps = o3.Irreps(input_attr_irreps) + self.input_attr_irreps_dim = self.input_attr_irreps.dim + # Create input irrep aggregator to combine node_attr with input_attr + # Combined irreps: node_attr irreps + input_attr irreps + combined_irreps = self.irreps_hidden + self.input_attr_irreps + + # Create aggregator that takes combined input and outputs node_attr irreps + self.input_irrep_aggregator = e3tools.nn.EquivariantMLP( + irreps_in=combined_irreps, + irreps_out=self.irreps_hidden, + irreps_hidden_list=[self.irreps_hidden], # Single hidden layer + ) + + def forward( + self, + pos: Tensor, # should be [N, 3 + spatial_features_dim] from [y.pos, spatial_features] + topology: torch_geometric.data.Batch, + c_noise: Tensor, + effective_radial_cutoff: float, + ) -> Tensor: + """ + Forward pass with spatiotemporal conditioning. + + Args: + pos: Concatenated positions [y.pos, spatial_features] with shape [N, 3 + spatial_features_dim] + topology: Graph topology + c_noise: Noise conditioning + effective_radial_cutoff: Radial cutoff for edges + + Returns: + Node attributes after processing + """ + # Extract edge attributes. + edge_index = topology["edge_index"] + bond_mask = topology["bond_mask"] + + src, dst = edge_index + + # Split positions: first 3 coords are physical positions, rest are spatial features + pos_physical = pos[:, :3] # [N, 3] - physical coordinates + pos_features = pos[:, 3:] # [N, spatial_features_dim] - spatial features + + # Compute edge spherical harmonics ONLY for physical positions + edge_vec_physical = pos_physical[src] - pos_physical[dst] + edge_sh = self.sh(edge_vec_physical) + + # Compute edge attributes using physical positions + bonded_edge_attr = self.embed_bondedness(bond_mask) + radial_edge_attr = e3nn.math.soft_one_hot_linspace( + edge_vec_physical.norm(dim=1), + 0.0, + effective_radial_cutoff, + self.radial_edge_attr_dim, + basis="gaussian", + cutoff=True, + ) + edge_attr = torch.cat((bonded_edge_attr, radial_edge_attr), dim=-1) + + # Compute initial node attributes + node_attr = self.atom_embedder(topology) + node_attr = self.initial_noise_scaling(node_attr, c_noise) + node_attr = self.initial_projector(node_attr, edge_index, edge_attr, edge_sh) + + # Combine node_attr with spatial features (input_attr) + # Validate spatial features shape + expected_dim = self.input_attr_irreps_dim + if pos_features.shape[-1] != expected_dim: + raise ValueError( + f"Expected spatial features to have dimension {expected_dim}, but got {pos_features.shape[-1]}" + ) + + # Concatenate node_attr with spatial features + combined_attr = torch.cat([node_attr, pos_features], dim=-1) + + # Aggregate to get back to node_attr irreps + node_attr = self.input_irrep_aggregator(combined_attr) + + # Continue with normal processing using only physical positions for edge computations + for scaling, skip, layer in zip(self.noise_scalings, self.skip_connections, self.layers): + node_attr = skip(node_attr, layer(scaling(node_attr, c_noise), edge_index, edge_attr, edge_sh), c_noise) + node_attr = self.output_head(node_attr) + node_attr = node_attr * self.output_gain + + if self.reduce is not None: + node_attr = scatter(node_attr, topology.batch, dim=0, reduce=self.reduce) + + return node_attr diff --git a/src/jamun/model/arch/spatiotemporal.py b/src/jamun/model/arch/spatiotemporal.py new file mode 100644 index 0000000..34845ea --- /dev/null +++ b/src/jamun/model/arch/spatiotemporal.py @@ -0,0 +1,551 @@ +""" +E(3)-equivariant spatio-temporal models and conversion functions. + +This module contains: +- E3Transformer: E(3)-equivariant transformer for temporal graph processing +- E3SpatioTemporal: Unified spatio-temporal processing model +- Spatial-temporal graph conversion utilities +""" + +import e3nn +import e3tools +import e3tools.nn +import torch +import torch.nn as nn +import torch_geometric +import torch_geometric.data +from e3nn import o3 + + +def calculate_temporal_positions(temporal_length, mode="linear", device=None): + """ + Calculate normalized temporal positions for nodes in a temporal graph. + + Args: + temporal_length: Total number of nodes in the temporal sequence + device: Device to create tensors on + + Returns: + torch.Tensor: Normalized positions [0, 1/T, 2/T, ..., (T-1)/T] + """ + if temporal_length <= 1: + return torch.tensor([0.0], device=device) + + if mode == "linear": + # Create positions [0, 1, 2, ..., T-1] and normalize by T + positions = torch.arange(temporal_length, dtype=torch.float32, device=device) + normalized_positions = positions / temporal_length + elif mode == "zeros": + # Create positions [0, 1, 2, ..., T-1] and normalize by T + positions = torch.arange(temporal_length, dtype=torch.float32, device=device) + positions = torch.zeros_like(positions) + + return normalized_positions + + +def spatial_to_temporal_graphs(batch, graph_type="fan"): + """ + Convert a batch of spatial graphs to temporal graphs with configurable connectivity. + + For each spatial node with position + hidden states, create a temporal graph where: + - Node 0: current position + - Nodes 1-T: hidden state positions + - Connectivity depends on graph_type parameter + + Args: + batch: Input spatial graph batch + graph_type: Type of connectivity to use + - "fan": Hub connects to all + sequential connections (0->all, i->(i+1)) + - "hub_n_spoke": Only hub-spoke connections (0->all, no sequential) + - "complete": Complete graph with self-loops (all-to-all including self) + - "complete_no_self": Complete graph without self-loops (all-to-all excluding self) + """ + import torch_geometric + + # Validate graph_type + valid_types = ["fan", "hub_n_spoke", "complete", "complete_no_self"] + if graph_type not in valid_types: + raise ValueError(f"graph_type must be one of {valid_types}, got {graph_type}") + + # Get device from input batch + device = batch.pos.device + + # Get dimensions + num_spatial_nodes = batch.pos.shape[0] + + # Check if we have hidden states + if hasattr(batch, "hidden_state") and batch.hidden_state is not None and len(batch.hidden_state) > 0: + num_hidden_states = len(batch.hidden_state) + temporal_length = 1 + num_hidden_states # current + hidden + else: + # If no hidden states, just use current position + num_hidden_states = 0 + temporal_length = 1 + + # print(f"Creating {graph_type} temporal graphs: {num_spatial_nodes} spatial nodes -> {num_spatial_nodes} temporal graphs of length {temporal_length}") + + # Store reference to spatial graph + spatial_graph = batch.clone() + + # Set connectivity type code for tracking + connectivity_type_map = {"fan": 0, "hub_n_spoke": 1, "complete": 2, "complete_no_self": 3} + + temporal_graphs = [] + + for node_idx in range(num_spatial_nodes): + # Build temporal positions: [current_pos, hidden_1, hidden_2, ...] + temporal_positions = [batch.pos[node_idx]] # Start with current position + + # Add hidden state positions + if num_hidden_states > 0: + for hidden_pos in batch.hidden_state: + temporal_positions.append(hidden_pos[node_idx]) + + temporal_pos = torch.stack(temporal_positions) # Shape: [T, 3] + + # Calculate temporal positions for this sequence + temporal_position = calculate_temporal_positions(temporal_length, device=device) + + # Create edge connectivity based on graph_type + if temporal_length > 1: + if graph_type == "fan": + # Original fan system: hub-spoke + sequential + # Hub connections: 0->1, 0->2, 0->3, ..., 0->T-1 + hub_src = [0] * (temporal_length - 1) + hub_dst = list(range(1, temporal_length)) + + # Sequential connections: 1->2, 2->3, ..., (T-2)->(T-1) + seq_src = list(range(1, temporal_length - 1)) + seq_dst = list(range(2, temporal_length)) + + # Combine all edges + all_src = hub_src + seq_src + all_dst = hub_dst + seq_dst + + edge_index = torch.tensor([all_src, all_dst], dtype=torch.long, device=device) + + elif graph_type == "hub_n_spoke": + # Hub-and-spoke only: 0 connects to all others, no sequential + hub_src = [0] * (temporal_length - 1) + hub_dst = list(range(1, temporal_length)) + + edge_index = torch.tensor([hub_src, hub_dst], dtype=torch.long, device=device) + + elif graph_type == "complete": + # Complete graph without self-loops: all-to-all excluding self + src_nodes = [] + dst_nodes = [] + + for i in range(temporal_length): + for j in range(temporal_length): + if i != j: # Exclude self-loops + src_nodes.append(i) + dst_nodes.append(j) + + edge_index = torch.tensor([src_nodes, dst_nodes], dtype=torch.long, device=device) + + else: + # Single node case + if graph_type == "complete": + # Single node with self-loop + edge_index = torch.tensor([[0], [0]], dtype=torch.long, device=device) + else: + # Single node, no edges for other types + edge_index = torch.tensor([[], []], dtype=torch.long, device=device) + + # Create temporal graph for this spatial node + temporal_graph = torch_geometric.data.Data( + pos=temporal_pos, + edge_index=edge_index, + spatial_node_idx=torch.tensor([node_idx], device=device), # Track which spatial node this came from + temporal_length=torch.tensor([temporal_length], device=device), + temporal_position=temporal_position, # Normalized position in sequence [0, 1/T, 2/T, ...] + connectivity_type=torch.tensor([connectivity_type_map[graph_type]], device=device), + # Note: Removed graph_type string to avoid batching issues with PyTorch Geometric + ) + temporal_graphs.append(temporal_graph) + + # Batch all temporal graphs + temporal_batch = torch_geometric.data.Batch.from_data_list(temporal_graphs) + + # Store spatial graph reference + temporal_batch.spatial_graph = spatial_graph + # Note: Removed graph_type string to avoid batching issues with PyTorch Geometric + # Graph type can be inferred from connectivity_type tensor attribute + + return temporal_batch + + +def temporal_to_spatial_graphs(temporal_batch): + """ + Convert temporal graphs back to spatial graphs. + Take the 0th node position from each temporal graph as the updated spatial position. + """ + # Get the spatial graph template + spatial_graph = temporal_batch.spatial_graph.clone() + + # Extract 0th node positions from each temporal graph + num_temporal_graphs = temporal_batch.num_graphs + updated_positions = [] + + # Iterate through each temporal graph in the batch + for graph_idx in range(num_temporal_graphs): + # Get the node range for this temporal graph + start_idx = temporal_batch.ptr[graph_idx] + + # The 0th node of each temporal graph is at the start of its range + updated_positions.append(temporal_batch.pos[start_idx]) + + # Stack to create new position tensor + updated_positions = torch.stack(updated_positions) + + # Update spatial graph with new positions + spatial_graph.pos = updated_positions + + return spatial_graph + + +class E3Transformer(nn.Module): + """E(3)-equivariant transformer with temporal graph support.""" + + def __init__( + self, + irreps_out: str | e3nn.o3.Irreps, + irreps_hidden: str | e3nn.o3.Irreps, + irreps_sh: str | e3nn.o3.Irreps, + irreps_node_attr: str | e3nn.o3.Irreps, + num_layers: int, + edge_attr_dim: int, + num_attention_heads: int, + reduce: str | None = None, + conv=e3tools.nn.Conv, + irreps_node_attr_temporal: str | e3nn.o3.Irreps = "1x1e", + radial_edge_attr_encoding_function: str = "gaussian", + node_attr_temporal_encoding_function: str = "gaussian", + edge_attr_temporal_encoding_function: str = "gaussian", + ): + super().__init__() + + self.irreps_out = o3.Irreps(irreps_out) + self.irreps_hidden = o3.Irreps(irreps_hidden) + self.irreps_sh = o3.Irreps(irreps_sh) + self.irreps_node_attr = o3.Irreps(irreps_node_attr) # input irreps + self.irreps_node_attr_temporal = o3.Irreps(irreps_node_attr_temporal) + self.num_layers = num_layers + self.edge_attr_dim = edge_attr_dim + self.num_attention_heads = num_attention_heads + self.reduce = reduce + self.sh = o3.SphericalHarmonics(irreps_out=self.irreps_sh, normalize=True, normalization="component") + # Split edge attribute dimensions: radial and temporal (bondedness is optional) + self.radial_edge_attr_dim = self.edge_attr_dim // 2 + self.temporal_edge_attr_dim = self.edge_attr_dim - self.radial_edge_attr_dim + self.temporal_node_attr_dim = self.irreps_node_attr_temporal.dim + # Optional bondedness embedding (only used if bond_mask exists in graph) + self.embed_bondedness = nn.Embedding(2, self.edge_attr_dim // 3) + self.edge_attr_temporal_encoding_function = edge_attr_temporal_encoding_function + self.node_attr_temporal_encoding_function = node_attr_temporal_encoding_function + self.radial_edge_attr_encoding_function = radial_edge_attr_encoding_function + # Gate for combining node attributes with temporal position + # Input: node_attr (from data) + temporal_position (1x0e scalar) + # irreps_with_temporal = self.irreps_node_attr + o3.Irreps("1x0e") + irreps_with_temporal = self.irreps_node_attr + self.irreps_node_attr_temporal + self.temporal_gate = e3tools.nn.GateWrapper( + irreps_in=irreps_with_temporal, + irreps_out=self.irreps_hidden, + irreps_gate=irreps_with_temporal, + ) + + self.layers = nn.ModuleList() + self.conv = conv + for _ in range(num_layers): + self.layers.append( + e3tools.nn.TransformerBlock( + irreps_in=self.irreps_hidden, + irreps_out=self.irreps_hidden, + irreps_sh=self.irreps_sh, + edge_attr_dim=self.edge_attr_dim, + num_heads=self.num_attention_heads, + conv=self.conv, + ) + ) + self.output_head = e3tools.nn.EquivariantMLP( + irreps_in=self.irreps_hidden, + irreps_out=self.irreps_out, + irreps_hidden_list=[self.irreps_hidden], + ) + + def forward( + self, + node_attr: torch.Tensor, + temporal_graph: torch_geometric.data.Batch, + effective_radial_cutoff: float, + temporal_cutoff: float = 1.0, + ) -> torch.Tensor: + """Forward pass of the E3Transformer model.""" + # Extract graph data + pos = temporal_graph.pos + edge_index = temporal_graph.edge_index + temporal_position = temporal_graph.temporal_position + batch = temporal_graph.batch + num_graphs = temporal_graph.num_graphs + + src, dst = edge_index + edge_vec = pos[src] - pos[dst] + edge_sh = self.sh(edge_vec) + + # Compute edge attributes: radial and temporal + if self.radial_edge_attr_encoding_function != "ones": + radial_edge_attr = e3nn.math.soft_one_hot_linspace( + edge_vec.norm(dim=1), + 0.0, + temporal_cutoff, + self.radial_edge_attr_dim, + basis=self.radial_edge_attr_encoding_function, + cutoff=True, + ) + else: + radial_edge_attr = e3nn.math.soft_one_hot_linspace( + edge_vec.norm(dim=1), + 0.0, + temporal_cutoff, + self.radial_edge_attr_dim, + basis="gaussian", + cutoff=True, + ) + radial_edge_attr = torch.ones_like(radial_edge_attr) + + # Temporal edge attributes from temporal_position differences + temporal_edge_vec = temporal_position[src] - temporal_position[dst] + if self.edge_attr_temporal_encoding_function != "ones": + temporal_edge_attr = e3nn.math.soft_one_hot_linspace( + temporal_edge_vec.abs(), # Use absolute difference + 0.0, + 2.0, + self.temporal_edge_attr_dim, + basis=self.edge_attr_temporal_encoding_function, + cutoff=True, + ) + else: + temporal_edge_attr = e3nn.math.soft_one_hot_linspace( + temporal_edge_vec.abs(), # Use absolute difference + 0.0, + 2.0, + self.temporal_edge_attr_dim, + basis="gaussian", + cutoff=True, + ) + temporal_edge_attr = torch.ones_like(temporal_edge_attr) + + # temporal_edge_attr = torch.ones_like(temporal_edge_attr) # TODO: remove this, this is hacking. + + # Optional bondedness (if bond_mask exists in the temporal graph) + if hasattr(temporal_graph, "bond_mask") and temporal_graph.bond_mask is not None: + bonded_edge_attr = self.embed_bondedness(temporal_graph.bond_mask) + edge_attr = torch.cat((bonded_edge_attr, radial_edge_attr, temporal_edge_attr), dim=-1) + else: + edge_attr = torch.cat((radial_edge_attr, temporal_edge_attr), dim=-1) + + # Process node attributes with temporal gating + + # Concatenate node_attr with temporal_position (scalar) + if self.node_attr_temporal_encoding_function != "ones": + temporal_position = e3nn.math.soft_one_hot_linspace( + temporal_position, # Use absolute difference + 0.0, # time always starts at 0 + 1.0, # time always ends at 1 + self.temporal_node_attr_dim, + basis=self.node_attr_temporal_encoding_function, + cutoff=True, + ) + else: + temporal_position = e3nn.math.soft_one_hot_linspace( + temporal_position, # Use absolute difference + 0.0, # time always starts at 0 + 1.0, # time always ends at 1 + self.temporal_node_attr_dim, + basis="gaussian", + cutoff=True, + ) + temporal_position = torch.ones_like(temporal_position) + temporal_position_expanded = temporal_position # [N, 1] for concatenation + node_attr_with_temporal = torch.cat([node_attr, temporal_position_expanded], dim=-1) + + # Apply temporal gate + node_attr_processed = self.temporal_gate(node_attr_with_temporal) + + # Perform message passing with gated node attributes + for layer in self.layers: + node_attr_processed = layer(node_attr_processed, edge_index, edge_attr, edge_sh) + node_attr_processed = self.output_head(node_attr_processed) + + # Pool over nodes. + if self.reduce is not None: + node_attr_processed = e3tools.scatter( + node_attr_processed, + index=batch, + dim=0, + dim_size=num_graphs, + reduce=self.reduce, + ) + + return node_attr_processed + + +class E3SpatioTemporal(nn.Module): + """ + E(3)-equivariant spatio-temporal model that combines spatial and temporal processing. + + This model implements the complete workflow: + 1. Process input spatial graph and hidden states through spatial module + 2. Pool spatial features to temporal graph representation + 3. Process temporal graph through temporal module + 4. Pool temporal features back to spatial representation + 5. Convert temporal graph back to spatial graph + """ + + def __init__( + self, + spatial_module: nn.Module, + temporal_module: nn.Module, + spatial_to_temporal_pooler: nn.Module, + temporal_to_spatial_pooler: nn.Module, + radial_cutoff: float, + temporal_cutoff: float = 1.0, + graph_type: str | None = "fan", + ): + """ + Initialize the E3SpatioTemporal model. + + Args: + spatial_module: Module for processing spatial positions (e.g., E3Conv) + temporal_module: Module for processing temporal graphs (e.g., E3Transformer) + spatial_to_temporal_pooler: Module to convert spatial-temporal features to temporal node attributes + temporal_to_spatial_pooler: Module to convert temporal features back to spatial features + radial_cutoff: Cutoff for spatial radial edge weights + temporal_cutoff: Cutoff for temporal edge weights + """ + super().__init__() + + self.spatial_module = spatial_module + self.temporal_module = temporal_module + self.spatial_to_temporal_pooler = spatial_to_temporal_pooler + self.temporal_to_spatial_pooler = temporal_to_spatial_pooler + self.radial_cutoff = radial_cutoff + self.temporal_cutoff = temporal_cutoff + self.graph_type = graph_type + + def forward( + self, + batch: torch_geometric.data.Batch, + c_noise: torch.Tensor, + return_temporal_features: bool = False, + return_temporal_graph: bool = False, + ) -> torch.Tensor | dict[str, torch.Tensor]: + """ + Forward pass implementing the complete spatio-temporal workflow. + + Args: + batch: Input spatial graph batch with pos, batch, num_graphs, and optionally hidden_state + c_noise: Noise conditioning tensor + return_temporal_features: Whether to return intermediate temporal features + return_temporal_graph: Whether to return the temporal graph + + Returns: + If return_temporal_features or return_temporal_graph is True, returns dict with: + - 'spatial_features': Final spatial features + - 'spatial_graph': Output spatial graph + - 'temporal_features': Temporal features (if requested) + - 'temporal_graph': Temporal graph (if requested) + - 'graph_type': Graph type used for conversion + Otherwise returns just the final spatial features tensor + """ + # Store original device + + # Step 1: Convert spatial graph to temporal graphs + if hasattr(self, "graph_type") and self.graph_type is not None: + temporal_batch = spatial_to_temporal_graphs(batch, graph_type=self.graph_type) + else: + temporal_batch = spatial_to_temporal_graphs(batch) # default to fan graph type + + # Step 2: Process all positions (current + hidden states) with spatial module + # Create topology for spatial processing (without positions) + topology = batch.clone() + # Remove position-dependent attributes but keep graph structure + if hasattr(topology, "pos"): + del topology.pos + if hasattr(topology, "batch"): + del topology.batch + if hasattr(topology, "num_graphs"): + del topology.num_graphs + + node_attr_list = [] + + # Process current positions + node_attr_current = self.spatial_module( + pos=batch.pos, + topology=topology, + batch=batch.batch, + num_graphs=batch.num_graphs, + c_noise=c_noise, + effective_radial_cutoff=self.radial_cutoff, + ).unsqueeze(1) # [N, 1, features] + node_attr_list.append(node_attr_current) + + # Process hidden state positions if they exist + if hasattr(batch, "hidden_state") and batch.hidden_state is not None and len(batch.hidden_state) > 0: + for hidden_pos in batch.hidden_state: + node_attr_hidden = self.spatial_module( + pos=hidden_pos, + topology=topology, + batch=batch.batch, + num_graphs=batch.num_graphs, + c_noise=c_noise, + effective_radial_cutoff=self.radial_cutoff, + ).unsqueeze(1) # [N, 1, features] + node_attr_list.append(node_attr_hidden) + + # Step 3: Stack spatial-temporal features + node_attr_spatial_temporal = torch.cat(node_attr_list, dim=1) # [N, T, features] + + # Step 4: Convert spatial-temporal features to temporal node attributes + temporal_node_attr = self.spatial_to_temporal_pooler(node_attr_spatial_temporal, temporal_batch) + + # Step 5: Process temporal graph through temporal module + temporal_output = self.temporal_module( + temporal_node_attr, temporal_batch, self.radial_cutoff, self.temporal_cutoff + ) + + # Step 6: Pool temporal features back to spatial features + spatial_features = self.temporal_to_spatial_pooler(temporal_output, temporal_batch) + + # Step 7: Convert temporal graph back to spatial graph + # output_spatial_graph = temporal_to_spatial_graphs(temporal_batch) + output_spatial_graph = batch + + # Prepare return values + if return_temporal_features or return_temporal_graph: + result = { + "spatial_features": spatial_features, + "spatial_graph": output_spatial_graph, + } + if return_temporal_features: + result["temporal_features"] = temporal_output + if return_temporal_graph: + result["temporal_graph"] = temporal_batch + return result + else: + return spatial_features + + def get_spatial_output_irreps(self): + """Get the irreps of the spatial module output.""" + if hasattr(self.spatial_module, "irreps_out"): + return self.spatial_module.irreps_out + else: + raise AttributeError("Spatial module does not have irreps_out attribute") + + def get_temporal_output_irreps(self): + """Get the irreps of the temporal module output.""" + if hasattr(self.temporal_module, "irreps_out"): + return self.temporal_module.irreps_out + else: + raise AttributeError("Temporal module does not have irreps_out attribute") diff --git a/src/jamun/model/conditioner_usage_example.py b/src/jamun/model/conditioner_usage_example.py new file mode 100644 index 0000000..a684ff6 --- /dev/null +++ b/src/jamun/model/conditioner_usage_example.py @@ -0,0 +1,327 @@ +""" +Test for ConditionerSpiked with DenoiserSpiked using ALA_ALA data. + +This file demonstrates and tests the DenoiserSpiked model with ConditionerSpiked. +""" + +import functools +import os +from pathlib import Path + +import torch + +import jamun +import jamun.data +import jamun.distributions +from jamun.model import DenoiserSpiked +from jamun.model.arch import E3ConvConditional +from jamun.model.conditioners import ConditionerSpiked + + +def get_ala_ala_data(num_frames=20, total_lag_time=5): + """ + Load ALA_ALA data with specified parameters. + + Args: + num_frames: Number of frames to load per dataset + total_lag_time: Number of hidden states (total time lag) + + Returns: + List of datasets + """ + # Check if data path exists + data_path = os.getenv("JAMUN_DATA_PATH") + if data_path is None: + # Try common locations + possible_paths = ["/data/bucket/kleinhej/", "/data2/sules/", "/path/to/data/"] + for path in possible_paths: + ala_path = Path(path) / "capped_diamines/timewarp_splits/train" + if ala_path.exists(): + data_path = path + break + + if data_path is None: + raise ValueError( + "JAMUN_DATA_PATH not set and cannot find data. Please set JAMUN_DATA_PATH environment variable." + ) + + print(f"Using data path: {data_path}") + root_path = f"{data_path}/capped_diamines/timewarp_splits/train" + + datasets = jamun.data.parse_datasets_from_directory( + root=root_path, + traj_pattern="^(.*).xtc", + pdb_pattern="^(.*).pdb", + filter_codes=["ALA_ALA"], + as_iterable=False, + subsample=1, + total_lag_time=total_lag_time, + lag_subsample_rate=1, + num_frames=num_frames, + max_datasets=1, # Just use one dataset for testing + ) + + return datasets + + +def create_test_denoiser_spiked(total_lag_time=5): + """ + Create a simple DenoiserSpiked model for testing. + + Args: + total_lag_time: Number of structures for conditioning + + Returns: + DenoiserSpiked model + """ + import e3tools.nn + + # Note: The actual data has 4 hidden states, so we'll have 4 + 1 clean = 5 structures + actual_n_structures = 5 # 4 hidden states + 1 clean structure + + arch = functools.partial( + E3ConvConditional, + irreps_out="1x1e", + irreps_hidden="32x0e + 8x1e", # Smaller for testing + irreps_sh="1x0e + 1x1e", + n_layers=2, # Fewer layers for faster testing + edge_attr_dim=32, + atom_type_embedding_dim=8, + atom_code_embedding_dim=8, + residue_code_embedding_dim=16, + residue_index_embedding_dim=8, + use_residue_information=True, + use_residue_sequence_index=False, + N_structures=actual_n_structures, # Match actual data structure count + hidden_layer_factory=functools.partial( + e3tools.nn.ConvBlock, + conv=e3tools.nn.Conv, + ), + output_head_factory=functools.partial(e3tools.nn.EquivariantMLP, irreps_hidden_list=["32x0e + 8x1e"]), + ) + + conditioner = ConditionerSpiked(N_structures=actual_n_structures) + + denoiser = DenoiserSpiked( + arch=arch, + optim=functools.partial(torch.optim.Adam, lr=1e-3), + sigma_distribution=jamun.distributions.ConstantSigma(sigma=0.04), + max_radius=1000.0, # Large radius for testing + average_squared_distance=10.0, + add_fixed_noise=False, + add_fixed_ones=False, + align_noisy_input_during_training=True, + align_noisy_input_during_evaluation=True, + mean_center=True, + mirror_augmentation_rate=0.0, + conditioner=conditioner, + ) + + return denoiser + + +def test_noise_and_denoise(): + """ + Test the noise_and_denoise method with ALA_ALA data. + """ + print("=" * 60) + print("Testing DenoiserSpiked with ConditionerSpiked on ALA_ALA data") + print("=" * 60) + + # Load data + try: + total_lag_time = 5 + datasets = get_ala_ala_data(num_frames=10, total_lag_time=total_lag_time) + print(f"✅ Successfully loaded {len(datasets)} datasets") + + dataset = datasets[0] + print(f" Dataset label: {dataset.label()}") + print(f" Dataset length: {len(dataset)}") + + # Get a sample + sample = dataset[0] + print(f" Sample positions shape: {sample.pos.shape}") + print( + f" Sample hidden states: {len(sample.hidden_state) if hasattr(sample, 'hidden_state') and sample.hidden_state else 0}" + ) + if hasattr(sample, "hidden_state") and sample.hidden_state: + for i, h in enumerate(sample.hidden_state): + print(f" Hidden state {i} shape: {h.shape}") + + except Exception as e: + print(f"❌ Failed to load data: {e}") + return False + + # Create model + try: + denoiser = create_test_denoiser_spiked(total_lag_time=total_lag_time) + print("✅ Successfully created DenoiserSpiked model") + print(f" Conditioner: {type(denoiser.conditioning_module).__name__}") + + except Exception as e: + print(f"❌ Failed to create model: {e}") + return False + + # Test noise_and_denoise + try: + print("\n" + "-" * 40) + print("Testing noise_and_denoise method...") + print("-" * 40) + + # Convert to batch for testing + import torch_geometric.data + + batch = torch_geometric.data.Batch.from_data_list([sample]) + print(f" Batch positions shape: {batch.pos.shape}") + print(f" Batch num_graphs: {batch.num_graphs}") + + # Set model to eval mode + denoiser.eval() + + # Test with different sigma values + sigma_values = [0.01, 0.04, 0.1] + + for sigma in sigma_values: + print(f"\n Testing with sigma = {sigma}") + + # Run noise_and_denoise + with torch.no_grad(): + x_target, xhat, y_noisy = denoiser.noise_and_denoise(batch, sigma=sigma, align_noisy_input=True) + + print(" ✅ noise_and_denoise completed successfully") + print(f" Target shape: {x_target.pos.shape}") + print(f" Prediction shape: {xhat.pos.shape}") + print(f" Noisy input shape: {y_noisy.pos.shape}") + + # Check that shapes match + assert x_target.pos.shape == xhat.pos.shape == y_noisy.pos.shape + + # Test conditioning + print(" Testing conditioner...") + print( + f" y_noisy has hidden_state: {hasattr(y_noisy, 'hidden_state') and y_noisy.hidden_state is not None}" + ) + if hasattr(y_noisy, "hidden_state") and y_noisy.hidden_state is not None: + print(f" y_noisy hidden_state count: {len(y_noisy.hidden_state)}") + print(f" x_target is not None: {x_target is not None}") + if x_target is not None: + print(f" x_target.pos shape: {x_target.pos.shape}") + print( + f" x_target has hidden_state: {hasattr(x_target, 'hidden_state') and x_target.hidden_state is not None}" + ) + + conditioned_structures = denoiser.conditioner(y_noisy, x_target) + print(f" Conditioned structures count: {len(conditioned_structures)}") + + for i, struct in enumerate(conditioned_structures): + print(f" Structure {i} shape: {struct.shape}") + + # Verify that the last structure is the clean structure + if len(conditioned_structures) > 0: + last_structure = conditioned_structures[-1] + clean_structure = x_target.pos + if torch.allclose(last_structure, clean_structure, atol=1e-6): + print(" ✅ Last conditioned structure matches x_clean.pos") + else: + print(" ❌ Last conditioned structure does NOT match x_clean.pos") + print(f" Max difference: {torch.max(torch.abs(last_structure - clean_structure)).item():.8f}") + + # Calculate some basic metrics + noise_level = torch.mean(torch.norm(y_noisy.pos - x_target.pos, dim=-1)) + prediction_error = torch.mean(torch.norm(xhat.pos - x_target.pos, dim=-1)) + print(f" Average noise level: {noise_level:.4f}") + print(f" Average prediction error: {prediction_error:.4f}") + + print("\n✅ All noise_and_denoise tests passed!") + return True + + except Exception as e: + print(f"❌ Failed during noise_and_denoise testing: {e}") + import traceback + + traceback.print_exc() + return False + + +def test_conditioning_shapes(): + """ + Test that conditioning produces expected shapes. + """ + print("\n" + "=" * 60) + print("Testing ConditionerSpiked shape outputs") + print("=" * 60) + + # Create dummy data + N_atoms = 22 # ALA_ALA has 22 atoms + N_structures = 5 + + # Create fake batch + pos = torch.randn(N_atoms, 3) + hidden_states = [torch.randn(N_atoms, 3) for _ in range(N_structures - 2)] # -2 for current pos and clean pos + + # Create fake torch_geometric batch + import torch_geometric.data + + y = torch_geometric.data.Data(pos=pos, hidden_state=hidden_states) + x_clean = torch_geometric.data.Data(pos=torch.randn(N_atoms, 3)) + + # Test conditioner + conditioner = ConditionerSpiked(N_structures=N_structures) + conditioned_structures = conditioner.forward(y, x_clean) + + print("Input shapes:") + print(f" y.pos: {y.pos.shape}") + print(f" y.hidden_state: {[h.shape for h in y.hidden_state]}") + print(f" x_clean.pos: {x_clean.pos.shape}") + + print("\nConditioned structures:") + for i, struct in enumerate(conditioned_structures): + print(f" Structure {i}: {struct.shape}") + + # Test concatenation (like in the model) + concatenated = torch.cat(conditioned_structures, dim=-1) + print(f"\nConcatenated shape: {concatenated.shape}") + expected_dim = len(conditioned_structures) * 3 # Each structure has 3D coordinates + print(f"Expected last dimension: {expected_dim}") + + # Verify that the last structure is the clean structure + if len(conditioned_structures) > 0: + last_structure = conditioned_structures[-1] + clean_structure = x_clean.pos + if torch.allclose(last_structure, clean_structure, atol=1e-6): + print("✅ Last conditioned structure matches x_clean.pos") + else: + print("❌ Last conditioned structure does NOT match x_clean.pos") + print(f" Max difference: {torch.max(torch.abs(last_structure - clean_structure)).item():.8f}") + raise AssertionError("Last conditioned structure should match x_clean.pos") + + assert concatenated.shape == (N_atoms, expected_dim) + print("✅ Shape test passed!") + + +if __name__ == "__main__": + # Run tests + try: + # Test data loading and noise_and_denoise + success = test_noise_and_denoise() + + # Test conditioning shapes + test_conditioning_shapes() + + if success: + print("\n" + "=" * 60) + print("🎉 ALL TESTS PASSED! 🎉") + print("DenoiserSpiked with ConditionerSpiked is working correctly!") + print("=" * 60) + else: + print("\n" + "=" * 60) + print("❌ SOME TESTS FAILED") + print("=" * 60) + + except KeyboardInterrupt: + print("\n⚠️ Tests interrupted by user") + except Exception as e: + print(f"\n❌ Unexpected error: {e}") + import traceback + + traceback.print_exc() diff --git a/src/jamun/model/conditioners/__init__.py b/src/jamun/model/conditioners/__init__.py new file mode 100644 index 0000000..1d5ecb2 --- /dev/null +++ b/src/jamun/model/conditioners/__init__.py @@ -0,0 +1,8 @@ +from .conditioners import ( + Conditioner, + ConditionerSpiked, + DenoisedConditioner, + MeanConditioner, + PositionConditioner, + SelfConditioner, +) diff --git a/src/jamun/model/conditioners/conditioners.py b/src/jamun/model/conditioners/conditioners.py new file mode 100644 index 0000000..e6ebc82 --- /dev/null +++ b/src/jamun/model/conditioners/conditioners.py @@ -0,0 +1,356 @@ +import logging + +import e3nn +import lightning.pytorch as pl +import torch +import torch_geometric + +from jamun.model.denoiser_conditional import Denoiser + +# Fix e3nn optimization for avoiding script issues +e3nn.set_optimization_defaults(jit_script_fx=False) + +from jamun.utils import mean_center, unsqueeze_trailing +from jamun.utils.align import kabsch_algorithm +from jamun.utils.checkpoint import find_checkpoint + + +class Conditioner(pl.LightningModule): + """ + Base class for conditioners. + """ + + def __init__(self, N_structures: int, **kwargs): + super().__init__() + self.N_structures = N_structures + + +class PositionConditioner(pl.LightningModule): + """ + Condition the hidden state on the position of the structure. + """ + + def __init__(self, N_structures: int, align_hidden_states: bool = True, **kwargs): + super().__init__() + self.N_structures = N_structures + self.align_hidden_states = align_hidden_states + + def forward(self, y: torch_geometric.data.Batch) -> list[torch.Tensor]: + conditioned_structures = [y.pos] # Start with current position + for positions in y.hidden_state: + if self.align_hidden_states: + aligned_positions = kabsch_algorithm(positions, y.pos, y.batch, y.num_graphs) + conditioned_structures.append(aligned_positions) + else: + conditioned_structures.append(positions) + return conditioned_structures + + +class SelfConditioner(pl.LightningModule): + """ + No conditioning, but add the position of the structure to itself to make it compatible with the denoiser. + """ + + def __init__(self, N_structures: int, **kwargs): + super().__init__() + self.N_structures = N_structures + + def forward(self, y: torch_geometric.data.Batch) -> list[torch.Tensor]: + conditioned_structures = [y.pos for _ in range(self.N_structures)] # Include current position + return conditioned_structures + + +class MeanConditioner(pl.LightningModule): + """ + Condition on the mean across time steps of positions and hidden states. + For each atom and coordinate, averages across all T+1 structures (current + hidden states). + """ + + def __init__(self, N_structures: int, **kwargs): + super().__init__() + self.N_structures = N_structures + + def forward(self, y: torch_geometric.data.Batch) -> list[torch.Tensor]: + # Start with current position + all_positions = [y.pos] + + # Add all hidden states if they exist + if hasattr(y, "hidden_state") and y.hidden_state is not None: + all_positions.extend(y.hidden_state) + + # Stack all positions along a new dimension and compute mean across time steps + # Shape: (T+1, N, 3) -> (N, 3) where T is number of hidden states + stacked_positions = torch.stack(all_positions, dim=0) # (T+1, N, 3) + mean_positions = torch.mean(stacked_positions, dim=0) # (N, 3) + # mean center the mean positions + dummy_graph = y.clone() + dummy_graph.pos = mean_positions + # mean center the mean positions + mean_positions = mean_center(dummy_graph).pos + # align the mean positions to the current positions + aligned_mean_positions = kabsch_algorithm(mean_positions, y.pos, y.batch, y.num_graphs) + + # Return the mean repeated N_structures times + conditioned_structures = [aligned_mean_positions for _ in range(self.N_structures)] + + return conditioned_structures + + +class DenoisedConditioner(pl.LightningModule): + """ + Conditioner that uses a pretrained denoiser to denoise hidden states. + + Takes hidden states, unscales them using c_in, denoises each structure, + then recenters and aligns them to the current noisy positions. + """ + + def __init__(self, N_structures: int, pretrained_model_path: str, c_in: float, **kwargs): + super().__init__() + self.N_structures = N_structures + self.c_in = c_in + self.pretrained_model_path = pretrained_model_path + + # Load the pretrained denoiser + py_logger = logging.getLogger("jamun") + py_logger.info(f"Loading pretrained denoiser from wandb run: {pretrained_model_path}") + + # Find the checkpoint for the wandb run + checkpoint_path = find_checkpoint(wandb_train_run_path=pretrained_model_path, checkpoint_type="best_so_far") + + # Load the denoiser from checkpoint + self.pretrained_denoiser = Denoiser.load_from_checkpoint(checkpoint_path, strict=False) + self.pretrained_denoiser.eval() # Set to evaluation mode + + # Freeze the pretrained model parameters + for param in self.pretrained_denoiser.parameters(): + param.requires_grad = False + + # Extract sigma from the pretrained denoiser + self.denoiser_sigma = self._extract_sigma_from_denoiser() + py_logger.info(f"Extracted sigma from pretrained denoiser: {self.denoiser_sigma}") + py_logger.info(f"Successfully loaded pretrained denoiser with c_in={c_in}") + + def _extract_sigma_from_denoiser(self) -> float: + """Extract sigma value from the pretrained denoiser's sigma distribution.""" + sigma_distribution = self.pretrained_denoiser.sigma_distribution + + # Handle different types of sigma distributions + if hasattr(sigma_distribution, "sigma"): + # For ConstantSigma distribution + return float(sigma_distribution.sigma) + elif hasattr(sigma_distribution, "mean"): + # For other distributions that might have a mean + return float(sigma_distribution.mean) + else: + # Fallback - sample from the distribution + sample = sigma_distribution.sample() + return float(sample) + + def forward(self, y: torch_geometric.data.Batch) -> list[torch.Tensor]: + """ + Forward pass that denoises hidden states and returns conditioned structures. + + Args: + y: Batch containing current positions and hidden states + + Returns: + List of tensors: [y.pos, *denoised_hidden_states] + """ + # Use the sigma from the pretrained denoiser + sigma_to_use = self.denoiser_sigma + + conditioned_structures = [y.pos] # Start with current position + + # Check if we have hidden states to process + if not hasattr(y, "hidden_state") or y.hidden_state is None: + # If no hidden states, just repeat current position + conditioned_structures.extend([y.pos for _ in range(self.N_structures - 1)]) + return conditioned_structures + + # # Move pretrained denoiser to same device as input + # device = y.pos.device + # self.pretrained_denoiser = self.pretrained_denoiser.to(device) + + # Process each hidden state + for i, hidden_positions in enumerate(y.hidden_state): + # Unscale the hidden state positions + unscaled_positions = hidden_positions / self.c_in + + # Create a batch for denoising + denoising_batch = y.clone() + denoising_batch.pos = unscaled_positions + + # Remove hidden states from the denoising batch to avoid recursion + if hasattr(denoising_batch, "hidden_state"): + delattr(denoising_batch, "hidden_state") + + # Denoise the unscaled positions using the denoiser's sigma + with torch.no_grad(): + denoised_batch = self.pretrained_denoiser.xhat(denoising_batch, sigma_to_use) + denoised_positions = denoised_batch.pos + + # Align the denoised positions to the current noisy positions + aligned_positions = kabsch_algorithm(denoised_positions, y.pos, y.batch, y.num_graphs) + + conditioned_structures.append(aligned_positions) + + # Break if we've processed enough structures + if len(conditioned_structures) >= self.N_structures: + break + + # If we don't have enough hidden states, pad with the last denoised structure + while len(conditioned_structures) < self.N_structures: + conditioned_structures.append(conditioned_structures[-1]) + + return conditioned_structures + + +class ConditionerSpiked(Conditioner): + """ + A conditioner that concatenates hidden states with the clean structure. + + The conditioning order is: + 1. Hidden states (y.hidden_state) - if present + 2. Clean structure positions (x_clean.pos) - if provided at the end + """ + + def __init__(self, N_structures: int, **kwargs): + super().__init__(N_structures, **kwargs) + + def forward(self, y: torch_geometric.data.Batch, x_clean: torch_geometric.data.Batch = None) -> list[torch.Tensor]: + """ + Create conditioning structures by concatenating hidden states with clean structure. + + Args: + y: The noisy sample batch containing positions and hidden states + x_clean: The clean sample batch containing ground truth positions + + Returns: + List of tensors to be concatenated for conditioning + """ + conditioned_structures = [y.pos] + + # Add hidden states if they exist + if hasattr(y, "hidden_state") and y.hidden_state is not None: + for hidden_pos in y.hidden_state: + conditioned_structures.append(hidden_pos) + + # Add clean structure positions at the end if provided + if x_clean is not None: + conditioned_structures.pop(-1) + conditioned_structures.append(x_clean.pos) + + return conditioned_structures + + +class SpatioTemporalConditioner(pl.LightningModule): + """ + Conditioner that uses a spatio-temporal model to process hidden states. + + This conditioner takes the current positions and hidden states, processes them + through a spatio-temporal model, and returns [y.pos, spatial_features]. + Always returns exactly 2 structures: the original positions and computed spatial features. + + By default, the spatiotemporal model is trainable. Set freeze_spatiotemporal_model=True + to freeze the parameters (e.g., when using a pretrained model). + """ + + def __init__( + self, + N_structures: int, + spatiotemporal_model: torch.nn.Module, + c_noise: float = 0.0, + freeze_spatiotemporal_model: bool = False, + **kwargs, + ): + """ + Initialize the SpatioTemporalConditioner. + + Args: + N_structures: Number of structures parameter (ignored - this conditioner always returns 1 structure) + spatiotemporal_model: The E3SpatioTemporal model to use for processing + c_noise: Noise conditioning parameter + freeze_spatiotemporal_model: Whether to freeze spatiotemporal model parameters + **kwargs: Additional arguments passed to parent class + """ + super().__init__() + self.N_structures = N_structures + self.spatiotemporal_model = spatiotemporal_model + self.c_noise = c_noise + self.freeze_spatiotemporal_model = freeze_spatiotemporal_model + + # Only freeze parameters if explicitly requested + if self.freeze_spatiotemporal_model: + self.freeze_spatiotemporal_parameters() + # Set to evaluation mode when frozen + self.spatiotemporal_model.eval() + + def freeze_spatiotemporal_parameters(self): + """Freeze the spatiotemporal model parameters.""" + for param in self.spatiotemporal_model.parameters(): + param.requires_grad = False + + def unfreeze_spatiotemporal_parameters(self): + """Unfreeze the spatiotemporal model parameters.""" + for param in self.spatiotemporal_model.parameters(): + param.requires_grad = True + + def configure_for_inference(self): + """Configure the conditioner for inference (freeze parameters and set eval mode).""" + self.freeze_spatiotemporal_model = True + self.freeze_spatiotemporal_parameters() + self.spatiotemporal_model.eval() + + def configure_for_training(self): + """Configure the conditioner for training (unfreeze parameters and set train mode).""" + self.freeze_spatiotemporal_model = False + self.unfreeze_spatiotemporal_parameters() + self.spatiotemporal_model.train() + + def forward(self, y: torch_geometric.data.Batch) -> list[torch.Tensor]: + """ + Forward pass that processes the batch through the spatio-temporal model. + + Args: + y: Batch containing current positions and hidden states + + Returns: + List containing [y.pos, spatial_features] for concatenation by the denoiser + """ + # Align hidden positions with current position before processing + if hasattr(y, "hidden_state") and y.hidden_state is not None and len(y.hidden_state) > 0: + # Create a copy of the batch to avoid modifying the original + y_aligned = y.clone() + + # Align each hidden state to the current position + aligned_hidden_states = [] + for hidden_pos in y.hidden_state: + # Align hidden_pos to y.pos using Kabsch algorithm (same as PositionConditioner) + aligned_hidden_pos = kabsch_algorithm(hidden_pos, y.pos, y.batch, y.num_graphs) + aligned_hidden_states.append(aligned_hidden_pos) + + # Update the hidden states in the aligned batch + y_aligned.hidden_state = aligned_hidden_states + else: + # If no hidden states, use original batch + y_aligned = y + + # Prepare noise conditioning + device = y_aligned.pos.device + sigma = torch.tensor(self.c_noise, device=device) + sigma = unsqueeze_trailing( + sigma, 1 + ) # actually this is correct, but this is bad variable naming. the positional e3conv will take a c_noise, so this is right, but it is not right to call it sigma. + + # Process through spatio-temporal model with aligned hidden states + # Only disable gradients if the model is frozen + if self.freeze_spatiotemporal_model: + with torch.no_grad(): + spatial_features = self.spatiotemporal_model(y_aligned, sigma) + else: + # Allow gradients to flow when training + spatial_features = self.spatiotemporal_model(y_aligned, sigma) + + # Return list containing [y.pos, spatial_features] for concatenation + # The denoiser will concatenate these along the feature dimension + return [y.pos, spatial_features] diff --git a/src/jamun/model/denoiser.py b/src/jamun/model/denoiser.py index a0a8567..c0f5b0f 100644 --- a/src/jamun/model/denoiser.py +++ b/src/jamun/model/denoiser.py @@ -182,6 +182,13 @@ def __init__( if self.rotational_augmentation: py_logger.info("Rotational augmentation is enabled.") + def on_before_optimizer_step(self, optimizer): + # Log gradients and parameters. + for name, param in self.named_parameters(): + self.log(f"parameter_norms/{name}", param.norm(), sync_dist=True) + if param.grad is not None: + self.log(f"gradient_norms/{name}", param.grad.norm(), sync_dist=True) + def add_noise(self, x: torch.Tensor, sigma: float | torch.Tensor, num_graphs: int) -> torch.Tensor: # pos [B, ...] sigma = unsqueeze_trailing(sigma, x.ndim) @@ -317,8 +324,8 @@ def noise_and_denoise( x = align_A_to_B_batched_f( x, y, - topology.batch, - topology.num_graphs, + batch, + num_graphs, sigma=sigma, correction_order=self.alignment_correction_order, ) @@ -386,8 +393,8 @@ def noise_and_compute_loss( align_noisy_input: bool, ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: """Add noise to the input and compute the loss.""" - xhat, x, _ = self.noise_and_denoise(x, topology, sigma, align_noisy_input=align_noisy_input) - return self.compute_loss(x, xhat, topology, sigma) + xhat, x, _ = self.noise_and_denoise(x, topology, batch, num_graphs, sigma, align_noisy_input=align_noisy_input) + return self.compute_loss(x, xhat, topology, batch, num_graphs, sigma) def training_step(self, data: torch_geometric.data.Batch, data_idx: int): """Called during training.""" diff --git a/src/jamun/model/denoiser_conditional.py b/src/jamun/model/denoiser_conditional.py new file mode 100644 index 0000000..c8008ee --- /dev/null +++ b/src/jamun/model/denoiser_conditional.py @@ -0,0 +1,492 @@ +import logging +import os +from collections.abc import Callable + +import e3tools +import lightning.pytorch as pl +import numpy as np +import torch +import torch_geometric +from e3tools import scatter + +from jamun.utils import mean_center, unsqueeze_trailing +from jamun.utils.align import kabsch_algorithm + + +class Denoiser(pl.LightningModule): + """The main denoiser mode with conditional architecture.""" + + def __init__( + self, + arch: Callable[..., torch.nn.Module], + optim: Callable[..., torch.optim.Optimizer], + sigma_distribution: torch.distributions.Distribution, + max_radius: float, + average_squared_distance: float, + add_fixed_noise: bool, + add_fixed_ones: bool, + align_noisy_input_during_training: bool, + align_noisy_input_during_evaluation: bool, + mean_center: bool, + mirror_augmentation_rate: float, + bond_loss_coefficient: float = 1.0, + normalization_type: str | None = "JAMUN", + sigma_data: float | None = None, # Only used if normalization_type is "EDM" + lr_scheduler_config: dict | None = None, + use_torch_compile: bool = True, + torch_compile_kwargs: dict | None = None, + conditioner: Callable[..., list[torch.Tensor]] = None, + rotational_augmentation: bool = False, + alignment_correction_order: int = 0, + pass_topology_as_atom_graphs: bool = False, + ): + super().__init__() + self.save_hyperparameters(logger=False) + + self.g = arch() + if use_torch_compile: + if torch_compile_kwargs is None: + torch_compile_kwargs = {} + + self.g = torch.compile(self.g, **torch_compile_kwargs) + + py_logger = logging.getLogger("jamun") + py_logger.info(self.g) + + self.optim_factory = optim + self.lr_scheduler_config = lr_scheduler_config + self.sigma_distribution = sigma_distribution + self.max_radius = max_radius + + self.add_fixed_noise = add_fixed_noise + self.add_fixed_ones = add_fixed_ones + if self.add_fixed_noise and self.add_fixed_ones: + raise ValueError("Can't add fixed noise and fixed ones at the same time") + if self.add_fixed_noise: + py_logger.info("Adding fixed noise") + if self.add_fixed_ones: + py_logger.info("Adding fixed ones") + + self.average_squared_distance = average_squared_distance + py_logger.info(f"Average squared distance = {self.average_squared_distance}") + + self.align_noisy_input_during_training = align_noisy_input_during_training + if self.align_noisy_input_during_training: + py_logger.info("Aligning noisy input during training.") + else: + py_logger.info("Not aligning noisy input during training.") + + self.align_noisy_input_during_evaluation = align_noisy_input_during_evaluation + if self.align_noisy_input_during_evaluation: + py_logger.info("Aligning noisy input during evaluation.") + else: + py_logger.info("Not aligning noisy input during evaluation.") + + self.mean_center = mean_center + if self.mean_center: + py_logger.info("Mean centering input and output.") + else: + py_logger.info("Not mean centering input and output.") + + self.mirror_augmentation_rate = mirror_augmentation_rate + py_logger.info(f"Mirror augmentation rate: {self.mirror_augmentation_rate}") + + self.normalization_type = normalization_type + if self.normalization_type is not None: + py_logger.info(f"Normalization type: {self.normalization_type}") + else: + py_logger.info("No normalization") + + self.sigma_data = sigma_data + if self.normalization_type == "EDM" and self.sigma_data is None: + raise ValueError("sigma_data must be provided when normalization_type is 'EDM'") + elif self.normalization_type != "EDM" and self.sigma_data is not None: + raise ValueError("sigma_data can only be used when normalization_type is 'EDM'") + + self.bond_loss_coefficient = bond_loss_coefficient + self.conditioning_module = conditioner + if self.conditioning_module is not None and not callable(self.conditioning_module): + raise ValueError("Conditioner must be a callable or None") + py_logger.info(f"Conditioner: {self.conditioning_module}") + + def on_before_optimizer_step(self, optimizer): + # Log gradients and parameters. + for name, param in self.named_parameters(): + self.log(f"parameter_norms/{name}", param.norm(), sync_dist=True) + if param.grad is not None: + self.log(f"gradient_norms/{name}", param.grad.norm(), sync_dist=True) + + def conditioner_default(self, y: torch_geometric.data.Batch) -> list[torch.Tensor]: + conditioned_structures = [y.pos] # Return complete list starting with current position + return conditioned_structures + + def conditioner(self, y: torch_geometric.data.Batch) -> list[torch.Tensor]: + if self.conditioning_module is None: + return self.conditioner_default(y) + elif callable(self.conditioning_module): + return self.conditioning_module(y) + else: + raise ValueError("Conditioner must be a callable or None") + + def _align_A_to_B_batched_with_hidden_states( + self, A: torch_geometric.data.Batch, B: torch_geometric.data.Batch + ) -> torch_geometric.data.Batch: + """Aligns each graph of A to the corresponding graph in B, including hidden states.""" + A_aligned = A.clone() + + # Align positions + A_aligned.pos = kabsch_algorithm(A.pos, B.pos, A.batch, A.num_graphs) + + # Align hidden states + if hasattr(A, "hidden_state") and A.hidden_state is not None: + A_aligned.hidden_state = [] + for i in range(len(A.hidden_state)): + A_aligned.hidden_state.append(kabsch_algorithm(A.hidden_state[i], B.pos, A.batch, A.num_graphs)) + return A_aligned + + def _mean_center_hidden_states(self, data: torch_geometric.data.Batch): + if hasattr(data, "hidden_state") and data.hidden_state is not None: + for i in range(len(data.hidden_state)): + mean = scatter(data.hidden_state[i], data.batch, dim=0, reduce="mean") + data.hidden_state[i] = data.hidden_state[i] - mean[data.batch] + return data + + def add_noise(self, x: torch_geometric.data.Batch, sigma: float | torch.Tensor) -> torch_geometric.data.Batch: + # pos [B, ...] + sigma = unsqueeze_trailing(sigma, x.pos.ndim) + + y = x.clone() + if self.add_fixed_ones: + noise = torch.ones_like(x.pos) + hidden_noise = [torch.randn_like(x.hidden_state[i]) for i in range(len(x.hidden_state))] + elif self.add_fixed_noise: + torch.manual_seed(0) + num_batches = x.batch.max().item() + 1 + if len(x.pos.shape) == 2: + num_nodes_per_batch = x.pos.shape[0] // num_batches + noise = torch.randn_like(x.pos[:num_nodes_per_batch]).repeat(num_batches, 1) + hidden_noise = [ + torch.randn_like(x.hidden_state[i][:num_nodes_per_batch]).repeat(num_batches, 1) + for i in range(len(x.hidden_state)) + ] + if len(x.pos.shape) == 3: + num_nodes_per_batch = x.pos.shape[1] + noise = torch.randn_like(x.pos[0]).repeat(num_batches, 1, 1) + hidden_noise = [ + torch.randn_like(x.hidden_state[i][0]).repeat(num_batches, 1, 1) for i in range(len(x.hidden_state)) + ] + else: + noise = torch.randn_like(x.pos) + hidden_noise = [torch.randn_like(x.hidden_state[i]) for i in range(len(x.hidden_state))] + y.pos = x.pos + sigma * noise + for i in range(len(y.hidden_state)): + y.hidden_state[i] = x.hidden_state[i] + sigma * hidden_noise[i] + if torch.rand(()) < self.mirror_augmentation_rate: + y.pos = -y.pos + return y + + def score(self, y: torch_geometric.data.Batch, sigma: float | torch.Tensor) -> torch_geometric.data.Batch: + """Compute the score function.""" + sigma = torch.as_tensor(sigma).to(y.pos) + return (self.xhat(y, sigma).pos - y.pos) / (unsqueeze_trailing(sigma, y.pos.ndim - 1) ** 2) + + def normalization_factors(self, sigma: float, D: int = 3) -> tuple[float, float, float, float]: + """Normalization factors for the input and output.""" + sigma = torch.as_tensor(sigma) + + if self.normalization_type is None: + return 1.0, 0.0, 1.0, sigma + + if self.normalization_type == "EDM": + c_skip = (self.sigma_data**2) / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / torch.sqrt(self.sigma_data**2 + sigma**2) + c_in = 1 / torch.sqrt(sigma**2 + self.sigma_data**2) + c_noise = torch.log(sigma / self.sigma_data) * 0.25 + return c_in, c_skip, c_out, c_noise + + if self.normalization_type == "JAMUN": + A = torch.as_tensor(self.average_squared_distance) + B = torch.as_tensor(2 * D * sigma**2) + + c_in = 1.0 / torch.sqrt(A + B) + c_skip = A / (A + B) + c_out = torch.sqrt((A * B) / (A + B)) + c_noise = torch.log(sigma) / 4 + return c_in, c_skip, c_out, c_noise + + raise ValueError(f"Unknown normalization type: {self.normalization_type}") + + def loss_weight(self, sigma: float, D: int = 3) -> float: + """Loss weight for this graph.""" + _, _, c_out, _ = self.normalization_factors(sigma, D) + return 1 / (c_out**2) + + def effective_radial_cutoff(self, sigma: float | torch.Tensor) -> torch.Tensor: + """Compute the effective radial cutoff for the noise level.""" + return torch.sqrt((self.max_radius**2) + 6 * (sigma**2)) + + def add_edges(self, y: torch_geometric.data.Batch, radial_cutoff: float) -> torch_geometric.data.Batch: + """Add edges to the graph based on the effective radial cutoff.""" + if y.get("edge_index") is not None: + return y + + y = y.clone() + if "batch" in y: + batch = y["batch"] + else: + batch = torch.zeros(y.num_nodes, dtype=torch.long, device=self.device) + + with torch.cuda.nvtx.range("radial_graph"): + radial_edge_index = e3tools.radius_graph(y.pos, radial_cutoff, batch) + + with torch.cuda.nvtx.range("concatenate_edges"): + edge_index = torch.cat((radial_edge_index, y.bonded_edge_index), dim=-1) + if y.bonded_edge_index.numel() == 0: + bond_mask = torch.zeros(radial_edge_index.shape[1], dtype=torch.long, device=y.pos.device) + else: + bond_mask = torch.cat( + ( + torch.zeros(radial_edge_index.shape[1], dtype=torch.long, device=y.pos.device), + torch.ones(y.bonded_edge_index.shape[1], dtype=torch.long, device=y.pos.device), + ), + dim=0, + ) + + y.edge_index = edge_index + y.bond_mask = bond_mask + return y + + def xhat_normalized(self, y: torch_geometric.data.Batch, sigma: float | torch.Tensor) -> torch_geometric.data.Batch: + """Compute the denoised prediction using the normalization factors from JAMUN.""" + sigma = torch.as_tensor(sigma).to(y.pos) + D = y.pos.shape[-1] + + # Compute the normalization factors. + with torch.cuda.nvtx.range("normalization_factors"): + c_in, c_skip, c_out, c_noise = self.normalization_factors(sigma, D) + radial_cutoff = self.effective_radial_cutoff(sigma) / c_in + + # Adjust dimensions. + c_in = unsqueeze_trailing(c_in, y.pos.ndim - 1) + c_skip = unsqueeze_trailing(c_skip, y.pos.ndim - 1) + c_out = unsqueeze_trailing(c_out, y.pos.ndim - 1) + c_noise = c_noise.unsqueeze(0) + + # Add edges to the graph. + with torch.cuda.nvtx.range("add_edges"): + y = self.add_edges(y, radial_cutoff) + + with torch.cuda.nvtx.range("scale_y"): + y_scaled = y.clone() + y_scaled.pos = y.pos * c_in + # Manually copy hidden state + if hasattr(y, "hidden_state") and y.hidden_state is not None: + y_scaled.hidden_state = [] + for positions in y.hidden_state: + y_scaled.hidden_state.append(positions * c_in) + + with torch.cuda.nvtx.range("clone_y"): + xhat = y.clone() + # Manually copy hidden state + if hasattr(y, "hidden_state") and y.hidden_state is not None: + xhat.hidden_state = [h.clone() for h in y.hidden_state] + + with torch.cuda.nvtx.range("conditioning"): + conditioned_structures = self.conditioner(y_scaled) + # print(f"Conditioner is working, number of conditioned structures: {len(conditioned_structures)}") + with torch.cuda.nvtx.range("g"): + g_pred = self.g( + torch.cat([*conditioned_structures], dim=-1), + topology=y_scaled, + c_noise=c_noise, + effective_radial_cutoff=radial_cutoff, + ) + + xhat.pos = c_skip * y.pos + c_out * g_pred + if hasattr(y, "hidden_state") and y.hidden_state is not None: + xhat.hidden_state = [y.pos, *y.hidden_state[:-1]] + return xhat + + def xhat(self, y: torch.Tensor, sigma: float | torch.Tensor): + """Compute the denoised prediction.""" + if self.mean_center: + with torch.cuda.nvtx.range("mean_center_y"): + y = mean_center(y) + y = self._mean_center_hidden_states(y) + + with torch.cuda.nvtx.range("xhat_normalized"): + xhat = self.xhat_normalized(y, sigma) + + # Mean center the prediction. + if self.mean_center: + with torch.cuda.nvtx.range("mean_center_xhat"): + xhat = mean_center(xhat) + + return xhat + + def noise_and_denoise( + self, + x: torch_geometric.data.Batch, + sigma: float | torch.Tensor, + align_noisy_input: bool, + ) -> tuple[torch_geometric.data.Batch, torch_geometric.data.Batch, torch_geometric.data.Batch]: + """ + Add noise to the input and denoise it. + Returns the target for the loss, the prediction, and the noisy input. + """ + with torch.no_grad(): + if self.mean_center: + # Operate on a clone to avoid side effects on the original batch object. + x_processed = mean_center(x) + x_processed = self._mean_center_hidden_states(x_processed) + else: + x_processed = x + + sigma = torch.as_tensor(sigma).to(x_processed.pos) + + with torch.cuda.nvtx.range("add_noise"): + y = self.add_noise(x_processed, sigma) + x_target = x_processed.clone() + # Manually copy hidden state + if hasattr(x_processed, "hidden_state") and x_processed.hidden_state is not None: + x_target.hidden_state = [h.clone() for h in x_processed.hidden_state] + + if self.mean_center: + with torch.cuda.nvtx.range("mean_center_y"): + y = mean_center(y) + y = self._mean_center_hidden_states(y) + + # Aligning each batch. + if align_noisy_input: + with torch.cuda.nvtx.range("align_A_to_B_batched"): + y = self._align_A_to_B_batched_with_hidden_states(y, x_target) + + with torch.cuda.nvtx.range("xhat"): + xhat = self.xhat(y, sigma) + + return x_target, xhat, y + + def compute_loss( + self, + x: torch_geometric.data.Batch, + xhat: torch.Tensor, + sigma: float | torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute the loss.""" + if self.mean_center: + with torch.cuda.nvtx.range("mean_center_x"): + x = mean_center(x) + + D = xhat.pos.shape[-1] + + # Compute the raw loss. + with torch.cuda.nvtx.range("raw_coordinate_loss"): + raw_coordinate_loss = (xhat.pos - x.pos).pow(2).sum(dim=-1) + + # Take the mean over each graph. + with torch.cuda.nvtx.range("mean_over_graphs"): + mse = scatter(raw_coordinate_loss, x.batch, dim=0, dim_size=x.num_graphs, reduce="mean") + + # Compute the scaled RMSD. + with torch.cuda.nvtx.range("scaled_rmsd"): + rmsd = torch.sqrt(mse) + scaled_rmsd = rmsd / (sigma * np.sqrt(D)) + + # Account for the loss weight across graphs and noise levels. + with torch.cuda.nvtx.range("loss_weight"): + loss = mse * x.loss_weight + loss = loss * self.loss_weight(sigma, D) + + return loss, { + "mse": mse, + "rmsd": rmsd, + "scaled_rmsd": scaled_rmsd, + } + + def noise_and_compute_loss( + self, + x: torch_geometric.data.Batch, + sigma: float | torch.Tensor, + align_noisy_input: bool, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Add noise to the input and compute the loss.""" + x_target, xhat, _ = self.noise_and_denoise(x, sigma, align_noisy_input=align_noisy_input) + return self.compute_loss(x_target, xhat, sigma) + + def _automatic_step(self, batch: torch_geometric.data.Batch, stage: str): + """The standard step for automatic optimization.""" + align_noisy_input = ( + self.align_noisy_input_during_training if stage == "train" else self.align_noisy_input_during_evaluation + ) + sigma = self.sigma_distribution.sample().to(self.device) + + loss, aux = self.noise_and_compute_loss( + batch, + sigma, + align_noisy_input=align_noisy_input, + ) # check if the loss is nan. if nan then save the model, and the batch and see what went on. + if torch.isnan(loss.sum()): + print(f"Loss is nan at step {self.global_step}") + print(f"Batch: {batch}") + print(f"Sigma: {sigma}") + print(f"Align noisy input: {align_noisy_input}") + print(f"Loss: {loss}") + print(f"Aux: {aux}") + # Create debug directory if it doesn't exist + debug_dir = f"/homefs/home/sules/jamun/debug_nan_loss_step_{self.global_step}" + os.makedirs(debug_dir, exist_ok=True) + + # Save model checkpoint + checkpoint_path = os.path.join(debug_dir, "model_nan_loss.ckpt") + self.trainer.save_checkpoint(checkpoint_path) + print(f"Model saved to {checkpoint_path}") + + torch.save(batch, debug_dir + "/batch_nan_loss.pt") + + # Optionally raise an exception to stop training + raise RuntimeError(f"NaN loss detected at step {self.global_step}. Debug files saved to {debug_dir}") + + # Average the loss and other metrics over all graphs. + with torch.cuda.nvtx.range("mean_over_graphs"): + aux["loss"] = loss + for key in aux: + aux[key] = aux[key].mean() + if stage == "train": + self.log(f"train/{key}", aux[key], prog_bar=False, batch_size=batch.num_graphs, sync_dist=False) + elif stage == "val": + self.log( + f"val/{key}", + aux[key], + prog_bar=(key == "scaled_rmsd"), + batch_size=batch.num_graphs, + sync_dist=True, + ) + else: + continue + + return { + "sigma": sigma, + **aux, + } + + def training_step(self, batch: torch_geometric.data.Batch, batch_idx: int): + """Called during training.""" + return self._automatic_step(batch, "train") + + def validation_step(self, batch: torch_geometric.data.Batch, batch_idx: int): + """Called during validation.""" + self._automatic_step(batch, "val") + + def configure_optimizers(self): + """Set up the optimizer and learning rate scheduler.""" + optimizer = self.optim_factory(params=self.parameters()) + + out = {"optimizer": optimizer} + if self.lr_scheduler_config: + scheduler = self.lr_scheduler_config.pop("scheduler") + out["lr_scheduler"] = { + "scheduler": scheduler(optimizer), + **self.lr_scheduler_config, + } + + return out diff --git a/src/jamun/model/denoiser_multimeasurement.py b/src/jamun/model/denoiser_multimeasurement.py new file mode 100644 index 0000000..d003945 --- /dev/null +++ b/src/jamun/model/denoiser_multimeasurement.py @@ -0,0 +1,690 @@ +import logging +from collections.abc import Callable + +import e3tools +import lightning.pytorch as pl +import numpy as np +import torch +import torch_geometric +from e3tools import scatter + +from jamun.utils import mean_center, unsqueeze_trailing +from jamun.utils.align import kabsch_algorithm + + +class DenoiserMultimeasurement(pl.LightningModule): + """The main denoiser mode with conditional architecture.""" + + def __init__( + self, + arch: Callable[..., torch.nn.Module], + optim: Callable[..., torch.optim.Optimizer], + sigma_distribution: torch.distributions.Distribution, + max_radius: float, + average_squared_distance: float, + add_fixed_noise: bool, + add_fixed_ones: bool, + align_noisy_input_during_training: bool, + align_noisy_input_during_evaluation: bool, + mean_center: bool, + mirror_augmentation_rate: float, + bond_loss_coefficient: float = 1.0, + normalization_type: str | None = "JAMUN", + sigma_data: float | None = None, # Only used if normalization_type is "EDM" + lr_scheduler_config: dict | None = None, + use_torch_compile: bool = True, + torch_compile_kwargs: dict | None = None, + conditioner: Callable[..., list[torch.Tensor]] = None, + multimeasurement: bool = False, + N_measurements_hidden: int = 1, + N_measurements: int = 1, + max_graphs_per_batch: int = None, + rotational_augmentation: bool = False, + alignment_correction_order: int = 0, + pass_topology_as_atom_graphs: bool = False, + ): + super().__init__() + self.save_hyperparameters(logger=False) + + # Let us control the optimization process only if we need to chunk batches. + self.automatic_optimization = max_graphs_per_batch is None + + self.g = arch() + if use_torch_compile: + if torch_compile_kwargs is None: + torch_compile_kwargs = {} + + self.g = torch.compile(self.g, **torch_compile_kwargs) + + py_logger = logging.getLogger("jamun") + py_logger.info(self.g) + + self.optim_factory = optim + self.lr_scheduler_config = lr_scheduler_config + self.sigma_distribution = sigma_distribution + self.max_radius = max_radius + + self.add_fixed_noise = add_fixed_noise + self.add_fixed_ones = add_fixed_ones + if self.add_fixed_noise and self.add_fixed_ones: + raise ValueError("Can't add fixed noise and fixed ones at the same time") + if self.add_fixed_noise: + py_logger.info("Adding fixed noise") + if self.add_fixed_ones: + py_logger.info("Adding fixed ones") + + self.average_squared_distance = average_squared_distance + py_logger.info(f"Average squared distance = {self.average_squared_distance}") + + self.align_noisy_input_during_training = align_noisy_input_during_training + if self.align_noisy_input_during_training: + py_logger.info("Aligning noisy input during training.") + else: + py_logger.info("Not aligning noisy input during training.") + + self.align_noisy_input_during_evaluation = align_noisy_input_during_evaluation + if self.align_noisy_input_during_evaluation: + py_logger.info("Aligning noisy input during evaluation.") + else: + py_logger.info("Not aligning noisy input during evaluation.") + + self.mean_center = mean_center + if self.mean_center: + py_logger.info("Mean centering input and output.") + else: + py_logger.info("Not mean centering input and output.") + + self.mirror_augmentation_rate = mirror_augmentation_rate + py_logger.info(f"Mirror augmentation rate: {self.mirror_augmentation_rate}") + + self.normalization_type = normalization_type + if self.normalization_type is not None: + py_logger.info(f"Normalization type: {self.normalization_type}") + else: + py_logger.info("No normalization") + + self.sigma_data = sigma_data + if self.normalization_type == "EDM" and self.sigma_data is None: + raise ValueError("sigma_data must be provided when normalization_type is 'EDM'") + elif self.normalization_type != "EDM" and self.sigma_data is not None: + raise ValueError("sigma_data can only be used when normalization_type is 'EDM'") + + self.bond_loss_coefficient = bond_loss_coefficient + self.conditioning_module = conditioner + if self.conditioning_module is not None and not callable(self.conditioning_module): + raise ValueError("Conditioner must be a callable or None") + py_logger.info(f"Conditioner: {self.conditioning_module}") + + self.multimeasurement = multimeasurement + self.N_measurements_hidden = N_measurements_hidden + self.N_measurements = N_measurements + self.max_graphs_per_batch = max_graphs_per_batch + if not self.automatic_optimization: + py_logger.info(f"Manual optimization enabled with micro-batch size of {self.max_graphs_per_batch} graphs.") + + def on_before_optimizer_step(self, optimizer): + # Log gradients and parameters. + for name, param in self.named_parameters(): + self.log(f"parameter_norms/{name}", param.norm(), sync_dist=True) + if param.grad is not None: + self.log(f"gradient_norms/{name}", param.grad.norm(), sync_dist=True) + + def _align_A_to_B_batched_with_hidden_states( + self, A: torch_geometric.data.Batch, B: torch_geometric.data.Batch + ) -> torch_geometric.data.Batch: + """Aligns each graph of A to the corresponding graph in B, including hidden states.""" + A_aligned = A.clone() + + # Align positions + A_aligned.pos = kabsch_algorithm(A.pos, B.pos, A.batch, A.num_graphs) + + # Align hidden states + if hasattr(A, "hidden_state") and A.hidden_state is not None: + A_aligned.hidden_state = [] + for i in range(len(A.hidden_state)): + A_aligned.hidden_state.append(kabsch_algorithm(A.hidden_state[i], B.pos, A.batch, A.num_graphs)) + return A_aligned + + def _mean_center_hidden_states(self, data: torch_geometric.data.Batch): + if hasattr(data, "hidden_state") and data.hidden_state is not None: + for i in range(len(data.hidden_state)): + mean = scatter(data.hidden_state[i], data.batch, dim=0, reduce="mean") + data.hidden_state[i] = data.hidden_state[i] - mean[data.batch] + return data + + def _prepare_noisy_batch( + self, + x: torch_geometric.data.Batch, + sigma: float | torch.Tensor, + align_noisy_input: bool, + ): + """Prepare a batch of noisy graphs and their targets.""" + with torch.no_grad(): + if self.mean_center: + x_processed = mean_center(x) + x_processed = self._mean_center_hidden_states(x_processed) + else: + x_processed = x + + sigma_tensor = torch.as_tensor(sigma).to(x_processed.pos.device) + + y = self.add_noise_hiddens(x_processed, self.N_measurements_hidden, self.N_measurements, sigma_tensor) + + x_list = x_processed.to_data_list() + repeated_x_list = [ + graph.clone() for graph in x_list for _ in range(self.N_measurements_hidden * self.N_measurements) + ] + x_target = torch_geometric.data.Batch.from_data_list(repeated_x_list).to(x_processed.pos.device) + + if self.mean_center: + y = mean_center(y) + y = self._mean_center_hidden_states(y) + + if align_noisy_input: + y = self._align_A_to_B_batched_with_hidden_states(y, x_target) + + return y, x_target + + def conditioner_default(self, y: torch_geometric.data.Batch) -> list[torch.Tensor]: + conditioned_structures = [y.pos] # Return complete list starting with current position + return conditioned_structures + + def conditioner(self, y: torch_geometric.data.Batch) -> list[torch.Tensor]: + if self.conditioning_module is None: + return self.conditioner_default(y) + elif callable(self.conditioning_module): + return self.conditioning_module(y) + else: + raise ValueError("Conditioner must be a callable or None") + + def add_noise(self, x: torch_geometric.data.Batch, sigma: float | torch.Tensor) -> torch_geometric.data.Batch: + # pos [B, ...] + sigma = unsqueeze_trailing(sigma, x.pos.ndim) + + y = x.clone() + if self.add_fixed_ones: + noise = torch.ones_like(x.pos) + if hasattr(x, "hidden_state") and x.hidden_state is not None: + hidden_noise = [torch.randn_like(x.hidden_state[i]) for i in range(len(x.hidden_state))] + else: + hidden_noise = [] + elif self.add_fixed_noise: + torch.manual_seed(0) + num_batches = x.batch.max().item() + 1 + if len(x.pos.shape) == 2: + num_nodes_per_batch = x.pos.shape[0] // num_batches + noise = torch.randn_like(x.pos[:num_nodes_per_batch]).repeat(num_batches, 1) + if hasattr(x, "hidden_state") and x.hidden_state is not None: + hidden_noise = [ + torch.randn_like(x.hidden_state[i][:num_nodes_per_batch]).repeat(num_batches, 1) + for i in range(len(x.hidden_state)) + ] + else: + hidden_noise = [] + if len(x.pos.shape) == 3: + num_nodes_per_batch = x.pos.shape[1] + noise = torch.randn_like(x.pos[0]).repeat(num_batches, 1, 1) + if hasattr(x, "hidden_state") and x.hidden_state is not None: + hidden_noise = [ + torch.randn_like(x.hidden_state[i][0]).repeat(num_batches, 1, 1) + for i in range(len(x.hidden_state)) + ] + else: + hidden_noise = [] + else: + noise = torch.randn_like(x.pos) + if hasattr(x, "hidden_state") and x.hidden_state is not None: + hidden_noise = [torch.randn_like(x.hidden_state[i]) for i in range(len(x.hidden_state))] + else: + hidden_noise = [] + y.pos = x.pos + sigma * noise + if hasattr(y, "hidden_state") and y.hidden_state is not None and hidden_noise: + for i in range(len(y.hidden_state)): + y.hidden_state[i] = x.hidden_state[i] + sigma * hidden_noise[i] + if torch.rand(()) < self.mirror_augmentation_rate: + y.pos = -y.pos + return y + + def add_noise_hiddens( + self, + x: torch_geometric.data.Batch, + N_measurements_hidden: int, + N_measurements: int, + sigma: float | torch.Tensor, + ) -> torch_geometric.data.Batch: + """ + Makes N_measurements_hidden number of noisy copies of the hidden states of x + and then for every noisy copy, makes N_measurements number of noisy copies of the positions of x. + + Args: + x (Batch): A torch_geometric Batch object. Must have `pos` and `hidden_state` attributes. + `hidden_state` is expected to be a list of tensors. + N_measurements_hidden (int): Number of noisy copies of hidden states. + N_measurements (int): Number of noisy copies of positions for each noisy hidden state. + sigma (float or torch.Tensor): The standard deviation of the Gaussian noise to add. + + Returns: + Batch: A new Batch object containing all the noisy copies. + """ + x_list = x.to_data_list() + noisy_y_list = [] + + for graph in x_list: + for _ in range(N_measurements_hidden): + # Create a noisy version of the hidden state + noisy_hidden_state = [] + if hasattr(graph, "hidden_state") and graph.hidden_state is not None: + for hs_tensor in graph.hidden_state: + noise = torch.randn_like(hs_tensor) * sigma + noisy_hidden_state.append(hs_tensor + noise) + + for _ in range(N_measurements): + noisy_graph = graph.clone() + + # Add noise to positions + pos_noise = torch.randn_like(graph.pos) * sigma + noisy_graph.pos = graph.pos + pos_noise + + # Assign the noisy hidden state + if hasattr(graph, "hidden_state") and graph.hidden_state is not None: + noisy_graph.hidden_state = [hs.clone() for hs in noisy_hidden_state] + + noisy_y_list.append(noisy_graph) + + return torch_geometric.data.Batch.from_data_list(noisy_y_list) + + def score(self, y: torch_geometric.data.Batch, sigma: float | torch.Tensor) -> torch_geometric.data.Batch: + """Compute the score function.""" + sigma = torch.as_tensor(sigma).to(y.pos) + return (self.xhat(y, sigma).pos - y.pos) / (unsqueeze_trailing(sigma, y.pos.ndim - 1) ** 2) + + def normalization_factors(self, sigma: float, D: int = 3) -> tuple[float, float, float, float]: + """Normalization factors for the input and output.""" + sigma = torch.as_tensor(sigma) + + if self.normalization_type is None: + return 1.0, 0.0, 1.0, sigma + + if self.normalization_type == "EDM": + c_skip = (self.sigma_data**2) / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / torch.sqrt(self.sigma_data**2 + sigma**2) + c_in = 1 / torch.sqrt(sigma**2 + self.sigma_data**2) + c_noise = torch.log(sigma / self.sigma_data) * 0.25 + return c_in, c_skip, c_out, c_noise + + if self.normalization_type == "JAMUN": + A = torch.as_tensor(self.average_squared_distance) + B = torch.as_tensor(2 * D * sigma**2) + + c_in = 1.0 / torch.sqrt(A + B) + c_skip = A / (A + B) + c_out = torch.sqrt((A * B) / (A + B)) + c_noise = torch.log(sigma) / 4 + return c_in, c_skip, c_out, c_noise + + raise ValueError(f"Unknown normalization type: {self.normalization_type}") + + def loss_weight(self, sigma: float, D: int = 3) -> float: + """Loss weight for this graph.""" + _, _, c_out, _ = self.normalization_factors(sigma, D) + return 1 / (c_out**2) + + def effective_radial_cutoff(self, sigma: float | torch.Tensor) -> torch.Tensor: + """Compute the effective radial cutoff for the noise level.""" + return torch.sqrt((self.max_radius**2) + 6 * (sigma**2)) + + def add_edges(self, y: torch_geometric.data.Batch, radial_cutoff: float) -> torch_geometric.data.Batch: + """Add edges to the graph based on the effective radial cutoff.""" + if y.get("edge_index") is not None: + return y + + y = y.clone() + if "batch" in y: + batch = y["batch"] + else: + batch = torch.zeros(y.num_nodes, dtype=torch.long, device=self.device) + + with torch.cuda.nvtx.range("radial_graph"): + radial_edge_index = e3tools.radius_graph(y.pos, radial_cutoff, batch) + + with torch.cuda.nvtx.range("concatenate_edges"): + edge_index = torch.cat((radial_edge_index, y.bonded_edge_index), dim=-1) + if y.bonded_edge_index.numel() == 0: + bond_mask = torch.zeros(radial_edge_index.shape[1], dtype=torch.long, device=y.pos.device) + else: + bond_mask = torch.cat( + ( + torch.zeros(radial_edge_index.shape[1], dtype=torch.long, device=y.pos.device), + torch.ones(y.bonded_edge_index.shape[1], dtype=torch.long, device=y.pos.device), + ), + dim=0, + ) + + y.edge_index = edge_index + y.bond_mask = bond_mask + return y + + def xhat_normalized(self, y: torch_geometric.data.Batch, sigma: float | torch.Tensor) -> torch_geometric.data.Batch: + """Compute the denoised prediction using the normalization factors from JAMUN.""" + sigma = torch.as_tensor(sigma).to(y.pos) + D = y.pos.shape[-1] + + # Compute the normalization factors. + with torch.cuda.nvtx.range("normalization_factors"): + c_in, c_skip, c_out, c_noise = self.normalization_factors(sigma, D) + radial_cutoff = self.effective_radial_cutoff(sigma) / c_in + + # Adjust dimensions. + c_in = unsqueeze_trailing(c_in, y.pos.ndim - 1) + c_skip = unsqueeze_trailing(c_skip, y.pos.ndim - 1) + c_out = unsqueeze_trailing(c_out, y.pos.ndim - 1) + c_noise = c_noise.unsqueeze(0) + + # Add edges to the graph. + with torch.cuda.nvtx.range("add_edges"): + y = self.add_edges(y, radial_cutoff) + + with torch.cuda.nvtx.range("scale_y"): + y_scaled = y.clone() + y_scaled.pos = y.pos * c_in + if hasattr(y, "hidden_state") and y.hidden_state is not None: + scaled_hidden_state = [] + for positions in y.hidden_state: + scaled_hidden_state.append(positions * c_in) + y_scaled.hidden_state = scaled_hidden_state + + with torch.cuda.nvtx.range("clone_y"): + xhat = y.clone() + + with torch.cuda.nvtx.range("conditioning"): + conditioned_structures = self.conditioner(y_scaled) + + with torch.cuda.nvtx.range("g"): + g_pred = self.g( + torch.cat([*conditioned_structures], dim=-1), + topology=y_scaled, + c_noise=c_noise, + effective_radial_cutoff=radial_cutoff, + ) + + xhat.pos = c_skip * y.pos + c_out * g_pred + if hasattr(y, "hidden_state") and y.hidden_state is not None: + xhat.hidden_state = [y.pos, *y.hidden_state[:-1]] # the hidden state updates! + return xhat + + def xhat(self, y: torch.Tensor, sigma: float | torch.Tensor): + """Compute the denoised prediction.""" + if self.mean_center: + with torch.cuda.nvtx.range("mean_center_y"): + y = mean_center(y) + y = self._mean_center_hidden_states(y) + + with torch.cuda.nvtx.range("xhat_normalized"): + xhat = self.xhat_normalized(y, sigma) + + # Mean center the prediction. + if self.mean_center: + with torch.cuda.nvtx.range("mean_center_xhat"): + xhat = mean_center(xhat) + xhat = self._mean_center_hidden_states(xhat) + + return xhat + + def noise_and_denoise( + self, + x: torch_geometric.data.Batch, + sigma: float | torch.Tensor, + align_noisy_input: bool, + ) -> tuple[torch_geometric.data.Batch, torch_geometric.data.Batch, torch_geometric.data.Batch]: + """ + Add noise to the input and denoise it. + Returns the target for the loss, the prediction, and the noisy input. + """ + with torch.no_grad(): + if self.mean_center: + # Operate on a clone to avoid side effects on the original batch object. + x_processed = mean_center(x) + x_processed = self._mean_center_hidden_states(x_processed) + else: + x_processed = x + + sigma = torch.as_tensor(sigma).to(x_processed.pos) + + if self.multimeasurement: + with torch.cuda.nvtx.range("add_noise_hiddens"): + y = self.add_noise_hiddens(x_processed, self.N_measurements_hidden, self.N_measurements, sigma) + + # Repeat x_processed to match y's batch size for alignment and loss calculation. + x_list = x_processed.to_data_list() + repeated_x_list = [ + graph.clone() for graph in x_list for _ in range(self.N_measurements_hidden * self.N_measurements) + ] + x_target = torch_geometric.data.Batch.from_data_list(repeated_x_list).to(x_processed.pos.device) + + else: + with torch.cuda.nvtx.range("add_noise"): + y = self.add_noise(x_processed, sigma) + x_target = x_processed.clone() + + if self.mean_center: + with torch.cuda.nvtx.range("mean_center_y"): + y = mean_center(y) + y = self._mean_center_hidden_states(y) + + # Aligning each batch. + if align_noisy_input: + with torch.cuda.nvtx.range("align_A_to_B_batched"): + y = self._align_A_to_B_batched_with_hidden_states(y, x_target) + + with torch.cuda.nvtx.range("xhat"): + xhat = self.xhat(y, sigma) + + return x_target, xhat, y + + def compute_loss( + self, + x: torch_geometric.data.Batch, + xhat: torch.Tensor, + sigma: float | torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute the loss.""" + if self.mean_center: + with torch.cuda.nvtx.range("mean_center_x"): + x = mean_center(x) + x = self._mean_center_hidden_states(x) + + D = xhat.pos.shape[-1] + + # Compute the raw loss. + with torch.cuda.nvtx.range("raw_coordinate_loss"): + raw_coordinate_loss = (xhat.pos - x.pos).pow(2).sum(dim=-1) + + # Take the mean over each graph. + with torch.cuda.nvtx.range("mean_over_graphs"): + mse = scatter(raw_coordinate_loss, x.batch, dim=0, dim_size=x.num_graphs, reduce="mean") + + # Compute the scaled RMSD. + with torch.cuda.nvtx.range("scaled_rmsd"): + rmsd = torch.sqrt(mse) + scaled_rmsd = rmsd / (sigma * np.sqrt(D)) + + # Account for the loss weight across graphs and noise levels. + with torch.cuda.nvtx.range("loss_weight"): + loss = mse * x.loss_weight + loss = loss * self.loss_weight(sigma, D) + + return loss, { + "mse": mse, + "rmsd": rmsd, + "scaled_rmsd": scaled_rmsd, + } + + def noise_and_compute_loss( + self, + x: torch_geometric.data.Batch, + sigma: float | torch.Tensor, + align_noisy_input: bool, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Add noise to the input and compute the loss.""" + x_target, xhat, _ = self.noise_and_denoise(x, sigma, align_noisy_input=align_noisy_input) + return self.compute_loss(x_target, xhat, sigma) + + def _automatic_step(self, batch: torch_geometric.data.Batch, stage: str): + """The standard step for automatic optimization.""" + align_noisy_input = ( + self.align_noisy_input_during_training if stage == "train" else self.align_noisy_input_during_evaluation + ) + sigma = self.sigma_distribution.sample().to(self.device) + + loss, aux = self.noise_and_compute_loss( + batch, + sigma, + align_noisy_input=align_noisy_input, + ) # check if the loss is nan. if nan then save the model, and the batch and see what went on. + # if torch.isnan(loss.sum()): + # print(f"Loss is nan at step {self.global_step}") + # print(f"Batch: {batch}") + # print(f"Sigma: {sigma}") + # print(f"Align noisy input: {align_noisy_input}") + # print(f"Loss: {loss}") + # print(f"Aux: {aux}") + # # Create debug directory if it doesn't exist + # debug_dir = f"/homefs/home/sules/jamun/debug_nan_loss_step_{self.global_step}" + # os.makedirs(debug_dir, exist_ok=True) + # + # # Save model checkpoint + # checkpoint_path = os.path.join(debug_dir, "model_nan_loss.ckpt") + # self.trainer.save_checkpoint(checkpoint_path) + # print(f"Model saved to {checkpoint_path}") + # + # torch.save(batch, debug_dir + "/batch_nan_loss.pt") + # + # # Optionally raise an exception to stop training + # raise RuntimeError(f"NaN loss detected at step {self.global_step}. Debug files saved to {debug_dir}") + + # Average the loss and other metrics over all graphs. + with torch.cuda.nvtx.range("mean_over_graphs"): + aux["loss"] = loss + for key in aux: + aux[key] = aux[key].mean() + if stage == "train": + self.log( + f"train/{key}", + aux[key], + prog_bar=False, + batch_size=batch.num_graphs, + sync_dist=False, + ) + elif stage == "val": + self.log( + f"val/{key}", + aux[key], + prog_bar=(key == "scaled_rmsd"), + batch_size=batch.num_graphs, + sync_dist=True, + ) + else: + continue + + return { + "sigma": sigma, + **aux, + } + + def _manual_step(self, batch: torch_geometric.data.Batch, stage: str): + """A shared step for training and validation with manual optimization.""" + sigma = self.sigma_distribution.sample().to(self.device) + align_noisy_input = ( + self.align_noisy_input_during_training if stage == "train" else self.align_noisy_input_during_evaluation + ) + + y, x_target = self._prepare_noisy_batch(batch, sigma, align_noisy_input) + + y_list = y.to_data_list() + x_target_list = x_target.to_data_list() + + chunk_size = self.max_graphs_per_batch + num_chunks = (len(y_list) + chunk_size - 1) // chunk_size + + all_aux = [] + opt = self.optimizers() if stage == "train" else None + + # print(f"Processing {num_chunks} chunks of size {chunk_size} for {stage}...") + for i in range(num_chunks): + start_index = i * chunk_size + end_index = min(start_index + chunk_size, len(y_list)) + + y_micro_batch_list = y_list[start_index:end_index] + x_target_micro_batch_list = x_target_list[start_index:end_index] + + if not y_micro_batch_list: + continue + + y_micro_batch = torch_geometric.data.Batch.from_data_list(y_micro_batch_list) + x_target_micro_batch = torch_geometric.data.Batch.from_data_list(x_target_micro_batch_list) + + xhat_micro_batch = self.xhat(y_micro_batch, sigma) + + loss, aux = self.compute_loss(x_target_micro_batch, xhat_micro_batch, sigma) + + with torch.cuda.nvtx.range("mean_over_graphs"): + aux["loss"] = loss + for key in aux: + aux[key] = aux[key].mean() + if stage == "train": + # Scale loss by number of chunks to match automatic optimization gradients + scaled_loss = aux["loss"] / num_chunks + opt.zero_grad() + self.manual_backward(scaled_loss) + opt.step() + all_aux.append(aux) + + avg_aux = {} + with torch.no_grad(): + if all_aux: + for key in all_aux[0]: + avg_aux[key] = torch.tensor([d[key] for d in all_aux]).mean() + log_opts = { + "prog_bar": (stage == "val" and "scaled_rmsd" in avg_aux), + "batch_size": len(y_list), + "sync_dist": (stage == "val"), # Only sync for validation + } + + # Ensure training metrics are always logged + if stage == "train": + log_opts["on_step"] = True + log_opts["on_epoch"] = True + + for key, value in avg_aux.items(): + self.log(f"{stage}/{key}", value, **log_opts) + + return {"sigma": sigma, **avg_aux} + + def training_step(self, batch: torch_geometric.data.Batch, batch_idx: int): + """Called during training.""" + if self.automatic_optimization: + return self._automatic_step(batch, "train") + else: + # print(f"Manual optimization enabled for training step {batch_idx}.") + return self._manual_step(batch, "train") + + def validation_step(self, batch: torch_geometric.data.Batch, batch_idx: int): + """Called during validation.""" + if self.automatic_optimization: + return self._automatic_step(batch, "val") + else: + return self._manual_step(batch, "val") + + def configure_optimizers(self): + """Set up the optimizer and learning rate scheduler.""" + optimizer = self.optim_factory(params=self.parameters()) + + out = {"optimizer": optimizer} + if self.lr_scheduler_config: + scheduler = self.lr_scheduler_config.pop("scheduler") + out["lr_scheduler"] = { + "scheduler": scheduler(optimizer), + **self.lr_scheduler_config, + } + + return out diff --git a/src/jamun/model/denoiser_spiked.py b/src/jamun/model/denoiser_spiked.py new file mode 100644 index 0000000..401b20f --- /dev/null +++ b/src/jamun/model/denoiser_spiked.py @@ -0,0 +1,483 @@ +import logging +from collections.abc import Callable + +import e3tools +import lightning.pytorch as pl +import numpy as np +import torch +import torch_geometric +from e3tools import scatter + +from jamun.utils import mean_center, unsqueeze_trailing +from jamun.utils.align import kabsch_algorithm + + +class DenoiserSpiked(pl.LightningModule): + """The main denoiser model with conditional architecture that includes clean sample conditioning.""" + + def __init__( + self, + arch: Callable[..., torch.nn.Module], + optim: Callable[..., torch.optim.Optimizer], + sigma_distribution: torch.distributions.Distribution, + max_radius: float, + average_squared_distance: float, + add_fixed_noise: bool, + add_fixed_ones: bool, + align_noisy_input_during_training: bool, + align_noisy_input_during_evaluation: bool, + mean_center: bool, + mirror_augmentation_rate: float, + bond_loss_coefficient: float = 1.0, + normalization_type: str | None = "JAMUN", + sigma_data: float | None = None, # Only used if normalization_type is "EDM" + lr_scheduler_config: dict | None = None, + use_torch_compile: bool = True, + torch_compile_kwargs: dict | None = None, + conditioner: Callable[..., list[torch.Tensor]] = None, + ): + super().__init__() + self.save_hyperparameters(logger=False) + + self.g = arch() + if use_torch_compile: + if torch_compile_kwargs is None: + torch_compile_kwargs = {} + + self.g = torch.compile(self.g, **torch_compile_kwargs) + + py_logger = logging.getLogger("jamun") + py_logger.info(self.g) + + self.optim_factory = optim + self.lr_scheduler_config = lr_scheduler_config + self.sigma_distribution = sigma_distribution + self.max_radius = max_radius + + self.add_fixed_noise = add_fixed_noise + self.add_fixed_ones = add_fixed_ones + if self.add_fixed_noise and self.add_fixed_ones: + raise ValueError("Can't add fixed noise and fixed ones at the same time") + if self.add_fixed_noise: + py_logger.info("Adding fixed noise") + if self.add_fixed_ones: + py_logger.info("Adding fixed ones") + + self.average_squared_distance = average_squared_distance + py_logger.info(f"Average squared distance = {self.average_squared_distance}") + + self.align_noisy_input_during_training = align_noisy_input_during_training + if self.align_noisy_input_during_training: + py_logger.info("Aligning noisy input during training.") + else: + py_logger.info("Not aligning noisy input during training.") + + self.align_noisy_input_during_evaluation = align_noisy_input_during_evaluation + if self.align_noisy_input_during_evaluation: + py_logger.info("Aligning noisy input during evaluation.") + else: + py_logger.info("Not aligning noisy input during evaluation.") + + self.mean_center = mean_center + if self.mean_center: + py_logger.info("Mean centering input and output.") + else: + py_logger.info("Not mean centering input and output.") + + self.mirror_augmentation_rate = mirror_augmentation_rate + py_logger.info(f"Mirror augmentation rate: {self.mirror_augmentation_rate}") + + self.normalization_type = normalization_type + if self.normalization_type is not None: + py_logger.info(f"Normalization type: {self.normalization_type}") + else: + py_logger.info("No normalization") + + self.sigma_data = sigma_data + if self.normalization_type == "EDM" and self.sigma_data is None: + raise ValueError("sigma_data must be provided when normalization_type is 'EDM'") + elif self.normalization_type != "EDM" and self.sigma_data is not None: + raise ValueError("sigma_data can only be used when normalization_type is 'EDM'") + + self.bond_loss_coefficient = bond_loss_coefficient + self.conditioning_module = conditioner + if self.conditioning_module is not None and not callable(self.conditioning_module): + raise ValueError("Conditioner must be a callable or None") + py_logger.info(f"Conditioner: {self.conditioning_module}") + + def on_before_optimizer_step(self, optimizer): + # Log gradients and parameters. + for name, param in self.named_parameters(): + self.log(f"parameter_norms/{name}", param.norm(), sync_dist=True) + if param.grad is not None: + self.log(f"gradient_norms/{name}", param.grad.norm(), sync_dist=True) + + def conditioner_default( + self, y: torch_geometric.data.Batch, x_clean: torch_geometric.data.Batch = None + ) -> list[torch.Tensor]: + conditioned_structures = [y.pos] # Return complete list starting with current position + if x_clean is not None: + conditioned_structures.append(x_clean.pos) # Add clean sample positions + return conditioned_structures + + def conditioner( + self, y: torch_geometric.data.Batch, x_clean: torch_geometric.data.Batch = None + ) -> list[torch.Tensor]: + if self.conditioning_module is None: + return self.conditioner_default(y, x_clean) + elif callable(self.conditioning_module): + return self.conditioning_module(y, x_clean) + else: + raise ValueError("Conditioner must be a callable or None") + + def _align_A_to_B_batched_with_hidden_states( + self, A: torch_geometric.data.Batch, B: torch_geometric.data.Batch + ) -> torch_geometric.data.Batch: + """Aligns each graph of A to the corresponding graph in B, including hidden states.""" + A_aligned = A.clone() + + # Align positions + A_aligned.pos = kabsch_algorithm(A.pos, B.pos, A.batch, A.num_graphs) + + # Align hidden states + if hasattr(A, "hidden_state") and A.hidden_state is not None: + A_aligned.hidden_state = [] + for i in range(len(A.hidden_state)): + A_aligned.hidden_state.append(kabsch_algorithm(A.hidden_state[i], B.pos, A.batch, A.num_graphs)) + return A_aligned + + def _mean_center_hidden_states(self, data: torch_geometric.data.Batch): + if hasattr(data, "hidden_state") and data.hidden_state is not None: + for i in range(len(data.hidden_state)): + mean = scatter(data.hidden_state[i], data.batch, dim=0, reduce="mean") + data.hidden_state[i] = data.hidden_state[i] - mean[data.batch] + return data + + def add_noise(self, x: torch_geometric.data.Batch, sigma: float | torch.Tensor) -> torch_geometric.data.Batch: + # pos [B, ...] + sigma = unsqueeze_trailing(sigma, x.pos.ndim) + + y = x.clone() + if self.add_fixed_ones: + noise = torch.ones_like(x.pos) + hidden_noise = [torch.randn_like(x.hidden_state[i]) for i in range(len(x.hidden_state))] + elif self.add_fixed_noise: + torch.manual_seed(0) + num_batches = x.batch.max().item() + 1 + if len(x.pos.shape) == 2: + num_nodes_per_batch = x.pos.shape[0] // num_batches + noise = torch.randn_like(x.pos[:num_nodes_per_batch]).repeat(num_batches, 1) + hidden_noise = [ + torch.randn_like(x.hidden_state[i][:num_nodes_per_batch]).repeat(num_batches, 1) + for i in range(len(x.hidden_state)) + ] + if len(x.pos.shape) == 3: + num_nodes_per_batch = x.pos.shape[1] + noise = torch.randn_like(x.pos[0]).repeat(num_batches, 1, 1) + hidden_noise = [ + torch.randn_like(x.hidden_state[i][0]).repeat(num_batches, 1, 1) for i in range(len(x.hidden_state)) + ] + else: + noise = torch.randn_like(x.pos) + hidden_noise = [torch.randn_like(x.hidden_state[i]) for i in range(len(x.hidden_state))] + y.pos = x.pos + sigma * noise + for i in range(len(y.hidden_state)): + y.hidden_state[i] = x.hidden_state[i] + sigma * hidden_noise[i] + if torch.rand(()) < self.mirror_augmentation_rate: + y.pos = -y.pos + return y + + def score( + self, y: torch_geometric.data.Batch, sigma: float | torch.Tensor, x_clean: torch_geometric.data.Batch + ) -> torch_geometric.data.Batch: + """Compute the score function.""" + sigma = torch.as_tensor(sigma).to(y.pos) + return (self.xhat(y, sigma, x_clean).pos - y.pos) / (unsqueeze_trailing(sigma, y.pos.ndim - 1) ** 2) + + def normalization_factors(self, sigma: float, D: int = 3) -> tuple[float, float, float, float]: + """Normalization factors for the input and output.""" + sigma = torch.as_tensor(sigma) + + if self.normalization_type is None: + return 1.0, 0.0, 1.0, sigma + + if self.normalization_type == "EDM": + c_skip = (self.sigma_data**2) / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / torch.sqrt(self.sigma_data**2 + sigma**2) + c_in = 1 / torch.sqrt(sigma**2 + self.sigma_data**2) + c_noise = torch.log(sigma / self.sigma_data) * 0.25 + return c_in, c_skip, c_out, c_noise + + if self.normalization_type == "JAMUN": + A = torch.as_tensor(self.average_squared_distance) + B = torch.as_tensor(2 * D * sigma**2) + + c_in = 1.0 / torch.sqrt(A + B) + c_skip = A / (A + B) + c_out = torch.sqrt((A * B) / (A + B)) + c_noise = torch.log(sigma) / 4 + return c_in, c_skip, c_out, c_noise + + raise ValueError(f"Unknown normalization type: {self.normalization_type}") + + def loss_weight(self, sigma: float, D: int = 3) -> float: + """Loss weight for this graph.""" + _, _, c_out, _ = self.normalization_factors(sigma, D) + return 1 / (c_out**2) + + def effective_radial_cutoff(self, sigma: float | torch.Tensor) -> torch.Tensor: + """Compute the effective radial cutoff for the noise level.""" + return torch.sqrt((self.max_radius**2) + 6 * (sigma**2)) + + def add_edges(self, y: torch_geometric.data.Batch, radial_cutoff: float) -> torch_geometric.data.Batch: + """Add edges to the graph based on the effective radial cutoff.""" + if y.get("edge_index") is not None: + return y + + y = y.clone() + if "batch" in y: + batch = y["batch"] + else: + batch = torch.zeros(y.num_nodes, dtype=torch.long, device=self.device) + + with torch.cuda.nvtx.range("radial_graph"): + radial_edge_index = e3tools.radius_graph(y.pos, radial_cutoff, batch) + + with torch.cuda.nvtx.range("concatenate_edges"): + edge_index = torch.cat((radial_edge_index, y.bonded_edge_index), dim=-1) + if y.bonded_edge_index.numel() == 0: + bond_mask = torch.zeros(radial_edge_index.shape[1], dtype=torch.long, device=y.pos.device) + else: + bond_mask = torch.cat( + ( + torch.zeros(radial_edge_index.shape[1], dtype=torch.long, device=y.pos.device), + torch.ones(y.bonded_edge_index.shape[1], dtype=torch.long, device=y.pos.device), + ), + dim=0, + ) + + y.edge_index = edge_index + y.bond_mask = bond_mask + return y + + def xhat_normalized( + self, y: torch_geometric.data.Batch, sigma: float | torch.Tensor, x_clean: torch_geometric.data.Batch + ) -> torch_geometric.data.Batch: + """Compute the denoised prediction using the normalization factors from JAMUN.""" + sigma = torch.as_tensor(sigma).to(y.pos) + D = y.pos.shape[-1] + + # Compute the normalization factors. + with torch.cuda.nvtx.range("normalization_factors"): + c_in, c_skip, c_out, c_noise = self.normalization_factors(sigma, D) + radial_cutoff = self.effective_radial_cutoff(sigma) / c_in + + # Adjust dimensions. + c_in = unsqueeze_trailing(c_in, y.pos.ndim - 1) + c_skip = unsqueeze_trailing(c_skip, y.pos.ndim - 1) + c_out = unsqueeze_trailing(c_out, y.pos.ndim - 1) + c_noise = c_noise.unsqueeze(0) + + # Add edges to the graph. + with torch.cuda.nvtx.range("add_edges"): + y = self.add_edges(y, radial_cutoff) + + with torch.cuda.nvtx.range("scale_y"): + y_scaled = y.clone() + y_scaled.pos = y.pos * c_in + # Manually copy hidden state + if hasattr(y, "hidden_state") and y.hidden_state is not None: + y_scaled.hidden_state = [] + for positions in y.hidden_state: + y_scaled.hidden_state.append(positions * c_in) + + # Keep clean sample unscaled + with torch.cuda.nvtx.range("clone_y"): + xhat = y.clone() + # Manually copy hidden state + if hasattr(y, "hidden_state") and y.hidden_state is not None: + xhat.hidden_state = [h.clone() for h in y.hidden_state] + + with torch.cuda.nvtx.range("conditioning"): + conditioned_structures = self.conditioner(y_scaled, x_clean) + # print(f"Conditioner is working, number of conditioned structures: {len(conditioned_structures)}") + with torch.cuda.nvtx.range("g"): + g_pred = self.g( + torch.cat([*conditioned_structures], dim=-1), + topology=y_scaled, + c_noise=c_noise, + effective_radial_cutoff=radial_cutoff, + ) + + xhat.pos = c_skip * y.pos + c_out * g_pred + if hasattr(y, "hidden_state") and y.hidden_state is not None: + xhat.hidden_state = [y.pos, *y.hidden_state[:-1]] + return xhat + + def xhat(self, y: torch.Tensor, sigma: float | torch.Tensor, x_clean: torch_geometric.data.Batch): + """Compute the denoised prediction.""" + if self.mean_center: + with torch.cuda.nvtx.range("mean_center_y"): + y = mean_center(y) + y = self._mean_center_hidden_states(y) + with torch.cuda.nvtx.range("mean_center_x_clean"): + x_clean = mean_center(x_clean) + x_clean = self._mean_center_hidden_states(x_clean) + + with torch.cuda.nvtx.range("xhat_normalized"): + xhat = self.xhat_normalized(y, sigma, x_clean) + + # Mean center the prediction. + if self.mean_center: + with torch.cuda.nvtx.range("mean_center_xhat"): + xhat = mean_center(xhat) + + return xhat + + def noise_and_denoise( + self, + x: torch_geometric.data.Batch, + sigma: float | torch.Tensor, + align_noisy_input: bool, + ) -> tuple[torch_geometric.data.Batch, torch_geometric.data.Batch, torch_geometric.data.Batch]: + """ + Add noise to the input and denoise it. + Returns the target for the loss, the prediction, and the noisy input. + """ + with torch.no_grad(): + if self.mean_center: + # Operate on a clone to avoid side effects on the original batch object. + x_processed = mean_center(x) + x_processed = self._mean_center_hidden_states(x_processed) + else: + x_processed = x + + sigma = torch.as_tensor(sigma).to(x_processed.pos) + + with torch.cuda.nvtx.range("add_noise"): + y = self.add_noise(x_processed, sigma) + x_target = x_processed.clone() + # Manually copy hidden state + if hasattr(x_processed, "hidden_state") and x_processed.hidden_state is not None: + x_target.hidden_state = [h.clone() for h in x_processed.hidden_state] + + if self.mean_center: + with torch.cuda.nvtx.range("mean_center_y"): + y = mean_center(y) + y = self._mean_center_hidden_states(y) + + # Aligning each batch. + if align_noisy_input: + with torch.cuda.nvtx.range("align_A_to_B_batched"): + y = self._align_A_to_B_batched_with_hidden_states(y, x_target) + + # KEY CHANGE: Pass both noisy sample (y) AND clean sample (x_target) to xhat + with torch.cuda.nvtx.range("xhat"): + xhat = self.xhat(y, sigma, x_target) + + return x_target, xhat, y + + def compute_loss( + self, + x: torch_geometric.data.Batch, + xhat: torch.Tensor, + sigma: float | torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute the loss.""" + if self.mean_center: + with torch.cuda.nvtx.range("mean_center_x"): + x = mean_center(x) + + D = xhat.pos.shape[-1] + + # Compute the raw loss. + with torch.cuda.nvtx.range("raw_coordinate_loss"): + raw_coordinate_loss = (xhat.pos - x.pos).pow(2).sum(dim=-1) + + # Take the mean over each graph. + with torch.cuda.nvtx.range("mean_over_graphs"): + mse = scatter(raw_coordinate_loss, x.batch, dim=0, dim_size=x.num_graphs, reduce="mean") + + # Compute the scaled RMSD. + with torch.cuda.nvtx.range("scaled_rmsd"): + rmsd = torch.sqrt(mse) + scaled_rmsd = rmsd / (sigma * np.sqrt(D)) + + # Account for the loss weight across graphs and noise levels. + with torch.cuda.nvtx.range("loss_weight"): + loss = mse * x.loss_weight + loss = loss * self.loss_weight(sigma, D) + + return loss, { + "mse": mse, + "rmsd": rmsd, + "scaled_rmsd": scaled_rmsd, + } + + def noise_and_compute_loss( + self, + x: torch_geometric.data.Batch, + sigma: float | torch.Tensor, + align_noisy_input: bool, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Add noise to the input and compute the loss.""" + x_target, xhat, _ = self.noise_and_denoise(x, sigma, align_noisy_input=align_noisy_input) + return self.compute_loss(x_target, xhat, sigma) + + def _automatic_step(self, batch: torch_geometric.data.Batch, stage: str): + """The standard step for automatic optimization.""" + align_noisy_input = ( + self.align_noisy_input_during_training if stage == "train" else self.align_noisy_input_during_evaluation + ) + sigma = self.sigma_distribution.sample().to(self.device) + + loss, aux = self.noise_and_compute_loss( + batch, + sigma, + align_noisy_input=align_noisy_input, + ) + + # Average the loss and other metrics over all graphs. + with torch.cuda.nvtx.range("mean_over_graphs"): + aux["loss"] = loss + for key in aux: + aux[key] = aux[key].mean() + if stage == "train": + self.log(f"train/{key}", aux[key], prog_bar=False, batch_size=batch.num_graphs, sync_dist=False) + elif stage == "val": + self.log( + f"val/{key}", + aux[key], + prog_bar=(key == "scaled_rmsd"), + batch_size=batch.num_graphs, + sync_dist=True, + ) + else: + continue + + return { + "sigma": sigma, + **aux, + } + + def training_step(self, batch: torch_geometric.data.Batch, batch_idx: int): + """Called during training.""" + return self._automatic_step(batch, "train") + + def validation_step(self, batch: torch_geometric.data.Batch, batch_idx: int): + """Called during validation.""" + self._automatic_step(batch, "val") + + def configure_optimizers(self): + """Set up the optimizer and learning rate scheduler.""" + optimizer = self.optim_factory(params=self.parameters()) + + out = {"optimizer": optimizer} + if self.lr_scheduler_config: + scheduler = self.lr_scheduler_config.pop("scheduler") + out["lr_scheduler"] = { + "scheduler": scheduler(optimizer), + **self.lr_scheduler_config, + } + + return out diff --git a/src/jamun/model/noise_test.py b/src/jamun/model/noise_test.py new file mode 100644 index 0000000..1eae2c4 --- /dev/null +++ b/src/jamun/model/noise_test.py @@ -0,0 +1,106 @@ +import torch +from torch_geometric.data import Batch, Data + + +def add_noise_hiddens(x: Batch, N_measurements_hidden: int, N_measurements: int, sigma: float) -> Batch: + """ + Makes N_measurements_hidden number of noisy copies of the hidden states of x + and then for every noisy copy, makes N_measurements number of noisy copies of the positions of x. + + Args: + x (Batch): A torch_geometric Batch object. Must have `pos` and `hidden_state` attributes. + `hidden_state` is expected to be a list of tensors. + N_measurements_hidden (int): Number of noisy copies of hidden states. + N_measurements (int): Number of noisy copies of positions for each noisy hidden state. + sigma (float): The standard deviation of the Gaussian noise to add. + + Returns: + Batch: A new Batch object containing all the noisy copies. + """ + x_list = x.to_data_list() + noisy_y_list = [] + + for graph in x_list: + for _ in range(N_measurements_hidden): + # Create a noisy version of the hidden state + noisy_hidden_state = [] + if hasattr(graph, "hidden_state") and graph.hidden_state is not None: + for hs_tensor in graph.hidden_state: + noise = torch.randn_like(hs_tensor) * sigma + noisy_hidden_state.append(hs_tensor + noise) + + for _ in range(N_measurements): + noisy_graph = graph.clone() + + # Add noise to positions + pos_noise = torch.randn_like(graph.pos) * sigma + noisy_graph.pos = graph.pos + pos_noise + + # Assign the noisy hidden state + if hasattr(graph, "hidden_state") and graph.hidden_state is not None: + noisy_graph.hidden_state = [hs.clone() for hs in noisy_hidden_state] + + noisy_y_list.append(noisy_graph) + + return Batch.from_data_list(noisy_y_list) + + +# --- Testing Script --- +def run_test(): + print("Running test for add_noise_hiddens...") + + # 1. Create dummy data + num_nodes = 5 + # single data object + data1 = Data(pos=torch.randn(num_nodes, 3), hidden_state=[torch.randn(num_nodes, 4), torch.randn(num_nodes, 8)]) + # another data object + data2 = Data( + pos=torch.randn(num_nodes + 2, 3), hidden_state=[torch.randn(num_nodes + 2, 4), torch.randn(num_nodes + 2, 8)] + ) + + original_batch = Batch.from_data_list([data1, data2]) + + # 2. Set parameters + N_measurements_hidden = 2 + N_measurements = 3 + sigma = 0.1 + + # 3. Call the function + noisy_batch = add_noise_hiddens(original_batch, N_measurements_hidden, N_measurements, sigma) + + # 4. Assertions + # Check total number of graphs + expected_num_graphs = original_batch.num_graphs * N_measurements_hidden * N_measurements + assert noisy_batch.num_graphs == expected_num_graphs, ( + f"Expected {expected_num_graphs} graphs, but got {noisy_batch.num_graphs}" + ) + print(f"Correct number of graphs in output batch: {noisy_batch.num_graphs}") + + noisy_graphs = noisy_batch.to_data_list() + + # Check that noise was added + assert not torch.allclose(noisy_graphs[0].pos, data1.pos) + assert not torch.allclose(noisy_graphs[0].hidden_state[0], data1.hidden_state[0]) + print("Noise was added to pos and hidden_state.") + + # Check hidden state logic + # The first N_measurements graphs (for the first original graph) should have the same hidden state + first_hidden_state_set = noisy_graphs[0].hidden_state + for i in range(1, N_measurements): + assert torch.allclose(noisy_graphs[i].hidden_state[0], first_hidden_state_set[0]) + assert torch.allclose(noisy_graphs[i].hidden_state[1], first_hidden_state_set[1]) + + # But their positions should be different + assert not torch.allclose(noisy_graphs[0].pos, noisy_graphs[1].pos) + + # The (N_measurements+1)-th graph should have a different hidden state (from the second hidden measurement) + next_hidden_state_set = noisy_graphs[N_measurements].hidden_state + assert not torch.allclose(next_hidden_state_set[0], first_hidden_state_set[0]) + + print("Hidden state noise logic seems correct.") + + print("Test passed!") + + +if __name__ == "__main__": + run_test() diff --git a/src/jamun/model/pooling.py b/src/jamun/model/pooling.py new file mode 100644 index 0000000..831659c --- /dev/null +++ b/src/jamun/model/pooling.py @@ -0,0 +1,249 @@ +#!/usr/bin/env python3 +""" +Lightning modules for converting node attributes between spatial and temporal representations. +""" + +import pytorch_lightning as pl +import torch +from e3tools.nn import LayerNorm + + +class SpatialToTemporalNodeAttr(pl.LightningModule): + """ + Lightning module to transfer node attributes from spatial nodes to temporal nodes + by repeating first temporal feature. + """ + + def __init__(self, irreps_out): + super().__init__() + self.irreps_out = irreps_out + self.layer_norm = LayerNorm(irreps_out) + + def forward(self, spatial_node_attr_temporal, temporal_batch): + """ + Transfer node attributes from spatial nodes to temporal nodes by repeating first temporal feature. + Takes the first temporal feature (t=0) and repeats it T times for each spatial node. + + Args: + spatial_node_attr_temporal (torch.Tensor): Node attributes [N_spatial, T, attr_dim] + temporal_batch (torch_geometric.data.Batch): Batch of temporal graphs + + Returns: + torch.Tensor: Node attributes for temporal nodes [N_temporal, attr_dim] + """ + num_spatial_nodes, temporal_length, attr_dim = spatial_node_attr_temporal.shape + num_temporal_graphs = temporal_batch.num_graphs + + # Verify consistency + assert num_spatial_nodes == num_temporal_graphs, ( + f"Mismatch: {num_spatial_nodes} spatial nodes vs {num_temporal_graphs} temporal graphs" + ) + + # Verify temporal length consistency + expected_temporal_nodes = temporal_batch.pos.shape[0] + expected_total_nodes = num_spatial_nodes * temporal_length + assert expected_total_nodes == expected_temporal_nodes, ( + f"Temporal length mismatch: {expected_total_nodes} vs {expected_temporal_nodes}" + ) + + # Extract first temporal feature (t=0) and repeat it T times for each spatial node + first_temporal_features = spatial_node_attr_temporal[:, 0, :] # [N, attr_dim] + + # Repeat each spatial node's first temporal feature T times + temporal_node_attr = first_temporal_features.repeat_interleave(temporal_length, dim=0) # [N*T, attr_dim] + + # Verify the output shape matches the temporal batch + assert temporal_node_attr.shape[0] == expected_temporal_nodes, ( + f"Output shape mismatch: {temporal_node_attr.shape[0]} vs expected {expected_temporal_nodes}" + ) + + # Apply layer normalization before returning + temporal_node_attr = self.layer_norm(temporal_node_attr) + + return temporal_node_attr + + +class TemporalToSpatialNodeAttr(pl.LightningModule): + """ + Lightning module to convert temporal node attributes back to spatial node attributes. + Takes the first temporal node attribute from each temporal graph. + """ + + def __init__(self, irreps_out): + super().__init__() + self.irreps_out = irreps_out + self.layer_norm = LayerNorm(irreps_out) + + def forward(self, temporal_node_attr, temporal_batch): + """ + Convert temporal node attributes back to spatial node attributes. + Takes the first temporal node attribute from each temporal graph. + + Args: + temporal_node_attr (torch.Tensor): Node attributes for temporal nodes [N_temporal, attr_dim] + temporal_batch (torch_geometric.data.Batch): Batch of temporal graphs + + Returns: + torch.Tensor: Node attributes for spatial nodes [N_spatial, attr_dim] + """ + num_temporal_graphs = temporal_batch.num_graphs + attr_dim = temporal_node_attr.shape[1] + + # Extract the first node attribute from each temporal graph + spatial_node_attr = [] + + for graph_idx in range(num_temporal_graphs): + # Get the node range for this temporal graph + start_idx = temporal_batch.ptr[graph_idx] + + # The 0th node of each temporal graph is at the start of its range + first_node_attr = temporal_node_attr[start_idx] + spatial_node_attr.append(first_node_attr) + + # Stack to create spatial node attribute tensor + spatial_node_attr = torch.stack(spatial_node_attr) + + # Verify output shape + assert spatial_node_attr.shape == (num_temporal_graphs, attr_dim), ( + f"Output shape mismatch: {spatial_node_attr.shape} vs expected ({num_temporal_graphs}, {attr_dim})" + ) + + # Apply layer normalization before returning + spatial_node_attr = self.layer_norm(spatial_node_attr) + + return spatial_node_attr + + +class TemporalToSpatialNodeAttrMean(pl.LightningModule): + """ + Lightning module to convert temporal node attributes back to spatial node attributes by averaging. + Takes the mean of all temporal node attributes for each temporal graph. + """ + + def __init__(self, irreps_out): + super().__init__() + self.irreps_out = irreps_out + self.layer_norm = LayerNorm(irreps_out) + + def forward(self, temporal_node_attr, temporal_batch): + """ + Convert temporal node attributes back to spatial node attributes by averaging. + Takes the mean of all temporal node attributes for each temporal graph. + + Args: + temporal_node_attr (torch.Tensor): Node attributes for temporal nodes [N_temporal, attr_dim] + temporal_batch (torch_geometric.data.Batch): Batch of temporal graphs + + Returns: + torch.Tensor: Node attributes for spatial nodes [N_spatial, attr_dim] + """ + num_temporal_graphs = temporal_batch.num_graphs + attr_dim = temporal_node_attr.shape[1] + + # Extract the mean node attributes from each temporal graph + spatial_node_attr = [] + + for graph_idx in range(num_temporal_graphs): + # Get the node range for this temporal graph + start_idx = temporal_batch.ptr[graph_idx] + end_idx = ( + temporal_batch.ptr[graph_idx + 1] + if graph_idx + 1 < len(temporal_batch.ptr) + else len(temporal_node_attr) + ) + + # Take the mean of all temporal nodes for this spatial node + temporal_nodes_attr = temporal_node_attr[start_idx:end_idx] # [temporal_length, attr_dim] + mean_node_attr = temporal_nodes_attr.mean(dim=0) # [attr_dim] + spatial_node_attr.append(mean_node_attr) + + # Stack to create spatial node attribute tensor + spatial_node_attr = torch.stack(spatial_node_attr) + + # Verify output shape + assert spatial_node_attr.shape == (num_temporal_graphs, attr_dim), ( + f"Output shape mismatch: {spatial_node_attr.shape} vs expected ({num_temporal_graphs}, {attr_dim})" + ) + + # Apply layer normalization before returning + spatial_node_attr = self.layer_norm(spatial_node_attr) + + return spatial_node_attr + + +class SpatialTemporalToTemporalNodeAttr(pl.LightningModule): + """ + Lightning module to convert spatial node attributes arranged temporally to temporal node attributes. + Converts from [N, T, features] to [NT, features] with correct temporal graph ordering. + """ + + def __init__(self, irreps_out): + super().__init__() + self.irreps_out = irreps_out + self.layer_norm = LayerNorm(irreps_out) + + def forward(self, spatial_node_attr_temporal, temporal_batch): + """ + Convert spatial node attributes arranged temporally to temporal node attributes. + Converts from [N, T, features] to [NT, features] with correct temporal graph ordering. + + Args: + spatial_node_attr_temporal (torch.Tensor): Node attributes [N_spatial, T, attr_dim] + temporal_batch (torch_geometric.data.Batch): Batch of temporal graphs for validation + + Returns: + torch.Tensor: Node attributes for temporal nodes [N_temporal, attr_dim] + """ + num_spatial_nodes, temporal_length, attr_dim = spatial_node_attr_temporal.shape + num_temporal_graphs = temporal_batch.num_graphs + + # Verify consistency with temporal batch + assert num_spatial_nodes == num_temporal_graphs, ( + f"Mismatch: {num_spatial_nodes} spatial nodes vs {num_temporal_graphs} temporal graphs" + ) + + # Verify temporal length consistency + expected_temporal_nodes = temporal_batch.pos.shape[0] + expected_total_nodes = num_spatial_nodes * temporal_length + assert expected_total_nodes == expected_temporal_nodes, ( + f"Temporal length mismatch: {expected_total_nodes} vs {expected_temporal_nodes}" + ) + + # Reshape to match temporal graph ordering: [N, T, features] -> [N*T, features] + # Temporal graph arranges nodes as: [node0_t0, node0_t1, ..., node0_tT-1, node1_t0, ...] + temporal_node_attr = spatial_node_attr_temporal.reshape(num_spatial_nodes * temporal_length, attr_dim) + + # Verify the output shape matches the temporal batch + assert temporal_node_attr.shape[0] == expected_temporal_nodes, ( + f"Output shape mismatch: {temporal_node_attr.shape[0]} vs expected {expected_temporal_nodes}" + ) + + # Apply layer normalization before returning + temporal_node_attr = self.layer_norm(temporal_node_attr) + + return temporal_node_attr + + +# Legacy function interfaces for backward compatibility +def spatial_to_temporal_node_attr(spatial_node_attr_temporal, temporal_batch, irreps_out): + """Legacy function interface for backward compatibility.""" + module = SpatialToTemporalNodeAttr(irreps_out) + return module(spatial_node_attr_temporal, temporal_batch) + + +def temporal_to_spatial_node_attr(temporal_node_attr, temporal_batch, irreps_out): + """Legacy function interface for backward compatibility.""" + module = TemporalToSpatialNodeAttr(irreps_out) + return module(temporal_node_attr, temporal_batch) + + +def temporal_to_spatial_node_attr_mean(temporal_node_attr, temporal_batch, irreps_out): + """Legacy function interface for backward compatibility.""" + module = TemporalToSpatialNodeAttrMean(irreps_out) + return module(temporal_node_attr, temporal_batch) + + +def spatial_temporal_to_temporal_node_attr(spatial_node_attr_temporal, temporal_batch, irreps_out): + """Legacy function interface for backward compatibility.""" + module = SpatialTemporalToTemporalNodeAttr(irreps_out) + return module(spatial_node_attr_temporal, temporal_batch) diff --git a/src/jamun/sampling/__init__.py b/src/jamun/sampling/__init__.py index 3f14de8..7e4059e 100644 --- a/src/jamun/sampling/__init__.py +++ b/src/jamun/sampling/__init__.py @@ -1,2 +1,2 @@ from . import diffusion, mcmc, walkjump -from ._sampler import Sampler +from ._sampler import Sampler, SamplerMemory diff --git a/src/jamun/sampling/_sampler.py b/src/jamun/sampling/_sampler.py index 4dc65a2..a803def 100644 --- a/src/jamun/sampling/_sampler.py +++ b/src/jamun/sampling/_sampler.py @@ -97,3 +97,57 @@ def sample( self.fabric.log("sampler/global_step", batch_idx) self.fabric.call("on_sample_end", sampler=self) + + +class SamplerMemory(Sampler): + """A sampler for molecular dynamics simulations that uses memory.""" + + def sample( + self, + model, + batch_sampler, + num_batches: int, + init_graphs: torch_geometric.data.Data, + continue_chain: bool = False, + ): + self.fabric.launch() + self.fabric.setup(model) + model.eval() + + init_graphs = init_graphs.to(self.fabric.device) + model_wrapped = utils.ModelSamplingWrapperMemory( + model=model, + init_graphs=init_graphs, + sigma=batch_sampler.sigma, + ) + + y_init = model_wrapped.sample_initial_noisy_positions() + y_hist_init = model_wrapped.sample_initial_noisy_history() + v_init = "gaussian" + + self.fabric.call("on_sample_start", sampler=self) + + batches = torch.arange(num_batches) + iterable = self.progbar_wrapper(batches, desc="Sampling", total=len(batches), leave=False) + + with torch.inference_mode(): + for batch_idx in iterable: + self.global_step = batch_idx + + out = batch_sampler.sample(model=model_wrapped, y_init=y_init, v_init=v_init, y_hist_init=y_hist_init) + samples = model_wrapped.unbatch_samples(out) + + # Start next chain from the end state of the previous chain? + if continue_chain: + y_init = out["y"] + v_init = out["v"] + y_hist_init = out["y_hist"] + else: + y_init = model_wrapped.sample_initial_noisy_positions() + y_hist_init = model_wrapped.sample_initial_noisy_history() + v_init = "gaussian" + + self.fabric.call("on_after_sample_batch", sample=samples, sampler=self) + self.fabric.log("sampler/global_step", batch_idx) + + self.fabric.call("on_sample_end", sampler=self) diff --git a/src/jamun/sampling/mcmc/__init__.py b/src/jamun/sampling/mcmc/__init__.py index b984d4a..70a59c5 100644 --- a/src/jamun/sampling/mcmc/__init__.py +++ b/src/jamun/sampling/mcmc/__init__.py @@ -1 +1 @@ -from ._splitting import ABOBA, BAOAB +from ._splitting import ABOBA, BAOAB, ABOBA_memory, BAOAB_memory diff --git a/src/jamun/sampling/mcmc/_splitting.py b/src/jamun/sampling/mcmc/_splitting.py index 2b02e3d..d53c252 100644 --- a/src/jamun/sampling/mcmc/_splitting.py +++ b/src/jamun/sampling/mcmc/_splitting.py @@ -5,7 +5,7 @@ import torch from torch import Tensor -from jamun.sampling.mcmc.functional import aboba, baoab +from jamun.sampling.mcmc.functional import aboba, aboba_memory, baoab, baoab_memory @dataclass @@ -56,3 +56,21 @@ def __post_init__(self): def __call__(self, y: torch.Tensor, score_fn: Callable, **kwargs): kwargs = dataclasses.asdict(self) | kwargs return baoab(y, score_fn, **kwargs) + + +@dataclass +class ABOBA_memory(ABOBA): + history_update_frequency: int = 1 + + def __call__(self, y: torch.Tensor, y_hist: list, score_fn: Callable, **kwargs): + kwargs = dataclasses.asdict(self) | kwargs + return aboba_memory(y=y, y_hist=y_hist, score_fn=score_fn, **kwargs) + + +@dataclass +class BAOAB_memory(BAOAB): + history_update_frequency: int = 1 + + def __call__(self, y: torch.Tensor, y_hist: list, score_fn: Callable, **kwargs): + kwargs = dataclasses.asdict(self) | kwargs + return baoab_memory(y=y, y_hist=y_hist, score_fn=score_fn, **kwargs) diff --git a/src/jamun/sampling/mcmc/functional/__init__.py b/src/jamun/sampling/mcmc/functional/__init__.py index 37ba462..0490fe6 100644 --- a/src/jamun/sampling/mcmc/functional/__init__.py +++ b/src/jamun/sampling/mcmc/functional/__init__.py @@ -1 +1 @@ -from ._splitting import aboba, baoab +from ._splitting import aboba, aboba_memory, baoab, baoab_memory diff --git a/src/jamun/sampling/mcmc/functional/_splitting.py b/src/jamun/sampling/mcmc/functional/_splitting.py index ca4a315..e5d126e 100644 --- a/src/jamun/sampling/mcmc/functional/_splitting.py +++ b/src/jamun/sampling/mcmc/functional/_splitting.py @@ -26,9 +26,9 @@ def initialize_velocity(v_init: str | torch.Tensor, y: torch.Tensor, u: float) - def create_score_fn(score_fn: Callable, inverse_temperature: float, score_fn_clip: float | None) -> Callable: """Create a score function that is clipped and scaled by the inverse temperature.""" - def score_fn_processed(y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def score_fn_processed(y: torch.Tensor, *args, **kwargs) -> tuple[torch.Tensor, torch.Tensor]: """Score function clipped and scaled by the inverse temperature.""" - orig_score = score_fn(y).to(dtype=y.dtype) + orig_score = score_fn(y, *args, **kwargs).to(dtype=y.dtype) # Clip the score by norm. score = orig_score if score_fn_clip is not None: @@ -176,3 +176,168 @@ def baoab( score_traj = torch.stack(score_traj) return y, v, y_traj, score_traj + + +def aboba_memory( + y: torch.Tensor, + y_hist: list, + score_fn: Callable, + steps: int, + v_init: str | torch.Tensor = "zero", + save_trajectory=False, + save_every_n_steps=1, + burn_in_steps=0, + history_update_frequency=1, + verbose=False, + cpu_offload=False, + delta: float = 1.0, + friction: float = 1.0, + M: float = 1.0, + inverse_temperature: float = 1.0, + score_fn_clip: float | None = None, + cleanup: bool | None = None, + sigma: float | None = None, + **_, +): + """ABOBA splitting scheme that updates a state history.""" + i = 0 + y_traj = [] if save_trajectory else None + score_traj = [] + y_hist_traj = [] + + # Initialize trajectory with initial state + if y_traj is not None and i >= burn_in_steps: + y_traj.append(y.detach().cpu() if cpu_offload else y.detach()) + score_traj.append(torch.zeros_like(y).detach().cpu() if cpu_offload else torch.zeros_like(y).detach()) + y_hist_traj.append(list(y_hist)) + + u = pow(M, -1) + zeta2 = math.sqrt(1 - math.exp(-2 * friction)) + v = initialize_velocity(v_init=v_init, y=y, u=u) + score_fn_processed = create_score_fn(score_fn, inverse_temperature, score_fn_clip) + + steps_iter = range(1, steps) + if verbose: + steps_iter = tqdm(steps_iter, leave=False, desc="ABOBA Memory") + + for i in steps_iter: + for j in range(1, history_update_frequency): + # inner aboba loop for equilibration to conditional density p(y_t | y_hist) + y_current = y.clone().detach() + y = y + (delta / 2) * v + psi, orig_score = score_fn_processed(y, y_hist=y_hist) + v = v + u * (delta / 2) * psi + R = torch.randn_like(y) + vhat = math.exp(-friction) * v + zeta2 * math.sqrt(u) * R + v = vhat + (delta / 2) * psi + y = y + (delta / 2) * v + + if save_trajectory and ((i % save_every_n_steps) == 0) and (i >= burn_in_steps): + y_traj.append(y.detach().cpu() if cpu_offload else y.detach()) + score_traj.append(orig_score.detach().cpu() if cpu_offload else orig_score.detach()) + y_hist_traj.append(list(y_hist)) + + if cleanup is not None and cleanup and sigma is not None: + y_current = y.clone().detach() + _, orig_score = score_fn_processed(y_current, y_hist=y_hist) + y_denoised_and_noised = y_current + (sigma**2) * orig_score + y_hist.pop(-1) + y_hist.insert(0, y_denoised_and_noised) + y = y_denoised_and_noised + else: + y_current = y.clone().detach() + y_hist.pop(-1) + y_hist.insert(0, y_current) + + return ( + y, + v, + y_hist, + torch.stack(y_traj) if y_traj else None, + torch.stack(score_traj) if score_traj else None, + y_hist_traj, + ) + + +def baoab_memory( + y: torch.Tensor, + y_hist: list, + score_fn: Callable, + steps: int, + v_init: str | torch.Tensor = "zero", + save_trajectory=False, + save_every_n_steps=1, + burn_in_steps=0, + history_update_frequency=1, + verbose=False, + cpu_offload=False, + delta: float = 1.0, + friction: float = 1.0, + M: float = 1.0, + inverse_temperature: float = 1.0, + score_fn_clip: float | None = None, + cleanup: bool | None = None, + sigma: float | None = None, + **_, +): + """BAOAB splitting scheme that updates a state history.""" + i = 0 + y_traj = [] if save_trajectory else None + score_traj = [] + y_hist_traj = [] + + u = pow(M, -1) + zeta2 = math.sqrt(1 - math.exp(-2 * friction)) + v = initialize_velocity(v_init=v_init, y=y, u=u) + score_fn_processed = create_score_fn(score_fn, inverse_temperature, score_fn_clip) + + steps_iter = range(1, steps) + if verbose: + steps_iter = tqdm(steps_iter, leave=False, desc="BAOAB Memory") + + psi, orig_score = score_fn_processed(y, y_hist=y_hist) + + # Initialize trajectory with initial state + if y_traj is not None and i >= burn_in_steps: + y_traj.append(y.detach().cpu() if cpu_offload else y.detach()) + score_traj.append(orig_score.detach().cpu() if cpu_offload else orig_score.detach()) + y_hist_traj.append(list(y_hist)) + + for i in steps_iter: + # print(f"Equilibrating to conditional density p(y_t | y_hist) for {history_update_frequency} steps...") + for j in range(1, history_update_frequency): + # inner baoab loop for equilibration to conditional density p(y_t | y_hist) + y_current = y.clone().detach() + v = v + u * (delta / 2) * psi # update with previous psi + y = y + (delta / 2) * v # update with previous v + R = torch.randn_like(y) + vhat = math.exp(-friction) * v + zeta2 * math.sqrt(u) * R + y = y + (delta / 2) * vhat + psi, orig_score = score_fn_processed(y, y_hist=y_hist) + v = vhat + (delta / 2) * psi + + if cleanup is not None and cleanup and sigma is not None: + y_current = y.clone().detach() + _, orig_score = score_fn_processed(y_current, y_hist=y_hist) + y_denoised_and_noised = ( + y_current + (sigma**2) * orig_score + sigma * torch.randn_like(y_current) + ) # clean and add noise + y_hist.pop(-1) + y_hist.insert(0, y_denoised_and_noised) + y = y_denoised_and_noised + else: + y_hist.pop(-1) # remove the last element of the history + y_hist.insert(0, y_current) # present point is the first element of the history + if save_trajectory and ((i % save_every_n_steps) == 0) and (i >= burn_in_steps): + y_traj.append(y.detach().cpu() if cpu_offload else y.detach()) + score_traj.append(orig_score.detach().cpu() if cpu_offload else orig_score.detach()) + y_hist_traj.append(list(y_hist)) + + return ( + y, + v, + y_hist, + torch.stack(y_traj) if y_traj else None, + torch.stack(score_traj) if score_traj else None, + y_hist_traj, + ) diff --git a/src/jamun/sampling/walkjump/__init__.py b/src/jamun/sampling/walkjump/__init__.py index 8c197ba..baf3c3c 100644 --- a/src/jamun/sampling/walkjump/__init__.py +++ b/src/jamun/sampling/walkjump/__init__.py @@ -1,2 +1,2 @@ from ._callbacks import InterpolateParametersCallback, MeasurementDependentParametersCallback -from ._single_measurement import SingleMeasurementSampler +from ._single_measurement import SingleMeasurementSampler, SingleMeasurementSamplerMemory diff --git a/src/jamun/sampling/walkjump/_single_measurement.py b/src/jamun/sampling/walkjump/_single_measurement.py index 0baf9f5..711b39a 100644 --- a/src/jamun/sampling/walkjump/_single_measurement.py +++ b/src/jamun/sampling/walkjump/_single_measurement.py @@ -85,3 +85,112 @@ def sample( out = self.walk_jump(model, batch_size=batch_size, y_init=y_init, v_init=v_init) out["sample"] = out["xhat"] return out + + +class SingleMeasurementSamplerMemory: + """Single Measurement Walk-Jump Sampler.""" + + def __init__(self, mcmc, sigma: float, y_init_distribution: torch.distributions.Distribution | None = None): + self.mcmc = mcmc + self.sigma = float(sigma) + self.y_init_distribution = y_init_distribution + + def walk( + self, + model, + batch_size: int | None = None, + y_init: torch.Tensor | None = None, + v_init: str | Tensor = "gaussian", + y_hist_init: list | None = None, + ): + if y_init is None: + if self.y_init_distribution is None: + raise RuntimeError("either y_init and y_init_distribution must be supplied") + y_init = self.y_init_distribution.sample(sample_shape=(batch_size,)).to(model.device) + if y_hist_init is None: + raise RuntimeError("y_hist_init must be supplied") + y, v, y_hist, y_traj, score_traj, y_hist_traj = self.mcmc( + y_init, + y_hist_init, + lambda y, y_hist: model.score(y, y_hist, self.sigma), + v_init=v_init, + cleanup=True, + sigma=self.sigma, + ) + + if y_traj is not None: + t_traj = torch.ones(y_traj.size(0), device=y_traj.device, dtype=int) + else: + t_traj = None + + return { + "y": y, + "v": v, + "y_hist": y_hist, + "y_traj": y_traj, + "t_traj": t_traj, + "score_traj": score_traj, + "y_hist_traj": y_hist_traj, + } + + def walk_jump( + self, + model, + batch_size: int | None = None, + y_init: torch.Tensor | None = None, + v_init: str | Tensor = "gaussian", + y_hist_init: list | None = None, + ): + out = self.walk( + model, + batch_size=batch_size, + y_init=y_init, + v_init=v_init, + y_hist_init=y_hist_init, + ) + y, v, y_hist, y_traj, t_traj, score_traj, y_hist_traj = ( + out["y"], + out["v"], + out["y_hist"], + out["y_traj"], + out["t_traj"], + out["score_traj"], + out["y_hist_traj"], + ) + + xhat = model.xhat(y, y_hist, sigma=self.sigma) + + if y_traj is not None: + xhat_traj = torch.stack( + [ + model.xhat(y_traj[i, :].to(model.device), y_hist_traj[i], sigma=self.sigma) + for i in tqdm(range(y_traj.size(0)), leave=False, desc="Jump") + ], + dim=0, + ) + else: + xhat_traj = None + + return { + "xhat": xhat, + "y": y, + "v": v, + "y_hist": y_hist, + "xhat_traj": xhat_traj, + "y_traj": y_traj, + "y_hist_traj": y_hist_traj, + "t_traj": t_traj, + "score_traj": score_traj, + } + + def sample( + self, + model, + batch_size: int | None = None, + y_init: torch.Tensor | None = None, + v_init: str | Tensor = "gaussian", + y_hist_init: list | None = None, + ): + out = self.walk_jump(model, batch_size=batch_size, y_init=y_init, v_init=v_init, y_hist_init=y_hist_init) + out["sample"] = out["xhat"] + return out diff --git a/src/jamun/utils/__init__.py b/src/jamun/utils/__init__.py index c0a3def..0790d7f 100644 --- a/src/jamun/utils/__init__.py +++ b/src/jamun/utils/__init__.py @@ -1,6 +1,10 @@ from .align import align_A_to_B, align_A_to_B_batched, align_A_to_B_batched_f, find_rigid_alignment from .atom_graphs import to_atom_graphs -from .average_squared_distance import compute_average_squared_distance, compute_average_squared_distance_from_datasets +from .average_squared_distance import ( + compute_average_squared_distance, + compute_average_squared_distance_from_datasets, + compute_temporal_average_squared_distance_from_datasets, +) from .checkpoint import find_checkpoint, find_checkpoint_directory, get_run_path_for_wandb_run, get_wandb_run_config from .data_with_residue_info import DataWithResidueInformation from .dist_log import dist_log, wandb_dist_log @@ -26,7 +30,7 @@ encode_atom_type, encode_residue, ) -from .sampling_wrapper import ModelSamplingWrapper +from .sampling_wrapper import ModelSamplingWrapper, ModelSamplingWrapperMemory from .scaled_rmsd import scaled_rmsd from .simple_ddp import SimpleDDPStrategy from .singleton import singleton diff --git a/src/jamun/utils/_normalizations.py b/src/jamun/utils/_normalizations.py new file mode 100644 index 0000000..7e00bcf --- /dev/null +++ b/src/jamun/utils/_normalizations.py @@ -0,0 +1,52 @@ +""" +Normalization utilities for jamun models. +""" + +import torch + + +def normalization_factors( + sigma: float, + average_squared_distance: float, + normalization_type: str = "JAMUN", + sigma_data: float = None, + D: int = 3, +) -> tuple[float, float, float, float]: + """ + Compute normalization factors for the input and output. + + Args: + sigma: Noise level + average_squared_distance: Average squared distance from the dataset + normalization_type: Type of normalization ("JAMUN", "EDM", or None) + sigma_data: Sigma data parameter (only used for EDM normalization) + D: Dimensionality (default: 3) + + Returns: + Tuple of (c_in, c_skip, c_out, c_noise) normalization factors + """ + sigma = torch.as_tensor(sigma) + + if normalization_type is None: + return 1.0, 0.0, 1.0, sigma + + if normalization_type == "EDM": + if sigma_data is None: + raise ValueError("sigma_data must be provided when normalization_type is 'EDM'") + c_skip = (sigma_data**2) / (sigma**2 + sigma_data**2) + c_out = sigma * sigma_data / torch.sqrt(sigma_data**2 + sigma**2) + c_in = 1 / torch.sqrt(sigma**2 + sigma_data**2) + c_noise = torch.log(sigma / sigma_data) * 0.25 + return c_in, c_skip, c_out, c_noise + + if normalization_type == "JAMUN": + A = torch.as_tensor(average_squared_distance) + B = torch.as_tensor(2 * D * sigma**2) + + c_in = 1.0 / torch.sqrt(A + B) + c_skip = A / (A + B) + c_out = torch.sqrt((A * B) / (A + B)) + c_noise = torch.log(sigma) / 4 + return c_in, c_skip, c_out, c_noise + + raise ValueError(f"Unknown normalization type: {normalization_type}") diff --git a/src/jamun/utils/average_squared_distance.py b/src/jamun/utils/average_squared_distance.py index 43e8cd8..b11d287 100644 --- a/src/jamun/utils/average_squared_distance.py +++ b/src/jamun/utils/average_squared_distance.py @@ -3,6 +3,7 @@ import numpy as np import torch +import torch_geometric from jamun import utils @@ -71,3 +72,50 @@ def compute_average_squared_distance_from_datasets( ) return float(mean_avg_sq_dist) + + +def compute_temporal_average_squared_distance_from_datasets( + datasets, num_samples: int = 100, verbose: bool = False +) -> float: + """ + Compute average squared distance between neighboring vertices in temporal graphs. + + Args: + datasets: Collection of datasets containing spatial graphs with hidden states + num_samples: Number of samples to use for estimation + verbose: Whether to print verbose output + + Returns: + float: Average squared distance between temporal neighbors + """ + from jamun.model.arch.spatiotemporal import spatial_to_temporal_graphs + + avg_sq_dists = [] + num_graphs = 0 + + # Follow pattern from existing functions in this module + for item in datasets: + if num_graphs >= num_samples: + break + for graph in item: + if num_graphs >= num_samples: + break + # Convert to temporal graphs + temporal_batch = spatial_to_temporal_graphs(graph) + temporal_graphs = torch_geometric.data.Batch.to_data_list(temporal_batch) + graph_mean = 0.0 + num_nodes = graph.pos.shape[0] + for temporal_graph in temporal_graphs: + avg_sq_dist = compute_average_squared_distance(temporal_graph.pos, cutoff=None) + graph_mean += avg_sq_dist / num_nodes + avg_sq_dists.append(graph_mean) + num_graphs += 1 + mean_avg_sq_dist = sum(avg_sq_dists) / num_graphs + + if verbose: + print(f"Total graphs processed: {num_graphs}") + print(f"Total temporal graphs processed: {len(avg_sq_dists)}") + print(f"Mean average squared distance between temporal nodes: {mean_avg_sq_dist:.6f}") + print(f"Standard deviation: {np.std(avg_sq_dists):.6f}") + + return float(mean_avg_sq_dist) diff --git a/src/jamun/utils/checkpoint.py b/src/jamun/utils/checkpoint.py index 4b9d601..8a76084 100644 --- a/src/jamun/utils/checkpoint.py +++ b/src/jamun/utils/checkpoint.py @@ -16,7 +16,8 @@ def get_wandb_run_config(wandb_run_path: str) -> dict[str, Any]: run = wandb.Api().run(wandb_run_path) py_logger = logging.getLogger("jamun") py_logger.info(f"Loading checkpoint corresponding to wandb run {run.name} at {run.url}") - return run.config["cfg"] + key = next(iter(run.config)) # the key might be named differently in the future + return run.config[key] def get_run_path_for_wandb_run(wandb_run_path: str) -> str: diff --git a/src/jamun/utils/data_with_residue_info.py b/src/jamun/utils/data_with_residue_info.py index 66e79bd..d85708a 100644 --- a/src/jamun/utils/data_with_residue_info.py +++ b/src/jamun/utils/data_with_residue_info.py @@ -1,3 +1,5 @@ +from typing import Any + import torch import torch_geometric @@ -13,6 +15,7 @@ class DataWithResidueInformation(torch_geometric.data.Data): residue_index: torch.Tensor # batched version of residue_sequence_index num_residues: int loss_weight: float + hidden_state: Any def __inc__(self, key, value, *args, **kwargs): del value, args, kwargs @@ -24,6 +27,7 @@ def __inc__(self, key, value, *args, **kwargs): "residue_sequence_index", "num_residues", "loss_weight", + "hidden_state", ]: return 0 if key in ["edge_index", "bonded_edge_index"]: diff --git a/src/jamun/utils/inspect_pretrained.py b/src/jamun/utils/inspect_pretrained.py new file mode 100644 index 0000000..ec056c4 --- /dev/null +++ b/src/jamun/utils/inspect_pretrained.py @@ -0,0 +1,241 @@ +#!/usr/bin/env python3 +""" +Utility script for inspecting pretrained checkpoints and checking module compatibility. +""" + +import argparse +import sys +from pathlib import Path + +import hydra +import torch +from omegaconf import DictConfig + +from jamun.utils.checkpoint import find_checkpoint +from jamun.utils.pretrained import ( + check_model_compatibility, + extract_module_from_model, + inspect_model_structure, + load_checkpoint_state_dict, + load_pretrained_model_from_checkpoint, +) + + +def print_checkpoint_structure(checkpoint_path: str, max_depth: int = 2, use_model_loading: bool = True): + """Print the structure of a checkpoint file.""" + print(f"\n📁 Checkpoint structure: {checkpoint_path}") + print("=" * 60) + + if use_model_loading: + # Try to load as a complete model first + try: + model = load_pretrained_model_from_checkpoint(checkpoint_path=checkpoint_path) + if model is not None: + inspect_model_structure(model, max_depth) + return + else: + print("⚠️ Could not load as complete model, falling back to state_dict inspection") + except Exception as e: + print(f"⚠️ Model loading failed ({e}), falling back to state_dict inspection") + + # Fallback to state_dict inspection + try: + state_dict = load_checkpoint_state_dict(checkpoint_path) + + # Group keys by their prefixes + key_groups: dict[str, list] = {} + for key in state_dict.keys(): + parts = key.split(".") + if len(parts) >= max_depth: + prefix = ".".join(parts[:max_depth]) + else: + prefix = key + + if prefix not in key_groups: + key_groups[prefix] = [] + key_groups[prefix].append(key) + + # Print grouped structure + for prefix in sorted(key_groups.keys()): + keys = key_groups[prefix] + if len(keys) == 1 and keys[0] == prefix: + # Single parameter + tensor = state_dict[prefix] + print(f" {prefix}: {list(tensor.shape)} ({tensor.dtype})") + else: + # Group of parameters + total_params = sum(state_dict[key].numel() for key in keys) + print(f" {prefix}.* : {len(keys)} parameters ({total_params:,} total elements)") + + # Show a few example keys + if len(keys) <= 5: + for key in sorted(keys)[:5]: + tensor = state_dict[key] + sub_key = key[len(prefix) + 1 :] if key.startswith(prefix + ".") else key + print(f" └─ {sub_key}: {list(tensor.shape)}") + else: + for key in sorted(keys)[:3]: + tensor = state_dict[key] + sub_key = key[len(prefix) + 1 :] if key.startswith(prefix + ".") else key + print(f" └─ {sub_key}: {list(tensor.shape)}") + print(f" └─ ... and {len(keys) - 3} more") + + total_params = sum(tensor.numel() for tensor in state_dict.values()) + print(f"\n📊 Total parameters: {total_params:,}") + + except Exception as e: + print(f"❌ Error loading checkpoint: {e}") + + +def check_compatibility_with_config(checkpoint_path: str, config_path: str, module_path: str | None = None): + """Check if a checkpoint is compatible and can be loaded.""" + print("\n🔍 Checking compatibility...") + print(f"Checkpoint: {checkpoint_path}") + print(f"Config: {config_path}") + if module_path: + print(f"Module path: {module_path}") + print("=" * 60) + + try: + # Check if checkpoint can be loaded as a model + compatibility = check_model_compatibility(checkpoint_path=checkpoint_path) + + if compatibility["loadable"]: + print("✅ Checkpoint can be loaded as a complete model!") + print(f"Model class: {compatibility['model_class'].__name__}") + print(f"Total parameters: {compatibility['total_params']:,}") + print(f"Trainable parameters: {compatibility['trainable_params']:,}") + + # If a module path is specified, try to extract it + if module_path: + try: + model = load_pretrained_model_from_checkpoint(checkpoint_path=checkpoint_path) + extracted_module = extract_module_from_model(model, module_path) + if extracted_module is not None: + print(f"✅ Module at path '{module_path}' can be extracted!") + module_params = sum(p.numel() for p in extracted_module.parameters()) + print(f"Module parameters: {module_params:,}") + else: + print(f"❌ Module at path '{module_path}' not found in model") + except Exception as e: + print(f"❌ Error extracting module: {e}") + + # Try loading with the config if provided + if config_path and Path(config_path).exists(): + try: + with hydra.initialize_config_dir(config_dir=str(Path(config_path).parent.absolute())): + cfg = hydra.compose(config_name=Path(config_path).stem) + if isinstance(cfg, DictConfig): + target_model = hydra.utils.instantiate(cfg) + if isinstance(target_model, compatibility["model_class"]): + print("✅ Config model class matches checkpoint model class!") + else: + print( + f"⚠️ Config model class ({type(target_model).__name__}) differs from checkpoint ({compatibility['model_class'].__name__})" + ) + except Exception as e: + print(f"⚠️ Could not instantiate model from config: {e}") + else: + print("❌ Checkpoint cannot be loaded as a model") + if "error" in compatibility: + print(f"Error: {compatibility['error']}") + + except Exception as e: + print(f"❌ Error checking compatibility: {e}") + + +def extract_and_save_module(checkpoint_path: str, module_path: str, output_path: str): + """Extract a specific module from a checkpoint and save it separately.""" + print(f"\n📤 Extracting module: {module_path}") + print(f"From: {checkpoint_path}") + print(f"To: {output_path}") + print("=" * 60) + + try: + # Load the full model + model = load_pretrained_model_from_checkpoint(checkpoint_path=checkpoint_path) + if model is None: + print("❌ Could not load model from checkpoint") + return + + # Extract the specific module + extracted_module = extract_module_from_model(model, module_path) + if extracted_module is None: + print(f"❌ Module at path '{module_path}' not found in model") + return + + # Save the extracted module + # We'll save it as a state dict that can be loaded later + module_state_dict = extracted_module.state_dict() + + save_data = { + "state_dict": module_state_dict, + "module_class": type(extracted_module).__name__, + "module_path": module_path, + "source_checkpoint": checkpoint_path, + } + + torch.save(save_data, output_path) + + param_count = sum(tensor.numel() for tensor in module_state_dict.values()) + print(f"✅ Extracted {len(module_state_dict)} parameters ({param_count:,} elements)") + print(f"Module class: {type(extracted_module).__name__}") + print(f"Saved to: {output_path}") + + except Exception as e: + print(f"❌ Error extracting module: {e}") + + +def main(): + parser = argparse.ArgumentParser(description="Inspect pretrained checkpoints for spatiotemporal models") + parser.add_argument("command", choices=["inspect", "check", "extract"], help="Command to run") + + # Common arguments + parser.add_argument("--checkpoint", type=str, help="Path to checkpoint file") + parser.add_argument("--wandb_run", type=str, help="WandB run path (e.g., user/project/run_id)") + parser.add_argument("--checkpoint_type", type=str, default="best_so_far", help="Type of checkpoint to load") + + # Inspect command arguments + parser.add_argument("--max_depth", type=int, default=2, help="Maximum depth for structure inspection") + + # Check command arguments + parser.add_argument("--config", type=str, help="Path to model config file") + parser.add_argument( + "--module_path", + type=str, + help="Module path to extract (e.g., 'conditioner.spatiotemporal_model.spatial_module')", + ) + + # Extract command arguments + parser.add_argument("--output", type=str, help="Output path for extracted module") + + args = parser.parse_args() + + # Get checkpoint path + if args.checkpoint: + checkpoint_path = args.checkpoint + elif args.wandb_run: + checkpoint_path = find_checkpoint(wandb_train_run_path=args.wandb_run, checkpoint_type=args.checkpoint_type) + else: + print("❌ Must specify either --checkpoint or --wandb_run") + sys.exit(1) + + # Execute command + if args.command == "inspect": + print_checkpoint_structure(checkpoint_path, args.max_depth) + + elif args.command == "check": + if not args.config: + print("❌ --config is required for check command") + sys.exit(1) + check_compatibility_with_config(checkpoint_path, args.config, args.module_path) + + elif args.command == "extract": + if not args.module_path or not args.output: + print("❌ --module_path and --output are required for extract command") + sys.exit(1) + extract_and_save_module(checkpoint_path, args.module_path, args.output) + + +if __name__ == "__main__": + main() diff --git a/src/jamun/utils/pretrained.py b/src/jamun/utils/pretrained.py new file mode 100644 index 0000000..e40b862 --- /dev/null +++ b/src/jamun/utils/pretrained.py @@ -0,0 +1,222 @@ +"""Utilities for loading pretrained models from checkpoints.""" + +import logging +import os +from typing import Any + +import lightning.pytorch as pl +import torch +import torch.nn as nn + +from jamun.utils.checkpoint import find_checkpoint + +py_logger = logging.getLogger("jamun") + + +def load_checkpoint_state_dict(checkpoint_path: str) -> dict[str, torch.Tensor]: + """Load state dict from a checkpoint file.""" + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}") + + checkpoint = torch.load(checkpoint_path, map_location="cpu") + + # Handle different checkpoint formats + if "state_dict" in checkpoint: + return checkpoint["state_dict"] + elif isinstance(checkpoint, dict) and any(k.startswith("model.") for k in checkpoint.keys()): + return checkpoint + else: + raise ValueError(f"Unrecognized checkpoint format in {checkpoint_path}") + + +def load_pretrained_model_from_checkpoint( + checkpoint_path: str | None = None, + wandb_run_path: str | None = None, + checkpoint_type: str = "best_so_far", + model_class: type | None = None, +) -> pl.LightningModule | None: + """ + Load an entire pretrained model from checkpoint. + + Args: + checkpoint_path: Direct path to checkpoint file (mutually exclusive with wandb_run_path) + wandb_run_path: WandB run path to find checkpoint (mutually exclusive with checkpoint_path) + checkpoint_type: Type of checkpoint to load ("best_so_far", "last", etc.) + model_class: Optional model class to use for loading (if checkpoint doesn't contain class info) + + Returns: + Loaded model or None if loading failed + """ + if not checkpoint_path and not wandb_run_path: + py_logger.warning("No checkpoint path or wandb run path provided, skipping pretrained loading") + return None + + if checkpoint_path and wandb_run_path: + raise ValueError("Cannot specify both checkpoint_path and wandb_run_path") + + try: + # Find the checkpoint file + if wandb_run_path: + checkpoint_path = find_checkpoint(wandb_train_run_path=wandb_run_path, checkpoint_type=checkpoint_type) + + py_logger.info(f"Loading pretrained model from: {checkpoint_path}") + + # Load the entire model from checkpoint + if model_class: + model = model_class.load_from_checkpoint(checkpoint_path, strict=False) + else: + # Try to auto-detect model class from checkpoint + checkpoint = torch.load(checkpoint_path, map_location="cpu") + if "hyper_parameters" in checkpoint and "_target_" in checkpoint["hyper_parameters"]: + # Try to import and use the model class from checkpoint + import importlib + + target = checkpoint["hyper_parameters"]["_target_"] + module_path, class_name = target.rsplit(".", 1) + module = importlib.import_module(module_path) + model_class = getattr(module, class_name) + model = model_class.load_from_checkpoint(checkpoint_path, strict=False) + else: + py_logger.error("Cannot determine model class from checkpoint and no model_class provided") + return None + + py_logger.info(f"Successfully loaded pretrained model of type {type(model).__name__}") + return model + + except Exception as e: + py_logger.error(f"Error loading pretrained model: {e}") + return None + + +def extract_module_from_model(model: pl.LightningModule, module_path: str) -> nn.Module | None: + """ + Extract a specific module from a loaded model using dot notation. + + Args: + model: Loaded PyTorch Lightning model + module_path: Dot-separated path to module (e.g., "conditioner.spatiotemporal_model.spatial_module") + + Returns: + Extracted module or None if not found + """ + try: + current = model + for attr in module_path.split("."): + if hasattr(current, attr): + current = getattr(current, attr) + else: + py_logger.warning(f"Module path '{module_path}' not found in model") + return None + + py_logger.info(f"Successfully extracted module at path: {module_path}") + return current + + except Exception as e: + py_logger.error(f"Error extracting module '{module_path}': {e}") + return None + + +def load_pretrained_module_from_checkpoint( + checkpoint_path: str | None = None, + wandb_run_path: str | None = None, + checkpoint_type: str = "best_so_far", + module_path: str | None = None, + model_class: type | None = None, +) -> nn.Module | None: + """ + Load a specific module from a pretrained model checkpoint. + + Args: + checkpoint_path: Direct path to checkpoint file + wandb_run_path: WandB run path to find checkpoint + checkpoint_type: Type of checkpoint to load + module_path: Dot notation path to extract specific module (e.g., "conditioner.spatiotemporal_model.spatial_module") + model_class: Optional model class for loading + + Returns: + Extracted module or None if loading failed + """ + # Load the full model + model = load_pretrained_model_from_checkpoint( + checkpoint_path=checkpoint_path, + wandb_run_path=wandb_run_path, + checkpoint_type=checkpoint_type, + model_class=model_class, + ) + + if model is None: + return None + + # Extract the specific module if path provided + if module_path: + return extract_module_from_model(model, module_path) + else: + # Return the entire model if no specific module path + return model + + +def inspect_model_structure(model: pl.LightningModule, max_depth: int = 3) -> None: + """Print the structure of a loaded model.""" + print(f"\n📁 Model structure: {type(model).__name__}") + print("=" * 60) + + def print_module_tree(module, prefix="", depth=0): + if depth >= max_depth: + return + + for name, child in module.named_children(): + full_name = f"{prefix}.{name}" if prefix else name + param_count = sum(p.numel() for p in child.parameters()) + + print(f"{' ' * depth}├─ {name}: {type(child).__name__} ({param_count:,} params)") + + if depth < max_depth - 1: + print_module_tree(child, full_name, depth + 1) + + print_module_tree(model) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f"\n📊 Total parameters: {total_params:,}") + print(f"📊 Trainable parameters: {trainable_params:,}") + + +def check_model_compatibility( + checkpoint_path: str | None = None, + wandb_run_path: str | None = None, + checkpoint_type: str = "best_so_far", + expected_model_class: type | None = None, +) -> dict[str, Any]: + """ + Check if a checkpoint can be loaded and optionally verify model class. + + Returns: + Dict with 'loadable', 'model_class', 'error' info + """ + try: + if wandb_run_path: + checkpoint_path = find_checkpoint(wandb_train_run_path=wandb_run_path, checkpoint_type=checkpoint_type) + + # Try loading the model + model = load_pretrained_model_from_checkpoint(checkpoint_path=checkpoint_path) + + if model is None: + return {"loadable": False, "model_class": None, "error": "Failed to load model from checkpoint"} + + model_class = type(model) + class_compatible = True + + if expected_model_class: + class_compatible = isinstance(model, expected_model_class) + + return { + "loadable": True, + "model_class": model_class, + "class_compatible": class_compatible, + "checkpoint_path": checkpoint_path, + "total_params": sum(p.numel() for p in model.parameters()), + "trainable_params": sum(p.numel() for p in model.parameters() if p.requires_grad), + } + + except Exception as e: + return {"loadable": False, "model_class": None, "error": str(e)} diff --git a/src/jamun/utils/pretrained_wrapper.py b/src/jamun/utils/pretrained_wrapper.py new file mode 100644 index 0000000..a434406 --- /dev/null +++ b/src/jamun/utils/pretrained_wrapper.py @@ -0,0 +1,232 @@ +""" +Pretrained model wrapper utilities for seamless integration with Hydra configs. +""" + +import logging + +import torch +import torch.nn as nn + +from jamun.model import Denoiser +from jamun.utils import find_checkpoint, mean_center_f, unsqueeze_trailing + + +def compute_normalization_factors( + sigma: float | torch.Tensor, + *, + average_squared_distance: float, + normalization_type: str | None, + sigma_data: float | None = None, + D: int = 3, + device: torch.device | None = None, +) -> tuple[float, float, float, float]: + """Compute the normalization factors for the input, skip connection, output, and noise.""" + sigma = torch.as_tensor(sigma, device=device) + + if normalization_type is None: + c_in = torch.as_tensor(1.0, device=device) + c_skip = torch.as_tensor(0.0, device=device) + c_out = torch.as_tensor(1.0, device=device) + c_noise = torch.as_tensor(sigma, device=device) + return c_in, c_skip, c_out, c_noise + + if normalization_type == "EDM": + c_skip = (sigma_data**2) / (sigma**2 + sigma_data**2) + c_out = sigma * sigma_data / torch.sqrt(sigma_data**2 + sigma**2) + c_in = 1 / torch.sqrt(sigma**2 + sigma_data**2) + c_noise = torch.log(sigma / sigma_data) * 0.25 + return c_in, c_skip, c_out, c_noise + + if normalization_type == "JAMUN": + A = torch.as_tensor(average_squared_distance, device=device) + B = torch.as_tensor(2 * D * sigma**2, device=device) + + c_in = 1.0 / torch.sqrt(A + B) + c_skip = A / (A + B) + c_out = torch.sqrt((A * B) / (A + B)) + c_noise = torch.log(sigma) / 4 + return c_in, c_skip, c_out, c_noise + + raise ValueError(f"Unknown normalization type: {normalization_type}") + + +class DenoiserWrapper(nn.Module): + """ + Wrapper around a denoiser model that matches the spatial module interface. + + This allows pretrained denoiser models to be used as spatial/temporal modules + in the spatiotemporal architecture by replicating the full denoiser logic + including normalization factors computed from the denoiser's own parameters. + """ + + def __init__(self, denoiser_model: nn.Module, c_in: float = 1.0, trainable: bool = True): + """ + Initialize the wrapper. + + Args: + denoiser_model: The pretrained denoiser model + c_in: Rescaling factor to convert positions from overlaying model scale + trainable: Whether to keep the model trainable (default: True) + """ + super().__init__() + self.denoiser = denoiser_model + self.c_in = c_in + + # Set trainability + if not trainable: + for param in self.denoiser.parameters(): + param.requires_grad = False + + def forward(self, pos, topology, batch, num_graphs, c_noise, effective_radial_cutoff): + """ + Forward pass that replicates the denoiser's xhat and xhat_normalized methods. + + Args: + pos: Node positions (input to spatial module) + topology: Graph topology information (already contains bonded edges) + batch: Batch indices + num_graphs: Number of graphs in batch + c_noise: Noise conditioning parameter (already computed) + effective_radial_cutoff: Radial cutoff + + Returns: + Denoised positions from the pretrained model + """ + # Sample sigma from the denoiser's own sigma distribution + sigma = self.denoiser.sigma_distribution.sample().to(pos.device) + + # Rescale positions from overlaying model scale + y = pos / self.c_in + + # Replicate xhat logic + if self.denoiser.mean_center: + y = mean_center_f(y, batch, num_graphs) + + # Replicate xhat_normalized logic + # Compute the normalization factors for the rescaled positions + c_in, c_skip, c_out, _ = compute_normalization_factors( + sigma, + average_squared_distance=self.denoiser.average_squared_distance, + normalization_type=self.denoiser.normalization_type, + sigma_data=self.denoiser.sigma_data, + D=y.shape[-1], + device=y.device, + ) + + # Adjust dimensions + c_in = unsqueeze_trailing(c_in, y.ndim - 1) + c_skip = unsqueeze_trailing(c_skip, y.ndim - 1) + c_out = unsqueeze_trailing(c_out, y.ndim - 1) + c_noise = c_noise.unsqueeze(0) if c_noise.dim() == 0 else c_noise + + # Ensure c_noise is float type (fix for dtype mismatch) + c_noise = c_noise.float() + + # Scale input positions by c_in + y_scaled = y * c_in + + # # Call the denoiser's architecture (topology already has edges) + # # Add this right before line 129 in the pretrained wrapper call + # print("=== Debugging pretrained denoiser input types ===") + # print(f"y_scaled dtype: {y_scaled.dtype}, shape: {y_scaled.shape}") + # print(f"topology.edge_index dtype: {topology.edge_index.dtype if hasattr(topology, 'edge_index') else 'N/A'}") + # print(f"c_noise dtype: {c_noise.dtype}, shape: {c_noise.shape}") + # print(f"batch dtype: {batch.dtype}, shape: {batch.shape}") + # print(f"effective_radial_cutoff dtype: {type(effective_radial_cutoff)}") + + # # Check if topology has any Long tensors + # for attr_name in dir(topology): + # if not attr_name.startswith('_'): + # attr = getattr(topology, attr_name) + # if isinstance(attr, torch.Tensor): + # print(f"topology.{attr_name} dtype: {attr.dtype}") + g_pred = self.denoiser.g( + pos=y_scaled, + topology=topology, + c_noise=c_noise, + effective_radial_cutoff=effective_radial_cutoff, + batch=batch, + num_graphs=num_graphs, + ) + + # Compute final prediction with skip connection + xhat = c_skip * y + c_out * g_pred + + # Mean center the prediction if needed + if self.denoiser.mean_center: + xhat = mean_center_f(xhat, batch, num_graphs) + + return xhat + + +def return_wrapped_denoiser( + wandb_run_path: str | None = None, + checkpoint_dir: str | None = None, + checkpoint_type: str = "best_so_far", + c_in: float = 1.0, + trainable: bool = True, +) -> DenoiserWrapper: + """ + Load a pretrained denoiser model and return it wrapped for use in spatiotemporal architecture. + + This function is designed to be used directly as a _target_ in Hydra configs. + The wrapper replicates the full denoiser logic including normalization factors + computed from the denoiser's own training parameters. + + Args: + wandb_run_path: Path to wandb run (e.g., "entity/project/run_id") + checkpoint_path: Direct path to checkpoint file + checkpoint_type: Type of checkpoint to load ("best_so_far", "latest", etc.) + c_in: Rescaling factor to convert positions from overlaying model scale + trainable: Whether to keep the loaded model trainable + + Returns: + DenoiserWrapper containing the pretrained model + + Example usage in config: + spatial_module: + _target_: jamun.utils.pretrained_wrapper.return_wrapped_denoiser + wandb_run_path: "your_entity/your_project/run_id" + c_in: 1.0 + trainable: false + """ + py_logger = logging.getLogger("jamun") + + if not wandb_run_path and not checkpoint_dir: + raise ValueError("Either wandb_run_path or checkpoint_path must be provided") + + # Load the pretrained model + py_logger.info(f"Loading pretrained denoiser from: {wandb_run_path or checkpoint_dir}") + + # pretrained_model = load_pretrained_model_from_checkpoint( + # checkpoint_path=checkpoint_path, + # wandb_run_path=wandb_run_path, + # checkpoint_type=checkpoint_type + # ) + checkpoint_path = find_checkpoint( + wandb_train_run_path=wandb_run_path, checkpoint_dir=checkpoint_dir, checkpoint_type=checkpoint_type + ) + pretrained_model = Denoiser.load_from_checkpoint(checkpoint_path) + + if pretrained_model is None: + raise RuntimeError(f"Failed to load pretrained model from {wandb_run_path or checkpoint_path}") + + py_logger.info("✓ Successfully loaded pretrained denoiser") + + # Wrap the model + wrapped_model = DenoiserWrapper(pretrained_model, c_in=c_in, trainable=trainable) + + py_logger.info(f"✓ Using c_in rescaling factor: {c_in}") + py_logger.info("✓ Using denoiser's own normalization parameters:") + py_logger.info(f" - normalization_type: {pretrained_model.normalization_type}") + py_logger.info(f" - average_squared_distance: {pretrained_model.average_squared_distance}") + if hasattr(pretrained_model, "sigma_data") and pretrained_model.sigma_data is not None: + py_logger.info(f" - sigma_data: {pretrained_model.sigma_data}") + py_logger.info(f" - mean_center: {pretrained_model.mean_center}") + + if not trainable: + py_logger.info("✓ Frozen pretrained denoiser (not trainable)") + else: + py_logger.info("✓ Pretrained denoiser is trainable") + + return wrapped_model diff --git a/src/jamun/utils/sampling_wrapper.py b/src/jamun/utils/sampling_wrapper.py index 45561f9..663e00c 100644 --- a/src/jamun/utils/sampling_wrapper.py +++ b/src/jamun/utils/sampling_wrapper.py @@ -2,16 +2,25 @@ import torch import torch.nn as nn import torch_geometric +from e3tools import scatter + +from jamun.utils import mean_center class ModelSamplingWrapper: """Wrapper to sample positions from a model.""" - def __init__(self, model: nn.Module, init_graphs: torch_geometric.data.Data, sigma: float): + def __init__( + self, model: nn.Module, init_graphs: torch_geometric.data.Data, sigma: float, recenter_on_init: bool = True + ): self._model = model self.init_graphs = init_graphs self.sigma = sigma + # Apply mean centering if requested + if recenter_on_init: + self.init_graphs = mean_center(self.init_graphs) + @property def device(self) -> torch.device: return self._model.device @@ -28,8 +37,12 @@ def score(self, y, sigma, *args, **kwargs): return self._model.score(self.positions_to_graph(y), sigma) def xhat(self, y, sigma, *args, **kwargs): - xhat_graph = self._model.xhat(self.positions_to_graph(y), sigma) - return xhat_graph.pos + data = self.positions_to_graph(y) + y, topology, batch, num_graphs = data.pos, data.clone(), data.batch, data.num_graphs + del topology.batch, topology.num_graphs + sigma = torch.as_tensor(sigma).to(y) + xhat_pos = self._model.xhat(y, topology, batch, num_graphs, sigma) + return xhat_pos def positions_to_graph(self, positions: torch.Tensor) -> torch_geometric.data.Data: """Wraps a tensor of positions to a graph with these positions as an attribute.""" @@ -79,3 +92,105 @@ def unbatch_samples(self, samples: dict[str, torch.Tensor]) -> list[torch_geomet output_graph[key] = unbatched_value return output_graphs + + +class ModelSamplingWrapperMemory: + """Wrapper for models that depend on a memory of states.""" + + def __init__( + self, model: nn.Module, init_graphs: torch_geometric.data.Data, sigma: float, recenter_on_init: bool = True + ): + self._model = model + self.init_graphs = init_graphs + self.sigma = sigma + + # Apply mean centering if requested + if recenter_on_init: + # Mean center positions + self.init_graphs = mean_center(self.init_graphs) + + # Mean center hidden states if they exist and aren't empty + if hasattr(self.init_graphs, "hidden_state") and self.init_graphs.hidden_state: + for i in range(len(self.init_graphs.hidden_state)): + # Mean center each hidden state in-place + mean = scatter(self.init_graphs.hidden_state[i], self.init_graphs.batch, dim=0, reduce="mean") + self.init_graphs.hidden_state[i] = self.init_graphs.hidden_state[i] - mean[self.init_graphs.batch] + + @property + def device(self) -> torch.device: + return next(self._model.parameters()).device + + def sample_initial_noisy_positions(self) -> torch.Tensor: + pos = self.init_graphs.pos + pos = pos + torch.randn_like(pos) * self.sigma + return pos + + def sample_initial_noisy_history(self) -> list: + noisy_history = [] + for hidden_state in self.init_graphs.hidden_state: + noisy_history.append(hidden_state + torch.randn_like(hidden_state) * self.sigma) + return noisy_history + + def __getattr__(self, name): + return getattr(self._model, name) + + def score(self, y, y_hist, sigma): + graph = self.positions_to_graph(y, y_hist).to(self.device) + return self._model.score(graph, sigma) + + def xhat(self, y, y_hist, sigma): + graph = self.positions_to_graph(y, y_hist).to(self.device) + xhat_graph = self._model.xhat(graph, sigma) + return xhat_graph.pos + + def positions_to_graph(self, positions: torch.Tensor, y_hist: list) -> torch_geometric.data.Data: + """Wraps positions to a graph and attaches the historical states.""" + assert len(positions) == self.init_graphs.num_nodes + assert positions.shape[1] == 3 + input_graph = self.init_graphs.clone() + input_graph.pos = positions + input_graph.hidden_state = y_hist + return input_graph.to(positions.device) + + def unbatch_samples(self, samples: dict[str, torch.Tensor]) -> list[torch_geometric.data.Data]: + """Unbatch samples.""" + if "batch" not in self.init_graphs: + raise ValueError("The initial graph does not have a batch attribute.") + + # Copy off the input graphs, to update attributes later. + output_graphs = self.init_graphs.clone() + output_graphs = torch_geometric.data.Batch.to_data_list(output_graphs) + + for key, value in samples.items(): + if key == "y_hist" or key == "y_hist_traj": + if key == "y_hist": + value = [value] + value = torch.stack([torch.stack(traj, dim=1) for traj in value], dim=1) + else: + if hasattr(value, "ndim") and value.ndim not in [2, 3]: + # py_logger = logging.getLogger("jamun") + # py_logger.info(f"Skipping unbatching of key {key} with shape {value.shape} as it is not 2D or 3D.") + continue + if hasattr(value, "ndim") and value.ndim == 3: + value = einops.rearrange( + value, + "num_frames atoms coords -> atoms num_frames coords", + ) + + unbatched_values = torch_geometric.utils.unbatch(value, self.init_graphs.batch) + for output_graph, unbatched_value in zip(output_graphs, unbatched_values, strict=True): + if key in output_graph: + raise ValueError(f"Key {key} already exists in the output graph.") + + if unbatched_value.shape[0] != output_graph.num_nodes: + raise ValueError( + f"Number of nodes in unbatched value ({unbatched_value.shape[0]}) for key {key} does not match " + f"number of nodes in output graph ({output_graph.num_nodes})." + ) + if key == "y_hist": + unbatched_value = [t.squeeze(-2).squeeze(1) for t in torch.split(unbatched_value, 1, dim=-2)] + if key == "y_hist_traj": + unbatched_value = [t.squeeze(-2) for t in torch.split(unbatched_value, 1, dim=-2)] + output_graph[key] = unbatched_value + + return output_graphs diff --git a/uv.lock b/uv.lock index 2ee8e36..e6c5b66 100644 --- a/uv.lock +++ b/uv.lock @@ -1,9 +1,11 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10" resolution-markers = [ - "python_full_version >= '3.12' and sys_platform == 'linux'", - "python_full_version >= '3.12' and sys_platform != 'linux'", + "python_full_version >= '3.13' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and sys_platform == 'linux'", + "python_full_version >= '3.13' and sys_platform != 'linux'", + "python_full_version == '3.12.*' and sys_platform != 'linux'", "python_full_version == '3.11.*' and sys_platform == 'linux'", "python_full_version == '3.11.*' and sys_platform != 'linux'", "python_full_version < '3.11' and sys_platform == 'linux'", @@ -238,6 +240,32 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3d/20/352b2bf99f93ba18986615841786cbd0d38f7856bd49d4e154a540f04afe/botocore-1.37.1-py3-none-any.whl", hash = "sha256:c1db1bfc5d8c6b3b6d1ca6794f605294b4264e82a7e727b88e0fef9c2b9fbb9c", size = 13359164, upload-time = "2025-02-25T20:32:52.347Z" }, ] +[[package]] +name = "cached-path" +version = "1.7.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "boto3" }, + { name = "filelock" }, + { name = "google-cloud-storage" }, + { name = "huggingface-hub" }, + { name = "requests" }, + { name = "rich" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/14/35/ce8d9b5821f83df1a412379623d1d365a42c9b65c2e9c96fb7b24ce521ba/cached_path-1.7.3.tar.gz", hash = "sha256:956d21b5ac92d64ae6d76b2a1a043c5d660e3421d513e735157d56aca9a31d8e", size = 32795, upload-time = "2025-05-07T16:31:27.516Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bc/2c/787a39669287d69b57bf150d01f13cf5daedea76ab69eae7e7677780acca/cached_path-1.7.3-py3-none-any.whl", hash = "sha256:fe2b396b4816205c95d6e961efb35c66288966c4f96b82cd87c4fb03d4d037d1", size = 36839, upload-time = "2025-05-07T16:31:25.483Z" }, +] + +[[package]] +name = "cachetools" +version = "5.5.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6c/81/3747dad6b14fa2cf53fcf10548cf5aea6913e96fab41a3c198676f8948a5/cachetools-5.5.2.tar.gz", hash = "sha256:1a661caa9175d26759571b2e19580f9d6393969e5dfca11fdb1f947a23e640d4", size = 28380, upload-time = "2025-02-20T21:01:19.524Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/72/76/20fa66124dbe6be5cafeb312ece67de6b61dd91a0247d1ea13db4ebb33c2/cachetools-5.5.2-py3-none-any.whl", hash = "sha256:d26a22bcc62eb95c3beabd9f1ee5e820d3d2704fe2967cbe350e20c8ffcd3f0a", size = 10080, upload-time = "2025-02-20T21:01:16.647Z" }, +] + [[package]] name = "certifi" version = "2025.1.31" @@ -558,6 +586,37 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/91/a1/cf2472db20f7ce4a6be1253a81cfdf85ad9c7885ffbed7047fb72c24cf87/distlib-0.3.9-py2.py3-none-any.whl", hash = "sha256:47f8c22fd27c27e25a65601af709b38e4f0a45ea4fc2e710f65755fa8caaaf87", size = 468973, upload-time = "2024-10-09T18:35:44.272Z" }, ] +[[package]] +name = "dm-tree" +version = "0.1.8" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f8/6d/f1997aac42e0f550c1e952a0b920eaa0bfc4d27d0421499881b934b969fc/dm-tree-0.1.8.tar.gz", hash = "sha256:0fcaabbb14e7980377439e7140bd05552739ca5e515ecb3119f234acee4b9430", size = 35384, upload-time = "2022-12-18T09:46:55.953Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/be/3b/d5ef06ee302ecea27351b18c28f2bde7ac982c774967d7bc82f7765fa0cb/dm_tree-0.1.8-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:35cc164a79336bfcfafb47e5f297898359123bbd3330c1967f0c4994f9cf9f60", size = 167626, upload-time = "2022-12-18T09:46:03.126Z" }, + { url = "https://files.pythonhosted.org/packages/63/29/b7c77a2500742ebbc956c2e6c9c215abeb4348040ddda72a61c760999d64/dm_tree-0.1.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:39070ba268c0491af9fe7a58644d99e8b4f2cde6e5884ba3380bddc84ed43d5f", size = 115351, upload-time = "2022-12-18T09:46:05.517Z" }, + { url = "https://files.pythonhosted.org/packages/ab/b0/8bf47b99c302a01db55ec43645663a385b8d3dfeb94b5fe6adf03b1121dc/dm_tree-0.1.8-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2869228d9c619074de501a3c10dc7f07c75422f8fab36ecdcb859b6f1b1ec3ef", size = 110653, upload-time = "2022-12-18T09:46:07.869Z" }, + { url = "https://files.pythonhosted.org/packages/4c/4b/046c634913643333b1cf8f0dedd45683278013c0fb187fe36915b233ac7b/dm_tree-0.1.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d20f2faa3672b52e5013f4077117bfb99c4cfc0b445d3bde1584c34032b57436", size = 146732, upload-time = "2023-01-21T08:49:45.871Z" }, + { url = "https://files.pythonhosted.org/packages/ea/79/8f65fee71f3cf8bd993031578425fb10f42840b5d9a7298da0c1d52281f7/dm_tree-0.1.8-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5483dca4d7eb1a0d65fe86d3b6a53ae717face83c1f17e0887b1a4a64ae5c410", size = 174704, upload-time = "2023-01-21T08:49:48.433Z" }, + { url = "https://files.pythonhosted.org/packages/3e/9e/20bdcf1953949d8aa1e614f5c6cc1f9b556d4d72e0731e5aa1d353423bb1/dm_tree-0.1.8-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1d7c26e431fc93cc7e0cba867eb000db6a05f6f2b25af11ac4e9dada88fc5bca", size = 150386, upload-time = "2023-01-21T08:49:50.439Z" }, + { url = "https://files.pythonhosted.org/packages/cc/2b/a13e3a44f9121ecab0057af462baeb64dc50eb269de52648db8823bc12ae/dm_tree-0.1.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4d714371bb08839e4e5e29024fc95832d9affe129825ef38836b143028bd144", size = 152844, upload-time = "2022-12-18T09:46:10.308Z" }, + { url = "https://files.pythonhosted.org/packages/f0/5d/86eb4e071ff395fed0783076e94c56ad9a97ba7b6e49b5aaf1b651a4fcd3/dm_tree-0.1.8-cp310-cp310-win_amd64.whl", hash = "sha256:d40fa4106ca6edc66760246a08f500ec0c85ef55c762fb4a363f6ee739ba02ee", size = 101319, upload-time = "2022-12-18T09:46:12.352Z" }, + { url = "https://files.pythonhosted.org/packages/e2/64/901b324804793743f0fdc9e47db893bf0ded9e074850fab2440af330fe83/dm_tree-0.1.8-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ad16ceba90a56ec47cf45b21856d14962ac314787975ef786efb5e6e9ca75ec7", size = 167628, upload-time = "2022-12-18T09:46:14.195Z" }, + { url = "https://files.pythonhosted.org/packages/b1/65/4f10a68dde5fa0c91043c9c899e9bc79b1657ba932d39a5f8525c0058e68/dm_tree-0.1.8-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:803bfc53b4659f447ac694dbd04235f94a73ef7c1fd1e0df7c84ac41e0bc963b", size = 115351, upload-time = "2022-12-18T09:46:16.467Z" }, + { url = "https://files.pythonhosted.org/packages/08/e2/4c29cb9876456517f21979ddcbb6048f28a3b52c61aa9d14d42adafcdca4/dm_tree-0.1.8-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:378cc8ad93c5fe3590f405a309980721f021c790ca1bdf9b15bb1d59daec57f5", size = 110661, upload-time = "2022-12-18T09:46:18.821Z" }, + { url = "https://files.pythonhosted.org/packages/fe/89/386332bbd7567c4ccc13aa2e58f733237503fc75fb389955d3b06b9fb967/dm_tree-0.1.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1607ce49aa42f010d1e5e616d92ce899d66835d4d8bea49679582435285515de", size = 146727, upload-time = "2023-01-21T08:49:52.992Z" }, + { url = "https://files.pythonhosted.org/packages/a3/e7/b0c04ea5af82c19fd5984bfe980f4012601c4708634c7c51a952b17c93b2/dm_tree-0.1.8-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:343a4a4ebaa127451ff971254a4be4084eb4bdc0b2513c32b46f6f728fd03f9e", size = 174689, upload-time = "2023-01-21T08:49:56.279Z" }, + { url = "https://files.pythonhosted.org/packages/13/0d/09a4ecb54c03db53d9eb5bbc81609d89de26e3762743f003282c1b48debb/dm_tree-0.1.8-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d", size = 150338, upload-time = "2023-01-21T08:49:59.049Z" }, + { url = "https://files.pythonhosted.org/packages/4a/27/c5e3580a952a07e5a1428ae952874796870dc8db789f3d774e886160a9f4/dm_tree-0.1.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:83b7764de0d855338abefc6e3ee9fe40d301668310aa3baea3f778ff051f4393", size = 152800, upload-time = "2022-12-18T09:46:21.065Z" }, + { url = "https://files.pythonhosted.org/packages/e4/c1/522041457444b67125ac9527208bb3148f63d7dce0a86ffa589ec763a10e/dm_tree-0.1.8-cp311-cp311-win_amd64.whl", hash = "sha256:a5d819c38c03f0bb5b3b3703c60e4b170355a0fc6b5819325bf3d4ceb3ae7e80", size = 101336, upload-time = "2022-12-18T09:46:23.449Z" }, + { url = "https://files.pythonhosted.org/packages/72/2c/e33dfc96f974ae3cba82c9836371c93fcb4d59d5a82ebb853861618a0b0b/dm_tree-0.1.8-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ea9e59e0451e7d29aece402d9f908f2e2a80922bcde2ebfd5dcb07750fcbfee8", size = 169495, upload-time = "2024-02-06T09:09:13.276Z" }, + { url = "https://files.pythonhosted.org/packages/17/af/4030827253a5d50eb8da6f7189bc33d3c850c4109cf3414910e9af677cb7/dm_tree-0.1.8-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:94d3f0826311f45ee19b75f5b48c99466e4218a0489e81c0f0167bda50cacf22", size = 116525, upload-time = "2024-02-06T09:09:15.529Z" }, + { url = "https://files.pythonhosted.org/packages/10/10/5f9eed00b1186921e447960443f03cda6374cba8cd5cf7aff2b42ecb8a0e/dm_tree-0.1.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:435227cf3c5dc63f4de054cf3d00183790bd9ead4c3623138c74dde7f67f521b", size = 111436, upload-time = "2024-02-06T09:09:16.781Z" }, + { url = "https://files.pythonhosted.org/packages/4a/da/3d3d04f7a572f7649f48edc9402ff5836e2f90e18445ffde110fd6142889/dm_tree-0.1.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09964470f76a5201aff2e8f9b26842976de7889300676f927930f6285e256760", size = 146828, upload-time = "2024-02-13T21:25:21.639Z" }, + { url = "https://files.pythonhosted.org/packages/c4/12/0a8c2152655ca39c1059c762ea1dc12784166c735126eb0ab929c518ef4e/dm_tree-0.1.8-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:75c5d528bb992981c20793b6b453e91560784215dffb8a5440ba999753c14ceb", size = 175054, upload-time = "2024-02-13T21:25:23.532Z" }, + { url = "https://files.pythonhosted.org/packages/c9/d4/8cbb857612ca69763ee4f4f97c7b91659df1d373d62237cb9c772e55ae97/dm_tree-0.1.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0a94aba18a35457a1b5cd716fd7b46c5dafdc4cf7869b4bae665b91c4682a8e", size = 152834, upload-time = "2024-02-06T09:09:18.536Z" }, + { url = "https://files.pythonhosted.org/packages/ad/e3/96f5267fe5a47c882dce7f3d06b26ddd756681fc4fbedd55d51b78b08bca/dm_tree-0.1.8-cp312-cp312-win_amd64.whl", hash = "sha256:96a548a406a6fb15fe58f6a30a57ff2f2aafbf25f05afab00c8f5e5977b6c715", size = 101754, upload-time = "2024-02-06T09:09:20.962Z" }, +] + [[package]] name = "docker-pycreds" version = "0.4.0" @@ -587,16 +646,18 @@ wheels = [ [[package]] name = "e3tools" -version = "0.1.1" +version = "0.1.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "e3nn" }, + { name = "einops" }, { name = "jaxtyping" }, + { name = "setuptools" }, { name = "torch" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/5a/af/e48b3be6fa30e83adb553869da760117eaa643b7c2af9946b805cf48b325/e3tools-0.1.1.tar.gz", hash = "sha256:23022378554b9ca73f1480e4c575088ca7a73ee4adee7c761eb29907e8cc2098", size = 75249, upload-time = "2025-03-07T00:05:40.901Z" } +sdist = { url = "https://files.pythonhosted.org/packages/1a/03/d3560619bc9c7d7bb1548a260087201a217400f125f6b8c68e5c5532be3e/e3tools-0.1.3.tar.gz", hash = "sha256:a49d919b6f754767ca3c09eaa6a6e1c12fbddace156572b878bbe40ad70ceaa8", size = 106214, upload-time = "2025-08-04T23:29:52.313Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/fe/fd/f3bada4eb07cbca953d403ca121cc8224554bdee905cb2bd176cd75929fa/e3tools-0.1.1-py3-none-any.whl", hash = "sha256:40346ef8e17d966e16d3754e8945c30013baf2c6c4b959c5dacc9b810946d169", size = 18236, upload-time = "2025-03-07T00:05:39.723Z" }, + { url = "https://files.pythonhosted.org/packages/78/c2/76b4ea3bb1426a13d769bf1ae5a13766e8398116f9f158ee5298970c6664/e3tools-0.1.3-py3-none-any.whl", hash = "sha256:39fd064c42f5fe2edd5b15955b5cc514bf64863a5178dd6c041e73423dd03239", size = 21918, upload-time = "2025-08-04T23:29:51.107Z" }, ] [[package]] @@ -792,6 +853,159 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1d/9a/4114a9057db2f1462d5c8f8390ab7383925fe1ac012eaa42402ad65c2963/GitPython-3.1.44-py3-none-any.whl", hash = "sha256:9e0e10cda9bed1ee64bc9a6de50e7e38a9c9943241cd7f585f6df3ed28011110", size = 207599, upload-time = "2025-01-02T07:32:40.731Z" }, ] +[[package]] +name = "google-api-core" +version = "2.25.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-auth" }, + { name = "googleapis-common-protos" }, + { name = "proto-plus" }, + { name = "protobuf" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/dc/21/e9d043e88222317afdbdb567165fdbc3b0aad90064c7e0c9eb0ad9955ad8/google_api_core-2.25.1.tar.gz", hash = "sha256:d2aaa0b13c78c61cb3f4282c464c046e45fbd75755683c9c525e6e8f7ed0a5e8", size = 165443, upload-time = "2025-06-12T20:52:20.439Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/4b/ead00905132820b623732b175d66354e9d3e69fcf2a5dcdab780664e7896/google_api_core-2.25.1-py3-none-any.whl", hash = "sha256:8a2a56c1fef82987a524371f99f3bd0143702fecc670c72e600c1cda6bf8dbb7", size = 160807, upload-time = "2025-06-12T20:52:19.334Z" }, +] + +[[package]] +name = "google-auth" +version = "2.40.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cachetools" }, + { name = "pyasn1-modules" }, + { name = "rsa" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9e/9b/e92ef23b84fa10a64ce4831390b7a4c2e53c0132568d99d4ae61d04c8855/google_auth-2.40.3.tar.gz", hash = "sha256:500c3a29adedeb36ea9cf24b8d10858e152f2412e3ca37829b3fa18e33d63b77", size = 281029, upload-time = "2025-06-04T18:04:57.577Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/17/63/b19553b658a1692443c62bd07e5868adaa0ad746a0751ba62c59568cd45b/google_auth-2.40.3-py2.py3-none-any.whl", hash = "sha256:1370d4593e86213563547f97a92752fc658456fe4514c809544f330fed45a7ca", size = 216137, upload-time = "2025-06-04T18:04:55.573Z" }, +] + +[[package]] +name = "google-cloud-core" +version = "2.4.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-api-core" }, + { name = "google-auth" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d6/b8/2b53838d2acd6ec6168fd284a990c76695e84c65deee79c9f3a4276f6b4f/google_cloud_core-2.4.3.tar.gz", hash = "sha256:1fab62d7102844b278fe6dead3af32408b1df3eb06f5c7e8634cbd40edc4da53", size = 35861, upload-time = "2025-03-10T21:05:38.948Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/40/86/bda7241a8da2d28a754aad2ba0f6776e35b67e37c36ae0c45d49370f1014/google_cloud_core-2.4.3-py2.py3-none-any.whl", hash = "sha256:5130f9f4c14b4fafdff75c79448f9495cfade0d8775facf1b09c3bf67e027f6e", size = 29348, upload-time = "2025-03-10T21:05:37.785Z" }, +] + +[[package]] +name = "google-cloud-storage" +version = "2.19.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-api-core" }, + { name = "google-auth" }, + { name = "google-cloud-core" }, + { name = "google-crc32c" }, + { name = "google-resumable-media" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/36/76/4d965702e96bb67976e755bed9828fa50306dca003dbee08b67f41dd265e/google_cloud_storage-2.19.0.tar.gz", hash = "sha256:cd05e9e7191ba6cb68934d8eb76054d9be4562aa89dbc4236feee4d7d51342b2", size = 5535488, upload-time = "2024-12-05T01:35:06.49Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d5/94/6db383d8ee1adf45dc6c73477152b82731fa4c4a46d9c1932cc8757e0fd4/google_cloud_storage-2.19.0-py2.py3-none-any.whl", hash = "sha256:aeb971b5c29cf8ab98445082cbfe7b161a1f48ed275822f59ed3f1524ea54fba", size = 131787, upload-time = "2024-12-05T01:35:04.736Z" }, +] + +[[package]] +name = "google-crc32c" +version = "1.7.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/19/ae/87802e6d9f9d69adfaedfcfd599266bf386a54d0be058b532d04c794f76d/google_crc32c-1.7.1.tar.gz", hash = "sha256:2bff2305f98846f3e825dbeec9ee406f89da7962accdb29356e4eadc251bd472", size = 14495, upload-time = "2025-03-26T14:29:13.32Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/eb/69/b1b05cf415df0d86691d6a8b4b7e60ab3a6fb6efb783ee5cd3ed1382bfd3/google_crc32c-1.7.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:b07d48faf8292b4db7c3d64ab86f950c2e94e93a11fd47271c28ba458e4a0d76", size = 30467, upload-time = "2025-03-26T14:31:11.92Z" }, + { url = "https://files.pythonhosted.org/packages/44/3d/92f8928ecd671bd5b071756596971c79d252d09b835cdca5a44177fa87aa/google_crc32c-1.7.1-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:7cc81b3a2fbd932a4313eb53cc7d9dde424088ca3a0337160f35d91826880c1d", size = 30311, upload-time = "2025-03-26T14:53:14.161Z" }, + { url = "https://files.pythonhosted.org/packages/33/42/c2d15a73df79d45ed6b430b9e801d0bd8e28ac139a9012d7d58af50a385d/google_crc32c-1.7.1-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:1c67ca0a1f5b56162951a9dae987988679a7db682d6f97ce0f6381ebf0fbea4c", size = 37889, upload-time = "2025-03-26T14:41:27.83Z" }, + { url = "https://files.pythonhosted.org/packages/57/ea/ac59c86a3c694afd117bb669bde32aaf17d0de4305d01d706495f09cbf19/google_crc32c-1.7.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc5319db92daa516b653600794d5b9f9439a9a121f3e162f94b0e1891c7933cb", size = 33028, upload-time = "2025-03-26T14:41:29.141Z" }, + { url = "https://files.pythonhosted.org/packages/60/44/87e77e8476767a4a93f6cf271157c6d948eacec63688c093580af13b04be/google_crc32c-1.7.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dcdf5a64adb747610140572ed18d011896e3b9ae5195f2514b7ff678c80f1603", size = 38026, upload-time = "2025-03-26T14:41:29.921Z" }, + { url = "https://files.pythonhosted.org/packages/c8/bf/21ac7bb305cd7c1a6de9c52f71db0868e104a5b573a4977cd9d0ff830f82/google_crc32c-1.7.1-cp310-cp310-win_amd64.whl", hash = "sha256:754561c6c66e89d55754106739e22fdaa93fafa8da7221b29c8b8e8270c6ec8a", size = 33476, upload-time = "2025-03-26T14:29:09.086Z" }, + { url = "https://files.pythonhosted.org/packages/f7/94/220139ea87822b6fdfdab4fb9ba81b3fff7ea2c82e2af34adc726085bffc/google_crc32c-1.7.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:6fbab4b935989e2c3610371963ba1b86afb09537fd0c633049be82afe153ac06", size = 30468, upload-time = "2025-03-26T14:32:52.215Z" }, + { url = "https://files.pythonhosted.org/packages/94/97/789b23bdeeb9d15dc2904660463ad539d0318286d7633fe2760c10ed0c1c/google_crc32c-1.7.1-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:ed66cbe1ed9cbaaad9392b5259b3eba4a9e565420d734e6238813c428c3336c9", size = 30313, upload-time = "2025-03-26T14:57:38.758Z" }, + { url = "https://files.pythonhosted.org/packages/81/b8/976a2b843610c211e7ccb3e248996a61e87dbb2c09b1499847e295080aec/google_crc32c-1.7.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee6547b657621b6cbed3562ea7826c3e11cab01cd33b74e1f677690652883e77", size = 33048, upload-time = "2025-03-26T14:41:30.679Z" }, + { url = "https://files.pythonhosted.org/packages/c9/16/a3842c2cf591093b111d4a5e2bfb478ac6692d02f1b386d2a33283a19dc9/google_crc32c-1.7.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d68e17bad8f7dd9a49181a1f5a8f4b251c6dbc8cc96fb79f1d321dfd57d66f53", size = 32669, upload-time = "2025-03-26T14:41:31.432Z" }, + { url = "https://files.pythonhosted.org/packages/04/17/ed9aba495916fcf5fe4ecb2267ceb851fc5f273c4e4625ae453350cfd564/google_crc32c-1.7.1-cp311-cp311-win_amd64.whl", hash = "sha256:6335de12921f06e1f774d0dd1fbea6bf610abe0887a1638f64d694013138be5d", size = 33476, upload-time = "2025-03-26T14:29:10.211Z" }, + { url = "https://files.pythonhosted.org/packages/dd/b7/787e2453cf8639c94b3d06c9d61f512234a82e1d12d13d18584bd3049904/google_crc32c-1.7.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:2d73a68a653c57281401871dd4aeebbb6af3191dcac751a76ce430df4d403194", size = 30470, upload-time = "2025-03-26T14:34:31.655Z" }, + { url = "https://files.pythonhosted.org/packages/ed/b4/6042c2b0cbac3ec3a69bb4c49b28d2f517b7a0f4a0232603c42c58e22b44/google_crc32c-1.7.1-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:22beacf83baaf59f9d3ab2bbb4db0fb018da8e5aebdce07ef9f09fce8220285e", size = 30315, upload-time = "2025-03-26T15:01:54.634Z" }, + { url = "https://files.pythonhosted.org/packages/29/ad/01e7a61a5d059bc57b702d9ff6a18b2585ad97f720bd0a0dbe215df1ab0e/google_crc32c-1.7.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:19eafa0e4af11b0a4eb3974483d55d2d77ad1911e6cf6f832e1574f6781fd337", size = 33180, upload-time = "2025-03-26T14:41:32.168Z" }, + { url = "https://files.pythonhosted.org/packages/3b/a5/7279055cf004561894ed3a7bfdf5bf90a53f28fadd01af7cd166e88ddf16/google_crc32c-1.7.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b6d86616faaea68101195c6bdc40c494e4d76f41e07a37ffdef270879c15fb65", size = 32794, upload-time = "2025-03-26T14:41:33.264Z" }, + { url = "https://files.pythonhosted.org/packages/0f/d6/77060dbd140c624e42ae3ece3df53b9d811000729a5c821b9fd671ceaac6/google_crc32c-1.7.1-cp312-cp312-win_amd64.whl", hash = "sha256:b7491bdc0c7564fcf48c0179d2048ab2f7c7ba36b84ccd3a3e1c3f7a72d3bba6", size = 33477, upload-time = "2025-03-26T14:29:10.94Z" }, + { url = "https://files.pythonhosted.org/packages/8b/72/b8d785e9184ba6297a8620c8a37cf6e39b81a8ca01bb0796d7cbb28b3386/google_crc32c-1.7.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:df8b38bdaf1629d62d51be8bdd04888f37c451564c2042d36e5812da9eff3c35", size = 30467, upload-time = "2025-03-26T14:36:06.909Z" }, + { url = "https://files.pythonhosted.org/packages/34/25/5f18076968212067c4e8ea95bf3b69669f9fc698476e5f5eb97d5b37999f/google_crc32c-1.7.1-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:e42e20a83a29aa2709a0cf271c7f8aefaa23b7ab52e53b322585297bb94d4638", size = 30309, upload-time = "2025-03-26T15:06:15.318Z" }, + { url = "https://files.pythonhosted.org/packages/92/83/9228fe65bf70e93e419f38bdf6c5ca5083fc6d32886ee79b450ceefd1dbd/google_crc32c-1.7.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:905a385140bf492ac300026717af339790921f411c0dfd9aa5a9e69a08ed32eb", size = 33133, upload-time = "2025-03-26T14:41:34.388Z" }, + { url = "https://files.pythonhosted.org/packages/c3/ca/1ea2fd13ff9f8955b85e7956872fdb7050c4ace8a2306a6d177edb9cf7fe/google_crc32c-1.7.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b211ddaf20f7ebeec5c333448582c224a7c90a9d98826fbab82c0ddc11348e6", size = 32773, upload-time = "2025-03-26T14:41:35.19Z" }, + { url = "https://files.pythonhosted.org/packages/89/32/a22a281806e3ef21b72db16f948cad22ec68e4bdd384139291e00ff82fe2/google_crc32c-1.7.1-cp313-cp313-win_amd64.whl", hash = "sha256:0f99eaa09a9a7e642a61e06742856eec8b19fc0037832e03f941fe7cf0c8e4db", size = 33475, upload-time = "2025-03-26T14:29:11.771Z" }, + { url = "https://files.pythonhosted.org/packages/b8/c5/002975aff514e57fc084ba155697a049b3f9b52225ec3bc0f542871dd524/google_crc32c-1.7.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32d1da0d74ec5634a05f53ef7df18fc646666a25efaaca9fc7dcfd4caf1d98c3", size = 33243, upload-time = "2025-03-26T14:41:35.975Z" }, + { url = "https://files.pythonhosted.org/packages/61/cb/c585282a03a0cea70fcaa1bf55d5d702d0f2351094d663ec3be1c6c67c52/google_crc32c-1.7.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e10554d4abc5238823112c2ad7e4560f96c7bf3820b202660373d769d9e6e4c9", size = 32870, upload-time = "2025-03-26T14:41:37.08Z" }, + { url = "https://files.pythonhosted.org/packages/0b/43/31e57ce04530794917dfe25243860ec141de9fadf4aa9783dffe7dac7c39/google_crc32c-1.7.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a8e9afc74168b0b2232fb32dd202c93e46b7d5e4bf03e66ba5dc273bb3559589", size = 28242, upload-time = "2025-03-26T14:41:42.858Z" }, + { url = "https://files.pythonhosted.org/packages/eb/f3/8b84cd4e0ad111e63e30eb89453f8dd308e3ad36f42305cf8c202461cdf0/google_crc32c-1.7.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa8136cc14dd27f34a3221c0f16fd42d8a40e4778273e61a3c19aedaa44daf6b", size = 28049, upload-time = "2025-03-26T14:41:44.651Z" }, + { url = "https://files.pythonhosted.org/packages/16/1b/1693372bf423ada422f80fd88260dbfd140754adb15cbc4d7e9a68b1cb8e/google_crc32c-1.7.1-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85fef7fae11494e747c9fd1359a527e5970fc9603c90764843caabd3a16a0a48", size = 28241, upload-time = "2025-03-26T14:41:45.898Z" }, + { url = "https://files.pythonhosted.org/packages/fd/3c/2a19a60a473de48717b4efb19398c3f914795b64a96cf3fbe82588044f78/google_crc32c-1.7.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6efb97eb4369d52593ad6f75e7e10d053cf00c48983f7a973105bc70b0ac4d82", size = 28048, upload-time = "2025-03-26T14:41:46.696Z" }, +] + +[[package]] +name = "google-resumable-media" +version = "2.7.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-crc32c" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/58/5a/0efdc02665dca14e0837b62c8a1a93132c264bd02054a15abb2218afe0ae/google_resumable_media-2.7.2.tar.gz", hash = "sha256:5280aed4629f2b60b847b0d42f9857fd4935c11af266744df33d8074cae92fe0", size = 2163099, upload-time = "2024-08-07T22:20:38.555Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/82/35/b8d3baf8c46695858cb9d8835a53baa1eeb9906ddaf2f728a5f5b640fd1e/google_resumable_media-2.7.2-py2.py3-none-any.whl", hash = "sha256:3ce7551e9fe6d99e9a126101d2536612bb73486721951e9562fee0f90c6ababa", size = 81251, upload-time = "2024-08-07T22:20:36.409Z" }, +] + +[[package]] +name = "googleapis-common-protos" +version = "1.70.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/39/24/33db22342cf4a2ea27c9955e6713140fedd51e8b141b5ce5260897020f1a/googleapis_common_protos-1.70.0.tar.gz", hash = "sha256:0e1b44e0ea153e6594f9f394fef15193a68aaaea2d843f83e2742717ca753257", size = 145903, upload-time = "2025-04-14T10:17:02.924Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/86/f1/62a193f0227cf15a920390abe675f386dec35f7ae3ffe6da582d3ade42c7/googleapis_common_protos-1.70.0-py3-none-any.whl", hash = "sha256:b8bfcca8c25a2bb253e0e0b0adaf8c00773e5e6af6fd92397576680b807e0fd8", size = 294530, upload-time = "2025-04-14T10:17:01.271Z" }, +] + +[[package]] +name = "hf-xet" +version = "1.1.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ed/d4/7685999e85945ed0d7f0762b686ae7015035390de1161dcea9d5276c134c/hf_xet-1.1.5.tar.gz", hash = "sha256:69ebbcfd9ec44fdc2af73441619eeb06b94ee34511bbcf57cd423820090f5694", size = 495969, upload-time = "2025-06-20T21:48:38.007Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/89/a1119eebe2836cb25758e7661d6410d3eae982e2b5e974bcc4d250be9012/hf_xet-1.1.5-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:f52c2fa3635b8c37c7764d8796dfa72706cc4eded19d638331161e82b0792e23", size = 2687929, upload-time = "2025-06-20T21:48:32.284Z" }, + { url = "https://files.pythonhosted.org/packages/de/5f/2c78e28f309396e71ec8e4e9304a6483dcbc36172b5cea8f291994163425/hf_xet-1.1.5-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:9fa6e3ee5d61912c4a113e0708eaaef987047616465ac7aa30f7121a48fc1af8", size = 2556338, upload-time = "2025-06-20T21:48:30.079Z" }, + { url = "https://files.pythonhosted.org/packages/6d/2f/6cad7b5fe86b7652579346cb7f85156c11761df26435651cbba89376cd2c/hf_xet-1.1.5-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc874b5c843e642f45fd85cda1ce599e123308ad2901ead23d3510a47ff506d1", size = 3102894, upload-time = "2025-06-20T21:48:28.114Z" }, + { url = "https://files.pythonhosted.org/packages/d0/54/0fcf2b619720a26fbb6cc941e89f2472a522cd963a776c089b189559447f/hf_xet-1.1.5-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:dbba1660e5d810bd0ea77c511a99e9242d920790d0e63c0e4673ed36c4022d18", size = 3002134, upload-time = "2025-06-20T21:48:25.906Z" }, + { url = "https://files.pythonhosted.org/packages/f3/92/1d351ac6cef7c4ba8c85744d37ffbfac2d53d0a6c04d2cabeba614640a78/hf_xet-1.1.5-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ab34c4c3104133c495785d5d8bba3b1efc99de52c02e759cf711a91fd39d3a14", size = 3171009, upload-time = "2025-06-20T21:48:33.987Z" }, + { url = "https://files.pythonhosted.org/packages/c9/65/4b2ddb0e3e983f2508528eb4501288ae2f84963586fbdfae596836d5e57a/hf_xet-1.1.5-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:83088ecea236d5113de478acb2339f92c95b4fb0462acaa30621fac02f5a534a", size = 3279245, upload-time = "2025-06-20T21:48:36.051Z" }, + { url = "https://files.pythonhosted.org/packages/f0/55/ef77a85ee443ae05a9e9cba1c9f0dd9241eb42da2aeba1dc50f51154c81a/hf_xet-1.1.5-cp37-abi3-win_amd64.whl", hash = "sha256:73e167d9807d166596b4b2f0b585c6d5bd84a26dea32843665a8b58f6edba245", size = 2738931, upload-time = "2025-06-20T21:48:39.482Z" }, +] + +[[package]] +name = "huggingface-hub" +version = "0.34.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "filelock" }, + { name = "fsspec" }, + { name = "hf-xet", marker = "platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" }, + { name = "packaging" }, + { name = "pyyaml" }, + { name = "requests" }, + { name = "tqdm" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/22/cd/841bc8e0550d69f632a15cdd70004e95ba92cd0fbe13087d6669e2bb5f44/huggingface_hub-0.34.1.tar.gz", hash = "sha256:6978ed89ef981de3c78b75bab100a214843be1cc9d24f8e9c0dc4971808ef1b1", size = 456783, upload-time = "2025-07-25T14:54:54.758Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8c/cf/dd53c0132f50f258b06dd37a4616817b1f1f6a6b38382c06effd04bb6881/huggingface_hub-0.34.1-py3-none-any.whl", hash = "sha256:60d843dcb7bc335145b20e7d2f1dfe93910f6787b2b38a936fb772ce2a83757c", size = 558788, upload-time = "2025-07-25T14:54:52.957Z" }, +] + [[package]] name = "hydra-core" version = "1.3.2" @@ -889,8 +1103,10 @@ name = "ipython" version = "9.0.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.12' and sys_platform == 'linux'", - "python_full_version >= '3.12' and sys_platform != 'linux'", + "python_full_version >= '3.13' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and sys_platform == 'linux'", + "python_full_version >= '3.13' and sys_platform != 'linux'", + "python_full_version == '3.12.*' and sys_platform != 'linux'", "python_full_version == '3.11.*' and sys_platform == 'linux'", "python_full_version == '3.11.*' and sys_platform != 'linux'", ] @@ -940,6 +1156,8 @@ dependencies = [ { name = "ninja" }, { name = "numpy" }, { name = "omegaconf" }, + { name = "optree" }, + { name = "orb-models" }, { name = "pandas" }, { name = "plotly" }, { name = "posebusters" }, @@ -961,6 +1179,13 @@ dependencies = [ { name = "wandb" }, ] +[package.optional-dependencies] +analysis = [ + { name = "polars" }, + { name = "pyarrow" }, + { name = "seaborn" }, +] + [package.dev-dependencies] dev = [ { name = "ipykernel" }, @@ -976,7 +1201,7 @@ dev = [ requires-dist = [ { name = "ase", specifier = ">=3.23.0" }, { name = "e3nn", specifier = ">=0.5.6" }, - { name = "e3tools", specifier = ">=0.1.1" }, + { name = "e3tools", specifier = ">=0.1.2" }, { name = "einops", specifier = ">=0.8.0" }, { name = "hydra-core", specifier = ">=1.3.2" }, { name = "lightning", specifier = ">=2.4.0" }, @@ -986,16 +1211,21 @@ requires-dist = [ { name = "ninja", specifier = ">=1.11.1.3" }, { name = "numpy", specifier = ">=2" }, { name = "omegaconf", specifier = ">=2.3.0" }, + { name = "optree", specifier = ">=0.17.0" }, + { name = "orb-models", specifier = ">=0.5.4" }, { name = "pandas", specifier = ">=2.1.0" }, { name = "plotly", specifier = ">=5.24.1" }, + { name = "polars", marker = "extra == 'analysis'", specifier = ">=1.32.0" }, { name = "posebusters", specifier = ">=0.3.1" }, { name = "pot", specifier = ">=0.9.5" }, { name = "py3dmol", specifier = ">=2.4.2" }, + { name = "pyarrow", marker = "extra == 'analysis'", specifier = ">=21.0.0" }, { name = "python-dotenv", specifier = ">=1.0.1" }, { name = "rdkit", specifier = ">=2024.3.6" }, { name = "requests", specifier = ">=2.32.3" }, { name = "s3fs", extras = ["boto3"], specifier = ">=2024.10.0" }, { name = "scipy", specifier = ">=1.13.1" }, + { name = "seaborn", marker = "extra == 'analysis'", specifier = ">=0.13.2" }, { name = "statsmodels", specifier = ">=0.14.0" }, { name = "tabulate", specifier = ">=0.9.0" }, { name = "torch", specifier = ">=2.5.1" }, @@ -1006,6 +1236,7 @@ requires-dist = [ { name = "universal-pathlib", specifier = ">=0.2.6" }, { name = "wandb", specifier = ">=0.19.1" }, ] +provides-extras = ["analysis"] [package.metadata.requires-dev] dev = [ @@ -1252,6 +1483,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/17/c8/dd7dd77f5022b0ae6d377d28dd49b470d820dc2fc51592c2c4860aa6cb2d/lovelyplots-1.0.2-py3-none-any.whl", hash = "sha256:3d778300203e546d7ff642b2ee16c91fa47f0980574b409148e1cb8976fe6832", size = 12968, upload-time = "2024-03-06T05:18:49.258Z" }, ] +[[package]] +name = "markdown-it-py" +version = "3.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mdurl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/38/71/3b932df36c1a044d397a1f92d1cf91ee0a503d91e470cbd670aa66b07ed0/markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb", size = 74596, upload-time = "2023-06-03T06:41:14.443Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1", size = 87528, upload-time = "2023-06-03T06:41:11.019Z" }, +] + [[package]] name = "markupsafe" version = "3.0.2" @@ -1405,6 +1648,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/26/b7/cd45c6bae1566572d96bda6e749c63886c9c6ded079e34615376de5fe26e/mdtraj-1.10.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1c083e080d1ddf3eb25acec343f4efe93671e1508e17f61b656db8c3a50a38d1", size = 7800597, upload-time = "2025-02-07T18:11:16.853Z" }, ] +[[package]] +name = "mdurl" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729, upload-time = "2022-08-14T12:40:10.846Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, +] + [[package]] name = "mpmath" version = "1.3.0" @@ -1685,81 +1937,77 @@ wheels = [ [[package]] name = "nvidia-cublas-cu12" -version = "12.6.4.1" +version = "12.8.4.1" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/af/eb/ff4b8c503fa1f1796679dce648854d58751982426e4e4b37d6fce49d259c/nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:08ed2686e9875d01b58e3cb379c6896df8e76c75e0d4a7f7dace3d7b6d9ef8eb", size = 393138322, upload-time = "2024-11-20T17:40:25.65Z" }, + { url = "https://files.pythonhosted.org/packages/dc/61/e24b560ab2e2eaeb3c839129175fb330dfcfc29e5203196e5541a4c44682/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8ac4e771d5a348c551b2a426eda6193c19aa630236b418086020df5ba9667142", size = 594346921, upload-time = "2025-03-07T01:44:31.254Z" }, ] [[package]] name = "nvidia-cuda-cupti-cu12" -version = "12.6.80" +version = "12.8.90" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/49/60/7b6497946d74bcf1de852a21824d63baad12cd417db4195fc1bfe59db953/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6768bad6cab4f19e8292125e5f1ac8aa7d1718704012a0e3272a6f61c4bce132", size = 8917980, upload-time = "2024-11-20T17:36:04.019Z" }, - { url = "https://files.pythonhosted.org/packages/a5/24/120ee57b218d9952c379d1e026c4479c9ece9997a4fb46303611ee48f038/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a3eff6cdfcc6a4c35db968a06fcadb061cbc7d6dde548609a941ff8701b98b73", size = 8917972, upload-time = "2024-10-01T16:58:06.036Z" }, + { url = "https://files.pythonhosted.org/packages/f8/02/2adcaa145158bf1a8295d83591d22e4103dbfd821bcaf6f3f53151ca4ffa/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea0cb07ebda26bb9b29ba82cda34849e73c166c18162d3913575b0c9db9a6182", size = 10248621, upload-time = "2025-03-07T01:40:21.213Z" }, ] [[package]] name = "nvidia-cuda-nvrtc-cu12" -version = "12.6.77" +version = "12.8.93" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/75/2e/46030320b5a80661e88039f59060d1790298b4718944a65a7f2aeda3d9e9/nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:35b0cc6ee3a9636d5409133e79273ce1f3fd087abb0532d2d2e8fff1fe9efc53", size = 23650380, upload-time = "2024-10-01T17:00:14.643Z" }, + { url = "https://files.pythonhosted.org/packages/05/6b/32f747947df2da6994e999492ab306a903659555dddc0fbdeb9d71f75e52/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:a7756528852ef889772a84c6cd89d41dfa74667e24cca16bb31f8f061e3e9994", size = 88040029, upload-time = "2025-03-07T01:42:13.562Z" }, ] [[package]] name = "nvidia-cuda-runtime-cu12" -version = "12.6.77" +version = "12.8.90" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e1/23/e717c5ac26d26cf39a27fbc076240fad2e3b817e5889d671b67f4f9f49c5/nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ba3b56a4f896141e25e19ab287cd71e52a6a0f4b29d0d31609f60e3b4d5219b7", size = 897690, upload-time = "2024-11-20T17:35:30.697Z" }, - { url = "https://files.pythonhosted.org/packages/f0/62/65c05e161eeddbafeca24dc461f47de550d9fa8a7e04eb213e32b55cfd99/nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a84d15d5e1da416dd4774cb42edf5e954a3e60cc945698dc1d5be02321c44dc8", size = 897678, upload-time = "2024-10-01T16:57:33.821Z" }, + { url = "https://files.pythonhosted.org/packages/0d/9b/a997b638fcd068ad6e4d53b8551a7d30fe8b404d6f1804abf1df69838932/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adade8dcbd0edf427b7204d480d6066d33902cab2a4707dcfc48a2d0fd44ab90", size = 954765, upload-time = "2025-03-07T01:40:01.615Z" }, ] [[package]] name = "nvidia-cudnn-cu12" -version = "9.5.1.17" +version = "9.10.2.21" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/78/4535c9c7f859a64781e43c969a3a7e84c54634e319a996d43ef32ce46f83/nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:30ac3869f6db17d170e0e556dd6cc5eee02647abc31ca856634d5a40f82c15b2", size = 570988386, upload-time = "2024-10-25T19:54:26.39Z" }, + { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" }, ] [[package]] name = "nvidia-cufft-cu12" -version = "11.3.0.4" +version = "11.3.3.83" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/8f/16/73727675941ab8e6ffd86ca3a4b7b47065edcca7a997920b831f8147c99d/nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ccba62eb9cef5559abd5e0d54ceed2d9934030f51163df018532142a8ec533e5", size = 200221632, upload-time = "2024-11-20T17:41:32.357Z" }, - { url = "https://files.pythonhosted.org/packages/60/de/99ec247a07ea40c969d904fc14f3a356b3e2a704121675b75c366b694ee1/nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.whl", hash = "sha256:768160ac89f6f7b459bee747e8d175dbf53619cfe74b2a5636264163138013ca", size = 200221622, upload-time = "2024-10-01T17:03:58.79Z" }, + { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, ] [[package]] name = "nvidia-cufile-cu12" -version = "1.11.1.6" +version = "1.13.1.3" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b2/66/cc9876340ac68ae71b15c743ddb13f8b30d5244af344ec8322b449e35426/nvidia_cufile_cu12-1.11.1.6-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cc23469d1c7e52ce6c1d55253273d32c565dd22068647f3aa59b3c6b005bf159", size = 1142103, upload-time = "2024-11-20T17:42:11.83Z" }, + { url = "https://files.pythonhosted.org/packages/bb/fe/1bcba1dfbfb8d01be8d93f07bfc502c93fa23afa6fd5ab3fc7c1df71038a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1d069003be650e131b21c932ec3d8969c1715379251f8d23a1860554b1cb24fc", size = 1197834, upload-time = "2025-03-07T01:45:50.723Z" }, ] [[package]] name = "nvidia-curand-cu12" -version = "10.3.7.77" +version = "10.3.9.90" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/73/1b/44a01c4e70933637c93e6e1a8063d1e998b50213a6b65ac5a9169c47e98e/nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a42cd1344297f70b9e39a1e4f467a4e1c10f1da54ff7a85c12197f6c652c8bdf", size = 56279010, upload-time = "2024-11-20T17:42:50.958Z" }, - { url = "https://files.pythonhosted.org/packages/4a/aa/2c7ff0b5ee02eaef890c0ce7d4f74bc30901871c5e45dee1ae6d0083cd80/nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:99f1a32f1ac2bd134897fc7a203f779303261268a65762a623bf30cc9fe79117", size = 56279000, upload-time = "2024-10-01T17:04:45.274Z" }, + { url = "https://files.pythonhosted.org/packages/fb/aa/6584b56dc84ebe9cf93226a5cde4d99080c8e90ab40f0c27bda7a0f29aa1/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:b32331d4f4df5d6eefa0554c565b626c7216f87a06a4f56fab27c3b68a830ec9", size = 63619976, upload-time = "2025-03-07T01:46:23.323Z" }, ] [[package]] name = "nvidia-cusolver-cu12" -version = "11.7.1.2" +version = "11.7.3.90" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, @@ -1767,53 +2015,50 @@ dependencies = [ { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/f0/6e/c2cf12c9ff8b872e92b4a5740701e51ff17689c4d726fca91875b07f655d/nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e9e49843a7707e42022babb9bcfa33c29857a93b88020c4e4434656a655b698c", size = 158229790, upload-time = "2024-11-20T17:43:43.211Z" }, - { url = "https://files.pythonhosted.org/packages/9f/81/baba53585da791d043c10084cf9553e074548408e04ae884cfe9193bd484/nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:6cf28f17f64107a0c4d7802be5ff5537b2130bfc112f25d5a30df227058ca0e6", size = 158229780, upload-time = "2024-10-01T17:05:39.875Z" }, + { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, ] [[package]] name = "nvidia-cusparse-cu12" -version = "12.5.4.2" +version = "12.5.8.93" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/06/1e/b8b7c2f4099a37b96af5c9bb158632ea9e5d9d27d7391d7eb8fc45236674/nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7556d9eca156e18184b94947ade0fba5bb47d69cec46bf8660fd2c71a4b48b73", size = 216561367, upload-time = "2024-11-20T17:44:54.824Z" }, - { url = "https://files.pythonhosted.org/packages/43/ac/64c4316ba163e8217a99680c7605f779accffc6a4bcd0c778c12948d3707/nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:23749a6571191a215cb74d1cdbff4a86e7b19f1200c071b3fcf844a5bea23a2f", size = 216561357, upload-time = "2024-10-01T17:06:29.861Z" }, + { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, ] [[package]] name = "nvidia-cusparselt-cu12" -version = "0.6.3" +version = "0.7.1" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3b/9a/72ef35b399b0e183bc2e8f6f558036922d453c4d8237dab26c666a04244b/nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:e5c8a26c36445dd2e6812f1177978a24e2d37cacce7e090f297a688d1ec44f46", size = 156785796, upload-time = "2024-10-15T21:29:17.709Z" }, + { url = "https://files.pythonhosted.org/packages/56/79/12978b96bd44274fe38b5dde5cfb660b1d114f70a65ef962bcbbed99b549/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f1bb701d6b930d5a7cea44c19ceb973311500847f81b634d802b7b539dc55623", size = 287193691, upload-time = "2025-02-26T00:15:44.104Z" }, ] [[package]] name = "nvidia-nccl-cu12" -version = "2.26.2" +version = "2.27.3" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/67/ca/f42388aed0fddd64ade7493dbba36e1f534d4e6fdbdd355c6a90030ae028/nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:694cf3879a206553cc9d7dbda76b13efaf610fdb70a50cba303de1b0d1530ac6", size = 201319755, upload-time = "2025-03-13T00:29:55.296Z" }, + { url = "https://files.pythonhosted.org/packages/5c/5b/4e4fff7bad39adf89f735f2bc87248c81db71205b62bcc0d5ca5b606b3c3/nvidia_nccl_cu12-2.27.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adf27ccf4238253e0b826bce3ff5fa532d65fc42322c8bfdfaf28024c0fbe039", size = 322364134, upload-time = "2025-06-03T21:58:04.013Z" }, ] [[package]] name = "nvidia-nvjitlink-cu12" -version = "12.6.85" +version = "12.8.93" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/9d/d7/c5383e47c7e9bf1c99d5bd2a8c935af2b6d705ad831a7ec5c97db4d82f4f/nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:eedc36df9e88b682efe4309aa16b5b4e78c2407eac59e8c10a6a47535164369a", size = 19744971, upload-time = "2024-11-20T17:46:53.366Z" }, + { url = "https://files.pythonhosted.org/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:81ff63371a7ebd6e6451970684f916be2eab07321b73c9d244dc2b4da7f73b88", size = 39254836, upload-time = "2025-03-07T01:49:55.661Z" }, ] [[package]] name = "nvidia-nvtx-cu12" -version = "12.6.77" +version = "12.8.90" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/56/9a/fff8376f8e3d084cd1530e1ef7b879bb7d6d265620c95c1b322725c694f4/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b90bed3df379fa79afbd21be8e04a0314336b8ae16768b58f2d34cb1d04cd7d2", size = 89276, upload-time = "2024-11-20T17:38:27.621Z" }, - { url = "https://files.pythonhosted.org/packages/9e/4e/0d0c945463719429b7bd21dece907ad0bde437a2ff12b9b12fee94722ab0/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:6574241a3ec5fdc9334353ab8c479fe75841dbe8f4532a8fc97ce63503330ba1", size = 89265, upload-time = "2024-10-01T17:00:38.172Z" }, + { url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954, upload-time = "2025-03-07T01:42:44.131Z" }, ] [[package]] @@ -1852,6 +2097,115 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8d/4c/e0370709aaf9d7ceb68f975cac559751e75954429a77e83202e680606560/opt_einsum_fx-0.1.4-py3-none-any.whl", hash = "sha256:85f489f4c7c31fd88d5faf9669c09e61ec37a30098809fdcfe2a08a9e42f23c9", size = 13213, upload-time = "2021-11-07T20:49:32.395Z" }, ] +[[package]] +name = "optree" +version = "0.17.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/56/c7/0853e0c59b135dff770615d2713b547b6b3b5cde7c10995b4a5825244612/optree-0.17.0.tar.gz", hash = "sha256:5335a5ec44479920620d72324c66563bd705ab2a698605dd4b6ee67dbcad7ecd", size = 163111, upload-time = "2025-07-25T11:26:11.586Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/52/a0/d5795ac13390b04822f1c61699f684cde682b57bf0a2d6b406019e1762ae/optree-0.17.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:85ec183b8eec6efc9a5572c2a84c62214c949555efbc69ca2381aca6048d08df", size = 622371, upload-time = "2025-07-25T11:24:23.345Z" }, + { url = "https://files.pythonhosted.org/packages/53/8b/ae8ddb511e680eb9d61edd2f5245be88ce050456658fb165550144f9a509/optree-0.17.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6e77b6e0b7bb3ecfeb9a92ba605ef21b39bff38829b745af993e2e2b474322e2", size = 337260, upload-time = "2025-07-25T11:24:25.291Z" }, + { url = "https://files.pythonhosted.org/packages/91/f9/6ca076fd4c6f16be031afdc711a2676c1ff15bd1717ee2e699179b1a29bc/optree-0.17.0-cp310-cp310-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:98990201f352dba253af1a995c1453818db5f08de4cae7355d85aa6023676a52", size = 350398, upload-time = "2025-07-25T11:24:26.672Z" }, + { url = "https://files.pythonhosted.org/packages/95/4c/81344cbdcf8ea8525a21c9d65892d7529010ee2146c53423b2e9a84441ba/optree-0.17.0-cp310-cp310-manylinux_2_26_i686.manylinux_2_28_i686.whl", hash = "sha256:e1a40adf6bb78a6a4b4f480879de2cb6b57d46d680a4d9834aa824f41e69c0d9", size = 404834, upload-time = "2025-07-25T11:24:28.988Z" }, + { url = "https://files.pythonhosted.org/packages/e5/c4/ac1880372a89f5c21514a7965dfa23b1afb2ad683fb9804d366727de9ecf/optree-0.17.0-cp310-cp310-manylinux_2_26_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:78a113436a0a440f900b2799584f3cc2b2eea1b245d81c3583af42ac003e333c", size = 402116, upload-time = "2025-07-25T11:24:30.396Z" }, + { url = "https://files.pythonhosted.org/packages/ff/72/ad6be4d6a03805cf3921b492494cb3371ca28060d5ad19d5a36e10c4d67d/optree-0.17.0-cp310-cp310-manylinux_2_26_s390x.manylinux_2_28_s390x.whl", hash = "sha256:0e45c16018f4283f028cf839b707b7ac734e8056a31b7198a1577161fcbe146d", size = 398491, upload-time = "2025-07-25T11:24:31.725Z" }, + { url = "https://files.pythonhosted.org/packages/d9/c1/6827fb504351f9a3935699b0eb31c8a6af59d775ee78289a25e0ba54f732/optree-0.17.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b698613d821d80cc216a2444ebc3145c8bf671b55a2223058a6574c1483a65f6", size = 387957, upload-time = "2025-07-25T11:24:32.759Z" }, + { url = "https://files.pythonhosted.org/packages/21/3d/44b3cbe4c9245a13b2677e30db2aafadf00bda976a551d64a31dc92f4977/optree-0.17.0-cp310-cp310-win32.whl", hash = "sha256:d07bfd8ce803dbc005502a89fda5f5e078e237342eaa36fb0c46cfbdf750bc76", size = 280064, upload-time = "2025-07-25T11:24:33.875Z" }, + { url = "https://files.pythonhosted.org/packages/74/fa/83d4cd387043483ee23617b048829a1289bf54afe2f6cb98ec7b27133369/optree-0.17.0-cp310-cp310-win_amd64.whl", hash = "sha256:d009d368ef06b8757891b772cad24d4f84122bd1877f7674fb8227d6e15340b4", size = 304398, upload-time = "2025-07-25T11:24:34.844Z" }, + { url = "https://files.pythonhosted.org/packages/21/4f/752522f318683efa7bba1895667c9841165d0284f6dfadf601769f6398ce/optree-0.17.0-cp310-cp310-win_arm64.whl", hash = "sha256:3571085ed9a5f39ff78ef57def0e9607c6b3f0099b6910524a0b42f5d58e481e", size = 308260, upload-time = "2025-07-25T11:24:36.144Z" }, + { url = "https://files.pythonhosted.org/packages/d8/eb/389a7dae8b113064f53909707aea9d72372fdc2eb918c48783c443cb3438/optree-0.17.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:09fbc0e5e42b20cab11851dffb7abe2fdf289c45d29e5be2b50b4ea93d069a9f", size = 640773, upload-time = "2025-07-25T11:24:37.25Z" }, + { url = "https://files.pythonhosted.org/packages/2b/bb/2d78b524989cabb5720e85ea366addc8589b4bbd0ce3f5ea58e370e5636a/optree-0.17.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:90a5864689268eda75d90abded5d474ae0a7ae2608d510626724fb78a1955948", size = 346402, upload-time = "2025-07-25T11:24:38.25Z" }, + { url = "https://files.pythonhosted.org/packages/73/5c/13a2a864b0c0b39c3c193be534a195a3ab2463c7d0443d4a76e749e3ff83/optree-0.17.0-cp311-cp311-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3080c564c9760711aa72d1b4d700ce1417f99ad087136f415c4eb8221169e2a3", size = 362797, upload-time = "2025-07-25T11:24:39.509Z" }, + { url = "https://files.pythonhosted.org/packages/da/f5/ff7dcb5a0108ee89c2be09aed2ebd26a7e1333d8122031aa9d9322b24ee6/optree-0.17.0-cp311-cp311-manylinux_2_26_i686.manylinux_2_28_i686.whl", hash = "sha256:834a8fb358b608240b3a38706a09b43974675624485fad64c8ee641dae2eb57d", size = 419450, upload-time = "2025-07-25T11:24:40.555Z" }, + { url = "https://files.pythonhosted.org/packages/1b/e6/48a97aefd18770b55e5ed456d8183891f325cdb6d90592e5f072ed6951f8/optree-0.17.0-cp311-cp311-manylinux_2_26_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:1a2bd263e6b5621d000d0f94de1f245414fd5dbce365a24b7b89b1ed0ef56cf9", size = 417557, upload-time = "2025-07-25T11:24:42.396Z" }, + { url = "https://files.pythonhosted.org/packages/c4/b1/4e280edab8a86be47ec1f9bd9ed4b685d2e15f0950ae62b613b26d12a1da/optree-0.17.0-cp311-cp311-manylinux_2_26_s390x.manylinux_2_28_s390x.whl", hash = "sha256:9b37daca4ad89339b1f5320cc61ac600dcf976adbb060769d36d5542d6ebfedf", size = 414174, upload-time = "2025-07-25T11:24:43.51Z" }, + { url = "https://files.pythonhosted.org/packages/db/3b/49a9a1986215dd342525974deeb17c260a83fee8fad147276fd710ac8718/optree-0.17.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a146a6917f3e28cfdc268ff1770aa696c346482dd3da681c3ff92153d94450ea", size = 402000, upload-time = "2025-07-25T11:24:44.819Z" }, + { url = "https://files.pythonhosted.org/packages/00/8d/13b79d3394b83f4b1c93daac336f0eca5cb1cd5f58e10618f2c2db779cb7/optree-0.17.0-cp311-cp311-win32.whl", hash = "sha256:6b0446803d08f6aaae84f82f03c51527f36dfa15850873fc0183792247bc0071", size = 285777, upload-time = "2025-07-25T11:24:45.976Z" }, + { url = "https://files.pythonhosted.org/packages/90/32/da5191a347e33a78c2804a0cbfaed8eecb758818efda4b4d70bfd9b9b38d/optree-0.17.0-cp311-cp311-win_amd64.whl", hash = "sha256:e39f4f00b2967116badd9617ad6aa9845d8327fe13b6dbf5bc36d8c7b4a5ea03", size = 313761, upload-time = "2025-07-25T11:24:47.047Z" }, + { url = "https://files.pythonhosted.org/packages/e1/ea/7cae17a37a8ef67a33c354fce6f136d5f253d5afa40f68701252b1b2c2a0/optree-0.17.0-cp311-cp311-win_arm64.whl", hash = "sha256:50d4dbcbca3e379cc6b374f9b5a5626ff7ea41df8373e26c3af41d89d8a4b3d5", size = 318242, upload-time = "2025-07-25T11:24:48.708Z" }, + { url = "https://files.pythonhosted.org/packages/79/ce/471ff57336630f2434238a8cb8401e0d714ee7d54a6117823fd85de5f656/optree-0.17.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:09156e2ea62cde66dcbd9a450a5517ad6bad07d4ffc98fab0982c1e4f538341a", size = 654627, upload-time = "2025-07-25T11:24:49.754Z" }, + { url = "https://files.pythonhosted.org/packages/aa/ef/3143b7840dd2daedf1257643119c0f3addd23cf90cc9d2efc88f8166931e/optree-0.17.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:750f24304d1d437c8b235d4bc9e4afda17d85950706c34a875c16049f707eeb4", size = 351124, upload-time = "2025-07-25T11:24:50.813Z" }, + { url = "https://files.pythonhosted.org/packages/41/90/e12dea2cb5d8a5e17bbe3011ed4e972b89c027272a816db4897589751cad/optree-0.17.0-cp312-cp312-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e13ae51a63d69db445f269a3a4fd1d6edb064a705188d007ea47c9f034788fc5", size = 365869, upload-time = "2025-07-25T11:24:51.807Z" }, + { url = "https://files.pythonhosted.org/packages/76/ee/21af214663960a479863cd6c03d7a0abc8123ea22a6ea34689c2eed88ccd/optree-0.17.0-cp312-cp312-manylinux_2_26_i686.manylinux_2_28_i686.whl", hash = "sha256:5958f58423cc7870cb011c8c8f92687397380886e8c9d33adac752147e7bbc3f", size = 424465, upload-time = "2025-07-25T11:24:53.124Z" }, + { url = "https://files.pythonhosted.org/packages/54/a3/64b184a79373753f4f46a5cd301ea581f71d6dc1a5c103bd2394f0925d40/optree-0.17.0-cp312-cp312-manylinux_2_26_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:970ae4e47727b4c5526fc583b87d29190e576f6a2b6c19e8671589b73d256250", size = 420686, upload-time = "2025-07-25T11:24:54.212Z" }, + { url = "https://files.pythonhosted.org/packages/6c/6d/b6051b0b1ef9a49df96a66e9e62fc02620d2115d1ba659888c94e67fcfc9/optree-0.17.0-cp312-cp312-manylinux_2_26_s390x.manylinux_2_28_s390x.whl", hash = "sha256:54177fd3e6e05c08b66329e26d7d44b85f24125f25c6b74c921499a1b31b8f70", size = 421225, upload-time = "2025-07-25T11:24:55.213Z" }, + { url = "https://files.pythonhosted.org/packages/f6/f1/940bc959aaef9eede8bb1b1127833b0929c6ffa9268ec0f6cb19877e2027/optree-0.17.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e1959cfbc38c228c8195354967cda64887b96219924b7b3759e5ee355582c1ec", size = 408819, upload-time = "2025-07-25T11:24:56.315Z" }, + { url = "https://files.pythonhosted.org/packages/56/52/ce527556e27dbf77266c1b1bb313ca446c94bc6edd6d7a882dbded028197/optree-0.17.0-cp312-cp312-win32.whl", hash = "sha256:039ea98c0cd94a64040d6f6d21dbe5cd9731bb380d7893f78d6898672080a232", size = 289107, upload-time = "2025-07-25T11:24:57.357Z" }, + { url = "https://files.pythonhosted.org/packages/eb/f1/aecb0199d269ad8ea41a86182474f98378a72681facbd6a06e94c23a2d02/optree-0.17.0-cp312-cp312-win_amd64.whl", hash = "sha256:c3a21109f635ce353d116ed1d77a7dfd77b898bcdaccef3bf74881ce7d6d54d8", size = 314074, upload-time = "2025-07-25T11:24:58.499Z" }, + { url = "https://files.pythonhosted.org/packages/3a/20/615ad64d24318709a236163dd8620fa7879a7720bfd0c755604d3dceeb76/optree-0.17.0-cp312-cp312-win_arm64.whl", hash = "sha256:1a39f957299426d2d4aa36cbc1acd71edb198ff0f28ddb43029bf58efe34a9a1", size = 316409, upload-time = "2025-07-25T11:24:59.855Z" }, + { url = "https://files.pythonhosted.org/packages/21/04/9706d11b880186e9e9d66d7c21ce249b2ce0212645137cc13fdd18247c26/optree-0.17.0-cp313-cp313-ios_13_0_arm64_iphoneos.whl", hash = "sha256:b5995a3efce4b00a14049268a81ab0379656a41ddf3c3761e3b88937fca44d48", size = 348177, upload-time = "2025-07-25T11:25:00.999Z" }, + { url = "https://files.pythonhosted.org/packages/ae/4b/0415c18816818ac871c9f3d5c7c5f4ceb83baff03ed511c9c94591ace4bc/optree-0.17.0-cp313-cp313-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:d06e8143d16fe6c0708f3cc2807b5b65f815d60ee2b52f3d79e4022c95563482", size = 354389, upload-time = "2025-07-25T11:25:02.337Z" }, + { url = "https://files.pythonhosted.org/packages/88/4d/5ce687b3945a34f0f0e17765745f146473b47177badd93b5979374d6e29c/optree-0.17.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:9537c4f82fe454a689e124462f252c4911cd7c78c6277334e7132f8157fb85e8", size = 661629, upload-time = "2025-07-25T11:25:03.429Z" }, + { url = "https://files.pythonhosted.org/packages/45/17/52ec65b80b6a17a9b7242e4cbf569c3d8035e72c49b6a3baba73aed6aa16/optree-0.17.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:79e8a594002509163d218827476f522d4f9ee6436438d90251d28d413af6740c", size = 354967, upload-time = "2025-07-25T11:25:04.523Z" }, + { url = "https://files.pythonhosted.org/packages/dd/12/24d4a417fd325ec06cfbce52716ac4f816ef696653b868960ac2ccb28436/optree-0.17.0-cp313-cp313-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:dfeea4aa0fd354d27922aba63ff9d86e4e126c6bf89cfb02849e68515519f1a5", size = 368513, upload-time = "2025-07-25T11:25:05.548Z" }, + { url = "https://files.pythonhosted.org/packages/30/e2/34e392209933e2c582c67594a7a6b4851bca4015c83b51c7508384b616b4/optree-0.17.0-cp313-cp313-manylinux_2_26_i686.manylinux_2_28_i686.whl", hash = "sha256:6b2ff8999a9b84d00f23a032b6b3f13678894432a335d024e0670b9880f238ca", size = 430378, upload-time = "2025-07-25T11:25:06.918Z" }, + { url = "https://files.pythonhosted.org/packages/5f/16/0a0d6139022e9a53ecb1212fb6fbc5b60eff824371071ef5f5fa481d8167/optree-0.17.0-cp313-cp313-manylinux_2_26_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:ea8bef525432b38a84e7448348da1a2dc308375bce79c77675cc50a501305851", size = 423294, upload-time = "2025-07-25T11:25:08.043Z" }, + { url = "https://files.pythonhosted.org/packages/ef/60/2e083dabb6aff6d939d8aab16ba3dbe6eee9429597a13f3fca57b33cdcde/optree-0.17.0-cp313-cp313-manylinux_2_26_s390x.manylinux_2_28_s390x.whl", hash = "sha256:f95b81aa67538d38316b184a6ff39a3725ee5c8555fba21dcb692f8d7c39302e", size = 424633, upload-time = "2025-07-25T11:25:09.141Z" }, + { url = "https://files.pythonhosted.org/packages/af/fd/0e4229b5fa3fd9d3c779a606c0f358ffbdfee717f49b3477facd04de2cec/optree-0.17.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e808a1125169ae90de623456ef2423eb84a8578a74f03fe48b06b8561c2cc31d", size = 414866, upload-time = "2025-07-25T11:25:10.214Z" }, + { url = "https://files.pythonhosted.org/packages/e7/81/976082e979d42d36f9f81ee300d8fe7e86ca87588b70e372a40cb9203c9b/optree-0.17.0-cp313-cp313-win32.whl", hash = "sha256:4f3e0c5b20a4ef5b5a2688b5a07221cf1d2a8b2a57f82cf0c601f9d16f71450b", size = 289505, upload-time = "2025-07-25T11:25:11.616Z" }, + { url = "https://files.pythonhosted.org/packages/fb/ab/5b2c75c262c106747b5fbf1603a94ca8047896e719c3219ca85cb2d9c300/optree-0.17.0-cp313-cp313-win_amd64.whl", hash = "sha256:057f95213e403ff3a975f287aef6b687299d0c4512d211de24b1b98050cd4fbf", size = 316703, upload-time = "2025-07-25T11:25:12.638Z" }, + { url = "https://files.pythonhosted.org/packages/68/d6/78c0c927867b60d9b010bac84eae4046c761084bf2ed8a8d25521965ab4f/optree-0.17.0-cp313-cp313-win_arm64.whl", hash = "sha256:749dbecfd04edd50493b35bfb1f5be350f31b384533301e2257d4b0d0132544c", size = 318098, upload-time = "2025-07-25T11:25:13.755Z" }, + { url = "https://files.pythonhosted.org/packages/98/fd/6b5fdf3430157eced42d193bb49805668a380c672cc40317efe1dea3d739/optree-0.17.0-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:98c11fae09c5861f42c400f0fa3851f3d58ceba347267d458332710f094d5f75", size = 750506, upload-time = "2025-07-25T11:25:15.267Z" }, + { url = "https://files.pythonhosted.org/packages/19/0a/d8acb03fbf2edfd240a55363d903fad577e880a30a3117b60545a2a31aa5/optree-0.17.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:0b9f25c47de72044d7e1f42e9ed4c765f0867d321a2e6d194bc5facf69316417", size = 399106, upload-time = "2025-07-25T11:25:16.671Z" }, + { url = "https://files.pythonhosted.org/packages/39/df/b8882f5519c85af146de3a79a08066a56fe634b23052c593fcedc70bfcd7/optree-0.17.0-cp313-cp313t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8e45a13b35873712e095fe0f7fd6e9c4f98f3bd5af6f5dc33c17b80357bc97fc", size = 386945, upload-time = "2025-07-25T11:25:17.728Z" }, + { url = "https://files.pythonhosted.org/packages/ca/d7/91f4efb509bda601a1591465c4a5bd55320e4bafe06b294bf80754127b0e/optree-0.17.0-cp313-cp313t-manylinux_2_26_i686.manylinux_2_28_i686.whl", hash = "sha256:bfaf04d833dc53e5cfccff3b564e934a49086158472e31d84df31fce6d4f7b1c", size = 444177, upload-time = "2025-07-25T11:25:18.749Z" }, + { url = "https://files.pythonhosted.org/packages/84/17/a4833006e925c6ed5c45ceb02e65c9e9a260e70da6523858fcf628481847/optree-0.17.0-cp313-cp313t-manylinux_2_26_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:b4c1d030ac1c881803f5c8e23d241159ae403fd00cdf57625328f282fc671ebd", size = 439198, upload-time = "2025-07-25T11:25:19.865Z" }, + { url = "https://files.pythonhosted.org/packages/ef/d1/c08fc60f6dfcb1b86ca1fdc0add08a98412a1596cd45830acbdc309f2cdb/optree-0.17.0-cp313-cp313t-manylinux_2_26_s390x.manylinux_2_28_s390x.whl", hash = "sha256:bd7738709970acab5d963896192b63b2718be93bb6c0bcea91895ea157fa2b13", size = 439391, upload-time = "2025-07-25T11:25:20.942Z" }, + { url = "https://files.pythonhosted.org/packages/05/8f/461e10201003e6ad6bff3c594a29a7e044454aba68c5f795f4c8386ce47c/optree-0.17.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1644bc24b6e93cafccfdeee44157c3d4ae9bb0af3e861300602d716699865b1a", size = 426555, upload-time = "2025-07-25T11:25:21.968Z" }, + { url = "https://files.pythonhosted.org/packages/b5/4a/334d579dcb1ecea722ad37b7a8b7b29bb05ab7fe4464479862932ffd1869/optree-0.17.0-cp313-cp313t-win32.whl", hash = "sha256:f6be1f6f045f326bd419285ee92ebb13f1317149cbea84ca73c5bf06109a61bb", size = 319949, upload-time = "2025-07-25T11:25:23.127Z" }, + { url = "https://files.pythonhosted.org/packages/c8/96/5879944aee653471ad2a1ca5194ece0ca5d59de7c1d1fc5682ea3fb42057/optree-0.17.0-cp313-cp313t-win_amd64.whl", hash = "sha256:9d06b89803b1c72044fa5f07c708e33af7fe38ca2f5001cc9b6463894105b052", size = 352862, upload-time = "2025-07-25T11:25:24.214Z" }, + { url = "https://files.pythonhosted.org/packages/0d/de/cc600c216db4caa5b9ec5372e0c7fa05cd38eacde7e519c969ceab8712b6/optree-0.17.0-cp313-cp313t-win_arm64.whl", hash = "sha256:43f243d04fdba644647b1cabbfe4d7ca5fdb16c02e6d7d56e638d3e0b73566e8", size = 352101, upload-time = "2025-07-25T11:25:25.318Z" }, + { url = "https://files.pythonhosted.org/packages/d5/f7/cc6e920faaf96f78e373bf4ca83f806a40892104c0d437ab03402afeb94d/optree-0.17.0-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:8808e0b6bd9d0288b76cac6ed5d589532c9c4f3f2b88157c70591e8a0cc9aa3b", size = 662838, upload-time = "2025-07-25T11:25:26.439Z" }, + { url = "https://files.pythonhosted.org/packages/22/fd/a8859f401de8305bd09f6f0f7491e6153cf8e50a8390eaa2b9d0e1f1fc95/optree-0.17.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:80c9dd735e7990a48f3da981125df6c10c9990d1876be7a034357aece600e07f", size = 355857, upload-time = "2025-07-25T11:25:27.55Z" }, + { url = "https://files.pythonhosted.org/packages/3c/21/6480d23b52b2e23b976fe254b9fbdc4b514e90a349b1ee73565b185c69f1/optree-0.17.0-cp314-cp314-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:dd21e0a89806cc3b86aaa578a73897d56085038fe432043534a23b2e559d7691", size = 369929, upload-time = "2025-07-25T11:25:28.897Z" }, + { url = "https://files.pythonhosted.org/packages/b3/29/69bb26473ff862a1792f5568c977e7a2580e08afe0fdcd7a7b3e1e4d6933/optree-0.17.0-cp314-cp314-manylinux_2_26_i686.manylinux_2_28_i686.whl", hash = "sha256:9211c61285b8b3e42fd0e803cebd6e2b0987d8b2edffe45b42923debca09a9df", size = 430381, upload-time = "2025-07-25T11:25:29.984Z" }, + { url = "https://files.pythonhosted.org/packages/c8/8b/2c0a38c0d0c2396d698b97216cd6814d6754d11997b6ac66c57d87d71bae/optree-0.17.0-cp314-cp314-manylinux_2_26_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:87938255749a45979c4e331627cb33d81aa08b0a09d024368b3e25ff67f0e9f2", size = 424461, upload-time = "2025-07-25T11:25:31.116Z" }, + { url = "https://files.pythonhosted.org/packages/a7/77/08fda3f97621190d50762225ee8bad87463a8b3a55fba451a999971ff130/optree-0.17.0-cp314-cp314-manylinux_2_26_s390x.manylinux_2_28_s390x.whl", hash = "sha256:3432858145fd1955a3be12207507466ac40a6911f428bf5d2d6c7f67486530a2", size = 427234, upload-time = "2025-07-25T11:25:32.289Z" }, + { url = "https://files.pythonhosted.org/packages/ea/b5/b4f19952c36d6448c85a6ef6be5f916dd13548de2b684ab123f04b450850/optree-0.17.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5afe3e9e2f6da0a0a5c0892f32f675eb88965036b061aa555b74e6c412a05e17", size = 413863, upload-time = "2025-07-25T11:25:33.379Z" }, + { url = "https://files.pythonhosted.org/packages/b7/8c/1da744bb0cc550aed105f8a252fa8d8270067c5e21db7b95e457f76701da/optree-0.17.0-cp314-cp314-win32.whl", hash = "sha256:db6ce8e0d8585621230446736fa99c2883b34f9e56784957f69c47e2de34bdb4", size = 294314, upload-time = "2025-07-25T11:25:34.49Z" }, + { url = "https://files.pythonhosted.org/packages/84/05/5865e2a33c535c6b47378a43605de17cc286de59b93dc7814eb122861963/optree-0.17.0-cp314-cp314-win_amd64.whl", hash = "sha256:aa963de4146fa1b5cdffb479d324262f245c957df0bb9a9b37f6fd559d027acc", size = 323848, upload-time = "2025-07-25T11:25:35.511Z" }, + { url = "https://files.pythonhosted.org/packages/f1/01/55321c0d7b6bb60d88e5f5927216bcdc03e99f1f42567a0bcc23e786554e/optree-0.17.0-cp314-cp314-win_arm64.whl", hash = "sha256:855bfc78eba74748f931be6d6b739a9b03ac82a5c96511d66f310659903f6812", size = 325642, upload-time = "2025-07-25T11:25:36.649Z" }, + { url = "https://files.pythonhosted.org/packages/ee/be/24ef1e0d4212aedb087ff7b7a324426a093172327ecf9c33d2cf4cb6a69c/optree-0.17.0-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:0ac9626a51148c8497e82e9a9c21746795e179fbdec0b01c1644031e25f0d97e", size = 750484, upload-time = "2025-07-25T11:25:37.897Z" }, + { url = "https://files.pythonhosted.org/packages/4e/80/fc26e7c120849297992b0ecf8e435f213a379cc7923ea6ab1bad7b7d9c3f/optree-0.17.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:769c74ac289cdf108986fad2a36f24f4dd5ac6cf62919f99facdce943cd37359", size = 399067, upload-time = "2025-07-25T11:25:38.953Z" }, + { url = "https://files.pythonhosted.org/packages/88/42/6003f13e66cfbe7f0011bf8509da2479aba93068cdb9d79bf46010255089/optree-0.17.0-cp314-cp314t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5739c03a3362be42cb7649e82457c90aa818aa3e82af9681d3100c3346f4a90f", size = 386975, upload-time = "2025-07-25T11:25:40.376Z" }, + { url = "https://files.pythonhosted.org/packages/d0/53/621642abd76eda5a941b47adc98be81f0052683160be776499d11b4af83d/optree-0.17.0-cp314-cp314t-manylinux_2_26_i686.manylinux_2_28_i686.whl", hash = "sha256:ee07b59a08bd45aedd5252241a98841f1a5082a7b9b73df2dae6a433aa2a91d8", size = 444173, upload-time = "2025-07-25T11:25:41.474Z" }, + { url = "https://files.pythonhosted.org/packages/5b/d3/8819a2d5105a240d6793d11a61d597db91756ce84da5cee08808c6b8f61f/optree-0.17.0-cp314-cp314t-manylinux_2_26_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:875c017890a4b5d566af5593cab67fe3c4845544942af57e6bb9dea17e060297", size = 439080, upload-time = "2025-07-25T11:25:42.605Z" }, + { url = "https://files.pythonhosted.org/packages/c6/ef/9dbd34dfd1ad89feb239ca9925897a14ac94f190379a3bd991afdfd94186/optree-0.17.0-cp314-cp314t-manylinux_2_26_s390x.manylinux_2_28_s390x.whl", hash = "sha256:ffa5686191139f763e13445a169765c83517164bc28e60dbedb19bed2b2655f1", size = 439422, upload-time = "2025-07-25T11:25:43.672Z" }, + { url = "https://files.pythonhosted.org/packages/86/ca/a7a7549af2951925a692df508902ed2a6a94a51bc846806d2281b1029ef9/optree-0.17.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:575cf48cc2190acb565bd2b26b6f9b15c4e3b60183e86031215badc9d5441345", size = 426579, upload-time = "2025-07-25T11:25:44.765Z" }, + { url = "https://files.pythonhosted.org/packages/e6/0c/eb4d8ef38f1b51116095985b350ac9eede7a71d40c2ffaa283e9646b04e0/optree-0.17.0-cp314-cp314t-win32.whl", hash = "sha256:f1897de02364b7ef4a5bb56ae352b674ebf2cdd33da2b0f3543340282dc1f3e1", size = 329053, upload-time = "2025-07-25T11:25:45.845Z" }, + { url = "https://files.pythonhosted.org/packages/18/c6/f8e8c339e384578e3300215c732c20033f97d5ceb4c3d23a38bdb3527d98/optree-0.17.0-cp314-cp314t-win_amd64.whl", hash = "sha256:08df33cf74518f74b1c1f4ac0b760f544796a0b1cede91191c4daea0df3f314c", size = 367555, upload-time = "2025-07-25T11:25:46.95Z" }, + { url = "https://files.pythonhosted.org/packages/97/6f/1358550954dbbbb93b23fc953800e1ff2283024505255b0f9ba901f25e0e/optree-0.17.0-cp314-cp314t-win_arm64.whl", hash = "sha256:93d08d17b7b1d82b51ee7dd3a5a21ae2391fb30fc65a1369d4855c484923b967", size = 359135, upload-time = "2025-07-25T11:25:48.062Z" }, + { url = "https://files.pythonhosted.org/packages/ca/52/350c58dce327257afd77b92258e43d0bfe00416fc167b0c256ec86dcf9e7/optree-0.17.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:f365328450c1072e7a707dce67eaa6db3f63671907c866e3751e317b27ea187e", size = 342845, upload-time = "2025-07-25T11:26:01.651Z" }, + { url = "https://files.pythonhosted.org/packages/ed/d7/3036d15c028c447b1bd65dcf8f66cfd775bfa4e52daa74b82fb1d3c88faf/optree-0.17.0-pp310-pypy310_pp73-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:adde1427e0982cfc5f56939c26b4ebbd833091a176734c79fb95c78bdf833dff", size = 350952, upload-time = "2025-07-25T11:26:02.692Z" }, + { url = "https://files.pythonhosted.org/packages/71/45/e710024ef77324e745de48efd64f6270d8c209f14107a48ffef4049ac57a/optree-0.17.0-pp310-pypy310_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a80b7e5de5dd09b9c8b62d501e29a3850b047565c336c9d004b07ee1c01f4ae1", size = 389568, upload-time = "2025-07-25T11:26:04.094Z" }, + { url = "https://files.pythonhosted.org/packages/a8/63/b5cd1309f76f53e8a3cfbc88642647e58b1d3dd39f7cb0daf60ec516a252/optree-0.17.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:3c2c79652c45d82f23cbe08349456b1067ea513234a086b9a6bf1bcf128962a9", size = 306686, upload-time = "2025-07-25T11:26:05.511Z" }, + { url = "https://files.pythonhosted.org/packages/ca/40/afec131d9dd7a18d129190d407d97c95994f42b70c3d8ab897092d4de1d9/optree-0.17.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:bd92011cd0f2de40d28a95842819e778c476ab25c12731bfef1d1a0225554f83", size = 353955, upload-time = "2025-07-25T11:26:06.75Z" }, + { url = "https://files.pythonhosted.org/packages/69/c4/94a187ed3ca71194b9da6a276790e1703c7544c8f695ac915214ae8ce934/optree-0.17.0-pp311-pypy311_pp73-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f87f6f39015fc82d7adeee19900d246b89911319726e93cb2dbd4d1a809899bd", size = 363728, upload-time = "2025-07-25T11:26:07.959Z" }, + { url = "https://files.pythonhosted.org/packages/cd/99/23b7a484da8dfb814107b20ef2c93ef27c04f36aeb83bd976964a5b69e06/optree-0.17.0-pp311-pypy311_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:58b0a83a967d2ef0f343db7182f0ad074eb1166bcaea909ae33909462013f151", size = 404649, upload-time = "2025-07-25T11:26:09.463Z" }, + { url = "https://files.pythonhosted.org/packages/bc/1f/7eca6da47eadb9ff2183bc9169eadde3dda0518e9a0187b99d5926fb2994/optree-0.17.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:e1ae8cbbcfaa45c57f5e51c544afa554cefbbb9fe9586c108aaf2aebfadf5899", size = 316368, upload-time = "2025-07-25T11:26:10.572Z" }, +] + +[[package]] +name = "orb-models" +version = "0.5.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ase" }, + { name = "cached-path" }, + { name = "dm-tree" }, + { name = "numpy" }, + { name = "scipy" }, + { name = "torch" }, + { name = "tqdm" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ec/63/385c78f164a8062fac89ef631a414463a6029d310d78cff9ee949ef2a9cd/orb_models-0.5.4.tar.gz", hash = "sha256:bc4e7b11eac16e9b1681cb667ccbdd263edf9702433a1eb106969dcc29ce7916", size = 87763, upload-time = "2025-04-30T09:08:39.804Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c3/f8/6c6d12caf9a6fd25e7cb4aaeedf01ebda88cbc5fdc4a9d1db48e2683d15d/orb_models-0.5.4-py3-none-any.whl", hash = "sha256:af096f30c39cb11965aee792092b00b8cc350fd7dfbc13d53528f223f1953fe7", size = 92237, upload-time = "2025-04-30T09:08:38.328Z" }, +] + [[package]] name = "packaging" version = "24.2" @@ -2040,6 +2394,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556, upload-time = "2024-04-20T21:34:40.434Z" }, ] +[[package]] +name = "polars" +version = "1.32.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/67/23/6a5f151981f3ac409bed6dc48a3eaecd0592a03eb382693d4c7e749eda8b/polars-1.32.0.tar.gz", hash = "sha256:b01045981c0f23eeccfbfc870b782f93e73b74b29212fdfc8aae0be9024bc1fb", size = 4761045, upload-time = "2025-08-01T01:43:22.413Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e2/40/5b27067d10b5a77ab4094932118e16629ffb20ea9ae5f7d1178e04087891/polars-1.32.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:94f7c6a3b30bc99bc6b682ea42bb1ae983e33a302ca21aacbac50ae19e34fcf2", size = 37479518, upload-time = "2025-08-01T01:42:18.603Z" }, + { url = "https://files.pythonhosted.org/packages/08/b7/ca28ac10d340fb91bffb2751efd52aebc9799ae161b867214c6299c8f75b/polars-1.32.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:8bf14c16164839e62c741a863942a94a9a463db21e797452fca996c8afaf8827", size = 34214196, upload-time = "2025-08-01T01:42:22.667Z" }, + { url = "https://files.pythonhosted.org/packages/61/97/fe3797e8e1d4f9eadab32ffe218a841b8874585b6c9bd0f1a26469fb2992/polars-1.32.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f4c15adb97d44766d30c759f5cebbdb64d361e8349ef10b5afc7413f71bf4b72", size = 37985353, upload-time = "2025-08-01T01:42:26.033Z" }, + { url = "https://files.pythonhosted.org/packages/a0/7e/2baa2858556e970cc6a35c0d8ad34b2f9d982f1766c0a1fec20ca529a947/polars-1.32.0-cp39-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:13af55890734f89b76016a395fb2e7460e7d9feecf50ed2f55cf0f05a1c0c991", size = 35183912, upload-time = "2025-08-01T01:42:30.446Z" }, + { url = "https://files.pythonhosted.org/packages/ef/41/0e6821dccc5871186a9b95af3990404aa283318263918d33ac974b35cb37/polars-1.32.0-cp39-abi3-win_amd64.whl", hash = "sha256:0397fc2501a5d5f1bb3fe8d27e0c26c7a5349b4110157c0fb7833cd3f5921c9e", size = 37747905, upload-time = "2025-08-01T01:42:33.975Z" }, + { url = "https://files.pythonhosted.org/packages/c2/93/d06df0817da93f922a67e27e9e0f407856991374daa62687e2a45a18935c/polars-1.32.0-cp39-abi3-win_arm64.whl", hash = "sha256:dd84e24422509e1ec9be46f67f758d0bd9944d1ae4eacecee4f53adaa8ecd822", size = 33978543, upload-time = "2025-08-01T01:42:36.779Z" }, +] + [[package]] name = "posebusters" version = "0.3.6" @@ -2205,6 +2573,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b5/35/6c4c6fc8774a9e3629cd750dc24a7a4fb090a25ccd5c3246d127b70f9e22/propcache-0.3.0-py3-none-any.whl", hash = "sha256:67dda3c7325691c2081510e92c561f465ba61b975f481735aefdfc845d2cd043", size = 12101, upload-time = "2025-02-20T19:03:27.202Z" }, ] +[[package]] +name = "proto-plus" +version = "1.26.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f4/ac/87285f15f7cce6d4a008f33f1757fb5a13611ea8914eb58c3d0d26243468/proto_plus-1.26.1.tar.gz", hash = "sha256:21a515a4c4c0088a773899e23c7bbade3d18f9c66c73edd4c7ee3816bc96a012", size = 56142, upload-time = "2025-03-10T15:54:38.843Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4e/6d/280c4c2ce28b1593a19ad5239c8b826871fc6ec275c21afc8e1820108039/proto_plus-1.26.1-py3-none-any.whl", hash = "sha256:13285478c2dcf2abb829db158e1047e2f1e8d63a077d94263c2b88b043c75a66", size = 50163, upload-time = "2025-03-10T15:54:37.335Z" }, +] + [[package]] name = "protobuf" version = "5.29.4" @@ -2261,6 +2641,70 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/04/20/923885064f4e4d4392eb2be798532d91b315f9e60ef44f49f4800ba3c57a/py3Dmol-2.4.2-py2.py3-none-any.whl", hash = "sha256:bec23d9a015d692279a5f7d4db92803e4e82ba3bdcc1434a5b6a2be98a347856", size = 7046, upload-time = "2024-11-08T22:19:21.631Z" }, ] +[[package]] +name = "pyarrow" +version = "21.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ef/c2/ea068b8f00905c06329a3dfcd40d0fcc2b7d0f2e355bdb25b65e0a0e4cd4/pyarrow-21.0.0.tar.gz", hash = "sha256:5051f2dccf0e283ff56335760cbc8622cf52264d67e359d5569541ac11b6d5bc", size = 1133487, upload-time = "2025-07-18T00:57:31.761Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/17/d9/110de31880016e2afc52d8580b397dbe47615defbf09ca8cf55f56c62165/pyarrow-21.0.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:e563271e2c5ff4d4a4cbeb2c83d5cf0d4938b891518e676025f7268c6fe5fe26", size = 31196837, upload-time = "2025-07-18T00:54:34.755Z" }, + { url = "https://files.pythonhosted.org/packages/df/5f/c1c1997613abf24fceb087e79432d24c19bc6f7259cab57c2c8e5e545fab/pyarrow-21.0.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:fee33b0ca46f4c85443d6c450357101e47d53e6c3f008d658c27a2d020d44c79", size = 32659470, upload-time = "2025-07-18T00:54:38.329Z" }, + { url = "https://files.pythonhosted.org/packages/3e/ed/b1589a777816ee33ba123ba1e4f8f02243a844fed0deec97bde9fb21a5cf/pyarrow-21.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:7be45519b830f7c24b21d630a31d48bcebfd5d4d7f9d3bdb49da9cdf6d764edb", size = 41055619, upload-time = "2025-07-18T00:54:42.172Z" }, + { url = "https://files.pythonhosted.org/packages/44/28/b6672962639e85dc0ac36f71ab3a8f5f38e01b51343d7aa372a6b56fa3f3/pyarrow-21.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:26bfd95f6bff443ceae63c65dc7e048670b7e98bc892210acba7e4995d3d4b51", size = 42733488, upload-time = "2025-07-18T00:54:47.132Z" }, + { url = "https://files.pythonhosted.org/packages/f8/cc/de02c3614874b9089c94eac093f90ca5dfa6d5afe45de3ba847fd950fdf1/pyarrow-21.0.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:bd04ec08f7f8bd113c55868bd3fc442a9db67c27af098c5f814a3091e71cc61a", size = 43329159, upload-time = "2025-07-18T00:54:51.686Z" }, + { url = "https://files.pythonhosted.org/packages/a6/3e/99473332ac40278f196e105ce30b79ab8affab12f6194802f2593d6b0be2/pyarrow-21.0.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:9b0b14b49ac10654332a805aedfc0147fb3469cbf8ea951b3d040dab12372594", size = 45050567, upload-time = "2025-07-18T00:54:56.679Z" }, + { url = "https://files.pythonhosted.org/packages/7b/f5/c372ef60593d713e8bfbb7e0c743501605f0ad00719146dc075faf11172b/pyarrow-21.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:9d9f8bcb4c3be7738add259738abdeddc363de1b80e3310e04067aa1ca596634", size = 26217959, upload-time = "2025-07-18T00:55:00.482Z" }, + { url = "https://files.pythonhosted.org/packages/94/dc/80564a3071a57c20b7c32575e4a0120e8a330ef487c319b122942d665960/pyarrow-21.0.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:c077f48aab61738c237802836fc3844f85409a46015635198761b0d6a688f87b", size = 31243234, upload-time = "2025-07-18T00:55:03.812Z" }, + { url = "https://files.pythonhosted.org/packages/ea/cc/3b51cb2db26fe535d14f74cab4c79b191ed9a8cd4cbba45e2379b5ca2746/pyarrow-21.0.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:689f448066781856237eca8d1975b98cace19b8dd2ab6145bf49475478bcaa10", size = 32714370, upload-time = "2025-07-18T00:55:07.495Z" }, + { url = "https://files.pythonhosted.org/packages/24/11/a4431f36d5ad7d83b87146f515c063e4d07ef0b7240876ddb885e6b44f2e/pyarrow-21.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:479ee41399fcddc46159a551705b89c05f11e8b8cb8e968f7fec64f62d91985e", size = 41135424, upload-time = "2025-07-18T00:55:11.461Z" }, + { url = "https://files.pythonhosted.org/packages/74/dc/035d54638fc5d2971cbf1e987ccd45f1091c83bcf747281cf6cc25e72c88/pyarrow-21.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:40ebfcb54a4f11bcde86bc586cbd0272bac0d516cfa539c799c2453768477569", size = 42823810, upload-time = "2025-07-18T00:55:16.301Z" }, + { url = "https://files.pythonhosted.org/packages/2e/3b/89fced102448a9e3e0d4dded1f37fa3ce4700f02cdb8665457fcc8015f5b/pyarrow-21.0.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8d58d8497814274d3d20214fbb24abcad2f7e351474357d552a8d53bce70c70e", size = 43391538, upload-time = "2025-07-18T00:55:23.82Z" }, + { url = "https://files.pythonhosted.org/packages/fb/bb/ea7f1bd08978d39debd3b23611c293f64a642557e8141c80635d501e6d53/pyarrow-21.0.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:585e7224f21124dd57836b1530ac8f2df2afc43c861d7bf3d58a4870c42ae36c", size = 45120056, upload-time = "2025-07-18T00:55:28.231Z" }, + { url = "https://files.pythonhosted.org/packages/6e/0b/77ea0600009842b30ceebc3337639a7380cd946061b620ac1a2f3cb541e2/pyarrow-21.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:555ca6935b2cbca2c0e932bedd853e9bc523098c39636de9ad4693b5b1df86d6", size = 26220568, upload-time = "2025-07-18T00:55:32.122Z" }, + { url = "https://files.pythonhosted.org/packages/ca/d4/d4f817b21aacc30195cf6a46ba041dd1be827efa4a623cc8bf39a1c2a0c0/pyarrow-21.0.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:3a302f0e0963db37e0a24a70c56cf91a4faa0bca51c23812279ca2e23481fccd", size = 31160305, upload-time = "2025-07-18T00:55:35.373Z" }, + { url = "https://files.pythonhosted.org/packages/a2/9c/dcd38ce6e4b4d9a19e1d36914cb8e2b1da4e6003dd075474c4cfcdfe0601/pyarrow-21.0.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:b6b27cf01e243871390474a211a7922bfbe3bda21e39bc9160daf0da3fe48876", size = 32684264, upload-time = "2025-07-18T00:55:39.303Z" }, + { url = "https://files.pythonhosted.org/packages/4f/74/2a2d9f8d7a59b639523454bec12dba35ae3d0a07d8ab529dc0809f74b23c/pyarrow-21.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:e72a8ec6b868e258a2cd2672d91f2860ad532d590ce94cdf7d5e7ec674ccf03d", size = 41108099, upload-time = "2025-07-18T00:55:42.889Z" }, + { url = "https://files.pythonhosted.org/packages/ad/90/2660332eeb31303c13b653ea566a9918484b6e4d6b9d2d46879a33ab0622/pyarrow-21.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b7ae0bbdc8c6674259b25bef5d2a1d6af5d39d7200c819cf99e07f7dfef1c51e", size = 42829529, upload-time = "2025-07-18T00:55:47.069Z" }, + { url = "https://files.pythonhosted.org/packages/33/27/1a93a25c92717f6aa0fca06eb4700860577d016cd3ae51aad0e0488ac899/pyarrow-21.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:58c30a1729f82d201627c173d91bd431db88ea74dcaa3885855bc6203e433b82", size = 43367883, upload-time = "2025-07-18T00:55:53.069Z" }, + { url = "https://files.pythonhosted.org/packages/05/d9/4d09d919f35d599bc05c6950095e358c3e15148ead26292dfca1fb659b0c/pyarrow-21.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:072116f65604b822a7f22945a7a6e581cfa28e3454fdcc6939d4ff6090126623", size = 45133802, upload-time = "2025-07-18T00:55:57.714Z" }, + { url = "https://files.pythonhosted.org/packages/71/30/f3795b6e192c3ab881325ffe172e526499eb3780e306a15103a2764916a2/pyarrow-21.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:cf56ec8b0a5c8c9d7021d6fd754e688104f9ebebf1bf4449613c9531f5346a18", size = 26203175, upload-time = "2025-07-18T00:56:01.364Z" }, + { url = "https://files.pythonhosted.org/packages/16/ca/c7eaa8e62db8fb37ce942b1ea0c6d7abfe3786ca193957afa25e71b81b66/pyarrow-21.0.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:e99310a4ebd4479bcd1964dff9e14af33746300cb014aa4a3781738ac63baf4a", size = 31154306, upload-time = "2025-07-18T00:56:04.42Z" }, + { url = "https://files.pythonhosted.org/packages/ce/e8/e87d9e3b2489302b3a1aea709aaca4b781c5252fcb812a17ab6275a9a484/pyarrow-21.0.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:d2fe8e7f3ce329a71b7ddd7498b3cfac0eeb200c2789bd840234f0dc271a8efe", size = 32680622, upload-time = "2025-07-18T00:56:07.505Z" }, + { url = "https://files.pythonhosted.org/packages/84/52/79095d73a742aa0aba370c7942b1b655f598069489ab387fe47261a849e1/pyarrow-21.0.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:f522e5709379d72fb3da7785aa489ff0bb87448a9dc5a75f45763a795a089ebd", size = 41104094, upload-time = "2025-07-18T00:56:10.994Z" }, + { url = "https://files.pythonhosted.org/packages/89/4b/7782438b551dbb0468892a276b8c789b8bbdb25ea5c5eb27faadd753e037/pyarrow-21.0.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:69cbbdf0631396e9925e048cfa5bce4e8c3d3b41562bbd70c685a8eb53a91e61", size = 42825576, upload-time = "2025-07-18T00:56:15.569Z" }, + { url = "https://files.pythonhosted.org/packages/b3/62/0f29de6e0a1e33518dec92c65be0351d32d7ca351e51ec5f4f837a9aab91/pyarrow-21.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:731c7022587006b755d0bdb27626a1a3bb004bb56b11fb30d98b6c1b4718579d", size = 43368342, upload-time = "2025-07-18T00:56:19.531Z" }, + { url = "https://files.pythonhosted.org/packages/90/c7/0fa1f3f29cf75f339768cc698c8ad4ddd2481c1742e9741459911c9ac477/pyarrow-21.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:dc56bc708f2d8ac71bd1dcb927e458c93cec10b98eb4120206a4091db7b67b99", size = 45131218, upload-time = "2025-07-18T00:56:23.347Z" }, + { url = "https://files.pythonhosted.org/packages/01/63/581f2076465e67b23bc5a37d4a2abff8362d389d29d8105832e82c9c811c/pyarrow-21.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:186aa00bca62139f75b7de8420f745f2af12941595bbbfa7ed3870ff63e25636", size = 26087551, upload-time = "2025-07-18T00:56:26.758Z" }, + { url = "https://files.pythonhosted.org/packages/c9/ab/357d0d9648bb8241ee7348e564f2479d206ebe6e1c47ac5027c2e31ecd39/pyarrow-21.0.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:a7a102574faa3f421141a64c10216e078df467ab9576684d5cd696952546e2da", size = 31290064, upload-time = "2025-07-18T00:56:30.214Z" }, + { url = "https://files.pythonhosted.org/packages/3f/8a/5685d62a990e4cac2043fc76b4661bf38d06efed55cf45a334b455bd2759/pyarrow-21.0.0-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:1e005378c4a2c6db3ada3ad4c217b381f6c886f0a80d6a316fe586b90f77efd7", size = 32727837, upload-time = "2025-07-18T00:56:33.935Z" }, + { url = "https://files.pythonhosted.org/packages/fc/de/c0828ee09525c2bafefd3e736a248ebe764d07d0fd762d4f0929dbc516c9/pyarrow-21.0.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:65f8e85f79031449ec8706b74504a316805217b35b6099155dd7e227eef0d4b6", size = 41014158, upload-time = "2025-07-18T00:56:37.528Z" }, + { url = "https://files.pythonhosted.org/packages/6e/26/a2865c420c50b7a3748320b614f3484bfcde8347b2639b2b903b21ce6a72/pyarrow-21.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:3a81486adc665c7eb1a2bde0224cfca6ceaba344a82a971ef059678417880eb8", size = 42667885, upload-time = "2025-07-18T00:56:41.483Z" }, + { url = "https://files.pythonhosted.org/packages/0a/f9/4ee798dc902533159250fb4321267730bc0a107d8c6889e07c3add4fe3a5/pyarrow-21.0.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:fc0d2f88b81dcf3ccf9a6ae17f89183762c8a94a5bdcfa09e05cfe413acf0503", size = 43276625, upload-time = "2025-07-18T00:56:48.002Z" }, + { url = "https://files.pythonhosted.org/packages/5a/da/e02544d6997037a4b0d22d8e5f66bc9315c3671371a8b18c79ade1cefe14/pyarrow-21.0.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:6299449adf89df38537837487a4f8d3bd91ec94354fdd2a7d30bc11c48ef6e79", size = 44951890, upload-time = "2025-07-18T00:56:52.568Z" }, + { url = "https://files.pythonhosted.org/packages/e5/4e/519c1bc1876625fe6b71e9a28287c43ec2f20f73c658b9ae1d485c0c206e/pyarrow-21.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:222c39e2c70113543982c6b34f3077962b44fca38c0bd9e68bb6781534425c10", size = 26371006, upload-time = "2025-07-18T00:56:56.379Z" }, +] + +[[package]] +name = "pyasn1" +version = "0.6.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ba/e9/01f1a64245b89f039897cb0130016d79f77d52669aae6ee7b159a6c4c018/pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034", size = 145322, upload-time = "2024-09-10T22:41:42.55Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/f1/d6a797abb14f6283c0ddff96bbdd46937f64122b8c925cab503dd37f8214/pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629", size = 83135, upload-time = "2024-09-11T16:00:36.122Z" }, +] + +[[package]] +name = "pyasn1-modules" +version = "0.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e9/e6/78ebbb10a8c8e4b61a59249394a4a594c1a7af95593dc933a349c8d00964/pyasn1_modules-0.4.2.tar.gz", hash = "sha256:677091de870a80aae844b1ca6134f54652fa2c8c5a52aa396440ac3106e941e6", size = 307892, upload-time = "2025-03-28T02:41:22.17Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/47/8d/d529b5d697919ba8c11ad626e835d4039be708a35b0d22de83a269a6682c/pyasn1_modules-0.4.2-py3-none-any.whl", hash = "sha256:29253a9207ce32b64c3ac6600edc75368f98473906e8fd1043bd6b5b1de2c14a", size = 181259, upload-time = "2025-03-28T02:41:19.028Z" }, +] + [[package]] name = "pycparser" version = "2.22" @@ -2639,6 +3083,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f9/9b/335f9764261e915ed497fcdeb11df5dfd6f7bf257d4a6a2a686d80da4d54/requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6", size = 64928, upload-time = "2024-05-29T15:37:47.027Z" }, ] +[[package]] +name = "rich" +version = "13.9.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py" }, + { name = "pygments" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ab/3a/0316b28d0761c6734d6bc14e770d85506c986c85ffb239e688eeaab2c2bc/rich-13.9.4.tar.gz", hash = "sha256:439594978a49a09530cff7ebc4b5c7103ef57baf48d5ea3184f21d9a2befa098", size = 223149, upload-time = "2024-11-01T16:43:57.873Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/19/71/39c7c0d87f8d4e6c020a393182060eaefeeae6c01dab6a84ec346f2567df/rich-13.9.4-py3-none-any.whl", hash = "sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90", size = 242424, upload-time = "2024-11-01T16:43:55.817Z" }, +] + [[package]] name = "rpds-py" version = "0.23.1" @@ -2724,6 +3182,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5e/bb/e45f51c4e1327dea3c72b846c6de129eebacb7a6cb309af7af35d0578c80/rpds_py-0.23.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:75307599f0d25bf6937248e5ac4e3bde5ea72ae6618623b86146ccc7845ed00b", size = 233827, upload-time = "2025-02-21T15:03:56.853Z" }, ] +[[package]] +name = "rsa" +version = "4.9.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/da/8a/22b7beea3ee0d44b1916c0c1cb0ee3af23b700b6da9f04991899d0c555d4/rsa-4.9.1.tar.gz", hash = "sha256:e7bdbfdb5497da4c07dfd35530e1a902659db6ff241e39d9953cad06ebd0ae75", size = 29034, upload-time = "2025-04-16T09:51:18.218Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/64/8d/0133e4eb4beed9e425d9a98ed6e081a55d195481b7632472be1af08d2f6b/rsa-4.9.1-py3-none-any.whl", hash = "sha256:68635866661c6836b8d39430f97a996acbd61bfa49406748ea243539fe239762", size = 34696, upload-time = "2025-04-16T09:51:17.142Z" }, +] + [[package]] name = "ruff" version = "0.11.2" @@ -2836,6 +3306,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0a/c8/b3f566db71461cabd4b2d5b39bcc24a7e1c119535c8361f81426be39bb47/scipy-1.15.2-cp313-cp313t-win_amd64.whl", hash = "sha256:fe8a9eb875d430d81755472c5ba75e84acc980e4a8f6204d402849234d3017db", size = 40477705, upload-time = "2025-02-17T00:34:43.619Z" }, ] +[[package]] +name = "seaborn" +version = "0.13.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "matplotlib" }, + { name = "numpy" }, + { name = "pandas" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/86/59/a451d7420a77ab0b98f7affa3a1d78a313d2f7281a57afb1a34bae8ab412/seaborn-0.13.2.tar.gz", hash = "sha256:93e60a40988f4d65e9f4885df477e2fdaff6b73a9ded434c1ab356dd57eefff7", size = 1457696, upload-time = "2024-01-25T13:21:52.551Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/83/11/00d3c3dfc25ad54e731d91449895a79e4bf2384dc3ac01809010ba88f6d5/seaborn-0.13.2-py3-none-any.whl", hash = "sha256:636f8336facf092165e27924f223d3c62ca560b1f2bb5dff7ab7fad265361987", size = 294914, upload-time = "2024-01-25T13:21:49.598Z" }, +] + [[package]] name = "sentry-sdk" version = "2.24.0" @@ -2911,11 +3395,11 @@ wheels = [ [[package]] name = "setuptools" -version = "77.0.3" +version = "80.9.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/81/ed/7101d53811fd359333583330ff976e5177c5e871ca8b909d1d6c30553aa3/setuptools-77.0.3.tar.gz", hash = "sha256:583b361c8da8de57403743e756609670de6fb2345920e36dc5c2d914c319c945", size = 1367236, upload-time = "2025-03-20T14:38:08.777Z" } +sdist = { url = "https://files.pythonhosted.org/packages/18/5d/3bf57dcd21979b887f014ea83c24ae194cfcd12b9e0fda66b957c69d1fca/setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c", size = 1319958, upload-time = "2025-05-27T00:56:51.443Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a9/07/99f2cefae815c66eb23148f15d79ec055429c38fa8986edcc712ab5f3223/setuptools-77.0.3-py3-none-any.whl", hash = "sha256:67122e78221da5cf550ddd04cf8742c8fe12094483749a792d56cd669d6cf58c", size = 1255678, upload-time = "2025-03-20T14:38:06.621Z" }, + { url = "https://files.pythonhosted.org/packages/a3/dc/17031897dae0efacfea57dfd3a82fdd2a2aeb58e0ff71b77b87e44edc772/setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922", size = 1201486, upload-time = "2025-05-27T00:56:49.664Z" }, ] [[package]] @@ -3051,7 +3535,7 @@ wheels = [ [[package]] name = "torch" -version = "2.7.0" +version = "2.8.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, @@ -3078,26 +3562,26 @@ dependencies = [ { name = "typing-extensions" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/46/c2/3fb87940fa160d956ee94d644d37b99a24b9c05a4222bf34f94c71880e28/torch-2.7.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:c9afea41b11e1a1ab1b258a5c31afbd646d6319042bfe4f231b408034b51128b", size = 99158447, upload-time = "2025-04-23T14:35:10.557Z" }, - { url = "https://files.pythonhosted.org/packages/cc/2c/91d1de65573fce563f5284e69d9c56b57289625cffbbb6d533d5d56c36a5/torch-2.7.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:0b9960183b6e5b71239a3e6c883d8852c304e691c0b2955f7045e8a6d05b9183", size = 865164221, upload-time = "2025-04-23T14:33:27.864Z" }, - { url = "https://files.pythonhosted.org/packages/7f/7e/1b1cc4e0e7cc2666cceb3d250eef47a205f0821c330392cf45eb08156ce5/torch-2.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:2ad79d0d8c2a20a37c5df6052ec67c2078a2c4e9a96dd3a8b55daaff6d28ea29", size = 212521189, upload-time = "2025-04-23T14:34:53.898Z" }, - { url = "https://files.pythonhosted.org/packages/dc/0b/b2b83f30b8e84a51bf4f96aa3f5f65fdf7c31c591cc519310942339977e2/torch-2.7.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:34e0168ed6de99121612d72224e59b2a58a83dae64999990eada7260c5dd582d", size = 68559462, upload-time = "2025-04-23T14:35:39.889Z" }, - { url = "https://files.pythonhosted.org/packages/40/da/7378d16cc636697f2a94f791cb496939b60fb8580ddbbef22367db2c2274/torch-2.7.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:2b7813e904757b125faf1a9a3154e1d50381d539ced34da1992f52440567c156", size = 99159397, upload-time = "2025-04-23T14:35:35.304Z" }, - { url = "https://files.pythonhosted.org/packages/0e/6b/87fcddd34df9f53880fa1f0c23af7b6b96c935856473faf3914323588c40/torch-2.7.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:fd5cfbb4c3bbadd57ad1b27d56a28008f8d8753733411a140fcfb84d7f933a25", size = 865183681, upload-time = "2025-04-23T14:34:21.802Z" }, - { url = "https://files.pythonhosted.org/packages/13/85/6c1092d4b06c3db1ed23d4106488750917156af0b24ab0a2d9951830b0e9/torch-2.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:58df8d5c2eeb81305760282b5069ea4442791a6bbf0c74d9069b7b3304ff8a37", size = 212520100, upload-time = "2025-04-23T14:35:27.473Z" }, - { url = "https://files.pythonhosted.org/packages/aa/3f/85b56f7e2abcfa558c5fbf7b11eb02d78a4a63e6aeee2bbae3bb552abea5/torch-2.7.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:0a8d43caa342b9986101ec5feb5bbf1d86570b5caa01e9cb426378311258fdde", size = 68569377, upload-time = "2025-04-23T14:35:20.361Z" }, - { url = "https://files.pythonhosted.org/packages/aa/5e/ac759f4c0ab7c01feffa777bd68b43d2ac61560a9770eeac074b450f81d4/torch-2.7.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:36a6368c7ace41ad1c0f69f18056020b6a5ca47bedaca9a2f3b578f5a104c26c", size = 99013250, upload-time = "2025-04-23T14:35:15.589Z" }, - { url = "https://files.pythonhosted.org/packages/9c/58/2d245b6f1ef61cf11dfc4aceeaacbb40fea706ccebac3f863890c720ab73/torch-2.7.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:15aab3e31c16feb12ae0a88dba3434a458874636f360c567caa6a91f6bfba481", size = 865042157, upload-time = "2025-04-23T14:32:56.011Z" }, - { url = "https://files.pythonhosted.org/packages/44/80/b353c024e6b624cd9ce1d66dcb9d24e0294680f95b369f19280e241a0159/torch-2.7.0-cp312-cp312-win_amd64.whl", hash = "sha256:f56d4b2510934e072bab3ab8987e00e60e1262fb238176168f5e0c43a1320c6d", size = 212482262, upload-time = "2025-04-23T14:35:03.527Z" }, - { url = "https://files.pythonhosted.org/packages/ee/8d/b2939e5254be932db1a34b2bd099070c509e8887e0c5a90c498a917e4032/torch-2.7.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:30b7688a87239a7de83f269333651d8e582afffce6f591fff08c046f7787296e", size = 68574294, upload-time = "2025-04-23T14:34:47.098Z" }, - { url = "https://files.pythonhosted.org/packages/14/24/720ea9a66c29151b315ea6ba6f404650834af57a26b2a04af23ec246b2d5/torch-2.7.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:868ccdc11798535b5727509480cd1d86d74220cfdc42842c4617338c1109a205", size = 99015553, upload-time = "2025-04-23T14:34:41.075Z" }, - { url = "https://files.pythonhosted.org/packages/4b/27/285a8cf12bd7cd71f9f211a968516b07dcffed3ef0be585c6e823675ab91/torch-2.7.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:9b52347118116cf3dff2ab5a3c3dd97c719eb924ac658ca2a7335652076df708", size = 865046389, upload-time = "2025-04-23T14:32:01.16Z" }, - { url = "https://files.pythonhosted.org/packages/74/c8/2ab2b6eadc45554af8768ae99668c5a8a8552e2012c7238ded7e9e4395e1/torch-2.7.0-cp313-cp313-win_amd64.whl", hash = "sha256:434cf3b378340efc87c758f250e884f34460624c0523fe5c9b518d205c91dd1b", size = 212490304, upload-time = "2025-04-23T14:33:57.108Z" }, - { url = "https://files.pythonhosted.org/packages/28/fd/74ba6fde80e2b9eef4237fe668ffae302c76f0e4221759949a632ca13afa/torch-2.7.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:edad98dddd82220465b106506bb91ee5ce32bd075cddbcf2b443dfaa2cbd83bf", size = 68856166, upload-time = "2025-04-23T14:34:04.012Z" }, - { url = "https://files.pythonhosted.org/packages/cb/b4/8df3f9fe6bdf59e56a0e538592c308d18638eb5f5dc4b08d02abb173c9f0/torch-2.7.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:2a885fc25afefb6e6eb18a7d1e8bfa01cc153e92271d980a49243b250d5ab6d9", size = 99091348, upload-time = "2025-04-23T14:33:48.975Z" }, - { url = "https://files.pythonhosted.org/packages/9d/f5/0bd30e9da04c3036614aa1b935a9f7e505a9e4f1f731b15e165faf8a4c74/torch-2.7.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:176300ff5bc11a5f5b0784e40bde9e10a35c4ae9609beed96b4aeb46a27f5fae", size = 865104023, upload-time = "2025-04-23T14:30:40.537Z" }, - { url = "https://files.pythonhosted.org/packages/d1/b7/2235d0c3012c596df1c8d39a3f4afc1ee1b6e318d469eda4c8bb68566448/torch-2.7.0-cp313-cp313t-win_amd64.whl", hash = "sha256:d0ca446a93f474985d81dc866fcc8dccefb9460a29a456f79d99c29a78a66993", size = 212750916, upload-time = "2025-04-23T14:32:22.91Z" }, - { url = "https://files.pythonhosted.org/packages/90/48/7e6477cf40d48cc0a61fa0d41ee9582b9a316b12772fcac17bc1a40178e7/torch-2.7.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:27f5007bdf45f7bb7af7f11d1828d5c2487e030690afb3d89a651fd7036a390e", size = 68575074, upload-time = "2025-04-23T14:32:38.136Z" }, + { url = "https://files.pythonhosted.org/packages/63/28/110f7274254f1b8476c561dada127173f994afa2b1ffc044efb773c15650/torch-2.8.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:0be92c08b44009d4131d1ff7a8060d10bafdb7ddcb7359ef8d8c5169007ea905", size = 102052793, upload-time = "2025-08-06T14:53:15.852Z" }, + { url = "https://files.pythonhosted.org/packages/70/1c/58da560016f81c339ae14ab16c98153d51c941544ae568da3cb5b1ceb572/torch-2.8.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:89aa9ee820bb39d4d72b794345cccef106b574508dd17dbec457949678c76011", size = 888025420, upload-time = "2025-08-06T14:54:18.014Z" }, + { url = "https://files.pythonhosted.org/packages/70/87/f69752d0dd4ba8218c390f0438130c166fa264a33b7025adb5014b92192c/torch-2.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:e8e5bf982e87e2b59d932769938b698858c64cc53753894be25629bdf5cf2f46", size = 241363614, upload-time = "2025-08-06T14:53:31.496Z" }, + { url = "https://files.pythonhosted.org/packages/ef/d6/e6d4c57e61c2b2175d3aafbfb779926a2cfd7c32eeda7c543925dceec923/torch-2.8.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:a3f16a58a9a800f589b26d47ee15aca3acf065546137fc2af039876135f4c760", size = 73611154, upload-time = "2025-08-06T14:53:10.919Z" }, + { url = "https://files.pythonhosted.org/packages/8f/c4/3e7a3887eba14e815e614db70b3b529112d1513d9dae6f4d43e373360b7f/torch-2.8.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:220a06fd7af8b653c35d359dfe1aaf32f65aa85befa342629f716acb134b9710", size = 102073391, upload-time = "2025-08-06T14:53:20.937Z" }, + { url = "https://files.pythonhosted.org/packages/5a/63/4fdc45a0304536e75a5e1b1bbfb1b56dd0e2743c48ee83ca729f7ce44162/torch-2.8.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:c12fa219f51a933d5f80eeb3a7a5d0cbe9168c0a14bbb4055f1979431660879b", size = 888063640, upload-time = "2025-08-06T14:55:05.325Z" }, + { url = "https://files.pythonhosted.org/packages/84/57/2f64161769610cf6b1c5ed782bd8a780e18a3c9d48931319f2887fa9d0b1/torch-2.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:8c7ef765e27551b2fbfc0f41bcf270e1292d9bf79f8e0724848b1682be6e80aa", size = 241366752, upload-time = "2025-08-06T14:53:38.692Z" }, + { url = "https://files.pythonhosted.org/packages/a4/5e/05a5c46085d9b97e928f3f037081d3d2b87fb4b4195030fc099aaec5effc/torch-2.8.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:5ae0524688fb6707c57a530c2325e13bb0090b745ba7b4a2cd6a3ce262572916", size = 73621174, upload-time = "2025-08-06T14:53:25.44Z" }, + { url = "https://files.pythonhosted.org/packages/49/0c/2fd4df0d83a495bb5e54dca4474c4ec5f9c62db185421563deeb5dabf609/torch-2.8.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:e2fab4153768d433f8ed9279c8133a114a034a61e77a3a104dcdf54388838705", size = 101906089, upload-time = "2025-08-06T14:53:52.631Z" }, + { url = "https://files.pythonhosted.org/packages/99/a8/6acf48d48838fb8fe480597d98a0668c2beb02ee4755cc136de92a0a956f/torch-2.8.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b2aca0939fb7e4d842561febbd4ffda67a8e958ff725c1c27e244e85e982173c", size = 887913624, upload-time = "2025-08-06T14:56:44.33Z" }, + { url = "https://files.pythonhosted.org/packages/af/8a/5c87f08e3abd825c7dfecef5a0f1d9aa5df5dd0e3fd1fa2f490a8e512402/torch-2.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:2f4ac52f0130275d7517b03a33d2493bab3693c83dcfadf4f81688ea82147d2e", size = 241326087, upload-time = "2025-08-06T14:53:46.503Z" }, + { url = "https://files.pythonhosted.org/packages/be/66/5c9a321b325aaecb92d4d1855421e3a055abd77903b7dab6575ca07796db/torch-2.8.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:619c2869db3ada2c0105487ba21b5008defcc472d23f8b80ed91ac4a380283b0", size = 73630478, upload-time = "2025-08-06T14:53:57.144Z" }, + { url = "https://files.pythonhosted.org/packages/10/4e/469ced5a0603245d6a19a556e9053300033f9c5baccf43a3d25ba73e189e/torch-2.8.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:2b2f96814e0345f5a5aed9bf9734efa913678ed19caf6dc2cddb7930672d6128", size = 101936856, upload-time = "2025-08-06T14:54:01.526Z" }, + { url = "https://files.pythonhosted.org/packages/16/82/3948e54c01b2109238357c6f86242e6ecbf0c63a1af46906772902f82057/torch-2.8.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:65616ca8ec6f43245e1f5f296603e33923f4c30f93d65e103d9e50c25b35150b", size = 887922844, upload-time = "2025-08-06T14:55:50.78Z" }, + { url = "https://files.pythonhosted.org/packages/e3/54/941ea0a860f2717d86a811adf0c2cd01b3983bdd460d0803053c4e0b8649/torch-2.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:659df54119ae03e83a800addc125856effda88b016dfc54d9f65215c3975be16", size = 241330968, upload-time = "2025-08-06T14:54:45.293Z" }, + { url = "https://files.pythonhosted.org/packages/de/69/8b7b13bba430f5e21d77708b616f767683629fc4f8037564a177d20f90ed/torch-2.8.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:1a62a1ec4b0498930e2543535cf70b1bef8c777713de7ceb84cd79115f553767", size = 73915128, upload-time = "2025-08-06T14:54:34.769Z" }, + { url = "https://files.pythonhosted.org/packages/15/0e/8a800e093b7f7430dbaefa80075aee9158ec22e4c4fc3c1a66e4fb96cb4f/torch-2.8.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:83c13411a26fac3d101fe8035a6b0476ae606deb8688e904e796a3534c197def", size = 102020139, upload-time = "2025-08-06T14:54:39.047Z" }, + { url = "https://files.pythonhosted.org/packages/4a/15/5e488ca0bc6162c86a33b58642bc577c84ded17c7b72d97e49b5833e2d73/torch-2.8.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:8f0a9d617a66509ded240add3754e462430a6c1fc5589f86c17b433dd808f97a", size = 887990692, upload-time = "2025-08-06T14:56:18.286Z" }, + { url = "https://files.pythonhosted.org/packages/b4/a8/6a04e4b54472fc5dba7ca2341ab219e529f3c07b6941059fbf18dccac31f/torch-2.8.0-cp313-cp313t-win_amd64.whl", hash = "sha256:a7242b86f42be98ac674b88a4988643b9bc6145437ec8f048fea23f72feb5eca", size = 241603453, upload-time = "2025-08-06T14:55:22.945Z" }, + { url = "https://files.pythonhosted.org/packages/04/6e/650bb7f28f771af0cb791b02348db8b7f5f64f40f6829ee82aa6ce99aabe/torch-2.8.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:7b677e17f5a3e69fdef7eb3b9da72622f8d322692930297e4ccb52fefc6c8211", size = 73632395, upload-time = "2025-08-06T14:55:28.645Z" }, ] [[package]] @@ -3136,7 +3620,7 @@ wheels = [ [[package]] name = "torchvision" -version = "0.22.0" +version = "0.23.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, @@ -3144,26 +3628,26 @@ dependencies = [ { name = "torch" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/eb/03/a514766f068b088180f273913e539d08e830be3ae46ef8577ea62584a27c/torchvision-0.22.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:72256f1d7ff510b16c9fb4dd488584d0693f40c792f286a9620674438a81ccca", size = 1947829, upload-time = "2025-04-23T14:42:04.652Z" }, - { url = "https://files.pythonhosted.org/packages/a3/e5/ec4b52041cd8c440521b75864376605756bd2d112d6351ea6a1ab25008c1/torchvision-0.22.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:810ea4af3bc63cf39e834f91f4218ff5999271caaffe2456247df905002bd6c0", size = 2512604, upload-time = "2025-04-23T14:41:56.515Z" }, - { url = "https://files.pythonhosted.org/packages/e7/9e/e898a377e674da47e95227f3d7be2c49550ce381eebd8c7831c1f8bb7d39/torchvision-0.22.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:6fbca169c690fa2b9b8c39c0ad76d5b8992296d0d03df01e11df97ce12b4e0ac", size = 7446399, upload-time = "2025-04-23T14:41:49.793Z" }, - { url = "https://files.pythonhosted.org/packages/c7/ec/2cdb90c6d9d61410b3df9ca67c210b60bf9b07aac31f800380b20b90386c/torchvision-0.22.0-cp310-cp310-win_amd64.whl", hash = "sha256:8c869df2e8e00f7b1d80a34439e6d4609b50fe3141032f50b38341ec2b59404e", size = 1716700, upload-time = "2025-04-23T14:42:03.562Z" }, - { url = "https://files.pythonhosted.org/packages/b1/43/28bc858b022f6337326d75f4027d2073aad5432328f01ee1236d847f1b82/torchvision-0.22.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:191ea28321fc262d8aa1a7fe79c41ff2848864bf382f9f6ea45c41dde8313792", size = 1947828, upload-time = "2025-04-23T14:42:00.439Z" }, - { url = "https://files.pythonhosted.org/packages/7e/71/ce9a303b94e64fe25d534593522ffc76848c4e64c11e4cbe9f6b8d537210/torchvision-0.22.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:6c5620e10ffe388eb6f4744962106ed7cf1508d26e6fdfa0c10522d3249aea24", size = 2514016, upload-time = "2025-04-23T14:41:48.566Z" }, - { url = "https://files.pythonhosted.org/packages/09/42/6908bff012a1dcc4fc515e52339652d7f488e208986542765c02ea775c2f/torchvision-0.22.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:ce292701c77c64dd3935e3e31c722c3b8b176a75f76dc09b804342efc1db5494", size = 7447546, upload-time = "2025-04-23T14:41:47.297Z" }, - { url = "https://files.pythonhosted.org/packages/e4/cf/8f9305cc0ea26badbbb3558ecae54c04a245429f03168f7fad502f8a5b25/torchvision-0.22.0-cp311-cp311-win_amd64.whl", hash = "sha256:e4017b5685dbab4250df58084f07d95e677b2f3ed6c2e507a1afb8eb23b580ca", size = 1716472, upload-time = "2025-04-23T14:42:01.999Z" }, - { url = "https://files.pythonhosted.org/packages/cb/ea/887d1d61cf4431a46280972de665f350af1898ce5006cd046326e5d0a2f2/torchvision-0.22.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:31c3165418fe21c3d81fe3459e51077c2f948801b8933ed18169f54652796a0f", size = 1947826, upload-time = "2025-04-23T14:41:59.188Z" }, - { url = "https://files.pythonhosted.org/packages/72/ef/21f8b6122e13ae045b8e49658029c695fd774cd21083b3fa5c3f9c5d3e35/torchvision-0.22.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:8f116bc82e0c076e70ba7776e611ed392b9666aa443662e687808b08993d26af", size = 2514571, upload-time = "2025-04-23T14:41:53.458Z" }, - { url = "https://files.pythonhosted.org/packages/7c/48/5f7617f6c60d135f86277c53f9d5682dfa4e66f4697f505f1530e8b69fb1/torchvision-0.22.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:ce4dc334ebd508de2c534817c9388e928bc2500cf981906ae8d6e2ca3bf4727a", size = 7446522, upload-time = "2025-04-23T14:41:34.9Z" }, - { url = "https://files.pythonhosted.org/packages/99/94/a015e93955f5d3a68689cc7c385a3cfcd2d62b84655d18b61f32fb04eb67/torchvision-0.22.0-cp312-cp312-win_amd64.whl", hash = "sha256:24b8c9255c209ca419cc7174906da2791c8b557b75c23496663ec7d73b55bebf", size = 1716664, upload-time = "2025-04-23T14:41:58.019Z" }, - { url = "https://files.pythonhosted.org/packages/e1/2a/9b34685599dcb341d12fc2730055155623db7a619d2415a8d31f17050952/torchvision-0.22.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ece17995857dd328485c9c027c0b20ffc52db232e30c84ff6c95ab77201112c5", size = 1947823, upload-time = "2025-04-23T14:41:39.956Z" }, - { url = "https://files.pythonhosted.org/packages/77/77/88f64879483d66daf84f1d1c4d5c31ebb08e640411139042a258d5f7dbfe/torchvision-0.22.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:471c6dd75bb984c6ebe4f60322894a290bf3d4b195e769d80754f3689cd7f238", size = 2471592, upload-time = "2025-04-23T14:41:54.991Z" }, - { url = "https://files.pythonhosted.org/packages/f7/82/2f813eaae7c1fae1f9d9e7829578f5a91f39ef48d6c1c588a8900533dd3d/torchvision-0.22.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:2b839ac0610a38f56bef115ee5b9eaca5f9c2da3c3569a68cc62dbcc179c157f", size = 7446333, upload-time = "2025-04-23T14:41:36.603Z" }, - { url = "https://files.pythonhosted.org/packages/58/19/ca7a4f8907a56351dfe6ae0a708f4e6b3569b5c61d282e3e7f61cf42a4ce/torchvision-0.22.0-cp313-cp313-win_amd64.whl", hash = "sha256:4ada1c08b2f761443cd65b7c7b4aec9e2fc28f75b0d4e1b1ebc9d3953ebccc4d", size = 1716693, upload-time = "2025-04-23T14:41:41.031Z" }, - { url = "https://files.pythonhosted.org/packages/6f/a7/f43e9c8d13118b4ffbaebea664c9338ab20fa115a908125afd2238ff16e7/torchvision-0.22.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:cdc96daa4658b47ce9384154c86ed1e70cba9d972a19f5de6e33f8f94a626790", size = 2137621, upload-time = "2025-04-23T14:41:51.427Z" }, - { url = "https://files.pythonhosted.org/packages/6a/9a/2b59f5758ba7e3f23bc84e16947493bbce97392ec6d18efba7bdf0a3b10e/torchvision-0.22.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:753d3c84eeadd5979a33b3b73a25ecd0aa4af44d6b45ed2c70d44f5e0ac68312", size = 2476555, upload-time = "2025-04-23T14:41:38.357Z" }, - { url = "https://files.pythonhosted.org/packages/7d/40/a7bc2ab9b1e56d10a7fd9ae83191bb425fa308caa23d148f1c568006e02c/torchvision-0.22.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:b30e3ed29e4a61f7499bca50f57d8ebd23dfc52b14608efa17a534a55ee59a03", size = 7617924, upload-time = "2025-04-23T14:41:42.709Z" }, - { url = "https://files.pythonhosted.org/packages/c1/7b/30d423bdb2546250d719d7821aaf9058cc093d165565b245b159c788a9dd/torchvision-0.22.0-cp313-cp313t-win_amd64.whl", hash = "sha256:e5d680162694fac4c8a374954e261ddfb4eb0ce103287b0f693e4e9c579ef957", size = 1638621, upload-time = "2025-04-23T14:41:46.06Z" }, + { url = "https://files.pythonhosted.org/packages/4d/49/5ad5c3ff4920be0adee9eb4339b4fb3b023a0fc55b9ed8dbc73df92946b8/torchvision-0.23.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:7266871daca00ad46d1c073e55d972179d12a58fa5c9adec9a3db9bbed71284a", size = 1856885, upload-time = "2025-08-06T14:57:55.024Z" }, + { url = "https://files.pythonhosted.org/packages/25/44/ddd56d1637bac42a8c5da2c8c440d8a28c431f996dd9790f32dd9a96ca6e/torchvision-0.23.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:31c583ba27426a3a04eca8c05450524105c1564db41be6632f7536ef405a6de2", size = 2394251, upload-time = "2025-08-06T14:58:01.725Z" }, + { url = "https://files.pythonhosted.org/packages/93/f3/3cdf55bbf0f737304d997561c34ab0176222e0496b6743b0feab5995182c/torchvision-0.23.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:3932bf67256f2d095ce90a9f826f6033694c818856f4bb26794cf2ce64253e53", size = 8627497, upload-time = "2025-08-06T14:58:09.317Z" }, + { url = "https://files.pythonhosted.org/packages/97/90/02afe57c3ef4284c5cf89d3b7ae203829b3a981f72b93a7dd2a3fd2c83c1/torchvision-0.23.0-cp310-cp310-win_amd64.whl", hash = "sha256:83ee5bf827d61a8af14620c0a61d8608558638ac9c3bac8adb7b27138e2147d1", size = 1600760, upload-time = "2025-08-06T14:57:56.783Z" }, + { url = "https://files.pythonhosted.org/packages/f0/d7/15d3d7bd8d0239211b21673d1bac7bc345a4ad904a8e25bb3fd8a9cf1fbc/torchvision-0.23.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:49aa20e21f0c2bd458c71d7b449776cbd5f16693dd5807195a820612b8a229b7", size = 1856884, upload-time = "2025-08-06T14:58:00.237Z" }, + { url = "https://files.pythonhosted.org/packages/dd/14/7b44fe766b7d11e064c539d92a172fa9689a53b69029e24f2f1f51e7dc56/torchvision-0.23.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:01dc33ee24c79148aee7cdbcf34ae8a3c9da1674a591e781577b716d233b1fa6", size = 2395543, upload-time = "2025-08-06T14:58:04.373Z" }, + { url = "https://files.pythonhosted.org/packages/79/9c/fcb09aff941c8147d9e6aa6c8f67412a05622b0c750bcf796be4c85a58d4/torchvision-0.23.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:35c27941831b653f5101edfe62c03d196c13f32139310519e8228f35eae0e96a", size = 8628388, upload-time = "2025-08-06T14:58:07.802Z" }, + { url = "https://files.pythonhosted.org/packages/93/40/3415d890eb357b25a8e0a215d32365a88ecc75a283f75c4e919024b22d97/torchvision-0.23.0-cp311-cp311-win_amd64.whl", hash = "sha256:09bfde260e7963a15b80c9e442faa9f021c7e7f877ac0a36ca6561b367185013", size = 1600741, upload-time = "2025-08-06T14:57:59.158Z" }, + { url = "https://files.pythonhosted.org/packages/df/1d/0ea0b34bde92a86d42620f29baa6dcbb5c2fc85990316df5cb8f7abb8ea2/torchvision-0.23.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e0e2c04a91403e8dd3af9756c6a024a1d9c0ed9c0d592a8314ded8f4fe30d440", size = 1856885, upload-time = "2025-08-06T14:58:06.503Z" }, + { url = "https://files.pythonhosted.org/packages/e2/00/2f6454decc0cd67158c7890364e446aad4b91797087a57a78e72e1a8f8bc/torchvision-0.23.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:6dd7c4d329a0e03157803031bc856220c6155ef08c26d4f5bbac938acecf0948", size = 2396614, upload-time = "2025-08-06T14:58:03.116Z" }, + { url = "https://files.pythonhosted.org/packages/e4/b5/3e580dcbc16f39a324f3dd71b90edbf02a42548ad44d2b4893cc92b1194b/torchvision-0.23.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:4e7d31c43bc7cbecbb1a5652ac0106b436aa66e26437585fc2c4b2cf04d6014c", size = 8627108, upload-time = "2025-08-06T14:58:12.956Z" }, + { url = "https://files.pythonhosted.org/packages/82/c1/c2fe6d61e110a8d0de2f94276899a2324a8f1e6aee559eb6b4629ab27466/torchvision-0.23.0-cp312-cp312-win_amd64.whl", hash = "sha256:a2e45272abe7b8bf0d06c405e78521b5757be1bd0ed7e5cd78120f7fdd4cbf35", size = 1600723, upload-time = "2025-08-06T14:57:57.986Z" }, + { url = "https://files.pythonhosted.org/packages/91/37/45a5b9407a7900f71d61b2b2f62db4b7c632debca397f205fdcacb502780/torchvision-0.23.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:1c37e325e09a184b730c3ef51424f383ec5745378dc0eca244520aca29722600", size = 1856886, upload-time = "2025-08-06T14:58:05.491Z" }, + { url = "https://files.pythonhosted.org/packages/ac/da/a06c60fc84fc849377cf035d3b3e9a1c896d52dbad493b963c0f1cdd74d0/torchvision-0.23.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:2f7fd6c15f3697e80627b77934f77705f3bc0e98278b989b2655de01f6903e1d", size = 2353112, upload-time = "2025-08-06T14:58:26.265Z" }, + { url = "https://files.pythonhosted.org/packages/a0/27/5ce65ba5c9d3b7d2ccdd79892ab86a2f87ac2ca6638f04bb0280321f1a9c/torchvision-0.23.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:a76fafe113b2977be3a21bf78f115438c1f88631d7a87203acb3dd6ae55889e6", size = 8627658, upload-time = "2025-08-06T14:58:15.999Z" }, + { url = "https://files.pythonhosted.org/packages/1f/e4/028a27b60aa578a2fa99d9d7334ff1871bb17008693ea055a2fdee96da0d/torchvision-0.23.0-cp313-cp313-win_amd64.whl", hash = "sha256:07d069cb29691ff566e3b7f11f20d91044f079e1dbdc9d72e0655899a9b06938", size = 1600749, upload-time = "2025-08-06T14:58:10.719Z" }, + { url = "https://files.pythonhosted.org/packages/05/35/72f91ad9ac7c19a849dedf083d347dc1123f0adeb401f53974f84f1d04c8/torchvision-0.23.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:2df618e1143805a7673aaf82cb5720dd9112d4e771983156aaf2ffff692eebf9", size = 2047192, upload-time = "2025-08-06T14:58:11.813Z" }, + { url = "https://files.pythonhosted.org/packages/1d/9d/406cea60a9eb9882145bcd62a184ee61e823e8e1d550cdc3c3ea866a9445/torchvision-0.23.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:2a3299d2b1d5a7aed2d3b6ffb69c672ca8830671967eb1cee1497bacd82fe47b", size = 2359295, upload-time = "2025-08-06T14:58:17.469Z" }, + { url = "https://files.pythonhosted.org/packages/2b/f4/34662f71a70fa1e59de99772142f22257ca750de05ccb400b8d2e3809c1d/torchvision-0.23.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:76bc4c0b63d5114aa81281390f8472a12a6a35ce9906e67ea6044e5af4cab60c", size = 8800474, upload-time = "2025-08-06T14:58:22.53Z" }, + { url = "https://files.pythonhosted.org/packages/6e/f5/b5a2d841a8d228b5dbda6d524704408e19e7ca6b7bb0f24490e081da1fa1/torchvision-0.23.0-cp313-cp313t-win_amd64.whl", hash = "sha256:b9e2dabf0da9c8aa9ea241afb63a8f3e98489e706b22ac3f30416a1be377153b", size = 1527667, upload-time = "2025-08-06T14:58:14.446Z" }, ] [[package]] @@ -3207,17 +3691,17 @@ wheels = [ [[package]] name = "triton" -version = "3.3.0" +version = "3.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "setuptools", marker = "sys_platform == 'linux'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/76/04/d54d3a6d077c646624dc9461b0059e23fd5d30e0dbe67471e3654aec81f9/triton-3.3.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fad99beafc860501d7fcc1fb7045d9496cbe2c882b1674640304949165a916e7", size = 156441993, upload-time = "2025-04-09T20:27:25.107Z" }, - { url = "https://files.pythonhosted.org/packages/3c/c5/4874a81131cc9e934d88377fbc9d24319ae1fb540f3333b4e9c696ebc607/triton-3.3.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3161a2bf073d6b22c4e2f33f951f3e5e3001462b2570e6df9cd57565bdec2984", size = 156528461, upload-time = "2025-04-09T20:27:32.599Z" }, - { url = "https://files.pythonhosted.org/packages/11/53/ce18470914ab6cfbec9384ee565d23c4d1c55f0548160b1c7b33000b11fd/triton-3.3.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b68c778f6c4218403a6bd01be7484f6dc9e20fe2083d22dd8aef33e3b87a10a3", size = 156504509, upload-time = "2025-04-09T20:27:40.413Z" }, - { url = "https://files.pythonhosted.org/packages/7d/74/4bf2702b65e93accaa20397b74da46fb7a0356452c1bb94dbabaf0582930/triton-3.3.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:47bc87ad66fa4ef17968299acacecaab71ce40a238890acc6ad197c3abe2b8f1", size = 156516468, upload-time = "2025-04-09T20:27:48.196Z" }, - { url = "https://files.pythonhosted.org/packages/0a/93/f28a696fa750b9b608baa236f8225dd3290e5aff27433b06143adc025961/triton-3.3.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ce4700fc14032af1e049005ae94ba908e71cd6c2df682239aed08e49bc71b742", size = 156580729, upload-time = "2025-04-09T20:27:55.424Z" }, + { url = "https://files.pythonhosted.org/packages/62/ee/0ee5f64a87eeda19bbad9bc54ae5ca5b98186ed00055281fd40fb4beb10e/triton-3.4.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7ff2785de9bc02f500e085420273bb5cc9c9bb767584a4aa28d6e360cec70128", size = 155430069, upload-time = "2025-07-30T19:58:21.715Z" }, + { url = "https://files.pythonhosted.org/packages/7d/39/43325b3b651d50187e591eefa22e236b2981afcebaefd4f2fc0ea99df191/triton-3.4.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b70f5e6a41e52e48cfc087436c8a28c17ff98db369447bcaff3b887a3ab4467", size = 155531138, upload-time = "2025-07-30T19:58:29.908Z" }, + { url = "https://files.pythonhosted.org/packages/d0/66/b1eb52839f563623d185f0927eb3530ee4d5ffe9d377cdaf5346b306689e/triton-3.4.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:31c1d84a5c0ec2c0f8e8a072d7fd150cab84a9c239eaddc6706c081bfae4eb04", size = 155560068, upload-time = "2025-07-30T19:58:37.081Z" }, + { url = "https://files.pythonhosted.org/packages/30/7b/0a685684ed5322d2af0bddefed7906674f67974aa88b0fae6e82e3b766f6/triton-3.4.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:00be2964616f4c619193cb0d1b29a99bd4b001d7dc333816073f92cf2a8ccdeb", size = 155569223, upload-time = "2025-07-30T19:58:44.017Z" }, + { url = "https://files.pythonhosted.org/packages/20/63/8cb444ad5cdb25d999b7d647abac25af0ee37d292afc009940c05b82dda0/triton-3.4.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7936b18a3499ed62059414d7df563e6c163c5e16c3773678a3ee3d417865035d", size = 155659780, upload-time = "2025-07-30T19:58:51.171Z" }, ] [[package]]