diff --git a/inversion.py b/inversion.py index 8ab53f7..cdce15f 100644 --- a/inversion.py +++ b/inversion.py @@ -1,18 +1,11 @@ from PIL import Image from omegaconf import OmegaConf import argparse -from collections import defaultdict -from diffusion_core import diffusion_models_registry from diffusion_core.utils import load_512 from diffusion_core.guiders.guidance_editing import GuidanceEditing from diffusion_core.utils import use_deterministic from diffusion_core.load_consistency import load_consistency - -def get_model(model_name, scheduler, device): - model = diffusion_models_registry[model_name](device) - return model - def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( @@ -23,9 +16,13 @@ def parse_args(): '--src_prompt', default='a cup of coffee with drawing of tulip putted on the wooden table', ) + parser.add_argument( + '--trg_prompt', + default='a cup of coffee with drawing of lion putted on the wooden table', + ) parser.add_argument( '--output_path', - default='inverted_image.jpg', + default='simple_edited_image.jpg', ) args = parser.parse_args() return args @@ -37,13 +34,16 @@ def parse_args(): device = 'cuda:0' solver = load_consistency(device) + # Minimal config without attention guidance to save memory config = OmegaConf.create() - config['cfg_schedule'] = [0, 0, 0, 0] - config['guiders'] = [] - config['noise_rescaling_setup'] = {"type": 'identity_rescaler', - 'init_setup': None} + config['cfg_schedule'] = [0, 7, 7, 7] # Simple CFG schedule + config['guiders'] = [] # No attention guidance + config['noise_rescaling_setup'] = {"type": 'identity_rescaler', 'init_setup': None} + image = load_512(args.image_path) guidance = GuidanceEditing(solver, config, device) - result = guidance(image, args.src_prompt, - args.src_prompt) - Image.fromarray(result).save(args.output_path) \ No newline at end of file + + # Now using different source and target prompts + result = guidance(image, args.src_prompt, args.trg_prompt) + Image.fromarray(result).save(args.output_path) + print(f"Edited image saved to {args.output_path}")