Skip to content

Commit

Permalink
refactor some old code with batching
Browse files Browse the repository at this point in the history
  • Loading branch information
LostRuins committed Feb 5, 2024
1 parent 38863a3 commit 35c32fd
Show file tree
Hide file tree
Showing 2 changed files with 252 additions and 88 deletions.
164 changes: 76 additions & 88 deletions gpttype_adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,8 @@ static llama_context * llama_ctx_v4;
static gpt_params * kcpp_params = nullptr;
static int max_context_limit_at_load = 0;
static int n_past = 0;
static int n_threads = 4;
static int n_blasthreads = 4;
static int n_batch = 8;
static bool useSmartContext = false;
static bool useContextShift = false;
static int blasbatchsize = 512;
static int smallbatchsize = 16;
static int debugmode = 0; //-1 = hide all, 0 = normal, 1 = showall
static std::string modelname;
static std::vector<gpt_vocab::id> last_n_tokens;
Expand Down Expand Up @@ -686,26 +681,38 @@ void PurgeMissingTokens(llama_context * ctx, std::vector<int> &current_context_t

}

static int GetBatchSize(int desiredBlasBatchSize,FileFormat in_file_format)
{
if(desiredBlasBatchSize<=0)
{
desiredBlasBatchSize = 16;
}
if (file_format != FileFormat::GGML && file_format != FileFormat::GGHF && file_format != FileFormat::GGJT && file_format != FileFormat::GGJT_2 && file_format != FileFormat::GGJT_3 && file_format != FileFormat::GGUF_GENERIC)
{
desiredBlasBatchSize = (desiredBlasBatchSize > 256 ? 256 : desiredBlasBatchSize);
}
if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
{
desiredBlasBatchSize = 1;
}
return desiredBlasBatchSize;
}

ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in_file_format, FileFormatExtraMeta file_format_meta)
{
ggml_time_init();
kcpp_params = new gpt_params(); //allocate on heap to avoid linux segfault. yes this leaks memory.

file_format = in_file_format;
n_threads = kcpp_params->n_threads = inputs.threads;
n_blasthreads = kcpp_params->n_threads_batch = inputs.blasthreads;
kcpp_params->n_threads = inputs.threads;
kcpp_params->n_threads_batch = inputs.blasthreads;
bool isGguf = (file_format == FileFormat::GGUF_GENERIC);

n_batch = kcpp_params->n_batch = smallbatchsize;
kcpp_params->n_batch = GetBatchSize(inputs.blasbatchsize, in_file_format);
modelname = kcpp_params->model = inputs.model_filename;
useSmartContext = inputs.use_smartcontext;
useContextShift = inputs.use_contextshift;
debugmode = inputs.debugmode;
blasbatchsize = inputs.blasbatchsize;
if(blasbatchsize<=0)
{
blasbatchsize = smallbatchsize;
}


auto clamped_max_context_length = inputs.max_context_length;

Expand Down Expand Up @@ -796,7 +803,6 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
SetQuantsUnshuffled(file_format == FileFormat::GGJT_2);
llama_v2_context_params llama_ctx_params_v2 = llama_v2_context_default_params();
llama_ctx_params_v2.n_ctx = clamped_max_context_length;
//llama_ctx_params.n_parts = -1;
llama_ctx_params_v2.seed = -1;
llama_ctx_params_v2.f16_kv = true;
llama_ctx_params_v2.logits_all = false;
Expand Down Expand Up @@ -827,7 +833,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
int err = llama_v2_apply_lora_from_file(llama_ctx_v2,
lora_filename.c_str(),
lora_base_arg,
n_threads);
kcpp_params->n_threads);
if (err != 0)
{
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
Expand All @@ -846,7 +852,6 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
{
llama_v3_context_params llama_ctx_params = llama_v3_context_default_params();
llama_ctx_params.n_ctx = clamped_max_context_length;
//llama_ctx_paran_parts = -1;
llama_ctx_params.seed = -1;
llama_ctx_params.f16_kv = true;
llama_ctx_params.low_vram = inputs.low_vram;
Expand All @@ -858,7 +863,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
llama_ctx_params.main_gpu = cu_parseinfo_maindevice;
llama_ctx_params.rope_freq_base = rope_freq_base;
llama_ctx_params.rope_freq_scale = rope_freq_scale;
llama_ctx_params.n_batch = blasbatchsize;
llama_ctx_params.n_batch = kcpp_params->n_batch;

#if defined(GGML_USE_CUBLAS)
bool ts_all_zero = true;
Expand Down Expand Up @@ -894,7 +899,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
int err = llama_v3_apply_lora_from_file(llama_ctx_v3,
lora_filename.c_str(),
lora_base_arg,
n_threads);
kcpp_params->n_threads);
if (err != 0)
{
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
Expand All @@ -915,6 +920,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
}
else if(file_format==FileFormat::GGUF_GENERIC)
{
llama_backend_init(false);

llama_model_params model_params = llama_model_default_params();
llama_context_params llama_ctx_params = llama_context_default_params();
llama_ctx_params.n_ctx = clamped_max_context_length;
Expand Down Expand Up @@ -955,9 +962,9 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
model_params.main_gpu = cu_parseinfo_maindevice;
model_params.split_mode = llama_split_mode::LLAMA_SPLIT_ROW;

llama_ctx_params.n_batch = blasbatchsize;
llama_ctx_params.n_threads = n_threads;
llama_ctx_params.n_threads_batch = n_blasthreads;
llama_ctx_params.n_batch = kcpp_params->n_batch;
llama_ctx_params.n_threads = kcpp_params->n_threads;
llama_ctx_params.n_threads_batch = kcpp_params->n_threads_batch;

#if defined(GGML_USE_CUBLAS)
bool ts_all_zero = true;
Expand Down Expand Up @@ -994,20 +1001,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
llamamodel->hparams.rope_freq_scale_train!=1.0f ||
llamamodel->hparams.rope_scaling_type_train==2)
{
// float ropemultiplier = 1.0f;
// if(llamamodel->hparams.rope_scaling_type_train!=2 &&
// llamamodel->hparams.n_ctx_train > 2048 && clamped_max_context_length > llamamodel->hparams.n_ctx_train &&
// llamamodel->hparams.rope_freq_scale_train==1.0f)
// {
// ropemultiplier = (float)llamamodel->hparams.n_ctx_train / (float)clamped_max_context_length;
// llama_ctx_params.rope_freq_base = rope_freq_base = llamamodel->hparams.rope_freq_base_train;
// llama_ctx_params.rope_freq_scale = rope_freq_scale = ropemultiplier * llamamodel->hparams.rope_freq_scale_train;
// printf("Automatic RoPE Scaling: Using (scale:%.3f, base:%.1f).\n", rope_freq_scale, rope_freq_base);
// }
// else
// {
printf("Automatic RoPE Scaling: Using model internal value.\n");
// }
printf("Automatic RoPE Scaling: Using model internal value.\n");
}
else
{
Expand Down Expand Up @@ -1038,7 +1032,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
lora_filename.c_str(),
1.0f,
lora_base_arg,
n_threads);
kcpp_params->n_threads);
if (err != 0)
{
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
Expand All @@ -1064,11 +1058,11 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
bool useWorldTokenizer = false;
if (file_format == FileFormat::RWKV_1)
{
rwkv_ctx_v2 = rwkv_v2_init_from_file(modelname.c_str(), n_threads);
rwkv_ctx_v2 = rwkv_v2_init_from_file(modelname.c_str(), kcpp_params->n_threads);
}
else //rwkv_2
{
rwkv_ctx_v3 = rwkv_init_from_file(modelname.c_str(), n_threads);
rwkv_ctx_v3 = rwkv_init_from_file(modelname.c_str(), kcpp_params->n_threads);

if(inputs.gpulayers>0)
{
Expand Down Expand Up @@ -1110,7 +1104,6 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in

if (file_format == FileFormat::RWKV_1)
{
n_batch = 1;

//setup buffers for rwkv state
auto padding = 512u;
Expand Down Expand Up @@ -1138,8 +1131,6 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
}
else
{
n_batch = 1; //do not use sequence mode to speedup until it is fixed

//setup buffers for rwkv state
auto padding = 512u;
auto statebufsiz = rwkv_get_state_buffer_element_count(rwkv_ctx_v3) * sizeof(float) + padding;
Expand Down Expand Up @@ -1472,6 +1463,22 @@ const std::string & gpttype_get_pending_output()
return concat_output_reader_copy;
}

bool GetThreadsToUse(bool blasmode)
{
if (blasmode)
{
if(!ggml_cpu_has_gpublas())
{
return 1;
}
else
{
return kcpp_params->n_threads_batch;
}
}
return kcpp_params->n_threads;
}

generation_outputs gpttype_generate(const generation_inputs inputs, generation_outputs &output)
{
if(kcpp_params==nullptr)
Expand All @@ -1482,6 +1489,12 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
generation_finished = true;
return output;
}

if(debugmode==1 && file_format == FileFormat::GGUF_GENERIC)
{
llama_reset_timings(llama_ctx_v4);
}

concat_output_mtx.lock();
concat_output = "";
concat_output_reader_copy = "";
Expand Down Expand Up @@ -1528,9 +1541,6 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
kcpp_params->dynatemp_range = inputs.dynatemp_range;
kcpp_params->dynatemp_exponent = inputs.dynatemp_exponent;
kcpp_params->n_ctx = inputs.max_context_length;
kcpp_params->n_batch = n_batch;
kcpp_params->n_threads = n_threads;
kcpp_params->n_threads_batch = n_blasthreads;
kcpp_params->smoothing_factor = inputs.smoothing_factor;

bool stream_sse = inputs.stream_sse;
Expand Down Expand Up @@ -1674,33 +1684,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
file_format == FileFormat::GPTJ_2 ||
file_format == FileFormat::RWKV_1 ||
file_format==FileFormat::RWKV_2);
bool blasmode = (approved_format && embd_inp.size() >= 32 && ggml_cpu_has_blas() && blasbatchsize>=32);
// bool blasmode = false;
int original_batch = kcpp_params->n_batch;
int original_threads = kcpp_params->n_threads;
if (blasmode)
{
//for non llama, limit to 256
int bbs = blasbatchsize;
if (file_format != FileFormat::GGML && file_format != FileFormat::GGHF && file_format != FileFormat::GGJT && file_format != FileFormat::GGJT_2 && file_format != FileFormat::GGJT_3 && file_format != FileFormat::GGUF_GENERIC)
{
bbs = (blasbatchsize > 256 ? 256 : blasbatchsize);
}

kcpp_params->n_batch = bbs; //received reports of 1024 and above crashing on some models
if(!ggml_cpu_has_gpublas())
{
//does not limit here for gguf anymore. this is kept for older models.
//new models will override threads inside decode fn.
kcpp_params->n_threads = 1;
kcpp_params->n_threads_batch = 1;
}
else
{
kcpp_params->n_threads = n_blasthreads;
kcpp_params->n_threads_batch = n_blasthreads;
}
}
bool blasmode = (approved_format && embd_inp.size() >= 32 && ggml_cpu_has_blas() && kcpp_params->n_batch>=32);

current_context_tokens.resize(n_past);

Expand Down Expand Up @@ -1828,11 +1812,11 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o

if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
{
evalres = (llama_v2_eval(llama_ctx_v2, embd.data(), embdsize, n_past, kcpp_params->n_threads)==0);
evalres = (llama_v2_eval(llama_ctx_v2, embd.data(), embdsize, n_past, GetThreadsToUse(blasmode))==0);
}
else if(file_format == FileFormat::GGJT_3)
{
evalres = (llama_v3_eval(llama_ctx_v3, embd.data(), embdsize, n_past, kcpp_params->n_threads)==0);
evalres = (llama_v3_eval(llama_ctx_v3, embd.data(), embdsize, n_past, GetThreadsToUse(blasmode))==0);
}
else if(file_format == FileFormat::GGUF_GENERIC)
{
Expand All @@ -1850,12 +1834,12 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
{
if(embd.size()>1)
{
evalres = rwkv_eval_sequence(rwkv_ctx_v3, kcpp_params->n_threads, (uint32_t*)embd.data(), embd.size(), rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
evalres = rwkv_eval_sequence(rwkv_ctx_v3, GetThreadsToUse(blasmode), (uint32_t*)embd.data(), embd.size(), rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
}
else
{
bool ignoreLogits = (!startedsampling && ((int)embd_inp.size() > input_consumed + 2));
evalres = rwkv_eval(rwkv_ctx_v3, kcpp_params->n_threads, embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, ignoreLogits?nullptr:rwkv_ctx_v3->logits_out);
evalres = rwkv_eval(rwkv_ctx_v3, GetThreadsToUse(blasmode), embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, ignoreLogits?nullptr:rwkv_ctx_v3->logits_out);
}

memcpy(logits.data(), rwkv_ctx_v3->logits_out, sizeof(float) * rwkv_vocab.size());
Expand All @@ -1864,39 +1848,39 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
}
else if(file_format==FileFormat::GPT2_1)
{
evalres = legacy_gpt2_eval(gpt2_ctx_v1, kcpp_params->n_threads, n_past, embd, logits, mem_per_token, file_format);
evalres = legacy_gpt2_eval(gpt2_ctx_v1, GetThreadsToUse(blasmode), n_past, embd, logits, mem_per_token, file_format);
}
else if(file_format==FileFormat::GPT2_2 || file_format==FileFormat::GPT2_3)
{
evalres = gpt2_v2_eval(gpt2_ctx_v2, kcpp_params->n_threads, n_past, embd, logits, mem_per_token, file_format);
evalres = gpt2_v2_eval(gpt2_ctx_v2, GetThreadsToUse(blasmode), n_past, embd, logits, mem_per_token, file_format);
}
else if(file_format==FileFormat::GPT2_4)
{
evalres = gpt2_eval(gpt2_ctx_v3, kcpp_params->n_threads, n_past, embd, logits, mem_per_token, v3_use_scratch);
evalres = gpt2_eval(gpt2_ctx_v3, GetThreadsToUse(blasmode), n_past, embd, logits, mem_per_token, v3_use_scratch);
}
else if(file_format==FileFormat::NEOX_1 || file_format == FileFormat::NEOX_2 || file_format == FileFormat::NEOX_3 || file_format==FileFormat::NEOX_4 || file_format==FileFormat::NEOX_5)
{
evalres = gpt_neox_v2_eval(neox_ctx_v2, kcpp_params->n_threads, n_past, embd, logits, mem_per_token);
evalres = gpt_neox_v2_eval(neox_ctx_v2, GetThreadsToUse(blasmode), n_past, embd, logits, mem_per_token);
}
else if(file_format==FileFormat::NEOX_6|| file_format==FileFormat::NEOX_7)
{
evalres = gpt_neox_eval(neox_ctx_v3, kcpp_params->n_threads, n_past, embd, logits, mem_per_token, v3_use_scratch);
evalres = gpt_neox_eval(neox_ctx_v3, GetThreadsToUse(blasmode), n_past, embd, logits, mem_per_token, v3_use_scratch);
}
else if(file_format==FileFormat::GPTJ_1 || file_format==FileFormat::GPTJ_2)
{
evalres = legacy_gptj_eval(gptj_ctx_v1, kcpp_params->n_threads, n_past, embd, logits, mem_per_token, file_format);
evalres = legacy_gptj_eval(gptj_ctx_v1, GetThreadsToUse(blasmode), n_past, embd, logits, mem_per_token, file_format);
}
else if(file_format==FileFormat::GPTJ_3 || file_format==FileFormat::GPTJ_4)
{
evalres = gptj_v2_eval(gptj_ctx_v2, kcpp_params->n_threads, n_past, embd, logits, mem_per_token);
evalres = gptj_v2_eval(gptj_ctx_v2, GetThreadsToUse(blasmode), n_past, embd, logits, mem_per_token);
}
else if(file_format==FileFormat::GPTJ_5)
{
evalres = gptj_eval(gptj_ctx_v3, kcpp_params->n_threads, n_past, embd, logits, mem_per_token, v3_use_scratch);
evalres = gptj_eval(gptj_ctx_v3, GetThreadsToUse(blasmode), n_past, embd, logits, mem_per_token, v3_use_scratch);
}
else if(file_format==FileFormat::MPT_1)
{
evalres = mpt_eval(mpt_ctx_v3, kcpp_params->n_threads, n_past, embd, logits, false, mem_per_token, v3_use_scratch);
evalres = mpt_eval(mpt_ctx_v3, GetThreadsToUse(blasmode), n_past, embd, logits, false, mem_per_token, v3_use_scratch);
}
else
{
Expand Down Expand Up @@ -1934,8 +1918,6 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
if (!startedsampling)
{
startedsampling = true;
kcpp_params->n_batch = original_batch;
kcpp_params->n_threads = original_threads;
time1 = timer_check();
timer_start();
if(allow_regular_prints)
Expand Down Expand Up @@ -2081,6 +2063,12 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
}
}
}

if(debugmode==1 && file_format == FileFormat::GGUF_GENERIC)
{
llama_print_timings(llama_ctx_v4);
}

time2 = timer_check();
float pt1 = (time1*1000.0/(embd_inp.size()==0?1:embd_inp.size()));
float ts1 = (1000.0/pt1);
Expand All @@ -2100,4 +2088,4 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
snprintf(output.text, sizeof(output.text), "%s", concat_output.c_str());

return output;
}
}
Loading

0 comments on commit 35c32fd

Please sign in to comment.