forked from Swastik3/DenseTeX
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathencoder.py
More file actions
76 lines (65 loc) · 3.13 KB
/
encoder.py
File metadata and controls
76 lines (65 loc) · 3.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import torch.nn as nn
import math
from model import GPTConfig
class PositionalEncoding2D(nn.Module):
def __init__(self, d_model, height, width):
super().__init__()
self.height = height
self.width = width
self.d_model = d_model
self.pe: torch.Tensor = self._get_positional_encoding(d_model, height, width).unsqueeze(0)
def _get_positional_encoding(self, d_model, height, width):
"""
:param d_model: dimension of the model
:param height: height of the positions
:param width: width of the positions
:return: d_model*height*width position matrix
"""
if d_model % 4 != 0:
raise ValueError("Cannot use sin/cos positional encoding with "
"odd dimension (got dim={:d})".format(d_model))
pe = torch.zeros(d_model, height, width)
# Each dimension use half of d_model
d_model = int(d_model / 2)
div_term = torch.exp(torch.arange(0., d_model, 2) *
-(math.log(10000.0) / d_model))
pos_w = torch.arange(0., width).unsqueeze(1)
pos_h = torch.arange(0., height).unsqueeze(1)
pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
return pe
def forward(self, x):
"""
Args:
x: Tensor of shape (batch_size, channels, height, width)
Returns:
Tensor with positional encodings added, of shape (batch_size, channels, height, width)
"""
batch_size, channels, height, width = x.size()
# Ensure the input has the correct number of channels
assert self.d_model == channels, "Dimension mismatch: d_model and input channels must be the same"
# Add positional encodings to the input tensor
self.pe = self.pe.to(x.device)
# print(f"PE shape: {self.pe.shape}, X shape: {x.shape}")
x = x + self.pe #the unsqueeze() might not be necessary, idk
# plt.imshow(self.pe[100], cmap = "gray")
return x
class InputEmbeddings(nn.Module):
def __init__(self, in_channels=1664, out_dim=GPTConfig.n_embd):
super().__init__()
self.in_channels = in_channels
self.out_dim = out_dim
self.projection = nn.Linear(in_channels, out_dim)
def forward(self, x):
# x shape: [batch_size, 1664, 12, 25]
batch_size = x.size(0)
# Reshape: [batch_size, 1664, 12, 25] -> [batch_size, 1664, 300]
x = x.view(batch_size, self.in_channels, -1)
# Transpose: [batch_size, 1664, 300] -> [batch_size, 300, 1664]
x = x.transpose(1, 2)
# Project: [batch_size, 300, 1664] -> [batch_size, 300, 768]
x = self.projection(x)
return x # Shape: [batch_size, sequence_length, embedding_dim]