diff --git a/dinov2/layers/dino_head.py b/dinov2/layers/dino_head.py index 0ace8ff..fb66072 100644 --- a/dinov2/layers/dino_head.py +++ b/dinov2/layers/dino_head.py @@ -19,13 +19,18 @@ def __init__( hidden_dim=2048, bottleneck_dim=256, mlp_bias=True, + normalize=True, + remove_last_layer=False ): super().__init__() nlayers = max(nlayers, 1) self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) self.apply(self._init_weights) - self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) - self.last_layer.weight_g.data.fill_(1) + if not remove_last_layer: + self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + self.normalize = normalize + self.remove_last_layer = remove_last_layer def _init_weights(self, m): if isinstance(m, nn.Linear): @@ -36,8 +41,10 @@ def _init_weights(self, m): def forward(self, x): x = self.mlp(x) eps = 1e-6 if x.dtype == torch.float16 else 1e-12 - x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) - x = self.last_layer(x) + if self.normalize: + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + if not self.remove_last_layer: + x = self.last_layer(x) return x diff --git a/dinov2/loss/__init__.py b/dinov2/loss/__init__.py index c9da02a..2bb01fa 100644 --- a/dinov2/loss/__init__.py +++ b/dinov2/loss/__init__.py @@ -3,7 +3,7 @@ # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. -from .dino_clstoken_loss import DINOLoss -from .ibot_patch_loss import iBOTPatchLoss +from .dino_clstoken_loss import DINOLoss, MCRLoss +from .ibot_patch_loss import iBOTPatchLoss, CosinePatchLoss from .koleo_loss import KoLeoLoss from .kde_loss import KDELoss diff --git a/dinov2/loss/dino_clstoken_loss.py b/dinov2/loss/dino_clstoken_loss.py index c31808e..c979596 100644 --- a/dinov2/loss/dino_clstoken_loss.py +++ b/dinov2/loss/dino_clstoken_loss.py @@ -2,24 +2,23 @@ # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. - import torch import torch.distributed as dist import torch.nn.functional as F from torch import nn - -class DINOLoss(nn.Module): +class DINOCenter(nn.Module): def __init__( self, out_dim, - student_temp=0.1, + enable=True, center_momentum=0.9, ): super().__init__() - self.student_temp = student_temp - self.center_momentum = center_momentum + if not enable: + return self.register_buffer("center", torch.zeros(1, out_dim)) + self.center_momentum = center_momentum self.updated = True self.reduce_handle = None self.len_teacher_output = None @@ -44,35 +43,17 @@ def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_iterations=3): if dist.is_initialized(): dist.all_reduce(sum_Q) Q /= sum_Q - for it in range(n_iterations): - # normalize each row: total weight per prototype must be 1/K sum_of_rows = torch.sum(Q, dim=1, keepdim=True) if dist.is_initialized(): dist.all_reduce(sum_of_rows) Q /= sum_of_rows Q /= K - - # normalize each column: total weight per sample must be 1/B Q /= torch.sum(Q, dim=0, keepdim=True) Q /= B - - Q *= B # the columns must sum to 1 so that Q is an assignment + Q *= B return Q.t() - def forward(self, student_output_list, teacher_out_softmaxed_centered_list): - """ - Cross-entropy between softmax outputs of the teacher and student networks. - """ - # TODO: Use cross_entropy_distribution here - total_loss = 0 - for s in student_output_list: - lsm = F.log_softmax(s / self.student_temp, dim=-1) - for t in teacher_out_softmaxed_centered_list: - loss = torch.sum(t * lsm, dim=-1) - total_loss -= loss.mean() - return total_loss - @torch.no_grad() def update_center(self, teacher_output): self.reduce_center_update(teacher_output) @@ -89,7 +70,6 @@ def reduce_center_update(self, teacher_output): def apply_center_update(self): if self.updated is False: world_size = dist.get_world_size() if dist.is_initialized() else 1 - if self.reduce_handle is not None: self.reduce_handle.wait() _t = self.async_batch_center / (self.len_teacher_output * world_size) @@ -97,3 +77,109 @@ def apply_center_update(self): self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum) self.updated = True + +def half_logdet(X): + # Ensure Matrix is Float32 for Cholesky Stability + return torch.linalg.cholesky_ex(X.float())[0].diagonal().log().sum() + +class MCRLoss(DINOCenter): + def __init__(self, out_dim, expa_type=0, reduce_cov=0, eps=0.05, coeff=1, center=False, *args, **kwargs): + super().__init__(out_dim, enable=center) + self.eps = eps + self.coeff = coeff + self.expa_type = expa_type + self.reduce_cov = reduce_cov + + def forward(self, student_feat_list, teacher_feat_list, no_diag=True, normalized=True): + """ + Expansion Loss and Compression Loss between features of the teacher and student networks. + """ + # Convert lists of tensors to a single tensor for vectorized operations + student_feat = torch.stack(student_feat_list) #ncrops,N,D + teacher_feat = torch.stack(teacher_feat_list) #2,N,D + if not normalized: + student_feat = F.normalize(student_feat, p=2, dim=-1) + teacher_feat = F.normalize(teacher_feat, p=2, dim=-1) + comp_loss, global_comp_loss = self.calc_compression(student_feat, teacher_feat, no_diag=no_diag) + match self.expa_type: + case 0: # only compute expansion on global views + expa_feat = student_feat[:len(teacher_feat)] + case 1: # center with teacher + expa_feat = (student_feat[:len(teacher_feat)] + teacher_feat) / 2 + expa_loss = self.calc_expansion(expa_feat) + loss = - self.coeff * comp_loss - expa_loss + return loss, {"loss": loss.detach(), "comp_loss":comp_loss.detach(), "global_comp_loss":global_comp_loss.detach(), "expa_loss":expa_loss.detach()} + + def calc_compression(self, student_feat_list, teacher_feat_list, no_diag=True): + """ + Compute compression loss between student and teacher features. + """ + + # Compute cosine similarity for all pairs + comp_loss = 0 + sim = (teacher_feat_list.unsqueeze(1)*student_feat_list.unsqueeze(0)).sum(-1).mean(-1) + # Mask out the diagonal elements where student and teacher operate on the same view + #mask = torch.eye(len(teacher_feat_list), len(student_feat_list), dtype=torch.bool,device=cosine_sim.device).unsqueeze_(2) + #sim = cosine_sim.masked_fill(mask, 0) + if no_diag: + sim.view(-1)[:: (len(student_feat_list) + 1)].fill_(0) # Trick to fill diagonal + + n_loss_terms = len(teacher_feat_list)* len(student_feat_list) - min(len(teacher_feat_list), len(student_feat_list)) + # Sum the cosine similarities + comp_loss = sim.sum()/n_loss_terms + global_comp_loss = sim[:, :len(teacher_feat_list)].detach().sum().div_(len(teacher_feat_list)) + return comp_loss, global_comp_loss + + def calc_expansion(self, feat_list, cross_list=None) -> torch.Tensor: + """ + Compute expansion loss using Coding Rate estimation. + """ + cov = [] + num_views = len(feat_list) + m, p = feat_list[0].shape + cov = torch.einsum('nbc,nbd->ncd', feat_list, cross_list or feat_list) + N=1 + if dist.is_initialized(): + N = dist.get_world_size() + if self.reduce_cov == 1: + cov = dist.nn.all_reduce(cov) + loss = 0 + scalar = p / (m * N * self.eps) + I = torch.eye(p, device=cov[0].device) + loss = sum([half_logdet(I + scalar * cov[i]) for i in range(num_views)]) + #loss = torch.logdet(I + scalar * cov).sum()/2 + loss /= num_views + loss *= (p+N*m)/(p*N*m) # the balancing factor gamma, you can also use the next line. This is ultimately a heuristic, so feel free to experiment. + # loss *= ((self.eps * N * m) ** 0.5 / p) + return loss + +class DINOLoss(DINOCenter): + def __init__( + self, + out_dim, + student_temp=0.1, + ): + super().__init__(out_dim) + self.student_temp = student_temp + self.reduce_handle = None + self.len_teacher_output = None + self.async_batch_center = None + + def forward(self, student_output_list, teacher_out_softmaxed_centered_list, no_diag=False): + """ + Cross-entropy between softmax outputs of the teacher and student networks. + """ + # TODO: Use cross_entropy_distribution here + total_loss = 0 + global_loss = 0 + for i, s in enumerate(student_output_list): + lsm = F.log_softmax(s / self.student_temp, dim=-1) + for j, t in enumerate(teacher_out_softmaxed_centered_list): + if no_diag and i == j: + continue + loss = torch.sum(t * lsm, dim=-1).mean() + total_loss -= loss + if no_diag and i < len(teacher_out_softmaxed_centered_list): + global_loss -= loss.detach() + return total_loss, global_loss + \ No newline at end of file diff --git a/dinov2/loss/fused_ce_loss.py b/dinov2/loss/fused_ce_loss.py new file mode 100644 index 0000000..12bc85e --- /dev/null +++ b/dinov2/loss/fused_ce_loss.py @@ -0,0 +1,192 @@ +"""This is a fused cross-entropy and linear layer. Idea is copied +from https://github.com/linkedin/Liger-Kernel who just copied it from +https://github.com/mgmalek/efficient_cross_entropy +""" + +import torch +from torch.autograd import Function +from torch.nn import functional as F + + +class EfficientCrossEntropy(Function): + @staticmethod + def forward(ctx, input: torch.Tensor, target: torch.Tensor, chunksize = 2048, + reduction: str = "mean", label_smoothing: float = 0.0, inplace_backward: bool = True): + if label_smoothing > 0.0: + raise NotImplementedError("Label smoothing is not implemented yet.") + bs = input.shape[0] + needs_grad = ctx.needs_input_grad[0] + if needs_grad: + act_grad = torch.empty_like(input) + if reduction == "none": + out_loss = torch.empty(bs, device=input.device) + else: + out_loss = torch.tensor(0.0, device=input.device) + is_label = len(target.shape) == 1 + for b in range(0, bs, chunksize): + end_idx = min(b + chunksize, bs) + + # Get current batch chunks + logits_chunk = input[b:end_idx] # [chunk_size, V] + target_chunk = target[b:end_idx] # [chunk_size] if is_label else [chunk_size, V] + + # Compute softmax and loss + max_logits = torch.max(logits_chunk, dim=-1, keepdim=True)[0] + logits_chunk -= max_logits + exp_logits = torch.exp(logits_chunk) + sum_exp = torch.sum(exp_logits, dim=-1, keepdim=True) + LSE_minus_max = torch.log(sum_exp) + + if is_label: + # Compute loss using gather + correct_logits = torch.gather( + logits_chunk, 1, target_chunk.unsqueeze(1) + ) # [chunk_size, 1] + if reduction == "none": + out_loss[b:end_idx] = LSE_minus_max.squeeze() + correct_logits.squeeze() + else: + out_loss += torch.sum(LSE_minus_max.squeeze() - correct_logits.squeeze()) + else: #target is probs + #out_loss -= torch.sum(target_chunk * torch.log(probs)) + #torch.log(probs)=logits_chunk-LSE + if reduction == "none": + out_loss[b:end_idx] = torch.sum(target_chunk * (LSE_minus_max-logits_chunk), dim=-1) + else: + out_loss += torch.sum(target_chunk * (LSE_minus_max-logits_chunk)) + + # Compute gradients + if needs_grad: + probs = exp_logits / sum_exp # [chunk_size, V] + if is_label: + grad = probs.clone() # [chunk_size, V] + grad.scatter_( + 1, + target_chunk.unsqueeze(1), + grad.gather(1, target_chunk.unsqueeze(1)) - 1, + ) + else: #target is probs, grad to input is: + grad = - target_chunk + torch.sum(target_chunk, dim=-1, keepdim=True) * probs + + # Accumulate gradients + act_grad[b:end_idx] = grad # [chunk_size, V] + + # Scale + if reduction == "mean": + scale = 1.0 / bs + else: + scale = 1.0 + if needs_grad: + act_grad *= scale + ctx.save_for_backward(act_grad) + ctx.inplace_backward = inplace_backward + return scale * out_loss + + @staticmethod + def backward(ctx, grad_output:torch.Tensor): # type: ignore + (act_grad,) = ctx.saved_tensors + #make sure grad_output have same dim as act_grad, or unsqueeze it + if grad_output.dim() == 1: + return act_grad.mul_(grad_output.unsqueeze(-1)), None, None, None, None, None + return act_grad.mul_(grad_output), None, None, None, None, None + + +class EfficientCrossEntropyFused(Function): + @staticmethod + def forward(ctx, weight: torch.Tensor, act: torch.Tensor, labels: torch.Tensor): + bs = act.shape[0] + weight_grad = torch.zeros_like(weight) + act_grad = torch.empty_like(act) + out_loss = torch.tensor(0.0, device=act.device) + chunksize = 2048 + + for b in range(0, bs, chunksize): + end_idx = min(b + chunksize, bs) + + # Get current batch chunks + act_chunk = act[b:end_idx] # [chunk_size, H] + labels_chunk = labels[b:end_idx] # [chunk_size] + + # Compute logits + logits = F.linear(act_chunk, weight) # [chunk_size, V] + + # Compute softmax and loss + max_logits = torch.max(logits, dim=-1, keepdim=True)[0] + exp_logits = torch.exp(logits - max_logits) + sum_exp = torch.sum(exp_logits, dim=-1, keepdim=True) + probs = exp_logits / sum_exp # [chunk_size, V] + + # Compute loss using gather + correct_logits = torch.gather( + logits, 1, labels_chunk.unsqueeze(1) + ) # [chunk_size, 1] + out_loss += torch.sum( + max_logits.squeeze() + + torch.log(sum_exp.squeeze()) + - correct_logits.squeeze() + ) + + # Compute gradients + dprobs = probs.clone() # [chunk_size, V] + dprobs.scatter_( + 1, + labels_chunk.unsqueeze(1), + dprobs.gather(1, labels_chunk.unsqueeze(1)) - 1, + ) + + # Accumulate gradients + weight_grad += dprobs.T @ act_chunk # [H, V] + act_grad[b:end_idx] = dprobs @ weight # [chunk_size, H] + + # Scale gradients + scale = 1.0 / bs + weight_grad *= scale + act_grad *= scale + + ctx.save_for_backward(weight_grad, act_grad) + return scale * out_loss + + @staticmethod + def backward(ctx, grad_output): # type: ignore + ( + weight_grad, + act_grad, + ) = ctx.saved_tensors + return grad_output * weight_grad, grad_output * act_grad, None + + +# torch.compile does a good enough job with the kernel here + +def cross_entropy(input, target, chunksize = 2048, + reduction: str = "mean", label_smoothing: float = 0.0, inplace_backward: bool = True): + return EfficientCrossEntropy.apply(input, target, chunksize,reduction,label_smoothing,inplace_backward) + +def fused_cross_entropy(lm_head_weight, act, labels): + return EfficientCrossEntropyFused.apply(lm_head_weight, act, labels) + +if __name__ == "__main__": + # Test if the forward pass is correct + ### + torch.manual_seed(0) + logits = torch.randn(4, 3, requires_grad=True) + labels = torch.tensor([0, 1, 2, 1]) + print("Logits:", logits, "Labels:", labels, "exprected gradient:", cross_entropy(logits.detach(), labels)) + loss = F.cross_entropy(logits, labels) + loss.backward() + print("Loss:", loss.item(), "Grad:", logits.grad) + logits.grad.zero_() + loss = cross_entropy(logits, labels) + loss.backward() + print("Loss:", loss.item(), "Grad:", logits.grad) + + print("######") + logits = torch.randn(4, 2, 3, requires_grad=True) + labels = torch.randn(4, 2, 3) + logits.grad.zero_() + loss = -torch.sum(labels * F.log_softmax(logits, dim=-1), dim=-1) + loss.mean().backward() + print("Loss:", loss, "Grad:", logits.grad) + logits.grad.zero_() + loss = cross_entropy(logits, labels, reduction="none") + loss.mean().backward() + print("Loss:", loss, "Grad:", logits.grad) + \ No newline at end of file diff --git a/dinov2/loss/ibot_patch_loss.py b/dinov2/loss/ibot_patch_loss.py index 6732cda..3330e2b 100644 --- a/dinov2/loss/ibot_patch_loss.py +++ b/dinov2/loss/ibot_patch_loss.py @@ -15,28 +15,38 @@ try: - from xformers.ops import cross_entropy - - def lossfunc(t, s, temp): + from xformers.ops.common import cross_entropy + #from flash_attn.ops.triton.cross_entropy import cross_entropy_loss + def lossfunc(s:torch.Tensor, t:torch.Tensor, temp): s = s.float() t = t.float() if s.ndim == 2: return -cross_entropy(s.unsqueeze(0), t.unsqueeze(0), temp, bw_inplace=True).squeeze(0) elif s.ndim == 3: return -cross_entropy(s, t, temp, bw_inplace=True) - except ImportError: - - def lossfunc(t, s, temp): - return torch.sum(t * F.log_softmax(s / temp, dim=-1), dim=-1) - - -class iBOTPatchLoss(nn.Module): - def __init__(self, patch_out_dim, student_temp=0.1, center_momentum=0.9): + try: + from .fused_ce_loss import cross_entropy + def lossfunc(s, t, temp): + return -cross_entropy(s/temp, t, reduction="none", chunksize=256) + except ImportError: + def lossfunc(s, t, temp): + return F.cross_entropy(s/temp, t, reduction="none") + #return torch.sum(t * F.log_softmax(s / temp, dim=-1), dim=-1) + + +class PatchDINOCenter(nn.Module): + def __init__( + self, + patch_out_dim, + enable=True, + center_momentum=0.9, + ): super().__init__() - self.student_temp = student_temp - self.center_momentum = center_momentum + if not enable: + return self.register_buffer("center", torch.zeros(1, 1, patch_out_dim)) + self.center_momentum = center_momentum self.updated = True self.reduce_handle = None self.len_teacher_patch_tokens = None @@ -89,6 +99,48 @@ def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_masked_patches_ Q *= B # the columns must sum to 1 so that Q is an assignment return Q.t() +class CosinePatchLoss(PatchDINOCenter): + def __init__(self, patch_out_dim, center=False, center_momentum=0.9, **kwargs): + super().__init__(patch_out_dim, center, center_momentum) + + def forward(self, student_patch_tokens, teacher_patch_tokens, student_masks_flat): + """ + Cross-entropy between softmax outputs of the teacher and student networks. + student_patch_tokens: (B, N, D) tensor + teacher_patch_tokens: (B, N, D) tensor + student_masks_flat: (B, N) tensor + """ + loss = F.cosine_similarity(teacher_patch_tokens, student_patch_tokens, dim=-1) + loss = torch.sum(loss * student_masks_flat.float(), dim=-1) / student_masks_flat.sum(dim=-1).clamp(min=1.0) + comp_loss= -loss.mean() + return comp_loss, {"comp_loss": comp_loss.detach()} + + def forward_masked( + self, + student_patch_tokens_masked, + teacher_patch_tokens_masked, + student_masks_flat, + n_masked_patches=None, + masks_weight=None, + ): + loss = F.cosine_similarity(teacher_patch_tokens_masked, student_patch_tokens_masked, dim=-1) + if masks_weight is None: + masks_weight = ( + (1 / student_masks_flat.sum(-1).clamp(min=1.0)) + .unsqueeze(-1) + .expand_as(student_masks_flat)[student_masks_flat] + ) + if n_masked_patches is not None: + loss = loss[:n_masked_patches] + loss = loss * masks_weight + comp_loss = -loss.sum() / student_masks_flat.shape[0] + return comp_loss, {"comp_loss": comp_loss.detach()} + +class iBOTPatchLoss(PatchDINOCenter): + def __init__(self, patch_out_dim, student_temp=0.1, center_momentum=0.9, **kwargs): + super().__init__(patch_out_dim, center_momentum) + self.student_temp = student_temp + def forward(self, student_patch_tokens, teacher_patch_tokens, student_masks_flat): """ Cross-entropy between softmax outputs of the teacher and student networks. @@ -113,7 +165,7 @@ def forward_masked( t = teacher_patch_tokens_masked s = student_patch_tokens_masked # loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1) - loss = lossfunc(t, s, self.student_temp) + loss = lossfunc(s, t, self.student_temp) if masks_weight is None: masks_weight = ( (1 / student_masks_flat.sum(-1).clamp(min=1.0)) diff --git a/dinov2/train/ssl_meta_arch.py b/dinov2/train/ssl_meta_arch.py index db95104..e94f1d8 100644 --- a/dinov2/train/ssl_meta_arch.py +++ b/dinov2/train/ssl_meta_arch.py @@ -9,7 +9,7 @@ import torch from torch import nn -from dinov2.loss import DINOLoss, iBOTPatchLoss, KoLeoLoss, KDELoss +from dinov2.loss import MCRLoss, DINOLoss, CosinePatchLoss, iBOTPatchLoss, KoLeoLoss from dinov2.models import build_model_from_cfg from dinov2.layers import DINOHead from dinov2.utils.utils import has_batchnorms @@ -18,7 +18,7 @@ from dinov2.models.vision_transformer import BlockChunk - +XFORMERS_AVAILABLE = True try: from xformers.ops import fmha except ImportError: @@ -41,27 +41,33 @@ def __init__(self, cfg): student_model_dict["backbone"] = student_backbone teacher_model_dict["backbone"] = teacher_backbone logger.info(f"OPTIONS -- architecture : embed_dim: {embed_dim}") - - if cfg.student.pretrained_weights: - chkpt = torch.load(cfg.student.pretrained_weights) - logger.info(f"OPTIONS -- pretrained weights: loading from {cfg.student.pretrained_weights}") - student_backbone.load_state_dict(chkpt["model"], strict=False) - self.embed_dim = embed_dim - self.dino_out_dim = cfg.dino.head_n_prototypes self.do_dino = cfg.dino.loss_weight > 0 self.do_koleo = cfg.dino.koleo_loss_weight > 0 - self.do_kde = cfg.dino.kde_loss_weight > 0 self.do_ibot = cfg.ibot.loss_weight > 0 self.ibot_separate_head = cfg.ibot.separate_head + + self.dino_use_mcr = cfg.dino.use_mcr + self.ibot_use_mcr = cfg.ibot.use_mcr + self.drop_masks = cfg.student.drop_masks + n_global_crops = 2 + #assert n_global_crops == 2 + n_local_crops = self.cfg.crops.local_crops_number + #ncrops = n_global_crops + n_local_crops + self.n_global_crops =n_global_crops + self.n_local_crops = n_local_crops + self.n_global_crops_loss_terms = (n_global_crops - 1) * n_global_crops + self.n_total_crops_loss_terms = n_local_crops * n_global_crops + self.n_global_crops_loss_terms - logger.info("OPTIONS -- DINO") if self.do_dino: + logger.info("OPTIONS -- DINO") logger.info(f"OPTIONS -- DINO -- loss_weight: {cfg.dino.loss_weight}") logger.info(f"OPTIONS -- DINO -- head_n_prototypes: {cfg.dino.head_n_prototypes}") logger.info(f"OPTIONS -- DINO -- head_bottleneck_dim: {cfg.dino.head_bottleneck_dim}") logger.info(f"OPTIONS -- DINO -- head_hidden_dim: {cfg.dino.head_hidden_dim}") + head_normalize = getattr(cfg.dino, "head_normalize", True) + remove_last_layer = getattr(cfg.dino, "remove_last_layer", True) self.dino_loss_weight = cfg.dino.loss_weight dino_head = partial( DINOHead, @@ -70,35 +76,35 @@ def __init__(self, cfg): hidden_dim=cfg.dino.head_hidden_dim, bottleneck_dim=cfg.dino.head_bottleneck_dim, nlayers=cfg.dino.head_nlayers, + normalize=cfg.dino.head_normalize, + remove_last_layer=cfg.dino.remove_last_layer ) - self.dino_loss = DINOLoss(self.dino_out_dim) + dino_out_dim = cfg.dino.head_bottleneck_dim if cfg.dino.remove_last_layer else cfg.dino.head_n_prototypes + self.dino_loss = MCRLoss(dino_out_dim, **cfg.dino.mcr) if self.dino_use_mcr else DINOLoss(dino_out_dim) if self.do_koleo: logger.info("OPTIONS -- DINO -- applying KOLEO regularization") self.koleo_loss = KoLeoLoss() - if self.do_kde: - logger.info("OPTIONS -- DINO -- apply KDE regularization") - self.kde_loss = KDELoss() - else: logger.info("OPTIONS -- DINO -- not using DINO") - - if self.do_dino or self.do_ibot: + + if self.do_dino or (self.do_ibot and not self.ibot_separate_head): student_model_dict["dino_head"] = dino_head() teacher_model_dict["dino_head"] = dino_head() - - logger.info("OPTIONS -- IBOT") - logger.info(f"OPTIONS -- IBOT -- loss_weight: {cfg.ibot.loss_weight}") - logger.info(f"OPTIONS -- IBOT masking -- ibot_mask_ratio_tuple: {cfg.ibot.mask_ratio_min_max}") - logger.info(f"OPTIONS -- IBOT masking -- ibot_mask_sample_probability: {cfg.ibot.mask_sample_probability}") if self.do_ibot: + logger.info("OPTIONS -- IBOT") + logger.info(f"OPTIONS -- IBOT -- loss_weight: {cfg.ibot.loss_weight}") + logger.info(f"OPTIONS -- IBOT masking -- ibot_mask_ratio_tuple: {cfg.ibot.mask_ratio_min_max}") + logger.info(f"OPTIONS -- IBOT masking -- ibot_mask_sample_probability: {cfg.ibot.mask_sample_probability}") self.ibot_loss_weight = cfg.ibot.loss_weight assert max(cfg.ibot.mask_ratio_min_max) > 0, "please provide a positive mask ratio tuple for ibot" assert cfg.ibot.mask_sample_probability > 0, "please provide a positive mask probability for ibot" - self.ibot_out_dim = cfg.ibot.head_n_prototypes if self.ibot_separate_head else cfg.dino.head_n_prototypes - self.ibot_patch_loss = iBOTPatchLoss(self.ibot_out_dim) + ibot_out_dim = (cfg.ibot.head_bottleneck_dim if cfg.dino.remove_last_layer else cfg.ibot.head_n_prototypes) if self.ibot_separate_head else dino_out_dim + self.ibot_patch_loss = CosinePatchLoss(ibot_out_dim, **cfg.ibot.mcr) if self.ibot_use_mcr else iBOTPatchLoss(ibot_out_dim) if self.ibot_separate_head: - logger.info(f"OPTIONS -- IBOT -- loss_weight: {cfg.ibot.loss_weight}") - logger.info(f"OPTIONS -- IBOT -- head_n_prototypes: {cfg.ibot.head_n_prototypes}") + if cfg.ibot.remove_last_layer: + logger.info("OPTIONS -- IBOT -- remove last layer") + else: + logger.info(f"OPTIONS -- IBOT -- head_n_prototypes: {cfg.ibot.head_n_prototypes}") logger.info(f"OPTIONS -- IBOT -- head_bottleneck_dim: {cfg.ibot.head_bottleneck_dim}") logger.info(f"OPTIONS -- IBOT -- head_hidden_dim: {cfg.ibot.head_hidden_dim}") ibot_head = partial( @@ -108,6 +114,8 @@ def __init__(self, cfg): hidden_dim=cfg.ibot.head_hidden_dim, bottleneck_dim=cfg.ibot.head_bottleneck_dim, nlayers=cfg.ibot.head_nlayers, + normalize=cfg.ibot.head_normalize, + remove_last_layer=cfg.ibot.remove_last_layer ) student_model_dict["ibot_head"] = ibot_head() teacher_model_dict["ibot_head"] = ibot_head() @@ -119,10 +127,17 @@ def __init__(self, cfg): self.student = nn.ModuleDict(student_model_dict) self.teacher = nn.ModuleDict(teacher_model_dict) + if cfg.compile: + self.teacher.compile() + self.student.compile() + getattr(self, "dino_loss", None) and self.dino_loss.compile() + getattr(self, "koleo_loss", None) and self.koleo_loss.compile() + getattr(self, "ibot_patch_loss", None) and self.ibot_patch_loss.compile() # there is no backpropagation through the teacher, so no need for gradients for p in self.teacher.parameters(): p.requires_grad = False - logger.info(f"Student and Teacher are built: they are both {cfg.student.arch} network.") + + logger.info(f"Student and Teacher are built: {cfg.student.arch}") def forward(self, inputs): raise NotImplementedError @@ -133,77 +148,55 @@ def backprop_loss(self, loss): else: loss.backward() - def forward_backward(self, images, teacher_temp): - n_global_crops = 2 - assert n_global_crops == 2 - n_local_crops = self.cfg.crops.local_crops_number - + def forward_backward(self, images, teacher_temp, activate_ibot=True): + n_global_crops = self.n_global_crops global_crops = images["collated_global_crops"].cuda(non_blocking=True) local_crops = images["collated_local_crops"].cuda(non_blocking=True) masks = images["collated_masks"].cuda(non_blocking=True) mask_indices_list = images["mask_indices_list"].cuda(non_blocking=True) - n_masked_patches_tensor = images["n_masked_patches"].cuda(non_blocking=True) n_masked_patches = mask_indices_list.shape[0] - upperbound = images["upperbound"] masks_weight = images["masks_weight"].cuda(non_blocking=True) - - n_local_crops_loss_terms = max(n_local_crops * n_global_crops, 1) - n_global_crops_loss_terms = (n_global_crops - 1) * n_global_crops - - do_dino = self.do_dino + do_ibot = self.do_ibot - - # loss scales - ibot_loss_scale = 1.0 / n_global_crops - - # teacher output @torch.no_grad() def get_teacher_output(): - x, n_global_crops_teacher = global_crops, n_global_crops - teacher_backbone_output_dict = self.teacher.backbone(x, is_training=True) + teacher_backbone_output_dict = self.teacher.backbone(global_crops, is_training=True) teacher_cls_tokens = teacher_backbone_output_dict["x_norm_clstoken"] - teacher_cls_tokens = teacher_cls_tokens.chunk(n_global_crops_teacher) - # watch out: these are chunked and cat'd in reverse so A is matched to B in the global crops dino loss - teacher_cls_tokens = torch.cat((teacher_cls_tokens[1], teacher_cls_tokens[0])) - ibot_teacher_patch_tokens = teacher_backbone_output_dict["x_norm_patchtokens"] - _dim = ibot_teacher_patch_tokens.shape[-1] + teacher_patch_tokens = teacher_backbone_output_dict["x_norm_patchtokens"] + _dim = teacher_patch_tokens.shape[-1] n_cls_tokens = teacher_cls_tokens.shape[0] if do_ibot and not self.ibot_separate_head: - buffer_tensor_teacher = ibot_teacher_patch_tokens.new_zeros(upperbound + n_cls_tokens, _dim) + buffer_tensor_teacher = teacher_patch_tokens.new_zeros(n_masked_patches + n_cls_tokens, _dim) buffer_tensor_teacher[:n_cls_tokens].copy_(teacher_cls_tokens) torch.index_select( - ibot_teacher_patch_tokens.flatten(0, 1), + teacher_patch_tokens.flatten(0, 1), dim=0, index=mask_indices_list, out=buffer_tensor_teacher[n_cls_tokens : n_cls_tokens + n_masked_patches], ) tokens_after_head = self.teacher.dino_head(buffer_tensor_teacher) - teacher_cls_tokens_after_head = tokens_after_head[:n_cls_tokens] - masked_teacher_patch_tokens_after_head = tokens_after_head[ - n_cls_tokens : n_cls_tokens + n_masked_patches - ] + teacher_cls_tokens_after_head, masked_teacher_patch_tokens_after_head = tokens_after_head.split([n_cls_tokens, n_masked_patches]) elif do_ibot and self.ibot_separate_head: - buffer_tensor_teacher = ibot_teacher_patch_tokens.new_zeros(upperbound, _dim) + teacher_cls_tokens_after_head = self.teacher.dino_head(teacher_cls_tokens) + buffer_tensor_teacher = teacher_patch_tokens.new_zeros(n_masked_patches, _dim) torch.index_select( - ibot_teacher_patch_tokens.flatten(0, 1), + teacher_patch_tokens.flatten(0, 1), dim=0, index=mask_indices_list, - out=buffer_tensor_teacher[:n_masked_patches], + out=buffer_tensor_teacher, ) - teacher_cls_tokens_after_head = self.teacher.dino_head(teacher_cls_tokens) - masked_teacher_patch_tokens_after_head = self.teacher.ibot_head(buffer_tensor_teacher)[ - :n_masked_patches - ] + masked_teacher_patch_tokens_after_head = self.teacher.ibot_head(buffer_tensor_teacher) else: teacher_cls_tokens_after_head = self.teacher.dino_head(teacher_cls_tokens) - masked_teacher_ibot_softmaxed_centered = None + masked_teacher_patch_tokens_after_head = None + masked_teacher_ibot_softmaxed_centered = None if self.cfg.train.centering == "centering": teacher_dino_softmaxed_centered_list = self.dino_loss.softmax_center_teacher( teacher_cls_tokens_after_head, teacher_temp=teacher_temp - ).view(n_global_crops_teacher, -1, *teacher_cls_tokens_after_head.shape[1:]) + ).view(n_global_crops, -1, *teacher_cls_tokens_after_head.shape[1:]) self.dino_loss.update_center(teacher_cls_tokens_after_head) if do_ibot: masked_teacher_patch_tokens_after_head = masked_teacher_patch_tokens_after_head.unsqueeze(0) @@ -212,11 +205,10 @@ def get_teacher_output(): ) masked_teacher_ibot_softmaxed_centered = masked_teacher_ibot_softmaxed_centered.squeeze(0) self.ibot_patch_loss.update_center(masked_teacher_patch_tokens_after_head[:n_masked_patches]) - elif self.cfg.train.centering == "sinkhorn_knopp": teacher_dino_softmaxed_centered_list = self.dino_loss.sinkhorn_knopp_teacher( teacher_cls_tokens_after_head, teacher_temp=teacher_temp - ).view(n_global_crops_teacher, -1, *teacher_cls_tokens_after_head.shape[1:]) + ).view(n_global_crops, -1, *teacher_cls_tokens_after_head.shape[1:]) if do_ibot: masked_teacher_ibot_softmaxed_centered = self.ibot_patch_loss.sinkhorn_knopp_teacher( @@ -224,146 +216,125 @@ def get_teacher_output(): teacher_temp=teacher_temp, n_masked_patches_tensor=n_masked_patches_tensor, ) - else: - raise NotImplementedError - - return teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered + teacher_dino_softmaxed_centered_list = teacher_cls_tokens_after_head + masked_teacher_ibot_softmaxed_centered = masked_teacher_patch_tokens_after_head + return teacher_cls_tokens, teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered - teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered = get_teacher_output() + teacher_cls_tokens, teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered = get_teacher_output() reshard_fsdp_model(self.teacher) loss_dict = {} + loss_accumulator = 0 - loss_accumulator = 0 # for backprop - student_global_backbone_output_dict, student_local_backbone_output_dict = self.student.backbone( - [global_crops, local_crops], masks=[masks, None], is_training=True - ) - - inputs_for_student_head_list = [] + # --- STUDENT FORWARD --- + student_global_backbone_output_dict = self.student.backbone(global_crops, masks=masks, is_training=True) + student_local_backbone_output_dict = self.student.backbone(local_crops, is_training=True) + inputs_for_student_head = [] + # 1a: local crops cls tokens # 1a: local crops cls tokens student_local_cls_tokens = student_local_backbone_output_dict["x_norm_clstoken"] - inputs_for_student_head_list.append(student_local_cls_tokens.unsqueeze(0)) + inputs_for_student_head.append(student_local_cls_tokens) # 1b: global crops cls tokens student_global_cls_tokens = student_global_backbone_output_dict["x_norm_clstoken"] - inputs_for_student_head_list.append(student_global_cls_tokens.unsqueeze(0)) + inputs_for_student_head.append(student_global_cls_tokens) # 1c: global crops patch tokens if do_ibot: - _dim = student_global_backbone_output_dict["x_norm_clstoken"].shape[-1] ibot_student_patch_tokens = student_global_backbone_output_dict["x_norm_patchtokens"] - buffer_tensor_patch_tokens = ibot_student_patch_tokens.new_zeros(upperbound, _dim) - buffer_tensor_patch_tokens[:n_masked_patches].copy_( - torch.index_select(ibot_student_patch_tokens.flatten(0, 1), dim=0, index=mask_indices_list) - ) + buffer_tensor_patch_tokens=torch.index_select(ibot_student_patch_tokens.flatten(0, 1), dim=0, index=mask_indices_list) if not self.ibot_separate_head: - inputs_for_student_head_list.append(buffer_tensor_patch_tokens.unsqueeze(0)) + inputs_for_student_head.append(buffer_tensor_patch_tokens) else: - student_global_masked_patch_tokens_after_head = self.student.ibot_head(buffer_tensor_patch_tokens)[ - :n_masked_patches - ] - - # 2: run - _attn_bias, cat_inputs = fmha.BlockDiagonalMask.from_tensor_list(inputs_for_student_head_list) - outputs_list = _attn_bias.split(self.student.dino_head(cat_inputs)) - - # 3a: local crops cls tokens - student_local_cls_tokens_after_head = outputs_list.pop(0).squeeze(0) - - # 3b: global crops cls tokens - student_global_cls_tokens_after_head = outputs_list.pop(0).squeeze(0) + student_global_masked_patch_tokens_after_head = self.student.ibot_head(buffer_tensor_patch_tokens) + del student_global_backbone_output_dict, student_local_backbone_output_dict + + if self.do_dino and self.do_koleo: + koleo_loss = self.cfg.dino.koleo_loss_weight * sum( + self.koleo_loss(p) for p in student_global_cls_tokens.chunk(2) + ) # we don't apply koleo loss between cls tokens of a same image + loss_accumulator += koleo_loss + loss_dict["koleo_loss"] = ( + koleo_loss #/ n_global_crops + ) # this is to display the same losses as before but we can remove eventually + + + if XFORMERS_AVAILABLE: + _attn_bias, cat_inputs = fmha.BlockDiagonalMask.from_tensor_list([x.unsqueeze(0) for x in inputs_for_student_head]) + outputs_list = [x.squeeze(0) for x in _attn_bias.split(self.student.dino_head(cat_inputs))] + del _attn_bias, cat_inputs + else: + seqs = [x.shape[0] for x in inputs_for_student_head] + inputs_for_student_head = torch.cat(inputs_for_student_head) + outputs_list = self.student.dino_head(inputs_for_student_head).split(seqs) + del inputs_for_student_head - # 3c: global crops patch tokens if do_ibot and not self.ibot_separate_head: - student_global_masked_patch_tokens_after_head = outputs_list.pop(0).squeeze(0)[:n_masked_patches] - - if n_local_crops > 0: - dino_local_crops_loss = self.dino_loss( - student_output_list=student_local_cls_tokens_after_head.chunk(n_local_crops), - teacher_out_softmaxed_centered_list=teacher_dino_softmaxed_centered_list, - ) / (n_global_crops_loss_terms + n_local_crops_loss_terms) - - # store for display - loss_dict["dino_local_crops_loss"] = dino_local_crops_loss - - # accumulate loss - loss_accumulator += self.dino_loss_weight * dino_local_crops_loss - - # process global crops - loss_scales = 2 # this is here since we process global crops together + student_local_cls_tokens_after_head, student_global_cls_tokens_after_head, student_global_masked_patch_tokens_after_head = outputs_list + else: + student_local_cls_tokens_after_head, student_global_cls_tokens_after_head = outputs_list + del outputs_list - if do_dino: + # --- LOSS CALCULATION --- + if self.do_dino: # compute loss - dino_global_crops_loss = ( - self.dino_loss( - student_output_list=[student_global_cls_tokens_after_head], - teacher_out_softmaxed_centered_list=[ - teacher_dino_softmaxed_centered_list.flatten(0, 1) - ], # these were chunked and stacked in reverse so A is matched to B - ) - * loss_scales - / (n_global_crops_loss_terms + n_local_crops_loss_terms) + dino_crops_loss, dino_loss_dict = self.dino_loss( + student_global_cls_tokens_after_head.chunk(2) + student_local_cls_tokens_after_head.chunk(self.n_local_crops), + teacher_dino_softmaxed_centered_list.chunk(2), no_diag=True ) - - loss_dict["dino_global_crops_loss"] = dino_global_crops_loss - + if self.dino_use_mcr: + dino_loss_dict = {"dino_mcr_"+k: v for k, v in dino_loss_dict.items()} + loss_dict |= dino_loss_dict + else: + #dino loss averaged over the number of crops + dino_crops_loss /= self.n_total_crops_loss_terms + loss_dict["dino_loss"] = dino_crops_loss + loss_dict["dino_global_crops_loss"] = dino_loss_dict / self.n_global_crops_loss_terms # accumulate loss - loss_accumulator += self.dino_loss_weight * dino_global_crops_loss - - student_cls_tokens = student_global_cls_tokens - - if self.do_koleo: - print("doing koleo") - koleo_loss = self.cfg.dino.koleo_loss_weight * sum( - self.koleo_loss(p) for p in student_cls_tokens.chunk(2) - ) # we don't apply koleo loss between cls tokens of a same image - loss_accumulator += koleo_loss - loss_dict["koleo_loss"] = ( - koleo_loss / loss_scales - ) # this is to display the same losses as before but we can remove eventually - print(self.cfg.dino.koleo_loss_weight) - - if self.do_kde: - kde_loss = self.cfg.dino.kde_loss_weight * sum( - self.kde_loss(p) for p in student_cls_tokens.chunk(2) - ) # we don't apply koleo loss between cls tokens of a same image - loss_accumulator += kde_loss - loss_dict["kde_loss"] = ( - kde_loss / loss_scales - ) # this is to display the same losses as before but we can remove eventually - + loss_accumulator += self.dino_loss_weight * dino_crops_loss + del student_global_cls_tokens_after_head, student_local_cls_tokens_after_head + del teacher_dino_softmaxed_centered_list if do_ibot: # compute loss - ibot_patch_loss = ( - self.ibot_patch_loss.forward_masked( + ibot_patch_loss = self.ibot_patch_loss.forward_masked( student_global_masked_patch_tokens_after_head, masked_teacher_ibot_softmaxed_centered, student_masks_flat=masks, n_masked_patches=n_masked_patches, masks_weight=masks_weight, ) - * loss_scales - * ibot_loss_scale - ) - # store for display - loss_dict["ibot_loss"] = ibot_patch_loss / 2 + if self.ibot_use_mcr: + ibot_patch_loss, ibot_loss_dict = ibot_patch_loss + ibot_loss_dict = {"ibot_"+k: v for k, v in ibot_loss_dict.items()} + loss_dict |= ibot_loss_dict + else: + loss_dict["ibot_loss"] = ibot_patch_loss # / n_global_crops # accumulate loss loss_accumulator += self.ibot_loss_weight * ibot_patch_loss + loss_dict["total_loss"] = loss_accumulator.detach() self.backprop_loss(loss_accumulator) self.fsdp_synchronize_streams() - + if torch.isnan(loss_accumulator): + print(f"loss_accumulator NaN detected: {loss_dict}") + import debugpy + debugpy.breakpoint() + return loss_dict def fsdp_synchronize_streams(self): if self.need_to_synchronize_fsdp_streams: torch.cuda.synchronize() + # self.student.dino_head._streams = ( + # self.teacher.dino_head._streams + # ) = self.student.backbone._streams = self.teacher.backbone._streams + for attr in {"_unshard_stream", "_post_backward_stream", "_pre_unshard_stream", "_all_reduce_stream", "_default_stream"}: stream = getattr(self.teacher.backbone, attr) setattr(self.student.dino_head, attr, stream) @@ -372,6 +343,11 @@ def fsdp_synchronize_streams(self): self.need_to_synchronize_fsdp_streams = False def update_teacher(self, m): + if m == 1.0: + return + elif m == 0.0: + self.teacher.load_state_dict(self.student.state_dict()) + return student_param_list = [] teacher_param_list = [] with torch.no_grad(): @@ -379,8 +355,11 @@ def update_teacher(self, m): for ms, mt in zip(get_fsdp_modules(self.student[k]), get_fsdp_modules(self.teacher[k])): student_param_list += ms.params teacher_param_list += mt.params - torch._foreach_mul_(teacher_param_list, m) - torch._foreach_add_(teacher_param_list, student_param_list, alpha=1 - m) + if hasattr(torch, '_foreach_lerp_'): + torch._foreach_lerp_(teacher_param_list, student_param_list, weight=1. - m) + else: + torch._foreach_mul_(teacher_param_list, m) + torch._foreach_add_(teacher_param_list, student_param_list, alpha=1. - m) def train(self): super().train() @@ -399,10 +378,14 @@ def get_maybe_fused_params_for_submodel(self, m): g["foreach"] = True return fused_params_groups - def get_params_groups(self): + def get_params_groups(self, fused=True): all_params_groups = [] for m in self.student.values(): - all_params_groups += self.get_maybe_fused_params_for_submodel(m) + all_params_groups += self.get_maybe_fused_params_for_submodel(m) if fused else get_params_groups_with_decay( + model=m, + lr_decay_rate=self.cfg.optim.layerwise_decay, + patch_embed_lr_mult=self.cfg.optim.patch_embed_lr_mult, + ) return all_params_groups def prepare_for_distributed_training(self): diff --git a/dinov2/train/train.py b/dinov2/train/train.py index e9bf557..0c79de9 100644 --- a/dinov2/train/train.py +++ b/dinov2/train/train.py @@ -838,6 +838,10 @@ def _worker_init(_): torch.distributed.all_reduce(v) loss_dict_reduced = {k: v.item() / distributed.get_global_size() for k, v in loss_dict.items()} + if "total_loss" in loss_dict_reduced: + losses_reduced = loss_dict_reduced.pop("total_loss") + else: + losses_reduced = sum(loss for loss in loss_dict_reduced.values()) if math.isnan(sum(loss_dict_reduced.values())): print(sum(loss_dict_reduced.values())) logger.info("NaN detected")