diff --git a/cvivit.py b/cvivit.py index 9ce2d74..9312ef7 100644 --- a/cvivit.py +++ b/cvivit.py @@ -142,7 +142,7 @@ def __init__(self, patch_size=(5, 8, 8), compressed_frames=20, latent_size=32, c super().__init__() self.encoder = Encoder(patch_size=patch_size, hidden_channels=c_hidden, size=latent_size, compressed_frames=compressed_frames, num_layers=num_layers_enc, num_heads=num_heads) - self.cod_mapper = nn.Linear(c_hidden, c_codebook) + self.cod_mapper = nn.Linear(c_hidden, c_codebook, bias=False) self.batchnorm = nn.BatchNorm2d(c_codebook) self.cod_unmapper = nn.Linear(c_codebook, c_hidden) diff --git a/vivq.py b/vivq.py index f82463f..024cd27 100644 --- a/vivq.py +++ b/vivq.py @@ -210,7 +210,7 @@ def __init__(self, base_channels=3, c_hidden=512, c_codebook=16, codebook_size=1 super().__init__() self.encoder = Encoder(base_channels, c_hidden=c_hidden) self.cod_mapper = nn.Sequential( - nn.Conv3d(c_hidden, c_codebook, kernel_size=1), + nn.Conv3d(c_hidden, c_codebook, kernel_size=1, bias=False), nn.BatchNorm3d(c_codebook), ) self.cod_unmapper = nn.Conv3d(c_codebook, c_hidden, kernel_size=1)