Skip to content

Commit 70d71f5

Browse files
EAddariojwcolin
authored andcommitted
quantize: Handle user-defined quantization levels for additional tensors (ggml-org#12511)
* Add llama_model_quantize_params parameters * Add new quantize parameters parsing and validation * Update usage * Add new parameters defaults * Add new quantization parameters logic * Add llama_model_quantize_params parameters * Add new quantize parameters parsing and validation * Update usage * Add new parameters defaults * Add new quantization parameters logic * Minor refactoring as per the contributors' coding guidelines * Update descriptions to match existing style * Add llama_model_quantize_params parameters * Add new quantize parameters parsing and validation * Update usage * Add new parameters defaults * Add new quantization parameters logic * Minor refactoring as per the contributors' guidelines * Implement general --tensor-type instead of tensor-specific command option * Fix implied type bug * Restore missing #includes * Add regex capability for tensor selection * Refactor function name and update ALLOWED_TENSOR_TYPE * Add missing #include * Handle edge case when tensor name is cls.output * Minor logging improvement
1 parent ac92a9d commit 70d71f5

File tree

3 files changed

+155
-20
lines changed

3 files changed

+155
-20
lines changed

examples/quantize/quantize.cpp

+115-2
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@
99
#include <fstream>
1010
#include <cmath>
1111
#include <cctype>
12+
#include <algorithm>
1213

1314
struct quant_option {
1415
std::string name;
1516
llama_ftype ftype;
1617
std::string desc;
1718
};
1819

19-
static const std::vector<struct quant_option> QUANT_OPTIONS = {
20+
static const std::vector<quant_option> QUANT_OPTIONS = {
2021
{ "Q4_0", LLAMA_FTYPE_MOSTLY_Q4_0, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
2122
{ "Q4_1", LLAMA_FTYPE_MOSTLY_Q4_1, " 4.78G, +0.4511 ppl @ Llama-3-8B", },
2223
{ "Q5_0", LLAMA_FTYPE_MOSTLY_Q5_0, " 5.21G, +0.1316 ppl @ Llama-3-8B", },
@@ -105,7 +106,8 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp
105106
//
106107
[[noreturn]]
107108
static void usage(const char * executable) {
108-
printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights] [--output-tensor-type] [--token-embedding-type] [--override-kv] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n", executable);
109+
printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights] [--output-tensor-type]\n", executable);
110+
printf(" [--token-embedding-type] [--tensor-type] [--keep-split] [--override-kv] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n");
109111
printf(" --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n");
110112
printf(" --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n");
111113
printf(" --pure: Disable k-quant mixtures and quantize all tensors to the same type\n");
@@ -114,6 +116,8 @@ static void usage(const char * executable) {
114116
printf(" --exclude-weights tensor_name: use importance matrix for this/these tensor(s)\n");
115117
printf(" --output-tensor-type ggml_type: use this ggml_type for the output.weight tensor\n");
116118
printf(" --token-embedding-type ggml_type: use this ggml_type for the token embeddings tensor\n");
119+
printf(" --tensor-type TENSOR=TYPE: quantize this tensor to this ggml_type. example: --tensor-type attn_q=q8_0\n");
120+
printf(" Advanced option to selectively quantize tensors. May be specified multiple times.\n");
117121
printf(" --keep-split: will generate quantized model in the same shards as input\n");
118122
printf(" --override-kv KEY=TYPE:VALUE\n");
119123
printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n");
@@ -244,6 +248,107 @@ static ggml_type parse_ggml_type(const char * arg) {
244248
return GGML_TYPE_COUNT;
245249
}
246250

251+
// Allowed tensors for arbitrary quantization with --tensor-type option
252+
static const std::vector<std::string> ALLOWED_TENSOR_TYPE = {
253+
"attn_k",
254+
"attn_kv_a_mqa",
255+
"attn_kv_b",
256+
"attn_o",
257+
"attn_output",
258+
"attn_q",
259+
"attn_q_a",
260+
"attn_q_b",
261+
"attn_qkv",
262+
"attn_v",
263+
"channel_mix_key",
264+
"channel_mix_receptance",
265+
"channel_mix_value",
266+
"cls",
267+
"cls.output",
268+
"cross_attn_k",
269+
"cross_attn_o",
270+
"cross_attn_q",
271+
"cross_attn_v",
272+
"ffn_act",
273+
"ffn_down",
274+
"ffn_down_exps",
275+
"ffn_down_shexp",
276+
"ffn_gate",
277+
"ffn_gate_exps",
278+
"ffn_gate_shexp",
279+
"ffn_up",
280+
"ffn_up_exps",
281+
"ffn_up_shexp",
282+
"ssm_in",
283+
"ssm_out",
284+
"time_mix_gate",
285+
"time_mix_key",
286+
"time_mix_output",
287+
"time_mix_receptance",
288+
"time_mix_value",
289+
};
290+
291+
// changes to this struct must be replicated in llama-quant.cpp
292+
struct tensor_quantization {
293+
std::string name;
294+
ggml_type quant = GGML_TYPE_COUNT;
295+
};
296+
297+
static bool parse_tensor_type(const char * data, std::vector<tensor_quantization> & tensor_type) {
298+
const char * sep = strchr(data, '=');
299+
if (sep == nullptr) {
300+
printf("\n%s: malformed tensor type '%s'\n\n", __func__, data);
301+
return false;
302+
}
303+
304+
const size_t tn_len = sep - data;
305+
if (tn_len == 0) {
306+
printf("\n%s: missing tensor name\n\n", __func__);
307+
return false;
308+
}
309+
310+
if (const size_t qt_len = strlen(sep); qt_len == 1) {
311+
printf("\n%s: missing quantization type\n\n", __func__);
312+
return false;
313+
}
314+
315+
std::string tn(data, tn_len);
316+
std::transform(tn.begin(), tn.end(), tn.begin(), tolower);
317+
sep++;
318+
const std::string qt(sep);
319+
320+
bool found = false;
321+
for (const auto & allowed : ALLOWED_TENSOR_TYPE) {
322+
std::string tensor;
323+
tensor = tn.rfind('.') != std::string::npos ? tn.substr(tn.rfind('.') + 1) : tn;
324+
// handle special case of cls.output
325+
std::string cls_output = "cls.output";
326+
if (tn.find(cls_output) != std::string::npos) {
327+
tensor = "cls.output";
328+
}
329+
// check if an allowed tensor exists and it's at the end of the kv string
330+
if (tensor == allowed) {
331+
found = true;
332+
break;
333+
}
334+
}
335+
if (!found) {
336+
printf("\n%s: invalid tensor name '%s'\n\n", __func__, tn.c_str());
337+
return false;
338+
}
339+
340+
if (parse_ggml_type(qt.c_str()) == GGML_TYPE_COUNT) {
341+
printf("\n%s: invalid quantization type '%s'\n\n", __func__, qt.c_str());
342+
return false;
343+
}
344+
345+
tensor_quantization tqz;
346+
tqz.name = tn;
347+
tqz.quant = parse_ggml_type(qt.c_str());
348+
tensor_type.emplace_back(std::move(tqz));
349+
return true;
350+
}
351+
247352
int main(int argc, char ** argv) {
248353
if (argc < 3) {
249354
usage(argv[0]);
@@ -255,6 +360,7 @@ int main(int argc, char ** argv) {
255360
std::string imatrix_file;
256361
std::vector<std::string> included_weights, excluded_weights;
257362
std::vector<llama_model_kv_override> kv_overrides;
363+
std::vector<tensor_quantization> tensor_types;
258364

259365
for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) {
260366
if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) {
@@ -277,6 +383,10 @@ int main(int argc, char ** argv) {
277383
} else {
278384
usage(argv[0]);
279385
}
386+
} else if (strcmp(argv[arg_idx], "--tensor-type") == 0) {
387+
if (arg_idx == argc-1 || !parse_tensor_type(argv[++arg_idx], tensor_types)) {
388+
usage(argv[0]);
389+
}
280390
} else if (strcmp(argv[arg_idx], "--override-kv") == 0) {
281391
if (arg_idx == argc-1 || !string_parse_kv_override(argv[++arg_idx], kv_overrides)) {
282392
usage(argv[0]);
@@ -361,6 +471,9 @@ int main(int argc, char ** argv) {
361471
kv_overrides.back().key[0] = 0;
362472
params.kv_overrides = &kv_overrides;
363473
}
474+
if (!tensor_types.empty()) {
475+
params.tensor_types = &tensor_types;
476+
}
364477

365478
llama_backend_init();
366479

include/llama.h

+12-11
Original file line numberDiff line numberDiff line change
@@ -367,17 +367,18 @@ extern "C" {
367367

368368
// model quantization parameters
369369
typedef struct llama_model_quantize_params {
370-
int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
371-
enum llama_ftype ftype; // quantize to this llama_ftype
372-
enum ggml_type output_tensor_type; // output tensor type
373-
enum ggml_type token_embedding_type; // token embeddings tensor type
374-
bool allow_requantize; // allow quantizing non-f32/f16 tensors
375-
bool quantize_output_tensor; // quantize output.weight
376-
bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
377-
bool pure; // quantize all tensors to the default type
378-
bool keep_split; // quantize to the same number of shards
379-
void * imatrix; // pointer to importance matrix data
380-
void * kv_overrides; // pointer to vector containing overrides
370+
int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
371+
enum llama_ftype ftype; // quantize to this llama_ftype
372+
enum ggml_type output_tensor_type; // output tensor type
373+
enum ggml_type token_embedding_type; // token embeddings tensor type
374+
bool allow_requantize; // allow quantizing non-f32/f16 tensors
375+
bool quantize_output_tensor; // quantize output.weight
376+
bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
377+
bool pure; // quantize all tensors to the default type
378+
bool keep_split; // quantize to the same number of shards
379+
void * imatrix; // pointer to importance matrix data
380+
void * kv_overrides; // pointer to vector containing overrides
381+
void * tensor_types; // pointer to vector containing tensor types
381382
} llama_model_quantize_params;
382383

383384
typedef struct llama_logit_bias {

src/llama-quant.cpp

+28-7
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <cinttypes>
1111
#include <fstream>
1212
#include <mutex>
13+
#include <regex>
1314
#include <thread>
1415
#include <unordered_map>
1516

@@ -47,8 +48,14 @@ struct quantize_state_impl {
4748
{}
4849
};
4950

51+
// changes to this struct must be replicated in quantize.cpp
52+
struct tensor_quantization {
53+
std::string name;
54+
ggml_type quant = GGML_TYPE_COUNT;
55+
};
56+
5057
static void llama_tensor_dequantize_impl(
51-
struct ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
58+
ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
5259
const size_t nelements, const int nthread
5360
) {
5461
if (output.size() < nelements) {
@@ -536,7 +543,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
536543
model.load_hparams(ml);
537544
model.load_stats (ml);
538545

539-
struct quantize_state_impl qs(model, params);
546+
quantize_state_impl qs(model, params);
540547

541548
if (params->only_copy) {
542549
ftype = ml.ftype;
@@ -661,7 +668,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
661668
// populate the original tensors so we get an initial meta data
662669
for (const auto * it : tensors) {
663670
uint16_t i_split = params->keep_split ? it->idx : 0;
664-
struct ggml_tensor * tensor = it->tensor;
671+
ggml_tensor * tensor = it->tensor;
665672
if (!ctx_outs[i_split]) {
666673
ctx_outs[i_split].reset(gguf_init_empty());
667674
}
@@ -710,7 +717,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
710717
new_ofstream(0);
711718
for (const auto * it : tensors) {
712719
const auto & weight = *it;
713-
struct ggml_tensor * tensor = weight.tensor;
720+
ggml_tensor * tensor = weight.tensor;
714721
if (weight.idx != cur_split && params->keep_split) {
715722
close_ofstream();
716723
new_ofstream(weight.idx);
@@ -776,7 +783,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
776783
// do not quantize relative position bias (T5)
777784
quantize &= name.find("attn_rel_b.weight") == std::string::npos;
778785

779-
enum ggml_type new_type;
786+
ggml_type new_type;
780787
void * new_data;
781788
size_t new_size;
782789

@@ -786,6 +793,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
786793
// get more optimal quantization type based on the tensor shape, layer, etc.
787794
if (!params->pure && ggml_is_quantized(default_type)) {
788795
new_type = llama_tensor_get_type(qs, new_type, tensor, ftype);
796+
// unless the user specifies a type
797+
if (params->tensor_types) {
798+
const std::vector<tensor_quantization> & tensor_types = *static_cast<const std::vector<tensor_quantization> *>(params->tensor_types);
799+
for (const auto & [tname, qtype] : tensor_types) {
800+
if (std::regex pattern(tname); std::regex_search(tensor->name, pattern)) {
801+
if (qtype != new_type) {
802+
LLAMA_LOG_DEBUG("(overriding %s -> %s), ", ggml_type_name(new_type), ggml_type_name(qtype));
803+
}
804+
new_type = qtype;
805+
break;
806+
}
807+
}
808+
}
789809
}
790810
if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {
791811
new_type = params->token_embedding_type;
@@ -910,8 +930,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
910930
// interface implementation
911931
//
912932

913-
struct llama_model_quantize_params llama_model_quantize_default_params() {
914-
struct llama_model_quantize_params result = {
933+
llama_model_quantize_params llama_model_quantize_default_params() {
934+
llama_model_quantize_params result = {
915935
/*.nthread =*/ 0,
916936
/*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1,
917937
/*.output_tensor_type =*/ GGML_TYPE_COUNT,
@@ -923,6 +943,7 @@ struct llama_model_quantize_params llama_model_quantize_default_params() {
923943
/*.keep_split =*/ false,
924944
/*.imatrix =*/ nullptr,
925945
/*.kv_overrides =*/ nullptr,
946+
/*.tensor_type =*/ nullptr,
926947
};
927948

928949
return result;

0 commit comments

Comments
 (0)