|
| 1 | +""" |
| 2 | +Extracted from tests/test_optim.py |
| 3 | +
|
| 4 | +Usage: pytest benchmarking/optimizer_benchmark.py |
| 5 | +""" |
| 6 | + |
| 7 | +import time |
| 8 | + |
| 9 | +import pytest |
| 10 | +from tests.helpers import describe_dtype, id_formatter |
| 11 | +import torch |
| 12 | + |
| 13 | +import bitsandbytes as bnb |
| 14 | + |
| 15 | +str2optimizers = {"paged_adamw": (torch.optim.AdamW, bnb.optim.PagedAdamW)} |
| 16 | + |
| 17 | + |
| 18 | +@pytest.mark.parametrize("dim1", [2 * 1024], ids=id_formatter("dim1")) |
| 19 | +@pytest.mark.parametrize("gtype", [torch.float16], ids=describe_dtype) |
| 20 | +@pytest.mark.parametrize("optim_name", ["paged_adamw"], ids=id_formatter("optim_name")) |
| 21 | +@pytest.mark.parametrize("mode", ["bnb"], ids=id_formatter("mode")) |
| 22 | +@pytest.mark.benchmark |
| 23 | +def test_stream_optimizer_bench(dim1, gtype, optim_name, mode): |
| 24 | + layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)])) |
| 25 | + layers1 = layers1.to(gtype) |
| 26 | + layers1 = layers1.cuda() |
| 27 | + |
| 28 | + large_tensor = None |
| 29 | + if mode == "torch": |
| 30 | + optim = str2optimizers[optim_name][0](layers1.parameters()) |
| 31 | + else: |
| 32 | + optim = str2optimizers[optim_name][1](layers1.parameters()) |
| 33 | + # 12 GB |
| 34 | + large_tensor = torch.empty((int(4.5e9),), device="cuda") |
| 35 | + |
| 36 | + torch.cuda.synchronize() |
| 37 | + time.sleep(5) |
| 38 | + |
| 39 | + num_batches = 5 |
| 40 | + batches = torch.randn(num_batches, 128, dim1, device="cuda").to(gtype) |
| 41 | + lbls = torch.randint(0, 10, size=(num_batches, 128)).cuda() |
| 42 | + |
| 43 | + for i in range(num_batches): |
| 44 | + print(i) |
| 45 | + b = batches[i] |
| 46 | + if i == 2: |
| 47 | + torch.cuda.synchronize() |
| 48 | + t0 = time.time() |
| 49 | + |
| 50 | + out1 = layers1(b) |
| 51 | + |
| 52 | + loss1 = torch.nn.functional.cross_entropy(out1, lbls[i]).mean() |
| 53 | + loss1.backward() |
| 54 | + optim.step() |
| 55 | + torch.cuda.synchronize() |
| 56 | + print(mode, time.time() - t0) |
0 commit comments