Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions configs/local_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
model:
use_concat: True
max_seq_length: 80
embed_dim: 2048
num_heads: 16
num_layers: 8
embed_dim: 1024
num_heads: 8
num_layers: 6
dropout: 0.1
# dropout: 0.0
resample_size: 1000

use_mlp_for_nmr: True #default is true
# Training Parameters
training:
batch_size: 64
batch_size: 32
# batch_size: 32
test_batch_size: 1
num_epochs: 5
Expand Down
6 changes: 4 additions & 2 deletions models/multimodal_to_smiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def __init__(
resample_size: int = 1000,
use_concat: bool = True,
verbose: bool = False,
domain_ranges: list | None = None
domain_ranges: list | None = None,
use_mlp_for_nmr: bool = True
):
super().__init__()

Expand All @@ -34,7 +35,8 @@ def __init__(
resample_size=resample_size,
use_concat=use_concat,
verbose=verbose,
domain_ranges=domain_ranges
domain_ranges=domain_ranges,
use_mlp_for_nmr=use_mlp_for_nmr
)

# Calculate decoder input dimension
Expand Down
219 changes: 114 additions & 105 deletions models/spectral_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __call__(self, nmr_data, ir_data, c_nmr_data):
x_nmr = torch.from_numpy(x_nmr).float()
x_nmr = x_nmr.unsqueeze(1) # Add channel dimension
else:
# Pass through raw data when not processing
x_nmr = nmr_data[0] if isinstance(nmr_data, tuple) else nmr_data

# Process IR - expects [batch, length] -> outputs [batch, 1, length]
Expand Down Expand Up @@ -95,32 +96,41 @@ def __call__(self, nmr_data, ir_data, c_nmr_data):
x_c_nmr = torch.from_numpy(x_c_nmr).float()
x_c_nmr = x_c_nmr.unsqueeze(1) # Add channel dimension
else:
# Pass through raw data when not processing
x_c_nmr = c_nmr_data[0] if isinstance(c_nmr_data, tuple) else c_nmr_data

# Move to device if inputs are on device
if isinstance(nmr_data[0], torch.Tensor):
device = nmr_data[0].device
x_nmr = x_nmr.to(device)
x_ir = x_ir.to(device)
x_c_nmr = x_c_nmr.to(device)
device = None
if ir_data is not None and isinstance(ir_data, tuple) and isinstance(ir_data[0], torch.Tensor):
device = ir_data[0].device

if device is not None:
if x_ir is not None and isinstance(x_ir, torch.Tensor):
x_ir = x_ir.to(device)
if x_nmr is not None and isinstance(x_nmr, torch.Tensor):
x_nmr = x_nmr.to(device)
if x_c_nmr is not None and isinstance(x_c_nmr, torch.Tensor):
x_c_nmr = x_c_nmr.to(device)

return x_nmr, x_ir, x_c_nmr

class MultimodalSpectralEncoder(nn.Module):
def __init__(self, embed_dim=768, num_heads=8, dropout=0.1, resample_size=1000,
use_concat=True, verbose=True, domain_ranges=None):
use_concat=True, verbose=True, domain_ranges=None, use_mlp_for_nmr=True):
super().__init__()

# Only check divisibility by 3 if using concatenation
if use_concat and embed_dim % 3 != 0:
self.verbose = verbose
self.use_concat = use_concat
self.use_mlp_for_nmr = use_mlp_for_nmr

# Only check divisibility by 3 if using concatenation and not using MLP
if use_concat and not use_mlp_for_nmr and embed_dim % 3 != 0:
raise ValueError(
f"When using concatenation (use_concat=True), embed_dim ({embed_dim}) "
f"must be divisible by 3 (number of modalities) to ensure equal "
f"dimension distribution across modalities."
)

self.verbose = verbose

# Unpack domain ranges if provided
if domain_ranges:
ir_range, h_nmr_range, c_nmr_range, _, _ = domain_ranges
Expand All @@ -129,142 +139,141 @@ def __init__(self, embed_dim=768, num_heads=8, dropout=0.1, resample_size=1000,
h_nmr_range = None
c_nmr_range = None

# Create preprocessor only for IR if using MLP for NMR
self.preprocessor = SpectralPreprocessor(
resample_size=resample_size,
process_nmr=True,
process_nmr=not use_mlp_for_nmr,
process_ir=True,
process_c_nmr=True,
process_c_nmr=not use_mlp_for_nmr,
nmr_window=h_nmr_range,
ir_window=ir_range,
c_nmr_window=c_nmr_range
)

# Calculate individual backbone output dimensions
n_modalities = 3
backbone_dim = embed_dim // n_modalities if use_concat else embed_dim
# Calculate backbone dimensions
if use_mlp_for_nmr:
backbone_dim = embed_dim # IR backbone uses full dimension

# Create bottleneck MLPs for NMRs
def create_bottleneck_mlp():
layers = []
current_dim = 10000
reduction_dims = [4096, 2048, 1024]

# Add reduction layers until we get close to embed_dim
for dim in reduction_dims:
if dim < embed_dim:
break
layers.extend([
nn.Linear(current_dim, dim),
nn.LayerNorm(dim),
nn.ReLU(),
nn.Dropout(dropout)
])
current_dim = dim

# Final projection to embed_dim
layers.extend([
nn.Linear(current_dim, embed_dim),
nn.LayerNorm(embed_dim)
])

return nn.Sequential(*layers)

self.h_nmr_mlp = create_bottleneck_mlp()
self.c_nmr_mlp = create_bottleneck_mlp()

else:
n_modalities = 3
backbone_dim = embed_dim // n_modalities if use_concat else embed_dim

# Calculate the final sequence length after all downsampling
final_seq_len = resample_size // 32

# Base ConvNeXt config with appropriate final dimension
if verbose:
print(f"Input sequence length: {resample_size}")
print(f"Final sequence length: {final_seq_len}")
print(f"Backbone output dimension: {backbone_dim}")
print(f"Using concatenation: {use_concat}")
print(f"Using MLP for NMR: {use_mlp_for_nmr}")

# Base ConvNeXt config
base_config = {
'depths': [3, 3, 6, 3],
'dims': [64, 128, 256, backbone_dim], # Final dim is embed_dim//3 if concat, else embed_dim
'dims': [64, 128, 256, backbone_dim],
'drop_path_rate': 0.1,
'layer_scale_init_value': 1e-6,
'regression': True,
'regression_dim': backbone_dim # Each backbone outputs embed_dim//3 if concat, else embed_dim
'regression_dim': backbone_dim
}

if verbose:
print(f"Input sequence length: {resample_size}")
print(f"Final sequence length: {final_seq_len}")
print(f"Backbone output dimension: {backbone_dim}")
print(f"Using concatenation: {use_concat}")

# Create 1D backbones for all spectra
self.nmr_backbone = ConvNeXt1D(in_chans=1, **base_config)
# Create backbones
if not use_mlp_for_nmr:
self.nmr_backbone = ConvNeXt1D(in_chans=1, **base_config)
self.c_nmr_backbone = ConvNeXt1D(in_chans=1, **base_config)
self.ir_backbone = ConvNeXt1D(in_chans=1, **base_config)
self.c_nmr_backbone = ConvNeXt1D(in_chans=1, **base_config)

# Ensure all backbones are on the same device as the parent module
device = next(self.parameters()).device
self.nmr_backbone = self.nmr_backbone.to(device)
self.ir_backbone = self.ir_backbone.to(device)
self.c_nmr_backbone = self.c_nmr_backbone.to(device)

self.use_concat = use_concat

# Only create cross attention components if not using concatenation
if not use_concat:
# Add higher-order cross attention
if not use_concat and not use_mlp_for_nmr:
cross_attn_config = type('Config', (), {
'n_head': num_heads,
'n_embd': embed_dim, # Use full embed_dim instead of backbone_dim
'n_embd': embed_dim,
'order': 3,
'dropout': dropout,
'bias': True
})
self.cross_attention = HigherOrderMultiInputCrossAttention(cross_attn_config)

# Add final layer norm
self.final_norm = nn.LayerNorm(embed_dim)

def forward(self, nmr_data, ir_data, c_nmr_data):
if self.verbose:
print("\nEncoder Processing:")
print("Processing input data...")

# Get the device of the model
device = next(self.parameters()).device

# Preprocess the input data
x_nmr, x_ir, x_c_nmr = self.preprocessor(nmr_data, ir_data, c_nmr_data)

if self.verbose:
print(f"Preprocessed shapes:")
print(f"NMR: {x_nmr.shape}")
print(f"IR: {x_ir.shape}")
print(f"C-NMR: {x_c_nmr.shape}")

# Reshape inputs to [batch, channels, sequence]
if isinstance(x_nmr, tuple):
x_nmr = x_nmr[0]
if isinstance(x_ir, tuple):
x_ir = x_ir[0]
if isinstance(x_c_nmr, tuple):
x_c_nmr = x_c_nmr[0]

# Add channel dimension and transpose if needed
if x_nmr.dim() == 2:
x_nmr = x_nmr.unsqueeze(1) # [batch, 1, sequence]
elif x_nmr.dim() == 3 and x_nmr.size(1) > x_nmr.size(2): # if [batch, sequence, 1]
x_nmr = x_nmr.transpose(1, 2) # [batch, 1, sequence]

if x_ir.dim() == 2:
x_ir = x_ir.unsqueeze(1)
elif x_ir.dim() == 3 and x_ir.size(1) > x_ir.size(2):
x_ir = x_ir.transpose(1, 2)

if x_c_nmr.dim() == 2:
x_c_nmr = x_c_nmr.unsqueeze(1)
elif x_c_nmr.dim() == 3 and x_c_nmr.size(1) > x_c_nmr.size(2):
x_c_nmr = x_c_nmr.transpose(1, 2)

if self.verbose:
print(f"Reshaped input shapes:")
print(f"NMR: {x_nmr.shape}")
print(f"IR: {x_ir.shape}")
print(f"C-NMR: {x_c_nmr.shape}")

# Pass through backbones
emb_nmr = self.nmr_backbone(x_nmr, keep_sequence=True) # [B, seq_len, embed_dim//3]
emb_ir = self.ir_backbone(x_ir, keep_sequence=True) # [B, seq_len, embed_dim//3]
emb_c_nmr = self.c_nmr_backbone(x_c_nmr, keep_sequence=True) # [B, seq_len, embed_dim//3]
if self.use_mlp_for_nmr:
# Get raw NMR data
h_nmr_data = nmr_data[0] if isinstance(nmr_data, tuple) else nmr_data
c_nmr_data = c_nmr_data[0] if isinstance(c_nmr_data, tuple) else c_nmr_data

# Process only IR through preprocessor
_, x_ir, _ = self.preprocessor(None, ir_data, None)
else:
# Process all data through preprocessor
x_nmr, x_ir, x_c_nmr = self.preprocessor(nmr_data, ir_data, c_nmr_data)

if self.verbose:
print(f"\nBackbone outputs:")
print(f"NMR embedding: {emb_nmr.shape}")
print(f"IR embedding: {emb_ir.shape}")
print(f"C-NMR embedding: {emb_c_nmr.shape}")
# Process through backbones/MLPs
if self.use_mlp_for_nmr:
# Process NMRs through MLPs
emb_nmr = self.h_nmr_mlp(h_nmr_data) # [B, embed_dim]
emb_nmr = emb_nmr.unsqueeze(1) # [B, 1, embed_dim]

emb_c_nmr = self.c_nmr_mlp(c_nmr_data) # [B, embed_dim]
emb_c_nmr = emb_c_nmr.unsqueeze(1) # [B, 1, embed_dim]

if self.use_concat:
# All sequences should have same length after backbone processing
assert emb_nmr.size(1) == emb_ir.size(1) == emb_c_nmr.size(1), "Sequence lengths must match"
# Process IR through backbone
emb_ir = self.ir_backbone(x_ir, keep_sequence=True) # [B, seq_len, embed_dim]

# Concatenate along embedding dimension
result = torch.cat([emb_nmr, emb_ir, emb_c_nmr], dim=-1) # [B, seq_len, embed_dim]
# Concatenate along sequence dimension
result = torch.cat([emb_ir, emb_nmr, emb_c_nmr], dim=1) # [B, seq_len+2, embed_dim]

if self.verbose:
print(f"\nFinal concatenated output: {result.shape}")
print(f"\nFinal output (MLP mode): {result.shape}")
return result
else:
# Apply higher-order cross attention
fused = self.cross_attention(emb_nmr, emb_ir, emb_c_nmr)

# Apply final normalization
fused = self.final_norm(fused)
else:
# Original processing logic
emb_nmr = self.nmr_backbone(x_nmr, keep_sequence=True)
emb_ir = self.ir_backbone(x_ir, keep_sequence=True)
emb_c_nmr = self.c_nmr_backbone(x_c_nmr, keep_sequence=True)

if self.verbose:
print(f"\nFinal fused output: {fused.shape}")
return fused
if self.use_concat:
result = torch.cat([emb_nmr, emb_ir, emb_c_nmr], dim=-1)
if self.verbose:
print(f"\nFinal concatenated output: {result.shape}")
return result
else:
fused = self.cross_attention(emb_nmr, emb_ir, emb_c_nmr)
fused = self.final_norm(fused)
if self.verbose:
print(f"\nFinal fused output: {fused.shape}")
return fused
6 changes: 4 additions & 2 deletions train_autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,8 @@ def load_config(config_path=None):
'num_layers': 6,
'dropout': 0.1,
'resample_size': 1000,
'use_concat': True
'use_concat': True,
'use_mlp_for_nmr': True
},
'training': {
'batch_size': 32,
Expand Down Expand Up @@ -545,7 +546,8 @@ def main():
resample_size=resample_size,
domain_ranges=domain_ranges,
verbose=False,
use_concat=config['model']['use_concat']
use_concat=config['model']['use_concat'],
use_mlp_for_nmr=config['model'].get('use_mlp_for_nmr', True)
).to(device)
print("[Main] Model initialized successfully")

Expand Down