diff --git a/configs/local_config.yaml b/configs/local_config.yaml index a69831d..7e57ed9 100644 --- a/configs/local_config.yaml +++ b/configs/local_config.yaml @@ -21,6 +21,11 @@ training: validation_frequency: 100 logging_frequency: 10 save_frequency: 5_000 + # TODO: allow multiprocess with unimol? + num_workers: 0 + ce_weight: 1.0 + cls_enc_weight: 1.0 + cls_dec_weight: 0.05 # Scheduler scheduler: diff --git a/configs/local_config_nocat.yaml b/configs/local_config_nocat.yaml index e20bb6e..6da84a1 100644 --- a/configs/local_config_nocat.yaml +++ b/configs/local_config_nocat.yaml @@ -21,6 +21,7 @@ training: validation_frequency: 100 logging_frequency: 10 save_frequency: 5_000 + num_workers: 8 # Scheduler scheduler: diff --git a/configs/test_config.yaml b/configs/test_config.yaml index 85778bd..99f3c23 100644 --- a/configs/test_config.yaml +++ b/configs/test_config.yaml @@ -1,7 +1,7 @@ # Model Architecture model: use_concat: True - max_seq_length: 1 + max_seq_length: 4 embed_dim: 3 num_heads: 1 num_layers: 1 @@ -18,6 +18,7 @@ training: validation_frequency: 50 logging_frequency: 10 save_frequency: 1000 + num_workers: 8 # Scheduler scheduler: diff --git a/models/multimodal_to_smiles.py b/models/multimodal_to_smiles.py index acb8d3e..9cc2528 100644 --- a/models/multimodal_to_smiles.py +++ b/models/multimodal_to_smiles.py @@ -17,12 +17,14 @@ def __init__( resample_size: int = 1000, use_concat: bool = True, verbose: bool = False, - domain_ranges: list | None = None + domain_ranges: list | None = None, + cls_dim: int = 512 ): super().__init__() self.use_concat = use_concat self.verbose = verbose + self.cls_dim = cls_dim # Spectral encoder with verbose off memory_dim = 2046 if use_concat else embed_dim @@ -52,7 +54,15 @@ def __init__( verbose=verbose ) + # extra projections for cls + self.to_cls_enc = th.nn.Linear(memory_dim, cls_dim) + self.to_cls_dec = th.nn.Linear(embed_dim, cls_dim) + def forward(self, nmr_data: tuple | th.Tensor | None, ir_data: tuple | th.Tensor | None, c_nmr_data: tuple | th.Tensor | None, target_seq: Any | None = None, target_mask: th.Tensor | None = None): + """ Returns the cls (dense encoding for molecule) and logits to sample smiles from + * (cls_enc, cls_dec) = 2x (B, D) + * logits: (B, L, D) + """ if self.verbose: print("\n=== Starting Forward Pass ===") print("\nSpectroscopic data shapes inside forward:") @@ -77,11 +87,15 @@ def forward(self, nmr_data: tuple | th.Tensor | None, ir_data: tuple | th.Tensor if self.verbose: print("\n=== Starting Decoding ===") - + # Decode to SMILES - logits = self.decoder(target_seq, memory, target_mask) + cls, logits = self.decoder(target_seq, memory, target_mask) if self.verbose: print("\n=== Forward Pass Complete ===") + + # [CLS] by mean pooling: (B, D) + cls_enc = self.to_cls_enc(memory).mean(dim=-2) + cls_dec = self.to_cls_dec(cls) - return logits \ No newline at end of file + return (cls_enc, cls_dec), logits \ No newline at end of file diff --git a/models/transformer_decoder.py b/models/transformer_decoder.py index 1ccbdfb..a38f779 100644 --- a/models/transformer_decoder.py +++ b/models/transformer_decoder.py @@ -234,7 +234,9 @@ def __init__( # original: # self.memory_proj = nn.Linear(memory_dim, memory_dim) self.memory_proj = nn.Linear(memory_dim, embed_dim) if memory_dim != embed_dim else nn.Identity() - + # absolute posemb (will also replace memory posemb) + self.pos_token = th.nn.Parameter(1e-3*th.randn(1, max_seq_length + 256, embed_dim)) + # Create decoder layers with updated memory dimension self.layers = nn.ModuleList([ # DecoderLayer( @@ -284,14 +286,17 @@ def forward(self, tgt: th.Tensor, memory: th.Tensor, tgt_mask: th.Tensor | None print(f"memory.shape: {memory.shape}") memory = memory.unsqueeze(1) - # Expand memory batch dimension if needed - # if memory.size(0) == 1 and x.size(0) > 1: - # memory = memory.expand(x.size(0), -1, -1) + # Add cls token: + cls_token = th.zeros_like(memory[:, -1:, :]) + memory = th.cat([memory, cls_token], dim=1) # mix `memory` and `x` as prompt + answer, and add causal mask M = memory.shape[-2] x = th.cat([memory, x], dim=1) + # add absolute posemb + x = x + self.pos_token[:, :T+M].repeat(B, 1, 1) + mask = th.ones(B, T+M, T+M, device=x.device).tril().bool() # memory can attend to all of itself. Unlike causal decoding mask[:, :M, :M] = True @@ -306,4 +311,5 @@ def forward(self, tgt: th.Tensor, memory: th.Tensor, tgt_mask: th.Tensor | None # remove memory prompt from output tokens out = self.out(x[:, M:]) - return out \ No newline at end of file + cls = x[:, M-1:M] + return cls, out \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index e14d183..c5aee26 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ scikit-learn rdkit tqdm seaborn +unimol_tools diff --git a/train_autoregressive.py b/train_autoregressive.py index 1281288..148376b 100644 --- a/train_autoregressive.py +++ b/train_autoregressive.py @@ -17,6 +17,7 @@ import torch import torch.nn as nn import torch.optim as optim +import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader import wandb from datetime import datetime @@ -49,6 +50,70 @@ vocab_path = os.path.join(current_dir, 'vocab.txt') tokenizer = SmilesTokenizer(vocab_file=vocab_path) + +# ------------------------------------------------------------------------- +# suppress unimol output +# ------------------------------------------------------------------------- + +from unimol_tools import UniMolRepr +clf = UniMolRepr(data_type='molecule', model_size="84m", model_name="unimolv1", use_gpu=True) + +import os +import sys +from contextlib import contextmanager +from typing import TextIO, Optional + + +@contextmanager +def suppress_output(suppress_stdout: bool = True, suppress_stderr: bool = True): + """ + Context manager to temporarily suppress stdout and/or stderr output. + + Args: + suppress_stdout (bool): Whether to suppress stdout. Defaults to True. + suppress_stderr (bool): Whether to suppress stderr. Defaults to True. + + Example: + with suppress_output(): + print("This won't be displayed") + sys.stderr.write("This error won't be displayed") + """ + # Save the original file descriptors + stdout_fd: Optional[int] = None + stderr_fd: Optional[int] = None + null_fd: Optional[int] = None + + try: + # Open null device for redirecting output + null_fd = os.open(os.devnull, os.O_RDWR) + + # Handle stdout suppression + if suppress_stdout: + stdout_fd = os.dup(sys.stdout.fileno()) + os.dup2(null_fd, sys.stdout.fileno()) + + # Handle stderr suppression + if suppress_stderr: + stderr_fd = os.dup(sys.stderr.fileno()) + os.dup2(null_fd, sys.stderr.fileno()) + + yield + + finally: + # Restore stdout if it was suppressed + if stdout_fd is not None: + os.dup2(stdout_fd, sys.stdout.fileno()) + os.close(stdout_fd) + + # Restore stderr if it was suppressed + if stderr_fd is not None: + os.dup2(stderr_fd, sys.stderr.fileno()) + os.close(stderr_fd) + + # Close null device + if null_fd is not None: + os.close(null_fd) + # ------------------------------------------------------------------------- # Linear Warmup + Constant LR Scheduler # ------------------------------------------------------------------------- @@ -208,7 +273,7 @@ def __getitem__(self, idx): ) tokens = torch.tensor(tokens, dtype=torch.long) - return tokens, ir_tuple, h_nmr_tuple, c_nmr_tuple + return tokens, ir_tuple, h_nmr_tuple, c_nmr_tuple, smiles_str # ------------------------------------------------------------------------- @@ -261,6 +326,13 @@ def __len__(self): return len(self.data) def __getitem__(self, idx): + """ + tokens: (L,) longtensor + ir_spectra: (N, L_ir) float32 tensor + h_nmr_spectra: (N, L_h_nmr) float32 tensor + c_nmr_spectra: (N, L_c_nmr) float32 tensor + smiles: str + """ row = self.data.iloc[idx] # Tokenize SMILES @@ -294,18 +366,19 @@ def to_tensor(x, spectrum_type): h_nmr_spectra = to_tensor(row['h_nmr_spectra'], 'h_nmr') c_nmr_spectra = to_tensor(row['c_nmr_spectra'], 'c_nmr') - return tokens, ir_spectra, h_nmr_spectra, c_nmr_spectra + return tokens, ir_spectra, h_nmr_spectra, c_nmr_spectra, smiles # ------------------------------------------------------------------------- # Collate Function - Moved outside to be picklable # ------------------------------------------------------------------------- + def collate_fn(batch): """ Custom collate: pad tokens, preserve spectral data tuples. """ # Unzip the batch into separate lists - all_tokens, all_ir, all_h_nmr, all_c_nmr = zip(*batch) + all_tokens, all_ir, all_h_nmr, all_c_nmr, smiles_list = zip(*batch) # Helper function to stack spectral data tuples def maybe_stack_with_domain(items): @@ -323,18 +396,18 @@ def maybe_stack_with_domain(items): c_nmr_batch = maybe_stack_with_domain(all_c_nmr) if all_c_nmr[0] is not None else None # Pad tokens + num_batch = len(all_tokens) max_len = max(len(t) for t in all_tokens) - padded_tokens = [] - for seq in all_tokens: - pad_amount = max_len - len(seq) - seq_tensor = torch.tensor(seq, dtype=torch.long) - if pad_amount > 0: - pad_tensor = torch.full((pad_amount,), tokenizer.pad_token_id, dtype=torch.long) - seq_tensor = torch.cat([seq_tensor, pad_tensor], dim=0) - padded_tokens.append(seq_tensor) - token_batch = torch.stack(padded_tokens, dim=0) + padded_tokens = torch.full((num_batch, max_len), tokenizer.pad_token_id, dtype=torch.long) + for i,seq in enumerate(all_tokens): + padded_tokens[i, :len(seq)] = torch.tensor(seq, dtype=torch.long) + + # TODO: need to suppress this output + with suppress_output(): + embedds = clf.get_repr(list(smiles_list), return_atomic_reprs=False) + embedds = torch.tensor(embedds["cls_repr"], dtype=torch.float32) - return token_batch, ir_batch, h_nmr_batch, c_nmr_batch + return padded_tokens, ir_batch, h_nmr_batch, c_nmr_batch, embedds def create_data_loaders(tokenizer, config): @@ -378,7 +451,7 @@ def create_data_loaders(tokenizer, config): torch.utils.data.Subset(dataset, train_indices), batch_size=config['training']['batch_size'], shuffle=True, - num_workers=0, + num_workers=config['training']['num_workers'], collate_fn=collate_fn ) @@ -509,6 +582,9 @@ def main(): num_layers = config['model']['num_layers'] dropout = config['model']['dropout'] resample_size = config['model']['resample_size'] + ce_weight = config["training"]["ce_weight"] + cls_enc_weight = config["training"]["cls_enc_weight"] + cls_dec_weight = config["training"]["cls_dec_weight"] PAD_TOKEN_ID = tokenizer.pad_token_id BOS_TOKEN_ID = tokenizer.cls_token_id @@ -814,14 +890,15 @@ def greedy_decode(model, nmr_data, ir_data, c_nmr_data, max_len=128): def validate(model, val_loader, criterion, tokenizer, device=device): """Validation using teacher forcing and comparing exact matches""" model.eval() - total_loss = 0 + loss_dict = {"total_loss": 0, "ce_loss": 0, "cls_enc_loss": 0, "cls_dec_loss": 0} num_batches = 0 detailed_results = [] with torch.no_grad(): - for tgt_tokens, ir, h_nmr, c_nmr in val_loader: + for tgt_tokens, ir, h_nmr, c_nmr, embedds in val_loader: target_cpu = tgt_tokens.clone().cpu() tgt_tokens = tgt_tokens.to(device) + embedds = embedds.to(device) # Handle spectral data tuples if ir is not None: @@ -846,9 +923,14 @@ def validate(model, val_loader, criterion, tokenizer, device=device): T = tgt_tokens.shape[1] mask = torch.triu(torch.ones(T, T, dtype=torch.bool, device=tgt_tokens.device), 1) - logits = model(h_nmr, ir, c_nmr, target_seq=tgt_tokens[:, :-1], target_mask=mask[:-1, :-1]) - loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_tokens[:, 1:].reshape(-1)) - + cls, logits = model(h_nmr, ir, c_nmr, target_seq=tgt_tokens[:, :-1], target_mask=mask[:-1, :-1]) + cls_enc, cls_dec = cls + ce_loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_tokens[:, 1:].reshape(-1)) + cls_enc_loss = (cls_enc - embedds).abs().mean() + cls_dec_loss = (cls_dec - embedds).abs().mean() + loss = ce_loss * ce_weight + cls_enc_weight * cls_enc_loss + cls_dec_weight * cls_dec_loss + + # Decoding molecules # Get predictions from logits and compare with targets pred_tokens = logits.argmax(dim=-1) # Shape: [batch_size, seq_len] @@ -863,7 +945,11 @@ def validate(model, val_loader, criterion, tokenizer, device=device): details = evaluate_predictions(preds_decoded, targets_decoded) detailed_results += details - total_loss += loss.item() + loss_dict["total_loss"] = loss_dict["total_loss"] + loss + loss_dict["ce_loss"] = loss_dict["ce_loss"] + ce_loss + loss_dict["cls_enc_loss"] = loss_dict["cls_enc_loss"] + cls_enc_loss + loss_dict["cls_dec_loss"] = loss_dict["cls_dec_loss"] + cls_dec_loss + num_batches += 1 detailed_metrics = aggregate_metrics(detailed_results) @@ -871,7 +957,10 @@ def validate(model, val_loader, criterion, tokenizer, device=device): sample_valid_set = np.random.choice(valid_set, size=min(len(valid_set), 100)).tolist() return { - f'val_loss': total_loss / num_batches, + f'val_total_loss': loss_dict["total_loss"].item() / num_batches, + f'val_loss': loss_dict["ce_loss"].item() / num_batches, + f'val_cls_enc_loss': loss_dict["cls_enc_loss"].item() / num_batches, + f'val_cls_enc_loss': loss_dict["cls_enc_loss"].item() / num_batches, # Store only a couple matches to avoid excessive logging 'valid_set': sample_valid_set, **detailed_metrics @@ -919,10 +1008,11 @@ def log_results(val_metrics, global_step, table: wandb.Table | None = None, pref batch_start_time = time.time() # Unpack the batch data correctly - tgt_tokens, ir, h_nmr, c_nmr = batch + tgt_tokens, ir, h_nmr, c_nmr, embedds = batch # Get the batch data tgt_tokens = tgt_tokens.to(device) + embedds = embedds.to(device) # Handle spectral data tuples if ir is not None: @@ -946,8 +1036,12 @@ def log_results(val_metrics, global_step, table: wandb.Table | None = None, pref # Forward pass T = tgt_tokens.shape[1] mask = torch.triu(torch.ones(T, T, dtype=torch.bool, device=tgt_tokens.device), 1) - logits = model(h_nmr, ir, c_nmr, target_seq=tgt_tokens[:, :-1], target_mask=mask[:-1, :-1]) - loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_tokens[:, 1:].reshape(-1)) + cls, logits = model(h_nmr, ir, c_nmr, target_seq=tgt_tokens[:, :-1], target_mask=mask[:-1, :-1]) + cls_enc, cls_dec = cls + ce_loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_tokens[:, 1:].reshape(-1)) + cls_enc_loss = (cls_enc - embedds).abs().mean() + cls_dec_loss = (cls_dec - embedds).abs().mean() + loss = ce_loss * ce_weight + cls_enc_weight * cls_enc_loss + cls_dec_weight * cls_dec_loss # Backward pass optimizer.zero_grad() @@ -968,7 +1062,11 @@ def log_results(val_metrics, global_step, table: wandb.Table | None = None, pref pbar.set_description(train_log) wandb.log({ - "batch_loss": loss.item(), + "batch_loss": ce_loss.item(), + "train_total_loss": loss.item(), + "train_cls_enc_loss": cls_enc_loss.item(), + "train_cls_dec_loss": cls_dec_loss.item(), + # others "learning_rate": current_lr, "epoch": epoch + 1, "global_step": global_step,