Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/models/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def train_neural_baseline_command(*args,**kwargs):
def train_autoencoder_command(*args, **kwargs):
train_autoencoder(*args,**kwargs)

@click.command(cls=NeuralCommand,name="train-sand")
@click.command(cls=CNNTransformer,name="train-sand")
def train_sand_command(*args,**kwargs):
train_sand(*args,**kwargs)

Expand Down
237 changes: 161 additions & 76 deletions src/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,68 +191,21 @@ def from_inverse_of_encoder(encoder):
max_pool_stride_size=None
)

class CNNToTransformerEncoder(pl.LightningModule):
class CNNTransformerBase(pl.LightningModule):
def __init__(self, input_features, num_attention_heads, num_hidden_layers, n_timesteps, kernel_sizes=[5,3,1], out_channels = [256,128,64],
stride_sizes=[2,2,2], dropout_rate=0.3, num_labels=2, learning_rate=1e-3, warmup_steps=100,
max_positional_embeddings = 1440*5, factor=64, inital_batch_size=100, clf_dropout_rate=0.0,
train_mix_positives_back_in=False, train_mixin_batch_size=3, skip_cnn=False, wandb_id=None,
positional_encoding = False, model_head="classification", no_bootstrap=False,
max_positional_embeddings = 1440*5, inital_batch_size=100,
train_mix_positives_back_in=False, train_mixin_batch_size=3, skip_cnn=False, wandb_id=None, #wandb_id is run id saved as hyperparameter
positional_encoding = False, factor=64, model_head="classification", no_bootstrap=False, clf_dropout_rate=0.0,
**model_specific_kwargs) -> None:

super().__init__()
self.config = get_config_from_locals(locals())
super(CNNToTransformerEncoder, self).__init__()

self.learning_rate = learning_rate
self.warmup_steps = warmup_steps
self.input_dim = (n_timesteps,input_features)

self.input_embedding = CNNEncoder(input_features, n_timesteps=n_timesteps, kernel_sizes=kernel_sizes,
out_channels=out_channels, stride_sizes=stride_sizes)

if not skip_cnn:
self.d_model = out_channels[-1]
final_length = self.input_embedding.final_output_length
else:
self.d_model = input_features
final_length = n_timesteps

if self.input_embedding.final_output_length < 1:
raise ValueError("CNN final output dim is <1 ")

if positional_encoding:
self.positional_encoding = modules.PositionalEncoding(self.d_model, final_length)
else:
self.positional_encoding = None

self.blocks = nn.ModuleList([
modules.EncoderBlock(self.d_model, num_attention_heads, dropout_rate) for _ in range(num_hidden_layers)
])

# self.dense_interpolation = modules.DenseInterpolation(final_length, factor)
self.is_classification = model_head == "classification"
if self.is_classification:
self.head = modules.ClassificationModule(self.d_model, final_length, num_labels,
dropout_p=clf_dropout_rate)
metric_class = TorchMetricClassification

else:
self.head = modules.RegressionModule(self.d_model, final_length, num_labels)
metric_class = TorchMetricRegression

self.train_metrics = metric_class(bootstrap_cis=False, prefix="train/")
self.eval_metrics = metric_class(bootstrap_cis=not no_bootstrap, prefix="eval/")
self.test_metrics = metric_class(bootstrap_cis=not no_bootstrap, prefix="test/")

self.provided_train_dataloader = None

self.criterion = build_loss_fn(model_specific_kwargs, task_type=model_head)
if num_attention_heads > 0:
self.name = "CNNToTransformerEncoder"
else:
self.name = "CNN"

self.base_model_prefix = self.name

self.train_probs = []
self.train_labels = []
self.train_mix_positives_back_in = train_mix_positives_back_in
Expand All @@ -277,29 +230,6 @@ def __init__(self, input_features, num_attention_heads, num_hidden_layers, n_tim
self.save_hyperparameters()


def forward(self, inputs_embeds,labels):
encoding = self.encode(inputs_embeds)
preds = self.head(encoding)
loss = self.criterion(preds,labels)
return loss, preds

def encode(self, inputs_embeds):
if not self.skip_cnn:
x = inputs_embeds.transpose(1, 2)
x = self.input_embedding(x)
x = x.transpose(1, 2)
else:
x = inputs_embeds

if self.positional_encoding:
x = self.positional_encoding(x)

for l in self.blocks:
x = l(x)

# x = self.dense_interpolation(x)
return x

def set_train_dataset(self,dataset):
self.train_dataset = dataset

Expand Down Expand Up @@ -480,7 +410,6 @@ def optimizer_step(
):
optimizer.step(closure=optimizer_closure)


def on_load_checkpoint(self, checkpoint: dict) -> None:
state_dict = checkpoint["state_dict"]
model_state_dict = self.state_dict()
Expand All @@ -500,6 +429,162 @@ def on_load_checkpoint(self, checkpoint: dict) -> None:
if is_changed:
checkpoint.pop("optimizer_states", None)

class CNNToTransformerEncoder(CNNTransformerBase):
def __init__(self, clf_dropout_rate=0.0, **kwargs):

super().__init__(**kwargs)
if kwargs.get('num_attention_heads') > 0:
self.name = "CNNToTransformerEncoder"
else:
self.name = "CNN"

self.base_model_prefix = self.name

self.input_embedding = CNNEncoder(kwargs.get('input_features'), n_timesteps=kwargs.get('n_timesteps'), kernel_sizes=kwargs.get('kernel_sizes'),
out_channels=kwargs.get('out_channels'), stride_sizes=kwargs.get('stride_sizes'))

if not kwargs.get('skip_cnn'):
self.d_model = kwargs.get('out_channels')[-1]
final_length = self.input_embedding.final_output_length
else:
self.d_model = kwargs.get('input_features')
final_length = kwargs.get('n_timesteps')

if self.input_embedding.final_output_length < 1:
raise ValueError("CNN final output dim is <1 ")

self.is_classification = kwargs.get('model_head', 'classification') == "classification"
if self.is_classification:
self.head = modules.ClassificationModule(self.d_model, final_length, kwargs.get('num_labels',2),
dropout_p=clf_dropout_rate)
metric_class = TorchMetricClassification

else:
self.head = modules.RegressionModule(self.d_model, final_length, kwargs.get('num_labels',2))
metric_class = TorchMetricRegression

self.train_metrics = metric_class(bootstrap_cis=False, prefix="train/")
self.eval_metrics = metric_class(bootstrap_cis=not kwargs.get('no_bootstrap',False), prefix="eval/")
self.test_metrics = metric_class(bootstrap_cis=not kwargs.get('no_bootstrap',False), prefix="test/")


if kwargs.get('positional_encoding'):
self.positional_encoding = modules.PositionalEncoding(self.d_model, final_length)
else:
self.positional_encoding = None

self.blocks = nn.ModuleList([
modules.EncoderBlock(self.d_model, kwargs.get('num_attention_heads'), kwargs.get('dropout_rate')) for _ in range(kwargs.get('num_hidden_layers'))
])

# self.dense_interpolation = modules.DenseInterpolation(final_length, factor)
self.clf = modules.ClassificationModule(self.d_model, final_length, kwargs.get('num_labels'),
dropout_p=clf_dropout_rate)
self.provided_train_dataloader = None
self.criterion = build_loss_fn(kwargs)

self.save_hyperparameters()


def forward(self, inputs_embeds,labels):
encoding = self.encode(inputs_embeds)
preds = self.head(encoding)
loss = self.criterion(preds,labels)
return loss, preds

def encode(self, inputs_embeds):
if not self.skip_cnn:
x = inputs_embeds.transpose(1, 2)
x = self.input_embedding(x)
x = x.transpose(1, 2)
else:
x = inputs_embeds

if self.positional_encoding:
x = self.positional_encoding(x)

for l in self.blocks:
x = l(x)

# x = self.dense_interpolation(x)
return x



class EncoderLayerForSAnD(nn.Module):
def __init__(self, input_features, n_heads, n_layers, d_model=128, dropout_rate=0.2) -> None:
super(EncoderLayerForSAnD, self).__init__()
self.d_model = d_model

self.input_embedding = nn.Conv1d(input_features, d_model, 1)
self.positional_encoding = modules.PositionalEncoding(d_model, input_features)
self.blocks = nn.ModuleList([
modules.EncoderBlock(d_model, n_heads, dropout_rate) for _ in range(n_layers)
])

def forward(self, x: torch.Tensor) -> torch.Tensor:
# x = x.transpose(1, 2)
x = self.input_embedding(x)
x = x.transpose(1, 2)

x = self.positional_encoding(x)

for l in self.blocks:
x = l(x)

return x
class LightningSAnD(CNNTransformerBase):
"""
Simply Attend and Diagnose model

The Thirty-Second AAAI Conference on Artificial Intelligence (AAAI-18)

`Attend and Diagnose: Clinical Time Series Analysis Using Attention Models <https://arxiv.org/abs/1711.03905>`_
Huan Song, Deepta Rajan, Jayaraman J. Thiagarajan, Andreas Spanias

NOTE: dropout_rate default value changes from 0.2 to 0.5
"""
def __init__(
self,
**kwargs) -> None: #NOTE default value for one argument passed here are really ignored when also using kwargs.get(argument)

super().__init__(**kwargs)
self.name = "SAnD"
self.base_model_prefix = self.name

self.encoder = EncoderLayerForSAnD(kwargs.get('n_timesteps'), #if any of these arguments is not passed, it becomes None and breaks execution
kwargs.get('num_attention_heads'),
kwargs.get('num_hidden_layers'),
kwargs.get('d_model', 128),
kwargs.get('dropout_rate'))
self.dense_interpolation = modules.DenseInterpolation(kwargs.get('input_features'), kwargs.get('factor', 256))

self.is_classification = kwargs.get('model_head', 'classification') == "classification"
if self.is_classification:
self.head = modules.ClassificationModule(kwargs.get('d_model', 128), kwargs.get('factor', 256), kwargs.get('num_labels'))
metric_class = TorchMetricClassification
else:
self.head = modules.RegressionModule(kwargs.get('d_model', 128), kwargs.get('factor', 256), kwargs.get('num_labels'))
metric_class = TorchMetricRegression

loss_weights = torch.tensor([float(kwargs.get('neg_class_weight',1)),float(kwargs.get('pos_class_weight',1))])
self.criterion = nn.CrossEntropyLoss(weight=loss_weights)
self.n_class = kwargs.get('num_labels')

self.train_metrics = metric_class(bootstrap_cis=False, prefix="train/")
self.eval_metrics = metric_class(bootstrap_cis=not kwargs.get('no_bootstrap',False), prefix="eval/")
self.test_metrics = metric_class(bootstrap_cis=not kwargs.get('no_bootstrap',False), prefix="test/")

self.save_hyperparameters()

def forward(self, inputs_embeds: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
x = self.encoder(inputs_embeds)
x = self.dense_interpolation(x)
x = self.head(x)
loss = self.criterion(x.view(-1,self.n_class),labels.view(-1))
return loss, x


class CNNToTransformerAutoEncoder(pl.LightningModule):
def __init__(self, input_features, num_attention_heads, num_hidden_layers,
n_timesteps, kernel_sizes=[5, 3, 1], out_channels=[256, 128, 64],
Expand Down
Loading