@@ -66,7 +66,7 @@ def test_bench_matmul(batch, seq, model, hidden):
66
66
torch .matmul (A , B .t ())
67
67
torch .cuda .synchronize ()
68
68
print (
69
- f"pytorch fp16: [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time ()- t0 :.4f} s" ,
69
+ f"pytorch fp16: [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time () - t0 :.4f} s" ,
70
70
)
71
71
72
72
# torch.cuda.synchronize()
@@ -88,22 +88,24 @@ def test_bench_matmul(batch, seq, model, hidden):
88
88
for i in range (iters ):
89
89
bnb .matmul_4bit (A , B_nf4 .t (), quant_state = state_nf4 )
90
90
torch .cuda .synchronize ()
91
- print (f"bnb nf4: [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time ()- t0 :.4f} s" )
91
+ print (f"bnb nf4: [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time () - t0 :.4f} s" )
92
92
93
93
torch .cuda .synchronize ()
94
94
t0 = time .time ()
95
95
for i in range (iters ):
96
96
bnb .matmul_4bit (A , B_nf4_c .t (), quant_state = state_nf4_c )
97
97
torch .cuda .synchronize ()
98
- print (f"bnb nf4+DQ: [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time ()- t0 :.4f} s" )
98
+ print (
99
+ f"bnb nf4+DQ: [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time () - t0 :.4f} s"
100
+ )
99
101
100
102
torch .cuda .synchronize ()
101
103
t0 = time .time ()
102
104
for i in range (iters ):
103
105
bnb .matmul (A , B )
104
106
torch .cuda .synchronize ()
105
107
print (
106
- f"B -> CB (each iteration): [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time ()- t0 :.4f} s"
108
+ f"B -> CB (each iteration): [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time () - t0 :.4f} s"
107
109
)
108
110
109
111
torch .cuda .synchronize ()
@@ -112,7 +114,7 @@ def test_bench_matmul(batch, seq, model, hidden):
112
114
bnb .matmul (A , B , threshold = 6.0 )
113
115
torch .cuda .synchronize ()
114
116
print (
115
- f"B -> CB + threshold: [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time ()- t0 :.4f} s"
117
+ f"B -> CB + threshold: [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time () - t0 :.4f} s"
116
118
)
117
119
118
120
CA , SCA , _ = F .int8_vectorwise_quant (A , threshold = 0.0 )
@@ -124,7 +126,7 @@ def test_bench_matmul(batch, seq, model, hidden):
124
126
out32 = F .int8_linear_matmul (CA , CB )
125
127
torch .cuda .synchronize ()
126
128
print (
127
- f"no overhead int8 [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time ()- t0 :.4f} s"
129
+ f"no overhead int8 [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time () - t0 :.4f} s"
128
130
)
129
131
130
132
# C32A, SA = F.transform(CA, "col32")
@@ -183,7 +185,7 @@ def test_bench_matmul(batch, seq, model, hidden):
183
185
linear8bit (A )
184
186
torch .cuda .synchronize ()
185
187
print (
186
- f"bnb linear8bitlt (eval): [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time ()- t0 :.4f} s"
188
+ f"bnb linear8bitlt (eval): [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time () - t0 :.4f} s"
187
189
)
188
190
189
191
linearMixedBit (A )
@@ -193,7 +195,7 @@ def test_bench_matmul(batch, seq, model, hidden):
193
195
linearMixedBit (A )
194
196
torch .cuda .synchronize ()
195
197
print (
196
- f"bnb linear8bitlt with threshold (eval): [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time ()- t0 :.4f} s"
198
+ f"bnb linear8bitlt with threshold (eval): [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time () - t0 :.4f} s"
197
199
)
198
200
199
201
# linear8bit_train(A)
0 commit comments