diff --git a/src/api.rs b/src/api.rs index 41d9390..9dfc439 100644 --- a/src/api.rs +++ b/src/api.rs @@ -699,6 +699,16 @@ impl ModelConfig { unsafe fn diffusion_ctx(&mut self, vae_decode_only: bool) -> *mut sd_ctx_t { unsafe { + // This is required to support img2img after text2img generation + // otherwise the context is cached and won't have a decode graph + // leading to an assertion error in sdcpp + if let Some((sd_ctx, sd_ctx_params)) = self.diffusion_ctx.as_ref() + && sd_ctx_params.vae_decode_only != vae_decode_only + { + sd_set_progress_callback(None, null_mut()); + free_sd_ctx(*sd_ctx); + self.diffusion_ctx = None; + } if self.diffusion_ctx.is_none() { let sd_ctx_params = sd_ctx_params_t { model_path: self.model.as_ptr(), @@ -850,12 +860,20 @@ pub struct Config { /// Path to the input image, required by img2img #[builder(default = "Default::default()")] - init_img: CLibPath, + init_img: PathBuf, + + /// Path to the image used as a mask for img2img + #[builder(default = "Default::default()")] + mask_img: PathBuf, /// Path to image condition, control net #[builder(default = "Default::default()")] control_image: CLibPath, + /// Paths to reference images for in-context conditioning (e.g. for Flux2) + #[builder(default = "Default::default()")] + ref_images: Vec, + /// Path to write result image to (default: ./output.png) #[builder(default = "PathBuf::from(\"./output.png\")")] output: PathBuf, @@ -1084,7 +1102,9 @@ impl From for ConfigBuilder { builder .pm_id_images_dir(value.pm_id_images_dir) .init_img(value.init_img) + .mask_img(value.mask_img) .control_image(value.control_image) + .ref_images(value.ref_images) .output(value.output) .prompt(value.prompt) .negative_prompt(value.negative_prompt) @@ -1224,15 +1244,20 @@ fn gen_img_maybe_progress( let prompt: CLibString = CLibString::from(config.prompt.as_str()); let files = output_files(&config.output, config.batch_count); unsafe { - let sd_ctx = model_config.diffusion_ctx(true); + let has_init_image = config.init_img.exists(); + let has_mask_image = config.mask_img.exists(); + + let is_decode_only = !has_init_image; + let sd_ctx = model_config.diffusion_ctx(is_decode_only); let upscaler_ctx = model_config.upscaler_ctx(); - let init_image = sd_image_t { + + let mut init_image = sd_image_t { width: 0, height: 0, channel: 3, - data: null_mut(), + data: std::ptr::null_mut(), }; - let mask_image = sd_image_t { + let mut mask_image = sd_image_t { width: config.width as u32, height: config.height as u32, channel: 1, @@ -1292,6 +1317,75 @@ fn gen_img_maybe_progress( style_strength: config.pm_style_strength, }; + // Declare the buffer in the function scope so it outlives the match block + let mut image_buffer: Vec = Vec::new(); + let mut mask_buffer: Vec = Vec::new(); + + if has_init_image { + let img = image::open(&config.init_img)?; + image_buffer = img.to_rgb8().into_raw(); + + init_image = sd_image_t { + width: img.width(), + height: img.height(), + channel: 3, + data: image_buffer.as_mut_ptr(), + } + } + + if has_mask_image { + let img = image::open(&config.mask_img)?; + // Masks have to have single channel luminosity information only + mask_buffer = img.to_luma8().into_raw(); + + mask_image = sd_image_t { + width: img.width(), + height: img.height(), + channel: 1, + data: mask_buffer.as_mut_ptr(), + } + } + + // stable-diffusion.cpp assumes that a mask is also included with the img2img flow + // if a mask is not provided, we use a flat white mask meaning all of the image is in scope + // otherwise generate_image throws a sigsegv when it tries to assign the mask + if !image_buffer.is_empty() && mask_buffer.is_empty() { + let img: ImageBuffer, Vec> = + ImageBuffer::from_pixel(init_image.width, init_image.height, image::Luma([255])); + mask_buffer = img.into_raw(); + mask_image = sd_image_t { + width: init_image.width, + height: init_image.height, + channel: 1, + data: mask_buffer.as_mut_ptr(), + } + } + + let mut ref_image_list = Vec::new(); + let mut ref_pixel_storage = Vec::new(); + for ref_path in &config.ref_images { + if ref_path.exists() { + let img = image::open(ref_path)?; + let image_data = img.to_rgb8().into_raw(); + + ref_pixel_storage.push(image_data); + let storage_ref = ref_pixel_storage.last_mut().unwrap(); + ref_image_list.push(sd_image_t { + width: img.width(), + height: img.height(), + channel: 3, + data: storage_ref.as_mut_ptr(), + }); + } + } + + let num_ref_images = ref_image_list.len(); + let ref_image_ptr = if num_ref_images > 0 { + ref_image_list.as_mut_ptr() + } else { + null_mut() + }; + unsafe extern "C" fn save_preview_local( _step: ::std::os::raw::c_int, _frame_count: ::std::os::raw::c_int, @@ -1342,8 +1436,8 @@ fn gen_img_maybe_progress( negative_prompt: config.negative_prompt.as_ptr(), clip_skip: config.clip_skip as i32, init_image, - ref_images: null_mut(), - ref_images_count: 0, + ref_images: ref_image_ptr, + ref_images_count: num_ref_images as i32, increase_ref_index: false, mask_image, width: config.width, @@ -1411,6 +1505,7 @@ fn save_img(img: sd_image_t, path: &Path, params: Option<&str>) -> Result<(), Di #[cfg(test)] mod tests { + use image::{DynamicImage, ImageBuffer, Rgba}; use std::path::PathBuf; use crate::{ @@ -1455,6 +1550,56 @@ mod tests { .unwrap(); } + #[ignore] + #[test] + fn test_img2img_gen() { + let model_path = + download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt") + .unwrap(); + let gen_img_output = "./output_img.png"; + let config = ConfigBuilder::default() + .prompt("A high quality 3d texture") + .output(PathBuf::from(gen_img_output)) + .batch_count(1) + .build() + .unwrap(); + + let mut model_config = ModelConfigBuilder::default() + .model(model_path) + .build() + .unwrap(); + + gen_img(&config, &mut model_config).unwrap(); + + // 2. Create conditioning image: Gradient square + let mut cond = ImageBuffer::new(512, 512); + for (x, y, pixel) in cond.enumerate_pixels_mut() { + let r = (x as f32 / 512.0 * 255.0) as u8; + let g = (y as f32 / 512.0 * 255.0) as u8; + let b = 127; + *pixel = Rgba([r, g, b, 255]); + } + let cond_path = "test_cond_image.png"; + DynamicImage::ImageRgba8(cond) + .save(cond_path) + .expect("Failed to save reference image"); + + // 3. Call refinement using the generated image as input + let refine_prompt = "PBR texture map, matching the lighting and micro-detail density of the reference image."; + let img2img_config = ConfigBuilder::default() + .prompt(refine_prompt) + .output(PathBuf::from("./output_img_ref.png")) + .ref_images(vec![PathBuf::from(cond_path)]) + .init_img(PathBuf::from(gen_img_output)) + .batch_count(1) + .build() + .unwrap(); + gen_img(&img2img_config, &mut model_config).unwrap(); + + // 4. Ensure decoder only mode works after img2img generation + gen_img(&config, &mut model_config).unwrap(); + } + #[ignore] #[test] fn test_img_gen() {