88#include < regex>
99#include < string>
1010#include < thread>
11- #include < trantor/utils/Logger.h>
1211
1312using namespace inferences ;
1413using json = nlohmann::json;
@@ -28,6 +27,45 @@ std::shared_ptr<State> createState(int task_id, llamaCPP *instance) {
2827
2928// --------------------------------------------
3029
30+ std::string create_full_return_json (const std::string &id,
31+ const std::string &model,
32+ const std::string &content,
33+ const std::string &system_fingerprint,
34+ int prompt_tokens, int completion_tokens,
35+ Json::Value finish_reason = Json::Value()) {
36+
37+ Json::Value root;
38+
39+ root[" id" ] = id;
40+ root[" model" ] = model;
41+ root[" created" ] = static_cast <int >(std::time (nullptr ));
42+ root[" object" ] = " chat.completion" ;
43+ root[" system_fingerprint" ] = system_fingerprint;
44+
45+ Json::Value choicesArray (Json::arrayValue);
46+ Json::Value choice;
47+
48+ choice[" index" ] = 0 ;
49+ Json::Value message;
50+ message[" role" ] = " assistant" ;
51+ message[" content" ] = content;
52+ choice[" message" ] = message;
53+ choice[" finish_reason" ] = finish_reason;
54+
55+ choicesArray.append (choice);
56+ root[" choices" ] = choicesArray;
57+
58+ Json::Value usage;
59+ usage[" prompt_tokens" ] = prompt_tokens;
60+ usage[" completion_tokens" ] = completion_tokens;
61+ usage[" total_tokens" ] = prompt_tokens + completion_tokens;
62+ root[" usage" ] = usage;
63+
64+ Json::StreamWriterBuilder writer;
65+ writer[" indentation" ] = " " ; // Compact output
66+ return Json::writeString (writer, root);
67+ }
68+
3169std::string create_return_json (const std::string &id, const std::string &model,
3270 const std::string &content,
3371 Json::Value finish_reason = Json::Value()) {
@@ -82,9 +120,9 @@ void llamaCPP::chatCompletion(
82120 json data;
83121 json stopWords;
84122 // To set default value
85- data[" stream" ] = true ;
86123
87124 if (jsonBody) {
125+ data[" stream" ] = (*jsonBody).get (" stream" , false ).asBool ();
88126 data[" n_predict" ] = (*jsonBody).get (" max_tokens" , 500 ).asInt ();
89127 data[" top_p" ] = (*jsonBody).get (" top_p" , 0.95 ).asFloat ();
90128 data[" temperature" ] = (*jsonBody).get (" temperature" , 0.8 ).asFloat ();
@@ -119,62 +157,87 @@ void llamaCPP::chatCompletion(
119157 data[" stop" ] = stopWords;
120158 }
121159
160+ bool is_streamed = data[" stream" ];
161+
122162 const int task_id = llama.request_completion (data, false , false );
123163 LOG_INFO << " Resolved request for task_id:" << task_id;
124164
125- auto state = createState (task_id, this );
165+ if (is_streamed) {
166+ auto state = createState (task_id, this );
126167
127- auto chunked_content_provider =
128- [state](char *pBuffer, std::size_t nBuffSize) -> std::size_t {
129- if (!pBuffer) {
130- LOG_INFO << " Connection closed or buffer is null. Reset context" ;
131- state->instance ->llama .request_cancel (state->task_id );
132- return 0 ;
133- }
134- if (state->isStopped ) {
135- return 0 ;
136- }
137-
138- task_result result = state->instance ->llama .next_result (state->task_id );
139- if (!result.error ) {
140- const std::string to_send = result.result_json [" content" ];
141- const std::string str =
142- " data: " +
143- create_return_json (nitro_utils::generate_random_string (20 ), " _" ,
144- to_send) +
145- " \n\n " ;
146-
147- std::size_t nRead = std::min (str.size (), nBuffSize);
148- memcpy (pBuffer, str.data (), nRead);
168+ auto chunked_content_provider =
169+ [state](char *pBuffer, std::size_t nBuffSize) -> std::size_t {
170+ if (!pBuffer) {
171+ LOG_INFO << " Connection closed or buffer is null. Reset context" ;
172+ state->instance ->llama .request_cancel (state->task_id );
173+ return 0 ;
174+ }
175+ if (state->isStopped ) {
176+ return 0 ;
177+ }
149178
150- if (result.stop ) {
179+ task_result result = state->instance ->llama .next_result (state->task_id );
180+ if (!result.error ) {
181+ const std::string to_send = result.result_json [" content" ];
151182 const std::string str =
152183 " data: " +
153- create_return_json (nitro_utils::generate_random_string (20 ), " _" , " " ,
154- " stop " ) +
155- " \n\n " + " data: [DONE] " + " \n\n " ;
184+ create_return_json (nitro_utils::generate_random_string (20 ), " _" ,
185+ to_send ) +
186+ " \n\n " ;
156187
157- LOG_VERBOSE (" data stream" , {{" to_send" , str}});
158188 std::size_t nRead = std::min (str.size (), nBuffSize);
159189 memcpy (pBuffer, str.data (), nRead);
160- LOG_INFO << " reached result stop" ;
161- state->isStopped = true ;
162- state->instance ->llama .request_cancel (state->task_id );
190+
191+ if (result.stop ) {
192+ const std::string str =
193+ " data: " +
194+ create_return_json (nitro_utils::generate_random_string (20 ), " _" ,
195+ " " , " stop" ) +
196+ " \n\n " + " data: [DONE]" + " \n\n " ;
197+
198+ LOG_VERBOSE (" data stream" , {{" to_send" , str}});
199+ std::size_t nRead = std::min (str.size (), nBuffSize);
200+ memcpy (pBuffer, str.data (), nRead);
201+ LOG_INFO << " reached result stop" ;
202+ state->isStopped = true ;
203+ state->instance ->llama .request_cancel (state->task_id );
204+ return nRead;
205+ }
163206 return nRead;
207+ } else {
208+ return 0 ;
164209 }
165- return nRead;
166- } else {
167210 return 0 ;
168- }
169- return 0 ;
170- };
171- auto resp = nitro_utils::nitroStreamResponse (chunked_content_provider,
172- " chat_completions.txt" );
173- callback (resp);
211+ };
212+ auto resp = nitro_utils::nitroStreamResponse (chunked_content_provider,
213+ " chat_completions.txt" );
214+ callback (resp);
174215
175- return ;
216+ return ;
217+ } else {
218+ Json::Value respData;
219+ auto resp = nitro_utils::nitroHttpResponse ();
220+ respData[" testing" ] = " thunghiem value moi" ;
221+ if (!json_value (data, " stream" , false )) {
222+ std::string completion_text;
223+ task_result result = llama.next_result (task_id);
224+ if (!result.error && result.stop ) {
225+ int prompt_tokens = result.result_json [" tokens_evaluated" ];
226+ int predicted_tokens = result.result_json [" tokens_predicted" ];
227+ std::string full_return =
228+ create_full_return_json (nitro_utils::generate_random_string (20 ),
229+ " _" , result.result_json [" content" ], " _" ,
230+ prompt_tokens, predicted_tokens);
231+ resp->setBody (full_return);
232+ } else {
233+ resp->setBody (" internal error during inference" );
234+ return ;
235+ }
236+ callback (resp);
237+ return ;
238+ }
239+ }
176240}
177-
178241void llamaCPP::embedding (
179242 const HttpRequestPtr &req,
180243 std::function<void (const HttpResponsePtr &)> &&callback) {
@@ -262,7 +325,8 @@ void llamaCPP::loadModel(
262325 this ->pre_prompt =
263326 (*jsonBody)
264327 .get (" pre_prompt" ,
265- " A chat between a curious user and an artificial intelligence "
328+ " A chat between a curious user and an artificial "
329+ " intelligence "
266330 " assistant. The assistant follows the given rules no matter "
267331 " what.\\ n" )
268332 .asString ();
0 commit comments