diff --git a/.gitignore b/.gitignore index abc801b252e27..32742106ced83 100644 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,7 @@ models-mnt /embedding /gguf /gguf-llama-simple +/imatrix /infill /libllama.so /llama-bench diff --git a/common/common.cpp b/common/common.cpp index a0378a215674a..aaa31895d8684 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -631,6 +631,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } params.ppl_stride = std::stoi(argv[i]); + } else if (arg == "-ptc" || arg == "--print-token-count") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_print = std::stoi(argv[i]); } else if (arg == "--ppl-output-type") { if (++i >= argc) { invalid_param = true; @@ -813,7 +819,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf("\n"); printf("options:\n"); printf(" -h, --help show this help message and exit\n"); - printf(" --version show version and build info\n"); + printf(" --version show version and build info\n"); printf(" -i, --interactive run in interactive mode\n"); printf(" --interactive-first run in interactive mode and wait for input right away\n"); printf(" -ins, --instruct run in instruction mode (use with Alpaca models)\n"); @@ -910,7 +916,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" number of layers to store in VRAM\n"); printf(" -ngld N, --n-gpu-layers-draft N\n"); printf(" number of layers to store in VRAM for the draft model\n"); - printf(" -ts SPLIT --tensor-split SPLIT\n"); + printf(" -ts SPLIT, --tensor-split SPLIT\n"); printf(" how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n"); printf(" -mg i, --main-gpu i the GPU to use for scratch and small tensors\n"); #ifdef GGML_USE_CUBLAS @@ -945,6 +951,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --override-kv KEY=TYPE:VALUE\n"); printf(" advanced option to override model metadata by key. may be specified multiple times.\n"); printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); + printf(" -ptc N, --print-token-count N\n"); + printf(" print token count every N tokens (default: %d)\n", params.n_print); printf("\n"); #ifndef LOG_DISABLE_LOGS log_print_usage(); @@ -1048,6 +1056,9 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & } static ggml_type kv_cache_type_from_str(const std::string & s) { + if (s == "f32") { + return GGML_TYPE_F32; + } if (s == "f16") { return GGML_TYPE_F16; } diff --git a/common/common.h b/common/common.h index 4a19688920e9a..1dfdaa7adbfca 100644 --- a/common/common.h +++ b/common/common.h @@ -58,6 +58,7 @@ struct gpt_params { int32_t n_beams = 0; // if non-zero then use beam search of given width. int32_t grp_attn_n = 1; // group-attention factor int32_t grp_attn_w = 512; // group-attention width + int32_t n_print = -1; // print token count every n tokens (-1 = disabled) float rope_freq_base = 0.0f; // RoPE base frequency float rope_freq_scale = 0.0f; // RoPE frequency scaling factor float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor @@ -254,4 +255,3 @@ void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80); // Dump the KV cache view showing individual sequences in each cell (long output). void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40); - diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 203eaf64b3fc3..a1c79fd478c22 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -817,10 +817,17 @@ def set_gguf_parameters(self): hidden_size = self.hparams["hidden_size"] self.gguf_writer.add_name('persimmon-8b-chat') + self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) self.gguf_writer.add_embedding_length(hidden_size) self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) - self.gguf_writer.add_rope_dimension_count(hidden_size // head_count) + + # NOTE: not sure about this change - why does the model not have a rope dimension count when it is smaller + # than the head size? + # ref: https://github.com/ggerganov/llama.cpp/pull/4889 + # self.gguf_writer.add_rope_dimension_count(hidden_size // head_count) + self.gguf_writer.add_rope_dimension_count(hidden_size // head_count // 2) + self.gguf_writer.add_head_count(head_count) self.gguf_writer.add_head_count_kv(head_count_kv) self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"]) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 0c71cbdf72a65..fa127a3aa7c9e 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -36,6 +36,7 @@ else() add_subdirectory(lookahead) add_subdirectory(lookup) add_subdirectory(train-text-from-scratch) + add_subdirectory(imatrix) if (LLAMA_METAL) add_subdirectory(metal) endif() diff --git a/examples/export-lora/export-lora.cpp b/examples/export-lora/export-lora.cpp index 58fbe204d3bbb..4cd5d99bb21ec 100644 --- a/examples/export-lora/export-lora.cpp +++ b/examples/export-lora/export-lora.cpp @@ -245,9 +245,8 @@ static struct lora_data * load_lora(struct lora_info * info) { params_ggml.no_alloc = true; result->ctx = ggml_init(params_ggml); - uint32_t LLAMA_FILE_MAGIC_LORA = 0x67676C61; // 'ggla' uint32_t magic = file.read_u32(); - if (magic != LLAMA_FILE_MAGIC_LORA) { + if (magic != LLAMA_FILE_MAGIC_GGLA) { die_fmt("unexpected lora header file magic in '%s'", info->filename.c_str()); } uint32_t version = file.read_u32(); diff --git a/examples/imatrix/CMakeLists.txt b/examples/imatrix/CMakeLists.txt new file mode 100644 index 0000000000000..d688a16209049 --- /dev/null +++ b/examples/imatrix/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET imatrix) +add_executable(${TARGET} imatrix.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp new file mode 100644 index 0000000000000..1461bc96376a7 --- /dev/null +++ b/examples/imatrix/imatrix.cpp @@ -0,0 +1,380 @@ +#include "common.h" +#include "llama.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +struct Stats { + std::vector values; + int ncall = 0; +}; + +struct StatParams { + std::string ofile = "imatrix.dat"; + int n_output_frequency = 10; + int verbosity = 1; + bool collect_output_weight = false; +}; + +class IMatrixCollector { +public: + IMatrixCollector() = default; + void set_parameters(StatParams&& params) { m_params = std::move(params); } + void collect_imatrix(const struct ggml_tensor * src0, const struct ggml_tensor * src1); + void save_imatrix() const; +private: + std::unordered_map m_stats; + StatParams m_params; + std::mutex m_mutex; + int m_last_call = 0; +}; + +void IMatrixCollector::collect_imatrix(const struct ggml_tensor * src0, const struct ggml_tensor * src1) { + if (src1->ne[1] < 16 || src1->type != GGML_TYPE_F32) return; + if (!(strncmp(src0->name, "blk.", 4) == 0 || (m_params.collect_output_weight && strcmp(src0->name, "output.weight") == 0))) return; + std::lock_guard lock(m_mutex); + auto& e = m_stats[src0->name]; + if (e.values.empty()) { + e.values.resize(src1->ne[0], 0); + } + else if (e.values.size() != (size_t)src1->ne[0]) { + fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", src0->name, (int)e.values.size(), (int)src1->ne[0]); + exit(1); //GGML_ASSERT(false); + } + ++e.ncall; + if (m_params.verbosity > 1) { + printf("%s[%d]: %s, %d x %d, %d\n",__func__,m_last_call,src0->name,(int)src1->ne[0],(int)src1->ne[1],(int)src1->type); + } + for (int row = 0; row < (int)src1->ne[1]; ++row) { + const float * x = (const float *)src1->data + row * src1->ne[0]; + for (int j = 0; j < (int)src1->ne[0]; ++j) { + e.values[j] += x[j]*x[j]; + } + } + if (e.ncall > m_last_call) { + m_last_call = e.ncall; + if (m_last_call % m_params.n_output_frequency == 0) { + save_imatrix(); + } + } +} + +void IMatrixCollector::save_imatrix() const { + const char * fname = m_params.ofile.empty() ? "imatrix.dat" : m_params.ofile.c_str(); + std::ofstream out(fname, std::ios::binary); + int n_entries = m_stats.size(); + out.write((const char*)&n_entries, sizeof(n_entries)); + for (auto& p : m_stats) { + int len = p.first.size(); + out.write((const char*)&len, sizeof(len)); + out.write(p.first.c_str(), len); + out.write((const char*)&p.second.ncall, sizeof(p.second.ncall)); + int nval = p.second.values.size(); + out.write((const char*)&nval, sizeof(nval)); + if (nval > 0) out.write((const char*)p.second.values.data(), nval*sizeof(float)); + } + if (m_params.verbosity > 0) { + fprintf(stderr, "\n%s: stored collected data after %d chunks in %s\n",__func__,m_last_call,fname); + } +} + +static IMatrixCollector g_collector; + +static void ik_collect_imatrix(const struct ggml_tensor * src0, const struct ggml_tensor * src1) { + g_collector.collect_imatrix(src0, src1); +} + + +struct results_log_softmax { + double log_softmax; + float logit; + float prob; +}; + +static std::vector softmax(const std::vector& logits) { + std::vector probs(logits.size()); + float max_logit = logits[0]; + for (float v : logits) { + max_logit = std::max(max_logit, v); + } + double sum_exp = 0.0; + for (size_t i = 0; i < logits.size(); i++) { + // Subtract the maximum logit value from the current logit value for numerical stability + const float logit = logits[i] - max_logit; + const float exp_logit = expf(logit); + sum_exp += exp_logit; + probs[i] = exp_logit; + } + for (size_t i = 0; i < probs.size(); i++) { + probs[i] /= sum_exp; + } + return probs; +} + +static results_log_softmax log_softmax(int n_vocab, const float * logits, int tok) { + float max_logit = logits[0]; + for (int i = 1; i < n_vocab; ++i) { + max_logit = std::max(max_logit, logits[i]); + } + double sum_exp = 0.0; + for (int i = 0; i < n_vocab; ++i) { + sum_exp += expf(logits[i] - max_logit); + } + return {logits[tok] - max_logit - log(sum_exp), logits[tok], expf(logits[tok] - max_logit) / (float) sum_exp}; +} + +static void process_logits( + int n_vocab, const float * logits, const int * tokens, int n_token, std::vector & workers, + double & nll, double & nll2, float * logit_history, float * prob_history +) { + std::mutex mutex; + int counter = 0; + auto compute = [&mutex, &counter, &nll, &nll2, logit_history, prob_history, n_vocab, logits, tokens, n_token] () { + double local_nll = 0; + double local_nll2 = 0; + while (true) { + std::unique_lock lock(mutex); + int i = counter++; + if (i >= n_token) { + nll += local_nll; nll2 += local_nll2; + break; + } + lock.unlock(); + const results_log_softmax results = log_softmax(n_vocab, logits + i*n_vocab, tokens[i+1]); + const double v = -results.log_softmax; + local_nll += v; + local_nll2 += v*v; + + logit_history[i] = results.logit; + prob_history[i] = results.prob; + } + }; + for (auto & w : workers) { + w = std::thread(compute); + } + compute(); + for (auto & w : workers) { + w.join(); + } +} + +static bool compute_imatrix(llama_context * ctx, const gpt_params & params) { + + const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); + const int n_ctx = llama_n_ctx(ctx); + + auto tim1 = std::chrono::high_resolution_clock::now(); + fprintf(stderr, "%s: tokenizing the input ..\n", __func__); + + std::vector tokens = ::llama_tokenize(ctx, params.prompt, add_bos); + + auto tim2 = std::chrono::high_resolution_clock::now(); + fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast(tim2-tim1).count()); + + if (int(tokens.size()) < 2*n_ctx) { + fprintf(stderr, "%s: you need at least %d tokens for a context of %d tokens\n",__func__,2*n_ctx, + n_ctx); + fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size()); + return false; + } + + std::vector logit_history; + logit_history.resize(tokens.size()); + + std::vector prob_history; + prob_history.resize(tokens.size()); + + const int n_chunk_max = tokens.size() / n_ctx; + + const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); + const int n_vocab = llama_n_vocab(llama_get_model(ctx)); + const int n_batch = params.n_batch; + + int count = 0; + double nll = 0.0; + double nll2 = 0.0; + + fprintf(stderr, "%s: computing over %d chunks with batch_size %d\n", __func__, n_chunk, n_batch); + + std::vector workers(std::thread::hardware_concurrency() - 1); + + for (int i = 0; i < n_chunk; ++i) { + const int start = i * n_ctx; + const int end = start + n_ctx; + + const int num_batches = (n_ctx + n_batch - 1) / n_batch; + + std::vector logits; + + const auto t_start = std::chrono::high_resolution_clock::now(); + + // clear the KV cache + llama_kv_cache_clear(ctx); + + for (int j = 0; j < num_batches; ++j) { + const int batch_start = start + j * n_batch; + const int batch_size = std::min(end - batch_start, n_batch); + + // save original token and restore it after eval + const auto token_org = tokens[batch_start]; + + // add BOS token for the first batch of each chunk + if (add_bos && j == 0) { + tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); + } + + if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return false; + } + + // restore the original token in case it was set to BOS + tokens[batch_start] = token_org; + + const auto * batch_logits = llama_get_logits(ctx); + logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); + } + + const auto t_end = std::chrono::high_resolution_clock::now(); + + if (i == 0) { + const float t_total = std::chrono::duration(t_end - t_start).count(); + fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total); + int total_seconds = (int)(t_total * n_chunk); + if (total_seconds >= 60*60) { + fprintf(stderr, "%d hours ", total_seconds / (60*60)); + total_seconds = total_seconds % (60*60); + } + fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0); + } + + const int first = n_ctx/2; + process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, + workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first); + count += n_ctx - first - 1; + + printf("[%d]%.4lf,", i + 1, std::exp(nll / count)); + fflush(stdout); + } + printf("\n"); + + nll2 /= count; + nll /= count; + const double ppl = exp(nll); + nll2 -= nll * nll; + if (nll2 > 0) { + nll2 = sqrt(nll2/(count-1)); + printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl); + } else { + printf("Unexpected negative standard deviation of log(prob)\n"); + } + + return true; +} + +int main(int argc, char ** argv) { + + StatParams sparams; + std::vector args; + args.push_back(argv[0]); + int iarg = 1; + for (; iarg < argc-1; ++iarg) { + std::string arg{argv[iarg]}; + if (arg == "-o" || arg == "--output-file") { + sparams.ofile = argv[++iarg]; + } + else if (arg == "-ofreq" || arg == "--output-frequency") { + sparams.n_output_frequency = std::stoi(argv[++iarg]); + } + else if (arg == "-ow" || arg == "--output-weight") { + sparams.collect_output_weight = std::stoi(argv[++iarg]); + } + else if (arg == "--verbosity") { + sparams.verbosity = std::stoi(argv[++iarg]); + } else { + args.push_back(argv[iarg]); + } + } + if (iarg < argc) { + args.push_back(argv[iarg]); + } + + gpt_params params; + params.n_batch = 512; + if (!gpt_params_parse(args.size(), args.data(), params)) { + return 1; + } + + g_collector.set_parameters(std::move(sparams)); + + ggml_set_imatrix_collection(ik_collect_imatrix); + + params.logits_all = true; + params.n_batch = std::min(params.n_batch, params.n_ctx); + + print_build_info(); + + if (params.seed == LLAMA_DEFAULT_SEED) { + params.seed = time(NULL); + } + + fprintf(stderr, "%s: seed = %u\n", __func__, params.seed); + + std::mt19937 rng(params.seed); + if (params.random_prompt) { + params.prompt = gpt_random_prompt(rng); + } + + llama_backend_init(params.numa); + + llama_model * model; + llama_context * ctx; + + // load the model and apply lora adapter, if any + std::tie(model, ctx) = llama_init_from_gpt_params(params); + if (model == NULL) { + fprintf(stderr, "%s: error: unable to load model\n", __func__); + return 1; + } + + const int n_ctx_train = llama_n_ctx_train(model); + if (params.n_ctx > n_ctx_train) { + fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n", + __func__, n_ctx_train, params.n_ctx); + } + + // print system information + { + fprintf(stderr, "\n"); + fprintf(stderr, "%s\n", get_system_info(params).c_str()); + } + + bool OK = compute_imatrix(ctx, params); + if (!OK) { + return 1; + } + + g_collector.save_imatrix(); + + llama_print_timings(ctx); + + llama_free(ctx); + llama_free_model(model); + + llama_backend_free(); + + return 0; +} diff --git a/examples/llama.swiftui/llama.swiftui.xcodeproj/project.pbxproj b/examples/llama.swiftui/llama.swiftui.xcodeproj/project.pbxproj index a8848a49fce6d..3950b9e9df843 100644 --- a/examples/llama.swiftui/llama.swiftui.xcodeproj/project.pbxproj +++ b/examples/llama.swiftui/llama.swiftui.xcodeproj/project.pbxproj @@ -8,6 +8,7 @@ /* Begin PBXBuildFile section */ 549479CB2AC9E16000E0F78B /* Metal.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 549479CA2AC9E16000E0F78B /* Metal.framework */; }; + 79E1D9CD2B4CD16E005F8E46 /* InputButton.swift in Sources */ = {isa = PBXBuildFile; fileRef = 79E1D9CC2B4CD16E005F8E46 /* InputButton.swift */; }; 7FA3D2B32B2EA2F600543F92 /* DownloadButton.swift in Sources */ = {isa = PBXBuildFile; fileRef = 7FA3D2B22B2EA2F600543F92 /* DownloadButton.swift */; }; 8A1C83772AC328BD0096AF73 /* llama_swiftuiApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8A1C83762AC328BD0096AF73 /* llama_swiftuiApp.swift */; }; 8A1C83792AC328BD0096AF73 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8A1C83782AC328BD0096AF73 /* ContentView.swift */; }; @@ -22,6 +23,7 @@ /* Begin PBXFileReference section */ 549479CA2AC9E16000E0F78B /* Metal.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Metal.framework; path = System/Library/Frameworks/Metal.framework; sourceTree = SDKROOT; }; + 79E1D9CC2B4CD16E005F8E46 /* InputButton.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = InputButton.swift; sourceTree = ""; }; 7FA3D2B22B2EA2F600543F92 /* DownloadButton.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = DownloadButton.swift; sourceTree = ""; }; 8A1C83732AC328BD0096AF73 /* llama.swiftui.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = llama.swiftui.app; sourceTree = BUILT_PRODUCTS_DIR; }; 8A1C83762AC328BD0096AF73 /* llama_swiftuiApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = llama_swiftuiApp.swift; sourceTree = ""; }; @@ -119,6 +121,7 @@ 7FA3D2B22B2EA2F600543F92 /* DownloadButton.swift */, 8A1C83782AC328BD0096AF73 /* ContentView.swift */, F1FE20E12B465EC900B45541 /* LoadCustomButton.swift */, + 79E1D9CC2B4CD16E005F8E46 /* InputButton.swift */, ); path = UI; sourceTree = ""; @@ -213,6 +216,7 @@ 8A1C83792AC328BD0096AF73 /* ContentView.swift in Sources */, 8A1C83772AC328BD0096AF73 /* llama_swiftuiApp.swift in Sources */, 7FA3D2B32B2EA2F600543F92 /* DownloadButton.swift in Sources */, + 79E1D9CD2B4CD16E005F8E46 /* InputButton.swift in Sources */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -345,7 +349,7 @@ CLANG_ENABLE_MODULES = YES; CODE_SIGN_STYLE = Automatic; CURRENT_PROJECT_VERSION = 1; - DEVELOPMENT_TEAM = STLSG3FG8Q; + DEVELOPMENT_TEAM = K5UQJPP73A; ENABLE_PREVIEWS = YES; GENERATE_INFOPLIST_FILE = YES; INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES; @@ -377,7 +381,7 @@ CLANG_ENABLE_MODULES = YES; CODE_SIGN_STYLE = Automatic; CURRENT_PROJECT_VERSION = 1; - DEVELOPMENT_TEAM = STLSG3FG8Q; + DEVELOPMENT_TEAM = K5UQJPP73A; ENABLE_PREVIEWS = YES; GENERATE_INFOPLIST_FILE = YES; INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES; diff --git a/examples/llama.swiftui/llama.swiftui/Models/LlamaState.swift b/examples/llama.swiftui/llama.swiftui/Models/LlamaState.swift index 17cb5b9dde942..5bde1891727ce 100644 --- a/examples/llama.swiftui/llama.swiftui/Models/LlamaState.swift +++ b/examples/llama.swiftui/llama.swiftui/Models/LlamaState.swift @@ -1,9 +1,19 @@ import Foundation +struct Model: Identifiable { + var id = UUID() + var name: String + var url: String + var filename: String + var status: String? +} + @MainActor class LlamaState: ObservableObject { @Published var messageLog = "" @Published var cacheCleared = false + @Published var downloadedModels: [Model] = [] + @Published var undownloadedModels: [Model] = [] let NS_PER_S = 1_000_000_000.0 private var llamaContext: LlamaContext? @@ -13,23 +23,102 @@ class LlamaState: ObservableObject { } init() { + loadModelsFromDisk() + loadDefaultModels() + } + + private func loadModelsFromDisk() { + do { + let documentsURL = getDocumentsDirectory() + let modelURLs = try FileManager.default.contentsOfDirectory(at: documentsURL, includingPropertiesForKeys: nil, options: [.skipsHiddenFiles, .skipsSubdirectoryDescendants]) + for modelURL in modelURLs { + let modelName = modelURL.deletingPathExtension().lastPathComponent + downloadedModels.append(Model(name: modelName, url: "", filename: modelURL.lastPathComponent, status: "downloaded")) + } + } catch { + print("Error loading models from disk: \(error)") + } + } + + private func loadDefaultModels() { do { try loadModel(modelUrl: defaultModelUrl) } catch { messageLog += "Error!\n" } + + for model in defaultModels { + let fileURL = getDocumentsDirectory().appendingPathComponent(model.filename) + if FileManager.default.fileExists(atPath: fileURL.path) { + + } else { + var undownloadedModel = model + undownloadedModel.status = "download" + undownloadedModels.append(undownloadedModel) + } + } } + func getDocumentsDirectory() -> URL { + let paths = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask) + return paths[0] + } + private let defaultModels: [Model] = [ + Model(name: "TinyLlama-1.1B (Q4_0, 0.6 GiB)",url: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q4_0.gguf?download=true",filename: "tinyllama-1.1b-1t-openorca.Q4_0.gguf", status: "download"), + Model( + name: "TinyLlama-1.1B Chat (Q8_0, 1.1 GiB)", + url: "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/tinyllama-1.1b-chat-v1.0.Q8_0.gguf?download=true", + filename: "tinyllama-1.1b-chat-v1.0.Q8_0.gguf", status: "download" + ), + + Model( + name: "TinyLlama-1.1B (F16, 2.2 GiB)", + url: "https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf?download=true", + filename: "tinyllama-1.1b-f16.gguf", status: "download" + ), + + Model( + name: "Phi-2.7B (Q4_0, 1.6 GiB)", + url: "https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q4_0.gguf?download=true", + filename: "phi-2-q4_0.gguf", status: "download" + ), + + Model( + name: "Phi-2.7B (Q8_0, 2.8 GiB)", + url: "https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q8_0.gguf?download=true", + filename: "phi-2-q8_0.gguf", status: "download" + ), + + Model( + name: "Mistral-7B-v0.1 (Q4_0, 3.8 GiB)", + url: "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_0.gguf?download=true", + filename: "mistral-7b-v0.1.Q4_0.gguf", status: "download" + ), + Model( + name: "OpenHermes-2.5-Mistral-7B (Q3_K_M, 3.52 GiB)", + url: "https://huggingface.co/TheBloke/OpenHermes-2.5-Mistral-7B-GGUF/resolve/main/openhermes-2.5-mistral-7b.Q3_K_M.gguf?download=true", + filename: "openhermes-2.5-mistral-7b.Q3_K_M.gguf", status: "download" + ) + ] func loadModel(modelUrl: URL?) throws { if let modelUrl { messageLog += "Loading model...\n" llamaContext = try LlamaContext.create_context(path: modelUrl.path()) messageLog += "Loaded model \(modelUrl.lastPathComponent)\n" + + // Assuming that the model is successfully loaded, update the downloaded models + updateDownloadedModels(modelName: modelUrl.lastPathComponent, status: "downloaded") } else { messageLog += "Load a model from the list below\n" } } + + private func updateDownloadedModels(modelName: String, status: String) { + undownloadedModels.removeAll { $0.name == modelName } + } + + func complete(text: String) async { guard let llamaContext else { return diff --git a/examples/llama.swiftui/llama.swiftui/UI/ContentView.swift b/examples/llama.swiftui/llama.swiftui/UI/ContentView.swift index 7c81ea256ffd7..30c2dc4310210 100644 --- a/examples/llama.swiftui/llama.swiftui/UI/ContentView.swift +++ b/examples/llama.swiftui/llama.swiftui/UI/ContentView.swift @@ -2,115 +2,57 @@ import SwiftUI struct ContentView: View { @StateObject var llamaState = LlamaState() - @State private var multiLineText = "" - - private static func cleanupModelCaches() { - // Delete all models (*.gguf) - let fileManager = FileManager.default - let documentsUrl = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask)[0] - do { - let fileURLs = try fileManager.contentsOfDirectory(at: documentsUrl, includingPropertiesForKeys: nil) - for fileURL in fileURLs { - if fileURL.pathExtension == "gguf" { - try fileManager.removeItem(at: fileURL) - } - } - } catch { - print("Error while enumerating files \(documentsUrl.path): \(error.localizedDescription)") - } - } + @State private var showingHelp = false // To track if Help Sheet should be shown var body: some View { - VStack { - ScrollView(.vertical, showsIndicators: true) { - Text(llamaState.messageLog) - .font(.system(size: 12)) - .frame(maxWidth: .infinity, alignment: .leading) - .padding() - .onTapGesture { - UIApplication.shared.sendAction(#selector(UIResponder.resignFirstResponder), to: nil, from: nil, for: nil) + NavigationView { + VStack { + ScrollView(.vertical, showsIndicators: true) { + Text(llamaState.messageLog) + .font(.system(size: 12)) + .frame(maxWidth: .infinity, alignment: .leading) + .padding() + .onTapGesture { + UIApplication.shared.sendAction(#selector(UIResponder.resignFirstResponder), to: nil, from: nil, for: nil) + } } - } - TextEditor(text: $multiLineText) - .frame(height: 80) - .padding() - .border(Color.gray, width: 0.5) + TextEditor(text: $multiLineText) + .frame(height: 80) + .padding() + .border(Color.gray, width: 0.5) - HStack { - Button("Send") { - sendText() - } + HStack { + Button("Send") { + sendText() + } - Button("Bench") { - bench() - } + Button("Bench") { + bench() + } - Button("Clear") { - clear() - } + Button("Clear") { + clear() + } - Button("Copy") { - UIPasteboard.general.string = llamaState.messageLog + Button("Copy") { + UIPasteboard.general.string = llamaState.messageLog + } } - }.buttonStyle(.bordered) - - VStack(alignment: .leading) { - DownloadButton( - llamaState: llamaState, - modelName: "TinyLlama-1.1B (Q4_0, 0.6 GiB)", - modelUrl: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q4_0.gguf?download=true", - filename: "tinyllama-1.1b-1t-openorca.Q4_0.gguf" - ) - - DownloadButton( - llamaState: llamaState, - modelName: "TinyLlama-1.1B (Q8_0, 1.1 GiB)", - modelUrl: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q8_0.gguf?download=true", - filename: "tinyllama-1.1b-1t-openorca.Q8_0.gguf" - ) - - DownloadButton( - llamaState: llamaState, - modelName: "TinyLlama-1.1B (F16, 2.2 GiB)", - modelUrl: "https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf?download=true", - filename: "tinyllama-1.1b-f16.gguf" - ) - - DownloadButton( - llamaState: llamaState, - modelName: "Phi-2.7B (Q4_0, 1.6 GiB)", - modelUrl: "https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q4_0.gguf?download=true", - filename: "phi-2-q4_0.gguf" - ) - - DownloadButton( - llamaState: llamaState, - modelName: "Phi-2.7B (Q8_0, 2.8 GiB)", - modelUrl: "https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q8_0.gguf?download=true", - filename: "phi-2-q8_0.gguf" - ) - - DownloadButton( - llamaState: llamaState, - modelName: "Mistral-7B-v0.1 (Q4_0, 3.8 GiB)", - modelUrl: "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_0.gguf?download=true", - filename: "mistral-7b-v0.1.Q4_0.gguf" - ) - - Button("Clear downloaded models") { - ContentView.cleanupModelCaches() - llamaState.cacheCleared = true + .buttonStyle(.bordered) + .padding() + + NavigationLink(destination: DrawerView(llamaState: llamaState)) { + Text("View Models") } + .padding() - LoadCustomButton(llamaState: llamaState) } - .padding(.top, 4) - .font(.system(size: 12)) - .frame(maxWidth: .infinity, alignment: .leading) + .padding() + .navigationBarTitle("Model Settings", displayMode: .inline) + } - .padding() } func sendText() { @@ -131,8 +73,73 @@ struct ContentView: View { await llamaState.clear() } } + struct DrawerView: View { + + @ObservedObject var llamaState: LlamaState + @State private var showingHelp = false + func delete(at offsets: IndexSet) { + offsets.forEach { offset in + let model = llamaState.downloadedModels[offset] + let fileURL = getDocumentsDirectory().appendingPathComponent(model.filename) + do { + try FileManager.default.removeItem(at: fileURL) + } catch { + print("Error deleting file: \(error)") + } + } + + // Remove models from downloadedModels array + llamaState.downloadedModels.remove(atOffsets: offsets) + } + + func getDocumentsDirectory() -> URL { + let paths = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask) + return paths[0] + } + var body: some View { + List { + Section(header: Text("Download Models From Hugging Face")) { + HStack { + InputButton(llamaState: llamaState) + } + } + Section(header: Text("Downloaded Models")) { + ForEach(llamaState.downloadedModels) { model in + DownloadButton(llamaState: llamaState, modelName: model.name, modelUrl: model.url, filename: model.filename) + } + .onDelete(perform: delete) + } + Section(header: Text("Default Models")) { + ForEach(llamaState.undownloadedModels) { model in + DownloadButton(llamaState: llamaState, modelName: model.name, modelUrl: model.url, filename: model.filename) + } + } + + } + .listStyle(GroupedListStyle()) + .navigationBarTitle("Model Settings", displayMode: .inline).toolbar { + ToolbarItem(placement: .navigationBarTrailing) { + Button("Help") { + showingHelp = true + } + } + }.sheet(isPresented: $showingHelp) { // Sheet for help modal + VStack(alignment: .leading) { + VStack(alignment: .leading) { + Text("1. Make sure the model is in GGUF Format") + .padding() + Text("2. Copy the download link of the quantized model") + .padding() + } + Spacer() + } + } + } + } } -//#Preview { -// ContentView() -//} +struct ContentView_Previews: PreviewProvider { + static var previews: some View { + ContentView() + } +} diff --git a/examples/llama.swiftui/llama.swiftui/UI/DownloadButton.swift b/examples/llama.swiftui/llama.swiftui/UI/DownloadButton.swift index c9f322ca14e72..4584d6eaa3d32 100644 --- a/examples/llama.swiftui/llama.swiftui/UI/DownloadButton.swift +++ b/examples/llama.swiftui/llama.swiftui/UI/DownloadButton.swift @@ -53,6 +53,8 @@ struct DownloadButton: View { llamaState.cacheCleared = false + let model = Model(name: modelName, url: modelUrl, filename: filename, status: "downloaded") + llamaState.downloadedModels.append(model) status = "downloaded" } } catch let err { diff --git a/examples/llama.swiftui/llama.swiftui/UI/InputButton.swift b/examples/llama.swiftui/llama.swiftui/UI/InputButton.swift new file mode 100644 index 0000000000000..c5ffbad4ec331 --- /dev/null +++ b/examples/llama.swiftui/llama.swiftui/UI/InputButton.swift @@ -0,0 +1,131 @@ +import SwiftUI + +struct InputButton: View { + @ObservedObject var llamaState: LlamaState + @State private var inputLink: String = "" + @State private var status: String = "download" + @State private var filename: String = "" + + @State private var downloadTask: URLSessionDownloadTask? + @State private var progress = 0.0 + @State private var observation: NSKeyValueObservation? + + private static func extractModelInfo(from link: String) -> (modelName: String, filename: String)? { + guard let url = URL(string: link), + let lastPathComponent = url.lastPathComponent.components(separatedBy: ".").first, + let modelName = lastPathComponent.components(separatedBy: "-").dropLast().joined(separator: "-").removingPercentEncoding, + let filename = lastPathComponent.removingPercentEncoding else { + return nil + } + + return (modelName, filename) + } + + private static func getFileURL(filename: String) -> URL { + FileManager.default.urls(for: .documentDirectory, in: .userDomainMask)[0].appendingPathComponent(filename) + } + + private func download() { + guard let extractedInfo = InputButton.extractModelInfo(from: inputLink) else { + // Handle invalid link or extraction failure + return + } + + let (modelName, filename) = extractedInfo + self.filename = filename // Set the state variable + + status = "downloading" + print("Downloading model \(modelName) from \(inputLink)") + guard let url = URL(string: inputLink) else { return } + let fileURL = InputButton.getFileURL(filename: filename) + + downloadTask = URLSession.shared.downloadTask(with: url) { temporaryURL, response, error in + if let error = error { + print("Error: \(error.localizedDescription)") + return + } + + guard let response = response as? HTTPURLResponse, (200...299).contains(response.statusCode) else { + print("Server error!") + return + } + + do { + if let temporaryURL = temporaryURL { + try FileManager.default.copyItem(at: temporaryURL, to: fileURL) + print("Writing to \(filename) completed") + + llamaState.cacheCleared = false + + let model = Model(name: modelName, url: self.inputLink, filename: filename, status: "downloaded") + llamaState.downloadedModels.append(model) + status = "downloaded" + } + } catch let err { + print("Error: \(err.localizedDescription)") + } + } + + observation = downloadTask?.progress.observe(\.fractionCompleted) { progress, _ in + self.progress = progress.fractionCompleted + } + + downloadTask?.resume() + } + + var body: some View { + VStack { + HStack { + TextField("Paste Quantized Download Link", text: $inputLink) + .textFieldStyle(RoundedBorderTextFieldStyle()) + + Button(action: { + downloadTask?.cancel() + status = "download" + }) { + Text("Cancel") + } + } + + if status == "download" { + Button(action: download) { + Text("Download Custom Model") + } + } else if status == "downloading" { + Button(action: { + downloadTask?.cancel() + status = "download" + }) { + Text("Downloading \(Int(progress * 100))%") + } + } else if status == "downloaded" { + Button(action: { + let fileURL = InputButton.getFileURL(filename: self.filename) + if !FileManager.default.fileExists(atPath: fileURL.path) { + download() + return + } + do { + try llamaState.loadModel(modelUrl: fileURL) + } catch let err { + print("Error: \(err.localizedDescription)") + } + }) { + Text("Load Custom Model") + } + } else { + Text("Unknown status") + } + } + .onDisappear() { + downloadTask?.cancel() + } + .onChange(of: llamaState.cacheCleared) { newValue in + if newValue { + downloadTask?.cancel() + let fileURL = InputButton.getFileURL(filename: self.filename) + status = FileManager.default.fileExists(atPath: fileURL.path) ? "downloaded" : "download" + } + } + } +} diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index cfb79e78940a7..2ae8853d3d5da 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -126,24 +126,7 @@ static struct ggml_tensor * get_tensor(struct ggml_context * ctx, const std::str } static std::string get_ftype(int ftype) { - switch (ftype) { - case 0: - return "f32"; - case 1: - return "f16"; - case 2: - return "q4_0"; - case 3: - return "q4_1"; - case 6: - return "q5_0"; - case 7: - return "q5_1"; - case 8: - return "q8_0"; - default: - throw std::runtime_error(format("%s: Unrecognized file type: %d\n", __func__, ftype)); - } + return ggml_type_name(static_cast(ftype)); } // @@ -533,6 +516,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { buffer_size += n_tensors * 128 /* CLIP PADDING */; clip_ctx * new_clip = new clip_ctx; + #ifdef GGML_USE_CUBLAS new_clip->backend = ggml_backend_cuda_init(0); printf("%s: CLIP using CUDA backend\n", __func__); @@ -543,6 +527,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { printf("%s: CLIP using Metal backend\n", __func__); #endif + if (!new_clip->backend) { new_clip->backend = ggml_backend_cpu_init(); printf("%s: CLIP using CPU backend\n", __func__); @@ -931,26 +916,8 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i ggml_type type = GGML_TYPE_Q4_1; - switch (itype) { - case 2: - type = GGML_TYPE_Q4_0; - break; - case 3: - type = GGML_TYPE_Q4_1; - break; - case 6: - type = GGML_TYPE_Q5_0; - break; - case 7: - type = GGML_TYPE_Q5_1; - break; - case 8: - type = GGML_TYPE_Q8_0; - break; - default: - fprintf(stderr, "%s: invalid quantization type %d\n", __func__, itype); - return false; - }; + assert(itype < GGML_TYPE_COUNT); + type = static_cast(itype); auto * ctx_clip = clip_model_load(fname_inp, 2); @@ -1010,6 +977,10 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i if (quantize) { new_type = type; + if (new_type >= GGML_TYPE_Q2_K && name.find("embd") != std::string::npos) { + new_type = GGML_TYPE_Q8_0; // ggml_get_rows needs non K type + // fprintf(stderr, "%s: quantizing %s to %s\n", __func__, name.c_str(), ggml_type_name(new_type)); + } const size_t n_elms = ggml_nelements(cur); float * f32_data; @@ -1054,6 +1025,21 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i case GGML_TYPE_Q8_0: { new_size = ggml_quantize_q8_0(f32_data, new_data, n_elms, cur->ne[0], hist_cur.data()); } break; + case GGML_TYPE_Q2_K: { + new_size = ggml_quantize_q2_K(f32_data, new_data, n_elms, cur->ne[0], hist_cur.data()); + } break; + case GGML_TYPE_Q3_K: { + new_size = ggml_quantize_q3_K(f32_data, new_data, n_elms, cur->ne[0], hist_cur.data()); + } break; + case GGML_TYPE_Q4_K: { + new_size = ggml_quantize_q4_K(f32_data, new_data, n_elms, cur->ne[0], hist_cur.data()); + } break; + case GGML_TYPE_Q5_K: { + new_size = ggml_quantize_q5_K(f32_data, new_data, n_elms, cur->ne[0], hist_cur.data()); + } break; + case GGML_TYPE_Q6_K: { + new_size = ggml_quantize_q6_K(f32_data, new_data, n_elms, cur->ne[0], hist_cur.data()); + } break; default: { fprintf(stderr, "%s: unsupported quantization type %d\n", __func__, new_type); return false; diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 1f5fcff937fc2..a046f9cbedfbd 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -501,7 +501,7 @@ int main(int argc, char ** argv) { while ((n_remain != 0 && !is_antiprompt) || params.interactive) { // predict if (!embd.empty()) { - // Note: n_ctx - 4 here is to match the logic for commandline prompt handling via + // Note: (n_ctx - 4) here is to match the logic for commandline prompt handling via // --prompt or --file which uses the same value. int max_embd_size = n_ctx - 4; @@ -651,6 +651,10 @@ int main(int argc, char ** argv) { n_past += n_eval; LOG("n_past = %d\n", n_past); + // Display total tokens alongside total time + if (params.n_print > 0 && n_past % params.n_print == 0) { + LOG_TEE("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx); + } } if (!embd.empty() && !path_session.empty()) { diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index be0b2fe1eb963..5b1415d21913c 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -19,6 +19,7 @@ static const std::vector QUANT_OPTIONS = { { "Q5_0", LLAMA_FTYPE_MOSTLY_Q5_0, " 4.33G, +0.0683 ppl @ LLaMA-v1-7B", }, { "Q5_1", LLAMA_FTYPE_MOSTLY_Q5_1, " 4.70G, +0.0349 ppl @ LLaMA-v1-7B", }, { "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", }, + { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", }, { "Q3_K", LLAMA_FTYPE_MOSTLY_Q3_K_M, "alias for Q3_K_M" }, { "Q3_K_S", LLAMA_FTYPE_MOSTLY_Q3_K_S, " 2.75G, +0.5551 ppl @ LLaMA-v1-7B", }, { "Q3_K_M", LLAMA_FTYPE_MOSTLY_Q3_K_M, " 3.07G, +0.2496 ppl @ LLaMA-v1-7B", }, diff --git a/examples/server/README.md b/examples/server/README.md index d85a14f891bc4..fd3034b99c3d2 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -23,7 +23,8 @@ Command line options: - `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`. - `--port`: Set the port to listen. Default: `8080`. - `--path`: path from which to serve static files (default examples/server/public) -- `--api-key`: Set an api key for request authorization. By default the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token. +- `--api-key`: Set an api key for request authorization. By default the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token. May be used multiple times to enable multiple valid keys. +- `--api-key-file`: path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access. May be used in conjunction with `--api-key`'s. - `--embedding`: Enable embedding extraction, Default: disabled. - `-np N`, `--parallel N`: Set the number of slots for process requests (default: 1) - `-cb`, `--cont-batching`: enable continuous batching (a.k.a dynamic batching) (default: disabled) @@ -110,6 +111,10 @@ node index.js ``` ## API Endpoints +- **GET** `/health`: Returns the current state of the server: + - `{"status": "loading model"}` if the model is still being loaded. + - `{"status": "error"}` if the model failed to load. + - `{"status": "ok"}` if the model is successfully loaded and the server is ready for further requests mentioned below. - **POST** `/completion`: Given a `prompt`, it returns the predicted completion. diff --git a/examples/server/server.cpp b/examples/server/server.cpp index d99dc75cf6f00..74dfccb360744 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -27,6 +27,7 @@ #include #include #include +#include #ifndef SERVER_VERBOSE #define SERVER_VERBOSE 1 @@ -39,7 +40,7 @@ using json = nlohmann::json; struct server_params { std::string hostname = "127.0.0.1"; - std::string api_key; + std::vector api_keys; std::string public_path = "examples/server/public"; int32_t port = 8080; int32_t read_timeout = 600; @@ -147,9 +148,15 @@ static std::vector base64_decode(const std::string & encoded_string) // parallel // +enum server_state { + SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet + SERVER_STATE_READY, // Server is ready and model is loaded + SERVER_STATE_ERROR // An error occurred, load_model failed +}; + enum task_type { - COMPLETION_TASK, - CANCEL_TASK + TASK_TYPE_COMPLETION, + TASK_TYPE_CANCEL, }; struct task_server { @@ -1396,11 +1403,11 @@ struct llama_server_context task.data = std::move(data); task.infill_mode = infill; task.embedding_mode = embedding; - task.type = COMPLETION_TASK; + task.type = TASK_TYPE_COMPLETION; task.multitask_id = multitask_id; // when a completion task's prompt array is not a singleton, we split it into multiple requests - if (task.data.at("prompt").size() > 1) + if (task.data.count("prompt") && task.data.at("prompt").size() > 1) { lock.unlock(); // entering new func scope return split_multiprompt_task(task); @@ -1518,7 +1525,7 @@ struct llama_server_context std::unique_lock lock(mutex_tasks); task_server task; task.id = id_gen++; - task.type = CANCEL_TASK; + task.type = TASK_TYPE_CANCEL; task.target_id = task_id; queue_tasks.push_back(task); condition_tasks.notify_one(); @@ -1554,7 +1561,7 @@ struct llama_server_context queue_tasks.erase(queue_tasks.begin()); switch (task.type) { - case COMPLETION_TASK: { + case TASK_TYPE_COMPLETION: { llama_client_slot *slot = get_slot(json_value(task.data, "slot_id", -1)); if (slot == nullptr) { @@ -1571,9 +1578,9 @@ struct llama_server_context slot->reset(); - slot->infill = task.infill_mode; - slot->embedding = task.embedding_mode; - slot->task_id = task.id; + slot->infill = task.infill_mode; + slot->embedding = task.embedding_mode; + slot->task_id = task.id; slot->multitask_id = task.multitask_id; if (!launch_slot_with_data(slot, task.data)) @@ -1583,7 +1590,7 @@ struct llama_server_context break; } } break; - case CANCEL_TASK: { // release slot linked with the task id + case TASK_TYPE_CANCEL: { // release slot linked with the task id for (auto & slot : slots) { if (slot.task_id == task.target_id) @@ -1725,7 +1732,8 @@ struct llama_server_context const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get().empty()) || !slot.images.empty(); // empty prompt passed -> release the slot and send empty response - if (slot.state == IDLE && slot.command == LOAD_PROMPT && !has_prompt) + // note: infill mode allows empty prompt + if (slot.state == IDLE && slot.command == LOAD_PROMPT && !has_prompt && !slot.infill) { slot.release(); slot.print_timings(); @@ -2015,6 +2023,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" --port PORT port to listen (default (default: %d)\n", sparams.port); printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str()); printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n"); + printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n"); printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled"); printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel); @@ -2075,7 +2084,28 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, invalid_param = true; break; } - sparams.api_key = argv[i]; + sparams.api_keys.push_back(argv[i]); + } + else if (arg == "--api-key-file") + { + if (++i >= argc) + { + invalid_param = true; + break; + } + std::ifstream key_file(argv[i]); + if (!key_file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + invalid_param = true; + break; + } + std::string key; + while (std::getline(key_file, key)) { + if (key.size() > 0) { + sparams.api_keys.push_back(key); + } + } + key_file.close(); } else if (arg == "--timeout" || arg == "-to") { @@ -2454,7 +2484,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } } - static std::string random_string() { static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"); @@ -2510,7 +2539,7 @@ json oaicompat_completion_params_parse( // // https://platform.openai.com/docs/api-reference/chat/create llama_sampling_params default_sparams; - llama_params["model"] = json_value(body, "model", std::string("uknown")); + llama_params["model"] = json_value(body, "model", std::string("unknown")); llama_params["prompt"] = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt' llama_params["cache_prompt"] = json_value(body, "cache_prompt", false); llama_params["temperature"] = json_value(body, "temperature", 0.0); @@ -2582,8 +2611,8 @@ static json format_final_response_oaicompat(const json &request, const task_resu {"object", streaming ? "chat.completion.chunk" : "chat.completion"}, {"usage", json{{"completion_tokens", num_tokens_predicted}, - {"prompt_tokens", num_prompt_tokens}, - {"total_tokens", num_tokens_predicted + num_prompt_tokens}}}, + {"prompt_tokens", num_prompt_tokens}, + {"total_tokens", num_tokens_predicted + num_prompt_tokens}}}, {"id", gen_chatcmplid()}}; if (server_verbose) { @@ -2791,20 +2820,131 @@ int main(int argc, char **argv) {"system_info", llama_print_system_info()}, }); - // load the model - if (!llama.load_model(params)) + httplib::Server svr; + + std::atomic state{SERVER_STATE_LOADING_MODEL}; + + svr.set_default_headers({{"Server", "llama.cpp"}}); + + // CORS preflight + svr.Options(R"(.*)", [](const httplib::Request &req, httplib::Response &res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + res.set_header("Access-Control-Allow-Credentials", "true"); + res.set_header("Access-Control-Allow-Methods", "POST"); + res.set_header("Access-Control-Allow-Headers", "*"); + }); + + svr.Get("/health", [&](const httplib::Request&, httplib::Response& res) { + server_state current_state = state.load(); + switch(current_state) { + case SERVER_STATE_READY: + res.set_content(R"({"status": "ok"})", "application/json"); + res.status = 200; // HTTP OK + break; + case SERVER_STATE_LOADING_MODEL: + res.set_content(R"({"status": "loading model"})", "application/json"); + res.status = 503; // HTTP Service Unavailable + break; + case SERVER_STATE_ERROR: + res.set_content(R"({"status": "error", "error": "Model failed to load"})", "application/json"); + res.status = 500; // HTTP Internal Server Error + break; + } + }); + + svr.set_logger(log_server_request); + + svr.set_exception_handler([](const httplib::Request &, httplib::Response &res, std::exception_ptr ep) + { + const char fmt[] = "500 Internal Server Error\n%s"; + char buf[BUFSIZ]; + try + { + std::rethrow_exception(std::move(ep)); + } + catch (std::exception &e) + { + snprintf(buf, sizeof(buf), fmt, e.what()); + } + catch (...) + { + snprintf(buf, sizeof(buf), fmt, "Unknown Exception"); + } + res.set_content(buf, "text/plain; charset=utf-8"); + res.status = 500; + }); + + svr.set_error_handler([](const httplib::Request &, httplib::Response &res) + { + if (res.status == 401) + { + res.set_content("Unauthorized", "text/plain; charset=utf-8"); + } + if (res.status == 400) + { + res.set_content("Invalid request", "text/plain; charset=utf-8"); + } + else if (res.status == 404) + { + res.set_content("File Not Found", "text/plain; charset=utf-8"); + res.status = 404; + } + }); + + // set timeouts and change hostname and port + svr.set_read_timeout (sparams.read_timeout); + svr.set_write_timeout(sparams.write_timeout); + + if (!svr.bind_to_port(sparams.hostname, sparams.port)) { + fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port); return 1; } - llama.initialize(); + // Set the base directory for serving static files + svr.set_base_dir(sparams.public_path); - httplib::Server svr; + // to make it ctrl+clickable: + LOG_TEE("\nllama server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port); + + std::unordered_map log_data; + log_data["hostname"] = sparams.hostname; + log_data["port"] = std::to_string(sparams.port); + + if (sparams.api_keys.size() == 1) { + log_data["api_key"] = "api_key: ****" + sparams.api_keys[0].substr(sparams.api_keys[0].length() - 4); + } else if (sparams.api_keys.size() > 1) { + log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded"; + } + + LOG_INFO("HTTP server listening", log_data); + // run the HTTP server in a thread - see comment below + std::thread t([&]() + { + if (!svr.listen_after_bind()) + { + state.store(SERVER_STATE_ERROR); + return 1; + } + + return 0; + }); + + // load the model + if (!llama.load_model(params)) + { + state.store(SERVER_STATE_ERROR); + return 1; + } else { + llama.initialize(); + state.store(SERVER_STATE_READY); + LOG_INFO("model loaded", {}); + } // Middleware for API key validation auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool { // If API key is not set, skip validation - if (sparams.api_key.empty()) { + if (sparams.api_keys.empty()) { return true; } @@ -2813,7 +2953,7 @@ int main(int argc, char **argv) std::string prefix = "Bearer "; if (auth_header.substr(0, prefix.size()) == prefix) { std::string received_api_key = auth_header.substr(prefix.size()); - if (received_api_key == sparams.api_key) { + if (std::find(sparams.api_keys.begin(), sparams.api_keys.end(), received_api_key) != sparams.api_keys.end()) { return true; // API key is valid } } @@ -2827,10 +2967,6 @@ int main(int argc, char **argv) return false; }; - svr.set_default_headers({{"Server", "llama.cpp"}, - {"Access-Control-Allow-Origin", "*"}, - {"Access-Control-Allow-Headers", "content-type"}}); - // this is only called if no index.html is found in the public --path svr.Get("/", [](const httplib::Request &, httplib::Response &res) { @@ -2859,9 +2995,9 @@ int main(int argc, char **argv) return false; }); - svr.Get("/props", [&llama](const httplib::Request & /*req*/, httplib::Response &res) + svr.Get("/props", [&llama](const httplib::Request & req, httplib::Response &res) { - res.set_header("Access-Control-Allow-Origin", "*"); + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = { { "user_name", llama.name_user.c_str() }, { "assistant_name", llama.name_assistant.c_str() } @@ -2871,6 +3007,7 @@ int main(int argc, char **argv) svr.Post("/completion", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!validate_api_key(req, res)) { return; } @@ -2938,10 +3075,9 @@ int main(int argc, char **argv) } }); - - - svr.Get("/v1/models", [¶ms](const httplib::Request&, httplib::Response& res) + svr.Get("/v1/models", [¶ms](const httplib::Request& req, httplib::Response& res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); std::time_t t = std::time(0); json models = { @@ -2959,9 +3095,11 @@ int main(int argc, char **argv) res.set_content(models.dump(), "application/json; charset=utf-8"); }); + // TODO: add mount point without "/v1" prefix -- how? svr.Post("/v1/chat/completions", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!validate_api_key(req, res)) { return; } @@ -3035,6 +3173,7 @@ int main(int argc, char **argv) svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!validate_api_key(req, res)) { return; } @@ -3107,6 +3246,7 @@ int main(int argc, char **argv) svr.Post("/tokenize", [&llama](const httplib::Request &req, httplib::Response &res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); std::vector tokens; if (body.count("content") != 0) @@ -3119,6 +3259,7 @@ int main(int argc, char **argv) svr.Post("/detokenize", [&llama](const httplib::Request &req, httplib::Response &res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); std::string content; if (body.count("tokens") != 0) @@ -3133,6 +3274,7 @@ int main(int argc, char **argv) svr.Post("/embedding", [&llama](const httplib::Request &req, httplib::Response &res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); json prompt; if (body.count("content") != 0) @@ -3158,81 +3300,6 @@ int main(int argc, char **argv) return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); }); - svr.set_logger(log_server_request); - - svr.set_exception_handler([](const httplib::Request &, httplib::Response &res, std::exception_ptr ep) - { - const char fmt[] = "500 Internal Server Error\n%s"; - char buf[BUFSIZ]; - try - { - std::rethrow_exception(std::move(ep)); - } - catch (std::exception &e) - { - snprintf(buf, sizeof(buf), fmt, e.what()); - } - catch (...) - { - snprintf(buf, sizeof(buf), fmt, "Unknown Exception"); - } - res.set_content(buf, "text/plain; charset=utf-8"); - res.status = 500; - }); - - svr.set_error_handler([](const httplib::Request &, httplib::Response &res) - { - if (res.status == 401) - { - res.set_content("Unauthorized", "text/plain; charset=utf-8"); - } - if (res.status == 400) - { - res.set_content("Invalid request", "text/plain; charset=utf-8"); - } - else if (res.status == 404) - { - res.set_content("File Not Found", "text/plain; charset=utf-8"); - res.status = 404; - } - }); - - // set timeouts and change hostname and port - svr.set_read_timeout (sparams.read_timeout); - svr.set_write_timeout(sparams.write_timeout); - - if (!svr.bind_to_port(sparams.hostname, sparams.port)) - { - fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port); - return 1; - } - - // Set the base directory for serving static files - svr.set_base_dir(sparams.public_path); - - // to make it ctrl+clickable: - LOG_TEE("\nllama server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port); - - std::unordered_map log_data; - log_data["hostname"] = sparams.hostname; - log_data["port"] = std::to_string(sparams.port); - - if (!sparams.api_key.empty()) { - log_data["api_key"] = "api_key: ****" + sparams.api_key.substr(sparams.api_key.length() - 4); - } - - LOG_INFO("HTTP server listening", log_data); - // run the HTTP server in a thread - see comment below - std::thread t([&]() - { - if (!svr.listen_after_bind()) - { - return 1; - } - - return 0; - }); - // GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!? // "Bus error: 10" - this is on macOS, it does not crash on Linux //std::thread t2([&]() diff --git a/ggml-cuda.cu b/ggml-cuda.cu index d3d519454f715..a0410b0fc2128 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -116,7 +116,7 @@ #include "ggml.h" #include "ggml-backend-impl.h" -#define CUDART_CI 11070 // CUDA 11.7, version used for the Github CI +#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed) #define CC_PASCAL 600 #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products @@ -488,6 +488,15 @@ typedef struct { } block_iq2_xxs; static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding"); +#define QR2_XS 8 +#define QI2_XS (QK_K / (4*QR2_XS)) +typedef struct { + half d; + uint16_t qs[QK_K/8]; + uint8_t scales[QK_K/32]; +} block_iq2_xs; +static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding"); + #define WARP_SIZE 32 #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses @@ -620,7 +629,7 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { } static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_CI +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); @@ -629,7 +638,7 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { #else (void) x; bad_arch(); -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_CI +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX } static __device__ __forceinline__ float op_repeat(const float a, const float b) { @@ -1331,7 +1340,7 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t #endif } -static const __device__ uint64_t kgrid_iq2xxs[256] = { +static const __device__ uint64_t iq2xxs_grid[256] = { 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819, @@ -1398,6 +1407,137 @@ static const __device__ uint64_t kgrid_iq2xxs[256] = { 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908, }; +static const __device__ uint64_t iq2xs_grid[512] = { + 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, + 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b, + 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919, + 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b, + 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919, + 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808, + 0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819, + 0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819, + 0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, + 0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b, + 0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b, + 0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908, + 0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908, + 0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919, + 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808, + 0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919, + 0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908, + 0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, + 0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, + 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08, + 0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808, + 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808, + 0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819, + 0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908, + 0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819, + 0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808, + 0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b, + 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819, + 0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819, + 0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808, + 0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908, + 0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19, + 0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b, + 0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b, + 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919, + 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808, + 0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819, + 0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819, + 0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b, + 0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908, + 0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808, + 0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819, + 0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808, + 0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919, + 0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808, + 0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808, + 0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908, + 0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908, + 0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808, + 0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b, + 0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819, + 0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, + 0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908, + 0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808, + 0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908, + 0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919, + 0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08, + 0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19, + 0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b, + 0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b, + 0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808, + 0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08, + 0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b, + 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908, + 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b, + 0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908, + 0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, + 0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808, + 0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808, + 0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08, + 0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819, + 0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919, + 0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808, + 0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808, + 0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819, + 0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819, + 0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908, + 0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908, + 0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b, + 0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908, + 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908, + 0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908, + 0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808, + 0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, + 0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819, + 0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819, + 0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808, + 0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b, + 0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819, + 0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819, + 0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08, + 0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808, + 0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19, + 0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919, + 0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, + 0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19, + 0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b, + 0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808, + 0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b, + 0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b, + 0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, + 0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b, + 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808, + 0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819, + 0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808, + 0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808, + 0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08, + 0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b, + 0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19, + 0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08, + 0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919, + 0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08, + 0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08, + 0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908, + 0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908, + 0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b, + 0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908, + 0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808, + 0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b, + 0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808, + 0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808, + 0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19, + 0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08, + 0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808, + 0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b, + 0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808, + 0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b, + 0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b, +}; + static const __device__ uint8_t ksigns_iq2xs[128] = { 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15, 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159, @@ -1442,7 +1582,7 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds dst_t * y = yy + i*QK_K + 32*ib + 8*il; const uint16_t * q2 = x[i].qs + 4*ib; const uint8_t * aux8 = (const uint8_t *)q2; - const uint8_t * grid = (const uint8_t *)(kgrid_iq2xxs + aux8[il]); + const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[il]); const uint32_t aux32 = q2[2] | (q2[3] << 16); const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f; const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127]; @@ -1453,6 +1593,28 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds } +template +static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int i = blockIdx.x; + const block_iq2_xs * x = (const block_iq2_xs *) vx; + + const int tid = threadIdx.x; +#if QK_K == 256 + const int il = tid/8; // 0...3 + const int ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const uint16_t * q2 = x[i].qs + 4*ib; + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511)); + const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f; + const uint8_t signs = ksigns_iq2xs[q2[il] >> 9]; + for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); +#else + assert(false); +#endif + +} + static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); @@ -3999,7 +4161,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1( uint32_t aux32 = q2[2] | (q2[3] << 16); int sumi = 0; for (int l = 0; l < 4; ++l) { - const uint8_t * grid = (const uint8_t *)(kgrid_iq2xxs + aux8[l]); + const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]); const uint8_t signs = ksigns_iq2xs[aux32 & 127]; for (int j = 0; j < 8; ++j) { sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1); @@ -4015,8 +4177,8 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1( const int il = iqs%2; const uint16_t * q2 = bq2->qs + 4*ib32; const uint8_t * aux8 = (const uint8_t *)q2; - const uint8_t * grid1 = (const uint8_t *)(kgrid_iq2xxs + aux8[2*il+0]); - const uint8_t * grid2 = (const uint8_t *)(kgrid_iq2xxs + aux8[2*il+1]); + const uint8_t * grid1 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+0]); + const uint8_t * grid2 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+1]); const uint32_t aux32 = q2[2] | (q2[3] << 16); const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * (float)bq8_1[ib32].ds.x * 0.25f; const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127]; @@ -4035,6 +4197,42 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1( #endif } +static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { +#if QK_K == 256 + const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq; + + const int ib32 = iqs; + const uint16_t * q2 = bq2->qs + 4*ib32; + const int8_t * q8 = bq8_1[ib32].qs; + const uint8_t ls1 = bq2->scales[ib32] & 0xf; + const uint8_t ls2 = bq2->scales[ib32] >> 4; + int sumi1 = 0; + for (int l = 0; l < 2; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); + const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; + for (int j = 0; j < 8; ++j) { + sumi1 += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + int sumi2 = 0; + for (int l = 2; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); + const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; + for (int j = 0; j < 8; ++j) { + sumi2 += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + const float d = (float)bq2->d * (float)bq8_1[ib32].ds.x * 0.25f; + return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2); +#else + assert(false); + return 0.f; +#endif +} + template static __device__ __forceinline__ void mul_mat_q( @@ -5418,7 +5616,7 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int template static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_CI +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template; const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2; @@ -5543,7 +5741,7 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds #else (void) x; (void) y; (void) dst; (void) ncols_par; (void) nrows_y; (void) scale; bad_arch(); -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_CI +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX } template @@ -6038,6 +6236,12 @@ static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int k, dequantize_block_iq2_xxs<<>>(vx, y); } +template +static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_iq2_xs<<>>(vx, y); +} + template static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; @@ -6068,6 +6272,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_q6_K_cuda; case GGML_TYPE_IQ2_XXS: return dequantize_row_iq2_xxs_cuda; + case GGML_TYPE_IQ2_XS: + return dequantize_row_iq2_xs_cuda; case GGML_TYPE_F32: return convert_unary_cuda; default: @@ -6099,6 +6305,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_q6_K_cuda; case GGML_TYPE_IQ2_XXS: return dequantize_row_iq2_xxs_cuda; + case GGML_TYPE_IQ2_XS: + return dequantize_row_iq2_xs_cuda; case GGML_TYPE_F16: return convert_unary_cuda; default: @@ -6302,6 +6510,15 @@ static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, floa <<>>(vx, vy, dst, ncols, nrows); } +static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + static void ggml_mul_mat_q4_0_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { @@ -7865,6 +8082,7 @@ static int64_t get_row_rounding(ggml_type type) { case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: return max_compute_capability >= CC_RDNA2 ? 128 : 64; default: GGML_ASSERT(false); @@ -7886,6 +8104,7 @@ static int64_t get_row_rounding(ggml_type type) { case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: return max_compute_capability >= CC_VOLTA ? 128 : 64; case GGML_TYPE_Q6_K: return 64; @@ -7939,6 +8158,9 @@ static void ggml_cuda_op_mul_mat_vec_q( case GGML_TYPE_IQ2_XXS: mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); break; + case GGML_TYPE_IQ2_XS: + mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; default: GGML_ASSERT(false); break; @@ -8346,7 +8568,7 @@ static void ggml_cuda_op_soft_max( float scale = 1.0f; memcpy(&scale, dst->op_params, sizeof(float)); -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION >= CUDART_CI +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION >= CUDART_HMAX #ifdef GGML_CUDA_F16 const bool use_f16_soft_max = true; #else @@ -8354,7 +8576,7 @@ static void ggml_cuda_op_soft_max( #endif // GGML_CUDA_F16 #else const bool use_f16_soft_max = false; -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && CUDART_VERSION >= CUDART_CI +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && CUDART_VERSION >= CUDART_HMAX if (use_f16_soft_max) { soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); @@ -10187,8 +10409,8 @@ static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, gg ggml_cuda_set_device(ctx->device); CUDA_CHECK(cudaDeviceSynchronize()); - CUDA_CHECK(cudaMemcpy((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaDeviceSynchronize()); } static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { diff --git a/ggml-metal.m b/ggml-metal.m index 0ce3559d8599f..19c27883f66bd 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -89,6 +89,7 @@ GGML_METAL_DECL_KERNEL(get_rows_q6_K); GGML_METAL_DECL_KERNEL(get_rows_i32); GGML_METAL_DECL_KERNEL(get_rows_iq2_xxs); + GGML_METAL_DECL_KERNEL(get_rows_iq2_xs); GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(group_norm); GGML_METAL_DECL_KERNEL(norm); @@ -108,6 +109,7 @@ GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32); GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32); GGML_METAL_DECL_KERNEL(mul_mv_iq2_xxs_f32); + GGML_METAL_DECL_KERNEL(mul_mv_iq2_xs_f32); GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32); //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16); GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32); @@ -124,6 +126,7 @@ GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32); GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32); GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xxs_f32); + GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xs_f32); GGML_METAL_DECL_KERNEL(mul_mm_f32_f32); GGML_METAL_DECL_KERNEL(mul_mm_f16_f32); GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32); @@ -137,6 +140,7 @@ GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32); GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32); GGML_METAL_DECL_KERNEL(mul_mm_iq2_xxs_f32); + GGML_METAL_DECL_KERNEL(mul_mm_iq2_xs_f32); GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32); GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32); GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32); @@ -150,6 +154,7 @@ GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32); GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32); GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xxs_f32); + GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xs_f32); GGML_METAL_DECL_KERNEL(rope_f32); GGML_METAL_DECL_KERNEL(rope_f16); GGML_METAL_DECL_KERNEL(alibi_f32); @@ -385,6 +390,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(get_rows_q6_K); GGML_METAL_ADD_KERNEL(get_rows_i32); GGML_METAL_ADD_KERNEL(get_rows_iq2_xxs); + GGML_METAL_ADD_KERNEL(get_rows_iq2_xs); GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(group_norm); GGML_METAL_ADD_KERNEL(norm); @@ -404,6 +410,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32); GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32); GGML_METAL_ADD_KERNEL(mul_mv_iq2_xxs_f32); + GGML_METAL_ADD_KERNEL(mul_mv_iq2_xs_f32); GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32); //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16); GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32); @@ -420,6 +427,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32); GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32); GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xxs_f32); + GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xs_f32); if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) { GGML_METAL_ADD_KERNEL(mul_mm_f32_f32); GGML_METAL_ADD_KERNEL(mul_mm_f16_f32); @@ -434,6 +442,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32); GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32); GGML_METAL_ADD_KERNEL(mul_mm_iq2_xxs_f32); + GGML_METAL_ADD_KERNEL(mul_mm_iq2_xs_f32); GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32); GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32); GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32); @@ -447,6 +456,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32); GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32); GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xxs_f32); + GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xs_f32); } GGML_METAL_ADD_KERNEL(rope_f32); GGML_METAL_ADD_KERNEL(rope_f16); @@ -513,6 +523,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(get_rows_q6_K); GGML_METAL_DEL_KERNEL(get_rows_i32); GGML_METAL_DEL_KERNEL(get_rows_iq2_xxs); + GGML_METAL_DEL_KERNEL(get_rows_iq2_xs); GGML_METAL_DEL_KERNEL(rms_norm); GGML_METAL_DEL_KERNEL(group_norm); GGML_METAL_DEL_KERNEL(norm); @@ -532,6 +543,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32); GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32); GGML_METAL_DEL_KERNEL(mul_mv_iq2_xxs_f32); + GGML_METAL_DEL_KERNEL(mul_mv_iq2_xs_f32); GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32); //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16); GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32); @@ -548,6 +560,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32); GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32); GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xxs_f32); + GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xs_f32); if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) { GGML_METAL_DEL_KERNEL(mul_mm_f32_f32); GGML_METAL_DEL_KERNEL(mul_mm_f16_f32); @@ -562,6 +575,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32); GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32); GGML_METAL_DEL_KERNEL(mul_mm_iq2_xxs_f32); + GGML_METAL_DEL_KERNEL(mul_mm_iq2_xs_f32); GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32); GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32); GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32); @@ -575,6 +589,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32); GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32); GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xxs_f32); + GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xs_f32); } GGML_METAL_DEL_KERNEL(rope_f32); GGML_METAL_DEL_KERNEL(rope_f16); @@ -1067,6 +1082,10 @@ bool ggml_metal_graph_compute( GGML_ASSERT(!"unsupported op"); } +#ifndef GGML_METAL_NDEBUG + [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]]; +#endif + const int64_t ne00 = src0 ? src0->ne[0] : 0; const int64_t ne01 = src0 ? src0->ne[1] : 0; const int64_t ne02 = src0 ? src0->ne[2] : 0; @@ -1557,6 +1576,7 @@ bool ggml_metal_graph_compute( case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break; case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break; case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xxs_f32]; break; + case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xs_f32]; break; default: GGML_ASSERT(false && "MUL MAT-MAT not implemented"); } [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1675,6 +1695,12 @@ bool ggml_metal_graph_compute( nth1 = 16; [encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xxs_f32]; } break; + case GGML_TYPE_IQ2_XS: + { + nth0 = 4; + nth1 = 16; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xs_f32]; + } break; default: { GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t); @@ -1708,12 +1734,12 @@ bool ggml_metal_graph_compute( if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || - //src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } - else if (src0t == GGML_TYPE_IQ2_XXS) { - [encoder setThreadgroupMemoryLength:(256*8+128) atIndex:0]; + else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { + const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_Q4_K) { @@ -1806,6 +1832,7 @@ bool ggml_metal_graph_compute( case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break; case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break; case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xxs_f32]; break; + case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xs_f32]; break; default: GGML_ASSERT(false && "MUL_MAT_ID not implemented"); } [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1927,6 +1954,12 @@ bool ggml_metal_graph_compute( nth1 = 16; [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xxs_f32]; } break; + case GGML_TYPE_IQ2_XS: + { + nth0 = 4; + nth1 = 16; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xs_f32]; + } break; default: { GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t); @@ -1976,12 +2009,12 @@ bool ggml_metal_graph_compute( if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 || src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 || - //src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } - else if (src2t == GGML_TYPE_IQ2_XXS) { - [encoder setThreadgroupMemoryLength:(256*8+128) atIndex:0]; + else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) { + const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src2t == GGML_TYPE_Q4_K) { @@ -2022,6 +2055,7 @@ bool ggml_metal_graph_compute( case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break; case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break; case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xxs]; break; + case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xs]; break; default: GGML_ASSERT(false && "not implemented"); } @@ -2423,6 +2457,10 @@ bool ggml_metal_graph_compute( GGML_ASSERT(false); } } + +#ifndef GGML_METAL_NDEBUG + [encoder popDebugGroup]; +#endif } if (encoder != nil) { diff --git a/ggml-metal.metal b/ggml-metal.metal index 229efb8b69db1..029578dc54dbd 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2452,6 +2452,13 @@ typedef struct { } block_iq2_xxs; // 66 bytes / block for QK_K = 256, so 2.0625 bpw +typedef struct { + half d; + uint16_t qs[QK_K/8]; + uint8_t scales[QK_K/32]; +} block_iq2_xs; +// 74 bytes / block for QK_K = 256, so 2.3125 bpw + //====================================== dot products ========================= void kernel_mul_mv_q2_K_f32_impl( @@ -3476,7 +3483,7 @@ kernel void kernel_mul_mv_q6_K_f32( // ======================= "True" 2-bit -constexpr constant static uint64_t kgrid_iq2xxs[256] = { +constexpr constant static uint64_t iq2xxs_grid[256] = { 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819, @@ -3543,6 +3550,137 @@ constexpr constant static uint64_t kgrid_iq2xxs[256] = { 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908, }; +constexpr constant static uint64_t iq2xs_grid[512] = { + 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, + 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b, + 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919, + 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b, + 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919, + 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808, + 0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819, + 0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819, + 0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, + 0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b, + 0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b, + 0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908, + 0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908, + 0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919, + 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808, + 0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919, + 0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908, + 0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, + 0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, + 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08, + 0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808, + 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808, + 0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819, + 0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908, + 0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819, + 0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808, + 0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b, + 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819, + 0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819, + 0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808, + 0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908, + 0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19, + 0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b, + 0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b, + 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919, + 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808, + 0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819, + 0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819, + 0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b, + 0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908, + 0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808, + 0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819, + 0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808, + 0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919, + 0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808, + 0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808, + 0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908, + 0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908, + 0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808, + 0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b, + 0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819, + 0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, + 0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908, + 0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808, + 0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908, + 0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919, + 0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08, + 0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19, + 0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b, + 0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b, + 0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808, + 0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08, + 0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b, + 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908, + 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b, + 0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908, + 0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, + 0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808, + 0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808, + 0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08, + 0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819, + 0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919, + 0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808, + 0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808, + 0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819, + 0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819, + 0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908, + 0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908, + 0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b, + 0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908, + 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908, + 0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908, + 0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808, + 0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, + 0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819, + 0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819, + 0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808, + 0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b, + 0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819, + 0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819, + 0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08, + 0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808, + 0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19, + 0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919, + 0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, + 0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19, + 0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b, + 0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808, + 0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b, + 0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b, + 0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, + 0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b, + 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808, + 0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819, + 0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808, + 0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808, + 0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08, + 0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b, + 0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19, + 0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08, + 0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919, + 0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08, + 0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08, + 0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908, + 0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908, + 0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b, + 0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908, + 0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808, + 0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b, + 0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808, + 0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808, + 0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19, + 0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08, + 0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808, + 0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b, + 0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808, + 0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b, + 0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b, +}; + constexpr constant static uint8_t ksigns_iq2xs[128] = { 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15, 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159, @@ -3600,7 +3738,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl( { int nval = 4; int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = kgrid_iq2xxs[pos + i]; + for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i]; nval = 2; pos = (32*sgitg + tiisg)*nval; for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; @@ -3689,6 +3827,149 @@ kernel void kernel_mul_mv_iq2_xxs_f32( kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } +void kernel_mul_mv_iq2_xs_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; + threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512); + { + int nval = 8; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i]; + nval = 2; + pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + +#if QK_K == 256 + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq2_xs * xr = x + ibl; + device const uint16_t * q2 = xr->qs + 4 * ib; + device const uint8_t * sc = xr->scales + ib; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + const float db = dh[0]; + const uint8_t ls1 = sc[0] & 0xf; + const uint8_t ls2 = sc[0] >> 4; + const float d1 = db * (0.5f + ls1); + const float d2 = db * (0.5f + ls2); + + float sum1 = 0, sum2 = 0; + for (int l = 0; l < 2; ++l) { + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511)); + const uint8_t signs = shared_signs[(q2[l] >> 9)]; + for (int j = 0; j < 8; ++j) { + sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + } + for (int l = 2; l < 4; ++l) { + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511)); + const uint8_t signs = shared_signs[(q2[l] >> 9)]; + for (int j = 0; j < 8; ++j) { + sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + } + sumf[row] += d1 * sum1 + d2 * sum2; + + dh += nb*sizeof(block_iq2_xs)/2; + q2 += nb*sizeof(block_iq2_xs)/2; + sc += nb*sizeof(block_iq2_xs); + } + + y4 += 32 * 32; + } +#else + // TODO +#endif + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; + } + } +} + +[[host_name("kernel_mul_mv_iq2_xs_f32")]] +kernel void kernel_mul_mv_iq2_xs_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + //============================= templates and their specializations ============================= // NOTE: this is not dequantizing - we are simply fitting the template @@ -3973,18 +4254,39 @@ void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x const uint32_t aux32_s = q2[2] | (q2[3] << 16); thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g; const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f; - constant uint8_t * grid = (constant uint8_t *)(kgrid_iq2xxs + aux8[2*il+0]); + constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]); uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127]; for (int i = 0; i < 8; ++i) { reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); } - grid = (constant uint8_t *)(kgrid_iq2xxs + aux8[2*il+1]); + grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]); signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127]; for (int i = 0; i < 8; ++i) { reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); } } +template +void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint16_t * q2 = xb->qs + 4*ib32; + const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f; + constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511)); + uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9]; + for (int i = 0; i < 8; ++i) { + reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } + grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511)); + signs = ksigns_iq2xs[q2[2*il+1] >> 9]; + for (int i = 0; i < 8; ++i) { + reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } +} + template kernel void kernel_get_rows( device const void * src0, @@ -4525,6 +4827,7 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows; // // matrix-matrix multiplication @@ -4562,6 +4865,7 @@ template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm; // // indirect matrix-matrix multiplication @@ -4611,6 +4915,7 @@ template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mu template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; // // matrix-vector multiplication @@ -5448,3 +5753,68 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32( tiisg, sgitg); } + +[[host_name("kernel_mul_mv_id_iq2_xs_f32")]] +kernel void kernel_mul_mv_id_iq2_xs_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_iq2_xs_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + dst + bid*ne0, + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + shared_values, + tgpig, + tiisg, + sgitg); +} diff --git a/ggml-quants.c b/ggml-quants.c index 5965729f58dca..29d259737f52f 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -2344,15 +2344,7 @@ size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * // ====================== "True" 2-bit (de)-quantization -void quantize_row_iq2_xxs_reference(const float * restrict x, block_iq2_xxs * restrict y, int k) { - (void)x; - (void)y; - (void)k; - assert(k % QK_K == 0); - //fprintf(stderr, "=========================== %s: not implemented\n", __func__); -} - -static const uint64_t iq2xxs_grid[256] = { +static const uint64_t iq2xxs_grid[256] = { 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819, @@ -2419,6 +2411,137 @@ static const uint64_t iq2xxs_grid[256] = { 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908, }; +static const uint64_t iq2xs_grid[512] = { + 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, + 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b, + 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919, + 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b, + 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919, + 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808, + 0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819, + 0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819, + 0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, + 0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b, + 0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b, + 0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908, + 0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908, + 0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919, + 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808, + 0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919, + 0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908, + 0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, + 0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, + 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08, + 0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808, + 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808, + 0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819, + 0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908, + 0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819, + 0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808, + 0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b, + 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819, + 0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819, + 0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808, + 0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908, + 0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19, + 0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b, + 0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b, + 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919, + 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808, + 0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819, + 0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819, + 0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b, + 0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908, + 0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808, + 0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819, + 0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808, + 0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919, + 0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808, + 0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808, + 0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908, + 0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908, + 0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808, + 0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b, + 0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819, + 0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, + 0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908, + 0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808, + 0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908, + 0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919, + 0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08, + 0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19, + 0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b, + 0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b, + 0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808, + 0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08, + 0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b, + 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908, + 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b, + 0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908, + 0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, + 0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808, + 0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808, + 0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08, + 0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819, + 0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919, + 0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808, + 0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808, + 0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819, + 0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819, + 0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908, + 0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908, + 0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b, + 0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908, + 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908, + 0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908, + 0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808, + 0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, + 0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819, + 0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819, + 0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808, + 0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b, + 0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819, + 0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819, + 0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08, + 0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808, + 0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19, + 0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919, + 0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, + 0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19, + 0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b, + 0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808, + 0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b, + 0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b, + 0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, + 0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b, + 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808, + 0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819, + 0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808, + 0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808, + 0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08, + 0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b, + 0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19, + 0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08, + 0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919, + 0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08, + 0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08, + 0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908, + 0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908, + 0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b, + 0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908, + 0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808, + 0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b, + 0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808, + 0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808, + 0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19, + 0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08, + 0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808, + 0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b, + 0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808, + 0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b, + 0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b, +}; + static const uint8_t ksigns_iq2xs[128] = { 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15, 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159, @@ -2429,8 +2552,17 @@ static const uint8_t ksigns_iq2xs[128] = { 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111, 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255, }; + static const uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128}; +void quantize_row_iq2_xxs_reference(const float * restrict x, block_iq2_xxs * restrict y, int k) { + (void)x; + (void)y; + (void)k; + assert(k % QK_K == 0); + //fprintf(stderr, "=========================== %s: not implemented\n", __func__); +} + void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -2474,6 +2606,58 @@ size_t ggml_quantize_iq2_xxs(const float * src, void * dst, int n, int k, int64_ return (n/QK_K*sizeof(block_iq2_xxs)); } +// ====================== 2.3125 bpw (de)-quantization + +void quantize_row_iq2_xs_reference(const float * restrict x, block_iq2_xs * restrict y, int k) { + (void)x; + (void)y; + (void)k; + assert(k % QK_K == 0); + //fprintf(stderr, "=========================== %s: not implemented\n", __func__); +} + +void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + float db[2]; + + for (int i = 0; i < nb; i++) { + + const float d = GGML_FP16_TO_FP32(x[i].d); + + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + db[0] = d * (0.5f + (x[i].scales[ib32] & 0xf)) * 0.25f; + db[1] = d * (0.5f + (x[i].scales[ib32] >> 4)) * 0.25f; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (x[i].qs[4*ib32 + l] & 511)); + const uint8_t signs = ksigns_iq2xs[x[i].qs[4*ib32 + l] >> 9]; + for (int j = 0; j < 8; ++j) { + y[j] = db[l/2] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + y += 8; + } + } + } +} + +void quantize_row_iq2_xs(const float * restrict x, void * restrict vy, int k) { + assert(k % QK_K == 0); + block_iq2_xs * restrict y = vy; + quantize_row_iq2_xs_reference(x, y, k); +} + +size_t ggml_quantize_iq2_xs(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK_K == 0); + (void)hist; // TODO: collect histograms + + for (int j = 0; j < n; j += k) { + block_iq2_xs * restrict y = (block_iq2_xs *)dst + j/QK_K; + quantize_row_iq2_xs_reference(src + j, y, k); + } + return (n/QK_K*sizeof(block_iq2_xs)); +} + //===================================== Q8_K ============================================== void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) { @@ -7359,3 +7543,161 @@ void ggml_vec_dot_iq2_xxs_q8_K(const int n, float * restrict s, const void * res *s = 0.125f * sumf; #endif } + +void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_iq2_xs * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + int8x16x4_t q2u; + int8x16x4_t q2s; + int8x16x4_t q8b; + + int32x4x4_t scales32; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + const uint8x8_t scales8 = vld1_u8(x[i].scales); + const uint8x8_t scales_l = vand_u8(scales8, vdup_n_u8(0xf)); + const uint8x8_t scales_h = vshr_n_u8(scales8, 4); + uint8x16_t scales = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h)); + scales = vaddq_u8(vshlq_n_u8(scales, 1), vdupq_n_u8(1)); + const uint16x8_t scales1 = vmovl_u8(vget_low_u8(scales)); + const uint16x8_t scales2 = vmovl_u8(vget_high_u8(scales)); + scales32.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales1))); + scales32.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales1))); + scales32.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales2))); + scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2))); + int32x4_t sumi = vdupq_n_s32(0); + for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { + q8b = vld1q_s8_x4(q8); q8 += 64; + q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511)))); + q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511)))); + q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511)))); + q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[6] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[7] & 511)))); + q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[0] >> 9))), vld1_s8((const void *)(signs64 + (q2[1] >> 9)))); + q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[2] >> 9))), vld1_s8((const void *)(signs64 + (q2[3] >> 9)))); + q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[4] >> 9))), vld1_s8((const void *)(signs64 + (q2[5] >> 9)))); + q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[6] >> 9))), vld1_s8((const void *)(signs64 + (q2[7] >> 9)))); + q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]); + q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]); + q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]); + q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]); + const int32x4_t p1 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]); + const int32x4_t p2 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[1], q8b.val[1]); + const int32x4_t p3 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]); + const int32x4_t p4 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[3], q8b.val[3]); + const int32x4_t p = vpaddq_s32(vpaddq_s32(p1, p2), vpaddq_s32(p3, p4)); + sumi = vmlaq_s32(sumi, p, scales32.val[ib64]); + q2 += 8; + } + sumf += d*vaddvq_s32(sumi); + } + *s = 0.125f * sumf; + +#elif defined(__AVX2__) + + const __m128i m4 = _mm_set1_epi8(0xf); + const __m128i m1 = _mm_set1_epi8(1); + const __m128i m511 = _mm_set1_epi16(511); + const __m128i m127 = _mm_set1_epi16(127); + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint64_t aux64; + + // somewhat hacky, but gives a significant boost in performance + __m128i aux_gindex, aux_sindex; + const uint16_t * gindex = (const uint16_t *)&aux_gindex; + const uint16_t * sindex = (const uint16_t *)&aux_sindex; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + memcpy(&aux64, x[i].scales, 8); + __m128i stmp = _mm_set1_epi64x(aux64); + stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4)); + const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1); + + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m128i q2_data = _mm_loadu_si128((const __m128i*)q2); q2 += 8; + aux_gindex = _mm_and_si128(q2_data, m511); + aux_sindex = _mm_and_si128(_mm_srli_epi16(q2_data, 9), m127); + const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]], iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]); + const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]], iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]); + const __m256i s2_1 = _mm256_set_epi64x(signs64[sindex[3]], signs64[sindex[2]], signs64[sindex[1]], signs64[sindex[0]]); + const __m256i s2_2 = _mm256_set_epi64x(signs64[sindex[7]], signs64[sindex[6]], signs64[sindex[5]], signs64[sindex[4]]); + const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1); + const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2); + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); + + const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0))); + const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1))); + + sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1)); + sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2)); + } + + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); + +#else + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const uint8_t * restrict sc = x[i].scales; + const int8_t * restrict q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1; + const uint16_t ls2 = 2*(sc[ib32] >> 4) + 1; + int32_t sumi = 0; + for (int l = 0; l < 2; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); + const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls1; + sumi = 0; + for (int l = 2; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); + const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls2; + q2 += 4; + } + sumf += d * bsum; + } + *s = 0.125f * sumf; +#endif +} diff --git a/ggml-quants.h b/ggml-quants.h index 8dd911d4182fa..df5e7ae807f5f 100644 --- a/ggml-quants.h +++ b/ggml-quants.h @@ -174,6 +174,14 @@ typedef struct { } block_iq2_xxs; static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding"); +// 2.3125 bpw quants +typedef struct { + ggml_fp16_t d; + uint16_t qs[QK_K/8]; + uint8_t scales[QK_K/32]; +} block_iq2_xs; +static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding"); + // Quantization void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k); void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k); @@ -189,6 +197,7 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k); void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k); void quantize_row_iq2_xxs_reference(const float * restrict x, block_iq2_xxs * restrict y, int k); +void quantize_row_iq2_xs_reference (const float * restrict x, block_iq2_xs * restrict y, int k); void quantize_row_q4_0(const float * restrict x, void * restrict y, int k); void quantize_row_q4_1(const float * restrict x, void * restrict y, int k); @@ -204,6 +213,7 @@ void quantize_row_q5_K(const float * restrict x, void * restrict y, int k); void quantize_row_q6_K(const float * restrict x, void * restrict y, int k); void quantize_row_q8_K(const float * restrict x, void * restrict y, int k); void quantize_row_iq2_xxs(const float * restrict x, void * restrict y, int k); +void quantize_row_iq2_xs (const float * restrict x, void * restrict y, int k); // Dequantization void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k); @@ -220,6 +230,7 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k); void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k); void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k); +void dequantize_row_iq2_xs (const block_iq2_xs * restrict x, float * restrict y, int k); // Dot product void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy); @@ -234,3 +245,4 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); +void ggml_vec_dot_iq2_xs_q8_K (int n, float * restrict s, const void * restrict vx, const void * restrict vy); diff --git a/ggml.c b/ggml.c index 03ec5349b907f..be2b6725481b2 100644 --- a/ggml.c +++ b/ggml.c @@ -132,7 +132,7 @@ void ggml_print_backtrace(void) { "-ex", "bt -frame-info source-and-location", "-ex", "detach", "-ex", "quit", - NULL); + (char *) NULL); } else { waitpid(pid, NULL, 0); } @@ -394,6 +394,12 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y); static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y); +ggml_collect_imatrix_t g_imatrix_collect = NULL; + +void ggml_set_imatrix_collection(ggml_collect_imatrix_t imatrix_collect) { + g_imatrix_collect = imatrix_collect; +} + static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { [GGML_TYPE_I8] = { .type_name = "i8", @@ -584,6 +590,17 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_iq2_xxs_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, }, + [GGML_TYPE_IQ2_XS] = { + .type_name = "iq2_xs", + .blck_size = QK_K, + .type_size = sizeof(block_iq2_xs), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq2_xs, + .from_float = quantize_row_iq2_xs, + .from_float_reference = (ggml_from_float_t) quantize_row_iq2_xs_reference, + .vec_dot = ggml_vec_dot_iq2_xs_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + }, [GGML_TYPE_Q8_K] = { .type_name = "q8_K", .blck_size = QK_K, @@ -2123,6 +2140,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break; case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break; case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break; + case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break; case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break; case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break; } @@ -4311,13 +4329,13 @@ struct ggml_tensor * ggml_set_2d_inplace( static struct ggml_tensor * ggml_cpy_impl( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b, - bool inplace) { + struct ggml_tensor * b) { GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b)); bool is_node = false; - if (!inplace && (a->grad || b->grad)) { + if (a->grad || b->grad) { + // inplace is false and either one have a grad is_node = true; } @@ -4341,29 +4359,21 @@ struct ggml_tensor * ggml_cpy( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { - return ggml_cpy_impl(ctx, a, b, false); -} - -struct ggml_tensor * ggml_cpy_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_cpy_impl(ctx, a, b, true); + return ggml_cpy_impl(ctx, a, b); } // ggml_cont static struct ggml_tensor * ggml_cont_impl( struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { + struct ggml_tensor * a) { bool is_node = false; - if (!inplace && a->grad) { + if (a->grad) { is_node = true; } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); ggml_format_name(result, "%s (cont)", a->name); result->op = GGML_OP_CONT; @@ -4376,13 +4386,7 @@ static struct ggml_tensor * ggml_cont_impl( struct ggml_tensor * ggml_cont( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_cont_impl(ctx, a, false); -} - -struct ggml_tensor * ggml_cont_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_cont_impl(ctx, a, true); + return ggml_cont_impl(ctx, a); } // make contiguous, with new shape @@ -7449,6 +7453,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: { ggml_compute_forward_add_q_f32(params, src0, src1, dst); } break; @@ -7714,6 +7719,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: { ggml_compute_forward_add1_q_f32(params, src0, src1, dst); } break; @@ -7829,6 +7835,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: default: { GGML_ASSERT(false); @@ -9762,6 +9769,10 @@ static void ggml_compute_forward_mul_mat( const int ith = params->ith; const int nth = params->nth; + if (ith == 1 && g_imatrix_collect) { + g_imatrix_collect(src0, src1); + } + const enum ggml_type type = src0->type; const bool src1_cont = ggml_is_contiguous(src1); @@ -10065,6 +10076,10 @@ static void ggml_compute_forward_mul_mat_id( const struct ggml_tensor * src0_cur = dst->src[cur_a + 2]; + if (ith == 1 && g_imatrix_collect) { + g_imatrix_collect(src0_cur, src1); + } + const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; const size_t row_size = ggml_row_size(vec_dot_type, ne10); @@ -10471,6 +10486,7 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: { ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst); } break; @@ -10646,6 +10662,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: default: { GGML_ASSERT(false); @@ -10841,6 +10858,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: { ggml_compute_forward_get_rows_q(params, src0, src1, dst); } break; @@ -11478,6 +11496,7 @@ static void ggml_compute_forward_alibi( case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: case GGML_TYPE_Q8_K: case GGML_TYPE_I8: case GGML_TYPE_I16: @@ -11553,6 +11572,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: case GGML_TYPE_Q8_K: case GGML_TYPE_I8: case GGML_TYPE_I16: @@ -18674,6 +18694,12 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i block_iq2_xxs * block = (block_iq2_xxs*)dst + start / QK_K; result = ggml_quantize_iq2_xxs(src + start, block, n, n, hist); } break; + case GGML_TYPE_IQ2_XS: + { + GGML_ASSERT(start % QK_K == 0); + block_iq2_xs * block = (block_iq2_xs*)dst + start / QK_K; + result = ggml_quantize_iq2_xs(src + start, block, n, n, hist); + } break; case GGML_TYPE_F16: { int elemsize = sizeof(ggml_fp16_t); @@ -19074,8 +19100,8 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p (int64_t) info->ne[3]; if (ne % ggml_blck_size(info->type) != 0) { - fprintf(stderr, "%s: tensor '%s' number of elements (%" PRId64 ") is not a multiple of block size (%d)\n", - __func__, info->name.data, ne, ggml_blck_size(info->type)); + fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%d)\n", + __func__, info->name.data, (int)info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type)); fclose(file); gguf_free(ctx); return NULL; diff --git a/ggml.h b/ggml.h index a78ce460066f7..8e02d7cbe2bba 100644 --- a/ggml.h +++ b/ggml.h @@ -218,7 +218,9 @@ #define GGML_MAX_PARAMS 2048 #define GGML_MAX_CONTEXTS 64 #define GGML_MAX_SRC 10 +#ifndef GGML_MAX_NAME #define GGML_MAX_NAME 64 +#endif #define GGML_MAX_OP_PARAMS 64 #define GGML_DEFAULT_N_THREADS 4 #define GGML_DEFAULT_GRAPH_SIZE 2048 @@ -347,6 +349,7 @@ extern "C" { GGML_TYPE_Q6_K = 14, GGML_TYPE_Q8_K = 15, GGML_TYPE_IQ2_XXS = 16, + GGML_TYPE_IQ2_XS = 17, GGML_TYPE_I8, GGML_TYPE_I16, GGML_TYPE_I32, @@ -382,6 +385,7 @@ extern "C" { GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ2_XS = 16, // except 1d tensors }; // available tensor operations: @@ -1168,22 +1172,11 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); - // a -> b, in-place, return view(b) - GGML_API struct ggml_tensor * ggml_cpy_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - // make contiguous GGML_API struct ggml_tensor * ggml_cont( struct ggml_context * ctx, struct ggml_tensor * a); - // make contiguous, in-place - GGML_API struct ggml_tensor * ggml_cont_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a); - // make contiguous, with new shape GGML_API struct ggml_tensor * ggml_cont_1d( struct ggml_context * ctx, @@ -2077,9 +2070,16 @@ extern "C" { GGML_API size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_iq2_xxs(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_iq2_xs (const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist); + // + // Importance matrix + // + typedef void(*ggml_collect_imatrix_t)(const struct ggml_tensor * src0, const struct ggml_tensor * src1); + GGML_API void ggml_set_imatrix_collection(ggml_collect_imatrix_t imatrix_collect); + // // gguf // diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 80c1d5449cc74..24a0890378496 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -57,6 +57,7 @@ class TensorNameMap: "transformer.norm_f", # mpt "ln_f", # refact bloom qwen gpt2 "language_model.encoder.final_layernorm", # persimmon + "model.final_layernorm", # persimmon "lm_head.ln", # phi2 ), @@ -98,6 +99,7 @@ class TensorNameMap: "transformer.h.{bid}.self_attention.query_key_value", # falcon "h.{bid}.self_attention.query_key_value", # bloom "language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon + "model.layers.{bid}.self_attn.query_key_value", # persimmon "h.{bid}.attn.c_attn", # gpt2 "transformer.h.{bid}.mixer.Wqkv", # phi2 ), @@ -141,6 +143,7 @@ class TensorNameMap: "encoder.layer.{bid}.attention.output.dense", # bert "transformer.h.{bid}.attn.out_proj", # gpt-j "language_model.encoder.layers.{bid}.self_attention.dense", # persimmon + "model.layers.{bid}.self_attn.dense", # persimmon "h.{bid}.attn.c_proj", # gpt2 "transformer.h.{bid}.mixer.out_proj", # phi2 "model.layers.layers.{bid}.self_attn.o_proj", # plamo @@ -184,6 +187,7 @@ class TensorNameMap: "encoder.layer.{bid}.intermediate.dense", # bert "transformer.h.{bid}.mlp.fc_in", # gpt-j "language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon + "model.layers.{bid}.mlp.dense_h_to_4h", # persimmon "transformer.h.{bid}.mlp.w1", # qwen "h.{bid}.mlp.c_fc", # gpt2 "transformer.h.{bid}.mlp.fc1", # phi2 @@ -225,6 +229,7 @@ class TensorNameMap: "encoder.layer.{bid}.output.dense", # bert "transformer.h.{bid}.mlp.fc_out", # gpt-j "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon + "model.layers.{bid}.mlp.dense_4h_to_h", # persimmon "h.{bid}.mlp.c_proj", # gpt2 "transformer.h.{bid}.mlp.fc2", # phi2 "model.layers.layers.{bid}.mlp.down_proj", # plamo @@ -237,10 +242,12 @@ class TensorNameMap: MODEL_TENSOR.ATTN_Q_NORM: ( "language_model.encoder.layers.{bid}.self_attention.q_layernorm", + "model.layers.{bid}.self_attn.q_layernorm", # persimmon ), MODEL_TENSOR.ATTN_K_NORM: ( "language_model.encoder.layers.{bid}.self_attention.k_layernorm", + "model.layers.{bid}.self_attn.k_layernorm", # persimmon ), MODEL_TENSOR.ROPE_FREQS: ( diff --git a/klite.embd b/klite.embd index 2f328e60bdd7c..e2c0430473433 100644 --- a/klite.embd +++ b/klite.embd @@ -142,7 +142,7 @@ Current version: 105 display: flex; flex-wrap: wrap; background-color: #4d4d4d; - padding: 10px; + padding: 6px; } body.connected .settingsmenu, @@ -1135,8 +1135,8 @@ Current version: 105 .settingitem { width: 50%; - padding-left: 8px; - padding-right: 8px; + padding-left: 6px; + padding-right: 6px; padding-bottom: 5px; padding-top: 5px; display: inline-block; @@ -3378,6 +3378,7 @@ Current version: 105 img_autogen: false, img_allownsfw: true, img_cfgscale: 7, + img_allowhd: false, img_steps: 20, save_images: true, prompt_for_savename: false, @@ -7383,6 +7384,7 @@ Current version: 105 update_horde_sdmodels(); document.getElementById("tokenstreammode").value = localsettings.tokenstreammode; + document.getElementById("img_allowhd").checked = localsettings.img_allowhd; document.getElementById("img_autogen").checked = localsettings.img_autogen; document.getElementById("save_images").checked = localsettings.save_images; document.getElementById("img_cfgscale").value = localsettings.img_cfgscale; @@ -7564,6 +7566,7 @@ Current version: 105 localsettings.image_styles = pendingstyle; localsettings.grammar = pendinggrammar; localsettings.tokenstreammode = document.getElementById("tokenstreammode").value; + localsettings.img_allowhd = (document.getElementById("img_allowhd").checked ? true : false); localsettings.img_autogen = (document.getElementById("img_autogen").checked ? true : false); localsettings.save_images = (document.getElementById("save_images").checked ? true : false); localsettings.prompt_for_savename = (document.getElementById("prompt_for_savename").checked ? true : false); @@ -9439,10 +9442,11 @@ Current version: 105 { //console.log(outputimg); let origImg = "data:image/jpeg;base64," + outputimg; + let imgres = localsettings.img_allowhd?380:256; compressImage(origImg, (newDataUri) => { image_db[imgid].done = true; image_db[imgid].result = newDataUri; - }, true); + }, true, true, imgres); }else{ image_db[imgid].queue = "Failed"; msgbox("Image Generation Failed!\n\nPlease make sure A1111 is running and properly configured!\nIn your local install of Automatic1111 WebUi, modify webui-user.bat and add these flags to enable API access:\n\nset COMMANDLINE_ARGS= --api --listen --cors-allow-origins=*\n"); @@ -9466,10 +9470,11 @@ Current version: 105 { //console.log(outputimg); let origImg = "data:image/jpeg;base64," + outputimg; + let imgres = localsettings.img_allowhd?380:256; compressImage(origImg, (newDataUri) => { image_db[imgid].done = true; image_db[imgid].result = newDataUri; - }, true); + }, true, true, imgres); }else{ image_db[imgid].queue = "Failed"; msgbox("Image Generation Failed!\n\nPlease make sure your OpenAI key is set correctly and you are allowed to use DALL-E.\n"); @@ -9518,6 +9523,7 @@ Current version: 105 function render_image_html(data, pend_txt = "", float=true) { var dim = (localsettings.opmode == 2 ? 160 : 200); //adventure mode has smaller pictures let siclass = (float?"storyimgfloat":"storyimg"); + let reinvertcolor = localsettings.invert_colors?" invert_colors":""; let alttxt = ""; if (!data || data == "") { let waittime = "Unavailable"; @@ -9529,13 +9535,13 @@ Current version: 105 console.log("Cannot render " + pend_txt); } - return `
` + pend_txt + `
` + waittime + `
`; + return `
` + pend_txt + `
` + waittime + `
`; } else { let imghash = cyrb_hash(data); if (completed_imgs_meta[imghash] != null) { alttxt = completed_imgs_meta[imghash].alt?escapeHtml(completed_imgs_meta[imghash].alt):""; } - return `
`; + return `
`; } } @@ -9775,7 +9781,8 @@ Current version: 105 img.queue = 0; let origImg = "data:image/jpeg;base64," + finalimg.generations[0].img; //console.log("Original image: " + origImg); - compressImage(origImg, (newDataUri) => { img.result = newDataUri; }, true); + let imgres = localsettings.img_allowhd?380:256; + compressImage(origImg, (newDataUri) => { img.result = newDataUri; }, true, true, imgres); } }) .catch((error) => { @@ -11732,7 +11739,8 @@ Current version: 105 } function image(role) { if (!as[`${role}_portrait`] || as.border_style == 'None' || role == 'sys') { return ''; } - return `
`; + let reinvertcolor = localsettings.invert_colors?" invert_colors":""; + return `
`; } function applyStylizedCodeBlocks() { let blocks = newbodystr.split(/(```[\s\S]*?\n[\s\S]*?```)/g); @@ -12487,7 +12495,7 @@ Current version: 105 - +
@@ -12512,7 +12520,7 @@ Current version: 105
Advanced Sampler Config ?These settings control alternative samplers configurations. They are inactive by default, you usually do not need to change them.
-
Start Seq.?The sequence to start an instruction prompt End Seq.?The sequence to end an instruction prompt
@@ -12560,7 +12568,7 @@ Current version: 105
Mirostat (If supported) ?Replaces your samplers with mirostat, an alternative sampling method. May not be available depending on backend, not supported on Horde.
-
Top-K
@@ -12586,7 +12594,7 @@ Current version: 105
-
Mode
@@ -12977,7 +12985,7 @@ Current version: 105
- +
@@ -13039,6 +13047,13 @@ Current version: 105
Cfg. Scale:
+
+
Save Higher-Res ? + This option will result in larger save files which may be slower. Changing this setting only applies to NEW images. + :
+ +
+
diff --git a/koboldcpp.py b/koboldcpp.py index 4919af6b63fa5..062c0392125d0 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -427,7 +427,7 @@ def bring_terminal_to_foreground(): modelbusy = threading.Lock() requestsinqueue = 0 defaultport = 5001 -KcppVersion = "1.55.1" +KcppVersion = "1.56" showdebug = True showsamplerwarning = True showmaxctxwarning = True diff --git a/llama.cpp b/llama.cpp index 39c26534f6839..db834936d4b93 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2234,6 +2234,7 @@ struct llama_model_loader { case GGML_TYPE_Q5_K: ftype = LLAMA_FTYPE_MOSTLY_Q5_K_M; break; case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break; case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break; + case GGML_TYPE_IQ2_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS; break; default: { LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max)); @@ -2596,7 +2597,8 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0"; // K-quants - case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K"; + case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small"; case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small"; case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "Q3_K - Medium"; case LLAMA_FTYPE_MOSTLY_Q3_K_L: return "Q3_K - Large"; @@ -2606,6 +2608,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium"; case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K"; case LLAMA_FTYPE_MOSTLY_IQ2_XXS:return "IQ2_XSS - 2.0625 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw"; default: return "unknown, may not work"; } @@ -2840,6 +2843,7 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { + case 24: model.type = e_model::MODEL_1B; break; case 32: model.type = e_model::MODEL_3B; break; default: model.type = e_model::MODEL_UNKNOWN; } @@ -3175,7 +3179,15 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str()); - LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, ml.n_elements*1e-9); + if (ml.n_elements >= 1e12) { + LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, ml.n_elements*1e-12); + } else if (ml.n_elements >= 1e9) { + LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, ml.n_elements*1e-9); + } else if (ml.n_elements >= 1e6) { + LLAMA_LOG_INFO("%s: model params = %.2f M\n", __func__, ml.n_elements*1e-6); + } else { + LLAMA_LOG_INFO("%s: model params = %.2f K\n", __func__, ml.n_elements*1e-3); + } if (ml.n_bytes < GiB) { LLAMA_LOG_INFO("%s: model size = %.2f MiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements); } else { @@ -4121,7 +4133,6 @@ static void llm_build_k_shift( struct ggml_cgraph * graph, llm_rope_type type, int64_t n_ctx, - int n_rot, float freq_base, float freq_scale, const llm_build_cb & cb) { @@ -4129,14 +4140,13 @@ static void llm_build_k_shift( const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_embd_head_k = hparams.n_embd_head_k; const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const int32_t n_rot = hparams.n_rot; const int32_t n_orig_ctx = cparams.n_yarn_orig_ctx; const float ext_factor = cparams.yarn_ext_factor; const float attn_factor = cparams.yarn_attn_factor; const float beta_fast = cparams.yarn_beta_fast; const float beta_slow = cparams.yarn_beta_slow; - GGML_ASSERT(n_embd_head_k % n_rot == 0); - struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_ctx); cb(K_shift, "K_shift", -1); @@ -4540,7 +4550,7 @@ struct llm_build_context { // shift the entire K-cache if needed if (do_rope_shift) { - llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb); + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb); } for (int il = 0; il < n_layer; ++il) { @@ -4578,14 +4588,14 @@ struct llm_build_context { Qcur = ggml_rope_custom( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, - n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale, + hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); Kcur = ggml_rope_custom( ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, - n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale, + hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -4708,6 +4718,7 @@ struct llm_build_context { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); struct ggml_tensor * cur; struct ggml_tensor * inpL; @@ -4725,7 +4736,7 @@ struct llm_build_context { // shift the entire K-cache if needed if (do_rope_shift) { - llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb); + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb); } for (int il = 0; il < n_layer; ++il) { @@ -4751,12 +4762,12 @@ struct llm_build_context { case MODEL_7B: Qcur = ggml_rope_custom( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, - n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale, + hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); Kcur = ggml_rope_custom( ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, - n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale, + hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); break; @@ -4829,6 +4840,7 @@ struct llm_build_context { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); struct ggml_tensor * cur; struct ggml_tensor * inpL; @@ -4846,7 +4858,7 @@ struct llm_build_context { // shift the entire K-cache if needed if (do_rope_shift) { - llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb); + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb); } for (int il = 0; il < n_layer; ++il) { @@ -4887,13 +4899,13 @@ struct llm_build_context { // using mode = 2 for neox mode Qcur = ggml_rope_custom( - ctx0, Qcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx, + ctx0, Qcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); Kcur = ggml_rope_custom( - ctx0, Kcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx, + ctx0, Kcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -5050,15 +5062,14 @@ struct llm_build_context { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); const int64_t n_embd_head = hparams.n_embd_head_v; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - - const int64_t n_rot = n_embd_head_k / 2; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head/2 == hparams.n_rot); struct ggml_tensor * cur; struct ggml_tensor * inpL; inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); - cb(inpL, "imp_embd", -1); + cb(inpL, "inp_embd", -1); // inp_pos - contains the positions struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); @@ -5069,7 +5080,7 @@ struct llm_build_context { cb(KQ_mask, "KQ_mask", -1); if (do_rope_shift) { - llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb); + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb); } for (int il = 0; il < n_layer; ++il) { @@ -5129,7 +5140,7 @@ struct llm_build_context { // RoPE the first n_rot of q/k, pass the other half, and concat. struct ggml_tensor * qrot = ggml_view_3d( - ctx0, tmpq, n_rot, n_head, n_tokens, + ctx0, tmpq, hparams.n_rot, n_head, n_tokens, ggml_element_size(tmpq) * n_embd_head, ggml_element_size(tmpq) * n_embd_head * n_head, 0 @@ -5137,7 +5148,7 @@ struct llm_build_context { cb(qrot, "qrot", il); struct ggml_tensor * krot = ggml_view_3d( - ctx0, tmpk, n_rot, n_head, n_tokens, + ctx0, tmpk, hparams.n_rot, n_head, n_tokens, ggml_element_size(tmpk) * n_embd_head, ggml_element_size(tmpk) * n_embd_head * n_head, 0 @@ -5146,29 +5157,29 @@ struct llm_build_context { // get the second half of tmpq, e.g tmpq[n_rot:, :, :] struct ggml_tensor * qpass = ggml_view_3d( - ctx0, tmpq, n_rot, n_head, n_tokens, + ctx0, tmpq, hparams.n_rot, n_head, n_tokens, ggml_element_size(tmpq) * n_embd_head, ggml_element_size(tmpq) * n_embd_head * n_head, - ggml_element_size(tmpq) * n_rot + ggml_element_size(tmpq) * hparams.n_rot ); cb(qpass, "qpass", il); struct ggml_tensor * kpass = ggml_view_3d( - ctx0, tmpk, n_rot, n_head, n_tokens, + ctx0, tmpk, hparams.n_rot, n_head, n_tokens, ggml_element_size(tmpk) * n_embd_head, ggml_element_size(tmpk) * n_embd_head * n_head, - ggml_element_size(tmpk) * n_rot + ggml_element_size(tmpk) * hparams.n_rot ); cb(kpass, "kpass", il); struct ggml_tensor * qrotated = ggml_rope_custom( - ctx0, qrot, inp_pos, n_rot, 2, 0, n_orig_ctx, + ctx0, qrot, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(qrotated, "qrotated", il); struct ggml_tensor * krotated = ggml_rope_custom( - ctx0, krot, inp_pos, n_rot, 2, 0, n_orig_ctx, + ctx0, krot, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(krotated, "krotated", il); @@ -5565,7 +5576,7 @@ struct llm_build_context { // shift the entire K-cache if needed if (do_rope_shift) { - llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, hparams.n_rot, freq_base, freq_scale, cb); + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb); } for (int il = 0; il < n_layer; ++il) { @@ -5678,7 +5689,7 @@ struct llm_build_context { // shift the entire K-cache if needed if (do_rope_shift) { - llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb); + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb); } for (int il = 0; il < n_layer; ++il) { @@ -5710,13 +5721,13 @@ struct llm_build_context { // using mode = 2 for neox mode Qcur = ggml_rope_custom( - ctx0, Qcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx, + ctx0, Qcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); Kcur = ggml_rope_custom( - ctx0, Kcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx, + ctx0, Kcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -5795,7 +5806,7 @@ struct llm_build_context { // shift the entire K-cache if needed if (do_rope_shift) { - llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb); + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb); } for (int il = 0; il < n_layer; ++il) { @@ -5891,6 +5902,7 @@ struct llm_build_context { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); struct ggml_tensor * cur; struct ggml_tensor * inpL; @@ -5908,7 +5920,7 @@ struct llm_build_context { // shift the entire K-cache if needed if (do_rope_shift) { - llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb); + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb); } for (int il = 0; il < n_layer; ++il) { @@ -5934,13 +5946,13 @@ struct llm_build_context { cb(Vcur, "Vcur", il); Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + ctx0, ggml_reshape_3d(ctx0, Qcur, hparams.n_rot, n_head, n_tokens), inp_pos, n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur", il); Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + ctx0, ggml_reshape_3d(ctx0, Kcur, hparams.n_rot, n_head_kv, n_tokens), inp_pos, n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); @@ -9277,10 +9289,13 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty // TODO: explore better strategies new_type = GGML_TYPE_Q8_0; } - } else if (name.find("ffn_down.weight") != std::string::npos) { + } else if (name.find("ffn_down") != std::string::npos) { if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S) { + if (qs.i_feed_forward_w2 < qs.n_feed_forward_w2/8) new_type = GGML_TYPE_Q4_K; + } else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) { - new_type = qs.i_feed_forward_w2 < 2 ? GGML_TYPE_Q5_K + new_type = qs.i_feed_forward_w2 < qs.n_feed_forward_w2/16 ? GGML_TYPE_Q5_K : arch != LLM_ARCH_FALCON || use_more_bits(qs.i_feed_forward_w2, qs.n_feed_forward_w2) ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K; } @@ -9289,14 +9304,14 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty } else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) { if (arch == LLM_ARCH_FALCON) { - new_type = qs.i_feed_forward_w2 < 2 ? GGML_TYPE_Q6_K : + new_type = qs.i_feed_forward_w2 < qs.n_feed_forward_w2/16 ? GGML_TYPE_Q6_K : use_more_bits(qs.i_feed_forward_w2, qs.n_feed_forward_w2) ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K; } else { if (use_more_bits(qs.i_feed_forward_w2, qs.n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K; } } else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(qs.i_feed_forward_w2, qs.n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K; - else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && arch != LLM_ARCH_FALCON && qs.i_feed_forward_w2 < 4) { + else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && arch != LLM_ARCH_FALCON && qs.i_feed_forward_w2 < qs.n_feed_forward_w2/8) { new_type = GGML_TYPE_Q5_K; } ++qs.i_feed_forward_w2; @@ -9314,9 +9329,10 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K; } - else if (name.find("ffn_gate.weight") != std::string::npos || name.find("ffn_up.weight") != std::string::npos) { - if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; - } + // IK: let's remove this, else Q2_K is almost the same as Q3_K_S + //else if (name.find("ffn_gate") != std::string::npos || name.find("ffn_up") != std::string::npos) { + // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; + //} // This can be used to reduce the size of the Q5_K_S model. // The associated PPL increase is fully in line with the size reduction //else { @@ -9365,6 +9381,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s // K-quants case LLAMA_FTYPE_MOSTLY_Q2_K: quantized_type = GGML_TYPE_Q2_K; break; + case LLAMA_FTYPE_MOSTLY_Q2_K_S: quantized_type = GGML_TYPE_Q2_K; break; case LLAMA_FTYPE_MOSTLY_Q3_K_S: case LLAMA_FTYPE_MOSTLY_Q3_K_M: case LLAMA_FTYPE_MOSTLY_Q3_K_L: quantized_type = GGML_TYPE_Q3_K; break; @@ -9374,6 +9391,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q5_K_M: quantized_type = GGML_TYPE_Q5_K; break; case LLAMA_FTYPE_MOSTLY_Q6_K: quantized_type = GGML_TYPE_Q6_K; break; case LLAMA_FTYPE_MOSTLY_IQ2_XXS:quantized_type = GGML_TYPE_IQ2_XXS; break; + case LLAMA_FTYPE_MOSTLY_IQ2_XS :quantized_type = GGML_TYPE_IQ2_XS; break; default: throw std::runtime_error(format("invalid output file type %d\n", ftype)); } @@ -9422,7 +9440,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (name.find("attn_v.weight") != std::string::npos || name.find("attn_qkv.weight") != std::string::npos) { ++qs.n_attention_wv; } - else if (name.find("ffn_down.weight") != std::string::npos) { + else if (name.find("ffn_down") != std::string::npos) { ++qs.n_feed_forward_w2; } } @@ -11262,7 +11280,7 @@ void llama_print_timings(struct llama_context * ctx) { __func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval); LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", __func__, timings.t_eval_ms, timings.n_eval, timings.t_eval_ms / timings.n_eval, 1e3 / timings.t_eval_ms * timings.n_eval); - LLAMA_LOG_INFO("%s: total time = %10.2f ms\n", __func__, (timings.t_end_ms - timings.t_start_ms)); + LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (timings.t_end_ms - timings.t_start_ms), (timings.n_p_eval + timings.n_eval)); } void llama_reset_timings(struct llama_context * ctx) { diff --git a/llama.h b/llama.h index 454445b889f62..e5f4bb1d3a8f1 100644 --- a/llama.h +++ b/llama.h @@ -105,6 +105,8 @@ extern "C" { LLAMA_FTYPE_MOSTLY_Q5_K_M = 17, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q6_K = 18, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ2_XXS = 19, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ2_XS = 20, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q2_K_S = 21, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index fe7f3202f4bb6..3e2c579d575cd 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -f96711108d55bdbbd277e6be07204dce6a94fb93 +979cc23b345006504cfc1f67c0fdf627805e3319
RpRng.