This repo contains official implementation for Training LLMs with MXFP4. Our MXFP4 training recipe achieves near-lossless training by computing unbiased gradient estimates (with stochastic rounding and random Hadamard transformation) using MXFP4-accelerated GEMMs. This allows us compute the backward pass in MXFP4, which constitutes
We support training with NVIDIA/Megatron-LM
and NVIDIA/TransformerEngine
. Due to lack of MXFP4 hardware supports (Blackwell GPUs), we use microsoft/microxcaling
to perform emulation of MXFP4 GEMMS (OCP MX specification).
We recommend using NGC's PyTorch container with released tag pytorch:24.04-py3
docker pull nvcr.io/nvidia/pytorch:24.04-py3
We support MXFP4 backward passes with both BF16 and FP8 forward passes, leveraging TransformerEngine for the latter. Currently, we only supported FP8 + MXFP4 training with TransformerEngine-Version('1.5.0+6a9edc3'), which comes pre-installed in the pytorch:24.04-py3
container.
We used the GPT2BPETokenizer
preprocessed Wikipedia dataset (around 3.28 billion tokens). Please follow AWS-Neuron-Examples-Megatron-LM-GPT to
download from s3.
export DATA_DIR=./examples_datasets/gpt2
mkdir -p ${DATA_DIR} && cd ${DATA_DIR}
wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json
wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt
aws s3 cp s3://neuron-s3/training_datasets/gpt/wikipedia/my-gpt2_text_document.bin . --no-sign-request
aws s3 cp s3://neuron-s3/training_datasets/gpt/wikipedia/my-gpt2_text_document.idx . --no-sign-request
aws s3 cp s3://neuron-s3/training_datasets/gpt/wikipedia/license.txt . --no-sign-request
Alternatively, you can also prepare custom dataset with Megatron-LM/tools/preprocess_data.py
to build megatron-compatible mmap format .bin
& .idx
from scratch, as illustrated in preparing-wikipedia-dataset-from-scratch.
Clone the repository and submodule with the following command:
git clone https://github.com/amazon-science/mxfp4-llm
cd mxfp4-llm
git submodule update --init --recursive
We made changes to the official NVIDIA/Megatron-LM
-v0.2.0 and microsoft/microxcaling
-v1.1.1.dev0, and packaged into patches. Apply these patches to third_party/*
with
cd third_party/Megatron-LM
git apply ../../patch_override_scripts/Megatron-LM.patch
cd ../microxcaling
git apply ../../patch_override_scripts/microxcaling.patch
cd ../../scripts
A detailed description of the changes can be found in patch_override_scripts
. Check our paper for more information.
We provide scripts to train GPT3-345M, 1.3B, and 6.7B parameter models in scripts/gpt3
, with a guideline on configurable precision options at scripts
. This code is well-tested on Ada (A100, L40S) and Hopper (H100) GPUs.
- Migrate to newer third_party package versions, e.g. Megatron-v0.13.0rc0 and TransformerEngine-v1.14.0, in order to support
--tp-comm-overlap
for the forward pass GEMM - Add LLaMA pretraining scripts
- Utilize MXFP4 GEMM support on Blackwell GPUs
This project welcomes contributions and suggestions, see CONTRIBUTING.md for details. This project has adopted the Amazon Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact [email protected] with any additional questions or comments.
This project is licensed under the Apache 2.0 license
.
If you find our works helpful in your research, please consider citing the following paper:
@inproceedings{
tseng2025training,
title={Training {LLM}s with {MXFP}4},
author={Albert Tseng and Tao Yu and Youngsuk Park},
booktitle={The 28th International Conference on Artificial Intelligence and Statistics},
year={2025},
url={https://openreview.net/forum?id=a8z5Q0WSPL}
}