Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ lg-gpu = [
mgm = [
"selfies>=2.1.2",
]
calm = [
"calm @ git+https://github.com/oxpig/CaLM.git",
]
flash = [
"flash-attn>=2.8.0.post2; sys_platform == 'linux'",
"rotary-embedding-torch",
Expand Down
4 changes: 2 additions & 2 deletions src/lobster/callbacks/_calm_linear_probe_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.utils.data import ConcatDataset, DataLoader, Subset
from tqdm import tqdm

from lobster.constants import CALM_TASKS
from lobster.constants import CALM_TASKS, Modality

Check failure on line 12 in src/lobster/callbacks/_calm_linear_probe_callback.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

src/lobster/callbacks/_calm_linear_probe_callback.py:12:43: F401 `lobster.constants.Modality` imported but unused
from lobster.datasets import CalmPropertyDataset

from ._linear_probe_callback import LinearProbeCallback
Expand Down Expand Up @@ -210,7 +210,7 @@
sequences = list(x)

# Both CaLM and UME models have embed_sequences method
batch_embeddings = model.embed_sequences(sequences=sequences)
batch_embeddings = model.embed_sequences(sequences=sequences, modality=MODALITY.NUCLEOTIDE)

Check failure on line 213 in src/lobster/callbacks/_calm_linear_probe_callback.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

src/lobster/callbacks/_calm_linear_probe_callback.py:213:88: F821 Undefined name `MODALITY`

embeddings.append(batch_embeddings.cpu())
targets.append(y.cpu())
Expand Down
1 change: 1 addition & 0 deletions src/lobster/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._calm import CaLM
from ._cbmlm import LobsterCBMPMLM
from ._clip import MiniCLIP
from ._clm import LobsterPCLM
Expand Down
144 changes: 144 additions & 0 deletions src/lobster/model/_calm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import pytorch_lightning as L
import torch

from lobster._ensure_package import ensure_package

# Ensure CaLM package is available
ensure_package("calm", group="calm")

from calm.pretrained import CaLM as PretrainedCaLM

Check failure on line 9 in src/lobster/model/_calm.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E402)

src/lobster/model/_calm.py:9:1: E402 Module level import not at top of file


class CaLM(L.LightningModule):
"""
Minimal LightningModule wrapper for the CaLM pretrained model.

Only supports inference via `embed_sequences`.

Parameters
----------
device : str, optional
Device for returned tensors. Default is "cpu".
"""

def __init__(self, device: str = "cpu"):
super().__init__()

self.device_type = device

# Initialize pretrained CaLM model (kept on its default device)
self.calm = PretrainedCaLM()
# Put underlying torch model in eval mode
if hasattr(self.calm, "model") and hasattr(self.calm.model, "eval"):
self.calm.model.eval()

def _normalize_to_codon_sequences(self, sequences: list[str]) -> list[str]:
"""
Normalize input sequences to RNA codon strings that CaLM expects.

- If input looks like nucleotides (A/C/G/T/U), convert T->U and trim
any trailing incomplete codon chunk.
- If input looks like amino acids, back-translate each residue to a
canonical codon (most common option) to form an RNA sequence.
"""
# Simple heuristic alphabets
nuc_letters = set("ACGTU")
set("ACDEFGHIKLMNPQRSTVWY")

# Canonical back-translation map (RNA, using common codons)
aa_to_codon = {
"A": "GCU",
"C": "UGC",
"D": "GAU",
"E": "GAA",
"F": "UUU",
"G": "GGU",
"H": "CAU",
"I": "AUU",
"K": "AAA",
"L": "CUU",
"M": "AUG",
"N": "AAU",
"P": "CCU",
"Q": "CAA",
"R": "CGU",
"S": "UCU",
"T": "ACU",
"V": "GUU",
"W": "UGG",
"Y": "UAU",
}

normalized: list[str] = []
for seq in sequences:
if not isinstance(seq, str):
seq = str(seq)
seq = seq.strip().upper().replace(" ", "")

# Decide modality by character set
unique_chars = set(seq)
is_nucleotide = unique_chars.issubset(nuc_letters) and len(unique_chars) > 0

if is_nucleotide:
# Convert DNA to RNA
rna = seq.replace("T", "U")
# Trim trailing incomplete codon
trim_len = len(rna) - (len(rna) % 3)
rna = rna[:trim_len]
normalized.append(rna)
else:
# Treat as amino acid sequence and back-translate to RNA
# Replace any non-standard AA with a reasonable placeholder (glycine)
codons = []
for aa in seq:
codons.append(aa_to_codon.get(aa, "GGU"))
normalized.append("".join(codons))

return normalized

def embed_sequences(self, sequences: str | list[str], **kwargs) -> torch.Tensor:
"""
Embed cDNA sequences using CaLM and return averaged embeddings.

Parameters
----------
sequences : str | list[str]
One sequence or list of sequences to embed.

Returns
-------
torch.Tensor
Tensor of shape (batch, embedding_dim).
"""
if isinstance(sequences, str):
sequences = [sequences]

# Normalize to RNA codon strings in multiples of 3
normalized_sequences = self._normalize_to_codon_sequences(sequences)

with torch.no_grad():
# Process sequences individually to avoid OOM with very long sequences
# This is less efficient but more memory-safe
embeddings_list = []

for seq in normalized_sequences:
# Check if sequence is extremely long and might cause OOM
if len(seq) > 50000: # ~50k characters is very long
# For extremely long sequences, truncate to manageable size
# Take from the middle to avoid losing important regions
start_idx = len(seq) // 4
end_idx = start_idx + 30000 # Take 30k characters from middle
seq = seq[start_idx:end_idx]
print(
f"Warning: Truncated very long sequence from {len(normalized_sequences[0])} to {len(seq)} characters"
)

# Process single sequence
single_embedding = self.calm.embed_sequences([seq])
embeddings_list.append(single_embedding)

# Concatenate all embeddings
embeddings = torch.cat(embeddings_list, dim=0)

# Ensure returned tensor is on the requested device
return embeddings.to(self.device_type)
Loading