-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathsample.py
More file actions
161 lines (133 loc) · 6.32 KB
/
sample.py
File metadata and controls
161 lines (133 loc) · 6.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import os
import argparse
import torch
from src.sd3 import SD3ControllerPipeline
from src.flux import FluxControllerPipeline
from src.sd1 import SD1ControllerPipeline
from src.controller import Controller
from peft import PeftModel
import yaml
import pandas as pd
def main(args):
# Create experiment directory
exp_idx = len([name for name in os.listdir(args.exp_dir) if os.path.isdir(os.path.join(args.exp_dir, name))])
exp_dir = os.path.join(args.exp_dir, f"exp_{exp_idx}")
os.makedirs(exp_dir, exist_ok=True)
# Save the configuration
with open(os.path.join(exp_dir, "config.yaml"), "w") as f:
yaml.dump(vars(args), f)
# Load the model
if args.model == "SD3":
pipe = SD3ControllerPipeline.from_pretrained(
"stabilityai/stable-diffusion-3.5-medium",
torch_dtype=torch.float16
)
NUM_INFERENCE_STEPS = 28
GUIDANCE_SCALE = 4.5
MAX_SEQ_LENGTH = 77
elif args.model == "FLUX":
pipe = FluxControllerPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16
)
NUM_INFERENCE_STEPS = 28
GUIDANCE_SCALE = 3.5
MAX_SEQ_LENGTH = 256
elif args.model == "SD1":
pipe = SD1ControllerPipeline.from_pretrained(
"sd-legacy/stable-diffusion-v1-5",
torch_dtype=torch.float16
)
NUM_INFERENCE_STEPS = 50
GUIDANCE_SCALE = 7.5
MAX_SEQ_LENGTH = 77
if args.path:
pipe.transformer = PeftModel.from_pretrained(pipe.transformer, args.path).merge_and_unload()
# For weaker GPUs
if args.save_memory == 0:
# Everything to GPU
pipe.to("cuda")
if args.save_memory == 1:
# Moves larges modules to CPU
pipe.enable_model_cpu_offload()
elif args.save_memory == 2:
# Sequentially moves modules to CPU
pipe.enable_sequential_cpu_offload()
elif args.save_memory == 3:
# Sequentially moves modules to CPU + gradient checkpointing
pipe.enable_sequential_cpu_offload()
from functools import partial
pipe.transformer.enable_gradient_checkpointing(
gradient_checkpointing_func=partial(torch.utils.checkpoint.checkpoint,
use_reentrant=False)
)
# Disable gradients for all parameters
if getattr(pipe, "transformer", None): pipe.transformer.requires_grad_(False)
if getattr(pipe, "vae", None): pipe.vae.requires_grad_(False)
if getattr(pipe, "text_encoder", None): pipe.text_encoder.requires_grad_(False)
if getattr(pipe, "text_encoder_2", None): pipe.text_encoder_2.requires_grad_(False)
if getattr(pipe, "text_encoder_3", None): pipe.text_encoder_3.requires_grad_(False)
print("Model loaded...")
# Load dataset
with open(args.dataset, "r") as file:
dataset = yaml.safe_load(file)
# Set seed(s)
if args.seed_range:
SEEDS = list(range(args.seed_range[0], args.seed_range[1]))
else:
SEEDS = [args.seed]
print(f"Sampling {len(dataset)} prompts with {len(SEEDS)} seeds each...")
# Override default parameters
NUM_INFERENCE_STEPS = args.num_steps if args.num_steps is not None else NUM_INFERENCE_STEPS
GUIDANCE_SCALE = args.guidance_scale if args.guidance_scale is not None else GUIDANCE_SCALE
print(f"Using {NUM_INFERENCE_STEPS} inference steps and guidance scale {GUIDANCE_SCALE}...")
image_counter = 0
output_paths = []
for seed in SEEDS:
for datum in dataset:
print(f"Sampling image {image_counter} with seed {seed} and prompt: {datum['prompt']}")
optim = Controller(
t5_ids=datum["t5"],
clip_ids=datum["clip"],
lambda_scale=args.lambda_scale,
heuristic=args.heuristic,
model=args.model
)
image = pipe(
datum["prompt"],
num_inference_steps=NUM_INFERENCE_STEPS,
guidance_scale=GUIDANCE_SCALE,
max_sequence_length=MAX_SEQ_LENGTH,
height=args.image_size,
width=args.image_size,
generator=torch.Generator("cpu").manual_seed(seed),
controller=optim if args.lambda_scale != 0 else None,
).images[0]
image_path = os.path.join(exp_dir, f"image_{image_counter}.png")
output_paths.append({
"image_path" : image_path,
"prompt": datum["prompt"],
})
image.save(image_path)
image_counter += 1
df = pd.DataFrame(output_paths)
df.to_csv(os.path.join(exp_dir, "prompts.csv"), index=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Sample from the model")
# Conditioning parameters
parser.add_argument("--exp-dir", type=str, default="images", help="Experiment directory")
parser.add_argument("--dataset", type=str, required=True, help="Path to the dataset YAML file")
parser.add_argument("--path", type=str, default=None, help="Path to the JEDI model weights")
parser.add_argument("--image-size", type=int, default=512, help="Size of the generated images")
parser.add_argument("--num-steps", type=int, default=None, help="Number of diffusion steps, if not specified, use model default")
parser.add_argument("--guidance-scale", type=float, default=None, help="Guidance scale for classifier-free guidance")
# Sampling parameters
parser.add_argument("--seed", type=int, default=0, help="Random seed for sampling")
parser.add_argument("--seed-range", type=int, nargs=2, help="Range of seeds for sampling")
# JEDI parameters
parser.add_argument("--lambda-scale", type=float, default=1, help="Lambda scaling factor for Controller, set to 0 to disable Controller")
parser.add_argument("--heuristic", type=str, choices=["focus", "conform", "attend_and_excite", "divide_and_bind", "jedi"], default="focus", help="Controller method to use")
parser.add_argument("--model", type=str, choices=["SD3", "FLUX", "SD1"], default="SD3", help="Base model to use")
parser.add_argument("--save-memory", type=int, choices=[0, 1, 2, 3], default=0, help="Memory saving mode: 0 (none), 1 (offload), 2 (sequential offload), 3 (sequential + grad checkpointing)")
args = parser.parse_args()
main(args)