Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions configs/local_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ training:
validation_frequency: 100
logging_frequency: 10
save_frequency: 5_000
# TODO: allow multiprocess with unimol?
num_workers: 0
ce_weight: 1.0
cls_enc_weight: 1.0
cls_dec_weight: 0.05

# Scheduler
scheduler:
Expand Down
1 change: 1 addition & 0 deletions configs/local_config_nocat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ training:
validation_frequency: 100
logging_frequency: 10
save_frequency: 5_000
num_workers: 8

# Scheduler
scheduler:
Expand Down
3 changes: 2 additions & 1 deletion configs/test_config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Model Architecture
model:
use_concat: True
max_seq_length: 1
max_seq_length: 4
embed_dim: 3
num_heads: 1
num_layers: 1
Expand All @@ -18,6 +18,7 @@ training:
validation_frequency: 50
logging_frequency: 10
save_frequency: 1000
num_workers: 8

# Scheduler
scheduler:
Expand Down
22 changes: 18 additions & 4 deletions models/multimodal_to_smiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ def __init__(
resample_size: int = 1000,
use_concat: bool = True,
verbose: bool = False,
domain_ranges: list | None = None
domain_ranges: list | None = None,
cls_dim: int = 512
):
super().__init__()

self.use_concat = use_concat
self.verbose = verbose
self.cls_dim = cls_dim

# Spectral encoder with verbose off
memory_dim = 2046 if use_concat else embed_dim
Expand Down Expand Up @@ -52,7 +54,15 @@ def __init__(
verbose=verbose
)

# extra projections for cls
self.to_cls_enc = th.nn.Linear(memory_dim, cls_dim)
self.to_cls_dec = th.nn.Linear(embed_dim, cls_dim)

def forward(self, nmr_data: tuple | th.Tensor | None, ir_data: tuple | th.Tensor | None, c_nmr_data: tuple | th.Tensor | None, target_seq: Any | None = None, target_mask: th.Tensor | None = None):
""" Returns the cls (dense encoding for molecule) and logits to sample smiles from
* (cls_enc, cls_dec) = 2x (B, D)
* logits: (B, L, D)
"""
if self.verbose:
print("\n=== Starting Forward Pass ===")
print("\nSpectroscopic data shapes inside forward:")
Expand All @@ -77,11 +87,15 @@ def forward(self, nmr_data: tuple | th.Tensor | None, ir_data: tuple | th.Tensor

if self.verbose:
print("\n=== Starting Decoding ===")

# Decode to SMILES
logits = self.decoder(target_seq, memory, target_mask)
cls, logits = self.decoder(target_seq, memory, target_mask)

if self.verbose:
print("\n=== Forward Pass Complete ===")

# [CLS] by mean pooling: (B, D)
cls_enc = self.to_cls_enc(memory).mean(dim=-2)
cls_dec = self.to_cls_dec(cls)

return logits
return (cls_enc, cls_dec), logits
16 changes: 11 additions & 5 deletions models/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,9 @@ def __init__(
# original:
# self.memory_proj = nn.Linear(memory_dim, memory_dim)
self.memory_proj = nn.Linear(memory_dim, embed_dim) if memory_dim != embed_dim else nn.Identity()

# absolute posemb (will also replace memory posemb)
self.pos_token = th.nn.Parameter(1e-3*th.randn(1, max_seq_length + 256, embed_dim))

# Create decoder layers with updated memory dimension
self.layers = nn.ModuleList([
# DecoderLayer(
Expand Down Expand Up @@ -284,14 +286,17 @@ def forward(self, tgt: th.Tensor, memory: th.Tensor, tgt_mask: th.Tensor | None
print(f"memory.shape: {memory.shape}")
memory = memory.unsqueeze(1)

# Expand memory batch dimension if needed
# if memory.size(0) == 1 and x.size(0) > 1:
# memory = memory.expand(x.size(0), -1, -1)
# Add cls token:
cls_token = th.zeros_like(memory[:, -1:, :])
memory = th.cat([memory, cls_token], dim=1)

# mix `memory` and `x` as prompt + answer, and add causal mask
M = memory.shape[-2]
x = th.cat([memory, x], dim=1)

# add absolute posemb
x = x + self.pos_token[:, :T+M].repeat(B, 1, 1)

mask = th.ones(B, T+M, T+M, device=x.device).tril().bool()
# memory can attend to all of itself. Unlike causal decoding
mask[:, :M, :M] = True
Expand All @@ -306,4 +311,5 @@ def forward(self, tgt: th.Tensor, memory: th.Tensor, tgt_mask: th.Tensor | None

# remove memory prompt from output tokens
out = self.out(x[:, M:])
return out
cls = x[:, M-1:M]
return cls, out
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ scikit-learn
rdkit
tqdm
seaborn
unimol_tools
Loading