diff --git a/configs/local_config.yaml b/configs/local_config.yaml index a69831d..8a72d17 100644 --- a/configs/local_config.yaml +++ b/configs/local_config.yaml @@ -2,16 +2,16 @@ model: use_concat: True max_seq_length: 80 - embed_dim: 2048 - num_heads: 16 - num_layers: 8 + embed_dim: 1024 + num_heads: 8 + num_layers: 6 dropout: 0.1 # dropout: 0.0 resample_size: 1000 - + use_mlp_for_nmr: True #default is true # Training Parameters training: - batch_size: 64 + batch_size: 32 # batch_size: 32 test_batch_size: 1 num_epochs: 5 diff --git a/models/multimodal_to_smiles.py b/models/multimodal_to_smiles.py index acb8d3e..32a63b5 100644 --- a/models/multimodal_to_smiles.py +++ b/models/multimodal_to_smiles.py @@ -17,7 +17,8 @@ def __init__( resample_size: int = 1000, use_concat: bool = True, verbose: bool = False, - domain_ranges: list | None = None + domain_ranges: list | None = None, + use_mlp_for_nmr: bool = True ): super().__init__() @@ -34,7 +35,8 @@ def __init__( resample_size=resample_size, use_concat=use_concat, verbose=verbose, - domain_ranges=domain_ranges + domain_ranges=domain_ranges, + use_mlp_for_nmr=use_mlp_for_nmr ) # Calculate decoder input dimension diff --git a/models/spectral_encoder.py b/models/spectral_encoder.py index 7442178..6d37080 100644 --- a/models/spectral_encoder.py +++ b/models/spectral_encoder.py @@ -57,6 +57,7 @@ def __call__(self, nmr_data, ir_data, c_nmr_data): x_nmr = torch.from_numpy(x_nmr).float() x_nmr = x_nmr.unsqueeze(1) # Add channel dimension else: + # Pass through raw data when not processing x_nmr = nmr_data[0] if isinstance(nmr_data, tuple) else nmr_data # Process IR - expects [batch, length] -> outputs [batch, 1, length] @@ -95,32 +96,41 @@ def __call__(self, nmr_data, ir_data, c_nmr_data): x_c_nmr = torch.from_numpy(x_c_nmr).float() x_c_nmr = x_c_nmr.unsqueeze(1) # Add channel dimension else: + # Pass through raw data when not processing x_c_nmr = c_nmr_data[0] if isinstance(c_nmr_data, tuple) else c_nmr_data # Move to device if inputs are on device - if isinstance(nmr_data[0], torch.Tensor): - device = nmr_data[0].device - x_nmr = x_nmr.to(device) - x_ir = x_ir.to(device) - x_c_nmr = x_c_nmr.to(device) + device = None + if ir_data is not None and isinstance(ir_data, tuple) and isinstance(ir_data[0], torch.Tensor): + device = ir_data[0].device + + if device is not None: + if x_ir is not None and isinstance(x_ir, torch.Tensor): + x_ir = x_ir.to(device) + if x_nmr is not None and isinstance(x_nmr, torch.Tensor): + x_nmr = x_nmr.to(device) + if x_c_nmr is not None and isinstance(x_c_nmr, torch.Tensor): + x_c_nmr = x_c_nmr.to(device) return x_nmr, x_ir, x_c_nmr class MultimodalSpectralEncoder(nn.Module): def __init__(self, embed_dim=768, num_heads=8, dropout=0.1, resample_size=1000, - use_concat=True, verbose=True, domain_ranges=None): + use_concat=True, verbose=True, domain_ranges=None, use_mlp_for_nmr=True): super().__init__() - # Only check divisibility by 3 if using concatenation - if use_concat and embed_dim % 3 != 0: + self.verbose = verbose + self.use_concat = use_concat + self.use_mlp_for_nmr = use_mlp_for_nmr + + # Only check divisibility by 3 if using concatenation and not using MLP + if use_concat and not use_mlp_for_nmr and embed_dim % 3 != 0: raise ValueError( f"When using concatenation (use_concat=True), embed_dim ({embed_dim}) " f"must be divisible by 3 (number of modalities) to ensure equal " f"dimension distribution across modalities." ) - self.verbose = verbose - # Unpack domain ranges if provided if domain_ranges: ir_range, h_nmr_range, c_nmr_range, _, _ = domain_ranges @@ -129,142 +139,141 @@ def __init__(self, embed_dim=768, num_heads=8, dropout=0.1, resample_size=1000, h_nmr_range = None c_nmr_range = None + # Create preprocessor only for IR if using MLP for NMR self.preprocessor = SpectralPreprocessor( resample_size=resample_size, - process_nmr=True, + process_nmr=not use_mlp_for_nmr, process_ir=True, - process_c_nmr=True, + process_c_nmr=not use_mlp_for_nmr, nmr_window=h_nmr_range, ir_window=ir_range, c_nmr_window=c_nmr_range ) - # Calculate individual backbone output dimensions - n_modalities = 3 - backbone_dim = embed_dim // n_modalities if use_concat else embed_dim + # Calculate backbone dimensions + if use_mlp_for_nmr: + backbone_dim = embed_dim # IR backbone uses full dimension + + # Create bottleneck MLPs for NMRs + def create_bottleneck_mlp(): + layers = [] + current_dim = 10000 + reduction_dims = [4096, 2048, 1024] + + # Add reduction layers until we get close to embed_dim + for dim in reduction_dims: + if dim < embed_dim: + break + layers.extend([ + nn.Linear(current_dim, dim), + nn.LayerNorm(dim), + nn.ReLU(), + nn.Dropout(dropout) + ]) + current_dim = dim + + # Final projection to embed_dim + layers.extend([ + nn.Linear(current_dim, embed_dim), + nn.LayerNorm(embed_dim) + ]) + + return nn.Sequential(*layers) + + self.h_nmr_mlp = create_bottleneck_mlp() + self.c_nmr_mlp = create_bottleneck_mlp() + + else: + n_modalities = 3 + backbone_dim = embed_dim // n_modalities if use_concat else embed_dim # Calculate the final sequence length after all downsampling final_seq_len = resample_size // 32 - # Base ConvNeXt config with appropriate final dimension + if verbose: + print(f"Input sequence length: {resample_size}") + print(f"Final sequence length: {final_seq_len}") + print(f"Backbone output dimension: {backbone_dim}") + print(f"Using concatenation: {use_concat}") + print(f"Using MLP for NMR: {use_mlp_for_nmr}") + + # Base ConvNeXt config base_config = { 'depths': [3, 3, 6, 3], - 'dims': [64, 128, 256, backbone_dim], # Final dim is embed_dim//3 if concat, else embed_dim + 'dims': [64, 128, 256, backbone_dim], 'drop_path_rate': 0.1, 'layer_scale_init_value': 1e-6, 'regression': True, - 'regression_dim': backbone_dim # Each backbone outputs embed_dim//3 if concat, else embed_dim + 'regression_dim': backbone_dim } - if verbose: - print(f"Input sequence length: {resample_size}") - print(f"Final sequence length: {final_seq_len}") - print(f"Backbone output dimension: {backbone_dim}") - print(f"Using concatenation: {use_concat}") - - # Create 1D backbones for all spectra - self.nmr_backbone = ConvNeXt1D(in_chans=1, **base_config) + # Create backbones + if not use_mlp_for_nmr: + self.nmr_backbone = ConvNeXt1D(in_chans=1, **base_config) + self.c_nmr_backbone = ConvNeXt1D(in_chans=1, **base_config) self.ir_backbone = ConvNeXt1D(in_chans=1, **base_config) - self.c_nmr_backbone = ConvNeXt1D(in_chans=1, **base_config) - - # Ensure all backbones are on the same device as the parent module - device = next(self.parameters()).device - self.nmr_backbone = self.nmr_backbone.to(device) - self.ir_backbone = self.ir_backbone.to(device) - self.c_nmr_backbone = self.c_nmr_backbone.to(device) - - self.use_concat = use_concat # Only create cross attention components if not using concatenation - if not use_concat: - # Add higher-order cross attention + if not use_concat and not use_mlp_for_nmr: cross_attn_config = type('Config', (), { 'n_head': num_heads, - 'n_embd': embed_dim, # Use full embed_dim instead of backbone_dim + 'n_embd': embed_dim, 'order': 3, 'dropout': dropout, 'bias': True }) self.cross_attention = HigherOrderMultiInputCrossAttention(cross_attn_config) - - # Add final layer norm self.final_norm = nn.LayerNorm(embed_dim) def forward(self, nmr_data, ir_data, c_nmr_data): if self.verbose: print("\nEncoder Processing:") print("Processing input data...") - - # Get the device of the model - device = next(self.parameters()).device - - # Preprocess the input data - x_nmr, x_ir, x_c_nmr = self.preprocessor(nmr_data, ir_data, c_nmr_data) - - if self.verbose: - print(f"Preprocessed shapes:") - print(f"NMR: {x_nmr.shape}") - print(f"IR: {x_ir.shape}") - print(f"C-NMR: {x_c_nmr.shape}") - - # Reshape inputs to [batch, channels, sequence] - if isinstance(x_nmr, tuple): - x_nmr = x_nmr[0] - if isinstance(x_ir, tuple): - x_ir = x_ir[0] - if isinstance(x_c_nmr, tuple): - x_c_nmr = x_c_nmr[0] - - # Add channel dimension and transpose if needed - if x_nmr.dim() == 2: - x_nmr = x_nmr.unsqueeze(1) # [batch, 1, sequence] - elif x_nmr.dim() == 3 and x_nmr.size(1) > x_nmr.size(2): # if [batch, sequence, 1] - x_nmr = x_nmr.transpose(1, 2) # [batch, 1, sequence] - if x_ir.dim() == 2: - x_ir = x_ir.unsqueeze(1) - elif x_ir.dim() == 3 and x_ir.size(1) > x_ir.size(2): - x_ir = x_ir.transpose(1, 2) - - if x_c_nmr.dim() == 2: - x_c_nmr = x_c_nmr.unsqueeze(1) - elif x_c_nmr.dim() == 3 and x_c_nmr.size(1) > x_c_nmr.size(2): - x_c_nmr = x_c_nmr.transpose(1, 2) - - if self.verbose: - print(f"Reshaped input shapes:") - print(f"NMR: {x_nmr.shape}") - print(f"IR: {x_ir.shape}") - print(f"C-NMR: {x_c_nmr.shape}") - - # Pass through backbones - emb_nmr = self.nmr_backbone(x_nmr, keep_sequence=True) # [B, seq_len, embed_dim//3] - emb_ir = self.ir_backbone(x_ir, keep_sequence=True) # [B, seq_len, embed_dim//3] - emb_c_nmr = self.c_nmr_backbone(x_c_nmr, keep_sequence=True) # [B, seq_len, embed_dim//3] + if self.use_mlp_for_nmr: + # Get raw NMR data + h_nmr_data = nmr_data[0] if isinstance(nmr_data, tuple) else nmr_data + c_nmr_data = c_nmr_data[0] if isinstance(c_nmr_data, tuple) else c_nmr_data + + # Process only IR through preprocessor + _, x_ir, _ = self.preprocessor(None, ir_data, None) + else: + # Process all data through preprocessor + x_nmr, x_ir, x_c_nmr = self.preprocessor(nmr_data, ir_data, c_nmr_data) - if self.verbose: - print(f"\nBackbone outputs:") - print(f"NMR embedding: {emb_nmr.shape}") - print(f"IR embedding: {emb_ir.shape}") - print(f"C-NMR embedding: {emb_c_nmr.shape}") + # Process through backbones/MLPs + if self.use_mlp_for_nmr: + # Process NMRs through MLPs + emb_nmr = self.h_nmr_mlp(h_nmr_data) # [B, embed_dim] + emb_nmr = emb_nmr.unsqueeze(1) # [B, 1, embed_dim] + + emb_c_nmr = self.c_nmr_mlp(c_nmr_data) # [B, embed_dim] + emb_c_nmr = emb_c_nmr.unsqueeze(1) # [B, 1, embed_dim] - if self.use_concat: - # All sequences should have same length after backbone processing - assert emb_nmr.size(1) == emb_ir.size(1) == emb_c_nmr.size(1), "Sequence lengths must match" + # Process IR through backbone + emb_ir = self.ir_backbone(x_ir, keep_sequence=True) # [B, seq_len, embed_dim] - # Concatenate along embedding dimension - result = torch.cat([emb_nmr, emb_ir, emb_c_nmr], dim=-1) # [B, seq_len, embed_dim] + # Concatenate along sequence dimension + result = torch.cat([emb_ir, emb_nmr, emb_c_nmr], dim=1) # [B, seq_len+2, embed_dim] if self.verbose: - print(f"\nFinal concatenated output: {result.shape}") + print(f"\nFinal output (MLP mode): {result.shape}") return result - else: - # Apply higher-order cross attention - fused = self.cross_attention(emb_nmr, emb_ir, emb_c_nmr) - # Apply final normalization - fused = self.final_norm(fused) + else: + # Original processing logic + emb_nmr = self.nmr_backbone(x_nmr, keep_sequence=True) + emb_ir = self.ir_backbone(x_ir, keep_sequence=True) + emb_c_nmr = self.c_nmr_backbone(x_c_nmr, keep_sequence=True) - if self.verbose: - print(f"\nFinal fused output: {fused.shape}") - return fused \ No newline at end of file + if self.use_concat: + result = torch.cat([emb_nmr, emb_ir, emb_c_nmr], dim=-1) + if self.verbose: + print(f"\nFinal concatenated output: {result.shape}") + return result + else: + fused = self.cross_attention(emb_nmr, emb_ir, emb_c_nmr) + fused = self.final_norm(fused) + if self.verbose: + print(f"\nFinal fused output: {fused.shape}") + return fused \ No newline at end of file diff --git a/train_autoregressive.py b/train_autoregressive.py index 1281288..76c5009 100644 --- a/train_autoregressive.py +++ b/train_autoregressive.py @@ -415,7 +415,8 @@ def load_config(config_path=None): 'num_layers': 6, 'dropout': 0.1, 'resample_size': 1000, - 'use_concat': True + 'use_concat': True, + 'use_mlp_for_nmr': True }, 'training': { 'batch_size': 32, @@ -545,7 +546,8 @@ def main(): resample_size=resample_size, domain_ranges=domain_ranges, verbose=False, - use_concat=config['model']['use_concat'] + use_concat=config['model']['use_concat'], + use_mlp_for_nmr=config['model'].get('use_mlp_for_nmr', True) ).to(device) print("[Main] Model initialized successfully")