@@ -92,111 +92,112 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
92
92
torch::Tensor X, // input
93
93
int64_t type, int64_t row) {
94
94
int col = X.sizes ()[1 ];
95
+ int vecs = X.sizes ()[0 ];
95
96
const int padded = (col + 512 - 1 ) / 512 * 512 ;
96
97
const at::cuda::OptionalCUDAGuard device_guard (device_of (X));
97
98
auto options = torch::TensorOptions ().dtype (X.dtype ()).device (W.device ());
98
- at::Tensor Y = torch::empty ({1 , row}, options);
99
+ at::Tensor Y = torch::empty ({vecs , row}, options);
99
100
cudaStream_t stream = at::cuda::getCurrentCUDAStream ().stream ();
100
101
options = torch::TensorOptions ().dtype (torch::kInt32 ).device (W.device ());
101
- at::Tensor quant_X = torch::empty ({1 , padded / 32 * 9 }, options);
102
+ at::Tensor quant_X = torch::empty ({vecs , padded / 32 * 9 }, options);
102
103
VLLM_DISPATCH_FLOATING_TYPES (X.scalar_type (), " ggml_mul_mat_vec_a8" , [&] {
103
- quantize_row_q8_1_cuda<scalar_t >(( scalar_t *)X. data_ptr (),
104
- (void *)quant_X.data_ptr (), col, 1 , stream);
104
+ quantize_row_q8_1_cuda<scalar_t >(
105
+ ( scalar_t *)X. data_ptr (), (void *)quant_X.data_ptr (), col, vecs , stream);
105
106
switch (type) {
106
107
case 2 :
107
108
mul_mat_vec_q4_0_q8_1_cuda<scalar_t >(
108
109
(void *)W.data_ptr (), (void *)quant_X.data_ptr (),
109
- (scalar_t *)Y.data_ptr (), col, row, stream);
110
+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
110
111
break ;
111
112
case 3 :
112
113
mul_mat_vec_q4_1_q8_1_cuda<scalar_t >(
113
114
(void *)W.data_ptr (), (void *)quant_X.data_ptr (),
114
- (scalar_t *)Y.data_ptr (), col, row, stream);
115
+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
115
116
break ;
116
117
case 6 :
117
118
mul_mat_vec_q5_0_q8_1_cuda<scalar_t >(
118
119
(void *)W.data_ptr (), (void *)quant_X.data_ptr (),
119
- (scalar_t *)Y.data_ptr (), col, row, stream);
120
+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
120
121
break ;
121
122
case 7 :
122
123
mul_mat_vec_q5_1_q8_1_cuda<scalar_t >(
123
124
(void *)W.data_ptr (), (void *)quant_X.data_ptr (),
124
- (scalar_t *)Y.data_ptr (), col, row, stream);
125
+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
125
126
break ;
126
127
case 8 :
127
128
mul_mat_vec_q8_0_q8_1_cuda<scalar_t >(
128
129
(void *)W.data_ptr (), (void *)quant_X.data_ptr (),
129
- (scalar_t *)Y.data_ptr (), col, row, stream);
130
+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
130
131
break ;
131
132
case 10 :
132
133
mul_mat_vec_q2_K_q8_1_cuda<scalar_t >(
133
134
(void *)W.data_ptr (), (void *)quant_X.data_ptr (),
134
- (scalar_t *)Y.data_ptr (), col, row, stream);
135
+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
135
136
break ;
136
137
case 11 :
137
138
mul_mat_vec_q3_K_q8_1_cuda<scalar_t >(
138
139
(void *)W.data_ptr (), (void *)quant_X.data_ptr (),
139
- (scalar_t *)Y.data_ptr (), col, row, stream);
140
+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
140
141
break ;
141
142
case 12 :
142
143
mul_mat_vec_q4_K_q8_1_cuda<scalar_t >(
143
144
(void *)W.data_ptr (), (void *)quant_X.data_ptr (),
144
- (scalar_t *)Y.data_ptr (), col, row, stream);
145
+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
145
146
break ;
146
147
case 13 :
147
148
mul_mat_vec_q5_K_q8_1_cuda<scalar_t >(
148
149
(void *)W.data_ptr (), (void *)quant_X.data_ptr (),
149
- (scalar_t *)Y.data_ptr (), col, row, stream);
150
+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
150
151
break ;
151
152
case 14 :
152
153
mul_mat_vec_q6_K_q8_1_cuda<scalar_t >(
153
154
(void *)W.data_ptr (), (void *)quant_X.data_ptr (),
154
- (scalar_t *)Y.data_ptr (), col, row, stream);
155
+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
155
156
break ;
156
157
case 16 :
157
158
mul_mat_vec_iq2_xxs_q8_1_cuda<scalar_t >(
158
159
(void *)W.data_ptr (), (void *)quant_X.data_ptr (),
159
- (scalar_t *)Y.data_ptr (), col, row, stream);
160
+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
160
161
break ;
161
162
case 17 :
162
163
mul_mat_vec_iq2_xs_q8_1_cuda<scalar_t >(
163
164
(void *)W.data_ptr (), (void *)quant_X.data_ptr (),
164
- (scalar_t *)Y.data_ptr (), col, row, stream);
165
+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
165
166
break ;
166
167
case 18 :
167
168
mul_mat_vec_iq3_xxs_q8_1_cuda<scalar_t >(
168
169
(void *)W.data_ptr (), (void *)quant_X.data_ptr (),
169
- (scalar_t *)Y.data_ptr (), col, row, stream);
170
+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
170
171
break ;
171
172
case 19 :
172
173
mul_mat_vec_iq1_s_q8_1_cuda<scalar_t >(
173
174
(void *)W.data_ptr (), (void *)quant_X.data_ptr (),
174
- (scalar_t *)Y.data_ptr (), col, row, stream);
175
+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
175
176
break ;
176
177
case 20 :
177
178
mul_mat_vec_iq4_nl_q8_1_cuda<scalar_t >(
178
179
(void *)W.data_ptr (), (void *)quant_X.data_ptr (),
179
- (scalar_t *)Y.data_ptr (), col, row, stream);
180
+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
180
181
break ;
181
182
case 21 :
182
183
mul_mat_vec_iq3_s_q8_1_cuda<scalar_t >(
183
184
(void *)W.data_ptr (), (void *)quant_X.data_ptr (),
184
- (scalar_t *)Y.data_ptr (), col, row, stream);
185
+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
185
186
break ;
186
187
case 22 :
187
188
mul_mat_vec_iq2_s_q8_1_cuda<scalar_t >(
188
189
(void *)W.data_ptr (), (void *)quant_X.data_ptr (),
189
- (scalar_t *)Y.data_ptr (), col, row, stream);
190
+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
190
191
break ;
191
192
case 23 :
192
193
mul_mat_vec_iq4_xs_q8_1_cuda<scalar_t >(
193
194
(void *)W.data_ptr (), (void *)quant_X.data_ptr (),
194
- (scalar_t *)Y.data_ptr (), col, row, stream);
195
+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
195
196
break ;
196
197
case 29 :
197
198
mul_mat_vec_iq1_m_q8_1_cuda<scalar_t >(
198
199
(void *)W.data_ptr (), (void *)quant_X.data_ptr (),
199
- (scalar_t *)Y.data_ptr (), col, row, stream);
200
+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
200
201
break ;
201
202
}
202
203
});
0 commit comments