Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 25 additions & 50 deletions src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,7 @@ pub enum ClipSkip {
}

type EmbeddingsStorage = (PathBuf, Vec<(CLibString, CLibPath)>, Vec<sd_embedding_t>);

#[derive(Default, Debug, Clone)]
struct LoraStorage {
lora_model_dir: CLibPath,
data: Vec<(CLibPath, String, f32)>,
loras_t: Vec<sd_lora_t>,
}
type LoraStorage = Vec<(CLibPath, LoraSpec)>;

/// Specify the instructions for a Lora model
#[derive(Default, Debug, Clone)]
Expand Down Expand Up @@ -573,7 +567,7 @@ impl ModelConfigBuilder {
))
}

fn filter_valid_extensions(&self, path: &Path) -> impl Iterator<Item = DirEntry> {
fn filter_valid_extensions(path: &Path) -> impl Iterator<Item = DirEntry> {
WalkDir::new(path)
.into_iter()
.filter_map(|entry| entry.ok())
Expand All @@ -586,26 +580,17 @@ impl ModelConfigBuilder {
.unwrap_or(false)
})
}

fn build_single_lora_storage(
spec: &LoraSpec,
is_high_noise: bool,
valid_loras: &HashMap<String, PathBuf>,
) -> ((CLibPath, String, f32), sd_lora_t) {
) -> (CLibPath, LoraSpec) {
let path = valid_loras.get(&spec.file_name).unwrap().as_path();
let c_path = CLibPath::from(path);
let lora = sd_lora_t {
is_high_noise,
multiplier: spec.multiplier,
path: c_path.as_ptr(),
};
let data = (c_path, spec.file_name.clone(), spec.multiplier);
(data, lora)
(c_path, spec.clone())
}

pub fn embeddings(&mut self, embeddings_dir: &Path) -> &mut Self {
let data: Vec<(CLibString, CLibPath)> = self
.filter_valid_extensions(embeddings_dir)
let data: Vec<(CLibString, CLibPath)> = Self::filter_valid_extensions(embeddings_dir)
.map(|entry| {
let file_stem = entry
.path()
Expand All @@ -628,8 +613,7 @@ impl ModelConfigBuilder {
}

pub fn lora_models(&mut self, lora_model_dir: &Path, specs: Vec<LoraSpec>) -> &mut Self {
let valid_loras: HashMap<String, PathBuf> = self
.filter_valid_extensions(lora_model_dir)
let valid_loras: HashMap<String, PathBuf> = Self::filter_valid_extensions(lora_model_dir)
.map(|entry| {
let path = entry.path();
(
Expand All @@ -645,24 +629,17 @@ impl ModelConfigBuilder {
let standard = specs
.iter()
.filter(|s| valid_lora_names.contains(&&s.file_name) && !s.is_high_noise)
.map(|s| Self::build_single_lora_storage(s, false, &valid_loras));
.map(|s| Self::build_single_lora_storage(s, &valid_loras));
let high_noise = specs
.iter()
.filter(|s| valid_lora_names.contains(&&s.file_name) && s.is_high_noise)
.map(|s| Self::build_single_lora_storage(s, true, &valid_loras));
.map(|s| Self::build_single_lora_storage(s, &valid_loras));

let mut data = Vec::new();
let mut loras_t = Vec::new();
for lora in standard.chain(high_noise) {
data.push(lora.0);
loras_t.push(lora.1);
}
self.lora_models_internal(standard.chain(high_noise).collect())
}

self.lora_models = Some(LoraStorage {
lora_model_dir: lora_model_dir.into(),
data,
loras_t,
});
fn lora_models_internal(&mut self, lora_storage: LoraStorage) -> &mut Self {
self.lora_models = Some(lora_storage);
self
}

Expand Down Expand Up @@ -829,19 +806,7 @@ impl From<ModelConfig> for ModelConfigBuilder {
.circular_y(value.circular_y)
.use_qwen_image_zero_cond_true(value.use_qwen_image_zero_cond_true);

let lora_model_dir = Into::<PathBuf>::into(&value.lora_models.lora_model_dir);
let lora_specs = value
.lora_models
.data
.iter()
.map(|(_, name, multiplier)| LoraSpec {
file_name: name.clone(),
is_high_noise: false,
multiplier: *multiplier,
})
.collect();

builder.lora_models(&lora_model_dir, lora_specs);
builder.lora_models_internal(value.lora_models.clone());

if let Some(model) = &value.upscale_model {
builder.upscale_model(model.clone());
Expand Down Expand Up @@ -1431,6 +1396,16 @@ fn gen_img_maybe_progress(
sd_set_progress_callback(Some(progress_callback), sender_ptr);
}

let loras: Vec<sd_lora_t> = model_config
.lora_models
.iter()
.map(|(c_path, spec)| sd_lora_t {
is_high_noise: spec.is_high_noise,
multiplier: spec.multiplier,
path: c_path.as_ptr(),
})
.collect();

let sd_img_gen_params = sd_img_gen_params_t {
prompt: prompt.as_ptr(),
negative_prompt: config.negative_prompt.as_ptr(),
Expand All @@ -1452,8 +1427,8 @@ fn gen_img_maybe_progress(
vae_tiling_params,
auto_resize_ref_image: config.disable_auto_resize_ref_image,
cache: config.cache,
loras: model_config.lora_models.loras_t.as_ptr(),
lora_count: model_config.lora_models.loras_t.len() as u32,
loras: loras.as_ptr(),
lora_count: loras.len() as u32,
};

let params_str = CString::from_raw(sd_img_gen_params_to_str(&sd_img_gen_params))
Expand Down
Loading