Skip to content

llama : add high-throughput mode #14363

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: gg/kv-cache-use-set-rows
Choose a base branch
from

Conversation

ggerganov
Copy link
Member

@ggerganov ggerganov commented Jun 24, 2025

target #14285

Overview

Improve multi-sequence decoding performance by avoiding the cross-sequence attention compute of the unified KV cache.

Still WIP, but initial results are promising. The functionality is temporarily gated with env LLAMA_HT and also requires the LLAMA_SET_ROWS from #14285.

Detailed description will be added when I make some more progress and am more convinced that the approach is viable.

Testing

# master
make -j && ./bin/llama-batched-bench -m ../models/qwen2.5-3b-coder/ggml-model-q8_0.gguf -c 100000 -b 2048 -ub 512 -npp 0,0,512,1024,2048 -ntg 32 -npl 32 -fa

main: n_kv_max = 100096, n_batch = 2048, n_ubatch = 512, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 16, n_threads_batch = 16

|    PP |     TG |    B |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |      T s |    S t/s |
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
|     0 |     32 |   32 |   1024 |    0.000 |     0.00 |    1.387 |   738.54 |    1.387 |   738.51 |
|     0 |     32 |   32 |   1024 |    0.000 |     0.00 |    1.396 |   733.74 |    1.396 |   733.69 |
|   512 |     32 |   32 |  17408 |    6.054 |  2706.51 |    2.021 |   506.62 |    8.075 |  2155.85 |
|  1024 |     32 |   32 |  33792 |   13.002 |  2520.20 |    2.668 |   383.80 |   15.670 |  2156.45 |
|  2048 |     32 |   32 |  66560 |   29.922 |  2190.20 |    3.957 |   258.81 |   33.879 |  1964.64 |
# PR
make -j && LLAMA_HT=1 LLAMA_SET_ROWS=1 ./bin/llama-batched-bench -m ../models/qwen2.5-3b-coder/ggml-model-q8_0.gguf -c 10000 -b 2048 -ub 512 -npp 0,0,512,1024,2048 -ntg 32 -npl 32 -fa

main: n_kv_max = 10240, n_batch = 2048, n_ubatch = 512, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 16, n_threads_batch = 16

|    PP |     TG |    B |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |      T s |    S t/s |
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
|     0 |     32 |   32 |   1024 |    0.000 |     0.00 |    1.322 |   774.85 |    1.322 |   774.71 |
|     0 |     32 |   32 |   1024 |    0.000 |     0.00 |    1.315 |   778.66 |    1.315 |   778.49 |
|   512 |     32 |   32 |  17408 |    5.630 |  2910.32 |    1.474 |   694.51 |    7.104 |  2450.44 |
|  1024 |     32 |   32 |  33792 |   11.577 |  2830.46 |    1.549 |   661.13 |   13.126 |  2574.48 |
|  2048 |     32 |   32 |  66560 |   24.376 |  2688.52 |    1.729 |   592.27 |   26.105 |  2549.69 |

Using a more real-world example with llama-parallel:

# master
make -j && ./bin/llama-parallel -m ../models/qwen2.5-3b-coder/ggml-model-q8_0.gguf -np 8 -ns 128 -s 1 -c 16384 -fa

# PR
make -j && LLAMA_HT=1 LLAMA_SET_ROWS=1 ./bin/llama-parallel -m ../models/qwen2.5-3b-coder/ggml-model-q8_0.gguf -np 32 -ns 128 -s 1 -c 4096 -fa

TODO

  • FA path
  • Non-FA path
  • Metal FA
  • Metal non-FA
  • CPU FA
  • CPU non-FA
  • ggml_soft_max_ext() support for virtual sequences
  • llama_memory_seq_cp support for virtual sequences
  • iSWA
  • split_equal support sequential ids
  • CUDA
  • Vulkan
  • etc.
  • more consistent sequence/virtual sequence naming
  • better term than "virtual sequence"?
  • env LLAMA_HT become regular compute parameter
  • Fix n_ctx meaning (total vs per-sequence)
  • Check input batch for no coupled sequences when HT is on
  • Require n_embd_v_gqa(il) == const when FA is off
@github-actions github-actions bot added examples ggml changes relating to the ggml tensor library for machine learning Apple Metal https://en.wikipedia.org/wiki/Metal_(API) labels Jun 24, 2025
@JohannesGaessler
Copy link
Collaborator

Right now I am comparatively less busy with my PhD so it would be a good time for me to write CUDA code that is still missing, if there is any.

@ggerganov
Copy link
Member Author

ggerganov commented Jun 24, 2025

For now, these are the necessary CUDA changes:

  • Add ggml_set_rows() support (need PR towards ggml : add ggml_set_rows #14274, can already start implementing this)
  • Extend ggml_flash_attn_ext() to support n_seq dim if it does not yet:
// old
    // q:    [n_embd_k, n_batch,     n_head,    1]
    // k:    [n_embd_k, n_kv,        n_head_kv, 1]
    // v:    [n_embd_v, n_kv,        n_head_kv, 1] !! not transposed !!
    // mask: [n_kv,     n_batch_pad, 1,         1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
    // res:  [n_embd_v, n_head,      n_batch,   1] !! permuted !!
    GGML_API struct ggml_tensor * ggml_flash_attn_ext(
            ...);

// new - supports `n_seq` dimension:
    // q:    [n_embd_k, n_batch,     n_head,    n_seq]
    // k:    [n_embd_k, n_kv,        n_head_kv, n_seq]
    // v:    [n_embd_v, n_kv,        n_head_kv, n_seq] !! not transposed !!
    // mask: [n_kv,     n_batch_pad, n_seq,         1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
    // res:  [n_embd_v, n_head,      n_batch,   n_seq] !! permuted !!
    GGML_API struct ggml_tensor * ggml_flash_attn_ext(
            ...);

CPU might also need to be extended (not sure yet)

  • Extend ggml_soft_max_ext to support n_seq dim if it does not yet in a similar way. Also not sure about the CPU state.

Edit: the CPU versions of ggml_soft_max_ext() and ggml_flash_attn_ext() are now correct and can be used as a reference.

@ggerganov ggerganov force-pushed the gg/llama-high-throughput branch from ab2a2bb to 1b74b9d Compare June 24, 2025 17:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple Metal https://en.wikipedia.org/wiki/Metal_(API) examples ggml changes relating to the ggml tensor library for machine learning
2 participants