Skip to content

llama : add high-throughput mode #14363

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: gg/kv-cache-use-set-rows
Choose a base branch
from
3 changes: 2 additions & 1 deletion examples/parallel/parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ int main(int argc, char ** argv) {

// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
llama_batch batch = llama_batch_init(n_ctx, 0, 1);
llama_batch batch = llama_batch_init(n_ctx*n_clients, 0, 1);

int32_t n_total_prompt = 0;
int32_t n_total_gen = 0;
Expand Down Expand Up @@ -289,6 +289,7 @@ int main(int argc, char ** argv) {
// all sequences have ended - clear the entire KV cache
for (int i = 1; i <= n_clients; ++i) {
llama_memory_seq_rm(mem, i, -1, -1);

// but keep the system prompt
llama_memory_seq_cp(mem, 0, i, -1, -1);
}
Expand Down
13 changes: 9 additions & 4 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4853,7 +4853,8 @@ static void ggml_compute_forward_soft_max_f32(

GGML_TENSOR_UNARY_OP_LOCALS

//const int64_t ne11 = src1 ? src1->ne[1] : 1;
const int64_t nb11 = src1 ? src1->nb[1] : 1;
const int64_t nb12 = src1 ? src1->nb[2] : 1;

// TODO: is this supposed to be ceil instead of floor?
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
Expand All @@ -4878,6 +4879,10 @@ static void ggml_compute_forward_soft_max_f32(
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);

for (int i1 = ir0; i1 < ir1; i1++) {
const int64_t i11 = (i1%ne01);
//const int64_t i12 = (i1/ne01)%ne02;
const int64_t i13 = (i1/ne01)/ne02;

// ALiBi
const uint32_t h = (i1/ne01)%ne02; // head
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
Expand All @@ -4886,8 +4891,8 @@ static void ggml_compute_forward_soft_max_f32(
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);

// broadcast the mask across rows
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i13*nb12) : NULL;
float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i13*nb12) : NULL;

ggml_vec_cpy_f32 (nc, wp, sp);
ggml_vec_scale_f32(nc, wp, scale);
Expand Down Expand Up @@ -7227,7 +7232,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
memset(VKQ32, 0, DV*sizeof(float));
}

const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + iq3*mask->nb[2]) : NULL;

// k indices
const int ik3 = iq3 / rk3;
Expand Down
3 changes: 3 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ typedef struct {
uint64_t nb22;
uint64_t nb23;
uint64_t nb31;
uint64_t nb32;
int32_t ne1;
int32_t ne2;
float scale;
Expand Down Expand Up @@ -453,6 +454,8 @@ typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
uint64_t nb11;
uint64_t nb12;
float scale;
float max_bias;
float m0;
Expand Down
8 changes: 4 additions & 4 deletions ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -2562,10 +2562,7 @@ static bool ggml_metal_encode_node(
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));

const int64_t nrows_x = ggml_nrows(src0);
const int64_t nrows_y = src0->ne[1];

const uint32_t n_head = nrows_x/nrows_y;
const uint32_t n_head = src0->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));

const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
Expand Down Expand Up @@ -2625,6 +2622,8 @@ static bool ggml_metal_encode_node(
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.scale =*/ scale,
/*.max_bias =*/ max_bias,
/*.m0 =*/ m0,
Expand Down Expand Up @@ -4882,6 +4881,7 @@ static bool ggml_metal_encode_node(
/*.nb22 =*/ nb22,
/*.nb23 =*/ nb23,
/*.nb31 =*/ nb31,
/*.nb32 =*/ nb32,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.scale =*/ scale,
Expand Down
8 changes: 4 additions & 4 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1263,7 +1263,7 @@ kernel void kernel_soft_max(
const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);

device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00 : nullptr;
device const T * pmask = src1 != src0 ? (device const T *) (src1 + i01*args.nb11 + i03*args.nb12) : nullptr;
device float * pdst = (device float *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);

float slope = 1.0f;
Expand Down Expand Up @@ -1359,7 +1359,7 @@ kernel void kernel_soft_max_4(
const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);

device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00/4 : nullptr;
device const T * pmask = src1 != src0 ? (device const T *) (src1 + i01*args.nb11 + i03*args.nb12) : nullptr;
device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;

float slope = 1.0f;
Expand Down Expand Up @@ -3645,7 +3645,7 @@ kernel void kernel_flash_attn_ext(
// load the mask in shared memory
#pragma unroll(Q)
for (short j = 0; j < Q; ++j) {
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + iq3*args.nb32);

const float m = pm[ic + tiisg];

Expand Down Expand Up @@ -4131,7 +4131,7 @@ kernel void kernel_flash_attn_ext_vec(
const bool has_mask = mask != q;

// pointer to the mask
device const half * pm = (device const half *) (mask + iq1*args.nb31);
device const half * pm = (device const half *) (mask + iq1*args.nb31 + iq3*args.nb32);

float slope = 1.0f;

Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -3526,7 +3526,7 @@ static struct ggml_tensor * ggml_soft_max_impl(
if (mask) {
GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(mask));
GGML_ASSERT(ggml_is_matrix(mask));
GGML_ASSERT(ggml_is_3d(mask));
GGML_ASSERT(mask->ne[0] == a->ne[0]);
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
}
Expand Down Expand Up @@ -4504,7 +4504,7 @@ struct ggml_tensor * ggml_flash_attn_ext(

if (mask) {
GGML_ASSERT(ggml_is_contiguous(mask));
GGML_ASSERT(mask->ne[2] == 1);
GGML_ASSERT(mask->ne[2] == q->ne[3]);
GGML_ASSERT(mask->ne[3] == 1);
GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
"the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
Expand Down
28 changes: 27 additions & 1 deletion src/llama-batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ bool llama_batch_allocr::init(

// note: tracking the other way around is not necessary for now
//seq_cpl[s0][s1] = true;

has_cpl = true;
}
}
}
Expand Down Expand Up @@ -403,6 +405,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
return n_outputs;
}

uint32_t llama_batch_allocr::get_n_used() const {
return n_used;
}

std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
return out_ids;
}
Expand All @@ -418,6 +424,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
void llama_batch_allocr::split_reset() {
out_ids.clear();

n_used = 0;

used.clear();
used.resize(get_n_tokens(), false);

Expand All @@ -442,6 +450,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
idxs.push_back(cur_idx);

used[cur_idx] = true;
++n_used;

++cur_idx;

Expand All @@ -457,9 +466,17 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
return ubatch_add(idxs, idxs.size(), false);
}

llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
if (sequential && has_cpl) {
LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__);

return {};
}

std::vector<seq_set_t> cur_seq_set;

llama_seq_id last_seq_id = -1;

// determine the non-overlapping sequence sets participating in this ubatch
for (int32_t i = 0; i < batch.n_tokens; ++i) {
if (used[i]) {
Expand All @@ -476,9 +493,16 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
}
}

// accept only increasing sequence ids
if (sequential) {
add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
}

if (add) {
cur_seq_set.push_back(seq_set[i]);

last_seq_id = batch.seq_id[i][0];

if (cur_seq_set.size() > n_ubatch) {
break;
}
Expand Down Expand Up @@ -527,6 +551,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
idxs_per_seq[s].push_back(idx);

used[idx] = true;
++n_used;

++cur_idx[s];
}
Expand Down Expand Up @@ -568,6 +593,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
idxs.push_back(cur_idx);

used[cur_idx] = true;
++n_used;

if (idxs.size() >= n_ubatch) {
break;
Expand Down
9 changes: 8 additions & 1 deletion src/llama-batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class llama_batch_allocr {

uint32_t get_n_tokens() const;
uint32_t get_n_outputs() const;
uint32_t get_n_used() const;

// the array of output indices in the order they were encountered during the ubatch splitting
std::vector<int32_t> & get_out_ids();
Expand All @@ -69,7 +70,8 @@ class llama_batch_allocr {
llama_ubatch split_simple(uint32_t n_ubatch);

// make ubatches of equal-length sequences sets
llama_ubatch split_equal(uint32_t n_ubatch);
// if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);

// sequence-set-wise split - each ubatch contains a single sequence-set
llama_ubatch split_seq(uint32_t n_ubatch);
Expand Down Expand Up @@ -112,6 +114,9 @@ class llama_batch_allocr {
using pos_set_t = std::set<llama_pos>;
using seq_cpl_t = std::vector<bool>;

// helper flag to quickly determine if there are any coupled sequences in the batch
bool has_cpl;

std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1

Expand All @@ -125,6 +130,8 @@ class llama_batch_allocr {
// batch indices of the output
std::vector<int32_t> out_ids;

uint32_t n_used;

// used[i] indicates if token i has already been used in a previous ubatch
std::vector<bool> used;

Expand Down
3 changes: 3 additions & 0 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ llama_context::llama_context(
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
}

const char * LLAMA_HT = getenv("LLAMA_HT");
cparams.n_seq_virt = LLAMA_HT ? cparams.n_seq_max : 1;

cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor;
Expand Down
5 changes: 3 additions & 2 deletions src/llama-cparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ struct llama_cparams {
uint32_t n_batch;
uint32_t n_ubatch;
uint32_t n_seq_max;
int n_threads; // number of threads to use for generation
int n_threads_batch; // number of threads to use for batch processing
uint32_t n_seq_virt;
int32_t n_threads; // number of threads to use for generation
int32_t n_threads_batch; // number of threads to use for batch processing

float rope_freq_base;
float rope_freq_scale;
Expand Down
Loading
Loading