- 
                Notifications
    You must be signed in to change notification settings 
- Fork 13.5k
server : remove n_past #16818
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
server : remove n_past #16818
Conversation
| 
 This should not affect mtmd as the most of the code paths for mtmd expect to use  | 
1709bfa    to
    545df93      
    Compare
  
    | llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.n_past, -n_discard); | ||
|  | ||
| // add generated tokens to cache | ||
| { | ||
| llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy | ||
| for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { | ||
| new_tokens[i - n_discard] = new_tokens[i]; | ||
| } | ||
|  | ||
| new_tokens.resize(slot.prompt.tokens.size() - n_discard); | ||
| slot.prompt.tokens.clear(); | ||
| slot.prompt.tokens.insert(new_tokens); | ||
| } | ||
| llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard); | ||
|  | ||
| slot.n_past -= n_discard; | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Attention here. I think this is simply equivalent to keep_first(), but not sure if there was a reason to implement it like this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I think it's more like "remove middle" rather than keep_first(). Example:
# n_keep=2, n_discard=3
 seq: 0 1 2 3 4 5 6 7 8 9
keep: x x
  rm:     x x x
new seq: 0 1 5 6 7 8 9
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also please note that when migrating from using std::vector to using server_tokens, I try to stick to the original code while poly-filling server_tokens with .clear() and .insert(), so that's why the code may looks a bit strange. But yes it's better to refactor this code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good catch - will revisit this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in 7e60d1c
| if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) { | ||
| // process the image | ||
| int32_t new_n_past; | ||
| int32_t res = input_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past); | ||
| int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.id, new_n_past); | ||
| if (res != 0) { | ||
| SLT_ERR(slot, "failed to process image, res = %d\n", res); | ||
| send_error(slot, "failed to process image", ERROR_TYPE_SERVER); | ||
| slot.release(); | ||
| continue; | ||
| } | ||
|  | ||
| slot.n_prompt_tokens_processed += new_n_past - slot.prompt.n_tokens(); | ||
|  | ||
| // add the image chunk to cache | ||
| { | ||
| const auto & chunk = input_tokens.find_chunk(slot.n_past); | ||
| const auto & chunk = input_tokens.find_chunk(slot.prompt.n_tokens()); | ||
| slot.prompt.tokens.push_back(chunk.get()); // copy | ||
| } | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Attention here - not sure I understand 100% the logic here, but I think the new version should be the same as the old one.
      
        
              This comment was marked as outdated.
        
          
      
    
    This comment was marked as outdated.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please ignore my last comment. Here the n_past is used in context-shifting only. It is safe to replace it with prompt.n_tokens() which should reflect exactly number of positions (and not the number of cells in KV cache). Actually I recommend to rename n_tokens() --> n_pos() for this reason.
(Edit: no, we actually want n_tokens() as we want to keep track the cells in KV cache)
While we're not using context shifting with mtmd, I think it's worth noting the diff here:
If the context shifting take place, for example when we have 10 tokens and we remove the first 5, n_past will remain 10 in this case, while prompt.n_tokens() will return 10-5=5
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The alternative solution is to:
- Make server_tokens::tokensto reflect the actual number of tokens in KV
- Calculate the posfor the next token dynamically using a combination ofllama_memory_seq_pos_max()andmtmd_input_chunk_get_n_pos()
The main question is: do you think it's OK to call llama_memory_seq_pos_max() in each new token? Does it has a significant impact on performance?
It this approach seems OK, I will push a commit for it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nevermind, I started working on the idea but it doesn't work, because sometimes the batch is not yet process so llama_memory_seq_pos_max() does not reflect the correct number of positions.
I think the more intuitive way now is to simply add a std::vector<llama_pos> to server_token. I'll push a commit to illustrate the idea (we can always revert it if it's not the desired solution)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can add a new const method pos_next() to server_tokens that loops over all chunks and determines the next position based on the contents. If !mtmd then pos_next() simply returns n_tokens(). Otherwise, it has to analyze all mtmd chunks  and take into account the mtmd_input_chunk_get_n_pos().
This way, we don't have to add extra state to server_tokens.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please take a look --> 1e1b4af
(Tested with Gemma 3 and Qwen 2.5 VL)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can avoid the pos state in server_tokens. The logic is quite simple:
- If we don't have mtmd chunks then pos_next()returnsn_tokens()
- If we have a single mtmd chunk that occupies Ttokens and usesPpositions, thenpos_next()returnsn_tokens() - T + P
- For N + 1mtmd chunks - similar logic:n_tokens() - T0 - T1 ... - TN + P0 + P1 + ... + PN
This should significantly simplify the logic and state management in server_tokens.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I implemented this idea - hopefully it works correctly. I think I understand better now the logic of the mtmd chunks. There are a few ideas to refactor and improve the implementation. Will do so in follow-up PRs.
| @ngxson For mtmd, I mainly tested this with Gemma and the LightOnOCR - would be nice to make do some extra testing. | 
Co-authored-by: Xuan-Son Nguyen <son@huggingface.co>
| Actually I suggest this patch which also fix an issue with  I tested it with Qwen 2.5 VL, the test contains 2 steps: 
 If  | 
| Yes, applying this change. | 
Co-authored-by: Xuan-Son Nguyen <son@huggingface.co>
80e3672    to
    ef1646e      
    Compare
  
    | I think this should be good to merge now. Let me know if you spot any additional issues. If CI is green, will merge a bit later today. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did some more tests with Qwen 2.5 VL and Gemma 3. I can confirm that it's working correctly now. Thanks!
The
n_pastcounter is redundant and is complicating the logic, so this PR removes it. The information of the number of previous tokens in the sequences is already expressed withslot.prompt.n_tokens(), so no need to have a second way of expressing this information.Haven't tested yet how this affects mtmd use cases, but the plan is to incorporate the ideas from #15474 (comment) in order to have a correct positional representation. If there is need for additional token position tracking for mrope, it should be implemented in addition to the existing sequence-position logic.
HTTP API Changes
/metrics: Renamellamacpp:n_past_max->llamacpp:n_tokens_max