66using namespace inferences ;
77using json = nlohmann::json;
88
9- struct State {
10- bool isStopped = false ;
9+ struct inferenceState {
10+ bool is_stopped = false ;
11+ bool is_streaming = false ;
1112 int task_id;
1213 llamaCPP *instance;
1314
14- State ( int tid, llamaCPP *inst) : task_id(tid), instance(inst) {}
15+ inferenceState ( llamaCPP *inst) : instance(inst) {}
1516};
1617
17- std::shared_ptr<State> createState ( int task_id, llamaCPP *instance) {
18- return std::make_shared<State>(task_id, instance);
18+ std::shared_ptr<inferenceState> create_inference_state ( llamaCPP *instance) {
19+ return std::make_shared<inferenceState>( instance);
1920}
2021
2122// --------------------------------------------
@@ -295,41 +296,35 @@ void llamaCPP::chatCompletion(
295296#endif
296297 int task_id;
297298
298- if (llama.params .n_parallel == 1 ) {
299- while (true ) {
300- if (!single_queue_is_busy) {
301- task_id = llama.request_completion (data, false , false , -1 );
302- single_queue_is_busy = true ;
303- break ;
304- } else {
305- std::this_thread::sleep_for (
306- std::chrono::milliseconds (500 )); // Sleep for 500 milliseconds
307- }
308- }
309- } else {
310- task_id = llama.request_completion (data, false , false , -1 );
311- }
312-
313299 LOG_INFO << " Resolved request for task_id:" << task_id;
314300
315301 if (is_streamed) {
316- auto state = createState (task_id, this );
317-
302+ auto state = create_inference_state ( this );
303+ state-> task_id = task_id;
318304 auto chunked_content_provider =
319- [this , state](char *pBuffer, std::size_t nBuffSize) -> std::size_t {
305+ [state, data](char *pBuffer, std::size_t nBuffSize) -> std::size_t {
306+ if (!state->is_streaming ) {
307+ state->task_id =
308+ state->instance ->llama .request_completion (data, false , false , -1 );
309+ state->instance ->single_queue_is_busy = true ;
310+ }
320311 if (!pBuffer) {
321312 LOG_INFO << " Connection closed or buffer is null. Reset context" ;
322313 state->instance ->llama .request_cancel (state->task_id );
323- single_queue_is_busy = false ;
314+ state->is_streaming = false ;
315+ state->instance ->single_queue_is_busy = false ;
324316 return 0 ;
325317 }
326- if (state->isStopped ) {
327- single_queue_is_busy = false ;
318+ if (state->is_stopped ) {
319+ state->is_streaming = false ;
320+ state->instance ->single_queue_is_busy = false ;
328321 return 0 ;
329322 }
330323
331324 task_result result = state->instance ->llama .next_result (state->task_id );
332325 if (!result.error ) {
326+ // Update streaming state to being streamed
327+ state->is_streaming = true ;
333328 const std::string to_send = result.result_json [" content" ];
334329 const std::string str =
335330 " data: " +
@@ -351,16 +346,30 @@ void llamaCPP::chatCompletion(
351346 std::size_t nRead = std::min (str.size (), nBuffSize);
352347 memcpy (pBuffer, str.data (), nRead);
353348 LOG_INFO << " reached result stop" ;
354- state->isStopped = true ;
349+ state->is_stopped = true ;
355350 state->instance ->llama .request_cancel (state->task_id );
351+ state->is_streaming = false ;
352+ state->instance ->single_queue_is_busy = false ;
353+
356354 return nRead;
357355 }
358356 return nRead;
359357 } else {
360- single_queue_is_busy = false ;
361- return 0 ;
358+ if (state->instance ->llama .params .n_parallel == 1 ) {
359+ while (state->instance ->single_queue_is_busy ) {
360+ LOG_INFO << " Waiting for task to be released status:"
361+ << state->instance ->single_queue_is_busy ;
362+ std::this_thread::sleep_for (std::chrono::milliseconds (500 )); // Waiting in 500 miliseconds step
363+ }
364+ }
365+ std::string str = " \n\n " ;
366+ std::size_t nRead = str.size ();
367+ memcpy (pBuffer, str.data (), nRead);
368+ LOG_INFO << " Failing retrying now" ;
369+ return nRead;
362370 }
363- single_queue_is_busy = false ;
371+ state->is_streaming = false ;
372+ state->instance ->single_queue_is_busy = false ;
364373 return 0 ;
365374 };
366375 auto resp = nitro_utils::nitroStreamResponse (chunked_content_provider,
0 commit comments