From f43f09366dfd018e4568e23a232aaa8c4f7cfc78 Mon Sep 17 00:00:00 2001 From: Ziad Ben Hadj-Alouane Date: Thu, 30 Nov 2023 17:25:04 -0500 Subject: [PATCH] server : add single-client multi-prompt support (#4232) * * add multiprompt support * * cleanup * * more cleanup * * remove atomicity of id_gen, and change lock_guard to unique_lock on completion requests * * remove all references to mutex_multitasks * Update examples/server/server.cpp Co-authored-by: Jared Van Bortel * Update examples/server/server.cpp Co-authored-by: Jared Van Bortel * Update examples/server/server.cpp Co-authored-by: Jared Van Bortel * Update examples/server/server.cpp Co-authored-by: Jared Van Bortel * * change to set --------- Co-authored-by: Jared Van Bortel --- examples/server/server.cpp | 139 ++++++++++++++++++++++++++++++++++--- 1 file changed, 128 insertions(+), 11 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 50f124b13e849..5edb3678efe09 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -155,15 +155,23 @@ struct task_server { json data; bool infill_mode = false; bool embedding_mode = false; + int multitask_id = -1; }; struct task_result { int id; + int multitask_id = -1; bool stop; bool error; json result_json; }; +struct task_multi { + int id; + std::set subtasks_remaining{}; + std::vector results{}; +}; + // TODO: can become bool if we can't find use of more states enum slot_state { @@ -406,6 +414,9 @@ struct llama_client_slot double t_prompt_processing; // ms double t_token_generation; // ms + // multitasks + int multitask_id = -1; + void reset() { num_prompt_tokens = 0; generated_text = ""; @@ -529,7 +540,8 @@ struct llama_server_context std::vector queue_tasks; std::vector queue_results; - std::mutex mutex_tasks; + std::vector queue_multitasks; + std::mutex mutex_tasks; // also guards id_gen, and queue_multitasks std::mutex mutex_results; ~llama_server_context() @@ -1112,17 +1124,40 @@ struct llama_server_context return slot.images.size() > 0; } - void send_error(int id, std::string error) + void send_error(task_server& task, std::string error) { std::lock_guard lock(mutex_results); task_result res; - res.id = id; + res.id = task.id; + res.multitask_id = task.multitask_id; res.stop = false; res.error = true; res.result_json = { { "content", error } }; queue_results.push_back(res); } + void add_multi_task(int id, std::vector& sub_ids) + { + std::lock_guard lock(mutex_tasks); + task_multi multi; + multi.id = id; + std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end())); + queue_multitasks.push_back(multi); + } + + void update_multi_task(int multitask_id, int subtask_id, task_result& result) + { + std::lock_guard lock(mutex_tasks); + for (auto& multitask : queue_multitasks) + { + if (multitask.id == multitask_id) + { + multitask.subtasks_remaining.erase(subtask_id); + multitask.results.push_back(result); + } + } + } + json get_model_props() { return get_formated_generation(slots[0]); @@ -1167,6 +1202,7 @@ struct llama_server_context std::lock_guard lock(mutex_results); task_result res; res.id = slot.task_id; + res.multitask_id = slot.multitask_id; res.error = false; res.stop = false; @@ -1206,6 +1242,7 @@ struct llama_server_context std::lock_guard lock(mutex_results); task_result res; res.id = slot.task_id; + res.multitask_id = slot.multitask_id; res.error = false; res.stop = true; @@ -1251,6 +1288,12 @@ struct llama_server_context res.result_json["model"] = slot.oaicompat_model; } + // parent multitask, if any, needs to be updated + if (slot.multitask_id != -1) + { + update_multi_task(slot.multitask_id, slot.task_id, res); + } + queue_results.push_back(res); } @@ -1259,6 +1302,7 @@ struct llama_server_context std::lock_guard lock(mutex_results); task_result res; res.id = slot.task_id; + res.multitask_id = slot.multitask_id; res.error = false; res.stop = true; @@ -1285,9 +1329,9 @@ struct llama_server_context queue_results.push_back(res); } - int request_completion(json data, bool infill, bool embedding) + int request_completion(json data, bool infill, bool embedding, int multitask_id) { - std::lock_guard lock(mutex_tasks); + std::unique_lock lock(mutex_tasks); task_server task; task.id = id_gen++; task.target_id = 0; @@ -1295,6 +1339,16 @@ struct llama_server_context task.infill_mode = infill; task.embedding_mode = embedding; task.type = COMPLETION_TASK; + task.multitask_id = multitask_id; + + // when a completion task's prompt array is not a singleton, we split it into multiple requests + if (task.data.at("prompt").size() > 1) + { + lock.unlock(); // entering new func scope + return split_multiprompt_task(task); + } + + // otherwise, it's a single-prompt task, we actually queue it queue_tasks.push_back(task); return task.id; } @@ -1313,8 +1367,17 @@ struct llama_server_context for (int i = 0; i < (int) queue_results.size(); i++) { + // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result + if (queue_results[i].multitask_id == task_id) + { + update_multi_task(task_id, queue_results[i].id, queue_results[i]); + queue_results.erase(queue_results.begin() + i); + continue; + } + if (queue_results[i].id == task_id) { + assert(queue_results[i].multitask_id == -1); task_result res = queue_results[i]; queue_results.erase(queue_results.begin() + i); return res; @@ -1404,6 +1467,27 @@ struct llama_server_context queue_tasks.push_back(task); } + int split_multiprompt_task(task_server& multiprompt_task) + { + auto prompt_count = multiprompt_task.data.at("prompt").size(); + assert(prompt_count > 1); + + int multitask_id = id_gen++; + std::vector subtask_ids(prompt_count); + for (int i = 0; i < prompt_count; i++) + { + json subtask_data = multiprompt_task.data; + subtask_data["prompt"] = subtask_data["prompt"][i]; + + // subtasks inherit everything else (infill mode, embedding mode, etc.) + subtask_ids[i] = request_completion(subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id); + } + + // queue up the multitask so we can track its subtask progression + add_multi_task(multitask_id, subtask_ids); + return multitask_id; + } + void process_tasks() { std::lock_guard lock(mutex_tasks); @@ -1419,7 +1503,7 @@ struct llama_server_context { LOG_TEE("slot unavailable\n"); // send error result - send_error(task.id, "slot unavailable"); + send_error(task, "slot unavailable"); return; } @@ -1433,11 +1517,12 @@ struct llama_server_context slot->infill = task.infill_mode; slot->embedding = task.embedding_mode; slot->task_id = task.id; + slot->multitask_id = task.multitask_id; if (!launch_slot_with_data(slot, task.data)) { // send error result - send_error(task.id, "internal_error"); + send_error(task, "internal_error"); break; } } break; @@ -1453,6 +1538,38 @@ struct llama_server_context } break; } } + + // remove finished multitasks from the queue of multitasks, and add the corresponding result to the result queue + auto queue_iterator = queue_multitasks.begin(); + while (queue_iterator != queue_multitasks.end()) + { + if (queue_iterator->subtasks_remaining.empty()) + { + // all subtasks done == multitask is done + task_result aggregate_result; + aggregate_result.id = queue_iterator->id; + aggregate_result.stop = true; + aggregate_result.error = false; + + // collect json results into one json result + std::vector result_jsons; + for (auto& subres : queue_iterator->results) + { + result_jsons.push_back(subres.result_json); + aggregate_result.error = aggregate_result.error && subres.error; + } + aggregate_result.result_json = json{ "results", result_jsons }; + + std::lock_guard lock(mutex_results); + queue_results.push_back(aggregate_result); + + queue_iterator = queue_multitasks.erase(queue_iterator); + } + else + { + ++queue_iterator; + } + } } bool update_slots() { @@ -2596,7 +2713,7 @@ int main(int argc, char **argv) svr.Post("/completion", [&llama](const httplib::Request &req, httplib::Response &res) { json data = json::parse(req.body); - const int task_id = llama.request_completion(data, false, false); + const int task_id = llama.request_completion(data, false, false, -1); if (!json_value(data, "stream", false)) { std::string completion_text; task_result result = llama.next_result(task_id); @@ -2685,7 +2802,7 @@ int main(int argc, char **argv) { json data = oaicompat_completion_params_parse(json::parse(req.body)); - const int task_id = llama.request_completion(data, false, false); + const int task_id = llama.request_completion(data, false, false, -1); if (!json_value(data, "stream", false)) { std::string completion_text; @@ -2754,7 +2871,7 @@ int main(int argc, char **argv) svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res) { json data = json::parse(req.body); - const int task_id = llama.request_completion(data, true, false); + const int task_id = llama.request_completion(data, true, false, -1); if (!json_value(data, "stream", false)) { std::string completion_text; task_result result = llama.next_result(task_id); @@ -2858,7 +2975,7 @@ int main(int argc, char **argv) { prompt = ""; } - const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false, true); + const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false, true, -1); task_result result = llama.next_result(task_id); return res.set_content(result.result_json.dump(), "application/json"); });