Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 226 additions & 1 deletion flux/dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
from functools import partial
from pathlib import Path
import os

import mlx.core as mx
import mlx.nn as nn
Expand All @@ -13,8 +14,106 @@
from mlx.utils import tree_flatten, tree_map, tree_reduce
from PIL import Image

from huggingface_hub import HfApi, interpreter_login
from huggingface_hub.utils import HfFolder

from flux import FluxPipeline, Trainer, load_dataset, save_config

class FinetuningDataset:
def __init__(self, flux, args):
self.args = args
self.flux = flux
self.dataset_base = Path(args.dataset)
dataset_index = self.dataset_base / "index.json"
if not dataset_index.exists():
raise ValueError(f"'{args.dataset}' is not a valid finetuning dataset")
with open(dataset_index, "r") as f:
self.index = json.load(f)

self.latents = []
self.t5_features = []
self.clip_features = []

def _random_crop_resize(self, img):
resolution = self.args.resolution
width, height = img.size

a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()

# Random crop the input image between 0.8 to 1.0 of its original dimensions
crop_size = (
max((0.8 + 0.2 * a) * width, resolution[0]),
max((0.8 + 0.2 * a) * height, resolution[1]),
)
pan = (width - crop_size[0], height - crop_size[1])
img = img.crop(
(
pan[0] * b,
pan[1] * c,
crop_size[0] + pan[0] * b,
crop_size[1] + pan[1] * c,
)
)

# Fit the largest rectangle with the ratio of resolution in the image
# rectangle.
width, height = crop_size
ratio = resolution[0] / resolution[1]
r1 = (height * ratio, height)
r2 = (width, width / ratio)
r = r1 if r1[0] <= width else r2
img = img.crop(
(
(width - r[0]) / 2,
(height - r[1]) / 2,
(width + r[0]) / 2,
(height + r[1]) / 2,
)
)

# Finally resize the image to resolution
img = img.resize(resolution, Image.LANCZOS)

return mx.array(np.array(img))

def encode_images(self):
"""Encode the images in the latent space to prepare for training."""
self.flux.ae.eval()
for sample in tqdm(self.index["data"]):
input_img = Image.open(self.dataset_base / sample["image"])
for i in range(self.args.num_augmentations):
img = self._random_crop_resize(input_img)
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
x_0 = self.flux.ae.encode(img[None])
x_0 = x_0.astype(self.flux.dtype)
mx.eval(x_0)
self.latents.append(x_0)

def encode_prompts(self):
"""Pre-encode the prompts so that we don't recompute them during
training (doesn't allow finetuning the text encoders)."""
for sample in tqdm(self.index["data"]):
t5_tok, clip_tok = self.flux.tokenize([sample["text"]])
t5_feat = self.flux.t5(t5_tok)
clip_feat = self.flux.clip(clip_tok).pooled_output
mx.eval(t5_feat, clip_feat)
self.t5_features.append(t5_feat)
self.clip_features.append(clip_feat)

def iterate(self, batch_size):
xs = mx.concatenate(self.latents)
t5 = mx.concatenate(self.t5_features)
clip = mx.concatenate(self.clip_features)
mx.eval(xs, t5, clip)
n_aug = self.args.num_augmentations
while True:
x_indices = mx.random.permutation(len(self.latents))
c_indices = x_indices // n_aug
for i in range(0, len(self.latents), batch_size):
x_i = x_indices[i : i + batch_size]
c_i = c_indices[i : i + batch_size]
yield xs[x_i], t5[c_i], clip[c_i]


def generate_progress_images(iteration, flux, args):
"""Generate images to monitor the progress of the finetuning."""
Expand Down Expand Up @@ -58,6 +157,108 @@ def save_adapters(adapter_name, flux, args):
},
)

def push_to_hub(args):
if args.hf_token is None:
interpreter_login(new_session=False, write_permission=True)
else:
HfFolder.save_token(args.hf_token)

repo_id = args.hf_repo_id or f"{HfFolder.get_token_username()}/{args.output_dir}"

readme_content = generate_readme(args, repo_id)
readme_path = os.path.join(args.output_dir, "README.md")
with open(readme_path, "w", encoding="utf-8") as f:
f.write(readme_content)

api = HfApi()

api.create_repo(
repo_id,
private=args.hf_private,
exist_ok=True
)

api.upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
ignore_patterns=["*.yaml", "*.pt"],
repo_type="model",
)

def generate_readme(args, repo_id):
import yaml
import re
base_model = f"flux-{args.model}"
tags = [
"text-to-image",
"flux",
"lora",
"diffusers",
"template:sd-lora",
"mlx",
"mlx-trainer"
]

widgets = []
sample_image_paths = []
# Look for progress images directly in the output directory
for filename in os.listdir(args.output_dir):
match = re.search(r"(\d+)_progress\.png$", filename)
if match:
iteration = int(match.group(1))
sample_image_paths.append((iteration, filename))

sample_image_paths.sort(key=lambda x: x[0], reverse=True)

if sample_image_paths:
widgets.append(
{
"text": args.progress_prompt,
"output": {
"url": sample_image_paths[0][1]
},
}
)

readme_content = f"""---
tags:
{yaml.dump(tags, indent=4).strip()}
{"widget:" if sample_image_paths else ""}
{yaml.dump(widgets, indent=4).strip() if widgets else ""}
base_model: {base_model}
license: other
---

# {os.path.basename(args.output_dir)}
Model trained with the MLX Flux Dreambooth script

<Gallery />

## Use it with [MLX](https://github.com/ml-explore/mlx-examples)
```py
from flux import FluxPipeline
import mlx.core as mx
flux = FluxPipeline("flux-{args.model}")
flux.linear_to_lora_layers({args.lora_rank}, {args.lora_blocks})
flux.flow.load_weights("{repo_id}")
image = flux.generate_images("{args.progress_prompt}", n_images=1, num_steps={args.progress_steps})
image.save("my_image.png")
```

## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
```py
from diffusers import AutoPipelineForText2Image
import torch

pipeline = AutoPipelineForText2Image.from_pretrained('black-forest-labs/{args.model}', torch_dtype=torch.bfloat16).to('cuda')
pipeline.load_lora_weights('{repo_id}')
image = pipeline({args.progress_prompt}').images[0]
image.save("my_image.png")
```

For more details on using Flux, check the [Flux documentation](https://github.com/black-forest-labs/flux).
"""
return readme_content

def setup_arg_parser():
"""Set up and return the argument parser."""
Expand Down Expand Up @@ -148,7 +349,28 @@ def setup_arg_parser():
parser.add_argument(
"--output-dir", default="mlx_output", help="Folder to save the checkpoints in"
)

parser.add_argument(
"--push_to_hub",
action="store_true",
help="Push the model to Hugging Face Hub after training",
)
parser.add_argument(
"--hf_token",
type=str,
default=None,
help="Hugging Face token for pushing to Hub",
)
parser.add_argument(
"--hf_repo_id",
type=str,
default=None,
help="Hugging Face repository ID for pushing to Hub",
)
parser.add_argument(
"--hf_private",
action="store_true",
help="Make the Hugging Face repository private",
)
parser.add_argument("dataset")
return parser

Expand Down Expand Up @@ -287,6 +509,9 @@ def step(x, t5_feat, clip_feat, guidance, prev_grads, perform_step):
if (i + 1) % 10 == 0:
losses = []
tic = time.time()

if args.push_to_hub:
push_to_hub(args)

save_adapters("final_adapters.safetensors", flux, args)
print("Training successful.")