-
Notifications
You must be signed in to change notification settings - Fork 12.2k
llama : add high-throughput mode #14363
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
ggerganov
wants to merge
11
commits into
gg/kv-cache-use-set-rows
Choose a base branch
from
gg/llama-high-throughput
base: gg/kv-cache-use-set-rows
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
+799
−317
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
ggml-ci
ggml-ci
Right now I am comparatively less busy with my PhD so it would be a good time for me to write CUDA code that is still missing, if there is any. |
For now, these are the necessary CUDA changes:
// old
// q: [n_embd_k, n_batch, n_head, 1]
// k: [n_embd_k, n_kv, n_head_kv, 1]
// v: [n_embd_v, n_kv, n_head_kv, 1] !! not transposed !!
// mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
// res: [n_embd_v, n_head, n_batch, 1] !! permuted !!
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
...);
// new - supports `n_seq` dimension:
// q: [n_embd_k, n_batch, n_head, n_seq]
// k: [n_embd_k, n_kv, n_head_kv, n_seq]
// v: [n_embd_v, n_kv, n_head_kv, n_seq] !! not transposed !!
// mask: [n_kv, n_batch_pad, n_seq, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
// res: [n_embd_v, n_head, n_batch, n_seq] !! permuted !!
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
...); CPU might also need to be extended (not sure yet)
Edit: the CPU versions of |
ab2a2bb
to
1b74b9d
Compare
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
Apple Metal
https://en.wikipedia.org/wiki/Metal_(API)
examples
ggml
changes relating to the ggml tensor library for machine learning
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
target #14285
Overview
Improve multi-sequence decoding performance by avoiding the cross-sequence attention compute of the unified KV cache.
Still WIP, but initial results are promising. The functionality is temporarily gated with env
LLAMA_HT
and also requires theLLAMA_SET_ROWS
from #14285.Detailed description will be added when I make some more progress and am more convinced that the approach is viable.
Testing
Using a more real-world example with
llama-parallel
:TODO
ggml_soft_max_ext()
support for virtual sequencesllama_memory_seq_cp
support for virtual sequencessplit_equal
support sequential idsLLAMA_HT
become regular compute parametern_ctx
meaning (total vs per-sequence)n_embd_v_gqa(il) == const
when FA is off