-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
105 lines (79 loc) · 2.32 KB
/
utils.py
File metadata and controls
105 lines (79 loc) · 2.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from functools import partial
from pathlib import Path
from typing import Tuple, Callable
from torch import nn, Tensor
from torch.nn import functional as F
from torch.optim import Optimizer
from torch.utils.data import Dataset
from dataset import ImagePuzzle
from model import Transformer, ImageSolver, TRM
def vanilla_step(model: nn.Module, optim: Optimizer, inputs: Tensor, labels: Tensor) -> Tuple[Tensor, Tensor]:
logits, = model(inputs)
loss = None
if labels is not None:
loss = F.cross_entropy(logits.transpose(1, 2), labels)
if optim is not None:
loss.backward()
optim.step()
optim.zero_grad()
return logits, loss
def vanilla(
puzzle_size: int,
model_dim: int,
ff_dim: int,
head_num: int,
layer_num: int,
**_,
) -> Tuple[nn.Module, Callable]:
piece_num = puzzle_size * puzzle_size
model = Transformer(model_dim, ff_dim, head_num, layer_num)
model = ImageSolver(piece_num, model)
step = vanilla_step
return model, step
def trm_step(model: nn.Module, optim: Optimizer, inputs: Tensor, labels: Tensor, s: int) -> Tuple[Tensor, Tensor]:
logits, loss, y, z = None, None, None, None
for i in range(s):
logits, y, z = model(inputs, y, z)
if labels is not None:
loss = F.cross_entropy(logits.transpose(1, 2), labels)
if optim is not None:
loss.backward()
optim.step()
optim.zero_grad()
return logits, loss
def trm(
puzzle_size: int,
model_dim: int,
ff_dim: int,
head_num: int,
layer_num: int,
s: int,
t: int,
n: int,
**_,
) -> Tuple[nn.Module, Callable]:
piece_num = puzzle_size * puzzle_size
model = TRM(model_dim, ff_dim, head_num, layer_num, piece_num, n, t)
model = ImageSolver(piece_num, model)
step = partial(trm_step, s=s)
return model, step
MODELS = {
"trm": trm,
"vanilla": vanilla,
}
def get_model(name: str, **kwargs):
model = MODELS[name](**kwargs)
return model
def image(path: Path, puzzle_size: int, tile_size: int, **_) -> Tuple[Dataset, Dataset]:
trainset = ImagePuzzle(path / "train", puzzle_size, tile_size)
testset = ImagePuzzle(path / "test", puzzle_size, tile_size)
return trainset, testset
load_dotenv()
DATASETS = os.getenv("DATASETS")
DATASETS = {
"coco": (image, {"path": DATASETS / "COCO" / "2017"}),
}
def get_dataset(name: str, **kwargs):
dataset, ds_kwargs = DATASETS[name]
dataset = dataset(**(ds_kwargs | kwargs))
return dataset