diff --git a/Cargo.lock b/Cargo.lock index b0fedc2..ab2bad0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -376,9 +376,8 @@ checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" [[package]] name = "candle-core" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06ccf5ee3532e66868516d9b315f73aec9f34ea1a37ae98514534d458915dbf1" +version = "0.9.0-alpha.2" +source = "git+https://github.com/huggingface/candle?branch=main#d7b7ce16e47072c5561debbcd8c8cef07bbfbc86" dependencies = [ "accelerate-src", "byteorder", @@ -405,20 +404,29 @@ dependencies = [ "zip", ] +[[package]] +name = "candle-flash-attn" +version = "0.9.0-alpha.2" +source = "git+https://github.com/huggingface/candle?branch=main#d7b7ce16e47072c5561debbcd8c8cef07bbfbc86" +dependencies = [ + "anyhow", + "bindgen_cuda", + "candle-core", + "half", +] + [[package]] name = "candle-kernels" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a10885bd902fad1b8518ba2b22369aaed88a3d94e123533ad3ca73db33b1c8ca" +version = "0.9.0-alpha.2" +source = "git+https://github.com/huggingface/candle?branch=main#d7b7ce16e47072c5561debbcd8c8cef07bbfbc86" dependencies = [ "bindgen_cuda", ] [[package]] name = "candle-metal-kernels" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52c85c21827c28db94e7112e364abe7e0cf8d2b022c014edf08642be6b94f21e" +version = "0.9.0-alpha.2" +source = "git+https://github.com/huggingface/candle?branch=main#d7b7ce16e47072c5561debbcd8c8cef07bbfbc86" dependencies = [ "metal 0.27.0", "once_cell", @@ -428,9 +436,8 @@ dependencies = [ [[package]] name = "candle-nn" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be1160c3b63f47d40d91110a3e1e1e566ae38edddbbf492a60b40ffc3bc1ff38" +version = "0.9.0-alpha.2" +source = "git+https://github.com/huggingface/candle?branch=main#d7b7ce16e47072c5561debbcd8c8cef07bbfbc86" dependencies = [ "accelerate-src", "candle-core", @@ -447,9 +454,8 @@ dependencies = [ [[package]] name = "candle-transformers" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94a0900d49f8605e0e7e6693a1f560e6271279de98e5fa369e7abf3aac245020" +version = "0.9.0-alpha.2" +source = "git+https://github.com/huggingface/candle?branch=main#d7b7ce16e47072c5561debbcd8c8cef07bbfbc86" dependencies = [ "accelerate-src", "byteorder", @@ -468,9 +474,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.17" +version = "1.2.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fcb57c740ae1daf453ae85f16e37396f672b039e00d9d866e07ddb24e328e3a" +checksum = "8e3a13707ac958681c13b39b458c073d0d9bc8a22cb1b2f4c8e55eb72c13f362" dependencies = [ "jobserver", "libc", @@ -515,9 +521,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.34" +version = "4.5.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e958897981290da2a852763fe9cdb89cd36977a5d729023127095fa94d95e2ff" +checksum = "2df961d8c8a0d08aa9945718ccf584145eee3f3aa06cddbeac12933781102e04" dependencies = [ "clap_builder", "clap_derive", @@ -525,9 +531,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.34" +version = "4.5.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83b0f35019843db2160b5bb19ae09b4e6411ac33fc6a712003c33e03090e2489" +checksum = "132dbda40fb6753878316a489d5a1242a8ef2f0d9e47ba01c951ea8aa7d013a5" dependencies = [ "anstream", "anstyle", @@ -666,9 +672,9 @@ dependencies = [ [[package]] name = "cudarc" -version = "0.13.9" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "486c221362668c63a1636cfa51463b09574433b39029326cff40864b3ba12b6e" +checksum = "547fbfa56792e073e3368051016a6f3619462be4715e67855d4d07eec4176570" dependencies = [ "half", "libloading", @@ -862,9 +868,9 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" [[package]] name = "errno" -version = "0.3.10" +version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" +checksum = "976dd42dc7e85965fe702eb8164f21f450704bdde31faefd6471dba214cb594e" dependencies = [ "libc", "windows-sys 0.59.0", @@ -928,9 +934,9 @@ dependencies = [ [[package]] name = "flate2" -version = "1.1.0" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11faaf5a5236997af9848be0bef4db95824b1d534ebc64d0f0c6cf3e67bd38dc" +checksum = "7ced92e76e966ca2fd84c8f7aa01a4aea65b0eb6648d72f7c8f3e2764a67fece" dependencies = [ "crc32fast", "miniz_oxide", @@ -943,6 +949,7 @@ dependencies = [ "accelerate-src", "anyhow", "candle-core", + "candle-flash-attn", "candle-nn", "candle-transformers", "flue-flash-attn-v2", @@ -1455,9 +1462,9 @@ dependencies = [ [[package]] name = "half" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7db2ff139bba50379da6aa0766b52fdcb62cb5b263009b09ed58ba604e14bbd1" +checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9" dependencies = [ "bytemuck", "cfg-if", @@ -1595,9 +1602,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.10" +version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4" +checksum = "497bbc33a26fdd4af9ed9c70d63f61cf56a938375fbb32df34db9b1cd6d643f2" dependencies = [ "bytes", "futures-channel", @@ -1605,6 +1612,7 @@ dependencies = [ "http", "http-body", "hyper", + "libc", "pin-project-lite", "socket2", "tokio", @@ -1614,9 +1622,9 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.62" +version = "0.1.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2fd658b06e56721792c5df4475705b6cda790e9298d19d2f8af083457bcd127" +checksum = "b0c919e5debc312ad217002b8048a17b7d83f80703865bbfcfebb0458b0b27d8" dependencies = [ "android_system_properties", "core-foundation-sys", @@ -1822,9 +1830,9 @@ checksum = "d0263a3d970d5c054ed9312c0057b4f3bde9c0b33836d3637361d4a9e6e7a408" [[package]] name = "indexmap" -version = "2.8.0" +version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3954d50fe15b02142bf25d3b8bdadb634ec3948f103d04ffe3031bc8fe9d7058" +checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e" dependencies = [ "equivalent", "hashbrown", @@ -1923,10 +1931,11 @@ checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" [[package]] name = "jobserver" -version = "0.1.32" +version = "0.1.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +checksum = "38f262f097c174adebe41eb73d66ae9c06b2844fb0da69969647bbddd9b0538a" dependencies = [ + "getrandom 0.3.2", "libc", ] @@ -2003,9 +2012,9 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.9.3" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe7db12097d22ec582439daf8618b8fdd1a7bef6270e9af3b1ebcd30893cf413" +checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" [[package]] name = "litemap" @@ -2129,9 +2138,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.8.5" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e3e04debbb59698c15bacbb6d93584a8c0ca9cc3213cb423d31f760d8843ce5" +checksum = "3be647b768db090acb35d5ec5db2b0e1f1de11133ca123b9eacf5137868f892a" dependencies = [ "adler2", "simd-adler32", @@ -2629,7 +2638,7 @@ dependencies = [ "once_cell", "socket2", "tracing", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -2824,9 +2833,9 @@ checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" [[package]] name = "redox_syscall" -version = "0.5.10" +version = "0.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b8c0c260b63a8219631167be35e6a988e9554dbd323f8bd08439c8ed1302bd1" +checksum = "d2f103c6d277498fbceb16e84d317e2a400f160f46904d5f5410848c829511a3" dependencies = [ "bitflags 2.9.0", ] @@ -2953,9 +2962,9 @@ checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" [[package]] name = "rustix" -version = "1.0.3" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e56a18552996ac8d29ecc3b190b4fdbb2d91ca4ec396de7bbffaf43f3d637e96" +checksum = "d97817398dd4bb2e6da002002db259209759911da105da92bec29ccb12cf58bf" dependencies = [ "bitflags 2.9.0", "errno", @@ -2966,9 +2975,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.25" +version = "0.23.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "822ee9188ac4ec04a2f0531e55d035fb2de73f18b41a63c70c2712503b6fb13c" +checksum = "df51b5869f3a441595eac5e8ff14d486ff285f7b8c0df8770e49c3b56351f0f0" dependencies = [ "log", "once_cell", @@ -3170,9 +3179,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.14.0" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fcf8323ef1faaee30a44a340193b1ac6814fd9b7b4e88e9d4519a3e4abe1cfd" +checksum = "8917285742e9f3e1683f0a9c4e6b57960b7314d0b08d30d1ecd426713ee2eee9" [[package]] name = "socket2" @@ -3465,9 +3474,9 @@ dependencies = [ [[package]] name = "tokio" -version = "1.44.1" +version = "1.44.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f382da615b842244d4b8738c82ed1275e6c5dd90c459a30941cd07080b06c91a" +checksum = "e6b88822cbe49de4185e3a4cbf8321dd487cf5fe0c5c65695fef6346371e9c48" dependencies = [ "backtrace", "bytes", @@ -3621,9 +3630,9 @@ checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" [[package]] name = "ug" -version = "0.1.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03719c61a91b51541f076dfdba45caacf750b230cefaa4b32d6f5411c3f7f437" +checksum = "d4bf09b7bd6c2b9a516a1918ebb7605705b5d1f852d0b4932b41164cae9437d8" dependencies = [ "gemm 0.18.2", "half", @@ -3642,9 +3651,9 @@ dependencies = [ [[package]] name = "ug-cuda" -version = "0.1.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50758486d7941f8b0a636ba7e29455c07071f41590beac1fd307ec893e8db69a" +checksum = "a68dfb95c5051313c3ff6e97e88b4b6f13c28c743dcf9a1a2048bb0ae344b753" dependencies = [ "cudarc", "half", @@ -3655,9 +3664,9 @@ dependencies = [ [[package]] name = "ug-metal" -version = "0.1.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a02ddc17bf32f7dcaaf016b6735f7198082b82f122df7b3ca15d8ead5911ccef" +checksum = "d1780751807948eecd462fa822dcde6f4e22e22bfbe086ade2c8c39aad756e7f" dependencies = [ "half", "metal 0.29.0", @@ -3972,11 +3981,37 @@ checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] name = "windows-core" -version = "0.52.0" +version = "0.61.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +checksum = "4763c1de310c86d75a878046489e2e5ba02c649d185f21c67d4cf8a56d098980" dependencies = [ - "windows-targets 0.52.6", + "windows-implement", + "windows-interface", + "windows-link", + "windows-result", + "windows-strings 0.4.0", +] + +[[package]] +name = "windows-implement" +version = "0.60.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-interface" +version = "0.59.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" +dependencies = [ + "proc-macro2", + "quote", + "syn", ] [[package]] @@ -3992,7 +4027,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4286ad90ddb45071efd1a66dfa43eb02dd0dfbae1545ad6cc3c51cf34d7e8ba3" dependencies = [ "windows-result", - "windows-strings", + "windows-strings 0.3.1", "windows-targets 0.53.0", ] @@ -4014,6 +4049,15 @@ dependencies = [ "windows-link", ] +[[package]] +name = "windows-strings" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2ba9642430ee452d5a7aa78d72907ebe8cfda358e8cb7918a2050581322f97" +dependencies = [ + "windows-link", +] + [[package]] name = "windows-sys" version = "0.48.0" @@ -4228,9 +4272,9 @@ checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" [[package]] name = "winnow" -version = "0.7.4" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e97b544156e9bebe1a0ffbc03484fc1ffe3100cbce3ffb17eac35f7cdd7ab36" +checksum = "63d3fcd9bba44b03821e7d699eeee959f3126dcc4aa8e4ae18ec617c2a5cea10" dependencies = [ "memchr", ] diff --git a/Cargo.toml b/Cargo.toml index 9b6bb53..dc8bc17 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,13 +24,19 @@ accelerate-src = { version = "0.3.2" } anyhow = "1.0.97" axum = "0.8.3" base64 = "0.22.1" +# TODO: Uncomment when candle 0.9.0-alpha.2 is pushed to crates.io +# candle-core = { version = "0.9.0-alpha.2" } +# candle-flash-attn = { version = "0.9.0-alpha.2" } +# candle-nn = { version = "0.9.0-alpha.2" } +# candle-transformers = { version = "0.9.0-alpha.2" } +candle-core = { git = "https://github.com/huggingface/candle", branch = "main" } +candle-flash-attn = { git = "https://github.com/huggingface/candle", branch = "main" } +candle-nn = { git = "https://github.com/huggingface/candle", branch = "main" } +candle-transformers = { git = "https://github.com/huggingface/candle", branch = "main" } +clap = { version = "4.5.34", features = ["derive"] } flue-core = { path = "./flue-core", version = "0.1.0" } -candle-core = { version = "0.8.4" } flue-flash-attn-v2 = { path = "./flue-flash-attn-v2", version = "0.8.0" } flue-flash-attn-v3 = { path = "./flue-flash-attn-v3", version = "0.8.0" } -candle-nn = { version = "0.8.4" } -candle-transformers = { version = "0.8.4" } -clap = { version = "4.5.34", features = ["derive"] } hf-hub = { version = "0.4.2", default-features = false, features = ["ureq", "tokio", "rustls-tls"] } image = "0.25.6" intel-mkl-src = { version = "0.8.1" } diff --git a/flue-core/Cargo.toml b/flue-core/Cargo.toml index d8741b5..d2d12da 100644 --- a/flue-core/Cargo.toml +++ b/flue-core/Cargo.toml @@ -14,10 +14,11 @@ homepage.workspace = true accelerate-src = { workspace = true, optional = true } anyhow = { workspace = true } candle-core = { workspace = true } -flue-flash-attn-v2 = { workspace = true, optional = true } -flue-flash-attn-v3 = { workspace = true, optional = true } +candle-flash-attn = { workspace = true, optional = true } candle-nn = { workspace = true } candle-transformers = { workspace = true } +flue-flash-attn-v2 = { workspace = true, optional = true } +flue-flash-attn-v3 = { workspace = true, optional = true } hf-hub = { workspace = true } image = { workspace = true } intel-mkl-src = { workspace = true, optional = true } @@ -32,8 +33,8 @@ serde_plain = { workspace = true } cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"] cudnn = ["candle-core/cudnn"] metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal"] +flash-attn = ["candle-flash-attn"] flash-attn-v2 = ["cuda", "flue-flash-attn-v2"] flash-attn-v3 = ["cuda", "flue-flash-attn-v3"] accelerate = ["candle-core/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate", "dep:accelerate-src"] mkl = ["candle-core/mkl", "candle-nn/mkl", "candle-transformers/mkl", "dep:intel-mkl-src"] - diff --git a/flue-core/src/flux/model.rs b/flue-core/src/flux/model.rs index 29cc422..fbd59bd 100644 --- a/flue-core/src/flux/model.rs +++ b/flue-core/src/flux/model.rs @@ -1,4 +1,4 @@ -use candle_core::{DType, IndexOp, Result, Tensor, D}; +use candle_core::{DType, IndexOp, Module, Result, Tensor, D}; use candle_nn::{LayerNorm, Linear, RmsNorm, VarBuilder}; // https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/model.py#L12 @@ -18,7 +18,6 @@ pub struct Config { pub guidance_embed: bool, } -#[allow(dead_code)] impl Config { // https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L32 pub fn dev() -> Self { @@ -62,7 +61,11 @@ fn layer_norm(dim: usize, vb: VarBuilder) -> Result { Ok(LayerNorm::new_no_bias(ws, 1e-6)) } -#[cfg(any(feature = "flash-attn-v2", feature = "flash-attn-v3"))] +#[cfg(any( + feature = "flash-attn-v2", + feature = "flash-attn-v3", + feature = "flash-attn" +))] fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result { let dim = q.dim(D::Minus1)?; let scale_factor = 1.0 / (dim as f64).sqrt(); @@ -74,6 +77,9 @@ fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result Result Result { let dim = q.dim(D::Minus1)?; let scale_factor = 1.0 / (dim as f64).sqrt(); @@ -149,11 +159,11 @@ pub(crate) fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result let dev = t.device(); let half = dim / 2; let t = (t * TIME_FACTOR)?; - let arange = Tensor::arange(0, half as u32, dev)?.to_dtype(candle_core::DType::F32)?; + let arange = Tensor::arange(0, half as u32, dev)?.to_dtype(DType::F32)?; let freqs = (arange * (-MAX_PERIOD.ln() / half as f64))?.exp()?; let args = t .unsqueeze(1)? - .to_dtype(candle_core::DType::F32)? + .to_dtype(DType::F32)? .broadcast_mul(&freqs.unsqueeze(0)?)?; let emb = Tensor::cat(&[args.cos()?, args.sin()?], D::Minus1)?.to_dtype(dtype)?; Ok(emb) @@ -177,7 +187,7 @@ impl EmbedNd { } } -impl candle_core::Module for EmbedNd { +impl Module for EmbedNd { fn forward(&self, ids: &Tensor) -> Result { let n_axes = ids.dim(D::Minus1)?; let mut emb = Vec::with_capacity(n_axes); @@ -211,7 +221,7 @@ impl MlpEmbedder { } } -impl candle_core::Module for MlpEmbedder { +impl Module for MlpEmbedder { fn forward(&self, xs: &Tensor) -> Result { xs.apply(&self.in_layer)?.silu()?.apply(&self.out_layer) } @@ -370,7 +380,7 @@ impl Mlp { } } -impl candle_core::Module for Mlp { +impl Module for Mlp { fn forward(&self, xs: &Tensor) -> Result { xs.apply(&self.lin1)?.gelu()?.apply(&self.lin2) } diff --git a/flue-flash-attn-v2/Cargo.toml b/flue-flash-attn-v2/Cargo.toml index d8bf93a..279fbc5 100644 --- a/flue-flash-attn-v2/Cargo.toml +++ b/flue-flash-attn-v2/Cargo.toml @@ -11,7 +11,9 @@ readme = "README.md" repository = "https://github.com/Apsu/flue" [dependencies] -candle-core = { version = "0.8.4", features = ["cuda"] } +# TODO: Uncomment when candle 0.9.0-alpha.2 is pushed to crates.io +# candle-core = { version = "0.9.0-alpha.2", features = ["cuda"] } +candle-core = { git = "https://github.com/huggingface/candle", branch = "main", features = ["cuda"] } half = { version = "2.3.1", features = ["num-traits"] } [build-dependencies] @@ -21,4 +23,6 @@ anyhow = { version = "1", features = ["backtrace"] } [dev-dependencies] anyhow = { version = "1", features = ["backtrace"] } -candle-nn = { version = "0.8.4", features = ["cuda"] } +# TODO: Uncomment when candle 0.9.0-alpha.2 is pushed to crates.io +# candle-nn = { version = "0.9.0-alpha.2", features = ["cuda"] } +candle-nn = { git = "https://github.com/huggingface/candle", branch = "main", features = ["cuda"] } diff --git a/flue-flash-attn-v2/src/lib.rs b/flue-flash-attn-v2/src/lib.rs index f4f213b..f05d2d9 100644 --- a/flue-flash-attn-v2/src/lib.rs +++ b/flue-flash-attn-v2/src/lib.rs @@ -2,7 +2,6 @@ mod ffi; use candle_core::backend::BackendStorage; use candle_core::cuda_backend::cudarc::driver::DevicePtr; -use candle_core::cuda_backend::WrapErr; use candle_core::{CpuStorage, DType, Layout, Result, Shape, Tensor}; use half::{bf16, f16}; @@ -91,6 +90,7 @@ impl FlashAttn { candle_core::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}") } + let stream = dev.cuda_stream(); let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes { if alibi_slopes.dtype() != DType::F32 { candle_core::bail!( @@ -117,7 +117,8 @@ impl FlashAttn { let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..); - *alibi_slopes.device_ptr() as *const core::ffi::c_void + let (ptr, _guard) = alibi_slopes.device_ptr(&stream); + ptr as *const core::ffi::c_void } else { std::ptr::null() }; @@ -164,17 +165,17 @@ impl FlashAttn { } unsafe { - let q_ptr = *q.device_ptr() as *const core::ffi::c_void; - let k_ptr = *k.device_ptr() as *const core::ffi::c_void; - let v_ptr = *v.device_ptr() as *const core::ffi::c_void; - let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void; - let softmax_lse_ptr = *softmax_lse.device_ptr() as *const core::ffi::c_void; + let (q_ptr, _guard) = q.device_ptr(&stream); + let (k_ptr, _guard) = k.device_ptr(&stream); + let (v_ptr, _guard) = v.device_ptr(&stream); + let (dst_ptr, _guard) = dst.device_ptr(&stream); + let (softmax_lse_ptr, _guard) = softmax_lse.device_ptr(&stream); ffi::run_mha( - q_ptr, - k_ptr, - v_ptr, - dst_ptr, - softmax_lse_ptr, + q_ptr as *const core::ffi::c_void, + k_ptr as *const core::ffi::c_void, + v_ptr as *const core::ffi::c_void, + dst_ptr as *const core::ffi::c_void, + softmax_lse_ptr as *const core::ffi::c_void, /* alibi_slopes_ptr */ alibi_slopes_ptr, /* cu_seqlens_q_ptr */ std::ptr::null(), /* cu_seqlens_k_ptr */ std::ptr::null(), @@ -556,6 +557,7 @@ impl FlashAttnVarLen { let batch_size = nseqlens_q - 1; + let stream = dev.cuda_stream(); let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes { if alibi_slopes.dtype() != DType::F32 { candle_core::bail!( @@ -582,7 +584,8 @@ impl FlashAttnVarLen { let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..); - *alibi_slopes.device_ptr() as *const core::ffi::c_void + let (ptr, _guard) = alibi_slopes.device_ptr(&stream); + ptr as *const core::ffi::c_void } else { std::ptr::null() }; @@ -629,22 +632,22 @@ impl FlashAttnVarLen { } unsafe { - let q_ptr = *q.device_ptr() as *const core::ffi::c_void; - let k_ptr = *k.device_ptr() as *const core::ffi::c_void; - let v_ptr = *v.device_ptr() as *const core::ffi::c_void; - let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void; - let softmax_lse_ptr = *softmax_lse.device_ptr() as *const core::ffi::c_void; - let seqlens_q_ptr = *seqlens_q.device_ptr() as *const core::ffi::c_int; - let seqlens_k_ptr = *seqlens_k.device_ptr() as *const core::ffi::c_int; + let (q_ptr, _guard) = q.device_ptr(&stream); + let (k_ptr, _guard) = k.device_ptr(&stream); + let (v_ptr, _guard) = v.device_ptr(&stream); + let (dst_ptr, _guard) = dst.device_ptr(&stream); + let (softmax_lse_ptr, _guard) = softmax_lse.device_ptr(&stream); + let (seqlens_q_ptr, _guard) = seqlens_q.device_ptr(&stream); + let (seqlens_k_ptr, _guard) = seqlens_k.device_ptr(&stream); ffi::run_mha( - q_ptr, - k_ptr, - v_ptr, - dst_ptr, - softmax_lse_ptr, - /* alibi_slopes_ptr */ alibi_slopes_ptr, - /* cu_seqlens_q_ptr */ seqlens_q_ptr, - /* cu_seqlens_k_ptr */ seqlens_k_ptr, + q_ptr as *const core::ffi::c_void, + k_ptr as *const core::ffi::c_void, + v_ptr as *const core::ffi::c_void, + dst_ptr as *const core::ffi::c_void, + softmax_lse_ptr as *const core::ffi::c_void, + /* alibi_slopes_ptr */ alibi_slopes_ptr as *const core::ffi::c_void, + /* cu_seqlens_q_ptr */ seqlens_q_ptr as *const i32, + /* cu_seqlens_k_ptr */ seqlens_k_ptr as *const i32, /* q_batch_stride */ 0, /* k_batch_stride */ 0, /* v_batch_stride */ 0, diff --git a/flue-flash-attn-v3/Cargo.toml b/flue-flash-attn-v3/Cargo.toml index dcc6b5a..1bce457 100644 --- a/flue-flash-attn-v3/Cargo.toml +++ b/flue-flash-attn-v3/Cargo.toml @@ -11,7 +11,9 @@ readme = "README.md" repository = "https://github.com/Apsu/flue" [dependencies] -candle-core = { version = "0.8.4", features = ["cuda"] } +# TODO: Uncomment when candle 0.9.0-alpha.2 is pushed to crates.io +# candle-core = { version = "0.9.0-alpha.2", features = ["cuda"] } +candle-core = { git = "https://github.com/huggingface/candle", branch = "main", features = ["cuda"] } half = { version = "2.3.1", features = ["num-traits"] } [build-dependencies] @@ -21,5 +23,7 @@ rayon = "1.7.0" [dev-dependencies] anyhow = { version = "1", features = ["backtrace"] } -candle-nn = { version = "0.8.4", features = ["cuda"] } -rstest = "0.23" \ No newline at end of file +# TODO: Uncomment when candle 0.9.0-alpha.2 is pushed to crates.io +# candle-nn = { version = "0.9.0-alpha.2", features = ["cuda"] } +candle-nn = { git = "https://github.com/huggingface/candle", branch = "main", features = ["cuda"] } +rstest = "0.23" diff --git a/flue-flash-attn-v3/src/lib.rs b/flue-flash-attn-v3/src/lib.rs index 6bc42c0..6508b66 100644 --- a/flue-flash-attn-v3/src/lib.rs +++ b/flue-flash-attn-v3/src/lib.rs @@ -2,7 +2,6 @@ mod ffi; use candle_core::backend::BackendStorage; use candle_core::cuda_backend::cudarc::driver::DevicePtr; -use candle_core::cuda_backend::WrapErr; use candle_core::{CpuStorage, DType, Layout, Result, Shape, Tensor}; use half::{bf16, f16}; @@ -98,6 +97,7 @@ impl FlashAttn { _ => 0, }; + let stream = dev.cuda_stream(); let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes { if alibi_slopes.dtype() != DType::F32 { candle_core::bail!( @@ -121,10 +121,10 @@ impl FlashAttn { candle_core::Storage::Cuda(c) => c.as_cuda_slice::()?, _ => candle_core::bail!("alibi_slopes must be a cuda tensor"), }; - let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..); - *alibi_slopes.device_ptr() as *const core::ffi::c_void + let (ptr, _guard) = alibi_slopes.device_ptr(&stream); + ptr as *const core::ffi::c_void } else { std::ptr::null() }; @@ -149,10 +149,8 @@ impl FlashAttn { let seqlen_k_rounded = round_multiple(seqlen_k, 128); let elem_count = out_shape.elem_count(); - let dst = unsafe { dev.alloc::(elem_count) }.w()?; - let softmax_lse = dev - .alloc_zeros::(b_sz * 128 * num_heads * seqlen_q) - .w()?; + let dst = unsafe { dev.alloc::(elem_count) }?; + let softmax_lse = dev.alloc_zeros::(b_sz * 128 * num_heads * seqlen_q)?; let is_bf16 = if is_bf16 { 1 } else { 0 }; @@ -171,17 +169,17 @@ impl FlashAttn { } unsafe { - let q_ptr = *q.device_ptr() as *const core::ffi::c_void; - let k_ptr = *k.device_ptr() as *const core::ffi::c_void; - let v_ptr = *v.device_ptr() as *const core::ffi::c_void; - let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void; - let softmax_lse_ptr = *softmax_lse.device_ptr() as *const core::ffi::c_void; + let (q_ptr, _guard) = q.device_ptr(&stream); + let (k_ptr, _guard) = k.device_ptr(&stream); + let (v_ptr, _guard) = v.device_ptr(&stream); + let (dst_ptr, _guard) = dst.device_ptr(&stream); + let (softmax_lse_ptr, _guard) = softmax_lse.device_ptr(&stream); ffi::run_mha( - q_ptr, - k_ptr, - v_ptr, - dst_ptr, - softmax_lse_ptr, + q_ptr as *const core::ffi::c_void, + k_ptr as *const core::ffi::c_void, + v_ptr as *const core::ffi::c_void, + dst_ptr as *const core::ffi::c_void, + softmax_lse_ptr as *const core::ffi::c_void, /* alibi_slopes_ptr */ alibi_slopes_ptr, /* cu_seqlens_q_ptr */ std::ptr::null(), /* cu_seqlens_k_ptr */ std::ptr::null(), @@ -535,6 +533,7 @@ impl FlashAttnVarLen { let batch_size = nseqlens_q - 1; + let stream = dev.cuda_stream(); let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes { if alibi_slopes.dtype() != DType::F32 { candle_core::bail!( @@ -561,7 +560,8 @@ impl FlashAttnVarLen { let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..); - *alibi_slopes.device_ptr() as *const core::ffi::c_void + let (ptr, _guard) = alibi_slopes.device_ptr(&stream); + ptr as *const core::ffi::c_void } else { std::ptr::null() }; @@ -592,8 +592,8 @@ impl FlashAttnVarLen { let seqlen_k_rounded = round_multiple(self.max_seqlen_k, 128); let elem_count = out_shape.elem_count(); - let dst = unsafe { dev.alloc::(elem_count) }.w()?; - let softmax_lse = dev.alloc_zeros::(num_heads * total_q).w()?; + let dst = unsafe { dev.alloc::(elem_count) }?; + let softmax_lse = dev.alloc_zeros::(num_heads * total_q)?; let is_bf16 = if is_bf16 { 1 } else { 0 }; @@ -611,22 +611,22 @@ impl FlashAttnVarLen { window_size_right = self.max_seqlen_k as i32; } unsafe { - let q_ptr = *q.device_ptr() as *const core::ffi::c_void; - let k_ptr = *k.device_ptr() as *const core::ffi::c_void; - let v_ptr = *v.device_ptr() as *const core::ffi::c_void; - let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void; - let softmax_lse_ptr = *softmax_lse.device_ptr() as *const core::ffi::c_void; - let seqlens_q_ptr = *seqlens_q.device_ptr() as *const core::ffi::c_int; - let seqlens_k_ptr = *seqlens_k.device_ptr() as *const core::ffi::c_int; + let (q_ptr, _guard) = q.device_ptr(&stream); + let (k_ptr, _guard) = k.device_ptr(&stream); + let (v_ptr, _guard) = v.device_ptr(&stream); + let (dst_ptr, _guard) = dst.device_ptr(&stream); + let (softmax_lse_ptr, _guard) = softmax_lse.device_ptr(&stream); + let (seqlens_q_ptr, _guard) = seqlens_q.device_ptr(&stream); + let (seqlens_k_ptr, _guard) = seqlens_k.device_ptr(&stream); ffi::run_mha( - q_ptr, - k_ptr, - v_ptr, - dst_ptr, - softmax_lse_ptr, - /* alibi_slopes_ptr */ alibi_slopes_ptr, - /* cu_seqlens_q_ptr */ seqlens_q_ptr, - /* cu_seqlens_k_ptr */ seqlens_k_ptr, + q_ptr as *const core::ffi::c_void, + k_ptr as *const core::ffi::c_void, + v_ptr as *const core::ffi::c_void, + dst_ptr as *const core::ffi::c_void, + softmax_lse_ptr as *const core::ffi::c_void, + /* alibi_slopes_ptr */ alibi_slopes_ptr as *const core::ffi::c_void, + /* cu_seqlens_q_ptr */ seqlens_q_ptr as *const i32, + /* cu_seqlens_k_ptr */ seqlens_k_ptr as *const i32, /* q_batch_stride */ 0, /* k_batch_stride */ 0, /* v_batch_stride */ 0,