11#include " llamaCPP.h"
22
3- #include < trantor/utils/SerialTaskQueue.h>
4-
53#include " llama.h"
64#include " log.h"
75#include " utils/nitro_utils.h"
86
97using namespace inferences ;
108using json = nlohmann::json;
119
12- /* *
13- * Queue to handle the inference task, this is to ensure that the inference
14- * task is handled in a sequential manner
15- */
16- static trantor::SerialTaskQueue queue (" worker" );
17-
1810/* *
1911 * The state of the inference task
2012 */
@@ -32,7 +24,6 @@ enum InferenceStatus {
3224 * associated with.
3325 */
3426struct inferenceState {
35- bool is_stopped = false ;
3627 int task_id;
3728 InferenceStatus inferenceStatus = PENDING;
3829 llamaCPP *instance;
@@ -150,7 +141,7 @@ std::string create_return_json(const std::string &id, const std::string &model,
150141 return Json::writeString (writer, root);
151142}
152143
153- llamaCPP::llamaCPP () {
144+ llamaCPP::llamaCPP (): queue( new trantor::ConcurrentTaskQueue(llama.params.n_parallel, " llamaCPP " )) {
154145 // Some default values for now below
155146 log_disable (); // Disable the log to file feature, reduce bloat for
156147 // target
@@ -341,18 +332,17 @@ void llamaCPP::inferenceImpl(
341332
342333 if (state->inferenceStatus == PENDING) {
343334 state->inferenceStatus = RUNNING;
335+ } else if (state->inferenceStatus == FINISHED) {
336+ return 0 ;
344337 }
345338
346339 if (!pBuffer) {
347340 LOG_INFO << " Connection closed or buffer is null. Reset context" ;
348341 state->instance ->llama .request_cancel (state->task_id );
349- state->instance ->single_queue_is_busy = false ;
350- return 0 ;
351- }
352- if (state->is_stopped ) {
353- state->instance ->single_queue_is_busy = false ;
342+ state->inferenceStatus = FINISHED;
354343 return 0 ;
355344 }
345+
356346
357347 task_result result = state->instance ->llama .next_result (state->task_id );
358348 if (!result.error ) {
@@ -377,31 +367,27 @@ void llamaCPP::inferenceImpl(
377367 std::size_t nRead = std::min (str.size (), nBuffSize);
378368 memcpy (pBuffer, str.data (), nRead);
379369 LOG_INFO << " reached result stop" ;
380- state->is_stopped = true ;
381370 state->instance ->llama .request_cancel (state->task_id );
382- state->instance -> single_queue_is_busy = false ;
371+ state->inferenceStatus = FINISHED ;
383372 }
384373
385374 // Make sure nBufferSize is not zero
386375 // Otherwise it stop streaming
387376 if (!nRead) {
388- state->instance -> single_queue_is_busy = false ;
377+ state->inferenceStatus = FINISHED ;
389378 }
390379
391380 return nRead;
392381 }
393- state->instance -> single_queue_is_busy = false ;
382+ state->inferenceStatus = FINISHED ;
394383 return 0 ;
395384 };
396-
397- // Run task in serial queue
398- queue.runTaskInQueue ([callback, state, data,
385+ // Queued task
386+ state->instance ->queue ->runTaskInQueue ([callback, state, data,
399387 chunked_content_provider]() {
400388 state->task_id =
401389 state->instance ->llama .request_completion (data, false , false , -1 );
402390
403- state->instance ->single_queue_is_busy = true ;
404-
405391 // Start streaming response
406392 auto resp = nitro_utils::nitroStreamResponse (chunked_content_provider,
407393 " chat_completions.txt" );
@@ -410,16 +396,14 @@ void llamaCPP::inferenceImpl(
410396 int retries = 0 ;
411397
412398 // Since this is an async task, we will wait for the task to be completed
413- while (state->instance -> single_queue_is_busy && retries < 10 ) {
399+ while (state->inferenceStatus != FINISHED && retries < 10 ) {
414400 // Should wait chunked_content_provider lambda to be called within 3s
415401 if (state->inferenceStatus == PENDING) {
416402 retries += 1 ;
417403 }
418404 LOG_INFO << " Wait for task to be released:" << state->task_id ;
419405 std::this_thread::sleep_for (std::chrono::milliseconds (300 ));
420406 }
421-
422- state->inferenceStatus = FINISHED;
423407 });
424408 return ;
425409 } else {
@@ -466,59 +450,51 @@ void llamaCPP::embeddingImpl(
466450 std::shared_ptr<Json::Value> jsonBody,
467451 std::function<void (const HttpResponsePtr &)> &callback) {
468452
469- Json::Value responseData (Json::arrayValue);
453+ // Queue embedding task
470454 auto state = create_inference_state (this );
471- if (jsonBody->isMember (" input" )) {
472- // If single queue is busy, we will wait if not we will just go ahead and
473- // process and make it busy, and yet i'm aware not DRY, i have the same
474- // stuff on chatcompletion as well
475- if (state->instance ->llama .params .n_parallel == 1 ) {
476- while (state->instance ->single_queue_is_busy ) {
477- LOG_INFO << " Waiting for task to be released status:"
478- << state->instance ->single_queue_is_busy ;
479- std::this_thread::sleep_for (
480- std::chrono::milliseconds (500 )); // Waiting in 500 miliseconds step
481- }
482- }
483- const Json::Value &input = (*jsonBody)[" input" ];
484- if (input.isString ()) {
485- // Process the single string input
486- state->task_id = llama.request_completion (
487- {{" prompt" , input.asString ()}, {" n_predict" , 0 }}, false , true , -1 );
488- state->instance ->single_queue_is_busy = true ;
489- task_result result = llama.next_result (state->task_id );
490- std::vector<float > embedding_result = result.result_json [" embedding" ];
491- responseData.append (create_embedding_payload (embedding_result, 0 ));
492- } else if (input.isArray ()) {
493- // Process each element in the array input
494- for (const auto &elem : input) {
495- if (elem.isString ()) {
496- const int task_id = llama.request_completion (
497- {{" prompt" , elem.asString ()}, {" n_predict" , 0 }}, false , true , -1 );
498- task_result result = llama.next_result (task_id);
499- std::vector<float > embedding_result = result.result_json [" embedding" ];
500- responseData.append (create_embedding_payload (embedding_result, 0 ));
455+
456+ state->instance ->queue ->runTaskInQueue ([this , state, jsonBody, callback]() {
457+ Json::Value responseData (Json::arrayValue);
458+
459+ if (jsonBody->isMember (" input" )) {
460+ const Json::Value &input = (*jsonBody)[" input" ];
461+ if (input.isString ()) {
462+ // Process the single string input
463+ state->task_id = llama.request_completion (
464+ {{" prompt" , input.asString ()}, {" n_predict" , 0 }}, false , true , -1 );
465+ task_result result = llama.next_result (state->task_id );
466+ std::vector<float > embedding_result = result.result_json [" embedding" ];
467+ responseData.append (create_embedding_payload (embedding_result, 0 ));
468+ } else if (input.isArray ()) {
469+ // Process each element in the array input
470+ for (const auto &elem : input) {
471+ if (elem.isString ()) {
472+ const int task_id = llama.request_completion (
473+ {{" prompt" , elem.asString ()}, {" n_predict" , 0 }}, false , true ,
474+ -1 );
475+ task_result result = llama.next_result (task_id);
476+ std::vector<float > embedding_result =
477+ result.result_json [" embedding" ];
478+ responseData.append (create_embedding_payload (embedding_result, 0 ));
479+ }
501480 }
502481 }
503482 }
504- }
505-
506- // We already got result of the embedding so no longer busy
507- state->instance ->single_queue_is_busy = false ;
508483
509- auto resp = nitro_utils::nitroHttpResponse ();
510- Json::Value root;
511- root[" data" ] = responseData;
512- root[" model" ] = " _" ;
513- root[" object" ] = " list" ;
514- Json::Value usage;
515- usage[" prompt_tokens" ] = 0 ;
516- usage[" total_tokens" ] = 0 ;
517- root[" usage" ] = usage;
518-
519- resp->setBody (Json::writeString (Json::StreamWriterBuilder (), root));
520- resp->setContentTypeString (" application/json" );
521- callback (resp);
484+ auto resp = nitro_utils::nitroHttpResponse ();
485+ Json::Value root;
486+ root[" data" ] = responseData;
487+ root[" model" ] = " _" ;
488+ root[" object" ] = " list" ;
489+ Json::Value usage;
490+ usage[" prompt_tokens" ] = 0 ;
491+ usage[" total_tokens" ] = 0 ;
492+ root[" usage" ] = usage;
493+
494+ resp->setBody (Json::writeString (Json::StreamWriterBuilder (), root));
495+ resp->setContentTypeString (" application/json" );
496+ callback (resp);
497+ });
522498}
523499
524500void llamaCPP::unloadModel (
@@ -539,6 +515,7 @@ void llamaCPP::unloadModel(
539515 callback (resp);
540516 return ;
541517}
518+
542519void llamaCPP::modelStatus (
543520 const HttpRequestPtr &req,
544521 std::function<void (const HttpResponsePtr &)> &&callback) {
@@ -555,6 +532,7 @@ void llamaCPP::modelStatus(
555532 callback (resp);
556533 return ;
557534}
535+
558536void llamaCPP::loadModel (
559537 const HttpRequestPtr &req,
560538 std::function<void (const HttpResponsePtr &)> &&callback) {
@@ -674,6 +652,12 @@ bool llamaCPP::loadModelImpl(std::shared_ptr<Json::Value> jsonBody) {
674652 }
675653 llama.initialize ();
676654
655+ if (queue != nullptr ) {
656+ delete queue;
657+ }
658+
659+ queue = new trantor::ConcurrentTaskQueue (llama.params .n_parallel , " llamaCPP" );
660+
677661 llama.model_loaded_external = true ;
678662
679663 LOG_INFO << " Started background task here!" ;
0 commit comments