Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit a305d60

Browse files
authored
Merge pull request #419 from janhq/417-feat-refactor-some-parts-of-the-code
417 feat refactor some parts of the code
2 parents 8f6d281 + 96deb0e commit a305d60

File tree

2 files changed

+112
-87
lines changed

2 files changed

+112
-87
lines changed

controllers/llamaCPP.cc

Lines changed: 97 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@
66
using namespace inferences;
77
using json = nlohmann::json;
88

9+
/**
10+
* There is a need to save state of current ongoing inference status of a
11+
* handler, this struct is to solve that issue
12+
*
13+
* @param inst Pointer to the llamaCPP instance this inference task is
14+
* associated with.
15+
*/
916
struct inferenceState {
1017
bool is_stopped = false;
1118
bool is_streaming = false;
@@ -15,15 +22,20 @@ struct inferenceState {
1522
inferenceState(llamaCPP *inst) : instance(inst) {}
1623
};
1724

25+
/**
26+
* This function is to create the smart pointer to inferenceState, hence the
27+
* inferenceState will be persisting even tho the lambda in streaming might go
28+
* out of scope and the handler already moved on
29+
*/
1830
std::shared_ptr<inferenceState> create_inference_state(llamaCPP *instance) {
1931
return std::make_shared<inferenceState>(instance);
2032
}
2133

22-
// --------------------------------------------
23-
24-
// Function to check if the model is loaded
25-
void check_model_loaded(
26-
llama_server_context &llama, const HttpRequestPtr &req,
34+
/**
35+
* Check if model already loaded if not return message to user
36+
* @param callback the function to return message to user
37+
*/
38+
void llamaCPP::checkModelLoaded(
2739
std::function<void(const HttpResponsePtr &)> &callback) {
2840
if (!llama.model_loaded_external) {
2941
Json::Value jsonResp;
@@ -136,7 +148,7 @@ void llamaCPP::warmupModel() {
136148
return;
137149
}
138150

139-
void llamaCPP::chatCompletionPrelight(
151+
void llamaCPP::handlePrelight(
140152
const HttpRequestPtr &req,
141153
std::function<void(const HttpResponsePtr &)> &&callback) {
142154
auto resp = drogon::HttpResponse::newHttpResponse();
@@ -151,10 +163,17 @@ void llamaCPP::chatCompletion(
151163
const HttpRequestPtr &req,
152164
std::function<void(const HttpResponsePtr &)> &&callback) {
153165

166+
const auto &jsonBody = req->getJsonObject();
154167
// Check if model is loaded
155-
check_model_loaded(llama, req, callback);
168+
checkModelLoaded(callback);
169+
170+
chatCompletionImpl(jsonBody, callback);
171+
}
172+
173+
void llamaCPP::chatCompletionImpl(
174+
std::shared_ptr<Json::Value> jsonBody,
175+
std::function<void(const HttpResponsePtr &)> &callback) {
156176

157-
const auto &jsonBody = req->getJsonObject();
158177
std::string formatted_output = pre_prompt;
159178

160179
json data;
@@ -402,17 +421,23 @@ void llamaCPP::chatCompletion(
402421
}
403422
}
404423
}
424+
405425
void llamaCPP::embedding(
406426
const HttpRequestPtr &req,
407427
std::function<void(const HttpResponsePtr &)> &&callback) {
408-
check_model_loaded(llama, req, callback);
428+
checkModelLoaded(callback);
429+
const auto &jsonBody = req->getJsonObject();
409430

410-
auto state = create_inference_state(this);
431+
embeddingImpl(jsonBody, callback);
432+
return;
433+
}
411434

412-
const auto &jsonBody = req->getJsonObject();
435+
void llamaCPP::embeddingImpl(
436+
std::shared_ptr<Json::Value> jsonBody,
437+
std::function<void(const HttpResponsePtr &)> &callback) {
413438

414439
Json::Value responseData(Json::arrayValue);
415-
440+
auto state = create_inference_state(this);
416441
if (jsonBody->isMember("input")) {
417442
// If single queue is busy, we will wait if not we will just go ahead and
418443
// process and make it busy, and yet i'm aware not DRY, i have the same
@@ -464,7 +489,6 @@ void llamaCPP::embedding(
464489
resp->setBody(Json::writeString(Json::StreamWriterBuilder(), root));
465490
resp->setContentTypeString("application/json");
466491
callback(resp);
467-
return;
468492
}
469493

470494
void llamaCPP::unloadModel(
@@ -501,31 +525,61 @@ void llamaCPP::modelStatus(
501525
callback(resp);
502526
return;
503527
}
528+
void llamaCPP::loadModel(
529+
const HttpRequestPtr &req,
530+
std::function<void(const HttpResponsePtr &)> &&callback) {
504531

505-
bool llamaCPP::loadModelImpl(const Json::Value &jsonBody) {
532+
if (llama.model_loaded_external) {
533+
LOG_INFO << "model loaded";
534+
Json::Value jsonResp;
535+
jsonResp["message"] = "Model already loaded";
536+
auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp);
537+
resp->setStatusCode(drogon::k409Conflict);
538+
callback(resp);
539+
return;
540+
}
506541

507-
gpt_params params;
542+
const auto &jsonBody = req->getJsonObject();
543+
if (!loadModelImpl(jsonBody)) {
544+
// Error occurred during model loading
545+
Json::Value jsonResp;
546+
jsonResp["message"] = "Failed to load model";
547+
auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp);
548+
resp->setStatusCode(drogon::k500InternalServerError);
549+
callback(resp);
550+
} else {
551+
// Model loaded successfully
552+
Json::Value jsonResp;
553+
jsonResp["message"] = "Model loaded successfully";
554+
auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp);
555+
callback(resp);
556+
}
557+
}
558+
559+
bool llamaCPP::loadModelImpl(std::shared_ptr<Json::Value> jsonBody) {
508560

561+
gpt_params params;
509562
// By default will setting based on number of handlers
510563
if (jsonBody) {
511-
if (!jsonBody["mmproj"].isNull()) {
564+
if (!jsonBody->operator[]("mmproj").isNull()) {
512565
LOG_INFO << "MMPROJ FILE detected, multi-model enabled!";
513-
params.mmproj = jsonBody["mmproj"].asString();
566+
params.mmproj = jsonBody->operator[]("mmproj").asString();
514567
}
515-
if (!jsonBody["grp_attn_n"].isNull()) {
568+
if (!jsonBody->operator[]("grp_attn_n").isNull()) {
516569

517-
params.grp_attn_n = jsonBody["grp_attn_n"].asInt();
570+
params.grp_attn_n = jsonBody->operator[]("grp_attn_n").asInt();
518571
}
519-
if (!jsonBody["grp_attn_w"].isNull()) {
572+
if (!jsonBody->operator[]("grp_attn_w").isNull()) {
520573

521-
params.grp_attn_w = jsonBody["grp_attn_w"].asInt();
574+
params.grp_attn_w = jsonBody->operator[]("grp_attn_w").asInt();
522575
}
523-
if (!jsonBody["mlock"].isNull()) {
524-
params.use_mlock = jsonBody["mlock"].asBool();
576+
if (!jsonBody->operator[]("mlock").isNull()) {
577+
params.use_mlock = jsonBody->operator[]("mlock").asBool();
525578
}
526579

527-
if (!jsonBody["grammar_file"].isNull()) {
528-
std::string grammar_file = jsonBody["grammar_file"].asString();
580+
if (!jsonBody->operator[]("grammar_file").isNull()) {
581+
std::string grammar_file =
582+
jsonBody->operator[]("grammar_file").asString();
529583
std::ifstream file(grammar_file);
530584
if (!file) {
531585
LOG_ERROR << "Grammar file not found";
@@ -536,30 +590,31 @@ bool llamaCPP::loadModelImpl(const Json::Value &jsonBody) {
536590
}
537591
};
538592

539-
params.model = jsonBody["llama_model_path"].asString();
540-
params.n_gpu_layers = jsonBody.get("ngl", 100).asInt();
541-
params.n_ctx = jsonBody.get("ctx_len", 2048).asInt();
542-
params.embedding = jsonBody.get("embedding", true).asBool();
593+
params.model = jsonBody->operator[]("llama_model_path").asString();
594+
params.n_gpu_layers = jsonBody->get("ngl", 100).asInt();
595+
params.n_ctx = jsonBody->get("ctx_len", 2048).asInt();
596+
params.embedding = jsonBody->get("embedding", true).asBool();
543597
// Check if n_parallel exists in jsonBody, if not, set to drogon_thread
544-
params.n_batch = jsonBody.get("n_batch", 512).asInt();
545-
params.n_parallel = jsonBody.get("n_parallel", 1).asInt();
598+
params.n_batch = jsonBody->get("n_batch", 512).asInt();
599+
params.n_parallel = jsonBody->get("n_parallel", 1).asInt();
546600
params.n_threads =
547-
jsonBody.get("cpu_threads", std::thread::hardware_concurrency())
601+
jsonBody->get("cpu_threads", std::thread::hardware_concurrency())
548602
.asInt();
549-
params.cont_batching = jsonBody.get("cont_batching", false).asBool();
603+
params.cont_batching = jsonBody->get("cont_batching", false).asBool();
550604
this->clean_cache_threshold =
551-
jsonBody.get("clean_cache_threshold", 5).asInt();
552-
this->caching_enabled = jsonBody.get("caching_enabled", false).asBool();
553-
this->user_prompt = jsonBody.get("user_prompt", "USER: ").asString();
554-
this->ai_prompt = jsonBody.get("ai_prompt", "ASSISTANT: ").asString();
605+
jsonBody->get("clean_cache_threshold", 5).asInt();
606+
this->caching_enabled = jsonBody->get("caching_enabled", false).asBool();
607+
this->user_prompt = jsonBody->get("user_prompt", "USER: ").asString();
608+
this->ai_prompt = jsonBody->get("ai_prompt", "ASSISTANT: ").asString();
555609
this->system_prompt =
556-
jsonBody.get("system_prompt", "ASSISTANT's RULE: ").asString();
557-
this->pre_prompt = jsonBody.get("pre_prompt", "").asString();
558-
this->repeat_last_n = jsonBody.get("repeat_last_n", 32).asInt();
610+
jsonBody->get("system_prompt", "ASSISTANT's RULE: ").asString();
611+
this->pre_prompt = jsonBody->get("pre_prompt", "").asString();
612+
this->repeat_last_n = jsonBody->get("repeat_last_n", 32).asInt();
559613

560-
if (!jsonBody["llama_log_folder"].isNull()) {
614+
if (!jsonBody->operator[]("llama_log_folder").isNull()) {
561615
log_enable();
562-
std::string llama_log_folder = jsonBody["llama_log_folder"].asString();
616+
std::string llama_log_folder =
617+
jsonBody->operator[]("llama_log_folder").asString();
563618
log_set_target(llama_log_folder + "llama.log");
564619
} // Set folder for llama log
565620
}
@@ -597,37 +652,6 @@ bool llamaCPP::loadModelImpl(const Json::Value &jsonBody) {
597652
return true;
598653
}
599654

600-
void llamaCPP::loadModel(
601-
const HttpRequestPtr &req,
602-
std::function<void(const HttpResponsePtr &)> &&callback) {
603-
604-
if (llama.model_loaded_external) {
605-
LOG_INFO << "model loaded";
606-
Json::Value jsonResp;
607-
jsonResp["message"] = "Model already loaded";
608-
auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp);
609-
resp->setStatusCode(drogon::k409Conflict);
610-
callback(resp);
611-
return;
612-
}
613-
614-
const auto &jsonBody = req->getJsonObject();
615-
if (!loadModelImpl(*jsonBody)) {
616-
// Error occurred during model loading
617-
Json::Value jsonResp;
618-
jsonResp["message"] = "Failed to load model";
619-
auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp);
620-
resp->setStatusCode(drogon::k500InternalServerError);
621-
callback(resp);
622-
} else {
623-
// Model loaded successfully
624-
Json::Value jsonResp;
625-
jsonResp["message"] = "Model loaded successfully";
626-
auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp);
627-
callback(resp);
628-
}
629-
}
630-
631655
void llamaCPP::backgroundTask() {
632656
while (llama.model_loaded_external) {
633657
// model_loaded =

controllers/llamaCPP.h

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2530,36 +2530,26 @@ class llamaCPP : public drogon::HttpController<llamaCPP> {
25302530

25312531
// Openai compatible path
25322532
ADD_METHOD_TO(llamaCPP::chatCompletion, "/v1/chat/completions", Post);
2533-
ADD_METHOD_TO(llamaCPP::chatCompletionPrelight, "/v1/chat/completions",
2534-
Options);
2533+
ADD_METHOD_TO(llamaCPP::handlePrelight, "/v1/chat/completions", Options);
25352534

25362535
ADD_METHOD_TO(llamaCPP::embedding, "/v1/embeddings", Post);
2536+
ADD_METHOD_TO(llamaCPP::handlePrelight, "/v1/embeddings", Options);
25372537

25382538
// PATH_ADD("/llama/chat_completion", Post);
25392539
METHOD_LIST_END
25402540
void chatCompletion(const HttpRequestPtr &req,
25412541
std::function<void(const HttpResponsePtr &)> &&callback);
2542-
void chatCompletionPrelight(
2543-
const HttpRequestPtr &req,
2544-
std::function<void(const HttpResponsePtr &)> &&callback);
2542+
void handlePrelight(const HttpRequestPtr &req,
2543+
std::function<void(const HttpResponsePtr &)> &&callback);
25452544
void embedding(const HttpRequestPtr &req,
25462545
std::function<void(const HttpResponsePtr &)> &&callback);
25472546
void loadModel(const HttpRequestPtr &req,
25482547
std::function<void(const HttpResponsePtr &)> &&callback);
25492548
void unloadModel(const HttpRequestPtr &req,
25502549
std::function<void(const HttpResponsePtr &)> &&callback);
2551-
25522550
void modelStatus(const HttpRequestPtr &req,
25532551
std::function<void(const HttpResponsePtr &)> &&callback);
25542552

2555-
bool loadModelImpl(const Json::Value &jsonBody);
2556-
2557-
void warmupModel();
2558-
2559-
void backgroundTask();
2560-
2561-
void stopBackgroundTask();
2562-
25632553
private:
25642554
llama_server_context llama;
25652555
// std::atomic<bool> model_loaded = false;
@@ -2577,5 +2567,16 @@ class llamaCPP : public drogon::HttpController<llamaCPP> {
25772567
std::atomic<bool> single_queue_is_busy; // This value only used under the
25782568
// condition n_parallel is 1
25792569
std::string grammar_file_content;
2570+
2571+
bool loadModelImpl(std::shared_ptr<Json::Value> jsonBody);
2572+
void
2573+
chatCompletionImpl(std::shared_ptr<Json::Value> jsonBody,
2574+
std::function<void(const HttpResponsePtr &)> &callback);
2575+
void embeddingImpl(std::shared_ptr<Json::Value> jsonBody,
2576+
std::function<void(const HttpResponsePtr &)> &callback);
2577+
void checkModelLoaded(std::function<void(const HttpResponsePtr &)> &callback);
2578+
void warmupModel();
2579+
void backgroundTask();
2580+
void stopBackgroundTask();
25802581
};
25812582
}; // namespace inferences

0 commit comments

Comments
 (0)