Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions bionemo-recipes/recipes/codonfm_ptl_te/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,14 @@ RUN chown -R ${USERNAME:-vscode}:${USERNAME:-vscode} /workspace/codonfm

# Switch to the non-root user
USER $USERNAME

# ----------------- For benchmarking only -----------------
# Warning: I was only able to build this image in an instance with 2TB of memory.
# Otherwise, a segmentation fault occurs during the build process:
# /bin/bash: line 1: 13517 Segmentation fault (core dumped) ptxas -arch=sm_90 -m64 -v --generate-line-info "/tmp/tmpxft_00002f58_00000000-6_flash_fwd_hdim64_256_fp16_paged_split_sm90.ptx" -o "/tmp/tmpxft_00002f58_00000000-8_flash_fwd_hdim64_256_fp16_paged_split_sm90.cubin" > /tmp/tmpxft_00002f58_00000000-10_2eb7d280_stdout 2> /tmp/tmpxft_00002f58_00000000-10_2eb7d280_stderr
#
# Could have also been caused by CUDA-compatibility issues.

FROM production AS benchmarking

RUN pip install -v --no-build-isolation -U git+https://github.com/facebookresearch/xformers.git@v0.0.32.post2#egg=xformers --no-deps
2 changes: 2 additions & 0 deletions bionemo-recipes/recipes/codonfm_ptl_te/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ The table below summarizes the set of open source pre-trained weights currently
| EnCodon 600M | MLM (random p=0.15) | 2048 | 12 | 16 | 8192 | `mlm/encodon_600m.sh` | [Link](https://huggingface.co/nvidia/NV-CodonFM-Encodon-600M-v1) | [Link](https://huggingface.co/nvidia/NV-CodonFM-Encodon-TE-600M-v1) |
| EnCodon 1B | MLM (random p=0.15) | 2048 | 18 | 16 | 8192 | `mlm/encodon_1b.sh` | [Link](https://huggingface.co/nvidia/NV-CodonFM-Encodon-1B-v1) | [Link](https://huggingface.co/nvidia/NV-CodonFM-Encodon-TE-1B-v1) |
| EnCodon 1B (CDSWT) | MLM (codon frequency-weighted) | 2048 | 18 | 16 | 8192 | `cdswt/encodon_1b.sh` | [Link](https://huggingface.co/nvidia/NV-CodonFM-Encodon-Cdwt-1B-v1) | [Link](https://huggingface.co/nvidia/NV-CodonFM-Encodon-TE-Cdwt-1B-v1) |
| EnCodon 5B | MLM (codon p=0.15) | 4096 | 24 | 32 | 16384 | `mlm/encodon_5b.sh` | [Link](https://huggingface.co/nvidia/NV-CodonFM-Encodon-Cdwt-5B-v1) | [Link](https://huggingface.co/nvidia/NV-CodonFM-Encodon-TE-Cdwt-5B-v1) |
| EnCodon 5B (CDSWT) | MLM (codon frequency-weighted) | 4096 | 24 | 32 | 16384 | `cdswt/encodon_5b.sh` | [Link](https://huggingface.co/nvidia/NV-CodonFM-Encodon-Cdwt-5B-v1) | [Link](https://huggingface.co/nvidia/NV-CodonFM-Encodon-TE-Cdwt-5B-v1) |

## Repository Structure

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,39 @@

import argparse
import logging
import os

import torch
from safetensors.torch import save_file as safetensors_save_file

from src.utils.load_checkpoint import load_checkpoint


logger = logging.getLogger(__name__)

ALLOWED_HYPERPARAMETER_KEYS = (
"vocab_size",
"hidden_size",
"num_hidden_layers",
"num_attention_heads",
"intermediate_size",
"hidden_act",
"hidden_dropout_prob",
"attention_probs_dropout_prob",
"initializer_range",
"layer_norm_eps",
"pad_token_id",
"position_embedding_type",
"classifier_dropout",
"rotary_theta",
"ignore_index",
"loss_type",
"lora",
"lora_alpha",
"lora_r",
"lora_dropout",
)

# PYTorch -> TE keymap
PYTORCH_TO_TE_KEYMAP = {
"model.layers.*.pre_attn_layer_norm.weight": "model.layers.*.self_attention.layernorm_qkv.layer_norm_weight",
Expand Down Expand Up @@ -300,6 +325,11 @@ def convert_state_dict(src: dict, keymap: dict):
return dst_state_dict


def filter_hyper_parameters(hyper_parameters: dict) -> dict:
"""Keep only conversion-compatible hyperparameter keys."""
return {key: value for key, value in hyper_parameters.items() if key in ALLOWED_HYPERPARAMETER_KEYS}


def main():
"""Main function."""
logging.basicConfig(level=logging.INFO)
Expand All @@ -325,6 +355,7 @@ def main():
# Load source checkpoint (automatically detects format)
logger.info(f"Loading checkpoint from {args.src}")
src_checkpoint = load_checkpoint(args.src, map_location="cpu")
src_checkpoint["hyper_parameters"] = filter_hyper_parameters(src_checkpoint["hyper_parameters"])

# Perform conversion based on direction
if args.direction == "pytorch2te":
Expand All @@ -341,11 +372,19 @@ def main():
dst_state_dict = split_qkv(converted_state_dict, src_checkpoint["hyper_parameters"])

# Prepare final checkpoint
dst_checkpoint = {"state_dict": dst_state_dict, "hyper_parameters": src_checkpoint["hyper_parameters"]}
dst_checkpoint = {
"state_dict": dst_state_dict,
"hyper_parameters": src_checkpoint["hyper_parameters"],
}

# Save the converted checkpoint in pickled format
torch.save(dst_checkpoint, args.dst)
logger.info(f"Successfully converted checkpoint from {args.src} to {args.dst}")
logger.info(f"Successfully converted checkpoint saved to {args.dst}")

# Save the state_dict in safetensors format alongside the .ckpt file
safetensors_path = os.path.splitext(args.dst)[0] + ".safetensors"
safetensors_save_file(dst_state_dict, safetensors_path)
logger.info(f"Successfully saved safetensors checkpoint to {safetensors_path}")


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@


# %%
import argparse
import json
import sys
from pathlib import Path
Expand All @@ -23,41 +24,52 @@
from tqdm import tqdm


sys.path.append("/workspace/codon_fm")
sys.path.append("/workspace/codonfm")
from src.tokenizer import Tokenizer


data_path = Path("/data/ncbi/processed_unfiltered")
tax_ids_to_remove = json.load(open("/data/ncbi/taxids_to_remove.json"))
metadata = json.load(open(data_path / "metadata.json"))
tokenizer = Tokenizer()


groups = set([x["file_name"][:-4] for x in metadata["file_metadata"]]) # noqa: C403
counts = {g: np.zeros(tokenizer.vocab_size) for g in groups}
for fm, cm in tqdm(zip(metadata["file_metadata"], metadata["chunks"]), total=len(metadata["file_metadata"])):
group = fm["file_name"][:-4]
if group in tax_ids_to_remove:
curr_taxids_to_remove = set(tax_ids_to_remove[group])
else:
curr_taxids_to_remove = set()
mmap = np.memmap(
data_path / cm["sequences"]["path"],
dtype=cm["sequences"]["dtype"],
mode="r",
shape=tuple(cm["sequences"]["shape"]),
)
idx_mmap = np.memmap(
data_path / cm["index"]["path"], dtype=cm["index"]["dtype"], mode="r", shape=tuple(cm["index"]["shape"])
)
for start, end, taxid in idx_mmap:
if taxid in curr_taxids_to_remove:
continue
seq = mmap[start:end]
idx, count = np.unique(seq, return_counts=True)
counts[group][idx] += count
def main(pretraining_processed_data_dir: Path, data_dir: Path):
"""Check codon frequency."""
tax_ids_to_remove = json.load(open(data_dir / Path("taxids_to_remove.json")))
metadata = json.load(open(pretraining_processed_data_dir / "metadata.json"))
tokenizer = Tokenizer()

# %%
for g in counts:
counts[g] = counts[g].tolist()
json.dump(counts, open("/data/ncbi/codon_counts_nopathogen.json", "w"))
groups = set([x["file_name"][:-4] for x in metadata["file_metadata"]]) # noqa: C403
counts = {g: np.zeros(tokenizer.vocab_size) for g in groups}
for fm, cm in tqdm(zip(metadata["file_metadata"], metadata["chunks"]), total=len(metadata["file_metadata"])):
group = fm["file_name"][:-4]
if group in tax_ids_to_remove:
curr_taxids_to_remove = set(tax_ids_to_remove[group])
else:
curr_taxids_to_remove = set()
mmap = np.memmap(
pretraining_processed_data_dir / cm["sequences"]["path"],
dtype=cm["sequences"]["dtype"],
mode="r",
shape=tuple(cm["sequences"]["shape"]),
)
idx_mmap = np.memmap(
pretraining_processed_data_dir / cm["index"]["path"],
dtype=cm["index"]["dtype"],
mode="r",
shape=tuple(cm["index"]["shape"]),
)
for start, end, taxid in idx_mmap:
if taxid in curr_taxids_to_remove:
continue
seq = mmap[start:end]
idx, count = np.unique(seq, return_counts=True)
counts[group][idx] += count

# %%
for g in counts:
counts[g] = counts[g].tolist()
json.dump(counts, open(data_dir / "codon_counts_nopathogen.json", "w"))


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Check codon frequency")
parser.add_argument("--pretraining_processed_data_dir", type=str, required=True)
parser.add_argument("--data_dir", type=str, required=True)
args = parser.parse_args()
main(Path(args.pretraining_processed_data_dir), Path(args.data_dir))
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@
import argparse
import json
import os
import sys
from multiprocessing import Pool, cpu_count

import numpy as np
import polars as pl
import pyarrow.parquet as pq
from tqdm import tqdm


sys.path.append("/workspace/codonfm")
from src.tokenizer import Tokenizer


Expand Down
Loading