Skip to content

[float8] add _auto_filter_for_recipe for float8 training #1319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Jun 18, 2025

Fixes #1207

Problem

  • float8 rowwise + vanilla TP in torchtitan had flat perf with respect to bfloat16 (see float8 rowwise vanilla TP low throughput #1207).
  • RCA In float8 rowwise vanilla TP low throughput #1207 found attention.wk and attention.wv layers were so small that float8 rowwise conversion resulted in approx ~40% slowdown for those layers, which nullified the perf benefits from fp8 rowwise conversion on larger linears.
  • This is because the default filter_fqns for float8 model conversion are fine for the fp8 tensorwise recipe, but bad for the float8 rowwise recipe.

Solution

This has been a footgun for various users as well (including Poolside), so I created an "auto filter" (pytorch/ao#2410) which automatically filters Linears for a given float8 recipe, by checking for the following criteria:

  1. dims not divisible by 16 (hardware requirement for float8)
  2. dim sizes below thresholds that may result in worse perf for that given recipe, using simple heuristics based on the linked recipe perf tables above.
  3. fqn matches one of the user defined filter_fqns

It prevents users from hitting this common footgun, while also preserving the flexibility to define their model-specific fqns.

Results

Benchmarks show a ~10% TPS improvement for TP and ~15% TPS improvement for async TP (over bf16 TP baseline).

Llama3 70b on 256 H100s with FSDP=32, TP=8, torch.compile, full AC, local batch size 16:

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 18, 2025
@danielvegamyhre danielvegamyhre changed the title [WIP] [float8] add float auto_filter_for_recipe [float8] add float auto_filter_for_recipe Jun 18, 2025
@danielvegamyhre danielvegamyhre marked this pull request as draft June 18, 2025 22:13
@danielvegamyhre danielvegamyhre changed the title [float8] add float auto_filter_for_recipe [WIP] [float8] add float auto_filter_for_recipe Jun 18, 2025
@danielvegamyhre danielvegamyhre marked this pull request as ready for review June 23, 2025 21:14
@danielvegamyhre danielvegamyhre changed the title [WIP] [float8] add float auto_filter_for_recipe [float8] add float auto_filter_for_recipe Jun 24, 2025
@danielvegamyhre danielvegamyhre changed the title [float8] add float auto_filter_for_recipe [float8] add float8 _auto_filter_for_recipe Jun 24, 2025
@danielvegamyhre danielvegamyhre changed the title [float8] add float8 _auto_filter_for_recipe [float8] add _auto_filter_for_recipe for float8 training Jun 24, 2025
@danielvegamyhre
Copy link
Contributor Author

cc @tianyu @vkuzo for review + thoughts on if this would be useful to add as the default module filter for float8 in torchtitan

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me. Thank you for the studies and efforts!

Let's also modify helper message to reflect this change
https://github.com/pytorch/torchtitan/blob/main/torchtitan/config_manager.py#L504

@@ -25,9 +24,9 @@ class Float8Converter(ModelConverter):
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
self.enabled = False

float8_config: Float8 = job_config.float8
self.float8_config: Float8 = job_config.float8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having both self.float8_config and self.config sounds confusing.
Can we define self.filter_fn in __init__() so that we don't need self.float8_config or self.filter_fqns?

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Jun 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, updated.

from torchao.float8 import _auto_filter_for_recipe

# Mutates the model inplace replacing instances of nn.Linear with Float8Linear
filter_fn = _auto_filter_for_recipe(
Copy link
Contributor

@tianyu-l tianyu-l Jun 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about mx quantization? would it also suffer from the issue / benefit from auto filtering?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably, but we don't have finalized perf numbers to reference to make an autofilter function for it (like the one added here https://github.com/pytorch/ao/pull/2312/files). We should add an auto filter option like this for mxfp8 once we can though.

@vkuzo
Copy link
Contributor

vkuzo commented Jun 25, 2025

I think it's better to have this off by default and make it easy to enable, to keep the defaults dead simple. Some challenges with this filtering is that it is not aware of M, it is not aware of the underlying hardware, and it will behave unexpectedly on the debug model. How about we just make this easy to enable and add documentation recommending to enable it?

@danielvegamyhre
Copy link
Contributor Author

danielvegamyhre commented Jun 27, 2025

I think it's better to have this off by default and make it easy to enable, to keep the defaults dead simple. Some challenges with this filtering is that it is not aware of M, it is not aware of the underlying hardware, and it will behave unexpectedly on the debug model. How about we just make this easy to enable and add documentation recommending to enable it?

Makes sense. How about this API to enable the auto filter:

torchtitan/train.py ... --float8.filter_fqns="auto_filter"

toml:

[float8]
filter_fqns = ["auto_filter"]

What do you think? This string could theoretically be part of a FQN but I think it's unlikely and we could document it clearly.

@danielvegamyhre danielvegamyhre force-pushed the auto_filter branch 4 times, most recently from c8d811f to 78d91f3 Compare June 27, 2025 06:41
docs/float8.md Outdated
@@ -17,6 +17,8 @@ CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_trai
* `--float8.enable_fsdp_float8_all_gather`: cast `Float8Linear.weight` from high precision to float8 before FSDP all-gather so we can communicate in float8 to save bandwidth.
* `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter.
* `--float8.force_recompute_fp8_weight_in_bwd` (optional): force recomputation of fp8 weights during backward pass, preventing unsharded fp8 weights from being saved for backward.
* `--float8.filter_fqns="..."` (optional): a comma separated list of fully qualified names of modules not to convert to float8 training. Example: `--float8.filter_fqns="attention.wk,attention.wv"`. You can determine which layers to convert by looking at the microbenchmarks in the [performance section](https://github.com/pytorch/ao/tree/main/torchao/float8#performance) of the torchao documentation for the float8 recipe you're using.
* **Auto-filter**: use `--float8.filter_fqns="auto_filter"` to enable automatic module filtering, which will automatically not convert linear layers that are not large enough to benefit from float8 training. The thresholds for conversion are based on microbenchmarks measured on NVIDIA H100 GPUs. For best performance, you should still manually filter out layers that are too small to benefit from float8 training.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit 1: would be good to enable the user to filter out module foo and then filter out other modules with the auto filter
nit 2: would be good to make the flag name more specific, for example auto_filter_low_kn instead of auto_filter. I guess this applies to torchao as well, sorry for not catching in initial review.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit 2: would be good to make the flag name more specific, for example auto_filter_low_kn instead of auto_filter. I guess this applies to torchao as well, sorry for not catching in initial review.

Made the name more explicit: auto_filter_small_kn

nit 1: would be good to enable the user to filter out module foo and then filter out other modules with the auto filter

I agree, I updated it so the API is to just include "auto_filter_small_kn" flag as one of the FQNs, instead of the only one. This way, the rest of the FQNs specified are processed as usual for filtering.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

float8 rowwise vanilla TP low throughput
4 participants