Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 108 additions & 64 deletions Cargo.lock

Large diffs are not rendered by default.

14 changes: 10 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,19 @@ accelerate-src = { version = "0.3.2" }
anyhow = "1.0.97"
axum = "0.8.3"
base64 = "0.22.1"
# TODO: Uncomment when candle 0.9.0-alpha.2 is pushed to crates.io
# candle-core = { version = "0.9.0-alpha.2" }
# candle-flash-attn = { version = "0.9.0-alpha.2" }
# candle-nn = { version = "0.9.0-alpha.2" }
# candle-transformers = { version = "0.9.0-alpha.2" }
candle-core = { git = "https://github.com/huggingface/candle", branch = "main" }
candle-flash-attn = { git = "https://github.com/huggingface/candle", branch = "main" }
candle-nn = { git = "https://github.com/huggingface/candle", branch = "main" }
candle-transformers = { git = "https://github.com/huggingface/candle", branch = "main" }
clap = { version = "4.5.34", features = ["derive"] }
flue-core = { path = "./flue-core", version = "0.1.0" }
candle-core = { version = "0.8.4" }
flue-flash-attn-v2 = { path = "./flue-flash-attn-v2", version = "0.8.0" }
flue-flash-attn-v3 = { path = "./flue-flash-attn-v3", version = "0.8.0" }
candle-nn = { version = "0.8.4" }
candle-transformers = { version = "0.8.4" }
clap = { version = "4.5.34", features = ["derive"] }
hf-hub = { version = "0.4.2", default-features = false, features = ["ureq", "tokio", "rustls-tls"] }
image = "0.25.6"
intel-mkl-src = { version = "0.8.1" }
Expand Down
7 changes: 4 additions & 3 deletions flue-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ homepage.workspace = true
accelerate-src = { workspace = true, optional = true }
anyhow = { workspace = true }
candle-core = { workspace = true }
flue-flash-attn-v2 = { workspace = true, optional = true }
flue-flash-attn-v3 = { workspace = true, optional = true }
candle-flash-attn = { workspace = true, optional = true }
candle-nn = { workspace = true }
candle-transformers = { workspace = true }
flue-flash-attn-v2 = { workspace = true, optional = true }
flue-flash-attn-v3 = { workspace = true, optional = true }
hf-hub = { workspace = true }
image = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
Expand All @@ -32,8 +33,8 @@ serde_plain = { workspace = true }
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
cudnn = ["candle-core/cudnn"]
metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal"]
flash-attn = ["candle-flash-attn"]
flash-attn-v2 = ["cuda", "flue-flash-attn-v2"]
flash-attn-v3 = ["cuda", "flue-flash-attn-v3"]
accelerate = ["candle-core/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate", "dep:accelerate-src"]
mkl = ["candle-core/mkl", "candle-nn/mkl", "candle-transformers/mkl", "dep:intel-mkl-src"]

28 changes: 19 additions & 9 deletions flue-core/src/flux/model.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use candle_core::{DType, IndexOp, Result, Tensor, D};
use candle_core::{DType, IndexOp, Module, Result, Tensor, D};
use candle_nn::{LayerNorm, Linear, RmsNorm, VarBuilder};

// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/model.py#L12
Expand All @@ -18,7 +18,6 @@ pub struct Config {
pub guidance_embed: bool,
}

#[allow(dead_code)]
impl Config {
// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L32
pub fn dev() -> Self {
Expand Down Expand Up @@ -62,7 +61,11 @@ fn layer_norm(dim: usize, vb: VarBuilder) -> Result<LayerNorm> {
Ok(LayerNorm::new_no_bias(ws, 1e-6))
}

#[cfg(any(feature = "flash-attn-v2", feature = "flash-attn-v3"))]
#[cfg(any(
feature = "flash-attn-v2",
feature = "flash-attn-v3",
feature = "flash-attn"
))]
fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
let dim = q.dim(D::Minus1)?;
let scale_factor = 1.0 / (dim as f64).sqrt();
Expand All @@ -74,6 +77,9 @@ fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Te
let v = v.transpose(1, 2)?;

// Use appropriate flash attention function
#[cfg(feature = "flash-attn")]
let attn = candle_flash_attn::flash_attn(&q, &k, &v, scale_factor as f32, false)?;

#[cfg(feature = "flash-attn-v2")]
let attn = flue_flash_attn_v2::flash_attn(&q, &k, &v, scale_factor as f32, false)?;

Expand All @@ -84,7 +90,11 @@ fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Te
attn.transpose(1, 2)?.reshape(batch_dims)
}

#[cfg(not(any(feature = "flash-attn-v2", feature = "flash-attn-v3")))]
#[cfg(not(any(
feature = "flash-attn-v2",
feature = "flash-attn-v3",
feature = "flash-attn"
)))]
fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
let dim = q.dim(D::Minus1)?;
let scale_factor = 1.0 / (dim as f64).sqrt();
Expand Down Expand Up @@ -149,11 +159,11 @@ pub(crate) fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result
let dev = t.device();
let half = dim / 2;
let t = (t * TIME_FACTOR)?;
let arange = Tensor::arange(0, half as u32, dev)?.to_dtype(candle_core::DType::F32)?;
let arange = Tensor::arange(0, half as u32, dev)?.to_dtype(DType::F32)?;
let freqs = (arange * (-MAX_PERIOD.ln() / half as f64))?.exp()?;
let args = t
.unsqueeze(1)?
.to_dtype(candle_core::DType::F32)?
.to_dtype(DType::F32)?
.broadcast_mul(&freqs.unsqueeze(0)?)?;
let emb = Tensor::cat(&[args.cos()?, args.sin()?], D::Minus1)?.to_dtype(dtype)?;
Ok(emb)
Expand All @@ -177,7 +187,7 @@ impl EmbedNd {
}
}

impl candle_core::Module for EmbedNd {
impl Module for EmbedNd {
fn forward(&self, ids: &Tensor) -> Result<Tensor> {
let n_axes = ids.dim(D::Minus1)?;
let mut emb = Vec::with_capacity(n_axes);
Expand Down Expand Up @@ -211,7 +221,7 @@ impl MlpEmbedder {
}
}

impl candle_core::Module for MlpEmbedder {
impl Module for MlpEmbedder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.in_layer)?.silu()?.apply(&self.out_layer)
}
Expand Down Expand Up @@ -370,7 +380,7 @@ impl Mlp {
}
}

impl candle_core::Module for Mlp {
impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.lin1)?.gelu()?.apply(&self.lin2)
}
Expand Down
8 changes: 6 additions & 2 deletions flue-flash-attn-v2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ readme = "README.md"
repository = "https://github.com/Apsu/flue"

[dependencies]
candle-core = { version = "0.8.4", features = ["cuda"] }
# TODO: Uncomment when candle 0.9.0-alpha.2 is pushed to crates.io
# candle-core = { version = "0.9.0-alpha.2", features = ["cuda"] }
candle-core = { git = "https://github.com/huggingface/candle", branch = "main", features = ["cuda"] }
half = { version = "2.3.1", features = ["num-traits"] }

[build-dependencies]
Expand All @@ -21,4 +23,6 @@ anyhow = { version = "1", features = ["backtrace"] }

[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
candle-nn = { version = "0.8.4", features = ["cuda"] }
# TODO: Uncomment when candle 0.9.0-alpha.2 is pushed to crates.io
# candle-nn = { version = "0.9.0-alpha.2", features = ["cuda"] }
candle-nn = { git = "https://github.com/huggingface/candle", branch = "main", features = ["cuda"] }
59 changes: 31 additions & 28 deletions flue-flash-attn-v2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ mod ffi;

use candle_core::backend::BackendStorage;
use candle_core::cuda_backend::cudarc::driver::DevicePtr;
use candle_core::cuda_backend::WrapErr;
use candle_core::{CpuStorage, DType, Layout, Result, Shape, Tensor};
use half::{bf16, f16};

Expand Down Expand Up @@ -91,6 +90,7 @@ impl FlashAttn {
candle_core::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}")
}

let stream = dev.cuda_stream();
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
if alibi_slopes.dtype() != DType::F32 {
candle_core::bail!(
Expand All @@ -117,7 +117,8 @@ impl FlashAttn {

let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);

*alibi_slopes.device_ptr() as *const core::ffi::c_void
let (ptr, _guard) = alibi_slopes.device_ptr(&stream);
ptr as *const core::ffi::c_void
} else {
std::ptr::null()
};
Expand Down Expand Up @@ -164,17 +165,17 @@ impl FlashAttn {
}

unsafe {
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
let v_ptr = *v.device_ptr() as *const core::ffi::c_void;
let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void;
let softmax_lse_ptr = *softmax_lse.device_ptr() as *const core::ffi::c_void;
let (q_ptr, _guard) = q.device_ptr(&stream);
let (k_ptr, _guard) = k.device_ptr(&stream);
let (v_ptr, _guard) = v.device_ptr(&stream);
let (dst_ptr, _guard) = dst.device_ptr(&stream);
let (softmax_lse_ptr, _guard) = softmax_lse.device_ptr(&stream);
ffi::run_mha(
q_ptr,
k_ptr,
v_ptr,
dst_ptr,
softmax_lse_ptr,
q_ptr as *const core::ffi::c_void,
k_ptr as *const core::ffi::c_void,
v_ptr as *const core::ffi::c_void,
dst_ptr as *const core::ffi::c_void,
softmax_lse_ptr as *const core::ffi::c_void,
/* alibi_slopes_ptr */ alibi_slopes_ptr,
/* cu_seqlens_q_ptr */ std::ptr::null(),
/* cu_seqlens_k_ptr */ std::ptr::null(),
Expand Down Expand Up @@ -556,6 +557,7 @@ impl FlashAttnVarLen {

let batch_size = nseqlens_q - 1;

let stream = dev.cuda_stream();
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
if alibi_slopes.dtype() != DType::F32 {
candle_core::bail!(
Expand All @@ -582,7 +584,8 @@ impl FlashAttnVarLen {

let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);

*alibi_slopes.device_ptr() as *const core::ffi::c_void
let (ptr, _guard) = alibi_slopes.device_ptr(&stream);
ptr as *const core::ffi::c_void
} else {
std::ptr::null()
};
Expand Down Expand Up @@ -629,22 +632,22 @@ impl FlashAttnVarLen {
}

unsafe {
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
let v_ptr = *v.device_ptr() as *const core::ffi::c_void;
let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void;
let softmax_lse_ptr = *softmax_lse.device_ptr() as *const core::ffi::c_void;
let seqlens_q_ptr = *seqlens_q.device_ptr() as *const core::ffi::c_int;
let seqlens_k_ptr = *seqlens_k.device_ptr() as *const core::ffi::c_int;
let (q_ptr, _guard) = q.device_ptr(&stream);
let (k_ptr, _guard) = k.device_ptr(&stream);
let (v_ptr, _guard) = v.device_ptr(&stream);
let (dst_ptr, _guard) = dst.device_ptr(&stream);
let (softmax_lse_ptr, _guard) = softmax_lse.device_ptr(&stream);
let (seqlens_q_ptr, _guard) = seqlens_q.device_ptr(&stream);
let (seqlens_k_ptr, _guard) = seqlens_k.device_ptr(&stream);
ffi::run_mha(
q_ptr,
k_ptr,
v_ptr,
dst_ptr,
softmax_lse_ptr,
/* alibi_slopes_ptr */ alibi_slopes_ptr,
/* cu_seqlens_q_ptr */ seqlens_q_ptr,
/* cu_seqlens_k_ptr */ seqlens_k_ptr,
q_ptr as *const core::ffi::c_void,
k_ptr as *const core::ffi::c_void,
v_ptr as *const core::ffi::c_void,
dst_ptr as *const core::ffi::c_void,
softmax_lse_ptr as *const core::ffi::c_void,
/* alibi_slopes_ptr */ alibi_slopes_ptr as *const core::ffi::c_void,
/* cu_seqlens_q_ptr */ seqlens_q_ptr as *const i32,
/* cu_seqlens_k_ptr */ seqlens_k_ptr as *const i32,
/* q_batch_stride */ 0,
/* k_batch_stride */ 0,
/* v_batch_stride */ 0,
Expand Down
10 changes: 7 additions & 3 deletions flue-flash-attn-v3/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ readme = "README.md"
repository = "https://github.com/Apsu/flue"

[dependencies]
candle-core = { version = "0.8.4", features = ["cuda"] }
# TODO: Uncomment when candle 0.9.0-alpha.2 is pushed to crates.io
# candle-core = { version = "0.9.0-alpha.2", features = ["cuda"] }
candle-core = { git = "https://github.com/huggingface/candle", branch = "main", features = ["cuda"] }
half = { version = "2.3.1", features = ["num-traits"] }

[build-dependencies]
Expand All @@ -21,5 +23,7 @@ rayon = "1.7.0"

[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
candle-nn = { version = "0.8.4", features = ["cuda"] }
rstest = "0.23"
# TODO: Uncomment when candle 0.9.0-alpha.2 is pushed to crates.io
# candle-nn = { version = "0.9.0-alpha.2", features = ["cuda"] }
candle-nn = { git = "https://github.com/huggingface/candle", branch = "main", features = ["cuda"] }
rstest = "0.23"
Loading
Loading