-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathsd3_sampler.py
More file actions
512 lines (407 loc) · 20.1 KB
/
sd3_sampler.py
File metadata and controls
512 lines (407 loc) · 20.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
from typing import List, Tuple, Optional
import math
import torch
from tqdm import tqdm
from diffusers import StableDiffusion3Pipeline
# =======================================================================
# Factory
# =======================================================================
__SOLVER__ = {}
def register_solver(name:str):
def wrapper(cls):
if __SOLVER__.get(name, None) is not None:
raise ValueError(f"Solver {name} already registered.")
__SOLVER__[name] = cls
return cls
return wrapper
def get_solver(name:str, **kwargs):
if name not in __SOLVER__:
raise ValueError(f"Solver {name} does not exist.")
return __SOLVER__[name](**kwargs)
# =======================================================================
class StableDiffusion3Base():
def __init__(self, model_key:str='stabilityai/stable-diffusion-3-medium-diffusers', device='cuda', dtype=torch.float16):
self.device = device
self.dtype = dtype
pipe = StableDiffusion3Pipeline.from_pretrained(model_key, torch_dtype=self.dtype)
self.scheduler = pipe.scheduler
self.tokenizer_1 = pipe.tokenizer
self.tokenizer_2 = pipe.tokenizer_2
self.tokenizer_3 = pipe.tokenizer_3
self.text_enc_1 = pipe.text_encoder
self.text_enc_2 = pipe.text_encoder_2
self.text_enc_3 = pipe.text_encoder_3
self.vae=pipe.vae
self.transformer = pipe.transformer
self.transformer.eval()
self.transformer.requires_grad_(False)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)-1) if hasattr(self, "vae") and self.vae is not None else 8
)
del pipe
def encode_prompt(self, prompt: List[str], batch_size:int=1) -> List[torch.Tensor]:
'''
We assume that
1. number of tokens < max_length
2. one prompt for one image
'''
# CLIP encode (used for modulation of adaLN-zero)
# now, we have two CLIPs
text_clip1_ids = self.tokenizer_1(prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors='pt').input_ids
text_clip1_emb = self.text_enc_1(text_clip1_ids.to(self.text_enc_1.device), output_hidden_states=True)
pool_clip1_emb = text_clip1_emb[0].to(dtype=self.dtype, device=self.text_enc_1.device)
text_clip1_emb = text_clip1_emb.hidden_states[-2].to(dtype=self.dtype, device=self.text_enc_1.device)
text_clip2_ids = self.tokenizer_2(prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors='pt').input_ids
text_clip2_emb = self.text_enc_2(text_clip2_ids.to(self.text_enc_2.device), output_hidden_states=True)
pool_clip2_emb = text_clip2_emb[0].to(dtype=self.dtype, device=self.text_enc_2.device)
text_clip2_emb = text_clip2_emb.hidden_states[-2].to(dtype=self.dtype, device=self.text_enc_2.device)
# T5 encode (used for text condition)
text_t5_ids = self.tokenizer_3(prompt,
padding="max_length",
max_length=77,
truncation=True,
add_special_tokens=True,
return_tensors='pt').input_ids
text_t5_emb = self.text_enc_3(text_t5_ids.to(self.text_enc_3.device))[0]
text_t5_emb = text_t5_emb.to(dtype=self.dtype, device=self.text_enc_3.device)
# Merge
clip_prompt_emb = torch.cat([text_clip1_emb, text_clip2_emb], dim=-1)
clip_prompt_emb = torch.nn.functional.pad(
clip_prompt_emb, (0, text_t5_emb.shape[-1] - clip_prompt_emb.shape[-1])
)
prompt_emb = torch.cat([clip_prompt_emb, text_t5_emb], dim=-2)
pooled_prompt_emb = torch.cat([pool_clip1_emb, pool_clip2_emb], dim=-1)
return prompt_emb, pooled_prompt_emb
def initialize_latent(self, img_size:Tuple[int], batch_size:int=1, **kwargs):
H, W = img_size
lH, lW = H//self.vae_scale_factor, W//self.vae_scale_factor
lC = self.transformer.config.in_channels
latent_shape = (batch_size, lC, lH, lW)
z = torch.randn(latent_shape, device=self.device, dtype=self.dtype)
return z
def encode(self, image: torch.Tensor) -> torch.Tensor:
z = self.vae.encode(image).latent_dist.sample()
z = (z-self.vae.config.shift_factor) * self.vae.config.scaling_factor
return z
def decode(self, z: torch.Tensor) -> torch.Tensor:
z = (z/self.vae.config.scaling_factor) + self.vae.config.shift_factor
return self.vae.decode(z, return_dict=False)[0]
def predict_vector(self, z, t, prompt_emb, pooled_emb):
v = self.transformer(hidden_states=z,
timestep=t,
pooled_projections=pooled_emb,
encoder_hidden_states=prompt_emb,
return_dict=False)[0]
return v
class SD3Euler(StableDiffusion3Base):
def __init__(self, model_key:str='stabilityai/stable-diffusion-3-medium-diffusers', device='cuda'):
super().__init__(model_key=model_key, device=device)
def inversion(self, src_img, prompts: List[str], NFE:int, cfg_scale: float=1.0, batch_size: int=1,
prompt_emb:Optional[List[torch.Tensor]]=None,
null_emb:Optional[List[torch.Tensor]]=None):
# encode text prompts
with torch.no_grad():
if prompt_emb is None:
prompt_emb, pooled_emb = self.encode_prompt(prompts, batch_size)
else:
prompt_emb, pooled_emb = prompt_emb[0], prompt_emb[1]
prompt_emb = prompt_emb.to(self.transformer.device)
pooled_emb = pooled_emb.to(self.transformer.device)
if null_emb is None:
null_prompt_emb, null_pooled_emb = self.encode_prompt([""])
else:
null_prompt_emb, null_pooled_emb = null_emb[0], null_emb[1]
null_prompt_emb = null_prompt_emb.to(self.transformer.device)
null_pooled_emb = null_pooled_emb.to(self.transformer.device)
# initialize latent
src_img = src_img.to(device=self.vae.device, dtype=self.dtype)
with torch.no_grad():
z = self.encode(src_img).to(self.transformer.device)
# timesteps (default option. You can make your custom here.)
self.scheduler.set_timesteps(NFE, device=self.transformer.device)
timesteps = self.scheduler.timesteps
timesteps = torch.cat([timesteps, torch.zeros(1, device=self.transformer.device)])
timesteps = reversed(timesteps)
sigmas = timesteps / self.scheduler.config.num_train_timesteps
# Solve ODE
pbar = tqdm(timesteps[:-1], total=NFE, desc='SD3 Euler Inversion')
for i, t in enumerate(pbar):
timestep = t.expand(z.shape[0]).to(self.transformer.device)
with torch.no_grad():
pred_v = self.predict_vector(z, timestep, prompt_emb, pooled_emb)
if cfg_scale != 1.0:
pred_null_v = self.predict_vector(z, timestep, null_prompt_emb, null_pooled_emb)
else:
pred_null_v = 0.0
sigma = sigmas[i]
sigma_next = sigmas[i+1]
z = z + (sigma_next - sigma) * (pred_null_v + cfg_scale * (pred_v - pred_null_v))
return z
def sample(self, prompts: List[str], NFE:int, img_shape: Optional[Tuple[int]]=None,
cfg_scale: float=1.0, batch_size: int = 1,
latent:Optional[List[torch.Tensor]]=None,
prompt_emb:Optional[List[torch.Tensor]]=None,
null_emb:Optional[List[torch.Tensor]]=None):
imgH, imgW = img_shape if img_shape is not None else (1024, 1024)
# encode text prompts
with torch.no_grad():
if prompt_emb is None:
prompt_emb, pooled_emb = self.encode_prompt(prompts, batch_size)
else:
prompt_emb, pooled_emb = prompt_emb[0], prompt_emb[1]
prompt_emb.to(self.transformer.device)
pooled_emb.to(self.transformer.device)
if null_emb is None:
null_prompt_emb, null_pooled_emb = self.encode_prompt([""], batch_size)
else:
null_prompt_emb, null_pooled_emb = null_emb[0], null_emb[1]
null_prompt_emb.to(self.transformer.device)
null_pooled_emb.to(self.transformer.device)
# initialize latent
if latent is None:
z = self.initialize_latent((imgH, imgW), batch_size)
else:
z = latent
# timesteps (default option. You can make your custom here.)
self.scheduler.set_timesteps(NFE, device=self.device)
timesteps = self.scheduler.timesteps
sigmas = timesteps / self.scheduler.config.num_train_timesteps
# Solve ODE
pbar = tqdm(timesteps, total=NFE, desc='SD3 Euler')
for i, t in enumerate(pbar):
timestep = t.expand(z.shape[0]).to(self.device)
pred_v = self.predict_vector(z, timestep, prompt_emb, pooled_emb)
if cfg_scale != 1.0:
pred_null_v = self.predict_vector(z, timestep, null_prompt_emb, null_pooled_emb)
else:
pred_null_v = 0.0
sigma = sigmas[i]
sigma_next = sigmas[i+1] if i+1 < NFE else 0.0
z = z + (sigma_next - sigma) * (pred_null_v + cfg_scale * (pred_v - pred_null_v))
# decode
with torch.no_grad():
img = self.decode(z)
return img
@register_solver("flowdps")
class SD3FlowDPS(SD3Euler):
def data_consistency(self, z0t, operator, measurement, task, stepsize:int=30.0):
z0t = z0t.requires_grad_(True)
num_iters = 3
for _ in range(num_iters):
x0t = self.decode(z0t).float()
if "sr" in task:
loss = torch.linalg.norm((operator.A_pinv(measurement) - operator.A_pinv(operator.A(x0t))).view(1, -1))
else:
loss = torch.linalg.norm((operator.At(measurement) - operator.At(operator.A(x0t))).view(1, -1))
grad = torch.autograd.grad(loss, z0t)[0].half()
z0t = z0t - stepsize*grad
return z0t.detach()
def sample(self, measurement, operator, task,
prompts: List[str], NFE:int,
img_shape: Optional[Tuple[int]]=None,
cfg_scale: float=1.0, batch_size: int = 1,
step_size: float=30.0,
latent:Optional[List[torch.Tensor]]=None,
prompt_emb:Optional[List[torch.Tensor]]=None,
null_emb:Optional[List[torch.Tensor]]=None):
imgH, imgW = img_shape if img_shape is not None else (1024, 1024)
# encode text prompts
with torch.no_grad():
if prompt_emb is None:
prompt_emb, pooled_emb = self.encode_prompt(prompts, batch_size)
else:
prompt_emb, pooled_emb = prompt_emb[0], prompt_emb[1]
prompt_emb.to(self.transformer.device)
pooled_emb.to(self.transformer.device)
if null_emb is None:
null_prompt_emb, null_pooled_emb = self.encode_prompt([""], batch_size)
else:
null_prompt_emb, null_pooled_emb = null_emb[0], null_emb[1]
null_prompt_emb.to(self.transformer.device)
null_pooled_emb.to(self.transformer.device)
# initialize latent
if latent is None:
z = self.initialize_latent((imgH, imgW), batch_size)
else:
z = latent
# timesteps (default option. You can make your custom here.)
self.scheduler.config.shift = 4.0
self.scheduler.set_timesteps(NFE, device=self.device)
timesteps = self.scheduler.timesteps
sigmas = timesteps / self.scheduler.config.num_train_timesteps
# Solve ODE
pbar = tqdm(timesteps, total=NFE, desc='SD3-FlowDPS')
for i, t in enumerate(pbar):
timestep = t.expand(z.shape[0]).to(self.device)
with torch.no_grad():
pred_v = self.predict_vector(z, timestep, prompt_emb, pooled_emb)
if cfg_scale != 1.0:
pred_null_v = self.predict_vector(z, timestep, null_prompt_emb, null_pooled_emb)
else:
pred_null_v = 0.0
sigma = sigmas[i]
sigma_next = sigmas[i+1] if i+1 < NFE else 0.0
# denoising
z0t = z - sigma * (pred_null_v + cfg_scale * (pred_v-pred_null_v))
z1t = z + (1-sigma) * (pred_null_v + cfg_scale * (pred_v-pred_null_v))
delta = sigma - sigma_next
if i < NFE:
z0y = self.data_consistency(z0t, operator, measurement, task=task, stepsize=step_size)
z0y = (1-sigma) * z0t + sigma * z0y
# renoising
noise = math.sqrt(sigma_next) * z1t + math.sqrt(1-sigma_next) * torch.randn_like(z1t)
z = z0y + (sigma-delta) * (noise - z0y)
# decode
with torch.no_grad():
img = self.decode(z)
return img
@register_solver("flowchef")
class SD3FlowChef(SD3Euler):
def data_consistency(self, z0t, operator, measurement, task):
z0t = z0t.requires_grad_(True)
x0t = self.decode(z0t).float()
if "sr" in task:
loss = torch.linalg.norm((operator.A_pinv(measurement) - operator.A_pinv(operator.A(x0t))).view(1, -1))
else:
loss = torch.linalg.norm((operator.At(measurement) - operator.At(operator.A(x0t))).view(1, -1))
grad = torch.autograd.grad(loss, z0t)[0].half()
return grad.detach()
def sample(self, measurement, operator, task,
prompts: List[str], NFE:int,
img_shape: Optional[Tuple[int]]=None,
cfg_scale: float=1.0, batch_size: int = 1,
step_size: float=50.0,
latent:Optional[List[torch.Tensor]]=None,
prompt_emb:Optional[List[torch.Tensor]]=None,
null_emb:Optional[List[torch.Tensor]]=None):
imgH, imgW = img_shape if img_shape is not None else (1024, 1024)
# encode text prompts
with torch.no_grad():
if prompt_emb is None:
prompt_emb, pooled_emb = self.encode_prompt(prompts, batch_size)
else:
prompt_emb, pooled_emb = prompt_emb[0], prompt_emb[1]
prompt_emb.to(self.transformer.device)
pooled_emb.to(self.transformer.device)
if null_emb is None:
null_prompt_emb, null_pooled_emb = self.encode_prompt([""], batch_size)
else:
null_prompt_emb, null_pooled_emb = null_emb[0], null_emb[1]
null_prompt_emb.to(self.transformer.device)
null_pooled_emb.to(self.transformer.device)
# initialize latent
if latent is None:
z = self.initialize_latent((imgH, imgW), batch_size)
else:
z = latent
# timesteps (default option. You can make your custom here.)
self.scheduler.config.shift = 4.0
self.scheduler.set_timesteps(NFE, device=self.device)
timesteps = self.scheduler.timesteps
sigmas = timesteps / self.scheduler.config.num_train_timesteps
# Solve ODE
pbar = tqdm(timesteps, total=NFE, desc='SD3-FlowChef')
for i, t in enumerate(pbar):
timestep = t.expand(z.shape[0]).to(self.device)
with torch.no_grad():
pred_v = self.predict_vector(z, timestep, prompt_emb, pooled_emb)
if cfg_scale != 1.0:
pred_null_v = self.predict_vector(z, timestep, null_prompt_emb, null_pooled_emb)
else:
pred_null_v = 0.0
sigma = sigmas[i]
sigma_next = sigmas[i+1] if i+1 < NFE else 0.0
# denoising
z0t = z - sigma * (pred_null_v + cfg_scale * (pred_v-pred_null_v))
z1t = z + (1-sigma) * (pred_null_v + cfg_scale * (pred_v-pred_null_v))
delta = sigma - sigma_next
if i < NFE:
grad = self.data_consistency(z0t, operator, measurement, task=task)
# renoising
z = z0t + (sigma-delta) * (z1t - z0t) - step_size*grad
# decode
with torch.no_grad():
img = self.decode(z)
return img
@register_solver('psld')
class SD3PSLD(SD3Euler):
def sample(self, measurement, operator, task,
prompts: List[str], NFE:int,
img_shape: Optional[Tuple[int]]=None,
cfg_scale: float=1.0, batch_size: int = 1,
step_size: float=50.0,
latent:Optional[List[torch.Tensor]]=None,
prompt_emb:Optional[List[torch.Tensor]]=None,
null_emb:Optional[List[torch.Tensor]]=None):
imgH, imgW = img_shape if img_shape is not None else (1024, 1024)
# encode text prompts
with torch.no_grad():
if prompt_emb is None:
prompt_emb, pooled_emb = self.encode_prompt(prompts, batch_size)
else:
prompt_emb, pooled_emb = prompt_emb[0], prompt_emb[1]
prompt_emb.to(self.transformer.device)
pooled_emb.to(self.transformer.device)
if null_emb is None:
null_prompt_emb, null_pooled_emb = self.encode_prompt([""], batch_size)
else:
null_prompt_emb, null_pooled_emb = null_emb[0], null_emb[1]
null_prompt_emb.to(self.transformer.device)
null_pooled_emb.to(self.transformer.device)
# initialize latent
if latent is None:
z = self.initialize_latent((imgH, imgW), batch_size)
else:
z = latent
# timesteps (default option. You can make your custom here.)
self.scheduler.config.shift = 4.0
self.scheduler.set_timesteps(NFE, device=self.device)
timesteps = self.scheduler.timesteps
sigmas = timesteps / self.scheduler.config.num_train_timesteps
# Solve ODE
pbar = tqdm(timesteps, total=NFE, desc='SD3-PSLD')
for i, t in enumerate(pbar):
timestep = t.expand(z.shape[0]).to(self.device)
z = z.requires_grad_(True)
pred_v = self.predict_vector(z, timestep, prompt_emb, pooled_emb)
if cfg_scale != 1.0:
pred_null_v = self.predict_vector(z, timestep, null_prompt_emb, null_pooled_emb)
else:
pred_null_v = 0.0
pred_v = pred_null_v + cfg_scale * (pred_v - pred_null_v)
sigma = sigmas[i]
sigma_next = sigmas[i+1] if i+1 < NFE else 0.0
# denoising
z0t = z - sigma * pred_v
z1t = z + (1-sigma) * pred_v
delta = sigma - sigma_next
# DC & goodness of z0t
x_pred = self.decode(z0t).float()
y_pred = operator.A(x_pred)
y_residue = torch.linalg.norm((y_pred-measurement).view(1, -1))
if "sr" in task:
ortho_proj = x_pred.reshape(1, -1) - operator.A_pinv(y_pred).reshape(1, -1)
parallel_proj = operator.A_pinv(measurement).reshape(1, -1)
else:
ortho_proj = x_pred.reshape(1, -1) - operator.At(y_pred).reshape(1, -1)
parallel_proj = operator.At(measurement).reshape(1, -1)
proj = parallel_proj + ortho_proj
recon_z = self.encode(proj.reshape(1, 3, imgH, imgW).half())
z0_residue = torch.linalg.norm((z0t - recon_z).view(1, -1))
residue = 1.0 * y_residue + 0.1 * z0_residue
grad = torch.autograd.grad(residue, z)[0]
# renoising
z = z0t + (sigma-delta) * (z1t - z0t) - step_size*grad
z.detach()
# decode
with torch.no_grad():
img = self.decode(z)
return img