diff --git a/docs/float8.md b/docs/float8.md index 63a029e60..5d8937429 100644 --- a/docs/float8.md +++ b/docs/float8.md @@ -17,6 +17,8 @@ CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_trai * `--float8.enable_fsdp_float8_all_gather`: cast `Float8Linear.weight` from high precision to float8 before FSDP all-gather so we can communicate in float8 to save bandwidth. * `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter. * `--float8.force_recompute_fp8_weight_in_bwd` (optional): force recomputation of fp8 weights during backward pass, preventing unsharded fp8 weights from being saved for backward. +* `--float8.filter_fqns="..."` (optional): a comma separated list of fully qualified names of modules not to convert to float8 training. Example: `--float8.filter_fqns="attention.wk,attention.wv"`. You can determine which layers to convert by looking at the microbenchmarks in the [performance section](https://github.com/pytorch/ao/tree/main/torchao/float8#performance) of the torchao documentation for the float8 recipe you're using. + * **Auto-filter**: add `"auto_filter_low_kn"` as one of the `--float8.filter_fqns=...` to to enable automatic module filtering, which will automatically not convert linear layers whose K,N dimensions are not large enough to benefit from float8 training. The thresholds for conversion are based on microbenchmarks measured on NVIDIA H100 GPUs. For best performance, you should still manually filter out layers that are too small to benefit from float8 training. * `--training.compile` (required for competitive performance): use `torch.compile` to fuse the float8 scaling/casting kernels For float8 with rowwise scaling, launch training job with the following command (or alternatively set configs in toml files) diff --git a/torchtitan/components/quantization/float8.py b/torchtitan/components/quantization/float8.py index 4a7e2a651..65357c6cf 100644 --- a/torchtitan/components/quantization/float8.py +++ b/torchtitan/components/quantization/float8.py @@ -3,7 +3,6 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - from functools import partial import torch @@ -20,6 +19,8 @@ from .utils import module_filter_fn +AUTO_FILTER_SMALL_KN_FLAG = "auto_filter_small_kn" + class Float8Converter(ModelConverter): def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): @@ -52,15 +53,18 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): return self.enabled = True - self.filter_fqns = float8_config.filter_fqns if float8_config.recipe_name is not None: - assert ( - not float8_config.enable_fsdp_float8_all_gather - ), "using `float8_config.enable_fsdp_float8_all_gather` together with `float8_config.recipe_name` is not supported" - assert ( - not float8_config.force_recompute_fp8_weight_in_bwd - ), "using `float8_config.force_recompute_fp8_weight_in_bwd` together with `float8_config.recipe_name` is not supported" + assert not float8_config.enable_fsdp_float8_all_gather, ( + "using `float8_config.enable_fsdp_float8_all_gather` together " + "with `float8_config.recipe_name` is not supported" + ) + + assert not float8_config.force_recompute_fp8_weight_in_bwd, ( + "using `float8_config.force_recompute_fp8_weight_in_bwd` together " + "with `float8_config.recipe_name` is not supported" + ) + self.config = Float8LinearConfig.from_recipe_name(float8_config.recipe_name) self.precompute_scale = False logger.info( @@ -73,7 +77,6 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): logger.debug( "Set torch._inductor.config.emulate_precision_casts to True" ) - else: # Mutates the model inplace replacing instances of nn.Linear with Float8Linear enable_fsdp_float8_all_gather = ( @@ -92,6 +95,50 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): ) logger.info("Float8 tensorwise scaled training active") + # configure the module filter function + self.filter_fn = self._init_filter_fn(float8_config) + + def _init_filter_fn(self, float8_config: Float8): + # use auto_filter if filter_fqns "auto_filter_small_kn" is one of the given fqns. + use_auto_filter = AUTO_FILTER_SMALL_KN_FLAG in float8_config.filter_fqns + if use_auto_filter: + try: + from torchao.float8 import _auto_filter_for_recipe + + logger.info( + "Using automatic module filter for float8 model conversion." + ) + + recipe_name = ( + float8_config.recipe_name + if float8_config.recipe_name + else "tensorwise" + ) + + # remove auto filter flag from filter_fqns before passing to _auto_filter_for_recipe + fqns = [ + fqn + for fqn in float8_config.filter_fqns + if fqn != AUTO_FILTER_SMALL_KN_FLAG + ] + + filter_fn = _auto_filter_for_recipe( + recipe_name, + filter_fqns=fqns, + ) + return filter_fn + except ImportError: + logger.warning( + ( + "Using default module_filter_fn for float8 model conversion. " + "To use _auto_filter_for_recipe, please install torchao nightly build." + ) + ) + + # use default filter func + filter_fn = partial(module_filter_fn, filter_fqns=float8_config.filter_fqns) + return filter_fn + def convert(self, model: nn.Module): """ This function converts the linear layers of `model` to `Float8Linear`. @@ -103,11 +150,10 @@ def convert(self, model: nn.Module): from torchao.float8 import convert_to_float8_training - # Mutates the model inplace replacing instances of nn.Linear with Float8Linear convert_to_float8_training( model, config=self.config, - module_filter_fn=partial(module_filter_fn, filter_fqns=self.filter_fqns), + module_filter_fn=self.filter_fn, ) logger.info( "Swapped to Float8Linear layers with enable_fsdp_float8_all_gather="