1
- #include " llama.h"
2
1
#include < cstdio>
3
2
#include < cstring>
4
3
#include < iostream>
4
+ #include < memory>
5
5
#include < string>
6
6
#include < vector>
7
7
8
+ #include " llama.h"
9
+
10
+ // Add a message to `messages` and store its content in `owned_content`
11
+ static void add_message (const std::string &role, const std::string &text,
12
+ std::vector<llama_chat_message> &messages,
13
+ std::vector<std::unique_ptr<char []>> &owned_content) {
14
+ auto content = std::make_unique<char []>(text.size () + 1 );
15
+ std::strcpy (content.get (), text.c_str ());
16
+ messages.push_back ({role.c_str (), content.get ()});
17
+ owned_content.push_back (std::move (content));
18
+ }
19
+
20
+ // Function to apply the chat template and resize `formatted` if needed
21
+ static int apply_chat_template (const llama_model *model,
22
+ const std::vector<llama_chat_message> &messages,
23
+ std::vector<char > &formatted, bool append) {
24
+ int result = llama_chat_apply_template (model, nullptr , messages.data (),
25
+ messages.size (), append,
26
+ formatted.data (), formatted.size ());
27
+ if (result > static_cast <int >(formatted.size ())) {
28
+ formatted.resize (result);
29
+ result = llama_chat_apply_template (model, nullptr , messages.data (),
30
+ messages.size (), append,
31
+ formatted.data (), formatted.size ());
32
+ }
33
+
34
+ return result;
35
+ }
36
+
37
+ // Function to tokenize the prompt
38
+ static int tokenize_prompt (const llama_model *model, const std::string &prompt,
39
+ std::vector<llama_token> &prompt_tokens) {
40
+ const int n_prompt_tokens = -llama_tokenize (
41
+ model, prompt.c_str (), prompt.size (), NULL , 0 , true , true );
42
+ prompt_tokens.resize (n_prompt_tokens);
43
+ if (llama_tokenize (model, prompt.c_str (), prompt.size (),
44
+ prompt_tokens.data (), prompt_tokens.size (), true ,
45
+ true ) < 0 ) {
46
+ GGML_ABORT (" failed to tokenize the prompt\n " );
47
+ return -1 ;
48
+ }
49
+
50
+ return n_prompt_tokens;
51
+ }
52
+
53
+ // Check if we have enough space in the context to evaluate this batch
54
+ static int check_context_size (const llama_context *ctx,
55
+ const llama_batch &batch) {
56
+ const int n_ctx = llama_n_ctx (ctx);
57
+ const int n_ctx_used = llama_get_kv_cache_used_cells (ctx);
58
+ if (n_ctx_used + batch.n_tokens > n_ctx) {
59
+ printf (" \033 [0m\n " );
60
+ fprintf (stderr, " context size exceeded\n " );
61
+ return 1 ;
62
+ }
63
+
64
+ return 0 ;
65
+ }
66
+
67
+ // convert the token to a string
68
+ static int convert_token_to_string (const llama_model *model,
69
+ const llama_token token_id,
70
+ std::string &piece) {
71
+ char buf[256 ];
72
+ int n = llama_token_to_piece (model, token_id, buf, sizeof (buf), 0 , true );
73
+ if (n < 0 ) {
74
+ GGML_ABORT (" failed to convert token to piece\n " );
75
+ return 1 ;
76
+ }
77
+
78
+ piece = std::string (buf, n);
79
+ return 0 ;
80
+ }
81
+
82
+ static void print_word_and_concatenate_to_response (const std::string &piece,
83
+ std::string &response) {
84
+ printf (" %s" , piece.c_str ());
85
+ fflush (stdout);
86
+ response += piece;
87
+ }
88
+
89
+ // helper function to evaluate a prompt and generate a response
90
+ static int generate (const llama_model *model, llama_sampler *smpl,
91
+ llama_context *ctx, const std::string &prompt,
92
+ std::string &response) {
93
+ std::vector<llama_token> prompt_tokens;
94
+ const int n_prompt_tokens = tokenize_prompt (model, prompt, prompt_tokens);
95
+ if (n_prompt_tokens < 0 ) {
96
+ return 1 ;
97
+ }
98
+
99
+ // prepare a batch for the prompt
100
+ llama_batch batch =
101
+ llama_batch_get_one (prompt_tokens.data (), prompt_tokens.size ());
102
+ llama_token new_token_id;
103
+ while (true ) {
104
+ check_context_size (ctx, batch);
105
+ if (llama_decode (ctx, batch)) {
106
+ GGML_ABORT (" failed to decode\n " );
107
+ return 1 ;
108
+ }
109
+
110
+ // sample the next token, check is it an end of generation?
111
+ new_token_id = llama_sampler_sample (smpl, ctx, -1 );
112
+ if (llama_token_is_eog (model, new_token_id)) {
113
+ break ;
114
+ }
115
+
116
+ std::string piece;
117
+ if (convert_token_to_string (model, new_token_id, piece)) {
118
+ return 1 ;
119
+ }
120
+
121
+ print_word_and_concatenate_to_response (piece, response);
122
+
123
+ // prepare the next batch with the sampled token
124
+ batch = llama_batch_get_one (&new_token_id, 1 );
125
+ }
126
+
127
+ return 0 ;
128
+ }
129
+
8
130
static void print_usage (int , char ** argv) {
9
131
printf (" \n example usage:\n " );
10
132
printf (" \n %s -m model.gguf [-c context_size] [-ngl n_gpu_layers]\n " , argv[0 ]);
@@ -66,6 +188,7 @@ int main(int argc, char ** argv) {
66
188
llama_model_params model_params = llama_model_default_params ();
67
189
model_params.n_gpu_layers = ngl;
68
190
191
+ // This prints ........
69
192
llama_model * model = llama_load_model_from_file (model_path.c_str (), model_params);
70
193
if (!model) {
71
194
fprintf (stderr , " %s: error: unable to load model\n " , __func__);
@@ -88,107 +211,49 @@ int main(int argc, char ** argv) {
88
211
llama_sampler_chain_add (smpl, llama_sampler_init_min_p (0 .05f , 1 ));
89
212
llama_sampler_chain_add (smpl, llama_sampler_init_temp (0 .8f ));
90
213
llama_sampler_chain_add (smpl, llama_sampler_init_dist (LLAMA_DEFAULT_SEED));
91
-
92
- // helper function to evaluate a prompt and generate a response
93
- auto generate = [&](const std::string & prompt) {
94
- std::string response;
95
-
96
- // tokenize the prompt
97
- const int n_prompt_tokens = -llama_tokenize (model, prompt.c_str (), prompt.size (), NULL , 0 , true , true );
98
- std::vector<llama_token> prompt_tokens (n_prompt_tokens);
99
- if (llama_tokenize (model, prompt.c_str (), prompt.size (), prompt_tokens.data (), prompt_tokens.size (), llama_get_kv_cache_used_cells (ctx) == 0 , true ) < 0 ) {
100
- GGML_ABORT (" failed to tokenize the prompt\n " );
101
- }
102
-
103
- // prepare a batch for the prompt
104
- llama_batch batch = llama_batch_get_one (prompt_tokens.data (), prompt_tokens.size ());
105
- llama_token new_token_id;
106
- while (true ) {
107
- // check if we have enough space in the context to evaluate this batch
108
- int n_ctx = llama_n_ctx (ctx);
109
- int n_ctx_used = llama_get_kv_cache_used_cells (ctx);
110
- if (n_ctx_used + batch.n_tokens > n_ctx) {
111
- printf (" \033 [0m\n " );
112
- fprintf (stderr, " context size exceeded\n " );
113
- exit (0 );
114
- }
115
-
116
- if (llama_decode (ctx, batch)) {
117
- GGML_ABORT (" failed to decode\n " );
118
- }
119
-
120
- // sample the next token
121
- new_token_id = llama_sampler_sample (smpl, ctx, -1 );
122
-
123
- // is it an end of generation?
124
- if (llama_token_is_eog (model, new_token_id)) {
125
- break ;
126
- }
127
-
128
- // convert the token to a string, print it and add it to the response
129
- char buf[256 ];
130
- int n = llama_token_to_piece (model, new_token_id, buf, sizeof (buf), 0 , true );
131
- if (n < 0 ) {
132
- GGML_ABORT (" failed to convert token to piece\n " );
133
- }
134
- std::string piece (buf, n);
135
- printf (" %s" , piece.c_str ());
136
- fflush (stdout);
137
- response += piece;
138
-
139
- // prepare the next batch with the sampled token
140
- batch = llama_batch_get_one (&new_token_id, 1 );
141
- }
142
-
143
- return response;
144
- };
145
-
146
214
std::vector<llama_chat_message> messages;
215
+ std::vector<std::unique_ptr<char []>> owned_content;
147
216
std::vector<char > formatted (llama_n_ctx (ctx));
148
217
int prev_len = 0 ;
149
218
while (true ) {
150
219
// get user input
151
220
printf (" \033 [32m> \033 [0m" );
152
221
std::string user;
153
222
std::getline (std::cin, user);
154
-
155
223
if (user.empty ()) {
156
224
break ;
157
225
}
158
226
159
- // add the user input to the message list and format it
160
- messages.push_back ({" user" , strdup (user.c_str ())});
161
- int new_len = llama_chat_apply_template (model, nullptr , messages.data (), messages.size (), true , formatted.data (), formatted.size ());
162
- if (new_len > (int )formatted.size ()) {
163
- formatted.resize (new_len);
164
- new_len = llama_chat_apply_template (model, nullptr , messages.data (), messages.size (), true , formatted.data (), formatted.size ());
165
- }
227
+ // Add user input to messages
228
+ add_message (" user" , user, messages, owned_content);
229
+ int new_len = apply_chat_template (model, messages, formatted, true );
166
230
if (new_len < 0 ) {
167
231
fprintf (stderr, " failed to apply the chat template\n " );
168
232
return 1 ;
169
233
}
170
234
171
- // remove previous messages to obtain the prompt to generate the response
172
- std::string prompt (formatted.begin () + prev_len, formatted.begin () + new_len);
235
+ // remove previous messages to obtain the prompt to generate the
236
+ // response
237
+ std::string prompt (formatted.begin () + prev_len,
238
+ formatted.begin () + new_len);
173
239
174
240
// generate a response
175
241
printf (" \033 [33m" );
176
- std::string response = generate (prompt);
242
+ std::string response;
243
+ if (generate (model, smpl, ctx, prompt, response)) {
244
+ return 1 ;
245
+ }
246
+
177
247
printf (" \n\033 [0m" );
178
248
179
- // add the response to the messages
180
- messages.push_back ({" assistant" , strdup (response.c_str ())});
181
- prev_len = llama_chat_apply_template (model, nullptr , messages.data (), messages.size (), false , nullptr , 0 );
249
+ // Add response to messages
250
+ prev_len = apply_chat_template (model, messages, formatted, false );
182
251
if (prev_len < 0 ) {
183
252
fprintf (stderr, " failed to apply the chat template\n " );
184
253
return 1 ;
185
254
}
186
255
}
187
256
188
- // free resources
189
- for (auto & msg : messages) {
190
- free (const_cast <char *>(msg.content ));
191
- }
192
257
llama_sampler_free (smpl);
193
258
llama_free (ctx);
194
259
llama_free_model (model);
0 commit comments