Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions inversion.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
from PIL import Image
from omegaconf import OmegaConf
import argparse
from collections import defaultdict
from diffusion_core import diffusion_models_registry
from diffusion_core.utils import load_512
from diffusion_core.guiders.guidance_editing import GuidanceEditing
from diffusion_core.utils import use_deterministic
from diffusion_core.load_consistency import load_consistency


def get_model(model_name, scheduler, device):
model = diffusion_models_registry[model_name](device)
return model

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
Expand All @@ -23,9 +16,13 @@ def parse_args():
'--src_prompt',
default='a cup of coffee with drawing of tulip putted on the wooden table',
)
parser.add_argument(
'--trg_prompt',
default='a cup of coffee with drawing of lion putted on the wooden table',
)
parser.add_argument(
'--output_path',
default='inverted_image.jpg',
default='simple_edited_image.jpg',
)
args = parser.parse_args()
return args
Expand All @@ -37,13 +34,16 @@ def parse_args():
device = 'cuda:0'
solver = load_consistency(device)

# Minimal config without attention guidance to save memory
config = OmegaConf.create()
config['cfg_schedule'] = [0, 0, 0, 0]
config['guiders'] = []
config['noise_rescaling_setup'] = {"type": 'identity_rescaler',
'init_setup': None}
config['cfg_schedule'] = [0, 7, 7, 7] # Simple CFG schedule
config['guiders'] = [] # No attention guidance
config['noise_rescaling_setup'] = {"type": 'identity_rescaler', 'init_setup': None}

image = load_512(args.image_path)
guidance = GuidanceEditing(solver, config, device)
result = guidance(image, args.src_prompt,
args.src_prompt)
Image.fromarray(result).save(args.output_path)

# Now using different source and target prompts
result = guidance(image, args.src_prompt, args.trg_prompt)
Image.fromarray(result).save(args.output_path)
print(f"Edited image saved to {args.output_path}")