Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 65 additions & 53 deletions piano/model/mil_factory/abmil/abmil.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,38 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from piano.model.mil_factory.layers.layers import create_mlp
from piano.model.mil_factory.layers.layers import GlobalAttention, GlobalGatedAttention
from piano.utils.wsi_finetune_tools import NLLSurvLoss


class ABMIL(nn.Module):
def __init__(self, dim_in, dim_hidden=None, dropout=0.25, num_classes=1000, survival=False):
def __init__(self,
in_dim: int = 1024,
embed_dim: int = 512,
num_fc_layers: int = 1,
dropout: float = 0.25,
attn_dim: int = 384,
num_classes: int = 2,
survival: bool = False):
super().__init__()
if dim_hidden is None:
self.dim_hidden = dim_in // 2
else:
self.dim_hidden = dim_hidden
self.survival = survival

self.attn_module = nn.Sequential(
nn.Linear(dim_in, self.dim_hidden),
nn.Tanh(),
nn.Dropout(dropout),
nn.Linear(self.dim_hidden, 1)
self.patch_embed = create_mlp(
in_dim=in_dim,
hid_dims=[embed_dim] *
(num_fc_layers - 1),
dropout=dropout,
out_dim=embed_dim,
end_with_fc=False
)
self.fc = nn.Linear(dim_in, num_classes)

self.global_attn = GlobalAttention(
L=embed_dim,
D=attn_dim,
dropout=dropout,
num_classes=1
)

self.classifier = nn.Linear(embed_dim, num_classes)

# Automatically select loss function based on survival or classification
if self.survival:
Expand All @@ -29,13 +42,12 @@ def __init__(self, dim_in, dim_hidden=None, dropout=0.25, num_classes=1000, surv

def forward(self, input_dict, return_loss=True):
x = input_dict['features'] # input_dict contain features, coords, and labels (optional)
attn = self.attn_module(x) # [B, N, 1]
A = torch.transpose(attn, -1, -2) # [B, 1, N]
A = torch.softmax(A, dim=-1) # [B, 1, N]
output = torch.matmul(A, x).squeeze(1) # [B, C]
logits = self.fc(output)

# Initialize output dictionary
h = self.patch_embed(x) # Apply patch embedding MLP
attn = self.global_attn(x) # Apply global attention
A = torch.transpose(attn, -2, -1) # Transpose attention matrix
A = F.softmax(A,dim = -1)
h = torch.bmm(A,h).squeeze(dim=1) # Initialize output dictionary
logits = self.classifier(h) # Classify the aggregated features
output_dict = {
'logits': logits,
'raw_attn': attn,
Expand Down Expand Up @@ -72,30 +84,32 @@ def forward(self, input_dict, return_loss=True):


class GatedABMIL(nn.Module):
def __init__(self, dim_in, dim_hidden=None, dropout=0.25, num_classes=1000, survival=False):
def __init__(self,
in_dim: int = 1024,
embed_dim: int = 512,
num_fc_layers: int = 1,
dropout: float = 0.25,
attn_dim: int = 384,
num_classes: int = 2,
survival: bool = False):
super().__init__()
if dim_hidden is None:
self.dim_hidden = dim_in // 2
else:
self.dim_hidden = dim_hidden

self.survival = survival

self.attn_1 = nn.Sequential(
nn.Linear(dim_in, self.dim_hidden),
nn.Tanh(),
nn.Dropout(dropout)
self.patch_embed = create_mlp(
in_dim=in_dim,
hid_dims=[embed_dim] *
(num_fc_layers - 1),
dropout=dropout,
out_dim=embed_dim,
end_with_fc=False
)

self.attn_2 = nn.Sequential(
nn.Linear(dim_in, self.dim_hidden),
nn.Sigmoid(),
nn.Dropout(dropout)
self.global_attn = GlobalGatedAttention(
L=embed_dim,
D=attn_dim,
dropout=dropout,
num_classes=1
)

self.attn_3 = nn.Linear(self.dim_hidden, 1)

self.fc = nn.Linear(dim_in, num_classes)

self.classifier = nn.Linear(embed_dim, num_classes)

# Automatically select loss function based on survival or classification
if self.survival:
Expand All @@ -104,17 +118,13 @@ def __init__(self, dim_in, dim_hidden=None, dropout=0.25, num_classes=1000, surv
self.loss_fn = nn.CrossEntropyLoss()

def forward(self, input_dict, return_loss=True):
x = input_dict['features']
attn_1 = self.attn_1(x)
attn_2 = self.attn_2(x)
attn = attn_1.mul(attn_2)
attn = self.attn_3(attn)
A = torch.transpose(attn, -1, -2)
A = torch.softmax(A, dim=-1)
output = torch.matmul(A, x).squeeze(1)
logits = self.fc(output)

# Initialize output dictionary
x = input_dict['features'] # input_dict contain features, coords, and labels (optional)
h = self.patch_embed(x) # Apply patch embedding MLP
attn = self.global_attn(x) # Apply global attention
A = torch.transpose(attn, -2, -1) # Transpose attention matrix
A = F.softmax(A,dim = -1)
h = torch.bmm(A,h).squeeze(dim=1) # Initialize output dictionary
logits = self.classifier(h) # Classify the aggregated features
output_dict = {
'logits': logits,
'raw_attn': attn,
Expand All @@ -131,7 +141,7 @@ def forward(self, input_dict, return_loss=True):
'hazards': hazards,
'S': S
})

# Loss calculation
if return_loss and 'labels' in input_dict:
if self.survival and 'events' in input_dict:
Expand All @@ -145,6 +155,8 @@ def forward(self, input_dict, return_loss=True):
loss = self.loss_fn(logits, input_dict['labels'])
else:
loss = None

output_dict['loss'] = loss
return output_dict


95 changes: 95 additions & 0 deletions piano/model/mil_factory/layers/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch.nn as nn

def create_mlp(
in_dim=768,
hid_dims=[512, 512],
out_dim=512,
act=nn.ReLU(),
dropout=0.,
end_with_fc=True,
end_with_dropout=False,
bias=True
):

layers = []
if len(hid_dims) < 0:
mlp = nn.Identity()
elif len(hid_dims) >= 0:
if len(hid_dims) > 0:
for hid_dim in hid_dims:
layers.append(nn.Linear(in_dim, hid_dim, bias=bias))
layers.append(act)
layers.append(nn.Dropout(dropout))
in_dim = hid_dim
layers.append(nn.Linear(in_dim, out_dim))
if not end_with_fc:
layers.append(act)
if end_with_dropout:
layers.append(nn.Dropout(dropout))
mlp = nn.Sequential(*layers)
return mlp


#
# Attention networks
#
class GlobalAttention(nn.Module):
"""
Attention Network without Gating (2 fc layers)
args:
L: input feature dimension
D: hidden layer dimension
dropout: dropout
num_classes: number of classes
"""

def __init__(self, L=1024, D=256, dropout=0., num_classes=1):
super().__init__()
self.module = [
nn.Linear(L, D),
nn.Tanh(),
nn.Dropout(dropout),
nn.Linear(D, num_classes)]

self.module = nn.Sequential(*self.module)

def forward(self, x):
return self.module(x) # N x num_classes


class GlobalGatedAttention(nn.Module):
"""
Attention Network with Sigmoid Gating (3 fc layers)
args:
L: input feature dimension
D: hidden layer dimension
dropout: dropout
num_classes: number of classes
"""

def __init__(self, L=1024, D=256, dropout=0., num_classes=1):
super().__init__()

self.attention_a = [
nn.Linear(L, D),
nn.Tanh(),
nn.Dropout(dropout)
]

self.attention_b = [
nn.Linear(L, D),
nn.Sigmoid(),
nn.Dropout(dropout)
]

self.attention_a = nn.Sequential(*self.attention_a)
self.attention_b = nn.Sequential(*self.attention_b)
self.attention_c = nn.Linear(D, num_classes)

def forward(self, x):
a = self.attention_a(x)
b = self.attention_b(x)
A = a.mul(b)
A = self.attention_c(A) # N x num_classes
return A