diff --git a/CHANGELOG.md b/CHANGELOG.md index f53eb99..8d18970 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Changelog +## 1.0.6 +- FIX: DDP training stability and config handling (single init, rank/world-size propagation, robust setup/token flow) +- NEW: tqdm progress bars for training/validation and interactive embedding extraction +- FIX: interactive plots alignment bugs and compute_embeddings device argument handling +- FIX: Streamlit compatibility updates for model explorer (`use_container_width`, watcher workaround) +- FIX: safer core behavior (no-grad embedding extraction, CPU fallbacks in train/SWA/LR finder, clearer missing global config error) +- IMPROVED: metric selection logic with configurable `target_metric_mode` (`auto` / `max` / `min`) + ## 1.0.5 - FIX: class order in model explorer was broken! - FIX: interactive plot was looking for a missing argument diff --git a/README.md b/README.md index 908ab66..52ecbd3 100755 --- a/README.md +++ b/README.md @@ -32,7 +32,10 @@ Read the paper: [https://onlinelibrary.wiley.com/doi/10.1111/ele.14495](https:// [>> Comprehensive help files <<](help) 1\. Install BioEncoder (into a virtual environment with pytorch/CUDA): + +Install `torch` / `torchvision` first for your platform (CPU or CUDA), then install BioEncoder: ```` +pip install torch torchvision pip install bioencoder ```` diff --git a/bioencoder/__init__.py b/bioencoder/__init__.py index ff66c10..98929be 100644 --- a/bioencoder/__init__.py +++ b/bioencoder/__init__.py @@ -2,15 +2,49 @@ from .core import utils # from .scripts import * -from .scripts.archive import archive -from .scripts.configure import configure -from .scripts.split_dataset import split_dataset -from .scripts.train import train -from .scripts.swa import swa -from .scripts.lr_finder import lr_finder -from .scripts.interactive_plots import interactive_plots -from .scripts.inference import inference -from .scripts.model_explorer_wrapper import model_explorer_wrapper as model_explorer +def archive(*args, **kwargs): + from .scripts.archive import archive as _archive + return _archive(*args, **kwargs) + + +def configure(*args, **kwargs): + from .scripts.configure import configure as _configure + return _configure(*args, **kwargs) + + +def split_dataset(*args, **kwargs): + from .scripts.split_dataset import split_dataset as _split_dataset + return _split_dataset(*args, **kwargs) + + +def train(*args, **kwargs): + from .scripts.train import train as _train + return _train(*args, **kwargs) + + +def swa(*args, **kwargs): + from .scripts.swa import swa as _swa + return _swa(*args, **kwargs) + + +def lr_finder(*args, **kwargs): + from .scripts.lr_finder import lr_finder as _lr_finder + return _lr_finder(*args, **kwargs) + + +def interactive_plots(*args, **kwargs): + from .scripts.interactive_plots import interactive_plots as _interactive_plots + return _interactive_plots(*args, **kwargs) + + +def inference(*args, **kwargs): + from .scripts.inference import inference as _inference + return _inference(*args, **kwargs) + + +def model_explorer(*args, **kwargs): + from .scripts.model_explorer_wrapper import model_explorer_wrapper as _model_explorer + return _model_explorer(*args, **kwargs) from importlib.metadata import version __version__ = version("bioencoder") \ No newline at end of file diff --git a/bioencoder/core/losses.py b/bioencoder/core/losses.py index 6bcf073..f59b048 100644 --- a/bioencoder/core/losses.py +++ b/bioencoder/core/losses.py @@ -173,7 +173,7 @@ def forward(self, pred:torch.Tensor, target:torch.Tensor): with torch.no_grad(): true_dist = torch.zeros_like(pred) true_dist.fill_(self.smoothing / (self.cls - 1)) - true_dist.scatter_(1, target.unsqueeze(1), self.confidence) + true_dist.scatter_(class_dim, target.unsqueeze(class_dim), self.confidence) return torch.mean(torch.sum(-true_dist * pred, dim=self.dim)) diff --git a/bioencoder/core/utils.py b/bioencoder/core/utils.py index 048131b..0b36feb 100644 --- a/bioencoder/core/utils.py +++ b/bioencoder/core/utils.py @@ -11,8 +11,10 @@ from functools import wraps import torch +import torch.distributed as dist from torchvision import transforms from torchvision.datasets import ImageFolder +from torch.utils.data.distributed import DistributedSampler from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator from sklearn.metrics import f1_score #, accuracy_score @@ -24,6 +26,34 @@ from .augmentations import get_transforms from bioencoder.vis import helpers + +def is_distributed(): + return dist.is_available() and dist.is_initialized() + + +def get_rank(): + return dist.get_rank() if is_distributed() else 0 + + +def get_world_size(): + return dist.get_world_size() if is_distributed() else 1 + + +def is_main_process(): + return get_rank() == 0 + +def init_distributed(backend="nccl", local_rank=None): + if local_rank is None: + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + dist.init_process_group(backend=backend, init_method="env://") + + + +def teardown_distributed(): + if is_distributed(): + dist.destroy_process_group() + def save_yaml(dic, yaml_path): with open(yaml_path, 'w') as file: yaml.dump(dic, file, default_flow_style=False) @@ -56,6 +86,11 @@ def restore_config(func): @wraps(func) def wrapper(*args, **kwargs): config_path = os.path.expanduser("~/.bioencoder.yaml") # Updated to load from YAML + if not os.path.isfile(config_path): + raise FileNotFoundError( + f"Global BioEncoder config not found at '{config_path}'. " + "Run the configure CLI first (e.g., bioencoder_configure --root-dir --run-name )." + ) config = load_yaml(config_path) # Import the bioencoder config module and update its attributes @@ -74,10 +109,11 @@ def load_model( stage, cuda_device ): + device = cuda_device if isinstance(cuda_device, torch.device) else torch.device(cuda_device) model = build_model( backbone, second_stage=(stage == 'second'), num_classes=num_classes, ckpt_pretrained=ckpt_pretrained, - cuda_device=cuda_device).cuda(cuda_device) + cuda_device=device).to(device) model.use_projection_head((stage=='second')) model.eval() @@ -92,21 +128,24 @@ def update_config(config, config_path=None): yaml.dump(config.__dict__, file, default_flow_style=False) -def set_seed(seed=42): +def set_seed(seed=42, rank_offset=0): """Set the random seed for the entire pipeline. Parameters: seed (int, optional): The seed value to set for all random number generators. Default is 42. """ - random.seed(seed) - os.environ["PYTHONHASHSEED"] = str(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_value = int(seed) + int(rank_offset) + random.seed(seed_value) + os.environ["PYTHONHASHSEED"] = str(seed_value) + np.random.seed(seed_value) + torch.manual_seed(seed_value) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed_value) + torch.cuda.manual_seed_all(seed_value) torch.backends.cudnn.deterministic = True - return seed + return seed_value def pprint_fill_hbar(message, symbol="-", ret=True): try: @@ -151,7 +190,8 @@ def add_to_tensorboard_logs(writer, message, tag, index): index (int): The global step at which to log the scalar value. """ - writer.add_scalar(tag, message, index) + if writer is not None: + writer.add_scalar(tag, message, index) class TwoCropTransform: @@ -198,7 +238,9 @@ def build_transforms(config): def build_loaders(data_dir, transforms, batch_sizes, num_workers, second_stage=False, is_supcon=False, - shuffle_train=True, drop_last=True): + shuffle_train=True, drop_last=True, + train_sampler=None, valid_sampler=None, train_supcon_sampler=None, + distributed=False, rank=0, world_size=1): """ Build data loaders for training and validation. @@ -229,23 +271,45 @@ def build_loaders(data_dir, transforms, batch_sizes, num_workers, second_stage=True ) + if distributed: + train_sampler = DistributedSampler( + train_features_dataset, + num_replicas=world_size, + rank=rank, + shuffle=shuffle_train, + drop_last=drop_last, + ) + valid_sampler = DistributedSampler( + valid_dataset, + num_replicas=world_size, + rank=rank, + shuffle=False, + drop_last=False, + ) + train_loader = torch.utils.data.DataLoader( train_features_dataset, batch_size=batch_sizes['train_batch_size'], - shuffle=shuffle_train, + shuffle=shuffle_train if train_sampler is None else False, + sampler=train_sampler, num_workers=num_workers, - pin_memory=True, - drop_last=drop_last and batch_sizes['train_batch_size'] is not None + pin_memory=torch.cuda.is_available(), + drop_last=drop_last and batch_sizes['train_batch_size'] is not None, + multiprocessing_context="spawn" if num_workers > 0 else None, + persistent_workers=(num_workers > 0), ) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=batch_sizes['valid_batch_size'], shuffle=False, + sampler=valid_sampler, num_workers=num_workers, - pin_memory=True, + pin_memory=torch.cuda.is_available(), # Keep all validation samples for unbiased validation metrics. - drop_last=False + drop_last=False, + multiprocessing_context="spawn" if num_workers > 0 else None, + persistent_workers=(num_workers > 0), ) loaders = { @@ -261,13 +325,25 @@ def build_loaders(data_dir, transforms, batch_sizes, num_workers, second_stage=False if is_supcon else True ) + if distributed: + train_supcon_sampler = DistributedSampler( + train_supcon_dataset, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=drop_last, + ) + train_supcon_loader = torch.utils.data.DataLoader( train_supcon_dataset, batch_size=batch_sizes['train_batch_size'], - shuffle=True, + shuffle=True if train_supcon_sampler is None else False, + sampler=train_supcon_sampler, num_workers=num_workers, - pin_memory=True, - drop_last=drop_last and batch_sizes['train_batch_size'] is not None + pin_memory=torch.cuda.is_available(), + drop_last=drop_last and batch_sizes['train_batch_size'] is not None, + multiprocessing_context="spawn" if num_workers > 0 else None, + persistent_workers=(num_workers > 0), ) loaders['train_supcon_loader'] = train_supcon_loader @@ -293,7 +369,8 @@ def build_model(backbone, second_stage=False, num_classes=None, ckpt_pretrained= model = BioEncoderModel(backbone=backbone, second_stage=second_stage, num_classes=num_classes) if ckpt_pretrained: - model.load_state_dict(torch.load(ckpt_pretrained, map_location=torch.device(cuda_device))['model_state_dict'], strict=False) + map_location = cuda_device if isinstance(cuda_device, torch.device) else torch.device(cuda_device) + model.load_state_dict(torch.load(ckpt_pretrained, map_location=map_location)['model_state_dict'], strict=False) return model @@ -359,7 +436,31 @@ def create_optimizer(parameters, spec): return {"criterion": criterion, "optimizer": optimizer, "scheduler": scheduler, "loss_optimizer": loss_optimizer} -def compute_embeddings(loader, model, scaler=None): +def _all_gather_cat(tensor): + world_size = get_world_size() + if world_size == 1: + return tensor + + local_size = torch.tensor([tensor.shape[0]], device=tensor.device, dtype=torch.long) + size_list = [torch.zeros_like(local_size) for _ in range(world_size)] + dist.all_gather(size_list, local_size) + max_size = int(torch.stack(size_list).max().item()) + + if tensor.shape[0] < max_size: + pad_shape = (max_size - tensor.shape[0],) + tensor.shape[1:] + pad = torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device) + tensor = torch.cat([tensor, pad], dim=0) + + gathered = [torch.zeros_like(tensor) for _ in range(world_size)] + dist.all_gather(gathered, tensor) + + outputs = [] + for idx, chunk in enumerate(gathered): + outputs.append(chunk[: int(size_list[idx].item())]) + return torch.cat(outputs, dim=0) + + +def compute_embeddings(loader, model, device, scaler=None, progress_bar=False, progress_desc=None): """Computes the embeddings and corresponding labels for a dataset. Parameters: @@ -375,27 +476,45 @@ def compute_embeddings(loader, model, scaler=None): total_embeddings = None total_labels = None - for images, labels in loader: - images = images.cuda() - if scaler: - with torch.amp.autocast("cuda"): - embed = model(images) - else: - embed = model(images) - if total_embeddings is None: - total_embeddings = embed.detach().cpu() - total_labels = labels.detach().cpu() - else: - total_embeddings = torch.cat((total_embeddings, embed.detach().cpu())) - total_labels = torch.cat((total_labels, labels.detach().cpu())) + pbar = None + if progress_bar and is_main_process(): + pbar = tqdm(total=len(loader), desc=progress_desc or "Validation", dynamic_ncols=True, leave=False) - del images, labels, embed + try: + for images, labels in loader: + with torch.no_grad(): + images = images.to(device, non_blocking=True) + if scaler: + with torch.amp.autocast("cuda"): + embed = model(images) + else: + embed = model(images) + if total_embeddings is None: + total_embeddings = embed.detach().cpu() + total_labels = labels.detach().cpu() + else: + total_embeddings = torch.cat((total_embeddings, embed.detach().cpu())) + total_labels = torch.cat((total_labels, labels.detach().cpu())) + + if pbar is not None: + pbar.update(1) - torch.cuda.empty_cache() + del images, labels, embed + finally: + if pbar is not None: + pbar.close() - return np.float32(total_embeddings), np.uint8(total_labels) + #torch.cuda.empty_cache() + emb = np.float32(total_embeddings) + lbl = np.uint8(total_labels) + if is_distributed(): + emb_t = torch.from_numpy(emb).to(device) + lbl_t = torch.from_numpy(lbl).to(device) + emb = _all_gather_cat(emb_t).detach().cpu().numpy().astype(np.float32) + lbl = _all_gather_cat(lbl_t).detach().cpu().numpy().astype(np.uint8) + return emb, lbl def train_epoch_constructive( train_loader, model, @@ -406,82 +525,116 @@ def train_epoch_constructive( loss_optimizer, scheduler=None, scheduler_step_per_batch=False, + device=torch.device("cuda"), + grad_accum_steps=1, + progress_bar=False, + epoch=None, ): - """ - Trains the `model` on the data from the `train_loader` for one epoch. The loss function is defined by `criterion` and - the optimization algorithm is defined by `optimizer`. The training process can also be scaled using the `scaler` and - the `ema` (exponential moving average) can be applied to the model's parameters. - - Parameters: - - train_loader (torch.utils.data.DataLoader): The data loader that provides the training data. - - model (torch.nn.Module): The model that will be trained. - - criterion (torch.nn.Module): The loss function to be used for training. - - optimizer (torch.optim.Optimizer): The optimization algorithm to be used for training. - - scaler (torch.amp.GradScaler, optional): The scaler used for gradient scaling in case of mixed precision training. - - ema (ExponentialMovingAverage, optional): If provided, the exponential moving average to be applied to the model's parameters. - - Returns: - - dict: A dictionary containing the mean loss over all training batches. - """ model.train() train_loss = [] - loss_optimization = False if loss_optimizer is None else True - - for idx, (images, labels) in enumerate(train_loader): - if loss_optimization: - images, labels = images.cuda(), labels.cuda() - else: - images = torch.cat([images[0]['image'], images[1]['image']], dim=0).cuda() - labels = labels.cuda() - bsz = labels.shape[0] - - if scaler: - with torch.amp.autocast("cuda"): + loss_optimization = loss_optimizer is not None + grad_accum_steps = max(1, int(grad_accum_steps)) + last_accum = len(train_loader) % grad_accum_steps + + optimizer.zero_grad(set_to_none=True) + if loss_optimization: + loss_optimizer.zero_grad(set_to_none=True) + + pbar = None + if progress_bar and is_main_process(): + epoch_str = f"{epoch + 1}" if epoch is not None else "?" + pbar = tqdm(total=len(train_loader), desc=f"Train Epoch {epoch_str}", dynamic_ncols=True, leave=False) + + try: + for idx, (images, labels) in enumerate(train_loader): + if loss_optimization: + images = images.to(device, non_blocking=True) + labels = labels.to(device, non_blocking=True) + else: + images = torch.cat([images[0]["image"], images[1]["image"]], dim=0).to(device, non_blocking=True) + labels = labels.to(device, non_blocking=True) + bsz = labels.shape[0] + + if scaler is not None: + with torch.amp.autocast("cuda"): + embed = model(images) + if not loss_optimization: + f1, f2 = torch.split(embed, [bsz, bsz], dim=0) + embed = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1) + loss = criterion(embed, labels) + else: embed = model(images) if not loss_optimization: f1, f2 = torch.split(embed, [bsz, bsz], dim=0) embed = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1) loss = criterion(embed, labels) - else: - embed = model(images) - if not loss_optimization: - f1, f2 = torch.split(embed, [bsz, bsz], dim=0) - embed = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1) - loss = criterion(embed, labels) + train_loss.append(loss.item()) + step_now = ((idx + 1) % grad_accum_steps == 0) or ((idx + 1) == len(train_loader)) + accum_denom = grad_accum_steps + if ((idx + 1) == len(train_loader)) and (last_accum != 0): + accum_denom = last_accum - del images, labels, embed - torch.cuda.empty_cache() + loss_to_backprop = loss / accum_denom - train_loss.append(loss.item()) + if scaler is not None: + scaler.scale(loss_to_backprop).backward() + else: + loss_to_backprop.backward() - optimizer.zero_grad() - if loss_optimization: - loss_optimizer.zero_grad() + if step_now: + if scaler is not None: + scaler.unscale_(optimizer) - if scaler: - scaler.scale(loss).backward() - scaler.step(optimizer) - if loss_optimization: - scaler.step(loss_optimizer) - scaler.update() - else: - loss.backward() - optimizer.step() - if loss_optimization: - loss_optimizer.step() + if loss_optimization: + scaler.unscale_(loss_optimizer) + + scaler.step(optimizer) - if scheduler_step_per_batch and scheduler is not None: - scheduler.step() + if loss_optimization: + scaler.step(loss_optimizer) - if ema: - ema.update(model.parameters()) + scaler.update() + else: + optimizer.step() - return {'loss': np.mean(train_loss)} + if loss_optimization: + loss_optimizer.step() + optimizer.zero_grad(set_to_none=True) -def validation_constructive(valid_loader, train_loader, model, scaler): + if loss_optimization: + loss_optimizer.zero_grad(set_to_none=True) + + if step_now and scheduler_step_per_batch and scheduler is not None: + scheduler.step() + + if ema and step_now: + ema.update(model.parameters()) + + if pbar is not None: + pbar.update(1) + if step_now: + pbar.set_postfix(loss=f"{np.mean(train_loss):.4f}") + + del images, labels, embed, loss, loss_to_backprop + finally: + if pbar is not None: + pbar.close() + + return {"loss": np.mean(train_loss)} + +def validation_constructive( + valid_loader, + train_loader, + model, + device, + scaler, + progress_bar=False, + epoch=None, + split_name="projection", +): """ This function performs the validation step of the constructive learning algorithm. @@ -499,19 +652,42 @@ def validation_constructive(valid_loader, train_loader, model, scaler): calculator = AccuracyCalculator(k=1, exclude=["r_precision","mean_average_precision_at_r"]) model.eval() - query_embeddings, query_labels = compute_embeddings(valid_loader, model, scaler) - reference_embeddings, reference_labels = compute_embeddings(train_loader, model, scaler) + epoch_str = f"{epoch + 1}" if epoch is not None else "?" + query_embeddings, query_labels = compute_embeddings( + valid_loader, + model, + device, + scaler, + progress_bar=progress_bar, + progress_desc=f"Valid E{epoch_str} ({split_name}) query", + ) + reference_embeddings, reference_labels = compute_embeddings( + train_loader, + model, + device, + scaler, + progress_bar=progress_bar, + progress_desc=f"Valid E{epoch_str} ({split_name}) ref", + ) - acc_dict = calculator.get_accuracy( - query_embeddings, - query_labels, - reference_embeddings, - reference_labels, - ) + if is_main_process(): + acc_dict = calculator.get_accuracy( + query_embeddings, + query_labels, + reference_embeddings, + reference_labels, + ) + else: + acc_dict = None + + if is_distributed(): + obj = [acc_dict] + dist.broadcast_object_list(obj, src=0) + acc_dict = obj[0] del query_embeddings, query_labels, reference_embeddings, reference_labels - torch.cuda.empty_cache() + #torch.cuda.empty_cache() return acc_dict @@ -525,6 +701,10 @@ def train_epoch_ce( ema, scheduler=None, scheduler_step_per_batch=False, + device=torch.device("cuda"), + grad_accum_steps=1, + progress_bar=False, + epoch=None, ): """ Train the model for one epoch using cross-entropy loss. @@ -543,75 +723,136 @@ def train_epoch_ce( model.train() train_loss = [] + grad_accum_steps = max(1, int(grad_accum_steps)) + last_accum = len(train_loader) % grad_accum_steps + + optimizer.zero_grad() - for batch_i, (data, target) in enumerate(train_loader): - data, target = data.cuda(), target.cuda() - optimizer.zero_grad() - if scaler: - with torch.amp.autocast("cuda"): + pbar = None + if progress_bar and is_main_process(): + epoch_str = f"{epoch + 1}" if epoch is not None else "?" + pbar = tqdm(total=len(train_loader), desc=f"Train Epoch {epoch_str}", dynamic_ncols=True, leave=False) + + try: + for batch_i, (data, target) in enumerate(train_loader): + data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True) + step_now = ((batch_i + 1) % grad_accum_steps == 0) or ((batch_i + 1) == len(train_loader)) + if scaler: + with torch.amp.autocast("cuda"): + output = model(data) + loss = criterion(output, target) + train_loss.append(loss.item()) + accum_denom = grad_accum_steps + if ((batch_i + 1) == len(train_loader)) and (last_accum != 0): + accum_denom = last_accum + loss_to_backprop = loss / accum_denom + scaler.scale(loss_to_backprop).backward() + if step_now: + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + else: output = model(data) loss = criterion(output, target) train_loss.append(loss.item()) - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - else: - output = model(data) - loss = criterion(output, target) - train_loss.append(loss.item()) - loss.backward() - optimizer.step() + accum_denom = grad_accum_steps + if ((batch_i + 1) == len(train_loader)) and (last_accum != 0): + accum_denom = last_accum + loss_to_backprop = loss / accum_denom + loss_to_backprop.backward() + if step_now: + optimizer.step() + optimizer.zero_grad() + + if step_now and scheduler_step_per_batch and scheduler is not None: + scheduler.step() + + if ema and step_now: + ema.update(model.parameters()) + + if pbar is not None: + pbar.update(1) + if step_now: + pbar.set_postfix(loss=f"{np.mean(train_loss):.4f}") - if scheduler_step_per_batch and scheduler is not None: - scheduler.step() - - if ema: - ema.update(model.parameters()) - - del data, target, output - torch.cuda.empty_cache() + del data, target, output + #torch.cuda.empty_cache() + finally: + if pbar is not None: + pbar.close() return {"loss": np.mean(train_loss)} -def validation_ce(model, criterion, valid_loader, scaler): +def validation_ce(model, criterion, valid_loader, device, scaler, progress_bar=False, epoch=None): model.eval() val_loss = [] y_pred, y_true = [], [] correct_samples = 0 total_samples = 0 - for batch_i, (data, target) in enumerate(valid_loader): - with torch.no_grad(): - data, target = data.cuda(), target.cuda() - if scaler: - with torch.amp.autocast("cuda"): + pbar = None + if progress_bar and is_main_process(): + epoch_str = f"{epoch + 1}" if epoch is not None else "?" + pbar = tqdm(total=len(valid_loader), desc=f"Valid Epoch {epoch_str}", dynamic_ncols=True, leave=False) + + try: + for batch_i, (data, target) in enumerate(valid_loader): + with torch.no_grad(): + data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True) + if scaler: + with torch.amp.autocast("cuda"): + output = model(data) + if criterion: + loss = criterion(output, target) + val_loss.append(loss.item()) + else: output = model(data) if criterion: loss = criterion(output, target) val_loss.append(loss.item()) - else: - output = model(data) - if criterion: - loss = criterion(output, target) - val_loss.append(loss.item()) - target_np = target.detach().cpu().numpy() - pred_np = np.argmax(output.detach().cpu().numpy(), axis=1) - correct_samples += (target_np == pred_np).sum() - total_samples += target_np.shape[0] - y_pred.append(pred_np) - y_true.append(target_np) + target_np = target.detach().cpu().numpy() + pred_np = np.argmax(output.detach().cpu().numpy(), axis=1) + correct_samples += (target_np == pred_np).sum() + total_samples += target_np.shape[0] + y_pred.append(pred_np) + y_true.append(target_np) - del data, target, output - torch.cuda.empty_cache() + if pbar is not None: + pbar.update(1) + if len(val_loss) > 0: + pbar.set_postfix(loss=f"{np.mean(val_loss):.4f}") + + del data, target, output + #torch.cuda.empty_cache() + finally: + if pbar is not None: + pbar.close() y_pred = np.concatenate(y_pred) if y_pred else np.array([], dtype=np.int64) y_true = np.concatenate(y_true) if y_true else np.array([], dtype=np.int64) - valid_loss = np.mean(val_loss) if val_loss else np.nan - f1_scores = f1_score(y_true, y_pred, average=None) if total_samples > 0 else np.array([]) - f1_score_macro = f1_score(y_true, y_pred, average='macro') if total_samples > 0 else np.nan - accuracy_score = correct_samples / total_samples if total_samples > 0 else np.nan + + if is_distributed(): + y_pred_t = torch.from_numpy(y_pred).to(device) + y_true_t = torch.from_numpy(y_true).to(device) + y_pred = _all_gather_cat(y_pred_t).detach().cpu().numpy() + y_true = _all_gather_cat(y_true_t).detach().cpu().numpy() + + stats = torch.tensor( + [sum(val_loss), len(val_loss), float(correct_samples), float(total_samples)], + device=device, + dtype=torch.float64, + ) + dist.all_reduce(stats, op=dist.ReduceOp.SUM) + loss_sum, loss_count, correct_samples, total_samples = stats.tolist() + valid_loss = (loss_sum / loss_count) if loss_count > 0 else np.nan + else: + valid_loss = np.mean(val_loss) if val_loss else np.nan + + f1_scores = f1_score(y_true, y_pred, average=None) if len(y_true) > 0 else np.array([]) + f1_score_macro = f1_score(y_true, y_pred, average='macro') if len(y_true) > 0 else np.nan + accuracy_score = (correct_samples / total_samples) if total_samples > 0 else np.nan metrics = {"loss": valid_loss, "accuracy": accuracy_score, "f1_scores": f1_scores, 'f1_score_macro': f1_score_macro} return metrics @@ -696,5 +937,3 @@ def save_augmented_sample(data_dir, transform, n_samples, seed): augmented_image = to_pil_image(postprocessing(augmented_image)) sample_path = os.path.join(save_dir, f"{class_label_str}_{image_name}_augmented.png") augmented_image.save(sample_path) - - diff --git a/bioencoder/scripts/interactive_plots.py b/bioencoder/scripts/interactive_plots.py index 10fa606..68fd286 100644 --- a/bioencoder/scripts/interactive_plots.py +++ b/bioencoder/scripts/interactive_plots.py @@ -8,6 +8,35 @@ from bioencoder.vis import helpers from bioencoder import config + +def _build_split_embeddings_df(rel_paths, embeddings, dataset_name): + """ + Build metadata + embeddings DataFrame for a split with strict length alignment. + """ + n_meta = len(rel_paths) + n_embed = len(embeddings) + n = min(n_meta, n_embed) + if n == 0: + raise ValueError(f"No samples available for split '{dataset_name}' (meta={n_meta}, embeddings={n_embed}).") + if n_meta != n_embed: + print( + f"Warning: split '{dataset_name}' metadata/embedding length mismatch " + f"(meta={n_meta}, embeddings={n_embed}); truncating to {n}.", + flush=True, + ) + + rel_paths = rel_paths[:n] + embeddings = embeddings[:n] + df_meta = pd.DataFrame( + { + "image_name": [os.path.basename(p) for p in rel_paths], + "class_str": [os.path.basename(os.path.dirname(p)) for p in rel_paths], + "dataset": dataset_name, + } + ) + return pd.concat([df_meta, pd.DataFrame(embeddings)], axis=1) + + def interactive_plots( config_path, overwrite=False, @@ -62,6 +91,7 @@ def interactive_plots( } num_workers = hyperparams.get("dataloaders", {}).get("num_workers", 4) perplexity = hyperparams.get("perplexity") + progress_bar = hyperparams.get("progress_bar", True) plot_config = { "color_classes": hyperparams.get("color_classes", None), @@ -84,7 +114,14 @@ def interactive_plots( print(f"Checkpoint: using {checkpoint} of {stage} stage") ckpt_pretrained = os.path.join(root_dir, "weights", run_name, stage, checkpoint) seed = utils.set_seed() - model = utils.build_model(backbone, second_stage=(stage == "second"), num_classes=num_classes, ckpt_pretrained=ckpt_pretrained).cuda() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = utils.build_model( + backbone, + second_stage=(stage == "second"), + num_classes=num_classes, + ckpt_pretrained=ckpt_pretrained, + cuda_device=device, + ).to(device) model.use_projection_head(False) model.eval() @@ -95,28 +132,27 @@ def interactive_plots( second_stage=(stage == "second"), drop_last=False, shuffle_train=False) ## val set (always computed) - embeddings_val, labels_val = utils.compute_embeddings(loaders["valid_loader"], model) - rel_paths_val = [item[0][len(root_dir) + 1:] for item in loaders["valid_loader"].dataset.imgs] - # Build validation DataFrame (meta + embeddings) - df_val_meta = pd.DataFrame({ - "image_name": [os.path.basename(p) for p in rel_paths_val], - "class_str": [os.path.basename(os.path.dirname(p)) for p in rel_paths_val], - "dataset": "val", - }) - df_embeddings = pd.concat([df_val_meta, pd.DataFrame(embeddings_val)], axis=1) + embeddings_val, labels_val = utils.compute_embeddings( + loaders["valid_loader"], + model, + device, + progress_bar=progress_bar, + progress_desc="Embeddings (val)", + ) + rel_paths_val = [item[0][len(root_dir) + 1:] for item in loaders["valid_loader"].dataset.samples] + df_embeddings = _build_split_embeddings_df(rel_paths_val, embeddings_val, "val") ## train set - skipped if zero batch size if batch_sizes["train_batch_size"] is not None: - embeddings_train, labels_train = utils.compute_embeddings(loaders["train_loader"], model) - rel_paths_train = [item[0][len(root_dir) + 1:] for item in loaders["train_loader"].dataset.imgs] - - # Build training DataFrame (meta + embeddings) - df_train_meta = pd.DataFrame({ - "image_name": [os.path.basename(p) for p in rel_paths_train], - "class_str": [os.path.basename(os.path.dirname(p)) for p in rel_paths_train], - "dataset": "train", - }) - df_train = pd.concat([df_train_meta, pd.DataFrame(embeddings_train)], axis=1) + embeddings_train, labels_train = utils.compute_embeddings( + loaders["train_loader"], + model, + device, + progress_bar=progress_bar, + progress_desc="Embeddings (train)", + ) + rel_paths_train = [item[0][len(root_dir) + 1:] for item in loaders["train_loader"].dataset.samples] + df_train = _build_split_embeddings_df(rel_paths_train, embeddings_train, "train") df_embeddings = pd.concat([df_embeddings, df_train], ignore_index=True) ## Stable order before reduction diff --git a/bioencoder/scripts/lr_finder.py b/bioencoder/scripts/lr_finder.py index 665cba2..947efaa 100644 --- a/bioencoder/scripts/lr_finder.py +++ b/bioencoder/scripts/lr_finder.py @@ -6,6 +6,7 @@ import argparse import os import matplotlib.pyplot as plt +import torch from torch_lr_finder import LRFinder @@ -102,12 +103,14 @@ def lr_finder( data_dir, transforms, batch_sizes, num_workers, second_stage=True ) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = utils.build_model( backbone, second_stage=True, num_classes=num_classes, ckpt_pretrained=ckpt_pretrained, - ).cuda() + cuda_device=device, + ).to(device) optim = utils.build_optim( model, optimizer_params, scheduler_params, criterion_params @@ -117,12 +120,13 @@ def lr_finder( optim["optimizer"], optim["scheduler"], ) - lr_finder = LRFinder(model, optimizer, criterion, device="cuda") + lr_finder = LRFinder(model, optimizer, criterion, device=str(device)) lr_finder.range_test(loaders["train_loader"], end_lr=1, num_iter=num_iter) fig, ax = plt.subplots() ax, lr = lr_finder.plot(ax=ax, skip_start=skip_start, skip_end=skip_end) config.lr = round(lr, 6) + config.second_lr = config.lr fig.suptitle(f"Suggested LR: {config.lr}" , fontsize=20) if save_figure: @@ -144,5 +148,3 @@ def cli(): if __name__ == "__main__": cli() - - diff --git a/bioencoder/scripts/model_explorer.py b/bioencoder/scripts/model_explorer.py index 11b9432..d88d5e1 100644 --- a/bioencoder/scripts/model_explorer.py +++ b/bioencoder/scripts/model_explorer.py @@ -88,7 +88,7 @@ def model_explorer( # Sidebar img_path = "https://github.com/agporto/BioEncoder/raw/main/assets/bioencoder_logo.png" - st.sidebar.image(img_path, width='stretch') + st.sidebar.image(img_path, use_container_width=True) st.sidebar.title("BioEncoder Model Explorer") # Image upload @@ -116,7 +116,7 @@ def model_explorer( # Display the uploaded image image = Image.open(uploaded_file).convert('RGB') - st.sidebar.image(image, caption="Input Image", width='stretch') + st.sidebar.image(image, caption="Input Image", use_container_width=True) # resize image image_resized = image.resize((img_size, img_size)) diff --git a/bioencoder/scripts/model_explorer_wrapper.py b/bioencoder/scripts/model_explorer_wrapper.py index 7f97154..ff99605 100644 --- a/bioencoder/scripts/model_explorer_wrapper.py +++ b/bioencoder/scripts/model_explorer_wrapper.py @@ -12,8 +12,19 @@ def model_explorer_wrapper(config_path): script_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model_explorer.py") - process = ["streamlit", "run", script_path , "--", "--config-path", config_path] - subprocess.run(process, check=True) + process = [ + "streamlit", + "run", + script_path, + "--server.fileWatcherType", + "none", + "--", + "--config-path", + config_path, + ] + env = os.environ.copy() + env.setdefault("STREAMLIT_SERVER_FILE_WATCHER_TYPE", "none") + subprocess.run(process, check=True, env=env) def cli(): diff --git a/bioencoder/scripts/swa.py b/bioencoder/scripts/swa.py index 6c24c4d..3c04d71 100644 --- a/bioencoder/scripts/swa.py +++ b/bioencoder/scripts/swa.py @@ -71,10 +71,9 @@ def swa( os.remove(os.path.join(weights_dir, "swa")) ## scaler - scaler = torch.amp.GradScaler("cuda") - if not amp: - scaler = None + scaler = torch.amp.GradScaler("cuda") if (amp and torch.cuda.is_available()) else None utils.set_seed() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") transforms = utils.build_transforms(hyperparams) loaders = utils.build_loaders( @@ -85,7 +84,8 @@ def swa( second_stage=(stage == "second"), num_classes=num_classes, ckpt_pretrained=None, - ).cuda() + cuda_device=device, + ).to(device) ## inspect available checkpoints (epochN files only) epoch_files = [ @@ -110,7 +110,7 @@ def swa( state_dicts = [] for path in checkpoints_paths: - state_dicts.append(torch.load(path)["model_state_dict"]) + state_dicts.append(torch.load(path, map_location=device)["model_state_dict"]) average_dict = OrderedDict() for k in state_dicts[0].keys(): @@ -120,16 +120,16 @@ def swa( torch.save({"model_state_dict": average_dict}, os.path.join(weights_dir, "swa")) model.load_state_dict( - torch.load(os.path.join(weights_dir, "swa"))["model_state_dict"], + torch.load(os.path.join(weights_dir, "swa"), map_location=device)["model_state_dict"], ) if stage == "first": valid_metrics = utils.validation_constructive( - loaders["valid_loader"], loaders["train_loader"], model, scaler + loaders["valid_loader"], loaders["train_loader"], model, device, scaler ) else: valid_metrics = utils.validation_ce( - model, None, loaders["valid_loader"], scaler + model, None, loaders["valid_loader"], device, scaler ) print("swa stage {} validation metrics: {}".format(stage, valid_metrics)) @@ -150,4 +150,4 @@ def cli(): if __name__ == "__main__": - cli() \ No newline at end of file + cli() diff --git a/bioencoder/scripts/train.py b/bioencoder/scripts/train.py index 7d7788d..662c030 100644 --- a/bioencoder/scripts/train.py +++ b/bioencoder/scripts/train.py @@ -13,69 +13,34 @@ from rich.pretty import pretty_repr import torch +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.multiprocessing as mp + from torch.utils.tensorboard import SummaryWriter from torch_ema import ExponentialMovingAverage from bioencoder import config, utils #%% function - def train( config_path, dry_run=False, overwrite=False, **kwargs, ): - """ - Trains the BioEncoder model based on the provided configuration settings in the yaml files. - - Parameters - ---------- - config_path : str - Path to the YAML configuration file that specifies detailed training and optimizer parameters. - This file controls various aspects of the training process including but not limited to model architecture, - optimizer settings, scheduler details, and data augmentation strategies. - overwrite : bool, optional - If True, existing directories for logs, runs, and weights will be removed and recreated, allowing for a clean training start. - If False, the training process will append to existing directories and files, which could lead to mixed results if not managed properly. - Default is False. - - Raises - ------ - FileNotFoundError - If the configuration file specified by `config_path` does not exist. - AssertionError - If certain conditions in the configuration (like minimum image count per class) are not met. - ValueError - If incompatible or inconsistent parameters are detected during the setup or training processes. - - Notes - ----- - There are two separate files for the first and second stage - make sure you set them up appropriately. E.g., for stage two, - specify the number of classes, and a different learning rate. - - Examples - -------- - To start a new training session with overwriting previous outputs: - bioencoder.train("/path/to/config.yaml", overwrite=True) - - """ - - ## load bioencoer config root_dir = config.root_dir run_name = config.run_name - - ## load config hyperparams = utils.load_yaml(config_path) - ## parse config backbone = hyperparams["model"]["backbone"] amp = hyperparams["train"]["amp"] ema = hyperparams["train"]["ema"] ema_decay_per_epoch = hyperparams["train"]["ema_decay_per_epoch"] + progress_bar = hyperparams["train"].get("progress_bar", True) n_epochs = hyperparams["train"]["n_epochs"] target_metric = hyperparams["train"]["target_metric"] min_improvement = hyperparams["train"].get("min_improvement", 0.01) + target_metric_mode = hyperparams["train"].get("target_metric_mode", "auto") stage = hyperparams["train"]["stage"] optimizer_params = hyperparams["optimizer"] scheduler_params = hyperparams.get("scheduler", None) @@ -89,20 +54,60 @@ def train( aug_sample = aug_config.get("sample_save", False) aug_sample_n = aug_config.get("sample_n", 5) aug_sample_seed = aug_config.get("sample_seed", 42) + dist_config = hyperparams.get("distributed", {}) + distributed_enabled = kwargs.get("distributed", dist_config.get("enabled", False)) + distributed_backend = kwargs.get("backend", dist_config.get("backend", "nccl")) + find_unused_parameters = dist_config.get("find_unused_parameters", False) + sync_bn = dist_config.get("sync_bn", False) + grad_accum_steps = dist_config.get("grad_accum_steps", 1) + seed = dist_config.get("seed", 42) + env_rank = int(os.environ.get("RANK", "0")) + env_world_size = int(os.environ.get("WORLD_SIZE", "1")) + + if distributed_enabled: + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + utils.init_distributed(backend=distributed_backend, local_rank=local_rank) + env_rank = utils.get_rank() + env_world_size = utils.get_world_size() + device = torch.device(f"cuda:{local_rank}") + else: + local_rank = 0 + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - ## manage directories and paths data_dir = os.path.join(root_dir, "data", run_name) log_dir = os.path.join(root_dir, "logs", run_name, stage) run_dir = os.path.join(root_dir, "runs", run_name, stage) weights_dir = os.path.join(root_dir, "weights", run_name, stage) - for directory in [log_dir, run_dir, weights_dir]: - if os.path.exists(directory) and overwrite==True: - print(f"removing {directory} (overwrite=True)") - shutil.rmtree(directory) - os.makedirs(directory, exist_ok=True) - - ## collect information on data - train_dir, val_dir = os.path.join(data_dir, "train"), os.path.join(data_dir, "val") + + setup_token = os.path.join(root_dir, f".setup_{run_name}_{stage}.done") + + if distributed_enabled: + if local_rank == 0: + if os.path.exists(setup_token): + os.remove(setup_token) + for directory in [log_dir, run_dir, weights_dir]: + if os.path.exists(directory) and overwrite: + shutil.rmtree(directory) + os.makedirs(directory, exist_ok=True) + with open(setup_token, "w") as f: + f.write("ok\n") + else: + t0 = time.time() + while not os.path.exists(setup_token): + if time.time() - t0 > 300: + raise TimeoutError(f"Timed out waiting for setup token: {setup_token}") + time.sleep(0.1) + for directory in [log_dir, run_dir, weights_dir]: + os.makedirs(directory, exist_ok=True) + else: + for directory in [log_dir, run_dir, weights_dir]: + if os.path.exists(directory) and overwrite: + shutil.rmtree(directory) + os.makedirs(directory, exist_ok=True) + + train_dir = os.path.join(data_dir, "train") + val_dir = os.path.join(data_dir, "val") class_names = sorted( [ class_name @@ -110,280 +115,346 @@ def train( if os.path.isdir(os.path.join(train_dir, class_name)) ] ) - data_stats = {"data_dir": data_dir} - data_stats["train"], data_stats["val"] = {},{} + + data_stats = {"data_dir": data_dir, "train": {}, "val": {}} for class_name in class_names: train_class_dir = os.path.join(train_dir, class_name) val_class_dir = os.path.join(val_dir, class_name) data_stats["train"][class_name] = len( - [ - file_name - for file_name in os.listdir(train_class_dir) - if os.path.isfile(os.path.join(train_class_dir, file_name)) - ] + [f for f in os.listdir(train_class_dir) if os.path.isfile(os.path.join(train_class_dir, f))] ) data_stats["val"][class_name] = len( - [ - file_name - for file_name in os.listdir(val_class_dir) - if os.path.isfile(os.path.join(val_class_dir, file_name)) - ] + [f for f in os.listdir(val_class_dir) if os.path.isfile(os.path.join(val_class_dir, f))] ) - ## set up logging and tensorboard writer - writer = SummaryWriter(run_dir) logger = logging.getLogger() logger.setLevel(logging.INFO) - if (logger.hasHandlers()): + if logger.hasHandlers(): logger.handlers.clear() + log_file_path = os.path.join(log_dir, f"{run_name}_{stage}.log") - - ## logging: stdout handler + stdout_handler = logging.StreamHandler(sys.stdout) stdout_handler.setLevel(logging.DEBUG) - stdout_formatter = logging.Formatter('%(asctime)s: %(message)s', "%H:%M:%S") - stdout_handler.setFormatter(stdout_formatter) + stdout_handler.setFormatter(logging.Formatter("%(asctime)s: %(message)s", "%H:%M:%S")) logger.addHandler(stdout_handler) - ## logging: logfile handler - if os.path.isfile(log_file_path): - os.remove(log_file_path) - file_handler = logging.FileHandler(log_file_path) - file_handler.setLevel(logging.INFO) - file_formatter = logging.Formatter('%(asctime)s: %(message)s', "%Y-%m-%d %H:%M:%S") - file_handler.setFormatter(file_formatter) - logger.addHandler(file_handler) - - ## manage second stage - if stage == "second": + if (not distributed_enabled) or local_rank == 0: + if os.path.isfile(log_file_path): + os.remove(log_file_path) + file_handler = logging.FileHandler(log_file_path) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(logging.Formatter("%(asctime)s: %(message)s", "%Y-%m-%d %H:%M:%S")) + logger.addHandler(file_handler) + + logger.info(utils.pprint_fill_hbar(f"Training {stage} stage ", symbol="#")) + logger.info(f"Dataset:\n{pretty_repr(data_stats)}") + logger.info(f"Hyperparameters:\n{pretty_repr(hyperparams)}") - ## number of classes + if stage == "second": num_classes = hyperparams["model"]["num_classes"] - - ## add learning rate - if not "params" in optimizer_params: + if "params" not in optimizer_params: optimizer_params["params"] = {} - if kwargs.get("lr"): + if kwargs.get("lr") is not None: optimizer_params["params"]["lr"] = kwargs.get("lr") - if not "lr" in optimizer_params["params"].keys(): - if "second_lr" in config.__dict__.keys(): + if "lr" not in optimizer_params["params"]: + if "second_lr" in config.__dict__: optimizer_params["params"]["lr"] = float(config.second_lr) logger.info(f"Using LR value from global bioencoder config: {config.second_lr}") + elif "lr" in config.__dict__: + optimizer_params["params"]["lr"] = float(config.lr) + logger.info(f"Using LR value from global bioencoder config: {config.lr}") else: - lr = optimizer_params["params"]["lr"] - logger.info(f"Using LR value from local bioencoder config: {lr}") - + logger.info(f"Using LR value from local bioencoder config: {optimizer_params['params']['lr']}") assert "lr" in optimizer_params["params"], "no learning rate specified" - - ## fetch checkpoints from first stage - ckpt_pretrained = os.path.join(root_dir, "weights", run_name, 'first', "swa") - else: - ckpt_pretrained = None + ckpt_pretrained = os.path.join(root_dir, "weights", run_name, "first", "swa") + else: num_classes = None + ckpt_pretrained = None - ## add hyperparams to log - logger.info(utils.pprint_fill_hbar(f"Training {stage} stage ", symbol="#")) - logger.info(f"Dataset:\n{pretty_repr(data_stats)}") - logger.info(f"Hyperparameters:\n{pretty_repr(hyperparams)}") - - ## scaler - scaler = torch.amp.GradScaler("cuda") - if not amp: - scaler = None - - ## set seed for entire pipeline - utils.set_seed() - - ## configure GPU before moving model to CUDA - assert torch.cuda.device_count() > 0, "No GPUs detected on this System (check your CUDA setup) - aborting." - if torch.cuda.device_count() == 1: - logger.info(f"Found one GPU: {torch.cuda.get_device_name(0)} (device {torch.cuda.current_device()})") - else: - logger.info(f"Found {torch.cuda.device_count()} GPUs, but unfortunately multi-GPU use isn't implemented yet.") - logger.info(f"Using GPU {torch.cuda.get_device_name(0)} (device {torch.cuda.current_device()})") + transforms = utils.build_transforms(hyperparams) - # create model, loaders, optimizer, etc - transforms = utils.build_transforms(hyperparams) loaders = utils.build_loaders( - data_dir, - transforms, - batch_sizes, - num_workers, - second_stage=(stage == "second"), + data_dir=data_dir, + transforms=transforms, + batch_sizes=batch_sizes, + num_workers=num_workers, + second_stage=(stage == "second"), is_supcon=(criterion_params["name"] == "SupCon"), + distributed=distributed_enabled, + rank=env_rank, + world_size=env_world_size, ) + + train_sampler = loaders["train_loader"].sampler if distributed_enabled else None + train_supcon_sampler = ( + loaders["train_supcon_loader"].sampler + if distributed_enabled and "train_supcon_loader" in loaders + else None + ) + + if distributed_enabled: + logger.info( + f"DDP initialized (rank={utils.get_rank()}/{utils.get_world_size()}, local_rank={local_rank}, backend={distributed_backend})" + ) + else: + if torch.cuda.is_available(): + torch.cuda.set_device(0) + device = torch.device("cuda:0") + else: + device = torch.device("cpu") + + utils.set_seed(seed, rank_offset=(local_rank if distributed_enabled else 0)) + + scaler = torch.amp.GradScaler("cuda") if amp else None + + gpu_count = torch.cuda.device_count() + if gpu_count == 0: + logger.info("No GPU found. Using CPU.") + elif gpu_count == 1: + logger.info(f"Found one GPU: {torch.cuda.get_device_name(0)} (device {torch.cuda.current_device()})") + else: + if distributed_enabled: + logger.info(f"Found {gpu_count} GPUs and using DDP across {utils.get_world_size()} ranks") + else: + logger.info(f"Found {gpu_count} GPUs, but distributed mode is disabled.") + logger.info(f"Using GPU {torch.cuda.get_device_name(0)} (device {torch.cuda.current_device()})") + + writer = SummaryWriter(run_dir) if ((not distributed_enabled) or utils.is_main_process()) else None + model = utils.build_model( backbone, second_stage=(stage == "second"), num_classes=num_classes, ckpt_pretrained=ckpt_pretrained, - ).cuda() - - ## save a sample of augmented images - if aug_sample: + cuda_device=device, + ).to(device) + model = model.to(device) + + if distributed_enabled and sync_bn: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + + if distributed_enabled: + model = DDP( + model, + device_ids=[local_rank], + output_device=local_rank, + broadcast_buffers=False, + find_unused_parameters=find_unused_parameters, + ) + + if aug_sample and ((not distributed_enabled) or utils.is_main_process()): utils.save_augmented_sample(data_dir, transforms["train_transforms"], aug_sample_n, seed=aug_sample_seed) logger.info(f"Saving augmentation samples: {aug_sample_n} per class to data/{run_name}/aug_sample") logger.info(f"Using backbone: {backbone}") - ## configure optimizer - optim = utils.build_optim( - model, optimizer_params, scheduler_params, criterion_params - ) - criterion, optimizer, scheduler, loss_optimizer = ( - optim["criterion"], - optim["optimizer"], - optim["scheduler"], - optim["loss_optimizer"], - ) + optim = utils.build_optim(model, optimizer_params, scheduler_params, criterion_params) + + criterion = optim["criterion"] + optimizer = optim["optimizer"] + scheduler = optim["scheduler"] + loss_optimizer = optim["loss_optimizer"] + scheduler_step_per_batch = isinstance(scheduler, torch.optim.lr_scheduler.CyclicLR) scheduler_requires_metric = isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) + if ema: active_train_loader = loaders["train_supcon_loader"] if stage == "first" else loaders["train_loader"] iters = len(active_train_loader) ema_decay = ema_decay_per_epoch ** (1 / iters) ema = ExponentialMovingAverage(model.parameters(), decay=ema_decay) - if loss_optimizer is not None and stage == 'second': - raise ValueError('Loss optimizers should only be present for stage 1 training. Check your config file.') - - # epoch loop - metric_best = 0 - if not dry_run: - for epoch in range(n_epochs): - logger.info(utils.pprint_fill_hbar(f"START - Epoch {epoch}")) - start_training_time = time.time() - if stage == "first": - train_metrics = utils.train_epoch_constructive( - loaders["train_supcon_loader"], - model, - criterion, - optimizer, - scaler, - ema, - loss_optimizer, - scheduler=scheduler, - scheduler_step_per_batch=scheduler_step_per_batch, - ) - else: - train_metrics = utils.train_epoch_ce( - loaders["train_loader"], - model, - criterion, - optimizer, - scaler, - ema, - scheduler=scheduler, - scheduler_step_per_batch=scheduler_step_per_batch, - ) - end_training_time = time.time() - - if ema: - copy_of_model_parameters = utils.copy_parameters_from_model(model) - ema.copy_to(model.parameters()) - - start_validation_time = time.time() - - if stage == "first": - valid_metrics_projection_head = utils.validation_constructive( - loaders["valid_loader"], loaders["train_loader"], model, scaler - ) - - ## check for GPU parallelization - #model_copy = model.module if isinstance(model, torch.nn.DataParallel) else model - - #model_copy.use_projection_head(False) - model.use_projection_head(False) - valid_metrics_encoder = utils.validation_constructive( - loaders["valid_loader"], loaders["train_loader"], model, scaler - ) - model.use_projection_head(True) - #model_copy.use_projection_head(True) parser.add_argument("--dry_run", action='store_true', help="Run without making any changes.") - - - ## epoch summary - message = "Summary epoch {}:\ntrain time {:.2f}\nvalid time {:.2f}\ntrain loss {:.2f}\nvalid acc projection head {}\nvalid acc encoder {}".format( - epoch, - end_training_time - start_training_time, - time.time() - start_validation_time, - train_metrics["loss"], - pretty_repr(valid_metrics_projection_head), - pretty_repr(valid_metrics_encoder), - ) - logger.info("\n".join(line if i == 0 else " " + line for i, line in enumerate(message.split("\n")))) - valid_metrics = valid_metrics_projection_head - else: - valid_metrics = utils.validation_ce( - model, criterion, loaders["valid_loader"], scaler - ) - ## epoch summary - message = "Summary epoch {}:\ntrain time {:.2f}\nvalid time {:.2f}\ntrain loss {:.2f}\nvalid acc dict {}".format( + if loss_optimizer is not None and stage == "second": + raise ValueError("Loss optimizers should only be present for stage 1 training. Check your config file.") + + if target_metric_mode not in {"auto", "max", "min"}: + raise ValueError("target_metric_mode must be one of: auto, max, min") + if target_metric_mode == "auto": + mode = "min" if "loss" in str(target_metric).lower() else "max" + else: + mode = target_metric_mode + + metric_best = float("inf") if mode == "min" else float("-inf") + + try: + if not dry_run: + for epoch in range(n_epochs): + if distributed_enabled: + if train_sampler is not None: + train_sampler.set_epoch(epoch) + if train_supcon_sampler is not None: + train_supcon_sampler.set_epoch(epoch) + + logger.info(utils.pprint_fill_hbar(f"START - Epoch {epoch}")) + start_training_time = time.time() + + if stage == "first": + train_metrics = utils.train_epoch_constructive( + loaders["train_supcon_loader"], + model, + criterion, + optimizer, + scaler, + ema, + loss_optimizer, + scheduler=scheduler, + scheduler_step_per_batch=scheduler_step_per_batch, + device=device, + grad_accum_steps=grad_accum_steps, + progress_bar=progress_bar, + epoch=epoch, + ) + else: + train_metrics = utils.train_epoch_ce( + loaders["train_loader"], + model, + criterion, + optimizer, + scaler, + ema, + scheduler=scheduler, + scheduler_step_per_batch=scheduler_step_per_batch, + device=device, + grad_accum_steps=grad_accum_steps, + progress_bar=progress_bar, + epoch=epoch, + ) + + end_training_time = time.time() + + if ema: + copy_of_model_parameters = utils.copy_parameters_from_model(model) + ema.copy_to(model.parameters()) + + start_validation_time = time.time() + + if stage == "first": + valid_metrics_projection_head = utils.validation_constructive( + loaders["valid_loader"], + loaders["train_loader"], + model, + device, + scaler, + progress_bar=progress_bar, + epoch=epoch, + split_name="projection", + ) + model_ref = model.module if isinstance(model, DDP) else model + model_ref.use_projection_head(False) + valid_metrics_encoder = utils.validation_constructive( + loaders["valid_loader"], + loaders["train_loader"], + model, + device, + scaler, + progress_bar=progress_bar, + epoch=epoch, + split_name="encoder", + ) + model_ref.use_projection_head(True) + + message = ( + "Summary epoch {}:\ntrain time {:.2f}\nvalid time {:.2f}\ntrain loss {:.2f}\n" + "valid acc projection head {}\nvalid acc encoder {}" + ).format( epoch, end_training_time - start_training_time, time.time() - start_validation_time, train_metrics["loss"], - pretty_repr(valid_metrics), - ) - logger.info("\n".join(line if i == 0 else " " + line for i, line in enumerate(message.split("\n")))) - - if target_metric not in valid_metrics: - raise ValueError( - f"target_metric='{target_metric}' not found in validation metrics. " - f"Available metrics: {list(valid_metrics.keys())}" - ) - - # write train and valid metrics to the logs - utils.add_to_tensorboard_logs( - writer, train_metrics["loss"], "Loss/train", epoch - ) - for valid_metric in valid_metrics: - try: - utils.add_to_tensorboard_logs( - writer, - valid_metrics[valid_metric], - "{}/validation".format(valid_metric), + pretty_repr(valid_metrics_projection_head), + pretty_repr(valid_metrics_encoder), + ) + logger.info("\n".join(line if i == 0 else " " + line for i, line in enumerate(message.split("\n")))) + valid_metrics = valid_metrics_projection_head + else: + valid_metrics = utils.validation_ce( + model, + criterion, + loaders["valid_loader"], + device, + scaler, + progress_bar=progress_bar, + epoch=epoch, + ) + message = ( + "Summary epoch {}:\ntrain time {:.2f}\nvalid time {:.2f}\ntrain loss {:.2f}\nvalid acc dict {}" + ).format( epoch, + end_training_time - start_training_time, + time.time() - start_validation_time, + train_metrics["loss"], + pretty_repr(valid_metrics), ) - except AssertionError: - # in case valid metric is a listhyperparams - pass - - # check if the best value of metric changed. If so -> save the model - current_metric = valid_metrics[target_metric] - if metric_best == 0 or current_metric > metric_best * (1 + min_improvement): - logger.info( - "{} improved by ≥{:.2%} ({:.6f} --> {:.6f}). Saving model ...".format( - target_metric, min_improvement, metric_best, current_metric + logger.info("\n".join(line if i == 0 else " " + line for i, line in enumerate(message.split("\n")))) + + if target_metric not in valid_metrics: + raise ValueError( + f"target_metric='{target_metric}' not found in validation metrics. " + f"Available metrics: {list(valid_metrics.keys())}" + ) + + if (not distributed_enabled) or utils.is_main_process(): + utils.add_to_tensorboard_logs(writer, train_metrics["loss"], "Loss/train", epoch) + for valid_metric in valid_metrics: + try: + utils.add_to_tensorboard_logs( + writer, + valid_metrics[valid_metric], + f"{valid_metric}/validation", + epoch, + ) + except AssertionError: + pass + + current_metric = valid_metrics[target_metric] + if mode == "max": + improved = current_metric > metric_best * (1 + min_improvement) + else: + improved = current_metric < metric_best * (1 - min_improvement) + if improved: + logger.info( + "{} improved by ≥{:.2%} ({:.6f} --> {:.6f}). Saving model ...".format( + target_metric, min_improvement, metric_best, current_metric + ) + ) + if (not distributed_enabled) or utils.is_main_process(): + model_state = model.module.state_dict() if isinstance(model, DDP) else model.state_dict() + torch.save( + { + "epoch": epoch, + "model_state_dict": model_state, + "optimizer_state_dict": optimizer.state_dict(), + }, + os.path.join(weights_dir, f"epoch{epoch}"), + ) + metric_best = current_metric + else: + logger.info( + f"Metric {target_metric} did not improve by ≥{min_improvement:.2%} " + f"(best: {metric_best:.6f}, current: {current_metric:.6f})" ) - ) - - torch.save( - { - "epoch": epoch, - "model_state_dict": model.state_dict(), - "optimizer_state_dict": optimizer.state_dict(), - }, - os.path.join(weights_dir, f"epoch{epoch}"), - ) - metric_best = current_metric - else: - logger.info(f"Metric {target_metric} did not improve by ≥{min_improvement:.2%} (best: {metric_best:.6f}, current: {current_metric:.6f})") - - # if ema is used, go back to regular weights without ema - if ema: - utils.copy_parameters_to_model(copy_of_model_parameters, model) - - if scheduler is not None: - if scheduler_requires_metric: - scheduler.step(valid_metrics[target_metric]) - elif not scheduler_step_per_batch: - scheduler.step() - logger.info(utils.pprint_fill_hbar(f"END - Epoch {epoch}")) - else: - logger.info(utils.pprint_fill_hbar("DRY-RUN ONLY - NO TRAINING")) - writer.close() - logging.shutdown() + if ema: + utils.copy_parameters_to_model(copy_of_model_parameters, model) + + if scheduler is not None: + if scheduler_requires_metric: + scheduler.step(valid_metrics[target_metric]) + elif not scheduler_step_per_batch: + scheduler.step() + + logger.info(utils.pprint_fill_hbar(f"END - Epoch {epoch}")) + else: + logger.info(utils.pprint_fill_hbar("DRY-RUN ONLY - NO TRAINING")) + finally: + if writer is not None: + writer.close() + if distributed_enabled and utils.is_distributed(): + utils.teardown_distributed() + if distributed_enabled and local_rank == 0 and os.path.exists(setup_token): + os.remove(setup_token) + logging.shutdown() def cli(): @@ -391,11 +462,21 @@ def cli(): parser.add_argument("--config-path",type=str, required=True, help="Path to the YAML configuration file that specifies detailed training and optimizer parameters.") parser.add_argument("--dry-run", action='store_true', help="Run without starting the training to inspect config and augmentations.") parser.add_argument("--overwrite", action='store_true', help="Overwrite existing files without asking.") + parser.add_argument("--distributed", action='store_true', help="Enable Distributed Data Parallel training.") + parser.add_argument("--backend", type=str, default="nccl", help="Distributed backend.") + parser.add_argument("--local-rank", "--local_rank", dest="local_rank", type=int, default=None, help="Local rank set by torchrun.") args = parser.parse_args() train_cli = utils.restore_config(train) - train_cli(args.config_path, overwrite=args.overwrite, dry_run=args.dry_run) + train_cli( + args.config_path, + overwrite=args.overwrite, + dry_run=args.dry_run, + distributed=args.distributed, + backend=args.backend, + local_rank=args.local_rank, + ) if __name__ == "__main__": - + mp.set_start_method("spawn", force=True) cli() diff --git a/bioencoder/vis/classes.py b/bioencoder/vis/classes.py index 9eeca0e..4d7b0cd 100644 --- a/bioencoder/vis/classes.py +++ b/bioencoder/vis/classes.py @@ -191,6 +191,15 @@ def backward(self, grad_output): return grad_input +class GuidedBackpropReLUModule(nn.Module): + """ + nn.Module wrapper for GuidedBackpropReLU autograd function. + """ + + def forward(self, input_img): + return GuidedBackpropReLU.apply(input_img) + + class GuidedBackpropReLUModel: """ A class that creates a model with GuidedBackpropReLU activation functions instead of standard ReLU activations. @@ -213,7 +222,7 @@ def recursive_relu_apply(module_top): for idx, module in module_top._modules.items(): recursive_relu_apply(module) if module.__class__.__name__ == 'ReLU': - module_top._modules[idx] = GuidedBackpropReLU.apply + module_top._modules[idx] = GuidedBackpropReLUModule() # replace ReLU with GuidedBackpropReLU recursive_relu_apply(self.model) diff --git a/bioencoder/vis/helpers.py b/bioencoder/vis/helpers.py index 13e99c2..4098f6d 100644 --- a/bioencoder/vis/helpers.py +++ b/bioencoder/vis/helpers.py @@ -150,8 +150,9 @@ def gen_coords(i, patch_size, stride, dim1, dim2): tuple Tuple containing the (x0, y0, x1, y1) coordinates of the patch. """ - x0 = int(stride * (i % dim1)) - y0 = int(stride * int(i / dim2)) + # dim1 is number of rows (y), dim2 is number of columns (x). + x0 = int(stride * (i % dim2)) + y0 = int(stride * (i // dim2)) x1 = x0 + patch_size y1 = y0 + patch_size @@ -285,8 +286,13 @@ class labels of the images). if not all(col in df.columns for col in ['paths', 'class']): raise ValueError("The dataframe must have columns 'paths' and 'class'") - - unique_classes = df['class'].unique() + df = df.copy() + if "class_str" not in df.columns: + df["class_str"] = df["class"].astype(str) + if "dataset" not in df.columns: + df["dataset"] = "dataset" + + unique_classes = df['class_str'].unique() unique_datasets = df['dataset'].astype(str).unique() markers = ['circle', 'square'] # Define markers for each group @@ -303,7 +309,8 @@ class labels of the images). else: num_classes = len(unique_classes) cmap = plt.cm.get_cmap(color_map, num_classes) - colors_raw = cmap(df['class'], bytes=True) + class_ids, _ = pd.factorize(df['class_str']) + colors_raw = cmap(class_ids, bytes=True) colors_str = ['#%02x%02x%02x' % tuple(c[:3]) for c in colors_raw] df['color'] = colors_str diff --git a/bioencoder/vis/methods.py b/bioencoder/vis/methods.py index 2da3dba..be6e2c7 100644 --- a/bioencoder/vis/methods.py +++ b/bioencoder/vis/methods.py @@ -69,7 +69,8 @@ def hook_fn(self, input, output): acts = acts[0][0].cpu().detach().numpy() # Subset the output for the first copy if acts.shape[0] > max_acts: - acts = acts[torch.randperm(acts.shape[0])[:max_acts]] + idx = np.random.choice(acts.shape[0], size=max_acts, replace=False) + acts = acts[idx] sqrt = int(acts.shape[0]**0.5) fig, axs = plt.subplots(sqrt, sqrt) diff --git a/bioencoder_configs/explore_stage1.yml b/bioencoder_configs/explore_stage1.yml index e962eba..7e49501 100644 --- a/bioencoder_configs/explore_stage1.yml +++ b/bioencoder_configs/explore_stage1.yml @@ -1,3 +1,5 @@ model: backbone: timm_tf_efficientnet_b5.ns_jft_in1k # Model architecture and pre-trained weights to use stage: first # Training stage: 'first' for embeddings, 'second' for classification + +img_size: 256 # Input image size used by the stage-1 model diff --git a/bioencoder_configs/explore_stage2.yml b/bioencoder_configs/explore_stage2.yml index c2e8ed9..cf2a43e 100644 --- a/bioencoder_configs/explore_stage2.yml +++ b/bioencoder_configs/explore_stage2.yml @@ -3,3 +3,4 @@ model: stage: second # Training stage: 'first' for embeddings, 'second' for classification num_classes: 4 # Number of output classes for classification +img_size: 256 # Input image size used by the stage-2 model diff --git a/bioencoder_configs/inference.yml b/bioencoder_configs/inference.yml index 839ed09..fdb1f5a 100644 --- a/bioencoder_configs/inference.yml +++ b/bioencoder_configs/inference.yml @@ -5,6 +5,6 @@ model: stage: second # Training stage: 'first' for embeddings, 'second' for classification num_classes: 4 # Number of output classes for classification -img_size: 384 # Image size for training and validation +img_size: 256 # Image size for training and validation -return_probs: false # Whether to return the probabilities for each class +return_probs: true # Whether to return the probabilities for each class diff --git a/bioencoder_configs/lr_finder.yml b/bioencoder_configs/lr_finder.yml index ef25a25..8cfc27f 100644 --- a/bioencoder_configs/lr_finder.yml +++ b/bioencoder_configs/lr_finder.yml @@ -3,14 +3,16 @@ model: num_classes: 4 # Number of output classes for classification dataloaders: - train_batch_size: 50 # Batch size for training data; larger sizes utilize GPU memory better - valid_batch_size: 50 # Batch size for validation data - num_workers: 32 # Number of CPU threads for data loading; set to the number of CPU cores available + train_batch_size: 40 # Batch size for training data; larger sizes utilize GPU memory better + valid_batch_size: 40 # Batch size for validation data + num_workers: 4 # Number of CPU threads for data loading; set to the number of CPU cores available optimizer: name: SGD # Optimizer type params: lr: 0.001 # Learning rate +img_size: 256 # Input image size used by the stage-1 model + criterion: name: 'CrossEntropy' # Loss function for multi-class classification diff --git a/bioencoder_configs/plot_stage1.yml b/bioencoder_configs/plot_stage1.yml index 4d8b1e0..cedf12b 100644 --- a/bioencoder_configs/plot_stage1.yml +++ b/bioencoder_configs/plot_stage1.yml @@ -6,11 +6,12 @@ model: dataloaders: train_batch_size: 20 # Larger is faster; no value or removing this line will not include training data valid_batch_size: 20 # Larger is faster; val data is always plotted - num_workers: 32 # Should not exceed available CPU cores + num_workers: 4 # Should not exceed available CPU cores -img_size: 384 # image size used for training +img_size: 256 # image size used for training perplexity: 30 # for tSNE<; cannot be larger than dataset +progress_bar: True # Show progress while computing train/val embeddings plot_style: 2 # (1: pictogram above point, 2: pictogram next to plot panel) point_size: 10 ## size of points in scatter plot diff --git a/bioencoder_configs/swa_stage1.yml b/bioencoder_configs/swa_stage1.yml index c8ee234..24cfc50 100644 --- a/bioencoder_configs/swa_stage1.yml +++ b/bioencoder_configs/swa_stage1.yml @@ -7,8 +7,8 @@ train: stage: first # Training stage: 'first' for SupCon, 'second' for fine-tuning classification dataloaders: - train_batch_size: 40 # Batch size for training data - valid_batch_size: 40 # Batch size for validation data - num_workers: 16 # Number of CPU threads for data loading + train_batch_size: 44 # Batch size for training data + valid_batch_size: 44 # Batch size for validation data + num_workers: 4 # Number of CPU threads for data loading -img_size: 384 # Image size for training and validation +img_size: 256 # Image size for training and validation diff --git a/bioencoder_configs/swa_stage2.yml b/bioencoder_configs/swa_stage2.yml index f3ddc41..2cd2d72 100644 --- a/bioencoder_configs/swa_stage2.yml +++ b/bioencoder_configs/swa_stage2.yml @@ -10,6 +10,6 @@ train: dataloaders: train_batch_size: 40 # Batch size for training data valid_batch_size: 40 # Batch size for validation data - num_workers: 16 # Number of CPU threads for data loading + num_workers: 4 # Number of CPU threads for data loading -img_size: 384 # Image size for training and validation +img_size: 256 # Image size for training and validation diff --git a/bioencoder_configs/train_stage1.yml b/bioencoder_configs/train_stage1.yml index 3260a1e..3d61ab4 100644 --- a/bioencoder_configs/train_stage1.yml +++ b/bioencoder_configs/train_stage1.yml @@ -6,19 +6,32 @@ train: n_epochs: &epochs 100 # Number of training epochs amp: True # Enable Automatic Mixed Precision (AMP) for faster training on compatible GPUs ema: True # Use Exponential Moving Average to stabilize training - ema_decay_per_epoch: 0.4 # EMA decay rate; adjust based on dataset size + progress_bar: True # Show per-batch tqdm progress during each epoch (rank 0 only in DDP) + ema_decay_per_epoch: 0.9 # EMA decay rate; adjust based on dataset size target_metric: precision_at_1 # Metric to optimize during training stage: first # Training stage: 'first' for SupCon, 'second' for fine-tuning classification +distributed: + enabled: True # Enable Distributed Data Parallel training with torchrun + backend: nccl # Distributed backend for single-node multi-GPU + find_unused_parameters: False # Set True only if you have intentionally unused branches + sync_bn: False # Convert BatchNorm to SyncBatchNorm across GPUs + grad_accum_steps: 1 # Gradient accumulation steps per rank + seed: 42 # Random seed used as base for reproducibility + dataloaders: - train_batch_size: 40 # Batch size for training data - valid_batch_size: 40 # Batch size for validation data - num_workers: 16 # Number of CPU threads for data loading + train_batch_size: 44 # Batch size for training data + valid_batch_size: 44 # Batch size for validation data + num_workers: 4 # Number of CPU threads for data loading optimizer: - name: SGD # Optimizer type - see https://github.com/agporto/BioEncoder/blob/main/help/05-options.md#optimizers + name: SGD params: - lr: 0.003 # Learning rate + lr: 0.003 + momentum: 0.9 + weight_decay: 0.0 + foreach: False + fused: False scheduler: name: CosineAnnealingLR # Learning rate scheduler - see https://github.com/agporto/BioEncoder/blob/main/help/05-options.md#schedulers @@ -31,21 +44,17 @@ criterion: params: temperature: 0.1 # Temperature parameter for contrastive loss -img_size: &size 384 # Image size for training and validation +img_size: &size 256 # Image size for training and validation -augmentations: # augmentations to be applied - see https://github.com/agporto/BioEncoder/blob/main/help/05-options.md#augmentations - sample_save: True # Whether to save a sample of augmented images - sample_n: 10 # Number of augmented image samples per class to save - sample_seed: 42 # Seed for random sample - transforms: - - RandomResizedCrop: # Randomly resize and crop the image +augmentations: + sample_save: False + sample_n: 10 + sample_seed: 42 + transforms: + - RandomResizedCrop: height: *size width: *size - scale: !!python/tuple [0.7,1] - - Flip: # Randomly flip the image horizontally - - RandomRotate90: # Randomly rotate the image by 90 degrees - - MedianBlur: # Apply median blur with a probability - - ShiftScaleRotate: # Randomly apply affine transformations - - OpticalDistortion: # Apply optical distortion to the image - - GridDistortion: # Apply grid distortion to the image - - HueSaturationValue: # Randomly change hue, saturation, and value of the image + scale: !!python/tuple [0.8, 1.0] + - Flip: + - ShiftScaleRotate: + - HueSaturationValue: diff --git a/bioencoder_configs/train_stage2.yml b/bioencoder_configs/train_stage2.yml index c4c8984..d493928 100644 --- a/bioencoder_configs/train_stage2.yml +++ b/bioencoder_configs/train_stage2.yml @@ -6,19 +6,32 @@ train: n_epochs: &epochs 30 # Number of training epochs amp: True # Enable Automatic Mixed Precision (AMP) for faster training on compatible GPUs ema: True # Use Exponential Moving Average to stabilize training + progress_bar: True # Show per-batch tqdm progress during each epoch (rank 0 only in DDP) ema_decay_per_epoch: 0.4 # EMA decay rate; adjust based on dataset size target_metric: accuracy # Metric to optimize during training stage: second # Training stage: 'first' for SupCon, 'second' for fine-tuning classification +distributed: + enabled: True # Enable Distributed Data Parallel training with torchrun + backend: nccl # Distributed backend for single-node multi-GPU + find_unused_parameters: False # Set True only if you have intentionally unused branches + sync_bn: False # Convert BatchNorm to SyncBatchNorm across GPUs + grad_accum_steps: 1 # Gradient accumulation steps per rank + seed: 42 # Random seed used as base for reproducibility + dataloaders: train_batch_size: 40 # Batch size for training data valid_batch_size: 40 # Batch size for validation data - num_workers: 16 # Number of CPU threads for data loading + num_workers: 4 # Number of CPU threads for data loading optimizer: name: SGD # Optimizer type - see https://github.com/agporto/BioEncoder/blob/main/help/05-options.md#optimizers params: - lr: 0.3 # Learning rate + lr: 0.003 # Learning rate + momentum: 0.9 + weight_decay: 0.0 + foreach: False + fused: False scheduler: name: CosineAnnealingLR # Learning rate scheduler - see https://github.com/agporto/BioEncoder/blob/main/help/05-options.md#schedulers @@ -29,25 +42,21 @@ scheduler: criterion: name: 'LabelSmoothing' # Loss function - see https://github.com/agporto/BioEncoder/blob/main/help/05-options.md#losses params: - classes: 100 # Number of classes (adjust based on actual number of classes) + classes: 4 # Number of classes (adjust based on actual number of classes) smoothing: 0.01 # Smoothing factor for label smoothing -img_size: &size 384 # Image size for training and validation +img_size: &size 256 # Image size for training and validation -augmentations: # augmentations to be applied - see https://github.com/agporto/BioEncoder/blob/main/help/05-options.md#augmentations - sample_save: True # Whether to save a sample of augmented images - sample_n: 10 # Number of augmented image samples per class to save - sample_seed: 42 # Seed for random sample +augmentations: + sample_save: False + sample_n: 10 + sample_seed: 42 transforms: - - RandomResizedCrop: # Randomly resize and crop the image + - RandomResizedCrop: height: *size width: *size - scale: !!python/tuple [0.7,1] - - Flip: # Randomly flip the image horizontally - - RandomRotate90: # Randomly rotate the image by 90 degrees - - MedianBlur: # Apply median blur with a probability - - ShiftScaleRotate: # Randomly apply affine transformations - - OpticalDistortion: # Apply optical distortion to the image - - GridDistortion: # Apply grid distortion to the image - - HueSaturationValue: # Randomly change hue, saturation, and value of the image + scale: !!python/tuple [0.8, 1.0] + - Flip: + - ShiftScaleRotate: + - HueSaturationValue: diff --git a/help/03-training.md b/help/03-training.md index 403a75d..68a8df6 100644 --- a/help/03-training.md +++ b/help/03-training.md @@ -46,3 +46,24 @@ To train stage 2 and do SWA, run the following command: bioencoder.train(config_path=r"bioencoder_configs/train_stage2.yml", overwrite=True) bioencoder.swa(config_path=r"bioencoder_configs/swa_stage2.yml") ``` + +# Single-node multi-GPU (DDP) + +BioEncoder supports single-node multi-GPU training via PyTorch DDP (`torchrun`). + +Example with 8 GPUs: + +```bash +torchrun --standalone --nproc_per_node=8 -m bioencoder.scripts.train \ + --config-path bioencoder_configs/train_stage2.yml \ + --distributed --backend nccl +``` + +You can also enable distributed mode directly in the YAML using: + +- `distributed.enabled` +- `distributed.backend` +- `distributed.find_unused_parameters` +- `distributed.sync_bn` +- `distributed.grad_accum_steps` +- `distributed.seed` diff --git a/pyproject.toml b/pyproject.toml index ec5035b..2f63b0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ readme = "README.md" requires-python = ">=3.9" keywords = ["metric learning", "biology"] dynamic = ["dependencies"] -version = "1.0.5" +version = "1.0.6" [project.urls] "Homepage" = "https://github.com/agporto/BioEncoder" diff --git a/requirements.txt b/requirements.txt index 616a395..4566f2a 100755 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,7 @@ pandas pytorch-metric-learning==2.0.1 rich scikit-learn +streamlit streamlit-option-menu tensorboard timm diff --git a/run_train_ddp_2gpu.sh b/run_train_ddp_2gpu.sh new file mode 100755 index 0000000..9c1a953 --- /dev/null +++ b/run_train_ddp_2gpu.sh @@ -0,0 +1,42 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR" + +ENV_NAME="${BIOENCODER_CONDA_ENV:-bioencoder_dev}" +NPROC_PER_NODE="${NPROC_PER_NODE:-2}" +BACKEND="${BACKEND:-nccl}" +CONFIG_PATH="${1:-bioencoder_configs/train_stage1.yml}" + +if [[ $# -gt 0 ]]; then + shift +fi + +if [[ -f "$HOME/miniconda3/etc/profile.d/conda.sh" ]]; then + # Typical local Miniconda install. + source "$HOME/miniconda3/etc/profile.d/conda.sh" +elif command -v conda >/dev/null 2>&1; then + # Fallback to whatever conda is on PATH. + source "$(conda info --base)/etc/profile.d/conda.sh" +else + echo "Could not find conda initialization script." >&2 + exit 1 +fi + +conda activate "$ENV_NAME" + +# Workaround for local NCCL transport issues seen on this machine. +export NCCL_P2P_DISABLE="${NCCL_P2P_DISABLE:-1}" +export NCCL_SHM_DISABLE="${NCCL_SHM_DISABLE:-1}" +export OMP_NUM_THREADS="${OMP_NUM_THREADS:-1}" + +exec torchrun \ + --standalone \ + --nnodes=1 \ + --nproc_per_node="$NPROC_PER_NODE" \ + -m bioencoder.scripts.train \ + --config-path "$CONFIG_PATH" \ + --distributed \ + --backend "$BACKEND" \ + "$@"