diff --git a/piano/model/mil_factory/abmil/abmil.py b/piano/model/mil_factory/abmil/abmil.py index 9affcb9..1106422 100644 --- a/piano/model/mil_factory/abmil/abmil.py +++ b/piano/model/mil_factory/abmil/abmil.py @@ -1,25 +1,38 @@ import torch import torch.nn as nn import torch.nn.functional as F +from piano.model.mil_factory.layers.layers import create_mlp +from piano.model.mil_factory.layers.layers import GlobalAttention, GlobalGatedAttention from piano.utils.wsi_finetune_tools import NLLSurvLoss class ABMIL(nn.Module): - def __init__(self, dim_in, dim_hidden=None, dropout=0.25, num_classes=1000, survival=False): + def __init__(self, + in_dim: int = 1024, + embed_dim: int = 512, + num_fc_layers: int = 1, + dropout: float = 0.25, + attn_dim: int = 384, + num_classes: int = 2, + survival: bool = False): super().__init__() - if dim_hidden is None: - self.dim_hidden = dim_in // 2 - else: - self.dim_hidden = dim_hidden - self.survival = survival - - self.attn_module = nn.Sequential( - nn.Linear(dim_in, self.dim_hidden), - nn.Tanh(), - nn.Dropout(dropout), - nn.Linear(self.dim_hidden, 1) + self.patch_embed = create_mlp( + in_dim=in_dim, + hid_dims=[embed_dim] * + (num_fc_layers - 1), + dropout=dropout, + out_dim=embed_dim, + end_with_fc=False ) - self.fc = nn.Linear(dim_in, num_classes) + + self.global_attn = GlobalAttention( + L=embed_dim, + D=attn_dim, + dropout=dropout, + num_classes=1 + ) + + self.classifier = nn.Linear(embed_dim, num_classes) # Automatically select loss function based on survival or classification if self.survival: @@ -29,13 +42,12 @@ def __init__(self, dim_in, dim_hidden=None, dropout=0.25, num_classes=1000, surv def forward(self, input_dict, return_loss=True): x = input_dict['features'] # input_dict contain features, coords, and labels (optional) - attn = self.attn_module(x) # [B, N, 1] - A = torch.transpose(attn, -1, -2) # [B, 1, N] - A = torch.softmax(A, dim=-1) # [B, 1, N] - output = torch.matmul(A, x).squeeze(1) # [B, C] - logits = self.fc(output) - - # Initialize output dictionary + h = self.patch_embed(x) # Apply patch embedding MLP + attn = self.global_attn(x) # Apply global attention + A = torch.transpose(attn, -2, -1) # Transpose attention matrix + A = F.softmax(A,dim = -1) + h = torch.bmm(A,h).squeeze(dim=1) # Initialize output dictionary + logits = self.classifier(h) # Classify the aggregated features output_dict = { 'logits': logits, 'raw_attn': attn, @@ -72,30 +84,32 @@ def forward(self, input_dict, return_loss=True): class GatedABMIL(nn.Module): - def __init__(self, dim_in, dim_hidden=None, dropout=0.25, num_classes=1000, survival=False): + def __init__(self, + in_dim: int = 1024, + embed_dim: int = 512, + num_fc_layers: int = 1, + dropout: float = 0.25, + attn_dim: int = 384, + num_classes: int = 2, + survival: bool = False): super().__init__() - if dim_hidden is None: - self.dim_hidden = dim_in // 2 - else: - self.dim_hidden = dim_hidden - - self.survival = survival - - self.attn_1 = nn.Sequential( - nn.Linear(dim_in, self.dim_hidden), - nn.Tanh(), - nn.Dropout(dropout) + self.patch_embed = create_mlp( + in_dim=in_dim, + hid_dims=[embed_dim] * + (num_fc_layers - 1), + dropout=dropout, + out_dim=embed_dim, + end_with_fc=False ) - self.attn_2 = nn.Sequential( - nn.Linear(dim_in, self.dim_hidden), - nn.Sigmoid(), - nn.Dropout(dropout) + self.global_attn = GlobalGatedAttention( + L=embed_dim, + D=attn_dim, + dropout=dropout, + num_classes=1 ) - - self.attn_3 = nn.Linear(self.dim_hidden, 1) - - self.fc = nn.Linear(dim_in, num_classes) + + self.classifier = nn.Linear(embed_dim, num_classes) # Automatically select loss function based on survival or classification if self.survival: @@ -104,17 +118,13 @@ def __init__(self, dim_in, dim_hidden=None, dropout=0.25, num_classes=1000, surv self.loss_fn = nn.CrossEntropyLoss() def forward(self, input_dict, return_loss=True): - x = input_dict['features'] - attn_1 = self.attn_1(x) - attn_2 = self.attn_2(x) - attn = attn_1.mul(attn_2) - attn = self.attn_3(attn) - A = torch.transpose(attn, -1, -2) - A = torch.softmax(A, dim=-1) - output = torch.matmul(A, x).squeeze(1) - logits = self.fc(output) - - # Initialize output dictionary + x = input_dict['features'] # input_dict contain features, coords, and labels (optional) + h = self.patch_embed(x) # Apply patch embedding MLP + attn = self.global_attn(x) # Apply global attention + A = torch.transpose(attn, -2, -1) # Transpose attention matrix + A = F.softmax(A,dim = -1) + h = torch.bmm(A,h).squeeze(dim=1) # Initialize output dictionary + logits = self.classifier(h) # Classify the aggregated features output_dict = { 'logits': logits, 'raw_attn': attn, @@ -131,7 +141,7 @@ def forward(self, input_dict, return_loss=True): 'hazards': hazards, 'S': S }) - + # Loss calculation if return_loss and 'labels' in input_dict: if self.survival and 'events' in input_dict: @@ -145,6 +155,8 @@ def forward(self, input_dict, return_loss=True): loss = self.loss_fn(logits, input_dict['labels']) else: loss = None - + output_dict['loss'] = loss return output_dict + + diff --git a/piano/model/mil_factory/layers/layers.py b/piano/model/mil_factory/layers/layers.py new file mode 100644 index 0000000..393b34a --- /dev/null +++ b/piano/model/mil_factory/layers/layers.py @@ -0,0 +1,95 @@ +import torch.nn as nn + +def create_mlp( + in_dim=768, + hid_dims=[512, 512], + out_dim=512, + act=nn.ReLU(), + dropout=0., + end_with_fc=True, + end_with_dropout=False, + bias=True + ): + + layers = [] + if len(hid_dims) < 0: + mlp = nn.Identity() + elif len(hid_dims) >= 0: + if len(hid_dims) > 0: + for hid_dim in hid_dims: + layers.append(nn.Linear(in_dim, hid_dim, bias=bias)) + layers.append(act) + layers.append(nn.Dropout(dropout)) + in_dim = hid_dim + layers.append(nn.Linear(in_dim, out_dim)) + if not end_with_fc: + layers.append(act) + if end_with_dropout: + layers.append(nn.Dropout(dropout)) + mlp = nn.Sequential(*layers) + return mlp + + +# +# Attention networks +# +class GlobalAttention(nn.Module): + """ + Attention Network without Gating (2 fc layers) + args: + L: input feature dimension + D: hidden layer dimension + dropout: dropout + num_classes: number of classes + """ + + def __init__(self, L=1024, D=256, dropout=0., num_classes=1): + super().__init__() + self.module = [ + nn.Linear(L, D), + nn.Tanh(), + nn.Dropout(dropout), + nn.Linear(D, num_classes)] + + self.module = nn.Sequential(*self.module) + + def forward(self, x): + return self.module(x) # N x num_classes + + +class GlobalGatedAttention(nn.Module): + """ + Attention Network with Sigmoid Gating (3 fc layers) + args: + L: input feature dimension + D: hidden layer dimension + dropout: dropout + num_classes: number of classes + """ + + def __init__(self, L=1024, D=256, dropout=0., num_classes=1): + super().__init__() + + self.attention_a = [ + nn.Linear(L, D), + nn.Tanh(), + nn.Dropout(dropout) + ] + + self.attention_b = [ + nn.Linear(L, D), + nn.Sigmoid(), + nn.Dropout(dropout) + ] + + self.attention_a = nn.Sequential(*self.attention_a) + self.attention_b = nn.Sequential(*self.attention_b) + self.attention_c = nn.Linear(D, num_classes) + + def forward(self, x): + a = self.attention_a(x) + b = self.attention_b(x) + A = a.mul(b) + A = self.attention_c(A) # N x num_classes + return A + \ No newline at end of file