From 5e438c231bf5d11b298c55e16c2401b515372fce Mon Sep 17 00:00:00 2001 From: YankoFelipe Date: Thu, 7 Sep 2023 13:14:40 +0200 Subject: [PATCH] Adding support for Pytorch 2.x and Pytorch-Lightning 2.0.7 Removing deprecated param in pl.Trainer Setting pytorch_lightning version --- deepethogram/base.py | 13 +++++-------- deepethogram/callbacks.py | 16 ++++++++-------- setup.py | 2 +- 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/deepethogram/base.py b/deepethogram/base.py index 040a442..6dd8d06 100644 --- a/deepethogram/base.py +++ b/deepethogram/base.py @@ -9,6 +9,7 @@ import numpy as np from omegaconf import DictConfig, OmegaConf import pytorch_lightning as pl +from pytorch_lightning.callbacks.progress import TQDMProgressBar try: from ray.tune.integration.pytorch_lightning import TuneReportCallback, \ TuneReportCheckpointCallback @@ -275,8 +276,7 @@ def get_trainer_from_cfg(cfg: DictConfig, lightning_module, stopper, profiler: s # learning rate schedule. if cfg.compute.batch_size == 'auto' or cfg.train.lr == 'auto': - trainer = pl.Trainer(gpus=[cfg.compute.gpu_id], - precision=16 if cfg.compute.fp16 else 32, + trainer = pl.Trainer(precision=16 if cfg.compute.fp16 else 32, limit_train_batches=1.0, limit_val_batches=1.0, limit_test_batches=1.0, @@ -378,13 +378,13 @@ def get_trainer_from_cfg(cfg: DictConfig, lightning_module, stopper, profiler: s else: tensorboard_logger = pl.loggers.tensorboard.TensorBoardLogger(os.getcwd()) refresh_rate = 1 + callback_list.append(TQDMProgressBar(refresh_rate=refresh_rate)) # tuning messes with the callbacks try: # will be deprecated in the future; pytorch lightning updated their kwargs for this function # don't like how they keep updating the api without proper deprecation warnings, etc. - trainer = pl.Trainer(gpus=[cfg.compute.gpu_id], - precision=16 if cfg.compute.fp16 else 32, + trainer = pl.Trainer(precision=16 if cfg.compute.fp16 else 32, limit_train_batches=steps_per_epoch['train'], limit_val_batches=steps_per_epoch['val'], limit_test_batches=steps_per_epoch['test'], @@ -393,13 +393,11 @@ def get_trainer_from_cfg(cfg: DictConfig, lightning_module, stopper, profiler: s num_sanity_val_steps=0, callbacks=callback_list, reload_dataloaders_every_epoch=True, - progress_bar_refresh_rate=refresh_rate, profiler=profiler, log_every_n_steps=1) except TypeError: - trainer = pl.Trainer(gpus=[cfg.compute.gpu_id], - precision=16 if cfg.compute.fp16 else 32, + trainer = pl.Trainer(precision=16 if cfg.compute.fp16 else 32, limit_train_batches=steps_per_epoch['train'], limit_val_batches=steps_per_epoch['val'], limit_test_batches=steps_per_epoch['test'], @@ -408,7 +406,6 @@ def get_trainer_from_cfg(cfg: DictConfig, lightning_module, stopper, profiler: s num_sanity_val_steps=0, callbacks=callback_list, reload_dataloaders_every_n_epochs=1, - progress_bar_refresh_rate=refresh_rate, profiler=profiler, log_every_n_steps=1) torch.cuda.empty_cache() diff --git a/deepethogram/callbacks.py b/deepethogram/callbacks.py index 9c2a41c..e33ad2f 100644 --- a/deepethogram/callbacks.py +++ b/deepethogram/callbacks.py @@ -20,10 +20,10 @@ def __init__(self): def on_init_end(self, trainer): log.info('on init start') - def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): log.debug('on train batch start') - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): log.debug('on train batch end') def on_train_epoch_start(self, trainer, pl_module): @@ -94,16 +94,16 @@ def end_batch(self, split, batch, pl_module, eps: float = 1e-7): pl_module.metrics.buffer.append(split, {'fps': fps}) - def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): self.start_timer('train') - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self.end_batch('train', batch, pl_module) - def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx): self.start_timer('val') - def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self.end_batch('val', batch, pl_module) def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): @@ -204,10 +204,10 @@ def on_validation_epoch_end(self, trainer, pl_module): def on_test_epoch_end(self, trainer, pl_module): self.reset_cnt(pl_module, 'test') - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): pl_module.viz_cnt['train'] += 1 - def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): pl_module.viz_cnt['val'] += 1 def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): diff --git a/setup.py b/setup.py index 7fa7687..4988973 100644 --- a/setup.py +++ b/setup.py @@ -18,5 +18,5 @@ install_requires=[ 'chardet<4.0', 'h5py', 'kornia>=0.5', 'matplotlib', 'numpy', 'omegaconf>=2', 'opencv-python-headless', 'opencv-transforms', 'pandas<1.4', 'PySide2', 'scikit-learn<1.1', - 'scipy<1.8', 'tqdm', 'vidio', 'pytorch_lightning>=1.5.10' + 'scipy<1.8', 'tqdm', 'vidio', 'pytorch_lightning>=2.0.7' ])