-
Notifications
You must be signed in to change notification settings - Fork 411
[float8] add _auto_filter_for_recipe for float8 training #1319
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
fd1ad02
b954430
0cf737a
4758006
3cafe89
fc6b141
f0af111
d18fa0e
89044d4
8a3a3de
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,8 @@ CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_trai | |
* `--float8.enable_fsdp_float8_all_gather`: cast `Float8Linear.weight` from high precision to float8 before FSDP all-gather so we can communicate in float8 to save bandwidth. | ||
* `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter. | ||
* `--float8.force_recompute_fp8_weight_in_bwd` (optional): force recomputation of fp8 weights during backward pass, preventing unsharded fp8 weights from being saved for backward. | ||
* `--float8.filter_fqns="..."` (optional): a comma separated list of fully qualified names of modules not to convert to float8 training. Example: `--float8.filter_fqns="attention.wk,attention.wv"`. You can determine which layers to convert by looking at the microbenchmarks in the [performance section](https://github.com/pytorch/ao/tree/main/torchao/float8#performance) of the torchao documentation for the float8 recipe you're using. | ||
* **Auto-filter**: add `"auto_filter_low_kn"` as one of the `--float8.filter_fqns=...` to to enable automatic module filtering, which will automatically not convert linear layers whose K,N dimensions are not large enough to benefit from float8 training. The thresholds for conversion are based on microbenchmarks measured on NVIDIA H100 GPUs. For best performance, you should still manually filter out layers that are too small to benefit from float8 training. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's be consistent with |
||
* `--training.compile` (required for competitive performance): use `torch.compile` to fuse the float8 scaling/casting kernels | ||
|
||
For float8 with rowwise scaling, launch training job with the following command (or alternatively set configs in toml files) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,6 @@ | |
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from functools import partial | ||
|
||
import torch | ||
|
@@ -20,6 +19,8 @@ | |
|
||
from .utils import module_filter_fn | ||
|
||
AUTO_FILTER_SMALL_KN_FLAG = "auto_filter_small_kn" | ||
|
||
|
||
class Float8Converter(ModelConverter): | ||
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): | ||
|
@@ -52,15 +53,18 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): | |
return | ||
|
||
self.enabled = True | ||
self.filter_fqns = float8_config.filter_fqns | ||
|
||
if float8_config.recipe_name is not None: | ||
assert ( | ||
not float8_config.enable_fsdp_float8_all_gather | ||
), "using `float8_config.enable_fsdp_float8_all_gather` together with `float8_config.recipe_name` is not supported" | ||
assert ( | ||
not float8_config.force_recompute_fp8_weight_in_bwd | ||
), "using `float8_config.force_recompute_fp8_weight_in_bwd` together with `float8_config.recipe_name` is not supported" | ||
assert not float8_config.enable_fsdp_float8_all_gather, ( | ||
"using `float8_config.enable_fsdp_float8_all_gather` together " | ||
"with `float8_config.recipe_name` is not supported" | ||
) | ||
|
||
assert not float8_config.force_recompute_fp8_weight_in_bwd, ( | ||
"using `float8_config.force_recompute_fp8_weight_in_bwd` together " | ||
"with `float8_config.recipe_name` is not supported" | ||
) | ||
|
||
self.config = Float8LinearConfig.from_recipe_name(float8_config.recipe_name) | ||
self.precompute_scale = False | ||
logger.info( | ||
|
@@ -73,7 +77,6 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): | |
logger.debug( | ||
"Set torch._inductor.config.emulate_precision_casts to True" | ||
) | ||
|
||
else: | ||
# Mutates the model inplace replacing instances of nn.Linear with Float8Linear | ||
enable_fsdp_float8_all_gather = ( | ||
|
@@ -92,6 +95,50 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): | |
) | ||
logger.info("Float8 tensorwise scaled training active") | ||
|
||
# configure the module filter function | ||
self.filter_fn = self._init_filter_fn(float8_config) | ||
|
||
def _init_filter_fn(self, float8_config: Float8): | ||
# use auto_filter if filter_fqns "auto_filter_small_kn" is one of the given fqns. | ||
use_auto_filter = AUTO_FILTER_SMALL_KN_FLAG in float8_config.filter_fqns | ||
if use_auto_filter: | ||
try: | ||
from torchao.float8 import _auto_filter_for_recipe | ||
|
||
logger.info( | ||
"Using automatic module filter for float8 model conversion." | ||
) | ||
|
||
recipe_name = ( | ||
float8_config.recipe_name | ||
if float8_config.recipe_name | ||
else "tensorwise" | ||
) | ||
|
||
# remove auto filter flag from filter_fqns before passing to _auto_filter_for_recipe | ||
fqns = [ | ||
fqn | ||
for fqn in float8_config.filter_fqns | ||
if fqn != AUTO_FILTER_SMALL_KN_FLAG | ||
] | ||
Comment on lines
+118
to
+123
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we use |
||
|
||
filter_fn = _auto_filter_for_recipe( | ||
recipe_name, | ||
filter_fqns=fqns, | ||
) | ||
return filter_fn | ||
except ImportError: | ||
logger.warning( | ||
( | ||
"Using default module_filter_fn for float8 model conversion. " | ||
"To use _auto_filter_for_recipe, please install torchao nightly build." | ||
) | ||
) | ||
|
||
# use default filter func | ||
filter_fn = partial(module_filter_fn, filter_fqns=float8_config.filter_fqns) | ||
return filter_fn | ||
Comment on lines
+139
to
+140
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe merge the two lines |
||
|
||
def convert(self, model: nn.Module): | ||
""" | ||
This function converts the linear layers of `model` to `Float8Linear`. | ||
|
@@ -103,11 +150,10 @@ def convert(self, model: nn.Module): | |
|
||
from torchao.float8 import convert_to_float8_training | ||
|
||
# Mutates the model inplace replacing instances of nn.Linear with Float8Linear | ||
convert_to_float8_training( | ||
model, | ||
config=self.config, | ||
module_filter_fn=partial(module_filter_fn, filter_fqns=self.filter_fqns), | ||
module_filter_fn=self.filter_fn, | ||
) | ||
logger.info( | ||
"Swapped to Float8Linear layers with enable_fsdp_float8_all_gather=" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you educate me more on what are K,N dimensions and why float8 doesn't benefit much if K,N are not large enough?
Users might also have such doubts, so might be good to explain a bit more in the short manual.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I'll add some more info to this doc, but basically the K and N dimensions are referring to the GEMM operation between the inputs and weights of a linear layer => (M,K) @ (K,N) = (M,N). So in this context, the linear layer has shape K,N. (technically, the weight is N,K row-major then is transposed for the matmul X @ W^T).
Our microbenchmarking shows there are certain size thresholds for the linear layer K and N, below which the performance of fp8 linear was always worse than bf16. Basically, the GEMMs have to be big enough that the speedup from using FP8 tensorcores is greater than the overhead of creating dynamically quantized inputs.
The threshholds are different for tensorwise scaling vs rowwise scaling - you can check out these performance tables to get an idea of when it makes sense to convert a linear layer to float8 or not: https://github.com/pytorch/ao/tree/main/torchao/float8#performance
For example, for tensorwise scaling, if K <= 4096 and N <= 1024, all of our benchmarks showed worse performance than bf16, for all tested values of M (from 1024 to 16384).
It's possible for very large values of M, beyond what we tested, the perf change could be positive. However, this auto filter is not intended to be universally optiminal in all cases - it's just a simple way users can avoid hitting this common footgun that causes fp8 to seemingly perform worse than bf16, without needing to do manual layer analysis + cross-referencing with our performance tables to manually filter out layers.
For the best results, users should still do layer analysis and not rely on this heuristic based auto filter that doesn't account for M.