Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 46 additions & 15 deletions src/ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,12 +491,16 @@ __STATIC_INLINE__ void ggml_ext_tensor_split_2d(struct ggml_tensor* input,
int64_t height = output->ne[1];
int64_t channels = output->ne[2];
int64_t ne3 = output->ne[3];

int64_t input_width = input->ne[0];
int64_t input_height = input->ne[1];

GGML_ASSERT(input->type == GGML_TYPE_F32 && output->type == GGML_TYPE_F32);
for (int iy = 0; iy < height; iy++) {
for (int ix = 0; ix < width; ix++) {
for (int k = 0; k < channels; k++) {
for (int l = 0; l < ne3; l++) {
float value = ggml_ext_tensor_get_f32(input, ix + x, iy + y, k, l);
float value = ggml_ext_tensor_get_f32(input, (ix + x) % input_width, (iy + y) % input_height, k, l);
ggml_ext_tensor_set_f32(output, value, ix, iy, k, l);
}
}
Expand All @@ -516,6 +520,8 @@ __STATIC_INLINE__ void ggml_ext_tensor_merge_2d(struct ggml_tensor* input,
int y,
int overlap_x,
int overlap_y,
bool circular_x,
bool circular_y,
int x_skip = 0,
int y_skip = 0) {
int64_t width = input->ne[0];
Expand All @@ -533,22 +539,22 @@ __STATIC_INLINE__ void ggml_ext_tensor_merge_2d(struct ggml_tensor* input,
for (int l = 0; l < ne3; l++) {
float new_value = ggml_ext_tensor_get_f32(input, ix, iy, k, l);
if (overlap_x > 0 || overlap_y > 0) { // blend colors in overlapped area
float old_value = ggml_ext_tensor_get_f32(output, x + ix, y + iy, k, l);
float old_value = ggml_ext_tensor_get_f32(output, (x + ix) % img_width, (y + iy) % img_height, k, l);

const float x_f_0 = (overlap_x > 0 && x > 0) ? (ix - x_skip) / float(overlap_x) : 1;
const float x_f_1 = (overlap_x > 0 && x < (img_width - width)) ? (width - ix) / float(overlap_x) : 1;
const float y_f_0 = (overlap_y > 0 && y > 0) ? (iy - y_skip) / float(overlap_y) : 1;
const float y_f_1 = (overlap_y > 0 && y < (img_height - height)) ? (height - iy) / float(overlap_y) : 1;
const float x_f_0 = (circular_x || (overlap_x > 0 && x > 0)) ? (ix - x_skip) / float(overlap_x) : 1;
const float x_f_1 = (circular_x || (overlap_x > 0 && x < (img_width - width))) ? (width - ix) / float(overlap_x) : 1;
const float y_f_0 = (circular_y || (overlap_y > 0 && y > 0)) ? (iy - y_skip) / float(overlap_y) : 1;
const float y_f_1 = (circular_y || (overlap_y > 0 && y < (img_height - height))) ? (height - iy) / float(overlap_y) : 1;

const float x_f = std::min(std::min(x_f_0, x_f_1), 1.f);
const float y_f = std::min(std::min(y_f_0, y_f_1), 1.f);

ggml_ext_tensor_set_f32(
output,
old_value + new_value * smootherstep_f32(y_f) * smootherstep_f32(x_f),
x + ix, y + iy, k, l);
(x + ix) % img_width, (y + iy) % img_height, k, l);
} else {
ggml_ext_tensor_set_f32(output, new_value, x + ix, y + iy, k, l);
ggml_ext_tensor_set_f32(output, new_value, (x + ix) % img_width, (y + iy) % img_height, k, l);
}
}
}
Expand Down Expand Up @@ -773,10 +779,31 @@ __STATIC_INLINE__ void sd_tiling_calc_tiles(int& num_tiles_dim,
float& tile_overlap_factor_dim,
int small_dim,
int tile_size,
const float tile_overlap_factor) {
const float tile_overlap_factor,
bool circular) {
int tile_overlap = static_cast<int>(tile_size * tile_overlap_factor);
int non_tile_overlap = tile_size - tile_overlap;

if (circular) {
// circular means the last and first tile are overlapping (wraping around)
num_tiles_dim = small_dim / non_tile_overlap;

if (num_tiles_dim < 1) {
num_tiles_dim = 1;
}

tile_overlap_factor_dim = (tile_size - small_dim / num_tiles_dim) / (float)tile_size;

// if single tile and tile_overlap_factor is not 0, add one to ensure we have at least two overlapping tiles
if (num_tiles_dim == 1 && tile_overlap_factor_dim > 0) {
num_tiles_dim++;
tile_overlap_factor_dim = 0.5;
}

return;
}
// else, non-circular means the last and first tile are not overlapping

num_tiles_dim = (small_dim - tile_overlap) / non_tile_overlap;
int overshoot_dim = ((num_tiles_dim + 1) * non_tile_overlap + tile_overlap) % small_dim;

Expand Down Expand Up @@ -805,6 +832,8 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
const int p_tile_size_x,
const int p_tile_size_y,
const float tile_overlap_factor,
const bool circular_x,
const bool circular_y,
on_tile_process on_processing) {
output = ggml_set_f32(output, 0);

Expand All @@ -829,11 +858,11 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,

int num_tiles_x;
float tile_overlap_factor_x;
sd_tiling_calc_tiles(num_tiles_x, tile_overlap_factor_x, small_width, p_tile_size_x, tile_overlap_factor);
sd_tiling_calc_tiles(num_tiles_x, tile_overlap_factor_x, small_width, p_tile_size_x, tile_overlap_factor, circular_x);

int num_tiles_y;
float tile_overlap_factor_y;
sd_tiling_calc_tiles(num_tiles_y, tile_overlap_factor_y, small_height, p_tile_size_y, tile_overlap_factor);
sd_tiling_calc_tiles(num_tiles_y, tile_overlap_factor_y, small_height, p_tile_size_y, tile_overlap_factor, circular_y);

LOG_DEBUG("num tiles : %d, %d ", num_tiles_x, num_tiles_y);
LOG_DEBUG("optimal overlap : %f, %f (targeting %f)", tile_overlap_factor_x, tile_overlap_factor_y, tile_overlap_factor);
Expand Down Expand Up @@ -887,7 +916,7 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
float last_time = 0.0f;
for (int y = 0; y < small_height && !last_y; y += non_tile_overlap_y) {
int dy = 0;
if (y + tile_size_y >= small_height) {
if (!circular_y && y + tile_size_y >= small_height) {
int _y = y;
y = small_height - tile_size_y;
dy = _y - y;
Expand All @@ -898,7 +927,7 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
}
for (int x = 0; x < small_width && !last_x; x += non_tile_overlap_x) {
int dx = 0;
if (x + tile_size_x >= small_width) {
if (!circular_x && x + tile_size_x >= small_width) {
int _x = x;
x = small_width - tile_size_x;
dx = _x - x;
Expand All @@ -919,7 +948,7 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
int64_t t1 = ggml_time_ms();
ggml_ext_tensor_split_2d(input, input_tile, x_in, y_in);
if (on_processing(input_tile, output_tile, false)) {
ggml_ext_tensor_merge_2d(output_tile, output, x_out, y_out, overlap_x_out, overlap_y_out, dx, dy);
ggml_ext_tensor_merge_2d(output_tile, output, x_out, y_out, overlap_x_out, overlap_y_out, circular_x, circular_y, dx, dy);

int64_t t2 = ggml_time_ms();
last_time = (t2 - t1) / 1000.0f;
Expand All @@ -942,8 +971,10 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input,
const int scale,
const int tile_size,
const float tile_overlap_factor,
const bool circular_x,
const bool circular_y,
on_tile_process on_processing) {
sd_tiling_non_square(input, output, scale, tile_size, tile_size, tile_overlap_factor, on_processing);
sd_tiling_non_square(input, output, scale, tile_size, tile_size, tile_overlap_factor, circular_x, circular_y, on_processing);
}

__STATIC_INLINE__ struct ggml_tensor* ggml_ext_group_norm_32(struct ggml_context* ctx,
Expand Down
64 changes: 51 additions & 13 deletions src/stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ class StableDiffusionGGML {
bool external_vae_is_invalid = false;
bool free_params_immediately = false;

bool circular_x = false;
bool circular_y = false;

std::shared_ptr<RNG> rng = std::make_shared<PhiloxRNG>();
std::shared_ptr<RNG> sampler_rng = nullptr;
int n_threads = -1;
Expand Down Expand Up @@ -749,12 +752,8 @@ class StableDiffusionGGML {
if (control_net) {
control_net->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
}
if (first_stage_model) {
first_stage_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
}
if (tae_first_stage) {
tae_first_stage->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
}
circular_x = sd_ctx_params->circular_x;
circular_y = sd_ctx_params->circular_y;
}

struct ggml_init_params params;
Expand Down Expand Up @@ -1456,7 +1455,7 @@ class StableDiffusionGGML {
sd_progress_cb_t cb = sd_get_progress_callback();
void* cbd = sd_get_progress_callback_data();
sd_set_progress_callback((sd_progress_cb_t)suppress_pp, nullptr);
sd_tiling(input, output, scale, tile_size, tile_overlap_factor, on_processing);
sd_tiling(input, output, scale, tile_size, tile_overlap_factor, circular_x, circular_y, on_processing);
sd_set_progress_callback(cb, cbd);
}

Expand Down Expand Up @@ -2547,7 +2546,7 @@ class StableDiffusionGGML {
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
return first_stage_model->compute(n_threads, in, false, &out, work_ctx);
};
sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, on_tiling);
sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, circular_x, circular_y, on_tiling);
} else {
first_stage_model->compute(n_threads, x, false, &result, work_ctx);
}
Expand All @@ -2558,7 +2557,7 @@ class StableDiffusionGGML {
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
return tae_first_stage->compute(n_threads, in, false, &out, nullptr);
};
sd_tiling(x, result, vae_scale_factor, 64, 0.5f, on_tiling);
sd_tiling(x, result, vae_scale_factor, 64, 0.5f, circular_x, circular_y, on_tiling);
} else {
tae_first_stage->compute(n_threads, x, false, &result, work_ctx);
}
Expand Down Expand Up @@ -2676,7 +2675,7 @@ class StableDiffusionGGML {
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
return first_stage_model->compute(n_threads, in, true, &out, nullptr);
};
sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, on_tiling);
sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, circular_x, circular_y, on_tiling);
} else {
if (!first_stage_model->compute(n_threads, x, true, &result, work_ctx)) {
LOG_ERROR("Failed to decode latetnts");
Expand All @@ -2692,7 +2691,7 @@ class StableDiffusionGGML {
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
return tae_first_stage->compute(n_threads, in, true, &out);
};
sd_tiling(x, result, vae_scale_factor, 64, 0.5f, on_tiling);
sd_tiling(x, result, vae_scale_factor, 64, 0.5f, circular_x, circular_y, on_tiling);
} else {
if (!tae_first_stage->compute(n_threads, x, true, &result)) {
LOG_ERROR("Failed to decode latetnts");
Expand Down Expand Up @@ -3495,8 +3494,9 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,

sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params) {
sd_ctx->sd->vae_tiling_params = sd_img_gen_params->vae_tiling_params;
int width = sd_img_gen_params->width;
int height = sd_img_gen_params->height;

int width = sd_img_gen_params->width;
int height = sd_img_gen_params->height;

int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor();
int diffusion_model_down_factor = sd_ctx->sd->get_diffusion_model_down_factor();
Expand All @@ -3510,6 +3510,40 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
LOG_WARN("align up %dx%d to %dx%d (multiple=%d)", sd_img_gen_params->width, sd_img_gen_params->height, width, height, spatial_multiple);
}

bool circular_x = sd_ctx->sd->circular_x;
bool circular_y = sd_ctx->sd->circular_y;

if (!sd_img_gen_params->vae_tiling_params.enabled) {
if (sd_ctx->sd->first_stage_model) {
sd_ctx->sd->first_stage_model->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y);
}
if (sd_ctx->sd->tae_first_stage) {
sd_ctx->sd->tae_first_stage->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y);
}
} else {
int tile_size_x, tile_size_y;
float _overlap;
int latent_size_x = width / sd_ctx->sd->get_vae_scale_factor();
int latent_size_y = height / sd_ctx->sd->get_vae_scale_factor();
sd_ctx->sd->get_tile_sizes(tile_size_x, tile_size_y, _overlap, sd_img_gen_params->vae_tiling_params, latent_size_x, latent_size_y);

// force disable circular padding for vae if tiling is enabled unless latent is smaller than tile size
// otherwise it will cause artifacts at the edges of the tiles
sd_ctx->sd->circular_x = sd_ctx->sd->circular_x && (tile_size_x >= latent_size_x);
sd_ctx->sd->circular_y = sd_ctx->sd->circular_y && (tile_size_y >= latent_size_y);

if (sd_ctx->sd->first_stage_model) {
sd_ctx->sd->first_stage_model->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y);
}
if (sd_ctx->sd->tae_first_stage) {
sd_ctx->sd->tae_first_stage->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y);
}

// disable circular tiling if it's enabled for the VAE
sd_ctx->sd->circular_x = circular_x && (tile_size_x < latent_size_x);
sd_ctx->sd->circular_y = circular_y && (tile_size_y < latent_size_y);
}

LOG_DEBUG("generate_image %dx%d", width, height);
if (sd_ctx == nullptr || sd_img_gen_params == nullptr) {
return nullptr;
Expand Down Expand Up @@ -3779,6 +3813,10 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
denoise_mask,
&sd_img_gen_params->cache);

// restore circular params
sd_ctx->sd->circular_x = circular_x;
sd_ctx->sd->circular_y = circular_y;

size_t t2 = ggml_time_ms();

LOG_INFO("generate_image completed in %.2fs", (t2 - t0) * 1.0f / 1000);
Expand Down
3 changes: 2 additions & 1 deletion src/upscaler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ struct UpscalerGGML {
return esrgan_upscaler->compute(n_threads, in, &out);
};
int64_t t0 = ggml_time_ms();
sd_tiling(input_image_tensor, upscaled, esrgan_upscaler->scale, esrgan_upscaler->tile_size, 0.25f, on_tiling);
// TODO: circular upscaling?
sd_tiling(input_image_tensor, upscaled, esrgan_upscaler->scale, esrgan_upscaler->tile_size, 0.25f, false, false, on_tiling);
esrgan_upscaler->free_compute_buffer();
ggml_ext_tensor_clamp_inplace(upscaled, 0.f, 1.f);
uint8_t* upscaled_data = ggml_tensor_to_sd_image(upscaled);
Expand Down