utils.hpp 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659
  1. // MIT License
  2. // Copyright (c) 2023 Georgi Gerganov
  3. // Permission is hereby granted, free of charge, to any person obtaining a copy
  4. // of this software and associated documentation files (the "Software"), to deal
  5. // in the Software without restriction, including without limitation the rights
  6. // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  7. // copies of the Software, and to permit persons to whom the Software is
  8. // furnished to do so, subject to the following conditions:
  9. // The above copyright notice and this permission notice shall be included in all
  10. // copies or substantial portions of the Software.
  11. // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  12. // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  13. // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  14. // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  15. // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  16. // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  17. // SOFTWARE.
  18. #pragma once
  19. #include <string>
  20. #include <vector>
  21. #include <set>
  22. #include <mutex>
  23. #include <condition_variable>
  24. #include <unordered_map>
  25. #include "json.hpp"
  26. #include "../llava/clip.h"
  27. using json = nlohmann::json;
  28. extern bool server_verbose;
  29. extern bool server_log_json;
  30. #ifndef SERVER_VERBOSE
  31. #define SERVER_VERBOSE 1
  32. #endif
  33. #if SERVER_VERBOSE != 1
  34. #define LOG_VERBOSE(MSG, ...)
  35. #else
  36. #define LOG_VERBOSE(MSG, ...) \
  37. do \
  38. { \
  39. if (server_verbose) \
  40. { \
  41. server_log("VERB", __func__, __LINE__, MSG, __VA_ARGS__); \
  42. } \
  43. } while (0)
  44. #endif
  45. #define LOG_ERROR( MSG, ...) server_log("ERROR", __func__, __LINE__, MSG, __VA_ARGS__)
  46. #define LOG_WARNING(MSG, ...) server_log("WARN", __func__, __LINE__, MSG, __VA_ARGS__)
  47. #define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
  48. #define LOG_DEBUG( MSG, ...) server_log("DEBUG", __func__, __LINE__, MSG, __VA_ARGS__)
  49. enum server_state {
  50. SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
  51. SERVER_STATE_READY, // Server is ready and model is loaded
  52. SERVER_STATE_ERROR // An error occurred, load_model failed
  53. };
  54. enum task_type {
  55. TASK_TYPE_COMPLETION,
  56. TASK_TYPE_CANCEL,
  57. TASK_TYPE_NEXT_RESPONSE,
  58. TASK_TYPE_METRICS
  59. };
  60. struct task_server {
  61. int id = -1; // to be filled by llama_server_queue
  62. int target_id;
  63. task_type type;
  64. json data;
  65. bool infill_mode = false;
  66. bool embedding_mode = false;
  67. int multitask_id = -1;
  68. };
  69. struct task_result {
  70. int id;
  71. int multitask_id = -1;
  72. bool stop;
  73. bool error;
  74. json result_json;
  75. };
  76. struct task_multi {
  77. int id;
  78. std::set<int> subtasks_remaining{};
  79. std::vector<task_result> results{};
  80. };
  81. // completion token output with probabilities
  82. struct completion_token_output {
  83. struct token_prob
  84. {
  85. llama_token tok;
  86. float prob;
  87. };
  88. std::vector<token_prob> probs;
  89. llama_token tok;
  90. std::string text_to_send;
  91. };
  92. struct token_translator {
  93. llama_context * ctx;
  94. std::string operator()(llama_token tok) const { return llama_token_to_piece(ctx, tok); }
  95. std::string operator()(const completion_token_output &cto) const { return (*this)(cto.tok); }
  96. };
  97. static inline void server_log(const char *level, const char *function, int line, const char *message, const nlohmann::ordered_json &extra) {
  98. std::stringstream ss_tid;
  99. ss_tid << std::this_thread::get_id();
  100. json log = nlohmann::ordered_json{
  101. {"tid", ss_tid.str()},
  102. {"timestamp", time(nullptr)},
  103. };
  104. if (strncmp("DEBUG", level, strlen(level)) == 0 && !server_verbose) {
  105. return;
  106. }
  107. if (server_log_json) {
  108. log.merge_patch(
  109. {
  110. {"level", level},
  111. {"function", function},
  112. {"line", line},
  113. {"msg", message},
  114. });
  115. if (!extra.empty()) {
  116. log.merge_patch(extra);
  117. }
  118. std::cout << log.dump(-1, ' ', false, json::error_handler_t::replace) << "\n" << std::flush;
  119. } else {
  120. if (!extra.empty()) {
  121. log.merge_patch(extra);
  122. }
  123. std::stringstream ss;
  124. ss << level << " [" << function << "] " << message << " |";
  125. for (const auto& el : log.items())
  126. {
  127. const std::string value = el.value().dump(-1, ' ', false, json::error_handler_t::replace);
  128. ss << " " << el.key() << "=" << value;
  129. }
  130. const std::string str = ss.str();
  131. printf("%.*s\n", (int)str.size(), str.data());
  132. fflush(stdout);
  133. }
  134. }
  135. //
  136. // server utils
  137. //
  138. template <typename T>
  139. static T json_value(const json &body, const std::string &key, const T &default_value) {
  140. // Fallback null to default value
  141. return body.contains(key) && !body.at(key).is_null()
  142. ? body.value(key, default_value)
  143. : default_value;
  144. }
  145. // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
  146. inline bool verify_custom_template(const std::string & tmpl) {
  147. llama_chat_message chat[] = {{"user", "test"}};
  148. std::vector<char> buf(1);
  149. int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, buf.data(), buf.size());
  150. return res >= 0;
  151. }
  152. // Format given chat. If tmpl is empty, we take the template from model metadata
  153. inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
  154. size_t alloc_size = 0;
  155. // vector holding all allocated string to be passed to llama_chat_apply_template
  156. std::vector<std::string> str(messages.size() * 2);
  157. std::vector<llama_chat_message> chat(messages.size());
  158. for (size_t i = 0; i < messages.size(); ++i) {
  159. auto &curr_msg = messages[i];
  160. str[i*2 + 0] = json_value(curr_msg, "role", std::string(""));
  161. str[i*2 + 1] = json_value(curr_msg, "content", std::string(""));
  162. alloc_size += str[i*2 + 1].length();
  163. chat[i].role = str[i*2 + 0].c_str();
  164. chat[i].content = str[i*2 + 1].c_str();
  165. }
  166. const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
  167. std::vector<char> buf(alloc_size * 2);
  168. // run the first time to get the total output length
  169. int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());
  170. // if it turns out that our buffer is too small, we resize it
  171. if ((size_t) res > buf.size()) {
  172. buf.resize(res);
  173. res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());
  174. }
  175. std::string formatted_chat(buf.data(), res);
  176. LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}});
  177. return formatted_chat;
  178. }
  179. //
  180. // work queue utils
  181. //
  182. struct llama_server_queue {
  183. int id = 0;
  184. std::mutex mutex_tasks;
  185. bool running;
  186. // queues
  187. std::vector<task_server> queue_tasks;
  188. std::vector<task_server> queue_tasks_deferred;
  189. std::vector<task_multi> queue_multitasks;
  190. std::condition_variable condition_tasks;
  191. // callback functions
  192. std::function<void(task_server&)> callback_new_task;
  193. std::function<void(task_multi&)> callback_finish_multitask;
  194. std::function<void(void)> callback_run_slots;
  195. // Add a new task to the end of the queue
  196. int post(task_server task) {
  197. std::unique_lock<std::mutex> lock(mutex_tasks);
  198. if (task.id == -1) {
  199. task.id = id++;
  200. LOG_VERBOSE("new task id", {{"new_id", task.id}});
  201. }
  202. queue_tasks.push_back(std::move(task));
  203. condition_tasks.notify_one();
  204. return task.id;
  205. }
  206. // Add a new task, but defer until one slot is available
  207. void defer(task_server task) {
  208. std::unique_lock<std::mutex> lock(mutex_tasks);
  209. queue_tasks_deferred.push_back(std::move(task));
  210. }
  211. // Get the next id for creating anew task
  212. int get_new_id() {
  213. std::unique_lock<std::mutex> lock(mutex_tasks);
  214. int new_id = id++;
  215. LOG_VERBOSE("new task id", {{"new_id", new_id}});
  216. return new_id;
  217. }
  218. // Register function to process a new task
  219. void on_new_task(std::function<void(task_server&)> callback) {
  220. callback_new_task = callback;
  221. }
  222. // Register function to process a multitask when it is finished
  223. void on_finish_multitask(std::function<void(task_multi&)> callback) {
  224. callback_finish_multitask = callback;
  225. }
  226. // Register the function to be called when all slots data is ready to be processed
  227. void on_run_slots(std::function<void(void)> callback) {
  228. callback_run_slots = callback;
  229. }
  230. // Call when the state of one slot is changed
  231. void notify_slot_changed() {
  232. // move deferred tasks back to main loop
  233. std::unique_lock<std::mutex> lock(mutex_tasks);
  234. for (auto & task : queue_tasks_deferred) {
  235. queue_tasks.push_back(std::move(task));
  236. }
  237. queue_tasks_deferred.clear();
  238. }
  239. // end the start_loop routine
  240. void terminate() {
  241. {
  242. std::unique_lock<std::mutex> lock(mutex_tasks);
  243. running = false;
  244. }
  245. condition_tasks.notify_all();
  246. }
  247. /**
  248. * Main loop consists of these steps:
  249. * - Wait until a new task arrives
  250. * - Process the task (i.e. maybe copy data into slot)
  251. * - Check if multitask is finished
  252. * - Run all slots
  253. */
  254. void start_loop() {
  255. running = true;
  256. while (true) {
  257. LOG_VERBOSE("new task may arrive", {});
  258. {
  259. while (true)
  260. {
  261. std::unique_lock<std::mutex> lock(mutex_tasks);
  262. if (queue_tasks.empty()) {
  263. lock.unlock();
  264. break;
  265. }
  266. task_server task = queue_tasks.front();
  267. queue_tasks.erase(queue_tasks.begin());
  268. lock.unlock();
  269. LOG_VERBOSE("callback_new_task", {{"task_id", task.id}});
  270. callback_new_task(task);
  271. }
  272. LOG_VERBOSE("update_multitasks", {});
  273. // check if we have any finished multitasks
  274. auto queue_iterator = queue_multitasks.begin();
  275. while (queue_iterator != queue_multitasks.end())
  276. {
  277. if (queue_iterator->subtasks_remaining.empty())
  278. {
  279. // all subtasks done == multitask is done
  280. task_multi current_multitask = *queue_iterator;
  281. callback_finish_multitask(current_multitask);
  282. // remove this multitask
  283. queue_iterator = queue_multitasks.erase(queue_iterator);
  284. }
  285. else
  286. {
  287. ++queue_iterator;
  288. }
  289. }
  290. // all tasks in the current loop is processed, slots data is now ready
  291. LOG_VERBOSE("callback_run_slots", {});
  292. callback_run_slots();
  293. }
  294. LOG_VERBOSE("wait for new task", {});
  295. // wait for new task
  296. {
  297. std::unique_lock<std::mutex> lock(mutex_tasks);
  298. if (queue_tasks.empty()) {
  299. if (!running) {
  300. LOG_VERBOSE("ending start_loop", {});
  301. return;
  302. }
  303. condition_tasks.wait(lock, [&]{
  304. return (!queue_tasks.empty() || !running);
  305. });
  306. }
  307. }
  308. }
  309. }
  310. //
  311. // functions to manage multitasks
  312. //
  313. // add a multitask by specifying the id of all subtask (subtask is a task_server)
  314. void add_multitask(int multitask_id, std::vector<int>& sub_ids)
  315. {
  316. std::lock_guard<std::mutex> lock(mutex_tasks);
  317. task_multi multi;
  318. multi.id = multitask_id;
  319. std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
  320. queue_multitasks.push_back(multi);
  321. }
  322. // updatethe remaining subtasks, while appending results to multitask
  323. void update_multitask(int multitask_id, int subtask_id, task_result& result)
  324. {
  325. std::lock_guard<std::mutex> lock(mutex_tasks);
  326. for (auto& multitask : queue_multitasks)
  327. {
  328. if (multitask.id == multitask_id)
  329. {
  330. multitask.subtasks_remaining.erase(subtask_id);
  331. multitask.results.push_back(result);
  332. }
  333. }
  334. }
  335. };
  336. struct llama_server_response {
  337. typedef std::function<void(int, int, task_result&)> callback_multitask_t;
  338. callback_multitask_t callback_update_multitask;
  339. // for keeping track of all tasks waiting for the result
  340. std::set<int> waiting_task_ids;
  341. // the main result queue
  342. std::vector<task_result> queue_results;
  343. std::mutex mutex_results;
  344. std::condition_variable condition_results;
  345. // add the task_id to the list of tasks waiting for response
  346. void add_waiting_task_id(int task_id) {
  347. LOG_VERBOSE("waiting for task id", {{"task_id", task_id}});
  348. std::unique_lock<std::mutex> lock(mutex_results);
  349. waiting_task_ids.insert(task_id);
  350. }
  351. // when the request is finished, we can remove task associated with it
  352. void remove_waiting_task_id(int task_id) {
  353. LOG_VERBOSE("remove waiting for task id", {{"task_id", task_id}});
  354. std::unique_lock<std::mutex> lock(mutex_results);
  355. waiting_task_ids.erase(task_id);
  356. }
  357. // This function blocks the thread until there is a response for this task_id
  358. task_result recv(int task_id) {
  359. while (true)
  360. {
  361. std::unique_lock<std::mutex> lock(mutex_results);
  362. condition_results.wait(lock, [&]{
  363. return !queue_results.empty();
  364. });
  365. for (int i = 0; i < (int) queue_results.size(); i++)
  366. {
  367. if (queue_results[i].id == task_id)
  368. {
  369. assert(queue_results[i].multitask_id == -1);
  370. task_result res = queue_results[i];
  371. queue_results.erase(queue_results.begin() + i);
  372. return res;
  373. }
  374. }
  375. }
  376. // should never reach here
  377. }
  378. // Register the function to update multitask
  379. void on_multitask_update(callback_multitask_t callback) {
  380. callback_update_multitask = callback;
  381. }
  382. // Send a new result to a waiting task_id
  383. void send(task_result result) {
  384. std::unique_lock<std::mutex> lock(mutex_results);
  385. LOG_VERBOSE("send new result", {{"task_id", result.id}});
  386. for (auto& task_id : waiting_task_ids) {
  387. // LOG_TEE("waiting task id %i \n", task_id);
  388. // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
  389. if (result.multitask_id == task_id)
  390. {
  391. LOG_VERBOSE("callback_update_multitask", {{"task_id", task_id}});
  392. callback_update_multitask(task_id, result.id, result);
  393. continue;
  394. }
  395. if (result.id == task_id)
  396. {
  397. LOG_VERBOSE("queue_results.push_back", {{"task_id", task_id}});
  398. queue_results.push_back(result);
  399. condition_results.notify_all();
  400. return;
  401. }
  402. }
  403. }
  404. };
  405. //
  406. // base64 utils (TODO: move to common in the future)
  407. //
  408. static const std::string base64_chars =
  409. "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
  410. "abcdefghijklmnopqrstuvwxyz"
  411. "0123456789+/";
  412. static inline bool is_base64(uint8_t c)
  413. {
  414. return (isalnum(c) || (c == '+') || (c == '/'));
  415. }
  416. static inline std::vector<uint8_t> base64_decode(const std::string & encoded_string)
  417. {
  418. int i = 0;
  419. int j = 0;
  420. int in_ = 0;
  421. int in_len = encoded_string.size();
  422. uint8_t char_array_4[4];
  423. uint8_t char_array_3[3];
  424. std::vector<uint8_t> ret;
  425. while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_]))
  426. {
  427. char_array_4[i++] = encoded_string[in_]; in_++;
  428. if (i == 4)
  429. {
  430. for (i = 0; i <4; i++)
  431. {
  432. char_array_4[i] = base64_chars.find(char_array_4[i]);
  433. }
  434. char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4);
  435. char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
  436. char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
  437. for (i = 0; (i < 3); i++)
  438. {
  439. ret.push_back(char_array_3[i]);
  440. }
  441. i = 0;
  442. }
  443. }
  444. if (i)
  445. {
  446. for (j = i; j <4; j++)
  447. {
  448. char_array_4[j] = 0;
  449. }
  450. for (j = 0; j <4; j++)
  451. {
  452. char_array_4[j] = base64_chars.find(char_array_4[j]);
  453. }
  454. char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4);
  455. char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
  456. char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
  457. for (j = 0; (j < i - 1); j++)
  458. {
  459. ret.push_back(char_array_3[j]);
  460. }
  461. }
  462. return ret;
  463. }
  464. //
  465. // random string / id
  466. //
  467. static std::string random_string()
  468. {
  469. static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz");
  470. std::random_device rd;
  471. std::mt19937 generator(rd());
  472. std::string result(32, ' ');
  473. for (int i = 0; i < 32; ++i) {
  474. result[i] = str[generator() % str.size()];
  475. }
  476. return result;
  477. }
  478. static std::string gen_chatcmplid()
  479. {
  480. std::stringstream chatcmplid;
  481. chatcmplid << "chatcmpl-" << random_string();
  482. return chatcmplid.str();
  483. }
  484. //
  485. // other common utils
  486. //
  487. static size_t common_part(const std::vector<llama_token> &a, const std::vector<llama_token> &b)
  488. {
  489. size_t i;
  490. for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++)
  491. {
  492. }
  493. return i;
  494. }
  495. static bool ends_with(const std::string &str, const std::string &suffix)
  496. {
  497. return str.size() >= suffix.size() &&
  498. 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
  499. }
  500. static size_t find_partial_stop_string(const std::string &stop,
  501. const std::string &text)
  502. {
  503. if (!text.empty() && !stop.empty())
  504. {
  505. const char text_last_char = text.back();
  506. for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--)
  507. {
  508. if (stop[char_index] == text_last_char)
  509. {
  510. const std::string current_partial = stop.substr(0, char_index + 1);
  511. if (ends_with(text, current_partial))
  512. {
  513. return text.size() - char_index - 1;
  514. }
  515. }
  516. }
  517. }
  518. return std::string::npos;
  519. }
  520. // TODO: reuse llama_detokenize
  521. template <class Iter>
  522. static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end)
  523. {
  524. std::string ret;
  525. for (; begin != end; ++begin)
  526. {
  527. ret += llama_token_to_piece(ctx, *begin);
  528. }
  529. return ret;
  530. }
  531. // format incomplete utf-8 multibyte character for output
  532. static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token)
  533. {
  534. std::string out = token == -1 ? "" : llama_token_to_piece(ctx, token);
  535. // if the size is 1 and first bit is 1, meaning it's a partial character
  536. // (size > 1 meaning it's already a known token)
  537. if (out.size() == 1 && (out[0] & 0x80) == 0x80)
  538. {
  539. std::stringstream ss;
  540. ss << std::hex << (out[0] & 0xff);
  541. std::string res(ss.str());
  542. out = "byte: \\x" + res;
  543. }
  544. return out;
  545. }
  546. // convert a vector of completion_token_output to json
  547. static json probs_vector_to_json(const llama_context *ctx, const std::vector<completion_token_output> &probs)
  548. {
  549. json out = json::array();
  550. for (const auto &prob : probs)
  551. {
  552. json probs_for_token = json::array();
  553. for (const auto &p : prob.probs)
  554. {
  555. std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok);
  556. probs_for_token.push_back(json
  557. {
  558. {"tok_str", tok_str},
  559. {"prob", p.prob},
  560. });
  561. }
  562. std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok);
  563. out.push_back(json{
  564. {"content", tok_str},
  565. {"probs", probs_for_token},
  566. });
  567. }
  568. return out;
  569. }