-
Notifications
You must be signed in to change notification settings - Fork 410
/
Copy pathswish.py
97 lines (87 loc) · 2.58 KB
/
swish.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import time
from typing import Optional
import torch
from torch.utils.cpp_extension import load
torch.set_grad_enabled(False)
# Load the CUDA kernel as a python module
lib = load(
name="swish_lib",
sources=["swish.cu"],
extra_cuda_cflags=[
"-O3",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
],
extra_cflags=["-std=c++17"],
)
def run_benchmark(
perf_func: callable,
x: torch.Tensor,
tag: str,
out: Optional[torch.Tensor] = None,
warmup: int = 10,
iters: int = 1000,
show_all: bool = False,
):
if out is not None:
out.fill_(0)
# warmup
if out is not None:
for i in range(warmup):
perf_func(x, out)
else:
for i in range(warmup):
out = perf_func(x)
torch.cuda.synchronize()
start = time.time()
# iters
if out is not None:
for i in range(iters):
perf_func(x, out)
else:
for i in range(iters):
out = perf_func(x)
torch.cuda.synchronize()
end = time.time()
total_time = (end - start) * 1000 # ms
mean_time = total_time / iters
out_info = f"out_{tag}"
out_val = out.flatten().detach().cpu().numpy().tolist()[:2]
out_val = [round(v, 8) for v in out_val]
out_val = [f"{v:<12}" for v in out_val]
print(f"{out_info:>18}: {out_val}, time:{mean_time:.8f}ms")
if show_all:
print(out)
return out, mean_time
def torch_swish(x, out=None):
if out is None:
return x * torch.sigmoid(x)
else:
torch.sigmoid(x, out=out)
out.mul_(x)
return out
Ss = [1024, 2048, 4096]
Ks = [1024, 2048, 4096]
SKs = [(S, K) for S in Ss for K in Ks]
for S, K in SKs:
print("-" * 85)
print(" " * 40 + f"S={S}, K={K}")
x = torch.randn((S, K)).cuda().float().contiguous()
y = torch.zeros_like(x).cuda().float().contiguous()
run_benchmark(lib.swish_f32, x, "f32", y)
run_benchmark(lib.swish_f32x4, x, "f32x4", y)
run_benchmark(torch_swish, x, "f32_th", y)
print("-" * 85)
x_f16 = x.half().contiguous()
y_f16 = y.half().contiguous()
run_benchmark(lib.swish_f16, x_f16, "f16", y_f16)
run_benchmark(lib.swish_f16x2, x_f16, "f16x2", y_f16)
run_benchmark(lib.swish_f16x8, x_f16, "f16x8", y_f16)
run_benchmark(lib.swish_f16x8_pack, x_f16, "f16x8pack", y_f16)
run_benchmark(torch_swish, x_f16, "f16_th", y_f16)
print("-" * 85)