Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit fb7bc74

Browse files
authored
Merge pull request #433 from janhq/refactor/simplify-state-with-queue-system
refactor: simplify state with queued system
2 parents 9c1d8b6 + 5a3432f commit fb7bc74

File tree

2 files changed

+65
-77
lines changed

2 files changed

+65
-77
lines changed

controllers/llamaCPP.cc

Lines changed: 59 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,12 @@
11
#include "llamaCPP.h"
22

3-
#include <trantor/utils/SerialTaskQueue.h>
4-
53
#include "llama.h"
64
#include "log.h"
75
#include "utils/nitro_utils.h"
86

97
using namespace inferences;
108
using json = nlohmann::json;
119

12-
/**
13-
* Queue to handle the inference task, this is to ensure that the inference
14-
* task is handled in a sequential manner
15-
*/
16-
static trantor::SerialTaskQueue queue("worker");
17-
1810
/**
1911
* The state of the inference task
2012
*/
@@ -32,7 +24,6 @@ enum InferenceStatus {
3224
* associated with.
3325
*/
3426
struct inferenceState {
35-
bool is_stopped = false;
3627
int task_id;
3728
InferenceStatus inferenceStatus = PENDING;
3829
llamaCPP *instance;
@@ -150,7 +141,7 @@ std::string create_return_json(const std::string &id, const std::string &model,
150141
return Json::writeString(writer, root);
151142
}
152143

153-
llamaCPP::llamaCPP() {
144+
llamaCPP::llamaCPP(): queue(new trantor::ConcurrentTaskQueue(llama.params.n_parallel, "llamaCPP")) {
154145
// Some default values for now below
155146
log_disable(); // Disable the log to file feature, reduce bloat for
156147
// target
@@ -341,18 +332,17 @@ void llamaCPP::inferenceImpl(
341332

342333
if(state->inferenceStatus == PENDING) {
343334
state->inferenceStatus = RUNNING;
335+
} else if (state->inferenceStatus == FINISHED) {
336+
return 0;
344337
}
345338

346339
if (!pBuffer) {
347340
LOG_INFO << "Connection closed or buffer is null. Reset context";
348341
state->instance->llama.request_cancel(state->task_id);
349-
state->instance->single_queue_is_busy = false;
350-
return 0;
351-
}
352-
if (state->is_stopped) {
353-
state->instance->single_queue_is_busy = false;
342+
state->inferenceStatus = FINISHED;
354343
return 0;
355344
}
345+
356346

357347
task_result result = state->instance->llama.next_result(state->task_id);
358348
if (!result.error) {
@@ -377,31 +367,27 @@ void llamaCPP::inferenceImpl(
377367
std::size_t nRead = std::min(str.size(), nBuffSize);
378368
memcpy(pBuffer, str.data(), nRead);
379369
LOG_INFO << "reached result stop";
380-
state->is_stopped = true;
381370
state->instance->llama.request_cancel(state->task_id);
382-
state->instance->single_queue_is_busy = false;
371+
state->inferenceStatus = FINISHED;
383372
}
384373

385374
// Make sure nBufferSize is not zero
386375
// Otherwise it stop streaming
387376
if(!nRead) {
388-
state->instance->single_queue_is_busy = false;
377+
state->inferenceStatus = FINISHED;
389378
}
390379

391380
return nRead;
392381
}
393-
state->instance->single_queue_is_busy = false;
382+
state->inferenceStatus = FINISHED;
394383
return 0;
395384
};
396-
397-
// Run task in serial queue
398-
queue.runTaskInQueue([callback, state, data,
385+
// Queued task
386+
state->instance->queue->runTaskInQueue([callback, state, data,
399387
chunked_content_provider]() {
400388
state->task_id =
401389
state->instance->llama.request_completion(data, false, false, -1);
402390

403-
state->instance->single_queue_is_busy = true;
404-
405391
// Start streaming response
406392
auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider,
407393
"chat_completions.txt");
@@ -410,16 +396,14 @@ void llamaCPP::inferenceImpl(
410396
int retries = 0;
411397

412398
// Since this is an async task, we will wait for the task to be completed
413-
while (state->instance->single_queue_is_busy && retries < 10) {
399+
while (state->inferenceStatus != FINISHED && retries < 10) {
414400
// Should wait chunked_content_provider lambda to be called within 3s
415401
if(state->inferenceStatus == PENDING) {
416402
retries += 1;
417403
}
418404
LOG_INFO << "Wait for task to be released:" << state->task_id;
419405
std::this_thread::sleep_for(std::chrono::milliseconds(300));
420406
}
421-
422-
state->inferenceStatus = FINISHED;
423407
});
424408
return;
425409
} else {
@@ -466,59 +450,51 @@ void llamaCPP::embeddingImpl(
466450
std::shared_ptr<Json::Value> jsonBody,
467451
std::function<void(const HttpResponsePtr &)> &callback) {
468452

469-
Json::Value responseData(Json::arrayValue);
453+
// Queue embedding task
470454
auto state = create_inference_state(this);
471-
if (jsonBody->isMember("input")) {
472-
// If single queue is busy, we will wait if not we will just go ahead and
473-
// process and make it busy, and yet i'm aware not DRY, i have the same
474-
// stuff on chatcompletion as well
475-
if (state->instance->llama.params.n_parallel == 1) {
476-
while (state->instance->single_queue_is_busy) {
477-
LOG_INFO << "Waiting for task to be released status:"
478-
<< state->instance->single_queue_is_busy;
479-
std::this_thread::sleep_for(
480-
std::chrono::milliseconds(500)); // Waiting in 500 miliseconds step
481-
}
482-
}
483-
const Json::Value &input = (*jsonBody)["input"];
484-
if (input.isString()) {
485-
// Process the single string input
486-
state->task_id = llama.request_completion(
487-
{{"prompt", input.asString()}, {"n_predict", 0}}, false, true, -1);
488-
state->instance->single_queue_is_busy = true;
489-
task_result result = llama.next_result(state->task_id);
490-
std::vector<float> embedding_result = result.result_json["embedding"];
491-
responseData.append(create_embedding_payload(embedding_result, 0));
492-
} else if (input.isArray()) {
493-
// Process each element in the array input
494-
for (const auto &elem : input) {
495-
if (elem.isString()) {
496-
const int task_id = llama.request_completion(
497-
{{"prompt", elem.asString()}, {"n_predict", 0}}, false, true, -1);
498-
task_result result = llama.next_result(task_id);
499-
std::vector<float> embedding_result = result.result_json["embedding"];
500-
responseData.append(create_embedding_payload(embedding_result, 0));
455+
456+
state->instance->queue->runTaskInQueue([this, state, jsonBody, callback]() {
457+
Json::Value responseData(Json::arrayValue);
458+
459+
if (jsonBody->isMember("input")) {
460+
const Json::Value &input = (*jsonBody)["input"];
461+
if (input.isString()) {
462+
// Process the single string input
463+
state->task_id = llama.request_completion(
464+
{{"prompt", input.asString()}, {"n_predict", 0}}, false, true, -1);
465+
task_result result = llama.next_result(state->task_id);
466+
std::vector<float> embedding_result = result.result_json["embedding"];
467+
responseData.append(create_embedding_payload(embedding_result, 0));
468+
} else if (input.isArray()) {
469+
// Process each element in the array input
470+
for (const auto &elem : input) {
471+
if (elem.isString()) {
472+
const int task_id = llama.request_completion(
473+
{{"prompt", elem.asString()}, {"n_predict", 0}}, false, true,
474+
-1);
475+
task_result result = llama.next_result(task_id);
476+
std::vector<float> embedding_result =
477+
result.result_json["embedding"];
478+
responseData.append(create_embedding_payload(embedding_result, 0));
479+
}
501480
}
502481
}
503482
}
504-
}
505-
506-
// We already got result of the embedding so no longer busy
507-
state->instance->single_queue_is_busy = false;
508483

509-
auto resp = nitro_utils::nitroHttpResponse();
510-
Json::Value root;
511-
root["data"] = responseData;
512-
root["model"] = "_";
513-
root["object"] = "list";
514-
Json::Value usage;
515-
usage["prompt_tokens"] = 0;
516-
usage["total_tokens"] = 0;
517-
root["usage"] = usage;
518-
519-
resp->setBody(Json::writeString(Json::StreamWriterBuilder(), root));
520-
resp->setContentTypeString("application/json");
521-
callback(resp);
484+
auto resp = nitro_utils::nitroHttpResponse();
485+
Json::Value root;
486+
root["data"] = responseData;
487+
root["model"] = "_";
488+
root["object"] = "list";
489+
Json::Value usage;
490+
usage["prompt_tokens"] = 0;
491+
usage["total_tokens"] = 0;
492+
root["usage"] = usage;
493+
494+
resp->setBody(Json::writeString(Json::StreamWriterBuilder(), root));
495+
resp->setContentTypeString("application/json");
496+
callback(resp);
497+
});
522498
}
523499

524500
void llamaCPP::unloadModel(
@@ -539,6 +515,7 @@ void llamaCPP::unloadModel(
539515
callback(resp);
540516
return;
541517
}
518+
542519
void llamaCPP::modelStatus(
543520
const HttpRequestPtr &req,
544521
std::function<void(const HttpResponsePtr &)> &&callback) {
@@ -555,6 +532,7 @@ void llamaCPP::modelStatus(
555532
callback(resp);
556533
return;
557534
}
535+
558536
void llamaCPP::loadModel(
559537
const HttpRequestPtr &req,
560538
std::function<void(const HttpResponsePtr &)> &&callback) {
@@ -674,6 +652,12 @@ bool llamaCPP::loadModelImpl(std::shared_ptr<Json::Value> jsonBody) {
674652
}
675653
llama.initialize();
676654

655+
if (queue != nullptr) {
656+
delete queue;
657+
}
658+
659+
queue = new trantor::ConcurrentTaskQueue(llama.params.n_parallel, "llamaCPP");
660+
677661
llama.model_loaded_external = true;
678662

679663
LOG_INFO << "Started background task here!";

controllers/llamaCPP.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
#include "common/base.h"
2828
#include "utils/json.hpp"
29+
#include <trantor/utils/ConcurrentTaskQueue.h>
2930

3031
// auto generated files (update with ./deps.sh)
3132

@@ -2562,10 +2563,13 @@ class llamaCPP : public drogon::HttpController<llamaCPP>, public ChatProvider {
25622563
bool caching_enabled;
25632564
std::atomic<int> no_of_chats = 0;
25642565
int clean_cache_threshold;
2565-
std::atomic<bool> single_queue_is_busy; // This value only used under the
2566-
// condition n_parallel is 1
25672566
std::string grammar_file_content;
25682567

2568+
/**
2569+
* Queue to handle the inference tasks
2570+
*/
2571+
trantor::ConcurrentTaskQueue *queue;
2572+
25692573
bool loadModelImpl(std::shared_ptr<Json::Value> jsonBody);
25702574
void inferenceImpl(std::shared_ptr<Json::Value> jsonBody,
25712575
std::function<void(const HttpResponsePtr &)> &callback);

0 commit comments

Comments
 (0)