Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions dinov2/layers/dino_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,18 @@ def __init__(
hidden_dim=2048,
bottleneck_dim=256,
mlp_bias=True,
normalize=True,
remove_last_layer=False
):
super().__init__()
nlayers = max(nlayers, 1)
self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
self.apply(self._init_weights)
self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
self.last_layer.weight_g.data.fill_(1)
if not remove_last_layer:
self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
self.last_layer.weight_g.data.fill_(1)
self.normalize = normalize
self.remove_last_layer = remove_last_layer

def _init_weights(self, m):
if isinstance(m, nn.Linear):
Expand All @@ -36,8 +41,10 @@ def _init_weights(self, m):
def forward(self, x):
x = self.mlp(x)
eps = 1e-6 if x.dtype == torch.float16 else 1e-12
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
x = self.last_layer(x)
if self.normalize:
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
if not self.remove_last_layer:
x = self.last_layer(x)
return x


Expand Down
4 changes: 2 additions & 2 deletions dinov2/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

from .dino_clstoken_loss import DINOLoss
from .ibot_patch_loss import iBOTPatchLoss
from .dino_clstoken_loss import DINOLoss, MCRLoss
from .ibot_patch_loss import iBOTPatchLoss, CosinePatchLoss
from .koleo_loss import KoLeoLoss
from .kde_loss import KDELoss
138 changes: 112 additions & 26 deletions dinov2/loss/dino_clstoken_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,23 @@
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn


class DINOLoss(nn.Module):
class DINOCenter(nn.Module):
def __init__(
self,
out_dim,
student_temp=0.1,
enable=True,
center_momentum=0.9,
):
super().__init__()
self.student_temp = student_temp
self.center_momentum = center_momentum
if not enable:
return
self.register_buffer("center", torch.zeros(1, out_dim))
self.center_momentum = center_momentum
self.updated = True
self.reduce_handle = None
self.len_teacher_output = None
Expand All @@ -44,35 +43,17 @@ def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_iterations=3):
if dist.is_initialized():
dist.all_reduce(sum_Q)
Q /= sum_Q

for it in range(n_iterations):
# normalize each row: total weight per prototype must be 1/K
sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
if dist.is_initialized():
dist.all_reduce(sum_of_rows)
Q /= sum_of_rows
Q /= K

# normalize each column: total weight per sample must be 1/B
Q /= torch.sum(Q, dim=0, keepdim=True)
Q /= B

Q *= B # the columns must sum to 1 so that Q is an assignment
Q *= B
return Q.t()

def forward(self, student_output_list, teacher_out_softmaxed_centered_list):
"""
Cross-entropy between softmax outputs of the teacher and student networks.
"""
# TODO: Use cross_entropy_distribution here
total_loss = 0
for s in student_output_list:
lsm = F.log_softmax(s / self.student_temp, dim=-1)
for t in teacher_out_softmaxed_centered_list:
loss = torch.sum(t * lsm, dim=-1)
total_loss -= loss.mean()
return total_loss

@torch.no_grad()
def update_center(self, teacher_output):
self.reduce_center_update(teacher_output)
Expand All @@ -89,11 +70,116 @@ def reduce_center_update(self, teacher_output):
def apply_center_update(self):
if self.updated is False:
world_size = dist.get_world_size() if dist.is_initialized() else 1

if self.reduce_handle is not None:
self.reduce_handle.wait()
_t = self.async_batch_center / (self.len_teacher_output * world_size)

self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum)

self.updated = True

def half_logdet(X):
# Ensure Matrix is Float32 for Cholesky Stability
return torch.linalg.cholesky_ex(X.float())[0].diagonal().log().sum()

class MCRLoss(DINOCenter):
def __init__(self, out_dim, expa_type=0, reduce_cov=0, eps=0.05, coeff=1, center=False, *args, **kwargs):
super().__init__(out_dim, enable=center)
self.eps = eps
self.coeff = coeff
self.expa_type = expa_type
self.reduce_cov = reduce_cov

def forward(self, student_feat_list, teacher_feat_list, no_diag=True, normalized=True):
"""
Expansion Loss and Compression Loss between features of the teacher and student networks.
"""
# Convert lists of tensors to a single tensor for vectorized operations
student_feat = torch.stack(student_feat_list) #ncrops,N,D
teacher_feat = torch.stack(teacher_feat_list) #2,N,D
if not normalized:
student_feat = F.normalize(student_feat, p=2, dim=-1)
teacher_feat = F.normalize(teacher_feat, p=2, dim=-1)
comp_loss, global_comp_loss = self.calc_compression(student_feat, teacher_feat, no_diag=no_diag)
match self.expa_type:
case 0: # only compute expansion on global views
expa_feat = student_feat[:len(teacher_feat)]
case 1: # center with teacher
expa_feat = (student_feat[:len(teacher_feat)] + teacher_feat) / 2
expa_loss = self.calc_expansion(expa_feat)
loss = - self.coeff * comp_loss - expa_loss
return loss, {"loss": loss.detach(), "comp_loss":comp_loss.detach(), "global_comp_loss":global_comp_loss.detach(), "expa_loss":expa_loss.detach()}

def calc_compression(self, student_feat_list, teacher_feat_list, no_diag=True):
"""
Compute compression loss between student and teacher features.
"""

# Compute cosine similarity for all pairs
comp_loss = 0
sim = (teacher_feat_list.unsqueeze(1)*student_feat_list.unsqueeze(0)).sum(-1).mean(-1)
# Mask out the diagonal elements where student and teacher operate on the same view
#mask = torch.eye(len(teacher_feat_list), len(student_feat_list), dtype=torch.bool,device=cosine_sim.device).unsqueeze_(2)
#sim = cosine_sim.masked_fill(mask, 0)
if no_diag:
sim.view(-1)[:: (len(student_feat_list) + 1)].fill_(0) # Trick to fill diagonal

n_loss_terms = len(teacher_feat_list)* len(student_feat_list) - min(len(teacher_feat_list), len(student_feat_list))
# Sum the cosine similarities
comp_loss = sim.sum()/n_loss_terms
global_comp_loss = sim[:, :len(teacher_feat_list)].detach().sum().div_(len(teacher_feat_list))
return comp_loss, global_comp_loss

def calc_expansion(self, feat_list, cross_list=None) -> torch.Tensor:
"""
Compute expansion loss using Coding Rate estimation.
"""
cov = []
num_views = len(feat_list)
m, p = feat_list[0].shape
cov = torch.einsum('nbc,nbd->ncd', feat_list, cross_list or feat_list)
N=1
if dist.is_initialized():
N = dist.get_world_size()
if self.reduce_cov == 1:
cov = dist.nn.all_reduce(cov)
loss = 0
scalar = p / (m * N * self.eps)
I = torch.eye(p, device=cov[0].device)
loss = sum([half_logdet(I + scalar * cov[i]) for i in range(num_views)])
#loss = torch.logdet(I + scalar * cov).sum()/2
loss /= num_views
loss *= (p+N*m)/(p*N*m) # the balancing factor gamma, you can also use the next line. This is ultimately a heuristic, so feel free to experiment.
# loss *= ((self.eps * N * m) ** 0.5 / p)
return loss

class DINOLoss(DINOCenter):
def __init__(
self,
out_dim,
student_temp=0.1,
):
super().__init__(out_dim)
self.student_temp = student_temp
self.reduce_handle = None
self.len_teacher_output = None
self.async_batch_center = None

def forward(self, student_output_list, teacher_out_softmaxed_centered_list, no_diag=False):
"""
Cross-entropy between softmax outputs of the teacher and student networks.
"""
# TODO: Use cross_entropy_distribution here
total_loss = 0
global_loss = 0
for i, s in enumerate(student_output_list):
lsm = F.log_softmax(s / self.student_temp, dim=-1)
for j, t in enumerate(teacher_out_softmaxed_centered_list):
if no_diag and i == j:
continue
loss = torch.sum(t * lsm, dim=-1).mean()
total_loss -= loss
if no_diag and i < len(teacher_out_softmaxed_centered_list):
global_loss -= loss.detach()
return total_loss, global_loss

Loading