Skip to content

Commit 0eb0616

Browse files
committed
update float8 training readme to include time measurement
Summary: Update the float8 training example code snippet to include time measurement that properly excludes torch.compile one-time warmup. Also, use larger shapes to demonstrate speedup from float8. Test Plan: copy-paste the snippet and run it, it works. Commenting out float8 shows a slowdown, as expected. Reviewers: Subscribers: Tasks: Tags:
1 parent dd43f16 commit 0eb0616

File tree

1 file changed

+20
-3
lines changed

1 file changed

+20
-3
lines changed

torchao/float8/README.md

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ and composable with key systems such as autograd, ```torch.compile``` and distri
1717
This is the default recipe, with a good balance of performance and accuracy.
1818

1919
```python
20+
import time
21+
2022
import torch
2123
import torch.nn as nn
2224
from torchao.float8 import convert_to_float8_training
@@ -26,11 +28,12 @@ if not TORCH_VERSION_AT_LEAST_2_5:
2628
raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater")
2729

2830
# create model and sample input
31+
M, K, N = 4096, 8192, 4096
2932
m = nn.Sequential(
30-
nn.Linear(2048, 4096),
31-
nn.Linear(4096, 128),
33+
nn.Linear(K, N, bias=False),
34+
nn.Linear(N, 128, bias=False),
3235
).bfloat16().cuda()
33-
x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
36+
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
3437
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)
3538

3639
# optional: filter modules from being eligible for float8 conversion
@@ -50,12 +53,26 @@ convert_to_float8_training(m, module_filter_fn=module_filter_fn)
5053
# enable torch.compile for competitive performance
5154
m = torch.compile(m)
5255

56+
# warm up torch.compile for a clean training time measurement
57+
for _ in range(1):
58+
optimizer.zero_grad()
59+
y = m(x)
60+
y.sum().backward()
61+
optimizer.step()
62+
63+
torch.cuda.synchronize()
64+
start_time = time.time()
65+
5366
# toy training loop
5467
for _ in range(10):
5568
optimizer.zero_grad()
5669
y = m(x)
5770
y.sum().backward()
5871
optimizer.step()
72+
73+
torch.cuda.synchronize()
74+
end_time = time.time()
75+
print("Training time:", end_time - start_time)
5976
```
6077

6178
## float8 linear with rowwise scaling

0 commit comments

Comments
 (0)