Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 87 additions & 63 deletions sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

# Add import for Mamba models
from bionemo.evo2.models.mamba import MAMBA_MODEL_OPTIONS, MambaModel
from bionemo.evo2.models.peft import Evo2LoRA
from bionemo.llm.data import collate
from bionemo.llm.lightning import LightningPassthroughPredictionMixin
from bionemo.llm.utils.callbacks import PredictionWriter
Expand Down Expand Up @@ -159,6 +160,13 @@ def parse_args():
"know a model was trained with a specific interpolation factor for ROPE, provide it here, it can make a big "
"difference in accuracy.",
)
ap.add_argument(
"--lora-checkpoint-path",
type=Path,
required=False,
default=None,
help="Path to the lora states to restore from.",
)
return ap.parse_args()


Expand Down Expand Up @@ -261,6 +269,11 @@ def predict_step(self, batch, batch_idx: int | None = None) -> Tensor | dict[str
class HyenaPredictor(BasePredictor, HyenaModel):
"""A predictor for the Hyena model. This adds in the predict step and the passthrough method."""

def configure_model(self, *args, **kwargs) -> None:
"""Configure the model."""
super().configure_model(*args, **kwargs)
self.trainer.strategy._init_model_parallel = True


class MambaPredictor(BasePredictor, MambaModel):
"""Mamba model for prediction with additional metrics."""
Expand Down Expand Up @@ -397,6 +410,7 @@ def predict(
num_layers: int | None = None,
seq_len_interpolation_factor: int | None = None,
files_per_subdir: int | None = None,
lora_checkpoint_path: Path | None = None,
):
"""Inference workflow for Evo2.

Expand Down Expand Up @@ -424,6 +438,77 @@ def predict(
)
global_batch_size = micro_batch_size * world_size // model_parallel_size

callbacks = [
PredictionWriter(
output_dir=output_dir,
write_interval=write_interval,
batch_dim_key_defaults={"token_logits": 0},
seq_dim_key_defaults={"token_logits": 1},
files_per_subdir=files_per_subdir,
save_all_model_parallel_ranks=False, # only write one copy of predictions.
)
]

# The following two config options are really only used for testing, but may also be useful for getting output from
# specific layers of the model.
config_modifiers_init = {}
if hybrid_override_pattern is not None:
config_modifiers_init["hybrid_override_pattern"] = hybrid_override_pattern
if num_layers is not None:
config_modifiers_init["num_layers"] = num_layers

tokenizer = get_nmt_tokenizer("byte-level")

# Select model config based on model type
if model_type == "hyena":
if "-1m" in model_size and "nv" not in model_size and seq_len_interpolation_factor is None:
# TODO remove this override once we add this as a default upstream in NeMo.
# if you see this, just check the pointed to model option for the 1m model in nemo and see if it already
# has this option set.
config_modifiers_init["seq_len_interpolation_factor"] = 128

if model_size not in HYENA_MODEL_OPTIONS:
raise ValueError(f"Invalid model size for Hyena: {model_size}")
config = HYENA_MODEL_OPTIONS[model_size](
forward_step_fn=hyena_predict_forward_step,
data_step_fn=hyena_predict_data_step, # , attention_backend=AttnBackend.fused,
distribute_saved_activations=False if sequence_parallel and tensor_parallel_size > 1 else True,
# Only use vortex style FP8 in the model config if using FP8 and not full FP8. This will only apply FP8 to
# the projection layer of the hyena mixer.
vortex_style_fp8=fp8 and not full_fp8,
**config_modifiers_init,
)

if lora_checkpoint_path:
model_transform = Evo2LoRA(peft_ckpt_path=str(lora_checkpoint_path))
callbacks.append(model_transform)
else:
model_transform = None

model = HyenaPredictor(
config,
tokenizer=tokenizer,
output_log_prob_seqs=output_log_prob_seqs,
log_prob_collapse_option=log_prob_collapse_option,
model_transform=model_transform,
)
else: # mamba
if model_size not in MAMBA_MODEL_OPTIONS:
raise ValueError(f"Invalid model size for Mamba: {model_size}")
config = MAMBA_MODEL_OPTIONS[model_size](
forward_step_fn=hyena_predict_forward_step, # Can reuse the same forward steps
data_step_fn=hyena_predict_data_step,
distribute_saved_activations=False if sequence_parallel and tensor_parallel_size > 1 else True,
**config_modifiers_init,
)

model = MambaPredictor(
config,
tokenizer=tokenizer,
output_log_prob_seqs=output_log_prob_seqs,
log_prob_collapse_option=log_prob_collapse_option,
)

# Create PTL trainer.
trainer = nl.Trainer(
accelerator="gpu",
Expand Down Expand Up @@ -451,16 +536,7 @@ def predict(
log_every_n_steps=1,
limit_val_batches=10,
num_sanity_val_steps=0,
callbacks=[
PredictionWriter(
output_dir=output_dir,
write_interval=write_interval,
batch_dim_key_defaults={"token_logits": 0},
seq_dim_key_defaults={"token_logits": 1},
files_per_subdir=files_per_subdir,
save_all_model_parallel_ranks=False, # only write one copy of predictions.
)
],
callbacks=callbacks,
plugins=nl.MegatronMixedPrecision(
precision="bf16-mixed",
params_dtype=torch.bfloat16,
Expand All @@ -471,42 +547,6 @@ def predict(
fp8_amax_compute_algo="max" if fp8 and full_fp8 else "most_recent",
),
)
# The following two config options are really only used for testing, but may also be useful for getting output from
# specific layers of the model.
config_modifiers_init = {}
if hybrid_override_pattern is not None:
config_modifiers_init["hybrid_override_pattern"] = hybrid_override_pattern
if num_layers is not None:
config_modifiers_init["num_layers"] = num_layers
# Select model config based on model type
if model_type == "hyena":
if "-1m" in model_size and "nv" not in model_size and seq_len_interpolation_factor is None:
# TODO remove this override once we add this as a default upstream in NeMo.
# if you see this, just check the pointed to model option for the 1m model in nemo and see if it already
# has this option set.
config_modifiers_init["seq_len_interpolation_factor"] = 128

if model_size not in HYENA_MODEL_OPTIONS:
raise ValueError(f"Invalid model size for Hyena: {model_size}")
config = HYENA_MODEL_OPTIONS[model_size](
forward_step_fn=hyena_predict_forward_step,
data_step_fn=hyena_predict_data_step, # , attention_backend=AttnBackend.fused,
distribute_saved_activations=False if sequence_parallel and tensor_parallel_size > 1 else True,
# Only use vortex style FP8 in the model config if using FP8 and not full FP8. This will only apply FP8 to
# the projection layer of the hyena mixer.
vortex_style_fp8=fp8 and not full_fp8,
**config_modifiers_init,
)
else: # mamba
if model_size not in MAMBA_MODEL_OPTIONS:
raise ValueError(f"Invalid model size for Mamba: {model_size}")
config = MAMBA_MODEL_OPTIONS[model_size](
forward_step_fn=hyena_predict_forward_step, # Can reuse the same forward steps
data_step_fn=hyena_predict_data_step,
distribute_saved_activations=False if sequence_parallel and tensor_parallel_size > 1 else True,
**config_modifiers_init,
)

trainer.strategy._setup_optimizers = False

nemo_logger = NeMoLogger(log_dir=work_dir)
Expand All @@ -518,23 +558,6 @@ def predict(
resume_from_path=str(ckpt_dir),
restore_config=None,
)
tokenizer = get_nmt_tokenizer("byte-level")

# Create appropriate model based on type
if model_type == "hyena":
model = HyenaPredictor(
config,
tokenizer=tokenizer,
output_log_prob_seqs=output_log_prob_seqs,
log_prob_collapse_option=log_prob_collapse_option,
)
else: # mamba
model = MambaPredictor(
config,
tokenizer=tokenizer,
output_log_prob_seqs=output_log_prob_seqs,
log_prob_collapse_option=log_prob_collapse_option,
)

resume.setup(trainer, model) # this pulls weights from the starting checkpoint.

Expand Down Expand Up @@ -573,6 +596,7 @@ def main():
num_layers=args.num_layers,
files_per_subdir=args.files_per_subdir,
write_interval=args.write_interval,
lora_checkpoint_path=args.lora_checkpoint_path,
)


Expand Down
23 changes: 22 additions & 1 deletion sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,18 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
default=False,
help="Enable CUDA memory cleanup before validation to prevent initialization errors.",
)
parser.add_argument(
"--lora-alpha",
type=int,
default=None,
help="Alpha parameter for LoRA fine-tuning.",
)
parser.add_argument(
"--lora-dim",
type=int,
default=None,
help="Dim parameter for LoRA fine-tuning.",
)

recompute_group = parser.add_mutually_exclusive_group(required=False)
recompute_group.add_argument("--no-activation-checkpointing", action="store_true", default=False)
Expand Down Expand Up @@ -801,7 +813,16 @@ def train(args: argparse.Namespace) -> nl.Trainer:
# Lora adaptors configuration
lora_transform = None
if args.lora_finetune:
lora_transform = Evo2LoRA(peft_ckpt_path=args.lora_checkpoint_path)
lora_kwargs = {
k: v
for k, v in {
"alpha": args.lora_alpha,
"dim": args.lora_dim,
}.items()
if v is not None
}

lora_transform = Evo2LoRA(peft_ckpt_path=args.lora_checkpoint_path, **lora_kwargs)

model = llm.HyenaModel(model_config, tokenizer=data_module.tokenizer, model_transform=lora_transform)
elif model_type == "mamba": # mamba
Expand Down
11 changes: 11 additions & 0 deletions sub-packages/bionemo-evo2/tests/bionemo/evo2/run/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,14 @@ def small_training_finetune_cmd(
f"{'--global-batch-size ' + str(global_batch_size) if global_batch_size is not None else ''}"
)
return cmd


def predict_cmd(ckpt_dir: str, output_dir: str, fasta_file_path: str, additional_args: str = ""):
"""Command fro predict."""
cmd = (
f"predict_evo2 --fasta {fasta_file_path} --ckpt-dir {ckpt_dir} --output-dir {output_dir} "
"--model-size 1b_nv --num-layers 4 --hybrid-override-pattern SDH* --tensor-parallel-size 1 "
f"--pipeline-model-parallel-size 1 --context-parallel-size 1 {additional_args}"
)

return cmd
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,11 @@
from bionemo.core.data.load import load
from bionemo.llm.lightning import batch_collator
from bionemo.testing.data.fasta import ALU_SEQUENCE, create_fasta_file
from bionemo.testing.subprocess_utils import run_command_in_subprocess
from bionemo.testing.torch import check_fp8_support

from .common import predict_cmd, small_training_finetune_cmd


def is_a6000_gpu() -> bool:
# Check if any of the visible GPUs is an A6000
Expand Down Expand Up @@ -364,3 +367,85 @@ def test_predict_evo2_equivalent_with_log_probs(
else:
rel = 1e-6
assert log_probs.item() == pytest.approx(baseline_predictions_7b_1m_results[original_idx.item()], rel=rel)


@pytest.mark.timeout(512)
@pytest.mark.slow
def test_different_results_with_without_peft(tmp_path):
try:
base_model_checkpoint_path = load("evo2/1b-8k:1.0")
except ValueError as e:
if e.args[0].endswith("does not have an NGC URL."):
raise ValueError(
"Please re-run test with `BIONEMO_DATA_SOURCE=pbss py.test ...`, "
"one or more files are missing from ngc."
)
else:
raise e

num_steps = 2

result_dir = tmp_path / "lora_finetune"

# Note: The command assumes that `train_evo2` is in your PATH.
command_finetune = small_training_finetune_cmd(
result_dir,
max_steps=num_steps,
val_check=num_steps,
prev_ckpt=base_model_checkpoint_path,
create_tflops_callback=False,
additional_args="--lora-finetune",
)
stdout_finetune: str = run_command_in_subprocess(command=command_finetune, path=str(tmp_path))
assert "Restoring model weights from RestoreConfig(path='" in stdout_finetune
assert "Loading adapters from" not in stdout_finetune

# Check if checkpoints dir exists
checkpoints_dir = result_dir / "evo2" / "checkpoints"
assert checkpoints_dir.exists(), "Checkpoints folder does not exist."

# Create a sample FASTA file to run predictions
fasta_file_path = tmp_path / "test.fasta"
create_fasta_file(fasta_file_path, 3, sequence_lengths=[32, 65, 129], repeating_dna_pattern=ALU_SEQUENCE)

result_dir_original = tmp_path / "results_original"
cmd_predict = predict_cmd(base_model_checkpoint_path, result_dir_original, fasta_file_path)
stdout_predict: str = run_command_in_subprocess(command=cmd_predict, path=str(tmp_path))

# Assert that the output directory was created.
pred_files_original = glob.glob(str(result_dir_original / "predictions__rank_*.pt"))
assert len(pred_files_original) == 1, f"Expected 1 prediction file (for this test), got {len(pred_files_original)}"

# Find the checkpoint dir generated by finetuning
expected_checkpoint_suffix = f"{num_steps}.0-last"
# Check if any subfolder ends with the expected suffix
matching_subfolders = [
p for p in checkpoints_dir.iterdir() if p.is_dir() and (expected_checkpoint_suffix in p.name)
]

assert matching_subfolders, (
f"No checkpoint subfolder ending with '{expected_checkpoint_suffix}' found in {checkpoints_dir}."
)

result_dir_peft = tmp_path / "results_peft"
additional_args = f"--lora-checkpoint-path {matching_subfolders[0]}"
cmd_predict = predict_cmd(base_model_checkpoint_path, result_dir_peft, fasta_file_path, additional_args)
stdout_predict: str = run_command_in_subprocess(command=cmd_predict, path=str(tmp_path))
assert "Loading adapters from" in stdout_predict

pred_files_peft = glob.glob(str(result_dir_peft / "predictions__rank_*.pt"))
assert len(pred_files_peft) == 1, f"Expected 1 prediction file (for this test), got {len(pred_files_peft)}"

results_original = torch.load(f"{result_dir_original}/predictions__rank_0__dp_rank_0.pt")
results_peft = torch.load(f"{result_dir_peft}/predictions__rank_0__dp_rank_0.pt")

seq_idx_original = results_original["seq_idx"]
seq_idx_peft = results_peft["seq_idx"]
assert torch.equal(seq_idx_original, seq_idx_peft), f"Tensors differ: {seq_idx_original} vs {seq_idx_peft}"

logits_original = results_original["token_logits"]
logits_peft = results_peft["token_logits"]
assert (logits_original != logits_peft).any()
assert logits_original.shape == logits_peft.shape, (
f"Shapes don't match: {logits_original.shape} vs {logits_peft.shape}"
)