Skip to content

Commit

Permalink
llama: Refactor string_split to use template specialization, fixes pa…
Browse files Browse the repository at this point in the history
…rsing strings with spaces
  • Loading branch information
Xarbirus committed Oct 23, 2024
1 parent 190a37d commit e41cfc7
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 21 deletions.
10 changes: 5 additions & 5 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,13 @@ static void common_params_handle_model_default(common_params & params) {
}
params.hf_file = params.model;
} else if (params.model.empty()) {
params.model = fs_get_cache_file(string_split(params.hf_file, '/').back());
params.model = fs_get_cache_file(string_split<std::string>(params.hf_file, '/').back());
}
} else if (!params.model_url.empty()) {
if (params.model.empty()) {
auto f = string_split(params.model_url, '#').front();
f = string_split(f, '?').front();
params.model = fs_get_cache_file(string_split(f, '/').back());
auto f = string_split<std::string>(params.model_url, '#').front();
f = string_split<std::string>(f, '?').front();
params.model = fs_get_cache_file(string_split<std::string>(f, '/').back());
}
} else if (params.model.empty()) {
params.model = DEFAULT_MODEL_PATH;
Expand Down Expand Up @@ -879,7 +879,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--samplers"}, "SAMPLERS",
string_format("samplers that will be used for generation in the order, separated by \';\'\n(default: %s)", sampler_type_names.c_str()),
[](common_params & params, const std::string & value) {
const auto sampler_names = string_split(value, ';');
const auto sampler_names = string_split<std::string>(value, ';');
params.sparams.samplers = common_sampler_types_from_names(sampler_names, true);
}
).set_sparam());
Expand Down
13 changes: 0 additions & 13 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,19 +416,6 @@ std::string string_format(const char * fmt, ...) {
return std::string(buf.data(), size);
}

std::vector<std::string> string_split(std::string input, char separator) {
std::vector<std::string> parts;
size_t separator_pos = input.find(separator);
while (separator_pos != std::string::npos) {
std::string part = input.substr(0, separator_pos);
parts.emplace_back(part);
input = input.substr(separator_pos + 1);
separator_pos = input.find(separator);
}
parts.emplace_back(input);
return parts;
}

std::string string_strip(const std::string & str) {
size_t start = 0;
size_t end = str.size();
Expand Down
18 changes: 16 additions & 2 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,6 @@ bool set_process_priority(enum ggml_sched_priority prio);
LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
std::string string_format(const char * fmt, ...);

std::vector<std::string> string_split(std::string input, char separator);

std::string string_strip(const std::string & str);
std::string string_get_sortable_timestamp();

Expand All @@ -401,6 +399,22 @@ static std::vector<T> string_split(const std::string & str, char delim) {
return values;
}

template<>
std::vector<std::string> string_split<std::string>(const std::string & input, char separator)
{
std::vector<std::string> parts;
size_t begin_pos = 0;
size_t separator_pos = input.find(separator);
while (separator_pos != std::string::npos) {
std::string part = input.substr(begin_pos, separator_pos - begin_pos);
parts.emplace_back(part);
begin_pos = separator_pos + 1;
separator_pos = input.find(separator, begin_pos);
}
parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos));
return parts;
}

bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
void string_process_escapes(std::string & input);

Expand Down
2 changes: 1 addition & 1 deletion examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2612,7 +2612,7 @@ int main(int argc, char ** argv) {
auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) {
server_state current_state = state.load();
if (current_state == SERVER_STATE_LOADING_MODEL) {
auto tmp = string_split(req.path, '.');
auto tmp = string_split<std::string>(req.path, '.');
if (req.path == "/" || tmp.back() == "html") {
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
res.status = 503;
Expand Down

0 comments on commit e41cfc7

Please sign in to comment.