Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 653059f

Browse files
authored
Merge pull request #498 from janhq/fix-embed
fix: make embedding work again
2 parents 3211173 + 09d7a09 commit 653059f

File tree

3 files changed

+73
-45
lines changed

3 files changed

+73
-45
lines changed

context/llama_server_context.h

Lines changed: 68 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1+
#include <mutex>
2+
#include <set>
13
#include <string>
24
#include <vector>
3-
#include <set>
4-
#include <mutex>
55

66
// External
77
#include "clip.h"
88
#include "common.h"
99
#include "llama.h"
10-
#include "utils/json.hpp"
11-
#include "stb_image.h"
1210
#include "llava.h"
11+
#include "stb_image.h"
12+
#include "utils/json.hpp"
1313

1414
#if defined(_WIN32)
1515
#define NOMINMAX
@@ -532,7 +532,8 @@ struct llama_server_context {
532532

533533
std::tie(model, ctx) = llama_init_from_gpt_params(params);
534534
if (model == nullptr) {
535-
LOG_ERROR_LLAMA("llama.cpp unable to load model", {{"model", params.model}});
535+
LOG_ERROR_LLAMA("llama.cpp unable to load model",
536+
{{"model", params.model}});
536537
return false;
537538
}
538539

@@ -585,7 +586,11 @@ struct llama_server_context {
585586
try {
586587
batch = llama_batch_init(n_ctx, 0, params.n_parallel);
587588
} catch (const std::exception& e) {
588-
LOG_ERROR_LLAMA("Failed to allocate llama.cpp batch metadata" , {{"exception", e.what()}, {"n_tokens_alloc", n_ctx}, {"embd", 0}, {"n_seq_max", params.n_parallel}});
589+
LOG_ERROR_LLAMA("Failed to allocate llama.cpp batch metadata",
590+
{{"exception", e.what()},
591+
{"n_tokens_alloc", n_ctx},
592+
{"embd", 0},
593+
{"n_seq_max", params.n_parallel}});
589594
}
590595

591596
// empty system prompt
@@ -1244,19 +1249,35 @@ struct llama_server_context {
12441249
res.stop = true;
12451250

12461251
const int n_embd = llama_n_embd(model);
1247-
if (!params.embedding) {
1248-
LOG_WARNING_LLAMA("embedding disabled",
1249-
{
1250-
{"params.embedding", params.embedding},
1251-
});
1252-
res.result_json = json{
1253-
{"embedding", std::vector<float>(n_embd, 0.0f)},
1254-
};
1255-
} else {
1256-
const float* data = llama_get_embeddings(ctx);
1257-
std::vector<float> embedding(data, data + n_embd);
1252+
1253+
std::vector<float> embd_res(n_embd, 0.0f);
1254+
1255+
for (int i = 0; i < batch.n_tokens; ++i) {
1256+
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
1257+
continue;
1258+
}
1259+
1260+
const float* embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
1261+
if (embd == NULL) {
1262+
embd = llama_get_embeddings_ith(ctx, i);
1263+
}
1264+
1265+
if (embd == NULL) {
1266+
LOG_ERROR << "failed to get embeddings "
1267+
<< "token: " << batch.token[i]
1268+
<< ", seq_id: " << batch.seq_id[i][0];
1269+
1270+
res.result_json = json{
1271+
{"embedding", std::vector<float>(n_embd, 0.0f)},
1272+
};
1273+
1274+
continue;
1275+
}
1276+
1277+
llama_embd_normalize(embd, embd_res.data(), n_embd);
1278+
12581279
res.result_json = json{
1259-
{"embedding", embedding},
1280+
{"embedding", embd_res},
12601281
};
12611282
}
12621283
queue_results.push_back(res);
@@ -1380,7 +1401,7 @@ struct llama_server_context {
13801401
std::vector<llama_token> append_tokens =
13811402
tokenize(json_prompt, false); // has next image
13821403
for (int i = 0; i < (int)append_tokens.size(); ++i) {
1383-
llama_batch_add(batch, append_tokens[i], slot.n_past, {slot.id}, true);
1404+
llama_batch_add(batch, append_tokens[i], slot.n_past, {slot.id + 1}, true);
13841405
slot.n_past += 1;
13851406
}
13861407
}
@@ -1523,27 +1544,28 @@ struct llama_server_context {
15231544

15241545
for (llama_client_slot& slot : slots) {
15251546
if (slot.is_processing() &&
1526-
slot.cache_tokens.size() >= (size_t)slot.n_ctx) {
1547+
(int)system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
15271548
// Shift context
1528-
const int n_left = slot.n_past - slot.params.n_keep - 1;
1549+
const int n_keep = slot.params.n_keep + add_bos_token;
1550+
const int n_left = (int)system_tokens.size() + slot.n_past - n_keep;
15291551
const int n_discard = n_left / 2;
15301552

15311553
LOG_TEE(
15321554
"slot %d: context shift - n_keep = %d, n_left = %d, n_discard "
15331555
"= %d\n",
15341556
slot.id, slot.params.n_keep, n_left, n_discard);
1535-
llama_kv_cache_seq_rm(ctx, slot.id, slot.params.n_keep + 1,
1536-
slot.params.n_keep + n_discard + 1);
1537-
llama_kv_cache_seq_add(ctx, slot.id,
1538-
slot.params.n_keep + 1 + n_discard,
1539-
slot.n_past, -n_discard);
1540-
1541-
for (size_t i = slot.params.n_keep + 1 + n_discard;
1542-
i < slot.cache_tokens.size(); i++) {
1543-
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
1544-
}
1557+
llama_kv_cache_seq_rm(ctx, slot.id + 1, n_keep, n_keep + n_discard);
1558+
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard,
1559+
system_tokens.size() + slot.n_past, -n_discard);
1560+
1561+
if (slot.params.cache_prompt) {
1562+
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size();
1563+
i++) {
1564+
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
1565+
}
15451566

1546-
slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
1567+
slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
1568+
}
15471569

15481570
slot.n_past -= n_discard;
15491571

@@ -1557,6 +1579,9 @@ struct llama_server_context {
15571579
}
15581580
}
15591581

1582+
// start populating the batch for this iteration
1583+
llama_batch_clear(batch);
1584+
15601585
// decode any currently ongoing sequences
15611586
for (auto& slot : slots) {
15621587
// release the slot
@@ -1578,14 +1603,15 @@ struct llama_server_context {
15781603
slot.i_batch = batch.n_tokens;
15791604

15801605
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.n_past,
1581-
{slot.id}, true);
1606+
{slot.id + 1}, true);
15821607

15831608
slot.n_decoded += 1;
15841609
slot.n_past += 1;
15851610
}
15861611

15871612
// process in chunks of params.n_batch
1588-
int32_t n_batch = params.n_batch;
1613+
int32_t n_batch = llama_n_batch(ctx);
1614+
int32_t n_ubatch = llama_n_ubatch(ctx);
15891615

15901616
// assign workload to the slots
15911617
if (params.cont_batching || batch.n_tokens == 0) {
@@ -1641,8 +1667,7 @@ struct llama_server_context {
16411667
} else {
16421668
prompt_tokens = tokenize(
16431669
slot.prompt,
1644-
system_prompt.empty() &&
1645-
add_bos_token); // add BOS if there isn't system prompt
1670+
system_prompt.empty()); // add BOS if there isn't system prompt
16461671
}
16471672

16481673
slot.num_prompt_tokens = prompt_tokens.size();
@@ -1738,9 +1763,11 @@ struct llama_server_context {
17381763
std::vector<llama_token> prefix_tokens =
17391764
has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token)
17401765
: prompt_tokens;
1741-
for (; slot.n_past < (int)prefix_tokens.size(); ++slot.n_past) {
1766+
for (;
1767+
slot.n_past < slot.num_prompt_tokens && batch.n_tokens < n_batch;
1768+
++slot.n_past) {
17421769
llama_batch_add(batch, prefix_tokens[slot.n_past],
1743-
system_tokens.size() + slot.n_past, {slot.id},
1770+
system_tokens.size() + slot.n_past, {slot.id + 1},
17441771
false);
17451772
}
17461773

@@ -1803,7 +1830,8 @@ struct llama_server_context {
18031830
}
18041831

18051832
for (auto& slot : slots) {
1806-
if (slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) {
1833+
if (slot.state != PROCESSING || slot.i_batch < (int)i ||
1834+
slot.i_batch >= (int)(i + n_tokens)) {
18071835
continue;
18081836
}
18091837

@@ -1812,7 +1840,7 @@ struct llama_server_context {
18121840
send_embedding(slot);
18131841
slot.release();
18141842
slot.i_batch = -1;
1815-
return true;
1843+
continue;
18161844
}
18171845

18181846
completion_token_output result;

controllers/llamaCPP.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,14 +154,14 @@ llamaCPP::~llamaCPP() {
154154
StopBackgroundTask();
155155
}
156156

157-
void llamaCPP::WarmupModel() {
157+
void llamaCPP::WarmupModel(bool is_embedding) {
158158
json pseudo;
159159

160160
LOG_INFO << "Warm-up model";
161161
pseudo["prompt"] = "Hello";
162162
pseudo["n_predict"] = 2;
163163
pseudo["stream"] = false;
164-
const int task_id = llama.request_completion(pseudo, false, false, -1);
164+
const int task_id = llama.request_completion(pseudo, false, is_embedding, -1);
165165
std::string completion_text;
166166
task_result result = llama.next_result(task_id);
167167
if (!result.error && result.stop) {
@@ -624,7 +624,7 @@ bool llamaCPP::LoadModelImpl(std::shared_ptr<Json::Value> jsonBody) {
624624

625625
params.n_gpu_layers = jsonBody->get("ngl", 100).asInt();
626626
params.n_ctx = jsonBody->get("ctx_len", 2048).asInt();
627-
params.embedding = jsonBody->get("embedding", true).asBool();
627+
params.embedding = jsonBody->get("embedding", false).asBool();
628628
// Check if n_parallel exists in jsonBody, if not, set to drogon_thread
629629
params.n_batch = jsonBody->get("n_batch", 512).asInt();
630630
params.n_parallel = jsonBody->get("n_parallel", 1).asInt();
@@ -681,7 +681,7 @@ bool llamaCPP::LoadModelImpl(std::shared_ptr<Json::Value> jsonBody) {
681681

682682
LOG_INFO << "Started background task here!";
683683
backgroundThread = std::thread(&llamaCPP::BackgroundTask, this);
684-
WarmupModel();
684+
WarmupModel(params.embedding);
685685
return true;
686686
}
687687

controllers/llamaCPP.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ class llamaCPP : public drogon::HttpController<llamaCPP>,
101101
void EmbeddingImpl(std::shared_ptr<Json::Value> jsonBody,
102102
std::function<void(const HttpResponsePtr&)>& callback);
103103
bool CheckModelLoaded(std::function<void(const HttpResponsePtr&)>& callback);
104-
void WarmupModel();
104+
void WarmupModel(bool is_embedding);
105105
void BackgroundTask();
106106
void StopBackgroundTask();
107107
};

0 commit comments

Comments
 (0)