|
| 1 | +FBGEMM GenAI Overview |
| 2 | +===================== |
| 3 | + |
| 4 | +High Level Overview |
| 5 | +------------------- |
| 6 | + |
| 7 | +FBGEMM FP8 rowwise quantization kernels have been officially adopted in the |
| 8 | +[Llama3.1 release](https://fb.workplace.com/groups/221503021668016/permalink/1900301927121442/). |
| 9 | +FP8 has been applied across Llama3 models with 8 B, 70 B, and 405 B. |
| 10 | +Notably, for the 405 B model, FP8 enables the inference on a single node, |
| 11 | +achieving a 2x throughput improvement over the baseline BF16 running on two |
| 12 | +nodes with pipeline parallelism. Externally, it has been mentioned in |
| 13 | +[Llama3 paper](https://ai.meta.com/research/publications/the-llama-3-herd-of-models/) & |
| 14 | +[repo](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/models/llama/quantize_impls.py), [HuggingFace](https://huggingface.co/docs/transformers/main/quantization/fbgemm_fp8), [vLLM](https://blog.vllm.ai/2024/07/23/llama31.html), and [TensorRT-LLM](https://developer.nvidia.com/blog/supercharging-llama-3-1-across-nvidia-platforms/). |
| 15 | + |
| 16 | +FBGEMM GenAI FP8 supports a variety of configurations: |
| 17 | + |
| 18 | +* GEMM Operators: {CUTLASS, CK, Triton} x {BF16, FP8} x {tensor-wise, row-wise, block-wise} x {Nvidia H100, AMD MI300x}. |
| 19 | +* High/low Precision Conversion Kernels: (FP32 / BF16 <-> FP8) with scaling options {tensor-wise, row-wise, block-wise} across hardware platforms {Nvidia H100, AMD MI300x} and programming options of {Triton, CUDA/HIP}. |
| 20 | + |
| 21 | +Besides FP8 support, FBGEMM GenAI operators also support: |
| 22 | + |
| 23 | +* Customized AllReduce communications (reduce latency for small message sizes). |
| 24 | +* GQA: optimized specifically for decoding cases, as detailed in PyTorch's blog on [INT4 decoding](https://pytorch.org/blog/int4-decoding/). |
| 25 | +* KV cache quantizations. |
| 26 | +* Rotary Positional Embedding (RoPE). |
| 27 | + |
| 28 | +FP8 Core API Functions |
| 29 | +---------------------- |
| 30 | + |
| 31 | +.. code:: python |
| 32 | +
|
| 33 | + # Rowwise quantize (channel wise) the weight from BF16 to FP8 |
| 34 | + wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w) |
| 35 | + # Rowwise quantize the activation (token wise) from BF16 to FP8 |
| 36 | + xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row( |
| 37 | + x, num_tokens, activation_scale_ub |
| 38 | + ) |
| 39 | + # Rowwise quantize GEMM with FP8 input and BF16 output |
| 40 | + y = torch.ops.fbgemm.f8f8bf16_rowwise( |
| 41 | + xq, |
| 42 | + wq, |
| 43 | + x_scale, |
| 44 | + w_scale, |
| 45 | + use_fast_accum=True, |
| 46 | + ) |
| 47 | +
|
| 48 | +See :ref:`genai-quantize-ops-stable-api` for more details. |
0 commit comments