From ee599f901a3f4ebbc6e42f2273e3e08dfc5b2646 Mon Sep 17 00:00:00 2001 From: Michael Podvitskiy Date: Tue, 22 Oct 2024 19:57:15 +0200 Subject: [PATCH] llama: correct reverting of the entire batch. also updates `llama_kv_cache_find_slot`, will correctly count the number of `used` cells for recurrent models --- src/llama.cpp | 122 ++++++++++++++++++++++++++------------------------ 1 file changed, 64 insertions(+), 58 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 85e613a63fef5..48f1f254b35c5 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2811,22 +2811,6 @@ struct llama_kv_cache { } }; -// saves the kv_cache state for future recovery -// used to preserve the kv_cache state before searching for a slot -struct llama_kv_slot_restorer { - struct llama_kv_cache_state { - uint32_t head = 0; - uint32_t size = 0; - uint32_t used = 0; - uint32_t n = 0; - } old_state; - - std::vector recurrent_cells; // for recurrent models only - std::pair slot_boundaries; // for non-recurrent models only - - bool restore = false; -}; - struct llama_control_vector { std::vector tensors; // per layer std::vector ctxs; @@ -3522,21 +3506,24 @@ static bool llama_kv_cache_init( // updates the cache head // Note: On success, it's important that cache.head points // to the first cell of the slot. -static bool llama_kv_cache_find_slot( +struct llama_kv_cache_slot_info { + std::pair boundaries; + bool found = false; + + explicit llama_kv_cache_slot_info(bool found_) : found{found_} {} + llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {} + + operator bool() const { return found; } +}; +static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false}; + +static struct llama_kv_cache_slot_info llama_kv_cache_find_slot( struct llama_kv_cache & cache, - const struct llama_ubatch & batch, - struct llama_kv_slot_restorer * slot_restorer = nullptr) { + const struct llama_ubatch & batch) { const uint32_t n_tokens = batch.n_tokens; const uint32_t n_seqs = batch.n_seqs; const uint32_t n_seq_tokens = batch.n_seq_tokens; - if (slot_restorer != nullptr) { - slot_restorer->old_state.head = cache.head; - slot_restorer->old_state.size = cache.size; - slot_restorer->old_state.used = cache.used; - slot_restorer->old_state.n = cache.n; - } - if (cache.recurrent) { // For recurrent state architectures (like Mamba or RWKV), // each cache cell can store the state for a whole sequence. @@ -3545,11 +3532,6 @@ static bool llama_kv_cache_find_slot( // can only process batches with an equal number of new tokens in each sequence GGML_ASSERT(batch.equal_seqs); - if (slot_restorer != nullptr) { - slot_restorer->recurrent_cells = cache.cells; - slot_restorer->restore = true; - } - int32_t min = cache.size - 1; int32_t max = 0; @@ -3563,7 +3545,7 @@ static bool llama_kv_cache_find_slot( // too big seq_id // TODO: would it be possible to resize the cache instead? LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size); - return false; + return llama_kv_cache_slot_info_failed; } if (j > 0) { llama_kv_cell & seq = cache.cells[seq_id]; @@ -3698,15 +3680,17 @@ static bool llama_kv_cache_find_slot( // allow getting the range of used cells, from head to head + n cache.head = min; cache.n = max - min + 1; + cache.used = std::count_if(cache.cells.begin(), cache.cells.end(), + [](const llama_kv_cell& cell){ return !cell.is_empty(); }); // sanity check - return cache.n >= n_seqs; + return llama_kv_cache_slot_info(cache.n >= n_seqs); } // otherwise, one cell per token. if (n_tokens > cache.size) { LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size); - return false; + return llama_kv_cache_slot_info_failed; } uint32_t n_tested = 0; @@ -3734,15 +3718,10 @@ static bool llama_kv_cache_find_slot( if (n_tested >= cache.size) { //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); - return false; + return llama_kv_cache_slot_info_failed; } } - if (slot_restorer != nullptr) { - slot_restorer->slot_boundaries = std::make_pair(cache.head, cache.head + n_tokens); - slot_restorer->restore = true; - } - for (uint32_t s = 0; s < n_seqs; s++) { for (uint32_t i = 0; i < n_seq_tokens; ++i) { uint32_t k = s*n_seq_tokens + i; @@ -3756,7 +3735,7 @@ static bool llama_kv_cache_find_slot( cache.used += n_tokens; - return true; + return llama_kv_cache_slot_info(cache.head, cache.head + n_tokens); } // find how many cells are currently in use @@ -4032,22 +4011,47 @@ static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) return cparams.flash_attn ? 256u : 32u; } -static void llama_kv_cache_slot_restore( - const struct llama_kv_slot_restorer & restorer, - struct llama_kv_cache & cache) { - if (restorer.restore) { - cache.head = restorer.old_state.head; - cache.size = restorer.old_state.size; - cache.used = restorer.old_state.used; - cache.n = restorer.old_state.n; - - if (cache.recurrent) { - cache.cells = restorer.recurrent_cells; - } else { - llama_kv_cache_seq_rm(cache, -1, restorer.slot_boundaries.first, restorer.slot_boundaries.second + 1); +// saves the kv_cache state for future recovery. +// used to rollback llama_kv_cache_find_slot changes. +struct llama_kv_slot_restorer { + struct llama_kv_cache_state { + uint32_t head = 0; + uint32_t n = 0; + } old_state; + + std::vector> slot_boundaries; // for non-recurrent models only + + bool do_restore = false; + + explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) { + old_state.head = cache.head; + old_state.n = cache.n; + } + + void save(const struct llama_kv_cache_slot_info& slot) { + if (slot) { + do_restore = true; + if (slot.boundaries.first != slot.boundaries.second) { + slot_boundaries.push_back(slot.boundaries); + } } } -} + + void restore(struct llama_kv_cache & cache) { + if (do_restore) { + cache.head = old_state.head; + cache.n = old_state.n; + + if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased + llama_kv_cache_seq_rm(cache, -1, -1, -1); + } else { + for (auto & slot : slot_boundaries) { + llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second); + } + } + } + } +}; // // model loading and saving @@ -17307,7 +17311,7 @@ static int llama_decode_internal( lctx.n_queued_tokens += n_tokens_all; auto & kv_self = lctx.kv_self; - llama_kv_slot_restorer kv_slot_restorer; + llama_kv_slot_restorer kv_slot_restorer(kv_self); const int64_t n_embd = hparams.n_embd; const int64_t n_vocab = hparams.n_vocab; @@ -17392,9 +17396,11 @@ static int llama_decode_internal( kv_self.head = 0; } - if (!llama_kv_cache_find_slot(kv_self, ubatch, &kv_slot_restorer)) { + const auto slot = llama_kv_cache_find_slot(kv_self, ubatch); + if (!slot) { return 1; } + kv_slot_restorer.save(slot); if (!kv_self.recurrent) { // a heuristic, to avoid attending the full cache if it is not yet utilized @@ -17443,7 +17449,7 @@ static int llama_decode_internal( const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool); if (compute_status != GGML_STATUS_SUCCESS) { - llama_kv_cache_slot_restore(kv_slot_restorer, kv_self); + kv_slot_restorer.restore(kv_self); switch (compute_status) { case GGML_STATUS_ABORTED: return 2;