Skip to content

update float8 training readme to include time measurement #2291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 4, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions torchao/float8/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ and composable with key systems such as autograd, ```torch.compile``` and distri
This is the default recipe, with a good balance of performance and accuracy.

```python
import time

import torch
import torch.nn as nn
from torchao.float8 import convert_to_float8_training
Expand All @@ -26,11 +28,12 @@ if not TORCH_VERSION_AT_LEAST_2_5:
raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater")

# create model and sample input
M, K, N = 4096, 8192, 4096
m = nn.Sequential(
nn.Linear(2048, 4096),
nn.Linear(4096, 128),
nn.Linear(K, N, bias=False),
nn.Linear(N, 128, bias=False),
).bfloat16().cuda()
x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)

# optional: filter modules from being eligible for float8 conversion
Expand All @@ -50,12 +53,26 @@ convert_to_float8_training(m, module_filter_fn=module_filter_fn)
# enable torch.compile for competitive performance
m = torch.compile(m)

# warm up torch.compile for a clean training time measurement
for _ in range(1):
optimizer.zero_grad()
y = m(x)
y.sum().backward()
optimizer.step()

torch.cuda.synchronize()
start_time = time.time()

# toy training loop
for _ in range(10):
optimizer.zero_grad()
y = m(x)
y.sum().backward()
optimizer.step()

torch.cuda.synchronize()
end_time = time.time()
print("Training time:", end_time - start_time)
```

## float8 linear with rowwise scaling
Expand Down
Loading