From aaf479a2a8440e6e59941fbdd64d327d9e2259a8 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sun, 1 Mar 2026 10:33:36 -0500 Subject: [PATCH] improved qkv speed by removing cont op --- src/flux.hpp | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/flux.hpp b/src/flux.hpp index 37cbb12..f2b0ded 100644 --- a/src/flux.hpp +++ b/src/flux.hpp @@ -103,11 +103,13 @@ namespace Flux { auto norm = std::dynamic_pointer_cast(blocks["norm"]); auto qkv = qkv_proj->forward(ctx, x); - auto qkv_vec = ggml_ext_chunk(ctx->ggml_ctx, qkv, 3, 0, true); - int64_t head_dim = qkv_vec[0]->ne[0] / num_heads; - auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); - auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); - auto v = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); + int64_t head_dim = qkv->ne[0] / 3 / num_heads; + auto q = ggml_view_4d(ctx->ggml_ctx, qkv, head_dim, num_heads, qkv->ne[1], qkv->ne[2], + qkv->nb[0]*head_dim, qkv->nb[1], qkv->nb[2], 0); + auto k = ggml_view_4d(ctx->ggml_ctx, qkv, head_dim, num_heads, qkv->ne[1], qkv->ne[2], + qkv->nb[0]*head_dim, qkv->nb[1], qkv->nb[2], (qkv->nb[0])*qkv->ne[0]/3); + auto v = ggml_view_4d(ctx->ggml_ctx, qkv, head_dim, num_heads, qkv->ne[1], qkv->ne[2], + qkv->nb[0]*head_dim, qkv->nb[1], qkv->nb[2], (qkv->nb[0])*2*qkv->ne[0]/3); q = norm->query_norm(ctx, q); k = norm->key_norm(ctx, k); return {q, k, v}; @@ -491,15 +493,15 @@ namespace Flux { auto x_mod = Flux::modulate(ctx->ggml_ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale); auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim*mlp_mult_factor] - auto q = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, hidden_size, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], 0); - auto k = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, hidden_size, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], hidden_size * qkv_mlp->nb[0]); - auto v = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, hidden_size, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], hidden_size * 2 * qkv_mlp->nb[0]); - int64_t head_dim = hidden_size / num_heads; - q = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, q), head_dim, num_heads, q->ne[1], q->ne[2]); // [N, n_token, n_head, d_head] - k = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, k), head_dim, num_heads, k->ne[1], k->ne[2]); // [N, n_token, n_head, d_head] - v = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, v), head_dim, num_heads, v->ne[1], v->ne[2]); // [N, n_token, n_head, d_head] + auto q = ggml_view_4d(ctx->ggml_ctx, qkv_mlp, head_dim, num_heads, qkv_mlp->ne[1], qkv_mlp->ne[2], + qkv_mlp->nb[0]*head_dim, qkv_mlp->nb[1], qkv_mlp->nb[2], 0); + auto k = ggml_view_4d(ctx->ggml_ctx, qkv_mlp, head_dim, num_heads, qkv_mlp->ne[1], qkv_mlp->ne[2], + qkv_mlp->nb[0]*head_dim, qkv_mlp->nb[1], qkv_mlp->nb[2], (qkv_mlp->nb[0])*hidden_size); + auto v = ggml_view_4d(ctx->ggml_ctx, qkv_mlp, head_dim, num_heads, qkv_mlp->ne[1], qkv_mlp->ne[2], + qkv_mlp->nb[0]*head_dim, qkv_mlp->nb[1], qkv_mlp->nb[2], (qkv_mlp->nb[0])*2*hidden_size); + q = norm->query_norm(ctx, q); k = norm->key_norm(ctx, k);