Skip to content

Commit 17f086b

Browse files
committed
Use smart pointers in simple-chat
Avoid manual memory cleanups. Less memory leaks in the code now. Signed-off-by: Eric Curtin <[email protected]>
1 parent 2a82891 commit 17f086b

File tree

2 files changed

+140
-75
lines changed

2 files changed

+140
-75
lines changed

examples/simple-chat/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ set(TARGET llama-simple-chat)
22
add_executable(${TARGET} simple-chat.cpp)
33
install(TARGETS ${TARGET} RUNTIME)
44
target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT})
5-
target_compile_features(${TARGET} PRIVATE cxx_std_11)
5+
target_compile_features(${TARGET} PRIVATE cxx_std_14)

examples/simple-chat/simple-chat.cpp

+139-74
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,132 @@
1-
#include "llama.h"
21
#include <cstdio>
32
#include <cstring>
43
#include <iostream>
4+
#include <memory>
55
#include <string>
66
#include <vector>
77

8+
#include "llama.h"
9+
10+
// Add a message to `messages` and store its content in `owned_content`
11+
static void add_message(const std::string &role, const std::string &text,
12+
std::vector<llama_chat_message> &messages,
13+
std::vector<std::unique_ptr<char[]>> &owned_content) {
14+
auto content = std::make_unique<char[]>(text.size() + 1);
15+
std::strcpy(content.get(), text.c_str());
16+
messages.push_back({role.c_str(), content.get()});
17+
owned_content.push_back(std::move(content));
18+
}
19+
20+
// Function to apply the chat template and resize `formatted` if needed
21+
static int apply_chat_template(const llama_model *model,
22+
const std::vector<llama_chat_message> &messages,
23+
std::vector<char> &formatted, bool append) {
24+
int result = llama_chat_apply_template(model, nullptr, messages.data(),
25+
messages.size(), append,
26+
formatted.data(), formatted.size());
27+
if (result > static_cast<int>(formatted.size())) {
28+
formatted.resize(result);
29+
result = llama_chat_apply_template(model, nullptr, messages.data(),
30+
messages.size(), append,
31+
formatted.data(), formatted.size());
32+
}
33+
34+
return result;
35+
}
36+
37+
// Function to tokenize the prompt
38+
static int tokenize_prompt(const llama_model *model, const std::string &prompt,
39+
std::vector<llama_token> &prompt_tokens) {
40+
const int n_prompt_tokens = -llama_tokenize(
41+
model, prompt.c_str(), prompt.size(), NULL, 0, true, true);
42+
prompt_tokens.resize(n_prompt_tokens);
43+
if (llama_tokenize(model, prompt.c_str(), prompt.size(),
44+
prompt_tokens.data(), prompt_tokens.size(), true,
45+
true) < 0) {
46+
GGML_ABORT("failed to tokenize the prompt\n");
47+
return -1;
48+
}
49+
50+
return n_prompt_tokens;
51+
}
52+
53+
// Check if we have enough space in the context to evaluate this batch
54+
static int check_context_size(const llama_context *ctx,
55+
const llama_batch &batch) {
56+
const int n_ctx = llama_n_ctx(ctx);
57+
const int n_ctx_used = llama_get_kv_cache_used_cells(ctx);
58+
if (n_ctx_used + batch.n_tokens > n_ctx) {
59+
printf("\033[0m\n");
60+
fprintf(stderr, "context size exceeded\n");
61+
return 1;
62+
}
63+
64+
return 0;
65+
}
66+
67+
// convert the token to a string
68+
static int convert_token_to_string(const llama_model *model,
69+
const llama_token token_id,
70+
std::string &piece) {
71+
char buf[256];
72+
int n = llama_token_to_piece(model, token_id, buf, sizeof(buf), 0, true);
73+
if (n < 0) {
74+
GGML_ABORT("failed to convert token to piece\n");
75+
return 1;
76+
}
77+
78+
piece = std::string(buf, n);
79+
return 0;
80+
}
81+
82+
static void print_word_and_concatenate_to_response(const std::string &piece,
83+
std::string &response) {
84+
printf("%s", piece.c_str());
85+
fflush(stdout);
86+
response += piece;
87+
}
88+
89+
// helper function to evaluate a prompt and generate a response
90+
static int generate(const llama_model *model, llama_sampler *smpl,
91+
llama_context *ctx, const std::string &prompt,
92+
std::string &response) {
93+
std::vector<llama_token> prompt_tokens;
94+
const int n_prompt_tokens = tokenize_prompt(model, prompt, prompt_tokens);
95+
if (n_prompt_tokens < 0) {
96+
return 1;
97+
}
98+
99+
// prepare a batch for the prompt
100+
llama_batch batch =
101+
llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
102+
llama_token new_token_id;
103+
while (true) {
104+
check_context_size(ctx, batch);
105+
if (llama_decode(ctx, batch)) {
106+
GGML_ABORT("failed to decode\n");
107+
return 1;
108+
}
109+
110+
// sample the next token, check is it an end of generation?
111+
new_token_id = llama_sampler_sample(smpl, ctx, -1);
112+
if (llama_token_is_eog(model, new_token_id)) {
113+
break;
114+
}
115+
116+
std::string piece;
117+
if (convert_token_to_string(model, new_token_id, piece)) {
118+
return 1;
119+
}
120+
121+
print_word_and_concatenate_to_response(piece, response);
122+
123+
// prepare the next batch with the sampled token
124+
batch = llama_batch_get_one(&new_token_id, 1);
125+
}
126+
127+
return 0;
128+
}
129+
8130
static void print_usage(int, char ** argv) {
9131
printf("\nexample usage:\n");
10132
printf("\n %s -m model.gguf [-c context_size] [-ngl n_gpu_layers]\n", argv[0]);
@@ -66,6 +188,7 @@ int main(int argc, char ** argv) {
66188
llama_model_params model_params = llama_model_default_params();
67189
model_params.n_gpu_layers = ngl;
68190

191+
// This prints ........
69192
llama_model * model = llama_load_model_from_file(model_path.c_str(), model_params);
70193
if (!model) {
71194
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
@@ -88,107 +211,49 @@ int main(int argc, char ** argv) {
88211
llama_sampler_chain_add(smpl, llama_sampler_init_min_p(0.05f, 1));
89212
llama_sampler_chain_add(smpl, llama_sampler_init_temp(0.8f));
90213
llama_sampler_chain_add(smpl, llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
91-
92-
// helper function to evaluate a prompt and generate a response
93-
auto generate = [&](const std::string & prompt) {
94-
std::string response;
95-
96-
// tokenize the prompt
97-
const int n_prompt_tokens = -llama_tokenize(model, prompt.c_str(), prompt.size(), NULL, 0, true, true);
98-
std::vector<llama_token> prompt_tokens(n_prompt_tokens);
99-
if (llama_tokenize(model, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), llama_get_kv_cache_used_cells(ctx) == 0, true) < 0) {
100-
GGML_ABORT("failed to tokenize the prompt\n");
101-
}
102-
103-
// prepare a batch for the prompt
104-
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
105-
llama_token new_token_id;
106-
while (true) {
107-
// check if we have enough space in the context to evaluate this batch
108-
int n_ctx = llama_n_ctx(ctx);
109-
int n_ctx_used = llama_get_kv_cache_used_cells(ctx);
110-
if (n_ctx_used + batch.n_tokens > n_ctx) {
111-
printf("\033[0m\n");
112-
fprintf(stderr, "context size exceeded\n");
113-
exit(0);
114-
}
115-
116-
if (llama_decode(ctx, batch)) {
117-
GGML_ABORT("failed to decode\n");
118-
}
119-
120-
// sample the next token
121-
new_token_id = llama_sampler_sample(smpl, ctx, -1);
122-
123-
// is it an end of generation?
124-
if (llama_token_is_eog(model, new_token_id)) {
125-
break;
126-
}
127-
128-
// convert the token to a string, print it and add it to the response
129-
char buf[256];
130-
int n = llama_token_to_piece(model, new_token_id, buf, sizeof(buf), 0, true);
131-
if (n < 0) {
132-
GGML_ABORT("failed to convert token to piece\n");
133-
}
134-
std::string piece(buf, n);
135-
printf("%s", piece.c_str());
136-
fflush(stdout);
137-
response += piece;
138-
139-
// prepare the next batch with the sampled token
140-
batch = llama_batch_get_one(&new_token_id, 1);
141-
}
142-
143-
return response;
144-
};
145-
146214
std::vector<llama_chat_message> messages;
215+
std::vector<std::unique_ptr<char[]>> owned_content;
147216
std::vector<char> formatted(llama_n_ctx(ctx));
148217
int prev_len = 0;
149218
while (true) {
150219
// get user input
151220
printf("\033[32m> \033[0m");
152221
std::string user;
153222
std::getline(std::cin, user);
154-
155223
if (user.empty()) {
156224
break;
157225
}
158226

159-
// add the user input to the message list and format it
160-
messages.push_back({"user", strdup(user.c_str())});
161-
int new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size());
162-
if (new_len > (int)formatted.size()) {
163-
formatted.resize(new_len);
164-
new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size());
165-
}
227+
// Add user input to messages
228+
add_message("user", user, messages, owned_content);
229+
int new_len = apply_chat_template(model, messages, formatted, true);
166230
if (new_len < 0) {
167231
fprintf(stderr, "failed to apply the chat template\n");
168232
return 1;
169233
}
170234

171-
// remove previous messages to obtain the prompt to generate the response
172-
std::string prompt(formatted.begin() + prev_len, formatted.begin() + new_len);
235+
// remove previous messages to obtain the prompt to generate the
236+
// response
237+
std::string prompt(formatted.begin() + prev_len,
238+
formatted.begin() + new_len);
173239

174240
// generate a response
175241
printf("\033[33m");
176-
std::string response = generate(prompt);
242+
std::string response;
243+
if (generate(model, smpl, ctx, prompt, response)) {
244+
return 1;
245+
}
246+
177247
printf("\n\033[0m");
178248

179-
// add the response to the messages
180-
messages.push_back({"assistant", strdup(response.c_str())});
181-
prev_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), false, nullptr, 0);
249+
// Add response to messages
250+
prev_len = apply_chat_template(model, messages, formatted, false);
182251
if (prev_len < 0) {
183252
fprintf(stderr, "failed to apply the chat template\n");
184253
return 1;
185254
}
186255
}
187256

188-
// free resources
189-
for (auto & msg : messages) {
190-
free(const_cast<char *>(msg.content));
191-
}
192257
llama_sampler_free(smpl);
193258
llama_free(ctx);
194259
llama_free_model(model);

0 commit comments

Comments
 (0)