Skip to content

Commit dec66d2

Browse files
authored
[Kernel] GGUF MMVQ kernel for multiple input vectors (#18754)
Signed-off-by: SzymonOzog <[email protected]>
1 parent 8d12070 commit dec66d2

File tree

4 files changed

+95
-87
lines changed

4 files changed

+95
-87
lines changed

csrc/quantization/gguf/gguf_kernel.cu

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -92,111 +92,112 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
9292
torch::Tensor X, // input
9393
int64_t type, int64_t row) {
9494
int col = X.sizes()[1];
95+
int vecs = X.sizes()[0];
9596
const int padded = (col + 512 - 1) / 512 * 512;
9697
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
9798
auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device());
98-
at::Tensor Y = torch::empty({1, row}, options);
99+
at::Tensor Y = torch::empty({vecs, row}, options);
99100
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
100101
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
101-
at::Tensor quant_X = torch::empty({1, padded / 32 * 9}, options);
102+
at::Tensor quant_X = torch::empty({vecs, padded / 32 * 9}, options);
102103
VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_mul_mat_vec_a8", [&] {
103-
quantize_row_q8_1_cuda<scalar_t>((scalar_t*)X.data_ptr(),
104-
(void*)quant_X.data_ptr(), col, 1, stream);
104+
quantize_row_q8_1_cuda<scalar_t>(
105+
(scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(), col, vecs, stream);
105106
switch (type) {
106107
case 2:
107108
mul_mat_vec_q4_0_q8_1_cuda<scalar_t>(
108109
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
109-
(scalar_t*)Y.data_ptr(), col, row, stream);
110+
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
110111
break;
111112
case 3:
112113
mul_mat_vec_q4_1_q8_1_cuda<scalar_t>(
113114
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
114-
(scalar_t*)Y.data_ptr(), col, row, stream);
115+
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
115116
break;
116117
case 6:
117118
mul_mat_vec_q5_0_q8_1_cuda<scalar_t>(
118119
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
119-
(scalar_t*)Y.data_ptr(), col, row, stream);
120+
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
120121
break;
121122
case 7:
122123
mul_mat_vec_q5_1_q8_1_cuda<scalar_t>(
123124
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
124-
(scalar_t*)Y.data_ptr(), col, row, stream);
125+
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
125126
break;
126127
case 8:
127128
mul_mat_vec_q8_0_q8_1_cuda<scalar_t>(
128129
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
129-
(scalar_t*)Y.data_ptr(), col, row, stream);
130+
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
130131
break;
131132
case 10:
132133
mul_mat_vec_q2_K_q8_1_cuda<scalar_t>(
133134
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
134-
(scalar_t*)Y.data_ptr(), col, row, stream);
135+
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
135136
break;
136137
case 11:
137138
mul_mat_vec_q3_K_q8_1_cuda<scalar_t>(
138139
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
139-
(scalar_t*)Y.data_ptr(), col, row, stream);
140+
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
140141
break;
141142
case 12:
142143
mul_mat_vec_q4_K_q8_1_cuda<scalar_t>(
143144
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
144-
(scalar_t*)Y.data_ptr(), col, row, stream);
145+
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
145146
break;
146147
case 13:
147148
mul_mat_vec_q5_K_q8_1_cuda<scalar_t>(
148149
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
149-
(scalar_t*)Y.data_ptr(), col, row, stream);
150+
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
150151
break;
151152
case 14:
152153
mul_mat_vec_q6_K_q8_1_cuda<scalar_t>(
153154
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
154-
(scalar_t*)Y.data_ptr(), col, row, stream);
155+
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
155156
break;
156157
case 16:
157158
mul_mat_vec_iq2_xxs_q8_1_cuda<scalar_t>(
158159
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
159-
(scalar_t*)Y.data_ptr(), col, row, stream);
160+
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
160161
break;
161162
case 17:
162163
mul_mat_vec_iq2_xs_q8_1_cuda<scalar_t>(
163164
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
164-
(scalar_t*)Y.data_ptr(), col, row, stream);
165+
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
165166
break;
166167
case 18:
167168
mul_mat_vec_iq3_xxs_q8_1_cuda<scalar_t>(
168169
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
169-
(scalar_t*)Y.data_ptr(), col, row, stream);
170+
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
170171
break;
171172
case 19:
172173
mul_mat_vec_iq1_s_q8_1_cuda<scalar_t>(
173174
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
174-
(scalar_t*)Y.data_ptr(), col, row, stream);
175+
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
175176
break;
176177
case 20:
177178
mul_mat_vec_iq4_nl_q8_1_cuda<scalar_t>(
178179
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
179-
(scalar_t*)Y.data_ptr(), col, row, stream);
180+
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
180181
break;
181182
case 21:
182183
mul_mat_vec_iq3_s_q8_1_cuda<scalar_t>(
183184
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
184-
(scalar_t*)Y.data_ptr(), col, row, stream);
185+
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
185186
break;
186187
case 22:
187188
mul_mat_vec_iq2_s_q8_1_cuda<scalar_t>(
188189
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
189-
(scalar_t*)Y.data_ptr(), col, row, stream);
190+
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
190191
break;
191192
case 23:
192193
mul_mat_vec_iq4_xs_q8_1_cuda<scalar_t>(
193194
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
194-
(scalar_t*)Y.data_ptr(), col, row, stream);
195+
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
195196
break;
196197
case 29:
197198
mul_mat_vec_iq1_m_q8_1_cuda<scalar_t>(
198199
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
199-
(scalar_t*)Y.data_ptr(), col, row, stream);
200+
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
200201
break;
201202
}
202203
});

0 commit comments

Comments
 (0)