diff --git a/flue-core/src/flux/sampling.rs b/flue-core/src/flux/sampling.rs index afa0008..540ca57 100644 --- a/flue-core/src/flux/sampling.rs +++ b/flue-core/src/flux/sampling.rs @@ -8,8 +8,8 @@ pub fn get_noise( width: usize, device: &Device, ) -> Result { - let height = (height + 15) / 16 * 2; - let width = (width + 15) / 16 * 2; + let height = height.div_ceil(16) * 2; + let width = width.div_ceil(16) * 2; Tensor::randn(0f32, 1., (num_samples, 16, height, width), device) } @@ -86,8 +86,8 @@ pub fn get_schedule(num_steps: usize, shift: Option<(usize, f64, f64)>) -> Vec Result { let (b, _h_w, c_ph_pw) = xs.dims3()?; - let height = (height + 15) / 16; - let width = (width + 15) / 16; + let height = height.div_ceil(16); + let width = width.div_ceil(16); xs.reshape((b, height, width, c_ph_pw / 4, 2, 2))? // (b, h, w, c, ph, pw) .permute((0, 3, 1, 4, 2, 5))? // (b, c, h, ph, w, pw) .reshape((b, c_ph_pw / 4, height * 2, width * 2)) diff --git a/flue-server/Cargo.toml b/flue-server/Cargo.toml index 306f242..b0cede0 100644 --- a/flue-server/Cargo.toml +++ b/flue-server/Cargo.toml @@ -11,7 +11,7 @@ license.workspace = true homepage.workspace = true [dependencies] -flue-core = { version = "0.1.0", path = "../flue-core" } +flue-core = { path = "../flue-core" } anyhow = { workspace = true } axum = { workspace = true } base64 = { workspace = true } @@ -21,12 +21,3 @@ image = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } tokio = { workspace = true } - -[features] -cuda = ["flue-core/cuda"] -cudnn = ["flue-core/cudnn"] -metal = ["flue-core/metal"] -flash-attn-v2 = ["cuda", "flue-core/flash-attn-v2"] -flash-attn-v3 = ["cuda", "flue-core/flash-attn-v3"] -accelerate = ["flue-core/accelerate"] -mkl = ["flue-core/mkl"] diff --git a/flue-server/src/main.rs b/flue-server/src/main.rs index 8d70e05..ad19b99 100644 --- a/flue-server/src/main.rs +++ b/flue-server/src/main.rs @@ -46,6 +46,7 @@ fn image_to_base64_png(img: &DynamicImage) -> Result { #[derive(Serialize)] struct GenerationResponse { image: String, + gen_time: f64, // Time in seconds } // Application state containing the preloaded models and device settings. @@ -57,7 +58,11 @@ async fn generate_image_handler( Json(req): Json, ) -> impl IntoResponse { match generate_image(req, &state).await { - Ok(img_base64) => Json(GenerationResponse { image: img_base64 }).into_response(), + Ok((img_base64, gen_time)) => Json(GenerationResponse { + image: img_base64, + gen_time, + }) + .into_response(), Err(e) => { eprintln!("Error generating image: {:?}", e); (StatusCode::INTERNAL_SERVER_ERROR, format!("Error: {:?}", e)).into_response() @@ -66,9 +71,13 @@ async fn generate_image_handler( } /// This function uses the preloaded models from `state` to generate an image (base64). -async fn generate_image(params: GenerationRequest, state: &AppState) -> Result { +/// Returns both the base64 image and the generation time in seconds. +async fn generate_image(params: GenerationRequest, state: &AppState) -> Result<(String, f64)> { + let start_time = std::time::Instant::now(); let image = state.0.run(params)?; - image_to_base64_png(&image) + let gen_time = start_time.elapsed().as_secs_f64(); + let base64_image = image_to_base64_png(&image)?; + Ok((base64_image, gen_time)) } #[tokio::main]