Skip to content

Commit 8b6fe9e

Browse files
Test cleanup (#1576)
* Testing cleanup * More test cleanup * Additional deprecations/removals. * Skip benchmark, deprecated, slow tests by default
1 parent 677ff40 commit 8b6fe9e

12 files changed

+462
-684
lines changed

benchmarking/optimizer_benchmark.py

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""
2+
Extracted from tests/test_optim.py
3+
4+
Usage: pytest benchmarking/optimizer_benchmark.py
5+
"""
6+
7+
import time
8+
9+
import pytest
10+
from tests.helpers import describe_dtype, id_formatter
11+
import torch
12+
13+
import bitsandbytes as bnb
14+
15+
str2optimizers = {"paged_adamw": (torch.optim.AdamW, bnb.optim.PagedAdamW)}
16+
17+
18+
@pytest.mark.parametrize("dim1", [2 * 1024], ids=id_formatter("dim1"))
19+
@pytest.mark.parametrize("gtype", [torch.float16], ids=describe_dtype)
20+
@pytest.mark.parametrize("optim_name", ["paged_adamw"], ids=id_formatter("optim_name"))
21+
@pytest.mark.parametrize("mode", ["bnb"], ids=id_formatter("mode"))
22+
@pytest.mark.benchmark
23+
def test_stream_optimizer_bench(dim1, gtype, optim_name, mode):
24+
layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)]))
25+
layers1 = layers1.to(gtype)
26+
layers1 = layers1.cuda()
27+
28+
large_tensor = None
29+
if mode == "torch":
30+
optim = str2optimizers[optim_name][0](layers1.parameters())
31+
else:
32+
optim = str2optimizers[optim_name][1](layers1.parameters())
33+
# 12 GB
34+
large_tensor = torch.empty((int(4.5e9),), device="cuda")
35+
36+
torch.cuda.synchronize()
37+
time.sleep(5)
38+
39+
num_batches = 5
40+
batches = torch.randn(num_batches, 128, dim1, device="cuda").to(gtype)
41+
lbls = torch.randint(0, 10, size=(num_batches, 128)).cuda()
42+
43+
for i in range(num_batches):
44+
print(i)
45+
b = batches[i]
46+
if i == 2:
47+
torch.cuda.synchronize()
48+
t0 = time.time()
49+
50+
out1 = layers1(b)
51+
52+
loss1 = torch.nn.functional.cross_entropy(out1, lbls[i]).mean()
53+
loss1.backward()
54+
optim.step()
55+
torch.cuda.synchronize()
56+
print(mode, time.time() - t0)

0 commit comments

Comments
 (0)