diff --git a/src/models/commands.py b/src/models/commands.py index c5b337b..14806be 100644 --- a/src/models/commands.py +++ b/src/models/commands.py @@ -159,7 +159,7 @@ def train_neural_baseline_command(*args,**kwargs): def train_autoencoder_command(*args, **kwargs): train_autoencoder(*args,**kwargs) -@click.command(cls=NeuralCommand,name="train-sand") +@click.command(cls=CNNTransformer,name="train-sand") def train_sand_command(*args,**kwargs): train_sand(*args,**kwargs) diff --git a/src/models/models.py b/src/models/models.py index c29a093..96a74a1 100644 --- a/src/models/models.py +++ b/src/models/models.py @@ -191,68 +191,21 @@ def from_inverse_of_encoder(encoder): max_pool_stride_size=None ) -class CNNToTransformerEncoder(pl.LightningModule): +class CNNTransformerBase(pl.LightningModule): def __init__(self, input_features, num_attention_heads, num_hidden_layers, n_timesteps, kernel_sizes=[5,3,1], out_channels = [256,128,64], stride_sizes=[2,2,2], dropout_rate=0.3, num_labels=2, learning_rate=1e-3, warmup_steps=100, - max_positional_embeddings = 1440*5, factor=64, inital_batch_size=100, clf_dropout_rate=0.0, - train_mix_positives_back_in=False, train_mixin_batch_size=3, skip_cnn=False, wandb_id=None, - positional_encoding = False, model_head="classification", no_bootstrap=False, + max_positional_embeddings = 1440*5, inital_batch_size=100, + train_mix_positives_back_in=False, train_mixin_batch_size=3, skip_cnn=False, wandb_id=None, #wandb_id is run id saved as hyperparameter + positional_encoding = False, factor=64, model_head="classification", no_bootstrap=False, clf_dropout_rate=0.0, **model_specific_kwargs) -> None: + super().__init__() self.config = get_config_from_locals(locals()) - super(CNNToTransformerEncoder, self).__init__() self.learning_rate = learning_rate self.warmup_steps = warmup_steps self.input_dim = (n_timesteps,input_features) - self.input_embedding = CNNEncoder(input_features, n_timesteps=n_timesteps, kernel_sizes=kernel_sizes, - out_channels=out_channels, stride_sizes=stride_sizes) - - if not skip_cnn: - self.d_model = out_channels[-1] - final_length = self.input_embedding.final_output_length - else: - self.d_model = input_features - final_length = n_timesteps - - if self.input_embedding.final_output_length < 1: - raise ValueError("CNN final output dim is <1 ") - - if positional_encoding: - self.positional_encoding = modules.PositionalEncoding(self.d_model, final_length) - else: - self.positional_encoding = None - - self.blocks = nn.ModuleList([ - modules.EncoderBlock(self.d_model, num_attention_heads, dropout_rate) for _ in range(num_hidden_layers) - ]) - - # self.dense_interpolation = modules.DenseInterpolation(final_length, factor) - self.is_classification = model_head == "classification" - if self.is_classification: - self.head = modules.ClassificationModule(self.d_model, final_length, num_labels, - dropout_p=clf_dropout_rate) - metric_class = TorchMetricClassification - - else: - self.head = modules.RegressionModule(self.d_model, final_length, num_labels) - metric_class = TorchMetricRegression - - self.train_metrics = metric_class(bootstrap_cis=False, prefix="train/") - self.eval_metrics = metric_class(bootstrap_cis=not no_bootstrap, prefix="eval/") - self.test_metrics = metric_class(bootstrap_cis=not no_bootstrap, prefix="test/") - - self.provided_train_dataloader = None - - self.criterion = build_loss_fn(model_specific_kwargs, task_type=model_head) - if num_attention_heads > 0: - self.name = "CNNToTransformerEncoder" - else: - self.name = "CNN" - - self.base_model_prefix = self.name - self.train_probs = [] self.train_labels = [] self.train_mix_positives_back_in = train_mix_positives_back_in @@ -277,29 +230,6 @@ def __init__(self, input_features, num_attention_heads, num_hidden_layers, n_tim self.save_hyperparameters() - def forward(self, inputs_embeds,labels): - encoding = self.encode(inputs_embeds) - preds = self.head(encoding) - loss = self.criterion(preds,labels) - return loss, preds - - def encode(self, inputs_embeds): - if not self.skip_cnn: - x = inputs_embeds.transpose(1, 2) - x = self.input_embedding(x) - x = x.transpose(1, 2) - else: - x = inputs_embeds - - if self.positional_encoding: - x = self.positional_encoding(x) - - for l in self.blocks: - x = l(x) - - # x = self.dense_interpolation(x) - return x - def set_train_dataset(self,dataset): self.train_dataset = dataset @@ -480,7 +410,6 @@ def optimizer_step( ): optimizer.step(closure=optimizer_closure) - def on_load_checkpoint(self, checkpoint: dict) -> None: state_dict = checkpoint["state_dict"] model_state_dict = self.state_dict() @@ -500,6 +429,162 @@ def on_load_checkpoint(self, checkpoint: dict) -> None: if is_changed: checkpoint.pop("optimizer_states", None) +class CNNToTransformerEncoder(CNNTransformerBase): + def __init__(self, clf_dropout_rate=0.0, **kwargs): + + super().__init__(**kwargs) + if kwargs.get('num_attention_heads') > 0: + self.name = "CNNToTransformerEncoder" + else: + self.name = "CNN" + + self.base_model_prefix = self.name + + self.input_embedding = CNNEncoder(kwargs.get('input_features'), n_timesteps=kwargs.get('n_timesteps'), kernel_sizes=kwargs.get('kernel_sizes'), + out_channels=kwargs.get('out_channels'), stride_sizes=kwargs.get('stride_sizes')) + + if not kwargs.get('skip_cnn'): + self.d_model = kwargs.get('out_channels')[-1] + final_length = self.input_embedding.final_output_length + else: + self.d_model = kwargs.get('input_features') + final_length = kwargs.get('n_timesteps') + + if self.input_embedding.final_output_length < 1: + raise ValueError("CNN final output dim is <1 ") + + self.is_classification = kwargs.get('model_head', 'classification') == "classification" + if self.is_classification: + self.head = modules.ClassificationModule(self.d_model, final_length, kwargs.get('num_labels',2), + dropout_p=clf_dropout_rate) + metric_class = TorchMetricClassification + + else: + self.head = modules.RegressionModule(self.d_model, final_length, kwargs.get('num_labels',2)) + metric_class = TorchMetricRegression + + self.train_metrics = metric_class(bootstrap_cis=False, prefix="train/") + self.eval_metrics = metric_class(bootstrap_cis=not kwargs.get('no_bootstrap',False), prefix="eval/") + self.test_metrics = metric_class(bootstrap_cis=not kwargs.get('no_bootstrap',False), prefix="test/") + + + if kwargs.get('positional_encoding'): + self.positional_encoding = modules.PositionalEncoding(self.d_model, final_length) + else: + self.positional_encoding = None + + self.blocks = nn.ModuleList([ + modules.EncoderBlock(self.d_model, kwargs.get('num_attention_heads'), kwargs.get('dropout_rate')) for _ in range(kwargs.get('num_hidden_layers')) + ]) + + # self.dense_interpolation = modules.DenseInterpolation(final_length, factor) + self.clf = modules.ClassificationModule(self.d_model, final_length, kwargs.get('num_labels'), + dropout_p=clf_dropout_rate) + self.provided_train_dataloader = None + self.criterion = build_loss_fn(kwargs) + + self.save_hyperparameters() + + + def forward(self, inputs_embeds,labels): + encoding = self.encode(inputs_embeds) + preds = self.head(encoding) + loss = self.criterion(preds,labels) + return loss, preds + + def encode(self, inputs_embeds): + if not self.skip_cnn: + x = inputs_embeds.transpose(1, 2) + x = self.input_embedding(x) + x = x.transpose(1, 2) + else: + x = inputs_embeds + + if self.positional_encoding: + x = self.positional_encoding(x) + + for l in self.blocks: + x = l(x) + + # x = self.dense_interpolation(x) + return x + + + +class EncoderLayerForSAnD(nn.Module): + def __init__(self, input_features, n_heads, n_layers, d_model=128, dropout_rate=0.2) -> None: + super(EncoderLayerForSAnD, self).__init__() + self.d_model = d_model + + self.input_embedding = nn.Conv1d(input_features, d_model, 1) + self.positional_encoding = modules.PositionalEncoding(d_model, input_features) + self.blocks = nn.ModuleList([ + modules.EncoderBlock(d_model, n_heads, dropout_rate) for _ in range(n_layers) + ]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x = x.transpose(1, 2) + x = self.input_embedding(x) + x = x.transpose(1, 2) + + x = self.positional_encoding(x) + + for l in self.blocks: + x = l(x) + + return x +class LightningSAnD(CNNTransformerBase): + """ + Simply Attend and Diagnose model + + The Thirty-Second AAAI Conference on Artificial Intelligence (AAAI-18) + + `Attend and Diagnose: Clinical Time Series Analysis Using Attention Models `_ + Huan Song, Deepta Rajan, Jayaraman J. Thiagarajan, Andreas Spanias + + NOTE: dropout_rate default value changes from 0.2 to 0.5 + """ + def __init__( + self, + **kwargs) -> None: #NOTE default value for one argument passed here are really ignored when also using kwargs.get(argument) + + super().__init__(**kwargs) + self.name = "SAnD" + self.base_model_prefix = self.name + + self.encoder = EncoderLayerForSAnD(kwargs.get('n_timesteps'), #if any of these arguments is not passed, it becomes None and breaks execution + kwargs.get('num_attention_heads'), + kwargs.get('num_hidden_layers'), + kwargs.get('d_model', 128), + kwargs.get('dropout_rate')) + self.dense_interpolation = modules.DenseInterpolation(kwargs.get('input_features'), kwargs.get('factor', 256)) + + self.is_classification = kwargs.get('model_head', 'classification') == "classification" + if self.is_classification: + self.head = modules.ClassificationModule(kwargs.get('d_model', 128), kwargs.get('factor', 256), kwargs.get('num_labels')) + metric_class = TorchMetricClassification + else: + self.head = modules.RegressionModule(kwargs.get('d_model', 128), kwargs.get('factor', 256), kwargs.get('num_labels')) + metric_class = TorchMetricRegression + + loss_weights = torch.tensor([float(kwargs.get('neg_class_weight',1)),float(kwargs.get('pos_class_weight',1))]) + self.criterion = nn.CrossEntropyLoss(weight=loss_weights) + self.n_class = kwargs.get('num_labels') + + self.train_metrics = metric_class(bootstrap_cis=False, prefix="train/") + self.eval_metrics = metric_class(bootstrap_cis=not kwargs.get('no_bootstrap',False), prefix="eval/") + self.test_metrics = metric_class(bootstrap_cis=not kwargs.get('no_bootstrap',False), prefix="test/") + + self.save_hyperparameters() + + def forward(self, inputs_embeds: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + x = self.encoder(inputs_embeds) + x = self.dense_interpolation(x) + x = self.head(x) + loss = self.criterion(x.view(-1,self.n_class),labels.view(-1)) + return loss, x + + class CNNToTransformerAutoEncoder(pl.LightningModule): def __init__(self, input_features, num_attention_heads, num_hidden_layers, n_timesteps, kernel_sizes=[5, 3, 1], out_channels=[256, 128, 64], diff --git a/src/models/train_model.py b/src/models/train_model.py index 319a245..2cc961e 100644 --- a/src/models/train_model.py +++ b/src/models/train_model.py @@ -57,6 +57,7 @@ from src.models.neural_baselines import create_neural_model from src.models.models import CNNToTransformerEncoder from src.models.trainer import FluTrainer +from src.models.models import LightningSAnD from src.SAnD.core.model import SAnD from src.utils import (get_logger, load_dotenv, render_network_plot, set_gpus_automatically, visualize_model) @@ -483,72 +484,160 @@ def train_sand( task_config=None, data_location=None, no_eval_during_training=False, auto_set_gpu=None, - **_): + use_huggingface=False, + backend="petastorm", + train_path=None, + eval_path=None, + test_path=None, + only_with_lab_results=False, + downsample_negative_frac=None, + reload_dataloaders = 0, + log_steps=50, + val_epochs=10, + model_config={}, + dropout_rate=0.5, + train_mixin_batch_size=3, + train_mix_positives_back_in=False, + pl_seed=2494, + resume_model_from_ckpt=None, + **model_specific_kwargs): - if model_path: - raise NotImplementedError() + if auto_set_gpu: + set_gpus_automatically(auto_set_gpu) + + if pl_seed: + pl.seed_everything(pl_seed) logger.info(f"Training SAnD") + if task_config: - task_name = task_config["task_name"] - dataset_args = task_config["dataset_args"] + task_name = task_config.get("task_name") #task_name = task_config["task_name"] + dataset_args = task_config.get("dataset_args",{}) #dataset_args = task_config["dataset_args"] + task_args = task_config.get("task_args",{}) #task_args = task_config["task_args"] + else: + task_name = None + task_args = None if not eval_frac is None: dataset_args["eval_frac"] = eval_frac dataset_args["return_dict"] = True dataset_args["data_location"] = data_location - task = get_task_with_name(task_name)(dataset_args=dataset_args, - activity_level=activity_level, - look_for_cached_datareader=look_for_cached_datareader, - datareader_ray_obj_ref=datareader_ray_obj_ref) + + task = get_task_with_name(task_name)(**task_args, + downsample_negative_frac=downsample_negative_frac, + dataset_args=dataset_args, + activity_level=activity_level, + look_for_cached_datareader=look_for_cached_datareader, + only_with_lab_results = only_with_lab_results, + datareader_ray_obj_ref=datareader_ray_obj_ref, + backend=backend, + train_path=train_path, + eval_path=eval_path, + test_path=test_path) if sinu_position_encoding: dataset_args["add_absolute_embedding"] = True + if not use_huggingface: + pl_training_args = dict( + max_epochs=n_epochs, + check_val_every_n_epoch=val_epochs, + log_every_n_steps=log_steps + ) - train_dataset = task.get_train_dataset() - infer_example = train_dataset[0]["inputs_embeds"] - n_timesteps, n_features = infer_example.shape - - model = SAnD(input_features=n_timesteps, - seq_len = n_features, - n_heads = num_attention_heads, - factor=256, - n_layers = num_hidden_layers, - n_class=2, - pos_class_weight=pos_class_weight, - neg_class_weight=neg_class_weight) - - - training_args = TrainingArguments( - output_dir='./results', # output directorz - num_train_epochs=n_epochs, # total # of training epochs - per_device_train_batch_size=train_batch_size, # batch size per device during training - per_device_eval_batch_size=eval_batch_size, # batch size for evaluation - warmup_steps=warmup_steps, # number of warmup steps for learning rate scheduler - weight_decay=weight_decay, - learning_rate=learning_rate, # strength of weight decay - logging_dir='./logs', - logging_steps=10, - do_eval=not no_eval_during_training, - dataloader_num_workers=16, - dataloader_pin_memory=True, - prediction_loss_only=False, - evaluation_strategy="epoch", - report_to=["wandb"] # directory for storing logs - ) + """ + train_dataset = task.get_train_dataset() #better to use task.data_shape + infer_example = train_dataset[0]["inputs_embeds"] + n_timesteps, n_features = infer_example.shape + """ + n_timesteps, n_features = task.data_shape + model_kwargs = dict(input_features=n_features, + n_timesteps=n_timesteps, + num_attention_heads = num_attention_heads, + num_hidden_layers = num_hidden_layers, + num_labels=2, + learning_rate =learning_rate, + warmup_steps = warmup_steps, + inital_batch_size=train_batch_size, + dropout_rate=dropout_rate, + train_mixin_batch_size = train_mixin_batch_size, + train_mix_positives_back_in = train_mix_positives_back_in, + d_model=128, #dimensionality of output sequence + factor=256, # Encoder output dimension + **model_specific_kwargs) + if model_config: + # model_kwargs.update(model_config) + raise NotImplementedError + + if model_path: + model = LightningSAnD.load_from_checkpoint(model_path, + strict=False, + **model_specific_kwargs) + model.hparams.wandb_id = None + + elif resume_model_from_ckpt: + model = LightningSAnD.load_from_checkpoint(resume_model_from_ckpt, + strict=False, + **model_specific_kwargs) + else: + model = LightningSAnD(**model_kwargs) - if task.is_classification: - metrics = task.get_huggingface_metrics(threshold=classification_threshold) - else: - metrics=None + run_pytorch_lightning(model = model, + task = task, + training_args=pl_training_args, + no_wandb=False, + notes=notes, + backend=backend, + reload_dataloaders = reload_dataloaders) + + else: + + train_dataset = task.get_train_dataset() + infer_example = train_dataset[0]["inputs_embeds"] + n_timesteps, n_features = infer_example.shape + if model_path: + model = load_model_from_huggingface_checkpoint(model_path) + else: + model = SAnD(input_features=n_timesteps, + seq_len = n_features, + n_heads = num_attention_heads, + factor=256, + n_layers = num_hidden_layers, + n_class=2, + pos_class_weight=pos_class_weight, + neg_class_weight=neg_class_weight) + + if task.is_classification: + metrics = task.get_huggingface_metrics(threshold=classification_threshold) + else: + metrics=None - run_huggingface(model=model, base_trainer=FluTrainer, + training_args = TrainingArguments( + output_dir='./results', # output directorz + num_train_epochs=n_epochs, # total # of training epochs + per_device_train_batch_size=train_batch_size, # batch size per device during training + per_device_eval_batch_size=eval_batch_size, # batch size for evaluation + warmup_steps=warmup_steps, # number of warmup steps for learning rate scheduler + weight_decay=weight_decay, + learning_rate=learning_rate, # strength of weight decay + logging_dir='./logs', + logging_steps=10, + do_eval=not no_eval_during_training, + dataloader_num_workers=16, + dataloader_pin_memory=True, + prediction_loss_only=False, + evaluation_strategy="epoch", + report_to=["wandb"] # directory for storing logs + ) + + run_huggingface(model=model, base_trainer=FluTrainer, training_args=training_args, metrics = metrics, task=task, no_wandb=no_wandb, notes=notes) + + def train_bert(task_config=None, task_name=None, n_epochs=10,