-
Notifications
You must be signed in to change notification settings - Fork 411
[float8] add _auto_filter_for_recipe for float8 training #1319
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to me. Thank you for the studies and efforts!
Let's also modify helper message to reflect this change
https://github.com/pytorch/torchtitan/blob/main/torchtitan/config_manager.py#L504
@@ -25,9 +24,9 @@ class Float8Converter(ModelConverter): | |||
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): | |||
self.enabled = False | |||
|
|||
float8_config: Float8 = job_config.float8 | |||
self.float8_config: Float8 = job_config.float8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having both self.float8_config
and self.config
sounds confusing.
Can we define self.filter_fn
in __init__()
so that we don't need self.float8_config
or self.filter_fqns
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, updated.
from torchao.float8 import _auto_filter_for_recipe | ||
|
||
# Mutates the model inplace replacing instances of nn.Linear with Float8Linear | ||
filter_fn = _auto_filter_for_recipe( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about mx quantization? would it also suffer from the issue / benefit from auto filtering?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably, but we don't have finalized perf numbers to reference to make an autofilter function for it (like the one added here https://github.com/pytorch/ao/pull/2312/files). We should add an auto filter option like this for mxfp8 once we can though.
I think it's better to have this off by default and make it easy to enable, to keep the defaults dead simple. Some challenges with this filtering is that it is not aware of |
Makes sense. How about this API to enable the auto filter: torchtitan/train.py ... --float8.filter_fqns="auto_filter" toml: [float8]
filter_fqns = ["auto_filter"] What do you think? This string could theoretically be part of a FQN but I think it's unlikely and we could document it clearly. |
c8d811f
to
78d91f3
Compare
78d91f3
to
89044d4
Compare
docs/float8.md
Outdated
@@ -17,6 +17,8 @@ CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_trai | |||
* `--float8.enable_fsdp_float8_all_gather`: cast `Float8Linear.weight` from high precision to float8 before FSDP all-gather so we can communicate in float8 to save bandwidth. | |||
* `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter. | |||
* `--float8.force_recompute_fp8_weight_in_bwd` (optional): force recomputation of fp8 weights during backward pass, preventing unsharded fp8 weights from being saved for backward. | |||
* `--float8.filter_fqns="..."` (optional): a comma separated list of fully qualified names of modules not to convert to float8 training. Example: `--float8.filter_fqns="attention.wk,attention.wv"`. You can determine which layers to convert by looking at the microbenchmarks in the [performance section](https://github.com/pytorch/ao/tree/main/torchao/float8#performance) of the torchao documentation for the float8 recipe you're using. | |||
* **Auto-filter**: use `--float8.filter_fqns="auto_filter"` to enable automatic module filtering, which will automatically not convert linear layers that are not large enough to benefit from float8 training. The thresholds for conversion are based on microbenchmarks measured on NVIDIA H100 GPUs. For best performance, you should still manually filter out layers that are too small to benefit from float8 training. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit 1: would be good to enable the user to filter out module foo
and then filter out other modules with the auto filter
nit 2: would be good to make the flag name more specific, for example auto_filter_low_kn
instead of auto_filter
. I guess this applies to torchao as well, sorry for not catching in initial review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit 2: would be good to make the flag name more specific, for example auto_filter_low_kn instead of auto_filter. I guess this applies to torchao as well, sorry for not catching in initial review.
Made the name more explicit: auto_filter_small_kn
nit 1: would be good to enable the user to filter out module foo and then filter out other modules with the auto filter
I agree, I updated it so the API is to just include "auto_filter_small_kn"
flag as one of the FQNs, instead of the only one. This way, the rest of the FQNs specified are processed as usual for filtering.
Fixes #1207
Problem
filter_fqns
for float8 model conversion are fine for the fp8 tensorwise recipe, but bad for the float8 rowwise recipe.Solution
This has been a footgun for various users as well (including Poolside), so I created an "auto filter" (pytorch/ao#2410) which automatically filters Linears for a given float8 recipe, by checking for the following criteria:
filter_fqns
It prevents users from hitting this common footgun, while also preserving the flexibility to define their model-specific fqns.
Results
Benchmarks show a ~10% TPS improvement for TP and ~15% TPS improvement for async TP (over bf16 TP baseline).
Llama3 70b on 256 H100s with FSDP=32, TP=8, torch.compile, full AC, local batch size 16: