Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions flue-core/src/flux/sampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ pub fn get_noise(
width: usize,
device: &Device,
) -> Result<Tensor> {
let height = (height + 15) / 16 * 2;
let width = (width + 15) / 16 * 2;
let height = height.div_ceil(16) * 2;
let width = width.div_ceil(16) * 2;
Tensor::randn(0f32, 1., (num_samples, 16, height, width), device)
}

Expand Down Expand Up @@ -86,8 +86,8 @@ pub fn get_schedule(num_steps: usize, shift: Option<(usize, f64, f64)>) -> Vec<f

pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result<Tensor> {
let (b, _h_w, c_ph_pw) = xs.dims3()?;
let height = (height + 15) / 16;
let width = (width + 15) / 16;
let height = height.div_ceil(16);
let width = width.div_ceil(16);
xs.reshape((b, height, width, c_ph_pw / 4, 2, 2))? // (b, h, w, c, ph, pw)
.permute((0, 3, 1, 4, 2, 5))? // (b, c, h, ph, w, pw)
.reshape((b, c_ph_pw / 4, height * 2, width * 2))
Expand Down
11 changes: 1 addition & 10 deletions flue-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ license.workspace = true
homepage.workspace = true

[dependencies]
flue-core = { version = "0.1.0", path = "../flue-core" }
flue-core = { path = "../flue-core" }
anyhow = { workspace = true }
axum = { workspace = true }
base64 = { workspace = true }
Expand All @@ -21,12 +21,3 @@ image = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
tokio = { workspace = true }

[features]
cuda = ["flue-core/cuda"]
cudnn = ["flue-core/cudnn"]
metal = ["flue-core/metal"]
flash-attn-v2 = ["cuda", "flue-core/flash-attn-v2"]
flash-attn-v3 = ["cuda", "flue-core/flash-attn-v3"]
accelerate = ["flue-core/accelerate"]
mkl = ["flue-core/mkl"]
15 changes: 12 additions & 3 deletions flue-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ fn image_to_base64_png(img: &DynamicImage) -> Result<String> {
#[derive(Serialize)]
struct GenerationResponse {
image: String,
gen_time: f64, // Time in seconds
}

// Application state containing the preloaded models and device settings.
Expand All @@ -57,7 +58,11 @@ async fn generate_image_handler(
Json(req): Json<GenerationRequest>,
) -> impl IntoResponse {
match generate_image(req, &state).await {
Ok(img_base64) => Json(GenerationResponse { image: img_base64 }).into_response(),
Ok((img_base64, gen_time)) => Json(GenerationResponse {
image: img_base64,
gen_time,
})
.into_response(),
Err(e) => {
eprintln!("Error generating image: {:?}", e);
(StatusCode::INTERNAL_SERVER_ERROR, format!("Error: {:?}", e)).into_response()
Expand All @@ -66,9 +71,13 @@ async fn generate_image_handler(
}

/// This function uses the preloaded models from `state` to generate an image (base64).
async fn generate_image(params: GenerationRequest, state: &AppState) -> Result<String> {
/// Returns both the base64 image and the generation time in seconds.
async fn generate_image(params: GenerationRequest, state: &AppState) -> Result<(String, f64)> {
let start_time = std::time::Instant::now();
let image = state.0.run(params)?;
image_to_base64_png(&image)
let gen_time = start_time.elapsed().as_secs_f64();
let base64_image = image_to_base64_png(&image)?;
Ok((base64_image, gen_time))
}

#[tokio::main]
Expand Down
Loading