Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 514fd2e

Browse files
authored
Merge pull request #143 from janhq/141-feat-non-stream-chat-completion
feat: make non stream completion possible to be fully compatible with…
2 parents e0cef1e + 2be0d28 commit 514fd2e

File tree

1 file changed

+108
-44
lines changed

1 file changed

+108
-44
lines changed

controllers/llamaCPP.cc

Lines changed: 108 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#include <regex>
99
#include <string>
1010
#include <thread>
11-
#include <trantor/utils/Logger.h>
1211

1312
using namespace inferences;
1413
using json = nlohmann::json;
@@ -28,6 +27,45 @@ std::shared_ptr<State> createState(int task_id, llamaCPP *instance) {
2827

2928
// --------------------------------------------
3029

30+
std::string create_full_return_json(const std::string &id,
31+
const std::string &model,
32+
const std::string &content,
33+
const std::string &system_fingerprint,
34+
int prompt_tokens, int completion_tokens,
35+
Json::Value finish_reason = Json::Value()) {
36+
37+
Json::Value root;
38+
39+
root["id"] = id;
40+
root["model"] = model;
41+
root["created"] = static_cast<int>(std::time(nullptr));
42+
root["object"] = "chat.completion";
43+
root["system_fingerprint"] = system_fingerprint;
44+
45+
Json::Value choicesArray(Json::arrayValue);
46+
Json::Value choice;
47+
48+
choice["index"] = 0;
49+
Json::Value message;
50+
message["role"] = "assistant";
51+
message["content"] = content;
52+
choice["message"] = message;
53+
choice["finish_reason"] = finish_reason;
54+
55+
choicesArray.append(choice);
56+
root["choices"] = choicesArray;
57+
58+
Json::Value usage;
59+
usage["prompt_tokens"] = prompt_tokens;
60+
usage["completion_tokens"] = completion_tokens;
61+
usage["total_tokens"] = prompt_tokens + completion_tokens;
62+
root["usage"] = usage;
63+
64+
Json::StreamWriterBuilder writer;
65+
writer["indentation"] = ""; // Compact output
66+
return Json::writeString(writer, root);
67+
}
68+
3169
std::string create_return_json(const std::string &id, const std::string &model,
3270
const std::string &content,
3371
Json::Value finish_reason = Json::Value()) {
@@ -82,9 +120,9 @@ void llamaCPP::chatCompletion(
82120
json data;
83121
json stopWords;
84122
// To set default value
85-
data["stream"] = true;
86123

87124
if (jsonBody) {
125+
data["stream"] = (*jsonBody).get("stream", false).asBool();
88126
data["n_predict"] = (*jsonBody).get("max_tokens", 500).asInt();
89127
data["top_p"] = (*jsonBody).get("top_p", 0.95).asFloat();
90128
data["temperature"] = (*jsonBody).get("temperature", 0.8).asFloat();
@@ -119,62 +157,87 @@ void llamaCPP::chatCompletion(
119157
data["stop"] = stopWords;
120158
}
121159

160+
bool is_streamed = data["stream"];
161+
122162
const int task_id = llama.request_completion(data, false, false);
123163
LOG_INFO << "Resolved request for task_id:" << task_id;
124164

125-
auto state = createState(task_id, this);
165+
if (is_streamed) {
166+
auto state = createState(task_id, this);
126167

127-
auto chunked_content_provider =
128-
[state](char *pBuffer, std::size_t nBuffSize) -> std::size_t {
129-
if (!pBuffer) {
130-
LOG_INFO << "Connection closed or buffer is null. Reset context";
131-
state->instance->llama.request_cancel(state->task_id);
132-
return 0;
133-
}
134-
if (state->isStopped) {
135-
return 0;
136-
}
137-
138-
task_result result = state->instance->llama.next_result(state->task_id);
139-
if (!result.error) {
140-
const std::string to_send = result.result_json["content"];
141-
const std::string str =
142-
"data: " +
143-
create_return_json(nitro_utils::generate_random_string(20), "_",
144-
to_send) +
145-
"\n\n";
146-
147-
std::size_t nRead = std::min(str.size(), nBuffSize);
148-
memcpy(pBuffer, str.data(), nRead);
168+
auto chunked_content_provider =
169+
[state](char *pBuffer, std::size_t nBuffSize) -> std::size_t {
170+
if (!pBuffer) {
171+
LOG_INFO << "Connection closed or buffer is null. Reset context";
172+
state->instance->llama.request_cancel(state->task_id);
173+
return 0;
174+
}
175+
if (state->isStopped) {
176+
return 0;
177+
}
149178

150-
if (result.stop) {
179+
task_result result = state->instance->llama.next_result(state->task_id);
180+
if (!result.error) {
181+
const std::string to_send = result.result_json["content"];
151182
const std::string str =
152183
"data: " +
153-
create_return_json(nitro_utils::generate_random_string(20), "_", "",
154-
"stop") +
155-
"\n\n" + "data: [DONE]" + "\n\n";
184+
create_return_json(nitro_utils::generate_random_string(20), "_",
185+
to_send) +
186+
"\n\n";
156187

157-
LOG_VERBOSE("data stream", {{"to_send", str}});
158188
std::size_t nRead = std::min(str.size(), nBuffSize);
159189
memcpy(pBuffer, str.data(), nRead);
160-
LOG_INFO << "reached result stop";
161-
state->isStopped = true;
162-
state->instance->llama.request_cancel(state->task_id);
190+
191+
if (result.stop) {
192+
const std::string str =
193+
"data: " +
194+
create_return_json(nitro_utils::generate_random_string(20), "_",
195+
"", "stop") +
196+
"\n\n" + "data: [DONE]" + "\n\n";
197+
198+
LOG_VERBOSE("data stream", {{"to_send", str}});
199+
std::size_t nRead = std::min(str.size(), nBuffSize);
200+
memcpy(pBuffer, str.data(), nRead);
201+
LOG_INFO << "reached result stop";
202+
state->isStopped = true;
203+
state->instance->llama.request_cancel(state->task_id);
204+
return nRead;
205+
}
163206
return nRead;
207+
} else {
208+
return 0;
164209
}
165-
return nRead;
166-
} else {
167210
return 0;
168-
}
169-
return 0;
170-
};
171-
auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider,
172-
"chat_completions.txt");
173-
callback(resp);
211+
};
212+
auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider,
213+
"chat_completions.txt");
214+
callback(resp);
174215

175-
return;
216+
return;
217+
} else {
218+
Json::Value respData;
219+
auto resp = nitro_utils::nitroHttpResponse();
220+
respData["testing"] = "thunghiem value moi";
221+
if (!json_value(data, "stream", false)) {
222+
std::string completion_text;
223+
task_result result = llama.next_result(task_id);
224+
if (!result.error && result.stop) {
225+
int prompt_tokens = result.result_json["tokens_evaluated"];
226+
int predicted_tokens = result.result_json["tokens_predicted"];
227+
std::string full_return =
228+
create_full_return_json(nitro_utils::generate_random_string(20),
229+
"_", result.result_json["content"], "_",
230+
prompt_tokens, predicted_tokens);
231+
resp->setBody(full_return);
232+
} else {
233+
resp->setBody("internal error during inference");
234+
return;
235+
}
236+
callback(resp);
237+
return;
238+
}
239+
}
176240
}
177-
178241
void llamaCPP::embedding(
179242
const HttpRequestPtr &req,
180243
std::function<void(const HttpResponsePtr &)> &&callback) {
@@ -262,7 +325,8 @@ void llamaCPP::loadModel(
262325
this->pre_prompt =
263326
(*jsonBody)
264327
.get("pre_prompt",
265-
"A chat between a curious user and an artificial intelligence "
328+
"A chat between a curious user and an artificial "
329+
"intelligence "
266330
"assistant. The assistant follows the given rules no matter "
267331
"what.\\n")
268332
.asString();

0 commit comments

Comments
 (0)