Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 152 additions & 7 deletions src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,16 @@ impl ModelConfig {

unsafe fn diffusion_ctx(&mut self, vae_decode_only: bool) -> *mut sd_ctx_t {
unsafe {
// This is required to support img2img after text2img generation
// otherwise the context is cached and won't have a decode graph
// leading to an assertion error in sdcpp
if let Some((sd_ctx, sd_ctx_params)) = self.diffusion_ctx.as_ref()
&& sd_ctx_params.vae_decode_only != vae_decode_only
{
sd_set_progress_callback(None, null_mut());
free_sd_ctx(*sd_ctx);
self.diffusion_ctx = None;
}
if self.diffusion_ctx.is_none() {
let sd_ctx_params = sd_ctx_params_t {
model_path: self.model.as_ptr(),
Expand Down Expand Up @@ -850,12 +860,20 @@ pub struct Config {

/// Path to the input image, required by img2img
#[builder(default = "Default::default()")]
init_img: CLibPath,
init_img: PathBuf,

/// Path to the image used as a mask for img2img
#[builder(default = "Default::default()")]
mask_img: PathBuf,

/// Path to image condition, control net
#[builder(default = "Default::default()")]
control_image: CLibPath,

/// Paths to reference images for in-context conditioning (e.g. for Flux2)
#[builder(default = "Default::default()")]
ref_images: Vec<PathBuf>,

/// Path to write result image to (default: ./output.png)
#[builder(default = "PathBuf::from(\"./output.png\")")]
output: PathBuf,
Expand Down Expand Up @@ -1084,7 +1102,9 @@ impl From<Config> for ConfigBuilder {
builder
.pm_id_images_dir(value.pm_id_images_dir)
.init_img(value.init_img)
.mask_img(value.mask_img)
.control_image(value.control_image)
.ref_images(value.ref_images)
.output(value.output)
.prompt(value.prompt)
.negative_prompt(value.negative_prompt)
Expand Down Expand Up @@ -1224,15 +1244,20 @@ fn gen_img_maybe_progress(
let prompt: CLibString = CLibString::from(config.prompt.as_str());
let files = output_files(&config.output, config.batch_count);
unsafe {
let sd_ctx = model_config.diffusion_ctx(true);
let has_init_image = config.init_img.exists();
let has_mask_image = config.mask_img.exists();

let is_decode_only = !has_init_image;
let sd_ctx = model_config.diffusion_ctx(is_decode_only);
let upscaler_ctx = model_config.upscaler_ctx();
let init_image = sd_image_t {

let mut init_image = sd_image_t {
width: 0,
height: 0,
channel: 3,
data: null_mut(),
data: std::ptr::null_mut(),
};
let mask_image = sd_image_t {
let mut mask_image = sd_image_t {
width: config.width as u32,
height: config.height as u32,
channel: 1,
Expand Down Expand Up @@ -1292,6 +1317,75 @@ fn gen_img_maybe_progress(
style_strength: config.pm_style_strength,
};

// Declare the buffer in the function scope so it outlives the match block
let mut image_buffer: Vec<u8> = Vec::new();
let mut mask_buffer: Vec<u8> = Vec::new();

if has_init_image {
let img = image::open(&config.init_img)?;
image_buffer = img.to_rgb8().into_raw();

init_image = sd_image_t {
width: img.width(),
height: img.height(),
channel: 3,
data: image_buffer.as_mut_ptr(),
}
}

if has_mask_image {
let img = image::open(&config.mask_img)?;
// Masks have to have single channel luminosity information only
mask_buffer = img.to_luma8().into_raw();

mask_image = sd_image_t {
width: img.width(),
height: img.height(),
channel: 1,
data: mask_buffer.as_mut_ptr(),
}
}

// stable-diffusion.cpp assumes that a mask is also included with the img2img flow
// if a mask is not provided, we use a flat white mask meaning all of the image is in scope
// otherwise generate_image throws a sigsegv when it tries to assign the mask
if !image_buffer.is_empty() && mask_buffer.is_empty() {
let img: ImageBuffer<image::Luma<u8>, Vec<u8>> =
ImageBuffer::from_pixel(init_image.width, init_image.height, image::Luma([255]));
mask_buffer = img.into_raw();
mask_image = sd_image_t {
width: init_image.width,
height: init_image.height,
channel: 1,
data: mask_buffer.as_mut_ptr(),
}
}

let mut ref_image_list = Vec::new();
let mut ref_pixel_storage = Vec::new();
for ref_path in &config.ref_images {
if ref_path.exists() {
let img = image::open(ref_path)?;
let image_data = img.to_rgb8().into_raw();

ref_pixel_storage.push(image_data);
let storage_ref = ref_pixel_storage.last_mut().unwrap();
ref_image_list.push(sd_image_t {
width: img.width(),
height: img.height(),
channel: 3,
data: storage_ref.as_mut_ptr(),
});
}
}

let num_ref_images = ref_image_list.len();
let ref_image_ptr = if num_ref_images > 0 {
ref_image_list.as_mut_ptr()
} else {
null_mut()
};

unsafe extern "C" fn save_preview_local(
_step: ::std::os::raw::c_int,
_frame_count: ::std::os::raw::c_int,
Expand Down Expand Up @@ -1342,8 +1436,8 @@ fn gen_img_maybe_progress(
negative_prompt: config.negative_prompt.as_ptr(),
clip_skip: config.clip_skip as i32,
init_image,
ref_images: null_mut(),
ref_images_count: 0,
ref_images: ref_image_ptr,
ref_images_count: num_ref_images as i32,
increase_ref_index: false,
mask_image,
width: config.width,
Expand Down Expand Up @@ -1411,6 +1505,7 @@ fn save_img(img: sd_image_t, path: &Path, params: Option<&str>) -> Result<(), Di

#[cfg(test)]
mod tests {
use image::{DynamicImage, ImageBuffer, Rgba};
use std::path::PathBuf;

use crate::{
Expand Down Expand Up @@ -1455,6 +1550,56 @@ mod tests {
.unwrap();
}

#[ignore]
#[test]
fn test_img2img_gen() {
let model_path =
download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt")
.unwrap();
let gen_img_output = "./output_img.png";
let config = ConfigBuilder::default()
.prompt("A high quality 3d texture")
.output(PathBuf::from(gen_img_output))
.batch_count(1)
.build()
.unwrap();

let mut model_config = ModelConfigBuilder::default()
.model(model_path)
.build()
.unwrap();

gen_img(&config, &mut model_config).unwrap();

// 2. Create conditioning image: Gradient square
let mut cond = ImageBuffer::new(512, 512);
for (x, y, pixel) in cond.enumerate_pixels_mut() {
let r = (x as f32 / 512.0 * 255.0) as u8;
let g = (y as f32 / 512.0 * 255.0) as u8;
let b = 127;
*pixel = Rgba([r, g, b, 255]);
}
let cond_path = "test_cond_image.png";
DynamicImage::ImageRgba8(cond)
.save(cond_path)
.expect("Failed to save reference image");

// 3. Call refinement using the generated image as input
let refine_prompt = "PBR texture map, matching the lighting and micro-detail density of the reference image.";
let img2img_config = ConfigBuilder::default()
.prompt(refine_prompt)
.output(PathBuf::from("./output_img_ref.png"))
.ref_images(vec![PathBuf::from(cond_path)])
.init_img(PathBuf::from(gen_img_output))
.batch_count(1)
.build()
.unwrap();
gen_img(&img2img_config, &mut model_config).unwrap();

// 4. Ensure decoder only mode works after img2img generation
gen_img(&config, &mut model_config).unwrap();
}

#[ignore]
#[test]
fn test_img_gen() {
Expand Down