From 230999769165bcd018d04e1b7eb83ac410d9fe2a Mon Sep 17 00:00:00 2001 From: theoakbear <138836705+theoakbear@users.noreply.github.com> Date: Tue, 17 Feb 2026 16:53:01 -0800 Subject: [PATCH 1/3] Added support for reference images and img2img scenarios --- src/api.rs | 203 ++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 194 insertions(+), 9 deletions(-) diff --git a/src/api.rs b/src/api.rs index 41d9390..d9c4a16 100644 --- a/src/api.rs +++ b/src/api.rs @@ -34,6 +34,7 @@ use diffusion_rs_sys::sd_set_progress_callback; use diffusion_rs_sys::sd_slg_params_t; use diffusion_rs_sys::sd_tiling_params_t; use diffusion_rs_sys::upscaler_ctx_t; +use image::DynamicImage; use image::ImageBuffer; use image::ImageError; use image::RgbImage; @@ -699,6 +700,17 @@ impl ModelConfig { unsafe fn diffusion_ctx(&mut self, vae_decode_only: bool) -> *mut sd_ctx_t { unsafe { + // This is required to support img2img after text2img generation + // otherwise the context is cached and won't have a decode graph + // leading to an assertion error in sdcpp + if self.diffusion_ctx.is_some() { + let (sd_ctx, sd_ctx_params) = self.diffusion_ctx.unwrap(); + if sd_ctx_params.vae_decode_only != vae_decode_only { + sd_set_progress_callback(None, null_mut()); + free_sd_ctx(sd_ctx); + } + self.diffusion_ctx = None; + } if self.diffusion_ctx.is_none() { let sd_ctx_params = sd_ctx_params_t { model_path: self.model.as_ptr(), @@ -852,10 +864,18 @@ pub struct Config { #[builder(default = "Default::default()")] init_img: CLibPath, + /// Path to the image used as a mask for img2img + #[builder(default = "Default::default()")] + mask_img: CLibPath, + /// Path to image condition, control net #[builder(default = "Default::default()")] control_image: CLibPath, + /// Paths to reference images for in-context conditioning (e.g. for Flux2) + #[builder(default = "Default::default()")] + ref_images: CLibPathVec, + /// Path to write result image to (default: ./output.png) #[builder(default = "PathBuf::from(\"./output.png\")")] output: PathBuf, @@ -1084,7 +1104,9 @@ impl From for ConfigBuilder { builder .pm_id_images_dir(value.pm_id_images_dir) .init_img(value.init_img) + .mask_img(value.mask_img) .control_image(value.control_image) + .ref_images(value.ref_images) .output(value.output) .prompt(value.prompt) .negative_prompt(value.negative_prompt) @@ -1144,6 +1166,33 @@ impl CLibPath { self.0.as_ptr() } } +#[derive(Debug, Clone, Default)] +struct CLibPathVec(Vec); + +impl FromIterator for CLibPathVec { + fn from_iter>(iter: T) -> Self { + let inner_vec: Vec = iter.into_iter().collect(); + + CLibPathVec(inner_vec) + } +} + +impl From> for CLibPathVec { + fn from(value: Vec<&Path>) -> CLibPathVec { + value.iter() + .map(|&p| CLibPath::from(p)) + .collect() + } +} + +impl From> for CLibPathVec { + fn from(value: Vec) -> CLibPathVec { + value.iter() + .map(|p| CLibPath::from(p.as_path())) + .collect() + } +} + impl From for CLibPath { fn from(value: PathBuf) -> Self { @@ -1224,15 +1273,28 @@ fn gen_img_maybe_progress( let prompt: CLibString = CLibString::from(config.prompt.as_str()); let files = output_files(&config.output, config.batch_count); unsafe { - let sd_ctx = model_config.diffusion_ctx(true); + let init_img_path_str = config.init_img.0.to_string_lossy(); + let mask_img_path_str = config.mask_img.0.to_string_lossy(); + let init_img_ref = init_img_path_str.as_ref(); + let mask_img_ref = mask_img_path_str.as_ref(); + let init_img_path = Path::new(&init_img_ref); + let mask_img_path = Path::new(&mask_img_ref); + + let has_init_image = init_img_path.exists(); + let has_mask_image = mask_img_path.exists(); + let has_ref_images = config.ref_images.0.len() > 0; + + let is_decode_only = !has_init_image; + let sd_ctx = model_config.diffusion_ctx(is_decode_only); let upscaler_ctx = model_config.upscaler_ctx(); - let init_image = sd_image_t { - width: 0, - height: 0, - channel: 3, - data: null_mut(), + + let mut init_image = sd_image_t { + width: 0, + height: 0, + channel: 3, + data: std::ptr::null_mut(), }; - let mask_image = sd_image_t { + let mut mask_image = sd_image_t { width: config.width as u32, height: config.height as u32, channel: 1, @@ -1292,6 +1354,83 @@ fn gen_img_maybe_progress( style_strength: config.pm_style_strength, }; + + // Declare the buffer in the function scope so it outlives the match block + let mut image_buffer: Vec = Vec::new(); + let mut mask_buffer: Vec = Vec::new(); + + if has_init_image { + let img = image::open(&init_img_path)?; + let (width, height) = (img.width(), img.height()); + image_buffer = img.to_rgb8().into_raw(); + + init_image = sd_image_t { + width: width, + height: height, + channel: 3, + data: image_buffer.as_mut_ptr(), + } + } + + if has_mask_image { + let img = image::open(&mask_img_path)?; + let (width, height) = (img.width(), img.height()); + // Masks have to have single channel luminosity information only + mask_buffer = img.to_luma8().into_raw(); + + mask_image = sd_image_t { + width: width, + height: height, + channel: 1, + data: mask_buffer.as_mut_ptr(), + } + } + + // stable-diffusion.cpp assumes that a mask is also included with the img2img flow + // if a mask is not provided, we use a flat white mask meaning all of the image is in scope + // otherwise generate_image throws a sigsegv when it tries to assign the mask + if !image_buffer.is_empty() && mask_buffer.is_empty() { + let mut img: ImageBuffer, Vec> = + ImageBuffer::from_pixel(init_image.width, init_image.height, image::Luma([255])); + mask_buffer = img.into_raw(); + mask_image = sd_image_t { + width: init_image.width, + height: init_image.height, + channel: 1, + data: mask_buffer.as_mut_ptr() + } + } + + let mut ref_image_list = Vec::new(); + let mut ref_pixel_storage = Vec::new(); + for ref_path in &config.ref_images.0 { + let ref_img_path_str = ref_path.0.to_string_lossy(); + let ref_img_ref = ref_img_path_str.as_ref(); + let ref_img_path = Path::new(&ref_img_ref); + + if ref_img_path.exists() { + let img = image::open(&ref_img_path)?; + let (width, height) = (img.width(), img.height()); + let image_data = img.to_rgb8().into_raw(); + + ref_pixel_storage.push(image_data); + let storage_ref = ref_pixel_storage.last_mut().unwrap(); + ref_image_list.push(sd_image_t { + width: width, + height: height, + channel: 3, + data: storage_ref.as_mut_ptr(), + }); + } + } + + let num_ref_images = ref_image_list.len(); + let ref_image_ptr = if num_ref_images > 0 { + ref_image_list.as_mut_ptr() + } else { + null_mut() + }; + unsafe extern "C" fn save_preview_local( _step: ::std::os::raw::c_int, _frame_count: ::std::os::raw::c_int, @@ -1342,8 +1481,8 @@ fn gen_img_maybe_progress( negative_prompt: config.negative_prompt.as_ptr(), clip_skip: config.clip_skip as i32, init_image, - ref_images: null_mut(), - ref_images_count: 0, + ref_images: ref_image_ptr, + ref_images_count: num_ref_images as i32, increase_ref_index: false, mask_image, width: config.width, @@ -1412,6 +1551,7 @@ fn save_img(img: sd_image_t, path: &Path, params: Option<&str>) -> Result<(), Di #[cfg(test)] mod tests { use std::path::PathBuf; + use image::{Rgba, ImageBuffer, DynamicImage}; use crate::{ api::{ConfigBuilderError, ModelConfigBuilder}, @@ -1455,6 +1595,51 @@ mod tests { .unwrap(); } + #[ignore] + #[test] + fn test_img2img_gen() { + let model_path = + download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt") + .unwrap(); + let gen_img_output = "./output_img.png"; + let config = ConfigBuilder::default() + .prompt("A high quality 3d texture") + .output(PathBuf::from(gen_img_output)) + .batch_count(1) + .build() + .unwrap(); + + let mut model_config = ModelConfigBuilder::default() + .model(model_path) + .build() + .unwrap(); + + gen_img(&config, &mut model_config); + + // 2. Create conditioning image: Gradient square + let mut cond = ImageBuffer::new(512, 512); + for (x, y, pixel) in cond.enumerate_pixels_mut() { + let r = (x as f32 / 512.0 * 255.0) as u8; + let g = (y as f32 / 512.0 * 255.0) as u8; + let b = 127; + *pixel = Rgba([r, g, b, 255]); + } + let cond_path = "test_cond_image.png"; + DynamicImage::ImageRgba8(cond).save(cond_path).expect("Failed to save reference image"); + + // 3. Call refinement using the generated image as input + let refine_prompt = "PBR texture map, matching the lighting and micro-detail density of the reference image."; + let img2img_config = ConfigBuilder::default() + .prompt(refine_prompt) + .output(PathBuf::from("./output_img_ref.png")) + .ref_images(vec![PathBuf::from(cond_path)]) + .init_img(PathBuf::from(gen_img_output)) + .batch_count(1) + .build() + .unwrap(); + gen_img(&img2img_config, &mut model_config); + } + #[ignore] #[test] fn test_img_gen() { From 873f14da2c5af6d2987a1a6ddf40d5360e88f03c Mon Sep 17 00:00:00 2001 From: theoakbear <138836705+theoakbear@users.noreply.github.com> Date: Thu, 19 Feb 2026 10:53:33 -0800 Subject: [PATCH 2/3] Changed parts of the API to use PathBuf per feedback. Addressed compiler warnings and updated tests with new use case. --- src/api.rs | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/src/api.rs b/src/api.rs index d9c4a16..c157bdc 100644 --- a/src/api.rs +++ b/src/api.rs @@ -34,7 +34,6 @@ use diffusion_rs_sys::sd_set_progress_callback; use diffusion_rs_sys::sd_slg_params_t; use diffusion_rs_sys::sd_tiling_params_t; use diffusion_rs_sys::upscaler_ctx_t; -use image::DynamicImage; use image::ImageBuffer; use image::ImageError; use image::RgbImage; @@ -862,11 +861,11 @@ pub struct Config { /// Path to the input image, required by img2img #[builder(default = "Default::default()")] - init_img: CLibPath, + init_img: PathBuf, /// Path to the image used as a mask for img2img #[builder(default = "Default::default()")] - mask_img: CLibPath, + mask_img: PathBuf, /// Path to image condition, control net #[builder(default = "Default::default()")] @@ -874,7 +873,7 @@ pub struct Config { /// Paths to reference images for in-context conditioning (e.g. for Flux2) #[builder(default = "Default::default()")] - ref_images: CLibPathVec, + ref_images: Vec, /// Path to write result image to (default: ./output.png) #[builder(default = "PathBuf::from(\"./output.png\")")] @@ -1166,6 +1165,7 @@ impl CLibPath { self.0.as_ptr() } } + #[derive(Debug, Clone, Default)] struct CLibPathVec(Vec); @@ -1273,16 +1273,11 @@ fn gen_img_maybe_progress( let prompt: CLibString = CLibString::from(config.prompt.as_str()); let files = output_files(&config.output, config.batch_count); unsafe { - let init_img_path_str = config.init_img.0.to_string_lossy(); - let mask_img_path_str = config.mask_img.0.to_string_lossy(); - let init_img_ref = init_img_path_str.as_ref(); - let mask_img_ref = mask_img_path_str.as_ref(); - let init_img_path = Path::new(&init_img_ref); - let mask_img_path = Path::new(&mask_img_ref); + let init_img_path = Path::new(&config.init_img); + let mask_img_path = Path::new(&config.mask_img); let has_init_image = init_img_path.exists(); let has_mask_image = mask_img_path.exists(); - let has_ref_images = config.ref_images.0.len() > 0; let is_decode_only = !has_init_image; let sd_ctx = model_config.diffusion_ctx(is_decode_only); @@ -1390,7 +1385,7 @@ fn gen_img_maybe_progress( // if a mask is not provided, we use a flat white mask meaning all of the image is in scope // otherwise generate_image throws a sigsegv when it tries to assign the mask if !image_buffer.is_empty() && mask_buffer.is_empty() { - let mut img: ImageBuffer, Vec> = + let img: ImageBuffer, Vec> = ImageBuffer::from_pixel(init_image.width, init_image.height, image::Luma([255])); mask_buffer = img.into_raw(); mask_image = sd_image_t { @@ -1403,10 +1398,8 @@ fn gen_img_maybe_progress( let mut ref_image_list = Vec::new(); let mut ref_pixel_storage = Vec::new(); - for ref_path in &config.ref_images.0 { - let ref_img_path_str = ref_path.0.to_string_lossy(); - let ref_img_ref = ref_img_path_str.as_ref(); - let ref_img_path = Path::new(&ref_img_ref); + for ref_path_str in &config.ref_images { + let ref_img_path = Path::new(&ref_path_str); if ref_img_path.exists() { let img = image::open(&ref_img_path)?; @@ -1614,7 +1607,7 @@ mod tests { .build() .unwrap(); - gen_img(&config, &mut model_config); + gen_img(&config, &mut model_config).unwrap(); // 2. Create conditioning image: Gradient square let mut cond = ImageBuffer::new(512, 512); @@ -1637,7 +1630,10 @@ mod tests { .batch_count(1) .build() .unwrap(); - gen_img(&img2img_config, &mut model_config); + gen_img(&img2img_config, &mut model_config).unwrap(); + + // 4. Ensure decoder only mode works after img2img generation + gen_img(&config, &mut model_config).unwrap(); } #[ignore] From cdb7e87d3dc093bcfb33ef43e27e333ed7aa22f5 Mon Sep 17 00:00:00 2001 From: newfla Date: Fri, 20 Feb 2026 10:15:08 +0100 Subject: [PATCH 3/3] chore: removed unused CLibPathVec and avoided rewrapping PathBuf to Path --- src/api.rs | 120 +++++++++++++++++++---------------------------------- 1 file changed, 42 insertions(+), 78 deletions(-) diff --git a/src/api.rs b/src/api.rs index c157bdc..9dfc439 100644 --- a/src/api.rs +++ b/src/api.rs @@ -699,15 +699,14 @@ impl ModelConfig { unsafe fn diffusion_ctx(&mut self, vae_decode_only: bool) -> *mut sd_ctx_t { unsafe { - // This is required to support img2img after text2img generation - // otherwise the context is cached and won't have a decode graph + // This is required to support img2img after text2img generation + // otherwise the context is cached and won't have a decode graph // leading to an assertion error in sdcpp - if self.diffusion_ctx.is_some() { - let (sd_ctx, sd_ctx_params) = self.diffusion_ctx.unwrap(); - if sd_ctx_params.vae_decode_only != vae_decode_only { - sd_set_progress_callback(None, null_mut()); - free_sd_ctx(sd_ctx); - } + if let Some((sd_ctx, sd_ctx_params)) = self.diffusion_ctx.as_ref() + && sd_ctx_params.vae_decode_only != vae_decode_only + { + sd_set_progress_callback(None, null_mut()); + free_sd_ctx(*sd_ctx); self.diffusion_ctx = None; } if self.diffusion_ctx.is_none() { @@ -1166,34 +1165,6 @@ impl CLibPath { } } -#[derive(Debug, Clone, Default)] -struct CLibPathVec(Vec); - -impl FromIterator for CLibPathVec { - fn from_iter>(iter: T) -> Self { - let inner_vec: Vec = iter.into_iter().collect(); - - CLibPathVec(inner_vec) - } -} - -impl From> for CLibPathVec { - fn from(value: Vec<&Path>) -> CLibPathVec { - value.iter() - .map(|&p| CLibPath::from(p)) - .collect() - } -} - -impl From> for CLibPathVec { - fn from(value: Vec) -> CLibPathVec { - value.iter() - .map(|p| CLibPath::from(p.as_path())) - .collect() - } -} - - impl From for CLibPath { fn from(value: PathBuf) -> Self { Self(CString::new(value.to_str().unwrap_or_default()).unwrap()) @@ -1273,21 +1244,18 @@ fn gen_img_maybe_progress( let prompt: CLibString = CLibString::from(config.prompt.as_str()); let files = output_files(&config.output, config.batch_count); unsafe { - let init_img_path = Path::new(&config.init_img); - let mask_img_path = Path::new(&config.mask_img); - - let has_init_image = init_img_path.exists(); - let has_mask_image = mask_img_path.exists(); - + let has_init_image = config.init_img.exists(); + let has_mask_image = config.mask_img.exists(); + let is_decode_only = !has_init_image; let sd_ctx = model_config.diffusion_ctx(is_decode_only); let upscaler_ctx = model_config.upscaler_ctx(); - + let mut init_image = sd_image_t { - width: 0, - height: 0, - channel: 3, - data: std::ptr::null_mut(), + width: 0, + height: 0, + channel: 3, + data: std::ptr::null_mut(), }; let mut mask_image = sd_image_t { width: config.width as u32, @@ -1349,33 +1317,30 @@ fn gen_img_maybe_progress( style_strength: config.pm_style_strength, }; - // Declare the buffer in the function scope so it outlives the match block let mut image_buffer: Vec = Vec::new(); let mut mask_buffer: Vec = Vec::new(); if has_init_image { - let img = image::open(&init_img_path)?; - let (width, height) = (img.width(), img.height()); - image_buffer = img.to_rgb8().into_raw(); - + let img = image::open(&config.init_img)?; + image_buffer = img.to_rgb8().into_raw(); + init_image = sd_image_t { - width: width, - height: height, + width: img.width(), + height: img.height(), channel: 3, data: image_buffer.as_mut_ptr(), } } if has_mask_image { - let img = image::open(&mask_img_path)?; - let (width, height) = (img.width(), img.height()); + let img = image::open(&config.mask_img)?; // Masks have to have single channel luminosity information only - mask_buffer = img.to_luma8().into_raw(); - + mask_buffer = img.to_luma8().into_raw(); + mask_image = sd_image_t { - width: width, - height: height, + width: img.width(), + height: img.height(), channel: 1, data: mask_buffer.as_mut_ptr(), } @@ -1385,32 +1350,29 @@ fn gen_img_maybe_progress( // if a mask is not provided, we use a flat white mask meaning all of the image is in scope // otherwise generate_image throws a sigsegv when it tries to assign the mask if !image_buffer.is_empty() && mask_buffer.is_empty() { - let img: ImageBuffer, Vec> = - ImageBuffer::from_pixel(init_image.width, init_image.height, image::Luma([255])); + let img: ImageBuffer, Vec> = + ImageBuffer::from_pixel(init_image.width, init_image.height, image::Luma([255])); mask_buffer = img.into_raw(); mask_image = sd_image_t { width: init_image.width, height: init_image.height, channel: 1, - data: mask_buffer.as_mut_ptr() + data: mask_buffer.as_mut_ptr(), } } let mut ref_image_list = Vec::new(); let mut ref_pixel_storage = Vec::new(); - for ref_path_str in &config.ref_images { - let ref_img_path = Path::new(&ref_path_str); - - if ref_img_path.exists() { - let img = image::open(&ref_img_path)?; - let (width, height) = (img.width(), img.height()); - let image_data = img.to_rgb8().into_raw(); - + for ref_path in &config.ref_images { + if ref_path.exists() { + let img = image::open(ref_path)?; + let image_data = img.to_rgb8().into_raw(); + ref_pixel_storage.push(image_data); let storage_ref = ref_pixel_storage.last_mut().unwrap(); ref_image_list.push(sd_image_t { - width: width, - height: height, + width: img.width(), + height: img.height(), channel: 3, data: storage_ref.as_mut_ptr(), }); @@ -1543,8 +1505,8 @@ fn save_img(img: sd_image_t, path: &Path, params: Option<&str>) -> Result<(), Di #[cfg(test)] mod tests { + use image::{DynamicImage, ImageBuffer, Rgba}; use std::path::PathBuf; - use image::{Rgba, ImageBuffer, DynamicImage}; use crate::{ api::{ConfigBuilderError, ModelConfigBuilder}, @@ -1606,7 +1568,7 @@ mod tests { .model(model_path) .build() .unwrap(); - + gen_img(&config, &mut model_config).unwrap(); // 2. Create conditioning image: Gradient square @@ -1618,11 +1580,13 @@ mod tests { *pixel = Rgba([r, g, b, 255]); } let cond_path = "test_cond_image.png"; - DynamicImage::ImageRgba8(cond).save(cond_path).expect("Failed to save reference image"); + DynamicImage::ImageRgba8(cond) + .save(cond_path) + .expect("Failed to save reference image"); // 3. Call refinement using the generated image as input let refine_prompt = "PBR texture map, matching the lighting and micro-detail density of the reference image."; - let img2img_config = ConfigBuilder::default() + let img2img_config = ConfigBuilder::default() .prompt(refine_prompt) .output(PathBuf::from("./output_img_ref.png")) .ref_images(vec![PathBuf::from(cond_path)]) @@ -1631,7 +1595,7 @@ mod tests { .build() .unwrap(); gen_img(&img2img_config, &mut model_config).unwrap(); - + // 4. Ensure decoder only mode works after img2img generation gen_img(&config, &mut model_config).unwrap(); }