Skip to content

Commit 036d7cb

Browse files
committed
runq - remove blas & optimize
runq - optimize matmul and quantization functions with OpenMP
1 parent 8458b68 commit 036d7cb

File tree

2 files changed

+80
-39
lines changed

2 files changed

+80
-39
lines changed

Diff for: Makefile

+10-10
Original file line numberDiff line numberDiff line change
@@ -90,55 +90,55 @@ run_cc_openmp: ## - OpenMP accelerated build
9090

9191
.PHONY: runq_cc_openmp
9292
runq_cc_openmp: ## - Same for quantized build
93-
$(CC) -D OPENMP -Ofast -fopenmp -march=native -mtune=native runq.c $(BOLT) -lm -o run
93+
$(CC) -D OPENMP -D CAT -Ofast -fopenmp -march=native -mtune=native runq.c $(BOLT) -lm -o run
9494

9595
.PHONY: run_cc_openacc
9696
run_cc_openacc: ## - OpenACC accelerated build
9797
$(CC) -D OPENACC -Ofast -fopenacc -march=native -mtune=native run.c $(BOLT) -lm -o run
9898

9999
.PHONY: runq_cc_openacc
100100
runq_cc_openacc: ## - Same for quantized build
101-
$(CC) -D OPENACC -Ofast -fopenacc -march=native -mtune=native runq.c $(BOLT) -lm -o run
101+
$(CC) -D OPENACC -D CAT -Ofast -fopenacc -march=native -mtune=native runq.c $(BOLT) -lm -o run
102102

103103
.PHONY: run_cc_omp_gnu
104104
run_cc_omp_gnu: ## - Generic linux distro + OpenMP build
105105
$(CC) -D OPENMP -Ofast -fopenmp -march=native -mtune=native -std=gnu11 run.c $(BOLT) -lm -o run
106106

107107
.PHONY: runq_cc_omp_gnu
108108
runq_cc_omp_gnu: ## - Same for quantized build
109-
$(CC) -D OPENMP -Ofast -fopenmp -march=native -mtune=native -std=gnu11 runq.c $(BOLT) -lm -o run
109+
$(CC) -D OPENMP -D CAT -Ofast -fopenmp -march=native -mtune=native -std=gnu11 runq.c $(BOLT) -lm -o run
110110

111111
.PHONY: run_cc_clblast
112112
run_cc_clblast: ## - CLBlast OpenCL CBLAS GPU accelerated build
113113
$(CC) -D OPENMP -D CLBLAST -Ofast -fopenmp -march=native -mtune=native run.c $(BOLT) -lm -lclblast -o run
114114

115115
.PHONY: runq_cc_clblast
116116
runq_cc_clblast: ## - Same for quantized build
117-
$(CC) -D OPENMP -D CLBLAST -Ofast -fopenmp -march=native -mtune=native runq.c $(BOLT) -lm -lclblast -o run
117+
$(CC) -D OPENMP -D CAT -D CLBLAST -Ofast -fopenmp -march=native -mtune=native runq.c $(BOLT) -lm -lclblast -o run
118118

119119
.PHONY: run_cc_openblas
120120
run_cc_openblas: ## - Openblas CBLAS accelerated build
121121
$(CC) -D OPENMP -D OPENBLAS -Ofast -fopenmp -march=native -mtune=native -I$(OPENBLAS_INC) run.c $(BOLT) -lm -lopenblas -o run
122122

123123
.PHONY: runq_cc_openblas
124124
runq_cc_openblas: ## - Same for quantized build
125-
$(CC) -D OPENMP -D OPENBLAS -Ofast -fopenmp -march=native -mtune=native -I$(OPENBLAS_INC) runq.c $(BOLT) -lm -lopenblas -o run
125+
$(CC) -D OPENMP -D CAT -D OPENBLAS -Ofast -fopenmp -march=native -mtune=native -I$(OPENBLAS_INC) runq.c $(BOLT) -lm -lopenblas -o run
126126

127127
.PHONY: run_cc_cblas
128128
run_cc_cblas: ## - Generic CBLAS accelerated build
129129
$(CC) -D OPENMP -D CBLAS -Ofast -fopenmp -march=native -mtune=native run.c $(BOLT) -lm -lcblas -o run
130130

131131
.PHONY: runq_cc_cblas
132132
runq_cc_cblas: ## - Same for quantized build
133-
$(CC) -D OPENMP -D CBLAS -Ofast -fopenmp -march=native -mtune=native runq.c $(BOLT) -lm -lcblas -o run
133+
$(CC) -D OPENMP -D CAT -D CBLAS -Ofast -fopenmp -march=native -mtune=native runq.c $(BOLT) -lm -lcblas -o run
134134

135135
.PHONY: run_cc_blis
136136
run_cc_blis: ## - BLIS accelerated build
137137
$(CC) -D OPENMP -D BLIS -Ofast -fopenmp -march=native -mtune=native -I$(BLIS_INC) run.c $(BOLT) -lm -lblis -o run
138138

139139
.PHONY: runq_cc_blis
140140
runq_cc_blis: ## - Same for quantized build
141-
$(CC) -D OPENMP -D BLIS -Ofast -fopenmp -march=native -mtune=native -I$(BLIS_INC) runq.c $(BOLT) -lm -lblis -o run
141+
$(CC) -D OPENMP -D CAT -D BLIS -Ofast -fopenmp -march=native -mtune=native -I$(BLIS_INC) runq.c $(BOLT) -lm -lblis -o run
142142

143143
##@ Special Builds
144144
##@ ---> x86_64
@@ -149,7 +149,7 @@ run_cc_mkl: ## - ***NEW*** OpenMP + Intel MKL CBLAS build (x86_64 / intel Mac)
149149

150150
.PHONY: runq_cc_mkl
151151
runq_cc_mkl: ## - Same for quantized build
152-
$(CC) -D MKL -D OPENMP -Ofast -fopenmp -march=native -mtune=native -I$(MKL_INC) -L$(MKL_LIB) runq.c -lmkl_rt -lpthread $(BOLT) -lm -o run
152+
$(CC) -D MKL -D OPENMP -D CAT -Ofast -fopenmp -march=native -mtune=native -I$(MKL_INC) -L$(MKL_LIB) runq.c -lmkl_rt -lpthread $(BOLT) -lm -o run
153153

154154
##@ ---> ARM64 / aarch64
155155
.PHONY: run_cc_armpl
@@ -158,7 +158,7 @@ run_cc_armpl: ## - ARM PL BLAS accelerated build (aarch64)
158158

159159
.PHONY: runq_cc_armpl
160160
runq_cc_armpl: ## - Same for quantized build
161-
$(CC) -D ARMPL -D OPENMP -Ofast -fopenmp -march=native -mtune=native runq.c $(BOLT) -lm -larmpl_lp64_mp -o run
161+
$(CC) -D ARMPL -D OPENMP -D CAT -Ofast -fopenmp -march=native -mtune=native runq.c $(BOLT) -lm -larmpl_lp64_mp -o run
162162

163163
##@ ---> Macintosh
164164
.PHONY: run_cc_mac_accel
@@ -167,7 +167,7 @@ run_cc_mac_accel: ## - Mac OS OPENMP + CBLAS via Accelerate Framework build (WI
167167

168168
.PHONY: runq_cc_mac_accel
169169
runq_cc_mac_accel: ## - Same for quantized build
170-
$(CC) -D AAF -D OPENMP -Ofast -fopenmp -march=native -mtune=native runq.c $(BOLT) -lm -framework Accelerate -o run
170+
$(CC) -D AAF -D OPENMP -D CAT -Ofast -fopenmp -march=native -mtune=native runq.c $(BOLT) -lm -framework Accelerate -o run
171171

172172
##@ ---> Windows
173173
.PHONY: run_win64

Diff for: runq.c

+70-29
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,10 @@ __static_yoink("zipos");
129129

130130
// Portable OpenMP and OpenACC pragma macros
131131
#ifdef OPENMP
132+
#define ACCELS() MK_PRAGMA(omp parallel for)
132133
#define ACCEL(...) MK_PRAGMA(omp parallel for private(__VA_ARGS__))
133134
#elif defined(OPENACC)
135+
#define ACCELS() MK_PRAGMA(acc parallel loop)
134136
#define ACCEL(...) MK_PRAGMA(acc parallel loop private(__VA_ARGS__))
135137
#endif
136138

@@ -154,7 +156,13 @@ __static_yoink("zipos");
154156
#endif
155157
// ----------------------------------------------------------------------------
156158
// Globals
159+
// L2E Addition
160+
#if defined CAT
161+
const int GS = 64; // group size 64 for Cheap Acceleration Tech :)
162+
#else
157163
int GS = 0; // group size global for quantization of the weights
164+
#endif
165+
// END L2E Addition
158166

159167
// ----------------------------------------------------------------------------
160168
// Transformer model
@@ -275,6 +283,11 @@ void free_run_state(RunState* s) {
275283
// Quantization functions
276284

277285
void dequantize(QuantizedTensor *qx, float* x, int n) {
286+
// L2E Addition
287+
#ifdef ACCEL
288+
ACCELS() // OMP/OACC Macro
289+
#endif
290+
// END L2E Addition
278291
for (int i = 0; i < n; i++) {
279292
x[i] = qx->q[i] * qx->s[i / GS];
280293
}
@@ -284,6 +297,11 @@ void quantize(QuantizedTensor *qx, float* x, int n) {
284297
int num_groups = n / GS;
285298
float Q_MAX = 127.0f;
286299

300+
// L2E Addition
301+
#ifdef ACCEL
302+
ACCELS() // OMP/OACC Macro
303+
#endif
304+
// END L2E Addition
287305
for (int group = 0; group < num_groups; group++) {
288306

289307
// find the max absolute value in the current group
@@ -391,7 +409,11 @@ void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weigh
391409
int group_size = *(int*) ptr;
392410
ptr += sizeof(int);
393411

412+
// L2E Addition
413+
#ifndef CAT
394414
GS = group_size; // set as global, as it will be used in many places
415+
#endif
416+
// END L2E Addition
395417

396418
void* weights_ptr = ((char*)*data) + header_size; // skip header bytes
397419
memory_map_weights(weights, config, weights_ptr, shared_classifier);
@@ -419,7 +441,13 @@ void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weigh
419441
if (fread(&shared_classifier, sizeof(uint8_t), 1, file) != 1) { exit(EXIT_FAILURE); }
420442
int group_size; // the group size used in quantization
421443
if (fread(&group_size, sizeof(int), 1, file) != 1) { exit(EXIT_FAILURE); }
444+
445+
// L2E Addition
446+
#ifndef CAT
422447
GS = group_size; // set as global, as it will be used in many places
448+
#endif
449+
// END L2E Addition
450+
423451
// figure out the file size
424452
fseek(file, 0, SEEK_END); // move file pointer to end of file
425453
*file_size = ftell(file); // get the file size, in bytes
@@ -508,64 +536,77 @@ void softmax(float* x, int size) {
508536
}
509537
}
510538

539+
// L2E Addition
540+
#ifdef CAT
541+
511542
void matmul(float* xout, QuantizedTensor *x, QuantizedTensor *w, int n, int d) {
512543
// W (d,n) @ x (n,) -> xout (d,)
513544
// by far the most amount of time is spent inside this little function
514545
// inputs to this function are both quantized
515546

516-
// L2E Addition
517-
518-
#ifdef BLAS
519547
int i;
520-
int j;
521-
522-
// Convert quantized tensors to floating point
523-
float* w_fp = malloc(d * n * sizeof(float));
524-
float* x_fp = malloc(n * sizeof(float));
525-
526548
#ifdef ACCEL
527-
ACCEL(i, j) // OMP/OACC Macro
528-
#endif
549+
ACCEL(i) // OMP/OACC Macro
550+
#endif
529551
for (i = 0; i < d; i++) {
530-
for (j = 0; j < n; j++) {
531-
w_fp[i * n + j] = ((float) w->q[i * n + j]) * w->s[i / GS];
552+
553+
float val = 0.0f;
554+
int32_t ival = 0;
555+
int in = i * n;
556+
557+
// do the matmul in groups of GS
558+
int j;
559+
for (j = 0; j <= n - GS; j += GS) {
560+
// unroll the inner loop by a factor of 4
561+
for (int k = 0; k < GS; k += 4) {
562+
ival += ((int32_t) x->q[j + k]) * ((int32_t) w->q[in + j + k]);
563+
ival += ((int32_t) x->q[j + k + 1]) * ((int32_t) w->q[in + j + k + 1]);
564+
ival += ((int32_t) x->q[j + k + 2]) * ((int32_t) w->q[in + j + k + 2]);
565+
ival += ((int32_t) x->q[j + k + 3]) * ((int32_t) w->q[in + j + k + 3]);
566+
}
567+
val += ((float) ival) * w->s[(in + j) / GS] * x->s[j / GS];
568+
ival = 0;
532569
}
533-
}
534570

535-
#ifdef ACCEL
536-
ACCEL(j) // OMP/OACC Macro
537-
#endif
538-
for (j = 0; j < n; j++) {
539-
x_fp[j] = ((float) x->q[j]) * x->s[j / GS];
571+
xout[i] = val;
540572
}
573+
}
541574

542-
cblas_sgemv(CblasRowMajor, CblasNoTrans, d, n, 1.0f, w_fp, n, x_fp, 1, 0.0f, xout, 1);
543-
544-
// Free memory
545-
free(w_fp);
546-
free(x_fp);
575+
#else
576+
// END L2E Addition
577+
void matmul(float* xout, QuantizedTensor *x, QuantizedTensor *w, int n, int d) {
578+
// W (d,n) @ x (n,) -> xout (d,)
579+
// by far the most amount of time is spent inside this little function
580+
// inputs to this function are both quantized
547581

548-
#else
582+
int i;
583+
// L2E Addition
584+
#ifdef ACCEL
585+
ACCEL(i) // OMP/OACC Macro
586+
#endif
549587
// END L2E Addition
550-
for (int i = 0; i < d; i++) {
588+
for (i = 0; i < d; i++) {
589+
551590
float val = 0.0f;
552591
int32_t ival = 0;
553592
int in = i * n;
554593

555594
// do the matmul in groups of GS
556-
for (int j = 0; j <= n - GS; j += GS) {
595+
int j;
596+
for (j = 0; j <= n - GS; j += GS) {
557597
for (int k = 0; k < GS; k++) {
558598
ival += ((int32_t) x->q[j + k]) * ((int32_t) w->q[in + j + k]);
559599
}
560600
val += ((float) ival) * w->s[(in + j) / GS] * x->s[j / GS];
561601
ival = 0;
562602
}
603+
563604
xout[i] = val;
564605
}
606+
}
565607
// L2E Addition
566-
#endif
608+
#endif
567609
// END L2E Addition
568-
}
569610

570611
float* forward(Transformer* transformer, int token, int pos) {
571612

0 commit comments

Comments
 (0)