66using namespace inferences ;
77using json = nlohmann::json;
88
9+ /* *
10+ * There is a need to save state of current ongoing inference status of a
11+ * handler, this struct is to solve that issue
12+ *
13+ * @param inst Pointer to the llamaCPP instance this inference task is
14+ * associated with.
15+ */
916struct inferenceState {
1017 bool is_stopped = false ;
1118 bool is_streaming = false ;
@@ -15,15 +22,20 @@ struct inferenceState {
1522 inferenceState (llamaCPP *inst) : instance(inst) {}
1623};
1724
25+ /* *
26+ * This function is to create the smart pointer to inferenceState, hence the
27+ * inferenceState will be persisting even tho the lambda in streaming might go
28+ * out of scope and the handler already moved on
29+ */
1830std::shared_ptr<inferenceState> create_inference_state (llamaCPP *instance) {
1931 return std::make_shared<inferenceState>(instance);
2032}
2133
22- // --------------------------------------------
23-
24- // Function to check if the model is loaded
25- void check_model_loaded (
26- llama_server_context &llama, const HttpRequestPtr &req,
34+ /* *
35+ * Check if model already loaded if not return message to user
36+ * @param callback the function to return message to user
37+ */
38+ void llamaCPP::checkModelLoaded (
2739 std::function<void (const HttpResponsePtr &)> &callback) {
2840 if (!llama.model_loaded_external ) {
2941 Json::Value jsonResp;
@@ -136,7 +148,7 @@ void llamaCPP::warmupModel() {
136148 return ;
137149}
138150
139- void llamaCPP::chatCompletionPrelight (
151+ void llamaCPP::handlePrelight (
140152 const HttpRequestPtr &req,
141153 std::function<void (const HttpResponsePtr &)> &&callback) {
142154 auto resp = drogon::HttpResponse::newHttpResponse ();
@@ -151,10 +163,17 @@ void llamaCPP::chatCompletion(
151163 const HttpRequestPtr &req,
152164 std::function<void (const HttpResponsePtr &)> &&callback) {
153165
166+ const auto &jsonBody = req->getJsonObject ();
154167 // Check if model is loaded
155- check_model_loaded (llama, req, callback);
168+ checkModelLoaded (callback);
169+
170+ chatCompletionImpl (jsonBody, callback);
171+ }
172+
173+ void llamaCPP::chatCompletionImpl (
174+ std::shared_ptr<Json::Value> jsonBody,
175+ std::function<void (const HttpResponsePtr &)> &callback) {
156176
157- const auto &jsonBody = req->getJsonObject ();
158177 std::string formatted_output = pre_prompt;
159178
160179 json data;
@@ -402,17 +421,23 @@ void llamaCPP::chatCompletion(
402421 }
403422 }
404423}
424+
405425void llamaCPP::embedding (
406426 const HttpRequestPtr &req,
407427 std::function<void (const HttpResponsePtr &)> &&callback) {
408- check_model_loaded (llama, req, callback);
428+ checkModelLoaded (callback);
429+ const auto &jsonBody = req->getJsonObject ();
409430
410- auto state = create_inference_state (this );
431+ embeddingImpl (jsonBody, callback);
432+ return ;
433+ }
411434
412- const auto &jsonBody = req->getJsonObject ();
435+ void llamaCPP::embeddingImpl (
436+ std::shared_ptr<Json::Value> jsonBody,
437+ std::function<void (const HttpResponsePtr &)> &callback) {
413438
414439 Json::Value responseData (Json::arrayValue);
415-
440+ auto state = create_inference_state ( this );
416441 if (jsonBody->isMember (" input" )) {
417442 // If single queue is busy, we will wait if not we will just go ahead and
418443 // process and make it busy, and yet i'm aware not DRY, i have the same
@@ -464,7 +489,6 @@ void llamaCPP::embedding(
464489 resp->setBody (Json::writeString (Json::StreamWriterBuilder (), root));
465490 resp->setContentTypeString (" application/json" );
466491 callback (resp);
467- return ;
468492}
469493
470494void llamaCPP::unloadModel (
@@ -501,31 +525,61 @@ void llamaCPP::modelStatus(
501525 callback (resp);
502526 return ;
503527}
528+ void llamaCPP::loadModel (
529+ const HttpRequestPtr &req,
530+ std::function<void (const HttpResponsePtr &)> &&callback) {
504531
505- bool llamaCPP::loadModelImpl (const Json::Value &jsonBody) {
532+ if (llama.model_loaded_external ) {
533+ LOG_INFO << " model loaded" ;
534+ Json::Value jsonResp;
535+ jsonResp[" message" ] = " Model already loaded" ;
536+ auto resp = nitro_utils::nitroHttpJsonResponse (jsonResp);
537+ resp->setStatusCode (drogon::k409Conflict);
538+ callback (resp);
539+ return ;
540+ }
506541
507- gpt_params params;
542+ const auto &jsonBody = req->getJsonObject ();
543+ if (!loadModelImpl (jsonBody)) {
544+ // Error occurred during model loading
545+ Json::Value jsonResp;
546+ jsonResp[" message" ] = " Failed to load model" ;
547+ auto resp = nitro_utils::nitroHttpJsonResponse (jsonResp);
548+ resp->setStatusCode (drogon::k500InternalServerError);
549+ callback (resp);
550+ } else {
551+ // Model loaded successfully
552+ Json::Value jsonResp;
553+ jsonResp[" message" ] = " Model loaded successfully" ;
554+ auto resp = nitro_utils::nitroHttpJsonResponse (jsonResp);
555+ callback (resp);
556+ }
557+ }
558+
559+ bool llamaCPP::loadModelImpl (std::shared_ptr<Json::Value> jsonBody) {
508560
561+ gpt_params params;
509562 // By default will setting based on number of handlers
510563 if (jsonBody) {
511- if (!jsonBody[ " mmproj" ] .isNull ()) {
564+ if (!jsonBody-> operator []( " mmproj" ) .isNull ()) {
512565 LOG_INFO << " MMPROJ FILE detected, multi-model enabled!" ;
513- params.mmproj = jsonBody[ " mmproj" ] .asString ();
566+ params.mmproj = jsonBody-> operator []( " mmproj" ) .asString ();
514567 }
515- if (!jsonBody[ " grp_attn_n" ] .isNull ()) {
568+ if (!jsonBody-> operator []( " grp_attn_n" ) .isNull ()) {
516569
517- params.grp_attn_n = jsonBody[ " grp_attn_n" ] .asInt ();
570+ params.grp_attn_n = jsonBody-> operator []( " grp_attn_n" ) .asInt ();
518571 }
519- if (!jsonBody[ " grp_attn_w" ] .isNull ()) {
572+ if (!jsonBody-> operator []( " grp_attn_w" ) .isNull ()) {
520573
521- params.grp_attn_w = jsonBody[ " grp_attn_w" ] .asInt ();
574+ params.grp_attn_w = jsonBody-> operator []( " grp_attn_w" ) .asInt ();
522575 }
523- if (!jsonBody[ " mlock" ] .isNull ()) {
524- params.use_mlock = jsonBody[ " mlock" ] .asBool ();
576+ if (!jsonBody-> operator []( " mlock" ) .isNull ()) {
577+ params.use_mlock = jsonBody-> operator []( " mlock" ) .asBool ();
525578 }
526579
527- if (!jsonBody[" grammar_file" ].isNull ()) {
528- std::string grammar_file = jsonBody[" grammar_file" ].asString ();
580+ if (!jsonBody->operator [](" grammar_file" ).isNull ()) {
581+ std::string grammar_file =
582+ jsonBody->operator [](" grammar_file" ).asString ();
529583 std::ifstream file (grammar_file);
530584 if (!file) {
531585 LOG_ERROR << " Grammar file not found" ;
@@ -536,30 +590,31 @@ bool llamaCPP::loadModelImpl(const Json::Value &jsonBody) {
536590 }
537591 };
538592
539- params.model = jsonBody[ " llama_model_path" ] .asString ();
540- params.n_gpu_layers = jsonBody. get (" ngl" , 100 ).asInt ();
541- params.n_ctx = jsonBody. get (" ctx_len" , 2048 ).asInt ();
542- params.embedding = jsonBody. get (" embedding" , true ).asBool ();
593+ params.model = jsonBody-> operator []( " llama_model_path" ) .asString ();
594+ params.n_gpu_layers = jsonBody-> get (" ngl" , 100 ).asInt ();
595+ params.n_ctx = jsonBody-> get (" ctx_len" , 2048 ).asInt ();
596+ params.embedding = jsonBody-> get (" embedding" , true ).asBool ();
543597 // Check if n_parallel exists in jsonBody, if not, set to drogon_thread
544- params.n_batch = jsonBody. get (" n_batch" , 512 ).asInt ();
545- params.n_parallel = jsonBody. get (" n_parallel" , 1 ).asInt ();
598+ params.n_batch = jsonBody-> get (" n_batch" , 512 ).asInt ();
599+ params.n_parallel = jsonBody-> get (" n_parallel" , 1 ).asInt ();
546600 params.n_threads =
547- jsonBody. get (" cpu_threads" , std::thread::hardware_concurrency ())
601+ jsonBody-> get (" cpu_threads" , std::thread::hardware_concurrency ())
548602 .asInt ();
549- params.cont_batching = jsonBody. get (" cont_batching" , false ).asBool ();
603+ params.cont_batching = jsonBody-> get (" cont_batching" , false ).asBool ();
550604 this ->clean_cache_threshold =
551- jsonBody. get (" clean_cache_threshold" , 5 ).asInt ();
552- this ->caching_enabled = jsonBody. get (" caching_enabled" , false ).asBool ();
553- this ->user_prompt = jsonBody. get (" user_prompt" , " USER: " ).asString ();
554- this ->ai_prompt = jsonBody. get (" ai_prompt" , " ASSISTANT: " ).asString ();
605+ jsonBody-> get (" clean_cache_threshold" , 5 ).asInt ();
606+ this ->caching_enabled = jsonBody-> get (" caching_enabled" , false ).asBool ();
607+ this ->user_prompt = jsonBody-> get (" user_prompt" , " USER: " ).asString ();
608+ this ->ai_prompt = jsonBody-> get (" ai_prompt" , " ASSISTANT: " ).asString ();
555609 this ->system_prompt =
556- jsonBody. get (" system_prompt" , " ASSISTANT's RULE: " ).asString ();
557- this ->pre_prompt = jsonBody. get (" pre_prompt" , " " ).asString ();
558- this ->repeat_last_n = jsonBody. get (" repeat_last_n" , 32 ).asInt ();
610+ jsonBody-> get (" system_prompt" , " ASSISTANT's RULE: " ).asString ();
611+ this ->pre_prompt = jsonBody-> get (" pre_prompt" , " " ).asString ();
612+ this ->repeat_last_n = jsonBody-> get (" repeat_last_n" , 32 ).asInt ();
559613
560- if (!jsonBody[ " llama_log_folder" ] .isNull ()) {
614+ if (!jsonBody-> operator []( " llama_log_folder" ) .isNull ()) {
561615 log_enable ();
562- std::string llama_log_folder = jsonBody[" llama_log_folder" ].asString ();
616+ std::string llama_log_folder =
617+ jsonBody->operator [](" llama_log_folder" ).asString ();
563618 log_set_target (llama_log_folder + " llama.log" );
564619 } // Set folder for llama log
565620 }
@@ -597,37 +652,6 @@ bool llamaCPP::loadModelImpl(const Json::Value &jsonBody) {
597652 return true ;
598653}
599654
600- void llamaCPP::loadModel (
601- const HttpRequestPtr &req,
602- std::function<void (const HttpResponsePtr &)> &&callback) {
603-
604- if (llama.model_loaded_external ) {
605- LOG_INFO << " model loaded" ;
606- Json::Value jsonResp;
607- jsonResp[" message" ] = " Model already loaded" ;
608- auto resp = nitro_utils::nitroHttpJsonResponse (jsonResp);
609- resp->setStatusCode (drogon::k409Conflict);
610- callback (resp);
611- return ;
612- }
613-
614- const auto &jsonBody = req->getJsonObject ();
615- if (!loadModelImpl (*jsonBody)) {
616- // Error occurred during model loading
617- Json::Value jsonResp;
618- jsonResp[" message" ] = " Failed to load model" ;
619- auto resp = nitro_utils::nitroHttpJsonResponse (jsonResp);
620- resp->setStatusCode (drogon::k500InternalServerError);
621- callback (resp);
622- } else {
623- // Model loaded successfully
624- Json::Value jsonResp;
625- jsonResp[" message" ] = " Model loaded successfully" ;
626- auto resp = nitro_utils::nitroHttpJsonResponse (jsonResp);
627- callback (resp);
628- }
629- }
630-
631655void llamaCPP::backgroundTask () {
632656 while (llama.model_loaded_external ) {
633657 // model_loaded =
0 commit comments