@@ -129,8 +129,10 @@ __static_yoink("zipos");
129
129
130
130
// Portable OpenMP and OpenACC pragma macros
131
131
#ifdef OPENMP
132
+ #define ACCELS () MK_PRAGMA(omp parallel for)
132
133
#define ACCEL (...) MK_PRAGMA(omp parallel for private(__VA_ARGS__))
133
134
#elif defined(OPENACC )
135
+ #define ACCELS () MK_PRAGMA(acc parallel loop)
134
136
#define ACCEL (...) MK_PRAGMA(acc parallel loop private(__VA_ARGS__))
135
137
#endif
136
138
@@ -154,7 +156,13 @@ __static_yoink("zipos");
154
156
#endif
155
157
// ----------------------------------------------------------------------------
156
158
// Globals
159
+ // L2E Addition
160
+ #if defined CAT
161
+ const int GS = 64 ; // group size 64 for Cheap Acceleration Tech :)
162
+ #else
157
163
int GS = 0 ; // group size global for quantization of the weights
164
+ #endif
165
+ // END L2E Addition
158
166
159
167
// ----------------------------------------------------------------------------
160
168
// Transformer model
@@ -275,6 +283,11 @@ void free_run_state(RunState* s) {
275
283
// Quantization functions
276
284
277
285
void dequantize (QuantizedTensor * qx , float * x , int n ) {
286
+ // L2E Addition
287
+ #ifdef ACCEL
288
+ ACCELS () // OMP/OACC Macro
289
+ #endif
290
+ // END L2E Addition
278
291
for (int i = 0 ; i < n ; i ++ ) {
279
292
x [i ] = qx -> q [i ] * qx -> s [i / GS ];
280
293
}
@@ -284,6 +297,11 @@ void quantize(QuantizedTensor *qx, float* x, int n) {
284
297
int num_groups = n / GS ;
285
298
float Q_MAX = 127.0f ;
286
299
300
+ // L2E Addition
301
+ #ifdef ACCEL
302
+ ACCELS () // OMP/OACC Macro
303
+ #endif
304
+ // END L2E Addition
287
305
for (int group = 0 ; group < num_groups ; group ++ ) {
288
306
289
307
// find the max absolute value in the current group
@@ -391,7 +409,11 @@ void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weigh
391
409
int group_size = * (int * ) ptr ;
392
410
ptr += sizeof (int );
393
411
412
+ // L2E Addition
413
+ #ifndef CAT
394
414
GS = group_size ; // set as global, as it will be used in many places
415
+ #endif
416
+ // END L2E Addition
395
417
396
418
void * weights_ptr = ((char * )* data ) + header_size ; // skip header bytes
397
419
memory_map_weights (weights , config , weights_ptr , shared_classifier );
@@ -419,7 +441,13 @@ void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weigh
419
441
if (fread (& shared_classifier , sizeof (uint8_t ), 1 , file ) != 1 ) { exit (EXIT_FAILURE ); }
420
442
int group_size ; // the group size used in quantization
421
443
if (fread (& group_size , sizeof (int ), 1 , file ) != 1 ) { exit (EXIT_FAILURE ); }
444
+
445
+ // L2E Addition
446
+ #ifndef CAT
422
447
GS = group_size ; // set as global, as it will be used in many places
448
+ #endif
449
+ // END L2E Addition
450
+
423
451
// figure out the file size
424
452
fseek (file , 0 , SEEK_END ); // move file pointer to end of file
425
453
* file_size = ftell (file ); // get the file size, in bytes
@@ -508,64 +536,77 @@ void softmax(float* x, int size) {
508
536
}
509
537
}
510
538
539
+ // L2E Addition
540
+ #ifdef CAT
541
+
511
542
void matmul (float * xout , QuantizedTensor * x , QuantizedTensor * w , int n , int d ) {
512
543
// W (d,n) @ x (n,) -> xout (d,)
513
544
// by far the most amount of time is spent inside this little function
514
545
// inputs to this function are both quantized
515
546
516
- // L2E Addition
517
-
518
- #ifdef BLAS
519
547
int i ;
520
- int j ;
521
-
522
- // Convert quantized tensors to floating point
523
- float * w_fp = malloc (d * n * sizeof (float ));
524
- float * x_fp = malloc (n * sizeof (float ));
525
-
526
548
#ifdef ACCEL
527
- ACCEL (i , j ) // OMP/OACC Macro
528
- #endif
549
+ ACCEL (i ) // OMP/OACC Macro
550
+ #endif
529
551
for (i = 0 ; i < d ; i ++ ) {
530
- for (j = 0 ; j < n ; j ++ ) {
531
- w_fp [i * n + j ] = ((float ) w -> q [i * n + j ]) * w -> s [i / GS ];
552
+
553
+ float val = 0.0f ;
554
+ int32_t ival = 0 ;
555
+ int in = i * n ;
556
+
557
+ // do the matmul in groups of GS
558
+ int j ;
559
+ for (j = 0 ; j <= n - GS ; j += GS ) {
560
+ // unroll the inner loop by a factor of 4
561
+ for (int k = 0 ; k < GS ; k += 4 ) {
562
+ ival += ((int32_t ) x -> q [j + k ]) * ((int32_t ) w -> q [in + j + k ]);
563
+ ival += ((int32_t ) x -> q [j + k + 1 ]) * ((int32_t ) w -> q [in + j + k + 1 ]);
564
+ ival += ((int32_t ) x -> q [j + k + 2 ]) * ((int32_t ) w -> q [in + j + k + 2 ]);
565
+ ival += ((int32_t ) x -> q [j + k + 3 ]) * ((int32_t ) w -> q [in + j + k + 3 ]);
566
+ }
567
+ val += ((float ) ival ) * w -> s [(in + j ) / GS ] * x -> s [j / GS ];
568
+ ival = 0 ;
532
569
}
533
- }
534
570
535
- #ifdef ACCEL
536
- ACCEL (j ) // OMP/OACC Macro
537
- #endif
538
- for (j = 0 ; j < n ; j ++ ) {
539
- x_fp [j ] = ((float ) x -> q [j ]) * x -> s [j / GS ];
571
+ xout [i ] = val ;
540
572
}
573
+ }
541
574
542
- cblas_sgemv (CblasRowMajor , CblasNoTrans , d , n , 1.0f , w_fp , n , x_fp , 1 , 0.0f , xout , 1 );
543
-
544
- // Free memory
545
- free (w_fp );
546
- free (x_fp );
575
+ #else
576
+ // END L2E Addition
577
+ void matmul (float * xout , QuantizedTensor * x , QuantizedTensor * w , int n , int d ) {
578
+ // W (d,n) @ x (n,) -> xout (d,)
579
+ // by far the most amount of time is spent inside this little function
580
+ // inputs to this function are both quantized
547
581
548
- #else
582
+ int i ;
583
+ // L2E Addition
584
+ #ifdef ACCEL
585
+ ACCEL (i ) // OMP/OACC Macro
586
+ #endif
549
587
// END L2E Addition
550
- for (int i = 0 ; i < d ; i ++ ) {
588
+ for (i = 0 ; i < d ; i ++ ) {
589
+
551
590
float val = 0.0f ;
552
591
int32_t ival = 0 ;
553
592
int in = i * n ;
554
593
555
594
// do the matmul in groups of GS
556
- for (int j = 0 ; j <= n - GS ; j += GS ) {
595
+ int j ;
596
+ for (j = 0 ; j <= n - GS ; j += GS ) {
557
597
for (int k = 0 ; k < GS ; k ++ ) {
558
598
ival += ((int32_t ) x -> q [j + k ]) * ((int32_t ) w -> q [in + j + k ]);
559
599
}
560
600
val += ((float ) ival ) * w -> s [(in + j ) / GS ] * x -> s [j / GS ];
561
601
ival = 0 ;
562
602
}
603
+
563
604
xout [i ] = val ;
564
605
}
606
+ }
565
607
// L2E Addition
566
- #endif
608
+ #endif
567
609
// END L2E Addition
568
- }
569
610
570
611
float * forward (Transformer * transformer , int token , int pos ) {
571
612
0 commit comments