-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtsne.py
More file actions
60 lines (52 loc) · 1.71 KB
/
tsne.py
File metadata and controls
60 lines (52 loc) · 1.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import numpy as np
from sklearn.manifold import TSNE
import torch.nn as nn
import torch
import matplotlib.pyplot as plt
from torchvision.datasets import STL10
from torchvision.transforms import ToTensor
from torchvision import models
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
def main():
TRAIN_DATASET = STL10(root="data", split="train", download=False, transform=ToTensor())
tsne = TSNE(n_components=2)
features = np.zeros([5000,512])
labels = []
model = models.resnet18(pretrained=False)
model.fc = Identity()
model.eval()
print("Transforming images to features...")
with torch.no_grad():
percent = 0
for xi, x in enumerate(TRAIN_DATASET):
if xi % (len(TRAIN_DATASET)/100) == 0:
print("{}%".format(percent))
percent += 1
c, y = x
cd = model(torch.unsqueeze(c, 0)).numpy()
features[xi] = cd[0]
labels.append(y)
res = tsne.fit_transform(features)
tsneDF = pd.DataFrame(data = res, columns=['tsne dim 1', 'tsne dim 2'])
fig = plt.figure()
fig.set_figheight(20)
fig.set_figwidth(30)
ts = fig.add_subplot(1, 1, 1)
ts.set_xlabel('Dimension 1', fontsize = 15)
ts.set_ylabel('Dimension 2', fontsize = 15)
ts.set_title('TSNE', fontsize = 20)
ts.scatter(tsneDF['tsne dim 1'],
tsneDF['tsne dim 2'],
c = cs_colors
)
recs=[]
for i in range(len(colors)):
recs.append(mpatches.Rectangle((0,0),1,1,fc=colors[i]))
ts.legend(handles=recs, labels=cell_ids, title="cell")
ts.grid()
if __name__ == '__main__':
main()