From 5896940b74ff69e46e76528e95bb2f65411be060 Mon Sep 17 00:00:00 2001 From: Bruno Alvisio Date: Fri, 22 Aug 2025 10:38:27 +0000 Subject: [PATCH] Add lora-checkpointing option to evo2 predict Signed-off-by: Bruno Alvisio --- .../src/bionemo/evo2/run/predict.py | 150 ++++++++++-------- .../src/bionemo/evo2/run/train.py | 23 ++- .../tests/bionemo/evo2/run/common.py | 11 ++ .../tests/bionemo/evo2/run/test_predict.py | 85 ++++++++++ 4 files changed, 205 insertions(+), 64 deletions(-) diff --git a/sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py b/sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py index b2105e541b..33d9686bd9 100644 --- a/sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py +++ b/sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py @@ -41,6 +41,7 @@ # Add import for Mamba models from bionemo.evo2.models.mamba import MAMBA_MODEL_OPTIONS, MambaModel +from bionemo.evo2.models.peft import Evo2LoRA from bionemo.llm.data import collate from bionemo.llm.lightning import LightningPassthroughPredictionMixin from bionemo.llm.utils.callbacks import PredictionWriter @@ -159,6 +160,13 @@ def parse_args(): "know a model was trained with a specific interpolation factor for ROPE, provide it here, it can make a big " "difference in accuracy.", ) + ap.add_argument( + "--lora-checkpoint-path", + type=Path, + required=False, + default=None, + help="Path to the lora states to restore from.", + ) return ap.parse_args() @@ -261,6 +269,11 @@ def predict_step(self, batch, batch_idx: int | None = None) -> Tensor | dict[str class HyenaPredictor(BasePredictor, HyenaModel): """A predictor for the Hyena model. This adds in the predict step and the passthrough method.""" + def configure_model(self, *args, **kwargs) -> None: + """Configure the model.""" + super().configure_model(*args, **kwargs) + self.trainer.strategy._init_model_parallel = True + class MambaPredictor(BasePredictor, MambaModel): """Mamba model for prediction with additional metrics.""" @@ -397,6 +410,7 @@ def predict( num_layers: int | None = None, seq_len_interpolation_factor: int | None = None, files_per_subdir: int | None = None, + lora_checkpoint_path: Path | None = None, ): """Inference workflow for Evo2. @@ -424,6 +438,77 @@ def predict( ) global_batch_size = micro_batch_size * world_size // model_parallel_size + callbacks = [ + PredictionWriter( + output_dir=output_dir, + write_interval=write_interval, + batch_dim_key_defaults={"token_logits": 0}, + seq_dim_key_defaults={"token_logits": 1}, + files_per_subdir=files_per_subdir, + save_all_model_parallel_ranks=False, # only write one copy of predictions. + ) + ] + + # The following two config options are really only used for testing, but may also be useful for getting output from + # specific layers of the model. + config_modifiers_init = {} + if hybrid_override_pattern is not None: + config_modifiers_init["hybrid_override_pattern"] = hybrid_override_pattern + if num_layers is not None: + config_modifiers_init["num_layers"] = num_layers + + tokenizer = get_nmt_tokenizer("byte-level") + + # Select model config based on model type + if model_type == "hyena": + if "-1m" in model_size and "nv" not in model_size and seq_len_interpolation_factor is None: + # TODO remove this override once we add this as a default upstream in NeMo. + # if you see this, just check the pointed to model option for the 1m model in nemo and see if it already + # has this option set. + config_modifiers_init["seq_len_interpolation_factor"] = 128 + + if model_size not in HYENA_MODEL_OPTIONS: + raise ValueError(f"Invalid model size for Hyena: {model_size}") + config = HYENA_MODEL_OPTIONS[model_size]( + forward_step_fn=hyena_predict_forward_step, + data_step_fn=hyena_predict_data_step, # , attention_backend=AttnBackend.fused, + distribute_saved_activations=False if sequence_parallel and tensor_parallel_size > 1 else True, + # Only use vortex style FP8 in the model config if using FP8 and not full FP8. This will only apply FP8 to + # the projection layer of the hyena mixer. + vortex_style_fp8=fp8 and not full_fp8, + **config_modifiers_init, + ) + + if lora_checkpoint_path: + model_transform = Evo2LoRA(peft_ckpt_path=str(lora_checkpoint_path)) + callbacks.append(model_transform) + else: + model_transform = None + + model = HyenaPredictor( + config, + tokenizer=tokenizer, + output_log_prob_seqs=output_log_prob_seqs, + log_prob_collapse_option=log_prob_collapse_option, + model_transform=model_transform, + ) + else: # mamba + if model_size not in MAMBA_MODEL_OPTIONS: + raise ValueError(f"Invalid model size for Mamba: {model_size}") + config = MAMBA_MODEL_OPTIONS[model_size]( + forward_step_fn=hyena_predict_forward_step, # Can reuse the same forward steps + data_step_fn=hyena_predict_data_step, + distribute_saved_activations=False if sequence_parallel and tensor_parallel_size > 1 else True, + **config_modifiers_init, + ) + + model = MambaPredictor( + config, + tokenizer=tokenizer, + output_log_prob_seqs=output_log_prob_seqs, + log_prob_collapse_option=log_prob_collapse_option, + ) + # Create PTL trainer. trainer = nl.Trainer( accelerator="gpu", @@ -451,16 +536,7 @@ def predict( log_every_n_steps=1, limit_val_batches=10, num_sanity_val_steps=0, - callbacks=[ - PredictionWriter( - output_dir=output_dir, - write_interval=write_interval, - batch_dim_key_defaults={"token_logits": 0}, - seq_dim_key_defaults={"token_logits": 1}, - files_per_subdir=files_per_subdir, - save_all_model_parallel_ranks=False, # only write one copy of predictions. - ) - ], + callbacks=callbacks, plugins=nl.MegatronMixedPrecision( precision="bf16-mixed", params_dtype=torch.bfloat16, @@ -471,42 +547,6 @@ def predict( fp8_amax_compute_algo="max" if fp8 and full_fp8 else "most_recent", ), ) - # The following two config options are really only used for testing, but may also be useful for getting output from - # specific layers of the model. - config_modifiers_init = {} - if hybrid_override_pattern is not None: - config_modifiers_init["hybrid_override_pattern"] = hybrid_override_pattern - if num_layers is not None: - config_modifiers_init["num_layers"] = num_layers - # Select model config based on model type - if model_type == "hyena": - if "-1m" in model_size and "nv" not in model_size and seq_len_interpolation_factor is None: - # TODO remove this override once we add this as a default upstream in NeMo. - # if you see this, just check the pointed to model option for the 1m model in nemo and see if it already - # has this option set. - config_modifiers_init["seq_len_interpolation_factor"] = 128 - - if model_size not in HYENA_MODEL_OPTIONS: - raise ValueError(f"Invalid model size for Hyena: {model_size}") - config = HYENA_MODEL_OPTIONS[model_size]( - forward_step_fn=hyena_predict_forward_step, - data_step_fn=hyena_predict_data_step, # , attention_backend=AttnBackend.fused, - distribute_saved_activations=False if sequence_parallel and tensor_parallel_size > 1 else True, - # Only use vortex style FP8 in the model config if using FP8 and not full FP8. This will only apply FP8 to - # the projection layer of the hyena mixer. - vortex_style_fp8=fp8 and not full_fp8, - **config_modifiers_init, - ) - else: # mamba - if model_size not in MAMBA_MODEL_OPTIONS: - raise ValueError(f"Invalid model size for Mamba: {model_size}") - config = MAMBA_MODEL_OPTIONS[model_size]( - forward_step_fn=hyena_predict_forward_step, # Can reuse the same forward steps - data_step_fn=hyena_predict_data_step, - distribute_saved_activations=False if sequence_parallel and tensor_parallel_size > 1 else True, - **config_modifiers_init, - ) - trainer.strategy._setup_optimizers = False nemo_logger = NeMoLogger(log_dir=work_dir) @@ -518,23 +558,6 @@ def predict( resume_from_path=str(ckpt_dir), restore_config=None, ) - tokenizer = get_nmt_tokenizer("byte-level") - - # Create appropriate model based on type - if model_type == "hyena": - model = HyenaPredictor( - config, - tokenizer=tokenizer, - output_log_prob_seqs=output_log_prob_seqs, - log_prob_collapse_option=log_prob_collapse_option, - ) - else: # mamba - model = MambaPredictor( - config, - tokenizer=tokenizer, - output_log_prob_seqs=output_log_prob_seqs, - log_prob_collapse_option=log_prob_collapse_option, - ) resume.setup(trainer, model) # this pulls weights from the starting checkpoint. @@ -573,6 +596,7 @@ def main(): num_layers=args.num_layers, files_per_subdir=args.files_per_subdir, write_interval=args.write_interval, + lora_checkpoint_path=args.lora_checkpoint_path, ) diff --git a/sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py b/sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py index 87ae5bddf2..957eef91e0 100644 --- a/sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py +++ b/sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py @@ -636,6 +636,18 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace: default=False, help="Enable CUDA memory cleanup before validation to prevent initialization errors.", ) + parser.add_argument( + "--lora-alpha", + type=int, + default=None, + help="Alpha parameter for LoRA fine-tuning.", + ) + parser.add_argument( + "--lora-dim", + type=int, + default=None, + help="Dim parameter for LoRA fine-tuning.", + ) recompute_group = parser.add_mutually_exclusive_group(required=False) recompute_group.add_argument("--no-activation-checkpointing", action="store_true", default=False) @@ -801,7 +813,16 @@ def train(args: argparse.Namespace) -> nl.Trainer: # Lora adaptors configuration lora_transform = None if args.lora_finetune: - lora_transform = Evo2LoRA(peft_ckpt_path=args.lora_checkpoint_path) + lora_kwargs = { + k: v + for k, v in { + "alpha": args.lora_alpha, + "dim": args.lora_dim, + }.items() + if v is not None + } + + lora_transform = Evo2LoRA(peft_ckpt_path=args.lora_checkpoint_path, **lora_kwargs) model = llm.HyenaModel(model_config, tokenizer=data_module.tokenizer, model_transform=lora_transform) elif model_type == "mamba": # mamba diff --git a/sub-packages/bionemo-evo2/tests/bionemo/evo2/run/common.py b/sub-packages/bionemo-evo2/tests/bionemo/evo2/run/common.py index 580e0f2b57..92ebd07afc 100644 --- a/sub-packages/bionemo-evo2/tests/bionemo/evo2/run/common.py +++ b/sub-packages/bionemo-evo2/tests/bionemo/evo2/run/common.py @@ -58,3 +58,14 @@ def small_training_finetune_cmd( f"{'--global-batch-size ' + str(global_batch_size) if global_batch_size is not None else ''}" ) return cmd + + +def predict_cmd(ckpt_dir: str, output_dir: str, fasta_file_path: str, additional_args: str = ""): + """Command fro predict.""" + cmd = ( + f"predict_evo2 --fasta {fasta_file_path} --ckpt-dir {ckpt_dir} --output-dir {output_dir} " + "--model-size 1b_nv --num-layers 4 --hybrid-override-pattern SDH* --tensor-parallel-size 1 " + f"--pipeline-model-parallel-size 1 --context-parallel-size 1 {additional_args}" + ) + + return cmd diff --git a/sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_predict.py b/sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_predict.py index d9d84b8d2e..744d0db372 100644 --- a/sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_predict.py +++ b/sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_predict.py @@ -31,8 +31,11 @@ from bionemo.core.data.load import load from bionemo.llm.lightning import batch_collator from bionemo.testing.data.fasta import ALU_SEQUENCE, create_fasta_file +from bionemo.testing.subprocess_utils import run_command_in_subprocess from bionemo.testing.torch import check_fp8_support +from .common import predict_cmd, small_training_finetune_cmd + def is_a6000_gpu() -> bool: # Check if any of the visible GPUs is an A6000 @@ -364,3 +367,85 @@ def test_predict_evo2_equivalent_with_log_probs( else: rel = 1e-6 assert log_probs.item() == pytest.approx(baseline_predictions_7b_1m_results[original_idx.item()], rel=rel) + + +@pytest.mark.timeout(512) +@pytest.mark.slow +def test_different_results_with_without_peft(tmp_path): + try: + base_model_checkpoint_path = load("evo2/1b-8k:1.0") + except ValueError as e: + if e.args[0].endswith("does not have an NGC URL."): + raise ValueError( + "Please re-run test with `BIONEMO_DATA_SOURCE=pbss py.test ...`, " + "one or more files are missing from ngc." + ) + else: + raise e + + num_steps = 2 + + result_dir = tmp_path / "lora_finetune" + + # Note: The command assumes that `train_evo2` is in your PATH. + command_finetune = small_training_finetune_cmd( + result_dir, + max_steps=num_steps, + val_check=num_steps, + prev_ckpt=base_model_checkpoint_path, + create_tflops_callback=False, + additional_args="--lora-finetune", + ) + stdout_finetune: str = run_command_in_subprocess(command=command_finetune, path=str(tmp_path)) + assert "Restoring model weights from RestoreConfig(path='" in stdout_finetune + assert "Loading adapters from" not in stdout_finetune + + # Check if checkpoints dir exists + checkpoints_dir = result_dir / "evo2" / "checkpoints" + assert checkpoints_dir.exists(), "Checkpoints folder does not exist." + + # Create a sample FASTA file to run predictions + fasta_file_path = tmp_path / "test.fasta" + create_fasta_file(fasta_file_path, 3, sequence_lengths=[32, 65, 129], repeating_dna_pattern=ALU_SEQUENCE) + + result_dir_original = tmp_path / "results_original" + cmd_predict = predict_cmd(base_model_checkpoint_path, result_dir_original, fasta_file_path) + stdout_predict: str = run_command_in_subprocess(command=cmd_predict, path=str(tmp_path)) + + # Assert that the output directory was created. + pred_files_original = glob.glob(str(result_dir_original / "predictions__rank_*.pt")) + assert len(pred_files_original) == 1, f"Expected 1 prediction file (for this test), got {len(pred_files_original)}" + + # Find the checkpoint dir generated by finetuning + expected_checkpoint_suffix = f"{num_steps}.0-last" + # Check if any subfolder ends with the expected suffix + matching_subfolders = [ + p for p in checkpoints_dir.iterdir() if p.is_dir() and (expected_checkpoint_suffix in p.name) + ] + + assert matching_subfolders, ( + f"No checkpoint subfolder ending with '{expected_checkpoint_suffix}' found in {checkpoints_dir}." + ) + + result_dir_peft = tmp_path / "results_peft" + additional_args = f"--lora-checkpoint-path {matching_subfolders[0]}" + cmd_predict = predict_cmd(base_model_checkpoint_path, result_dir_peft, fasta_file_path, additional_args) + stdout_predict: str = run_command_in_subprocess(command=cmd_predict, path=str(tmp_path)) + assert "Loading adapters from" in stdout_predict + + pred_files_peft = glob.glob(str(result_dir_peft / "predictions__rank_*.pt")) + assert len(pred_files_peft) == 1, f"Expected 1 prediction file (for this test), got {len(pred_files_peft)}" + + results_original = torch.load(f"{result_dir_original}/predictions__rank_0__dp_rank_0.pt") + results_peft = torch.load(f"{result_dir_peft}/predictions__rank_0__dp_rank_0.pt") + + seq_idx_original = results_original["seq_idx"] + seq_idx_peft = results_peft["seq_idx"] + assert torch.equal(seq_idx_original, seq_idx_peft), f"Tensors differ: {seq_idx_original} vs {seq_idx_peft}" + + logits_original = results_original["token_logits"] + logits_peft = results_peft["token_logits"] + assert (logits_original != logits_peft).any() + assert logits_original.shape == logits_peft.shape, ( + f"Shapes don't match: {logits_original.shape} vs {logits_peft.shape}" + )