-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexpert_data.py
More file actions
69 lines (60 loc) · 2.25 KB
/
expert_data.py
File metadata and controls
69 lines (60 loc) · 2.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import json
import torch
from models import selected_ids_list
from download import find_model
from models import DiT_models
from diffusion import create_diffusion
def image_class_expert_ratio():
image_size = 256
model = "DiT-S/2"
num_classes = 1000
device = "cuda"
ckpt_path = "results/002-DiT-S-2/checkpoints/ckpt.pt"
num_sampling_steps = 250
cfg_scale = 1.5
every_class_sample = 50
torch.manual_seed(1234)
torch.set_grad_enabled(False)
latent_size = image_size // 8
model = DiT_models[model](
input_size=latent_size,
num_classes=num_classes,
num_experts=8,
num_experts_per_tok=2,
).to(device)
if ckpt_path is not None:
print('load from: ', ckpt_path)
state_dict = find_model(ckpt_path)
model.load_state_dict(state_dict)
model.eval()
diffusion = create_diffusion(str(num_sampling_steps))
for i in range(1000):
experts_ids = []
for j in range(every_class_sample):
class_labels = [i]
# Create sampling noise:
n = len(class_labels)
z = torch.randn(n, 4, latent_size, latent_size, device=device)
y = torch.tensor(class_labels, device=device)
# Setup classifier-free guidance:
z = torch.cat([z, z], 0)
y_null = torch.tensor([1000] * n, device=device)
y = torch.cat([y, y_null], 0)
model_kwargs = dict(y=y, cfg_scale=cfg_scale)
# Sample images:
samples = diffusion.p_sample_loop(
model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device
)
print(i, j)
print(len(selected_ids_list), len(selected_ids_list[0]), len(selected_ids_list[0][0]))
tmp_ids_list = selected_ids_list[-3000:]
print(len(tmp_ids_list), len(tmp_ids_list[0]), len(tmp_ids_list[0][0]))
print(tmp_ids_list[0][0])
experts_ids.append(tmp_ids_list)
#break
#continue
print(len(experts_ids))
tgt_path = os.path.join('experts', str(i)+'.json')
with open(tgt_path, 'w') as f:
json.dump(experts_ids, f,)